25 |
26 | - The result by Avatar-Net receives concrete multi-scale style patterns (e.g. color distribution, brush strokes and circular patterns in _candy_ image).
27 | - [WCT](https://arxiv.org/abs/1705.08086) distorts the brush strokes and circular patterns. [AdaIN](https://arxiv.org/abs/1703.06868) cannot even keep the color distribution, while [Style-Swap](https://arxiv.org/abs/1612.04337) fails in this example.
28 |
29 | #### Execution Efficiency
30 | |Method| Gatys et. al. | AdaIN | WCT | Style-Swap | __Avatar-Net__ |
31 | | :---: | :---: | :---: | :---: | :---: | :---: |
32 | | __256x256 (sec)__ | 12.18 | 0.053 | 0.62 | 0.064 | __0.071__ |
33 | | __512x512 (sec)__ | 43.25 | 0.11 | 0.93 | 0.23 | __0.28__ |
34 |
35 | - Avatar-Net has a comparable executive time as AdaIN and GPU-accelerated Style-Swap, and is much faster than WCT and the optimization-based style transfer by [Gatys _et. al._](https://arxiv.org/abs/1508.06576).
36 | - The reference methods and the proposed Avatar-Net are implemented on a same TensorFlow platform with a same VGG network as the backbone.
37 |
38 | ## Dependencies
39 | - [TensorFlow](https://www.tensorflow.org/) (version >= 1.0, but just tested on TensorFlow 1.0).
40 | - Heavily depend on [TF-Slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim) and its [model repository](https://github.com/tensorflow/models/tree/master/research/slim).
41 |
42 | ## Download
43 | - The trained model of Avatar-Net can be downloaded through the [Google Drive](https://drive.google.com/open?id=1_7x93xwZMhCL-kLrz4B2iZ01Y8Q7SlTX).
44 | - The training of our style transfer network requires pretrained [VGG](https://arxiv.org/abs/1409.1556) networks, and they can be obtained from the [TF-Slim model repository](ttps://github.com/tensorflow/models/tree/master/research/slim). The encoding layers of Avatar-Net are also borrowed from pretrained VGG models.
45 | - [MSCOCO](http://cocodataset.org/#home) dataset is applied for the training of the proposed image reconstruction network.
46 |
47 | ## Usage
48 |
49 | ### Basic Usage
50 |
51 | Simply use the bash file `./scripts/evaluate_style_transfer.sh` to apply Avatar-Net to all content images in `CONTENT_DIR` from any style image in `STYLE_DIR`. For example,
52 |
53 | bash ./scripts/evaluate_style_transfer.sh gpu_id CONTENT_DIR STYLE_DIR EVAL_DIR
54 |
55 | - `gpu_id`: the mounted GPU ID for the TensorFlow session.
56 | - `CONTENT_DIR`: the directory of the content images. It can be `./data/contents/images` for multiple exemplar content images, or `./data/contents/sequences` for an exemplar content video.
57 | - `STYLE_DIR`: the directory of the style images. It can be `./data/styles` for multiple exemplar style images.
58 | - `EVAL_DIR`: the output directory. It contains multiple subdirectories named after the names of the style images.
59 |
60 | More detailed evaluation options can be found in `evaluate_style_transfer.py`, such as
61 |
62 | python evaluate_style_transfer.py
63 |
64 | ### Configuration
65 |
66 | The detailed configuration of Avatar-Net is listed in `configs/AvatarNet.yml`, including the training specifications and network hyper-parameters. The style decorator has three options:
67 |
68 | - `patch_size`: the patch size for the normalized cross-correlation, in default is `5`.
69 | - `style_coding`: the projection and reconstruction method, either `ZCA` or `AdaIN`.
70 | - `style_interp`: interpolation option between the transferred features and the content features, either `normalized` or `biased`.
71 |
72 | The style transfer is actually performed in `AvatarNet.transfer_styles(self, inputs, styles, inter_weight, intra_weights)`, in which
73 |
74 | - `inputs`: the content images.
75 | - `styles`: a list of style images (`len(styles)` > 2 for multiple style interpolation).
76 | - `inter_weight`: the weight balancing the style and content images.
77 | - `intra_weights`: a list of weights balancing the effects from different styles.
78 |
79 | Users may modify the evaluation script for multiple style interpolation or content-style trade-off.
80 |
81 | ### Training
82 |
83 | 1. Download [MSCOCO](http://cocodataset.org/#home) datasets and transfer the raw images into `tfexamples`, according to the python script `./datasets/convert_mscoco_to_tfexamples.py`.
84 | 2. Use `bash ./scripts/train_image_reconstruction.sh gpu_id DATASET_DIR MODEL_DIR` to start training with default hyper-parameters. `gpu_id` is the mounted GPU for the applied Tensorflow session. Replace `DATASET_DIR` with the path to MSCOCO training images and `MODEL_DIR` to Avatar-Net model directory.
85 |
86 | ## Citation
87 |
88 | If you find this code useful for your research, please cite the paper:
89 |
90 | Lu Sheng, Ziyi Lin, Jing Shao and Xiaogang Wang, "Avatar-Net: Multi-scale Zero-shot Style Transfer by Feature Decoration", in IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2018. [[Arxiv](https://arxiv.org/abs/1805.03857)]
91 |
92 | ```
93 | @inproceedings{sheng2018avatar,
94 | Title = {Avatar-Net: Multi-scale Zero-shot Style Transfer by Feature Decoration},
95 | author = {Sheng, Lu and Lin, Ziyi and Shao, Jing and Wang, Xiaogang},
96 | Booktitle = {Computer Vision and Pattern Recognition (CVPR), 2018 IEEE Conference on},
97 | pages={1--9},
98 | year={2018}
99 | }
100 | ```
101 |
102 | ## Acknowledgement
103 |
104 | This project is inspired by many style-agnostic style transfer methods, including [AdaIN](https://arxiv.org/abs/1703.06868), [WCT](https://arxiv.org/abs/1705.08086) and [Style-Swap](https://arxiv.org/abs/1612.04337), both from their papers and codes.
105 |
106 | ## Contact
107 |
108 | If you have any questions or suggestions about this paper, feel free to contact me ([lsheng@ee.cuhk.edu.hk](mailto:lsheng@ee.cuhk.edu.hk))
109 |
--------------------------------------------------------------------------------
/configs/AvatarNet_config.yml:
--------------------------------------------------------------------------------
1 | # name of the applied model
2 | model_name: 'AvatarNet'
3 |
4 | # the input sizes
5 | content_size: 512
6 | style_size: 512
7 |
8 | # perceptual loss configurations
9 | network_name: 'vgg_19'
10 | checkpoint_path: '/DATA/lsheng/model_zoo/VGG/vgg_19.ckpt'
11 | checkpoint_exclude_scopes: 'vgg_19/fc'
12 | ignore_missing_vars: True
13 |
14 | # style loss layers
15 | style_loss_layers:
16 | - 'conv1/conv1_1'
17 | - 'conv2/conv2_1'
18 | - 'conv3/conv3_1'
19 | - 'conv4/conv4_1'
20 |
21 | #################################
22 | # style decorator specification #
23 | #################################
24 | # patch size for style decorator
25 | patch_size: 5
26 |
27 | # style encoding method
28 | style_coding: 'ZCA' # 'AdaIN'
29 |
30 | # style interpolation
31 | style_interp: 'normalized'
32 |
33 | ####################
34 | # training routine #
35 | ####################
36 | training_image_size: 256
37 | weight_decay: 0.0005
38 | trainable_scopes: 'combined_decoder'
39 |
40 | # loss weights
41 | content_weight: 1.0
42 | recons_weight: 10.0
43 | tv_weight: 10.0
44 |
--------------------------------------------------------------------------------
/data/contents/images/avril.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/images/avril.jpg
--------------------------------------------------------------------------------
/data/contents/images/brad_pitt.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/images/brad_pitt.jpg
--------------------------------------------------------------------------------
/data/contents/images/cornell.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/images/cornell.jpg
--------------------------------------------------------------------------------
/data/contents/images/flowers.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/images/flowers.jpg
--------------------------------------------------------------------------------
/data/contents/images/modern.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/images/modern.jpg
--------------------------------------------------------------------------------
/data/contents/images/woman_side_portrait.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/images/woman_side_portrait.jpg
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0001.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0002.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0003.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0004.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0005.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0006.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0007.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0008.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0008.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0009.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0009.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0010.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0011.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0011.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0012.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0012.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0013.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0013.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0014.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0014.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0015.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0015.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0016.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0016.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0017.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0017.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0018.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0018.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0019.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0019.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0020.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0020.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0021.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0021.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0022.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0022.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0023.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0023.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0024.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0024.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0025.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0026.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0026.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0027.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0027.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0028.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0028.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0029.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0029.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0030.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0030.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0031.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0031.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0032.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0032.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0033.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0033.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0034.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0034.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0035.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0035.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0036.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0036.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0037.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0037.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0038.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0038.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0039.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0039.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0040.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0040.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0041.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0041.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0042.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0042.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0043.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0043.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0044.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0044.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0045.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0045.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0046.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0046.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0047.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0047.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0048.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0048.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0049.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0049.png
--------------------------------------------------------------------------------
/data/contents/sequences/frame_0050.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0050.png
--------------------------------------------------------------------------------
/data/styles/brushstrokers.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/styles/brushstrokers.jpg
--------------------------------------------------------------------------------
/data/styles/candy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/styles/candy.jpg
--------------------------------------------------------------------------------
/data/styles/la_muse.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/styles/la_muse.jpg
--------------------------------------------------------------------------------
/data/styles/plum_flower.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/styles/plum_flower.jpg
--------------------------------------------------------------------------------
/data/styles/woman_in_peasant_dress_cropped.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/styles/woman_in_peasant_dress_cropped.jpg
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/datasets/__init__.pyc
--------------------------------------------------------------------------------
/datasets/convert_mscoco_to_tfexamples.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import math
6 | import os
7 | import random
8 | import sys
9 |
10 | from datasets import dataset_utils
11 |
12 | import tensorflow as tf
13 |
14 | _NUM_SHARDS = 5
15 | _RANDOM_SEED = 0
16 |
17 | FLAGS = tf.app.flags.FLAGS
18 |
19 | tf.app.flags.DEFINE_string(
20 | 'output_dataset_dir', None,
21 | 'The directory where the outputs TFRecords and temporary files are saved')
22 |
23 | tf.app.flags.DEFINE_string(
24 | 'input_dataset_dir', None,
25 | 'The directory where the input files are saved.')
26 |
27 |
28 | def _get_filenames(dataset_dir):
29 | split_dirs = ['train2014', 'val2014', 'test2014']
30 |
31 | # get the full path to each image
32 | train_dir = os.path.join(dataset_dir, split_dirs[0])
33 | validation_dir = os.path.join(dataset_dir, split_dirs[1])
34 | test_dir = os.path.join(dataset_dir, split_dirs[2])
35 |
36 | train_image_filenames = []
37 | for filename in os.listdir(train_dir):
38 | file_path = os.path.join(train_dir, filename)
39 | train_image_filenames.append(file_path)
40 |
41 | validation_image_filenames = []
42 | for filename in os.listdir(validation_dir):
43 | file_path = os.path.join(validation_dir, filename)
44 | validation_image_filenames.append(file_path)
45 |
46 | test_image_filenames = []
47 | for filename in os.listdir(test_dir):
48 | file_path = os.path.join(test_dir, filename)
49 | test_image_filenames.append(file_path)
50 |
51 | print('Statistics in MSCOCO dataset...')
52 | print('There are %d images in train dataset' % len(train_image_filenames))
53 | print('There are %d images in validation dataset' % len(validation_image_filenames))
54 | print('There are %d images in test dataset' % len(test_image_filenames))
55 |
56 | return train_image_filenames, validation_image_filenames, test_image_filenames
57 |
58 |
59 | def _get_dataset_filename(dataset_dir, split_name, shard_id):
60 | output_filename = 'MSCOCO_%s_%05d-of-%05d.tfrecord' % (
61 | split_name, shard_id, _NUM_SHARDS)
62 | return os.path.join(dataset_dir, output_filename)
63 |
64 |
65 | def _convert_dataset(split_name, image_filenames, dataset_dir):
66 | assert split_name in ['train', 'validation', 'test']
67 |
68 | num_per_shard = int(math.ceil(len(image_filenames) / float(_NUM_SHARDS)))
69 |
70 | with tf.Graph().as_default():
71 | image_reader = dataset_utils.ImageReader()
72 |
73 | with tf.Session('') as sess:
74 | for shard_id in range(_NUM_SHARDS):
75 | output_filename = _get_dataset_filename(
76 | dataset_dir, split_name, shard_id)
77 | with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
78 | start_ndx = shard_id * num_per_shard
79 | end_ndx = min((shard_id+1) * num_per_shard, len(image_filenames))
80 | for i in range(start_ndx, end_ndx):
81 | sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
82 | i + 1, len(image_filenames), shard_id))
83 | sys.stdout.flush()
84 | # read the image
85 | img_filename = image_filenames[i]
86 | img_data = tf.gfile.FastGFile(img_filename, 'r').read()
87 | img_shape = image_reader.read_image_dims(sess, img_data)
88 | example = dataset_utils.image_to_tfexample(
89 | img_data, img_filename[-3:], img_shape, img_filename)
90 | tfrecord_writer.write(example.SerializeToString())
91 | sys.stdout.write('\n')
92 | sys.stdout.flush()
93 |
94 |
95 | def _dataset_exists(dataset_dir):
96 | for split_name in ['train', 'validation', 'test']:
97 | for shard_id in range(_NUM_SHARDS):
98 | output_filename = _get_dataset_filename(
99 | dataset_dir, split_name, shard_id)
100 | if not tf.gfile.Exists(output_filename):
101 | return False
102 | return True
103 |
104 |
105 | def run(input_dataset_dir, output_dataset_dir):
106 | if not tf.gfile.Exists(output_dataset_dir):
107 | tf.gfile.MakeDirs(output_dataset_dir)
108 |
109 | if _dataset_exists(output_dataset_dir):
110 | print('Dataset files already exist. Exiting without re-creating them.')
111 | return
112 |
113 | train_image_filenames, validation_image_filenames, test_image_filenames = \
114 | _get_filenames(input_dataset_dir)
115 |
116 | # randomize the datasets
117 | random.seed(_RANDOM_SEED)
118 | random.shuffle(train_image_filenames)
119 | random.shuffle(validation_image_filenames)
120 | random.shuffle(test_image_filenames)
121 |
122 | num_train = len(train_image_filenames)
123 | num_validation = len(validation_image_filenames)
124 | num_test = len(test_image_filenames)
125 | num_samples = num_train + num_validation + num_test
126 |
127 | # store the dataset meta data
128 | dataset_meta_data = {
129 | 'dataset_name': 'MSCOCO',
130 | 'source_dataset_dir': input_dataset_dir,
131 | 'num_of_samples': num_samples,
132 | 'num_of_train': num_train,
133 | 'num_of_validation': num_validation,
134 | 'num_of_test': num_test,
135 | 'train_image_filenames': train_image_filenames,
136 | 'validation_image_filenames': validation_image_filenames,
137 | 'test_image_filenames': test_image_filenames}
138 | dataset_utils.write_dataset_meta_data(output_dataset_dir, dataset_meta_data)
139 |
140 | _convert_dataset('train', train_image_filenames, output_dataset_dir)
141 | _convert_dataset('validation', validation_image_filenames, output_dataset_dir)
142 | _convert_dataset('test', test_image_filenames, output_dataset_dir)
143 |
144 |
145 | def main(_):
146 | if not FLAGS.input_dataset_dir:
147 | raise ValueError('You must supply the dataset directory with --dataset_dir')
148 | run(FLAGS.input_dataset_dir, FLAGS.output_dataset_dir)
149 |
150 |
151 | if __name__ == '__main__':
152 | tf.app.run()
153 |
--------------------------------------------------------------------------------
/datasets/dataset_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import yaml
7 |
8 | import tensorflow as tf
9 |
10 | slim = tf.contrib.slim
11 |
12 | _META_DATA_FILENAME = 'dataset_meta_data.txt'
13 |
14 | _FILE_PATTERN = '%s_%s_*.tfrecord'
15 |
16 | _ITEMS_TO_DESCRIPTIONS = {
17 | 'image': 'A color image of varying size.',
18 | 'shape': 'The shape of the image.'
19 | }
20 |
21 |
22 | def int64_feature(values):
23 | if not isinstance(values, (tuple, list)):
24 | values = [values]
25 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
26 |
27 |
28 | def bytes_feature(values):
29 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
30 |
31 |
32 | def image_to_tfexample(image_data, image_format, image_shape, image_filename):
33 | return tf.train.Example(features=tf.train.Features(feature={
34 | 'image/encoded': bytes_feature(image_data),
35 | 'image/format': bytes_feature(image_format),
36 | 'image/shape': int64_feature(image_shape),
37 | 'image/filename': bytes_feature(image_filename),
38 | }))
39 |
40 |
41 | def write_dataset_meta_data(dataset_dir, dataset_meta_data,
42 | filename=_META_DATA_FILENAME):
43 | meta_filename = os.path.join(dataset_dir, filename)
44 | with open(meta_filename, 'wb') as f:
45 | yaml.dump(dataset_meta_data, f)
46 | print('Finish writing the dataset meta data.')
47 |
48 |
49 | def has_dataset_meta_data_file(dataset_dir, filename=_META_DATA_FILENAME):
50 | return tf.gfile.Exists(os.path.join(dataset_dir, filename))
51 |
52 |
53 | def read_dataset_meta_data(dataset_dir, filename=_META_DATA_FILENAME):
54 | meta_filename = os.path.join(dataset_dir, filename)
55 | with open(meta_filename, 'rb') as f:
56 | dataset_meta_data = yaml.load(f)
57 | print('Finish loading the dataset meta data of [%s].' %
58 | dataset_meta_data.get('dataset_name'))
59 | return dataset_meta_data
60 |
61 |
62 | def get_split(dataset_name,
63 | split_name,
64 | dataset_dir,
65 | file_pattern=None,
66 | reader=None):
67 | if split_name not in ['train', 'validation']:
68 | raise ValueError('split name %s was not recognized.' % split_name)
69 |
70 | if not file_pattern:
71 | file_pattern = _FILE_PATTERN
72 | file_pattern = os.path.join(dataset_dir, file_pattern % (
73 | dataset_name, split_name))
74 |
75 | # read the dataset meta data
76 | if has_dataset_meta_data_file(dataset_dir):
77 | dataset_meta_data = read_dataset_meta_data(dataset_dir)
78 | num_samples = dataset_meta_data.get('num_of_' + split_name)
79 | else:
80 | raise ValueError('No dataset_meta_data file available in %s' % dataset_dir)
81 |
82 | if reader is None:
83 | reader = tf.TFRecordReader
84 |
85 | keys_to_features = {
86 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
87 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
88 | 'image/shape': tf.FixedLenFeature((3,), tf.int64, default_value=(224, 224, 3)),
89 | 'image/filename': tf.FixedLenFeature([], tf.string, default_value=''),
90 | }
91 |
92 | items_to_handlers = {
93 | 'image': slim.tfexample_decoder.Image(
94 | 'image/encoded', 'image/format'),
95 | 'shape': slim.tfexample_decoder.Tensor('image/shape'),
96 | 'filename': slim.tfexample_decoder.Tensor('image/filename')
97 | }
98 |
99 | decoder = slim.tfexample_decoder.TFExampleDecoder(
100 | keys_to_features, items_to_handlers)
101 |
102 | return slim.dataset.Dataset(
103 | data_sources=file_pattern,
104 | reader=reader,
105 | decoder=decoder,
106 | num_samples=num_samples,
107 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS)
108 |
109 |
110 | class ImageReader(object):
111 | """helper class that provides tensorflow image coding utilities."""
112 | def __init__(self):
113 | self._decode_data = tf.placeholder(dtype=tf.string)
114 | self._decode_image = tf.image.decode_image(self._decode_data, channels=0)
115 |
116 | def read_image_dims(self, sess, image_data):
117 | image = self.decode_image(sess, image_data)
118 | return image.shape
119 |
120 | def decode_image(self, sess, image_data):
121 | image = sess.run(self._decode_image,
122 | feed_dict={self._decode_data: image_data})
123 | assert len(image.shape) == 3
124 | assert image.shape[2] == 3
125 | return image
126 |
127 |
128 | class ImageCoder(object):
129 | """helper class that provides Tensorflow Image coding utilities,
130 | also works for corrupted data with incorrected extension
131 | """
132 | def __init__(self):
133 | self._decode_data = tf.placeholder(dtype=tf.string)
134 | self._decode_image = tf.image.decode_image(self._decode_data, channels=0)
135 | self._encode_jpeg = tf.image.encode_jpeg(self._decode_image, format='rgb', quality=100)
136 |
137 | def decode_image(self, sess, image_data):
138 | # verify the image from the image_data
139 | status = False
140 | try:
141 | # decode image and verify the data
142 | image = sess.run(self._decode_image,
143 | feed_dict={self._decode_data: image_data})
144 | image_shape = image.shape
145 | assert len(image_shape) == 3
146 | assert image_shape[2] == 3
147 | # encode as RGB JPEG image string and return
148 | image_string = sess.run(self._encode_jpeg, feed_dict={self._decode_data: image_data})
149 | status = True
150 | except BaseException:
151 | image_shape, image_string = None, None
152 | return status, image_string, image_shape
153 |
--------------------------------------------------------------------------------
/datasets/dataset_utils.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/datasets/dataset_utils.pyc
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-cayman
2 | title: Avatar-Net
3 | description: Multi-scale Zero-shot Style Transfer by Feature Decoration
4 | show_downloads: true
5 |
--------------------------------------------------------------------------------
/docs/figures/closed_ups.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/closed_ups.png
--------------------------------------------------------------------------------
/docs/figures/image_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/image_results.png
--------------------------------------------------------------------------------
/docs/figures/network_architecture_with_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/network_architecture_with_comparison.png
--------------------------------------------------------------------------------
/docs/figures/result_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/result_comparison.png
--------------------------------------------------------------------------------
/docs/figures/snapshot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/snapshot.png
--------------------------------------------------------------------------------
/docs/figures/style_decorator.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/style_decorator.png
--------------------------------------------------------------------------------
/docs/figures/style_interpolation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/style_interpolation.png
--------------------------------------------------------------------------------
/docs/figures/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/teaser.png
--------------------------------------------------------------------------------
/docs/figures/trade_off.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/trade_off.png
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | [teaser]: ./figures/teaser.png
2 | ![teaser]
3 |
4 | Exemplar stylized results by the proposed Avatar-Net, which faithfully transfers the lena image by arbitrary style.
5 |
6 |
7 | ## Overview
8 |
9 | Zero-shot artistic style transfer is an important image synthesis problem aiming at transferring arbitrary style into content images. However, the trade-off between the generalization and efficiency in existing methods impedes a high quality zero-shot style transfer in real-time. In this repository, we resolve this dilemma and propose an efficient yet effective Avatar-Net that enables visually plausible multi-scale transfer for arbitrary style.
10 |
11 | The key ingredient of our method is a __style decorator__ that makes up the content features by semantically aligned style features from an arbitrary style image, which does not only holistically match their feature distributions but also preserve detailed style patterns in the decorated features.
12 |
13 | [style_decorator]: ./figures/style_decorator.png
14 | ![style_decorator]
15 |
18 |
19 | By embedding this module into a reconstruction network that fuses multi-scale style abstractions, the Avatar-Net renders multi-scale stylization for any style image in one feed-forward pass.
20 |
21 | [network]: ./figures/network_architecture_with_comparison.png
22 | ![network]
23 |
24 | (a) Stylization comparison by autoencoder and style-augmented hourglass network. (b) The network architecture of the proposed method.
25 |
Exemplar stylized results by the proposed Avatar-Net.
32 |
33 | We demonstrate the state-of-the-art effectiveness and efficiency of the proposed method in generating high-quality stylized images, with a series of successful applications including multiple style integration, video stylization and etc.
34 |
35 | #### Comparison with Prior Arts
36 |
37 |
38 |
39 | - The result by Avatar-Net receives concrete multi-scale style patterns (e.g. color distribution, brush strokes and circular patterns in the style image).
40 | - WCT distorts the brush strokes and circular patterns. AdaIN cannot even keep the color distribution, while style-swap fails in this example.
41 |
42 | #### Execution Efficiency
43 |
44 |
45 |
46 |
47 |
48 |
Method
49 |
Gatys et. al.
50 |
AdaIN
51 |
WCT
52 |
Style-Swap
53 |
Avatar-Net
54 |
55 |
56 |
256x256 (sec)
57 |
12.18
58 |
0.053
59 |
0.62
60 |
0.064
61 |
0.071
62 |
63 |
64 |
512x512 (sec)
65 |
43.25
66 |
0.11
67 |
0.93
68 |
0.23
69 |
0.28
70 |
71 |
72 |
73 |
74 |
75 | - Avatar-Net has a comparable executive time as AdaIN and GPU-accelerated Style-Swap, and is much faster than WCT and the optimization-based style transfer by Gatys _et. al._.
76 | - The reference methods and the proposed Avatar-Net are implemented on a same TensorFlow platform with a same VGG network as the backbone.
77 |
78 | ### Applications
79 | #### Multi-style Interpolation
80 | [style_interpolation]: ./figures/style_interpolation.png
81 | ![style_interpolation]
82 |
83 | #### Content and Style Trade-off
84 | [trade_off]: ./figures/trade_off.png
85 | ![trade_off]
86 |
87 | #### Video Stylization ([the Youtube link](https://youtu.be/amaeqbw6TeA))
88 |
89 |
90 |
91 |
92 |
93 | ## Code
94 |
95 | Please refer to the [GitHub repository](https://github.com/LucasSheng/avatar-net) for more details.
96 |
97 | ## Publication
98 |
99 | Lu Sheng, Ziyi Lin, Jing Shao and Xiaogang Wang, "Avatar-Net: Multi-scale Zero-shot Style Transfer by Feature Decoration", in IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2018. [[Arxiv](https://arxiv.org/abs/1805.03857)]
100 |
101 | ```
102 | @inproceedings{sheng2018avatar,
103 | Title = {Avatar-Net: Multi-scale Zero-shot Style Transfer by Feature Decoration},
104 | author = {Sheng, Lu and Lin, Ziyi and Shao, Jing and Wang, Xiaogang},
105 | Booktitle = {Computer Vision and Pattern Recognition (CVPR), 2018 IEEE Conference on},
106 | pages={1--9},
107 | year={2018}
108 | }
109 | ```
110 |
--------------------------------------------------------------------------------
/evaluate_style_transfer.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 | import time
7 | import scipy.misc
8 | import numpy as np
9 | import tensorflow as tf
10 |
11 | from models import models_factory
12 | from models import preprocessing
13 |
14 | from PIL import Image
15 |
16 | slim = tf.contrib.slim
17 |
18 |
19 | tf.app.flags.DEFINE_string(
20 | 'checkpoint_dir', 'tmp/tfmodel',
21 | 'The directory where the model was written to or an absolute path to a '
22 | 'checkpoint file.')
23 | tf.app.flags.DEFINE_string(
24 | 'eval_dir', 'tmp/tfmodel',
25 | 'Directory where the results are saved to.')
26 | tf.app.flags.DEFINE_string(
27 | 'content_dataset_dir', None,
28 | 'The content directory where the test images are stored.')
29 | tf.app.flags.DEFINE_string(
30 | 'style_dataset_dir', None,
31 | 'The style directory where the style images are stored.')
32 |
33 | # choose the model configuration file
34 | tf.app.flags.DEFINE_string(
35 | 'model_config_path', None,
36 | 'The path of the model configuration file.')
37 | tf.app.flags.DEFINE_float(
38 | 'inter_weight', 1.0,
39 | 'The blending weight of the style patterns in the stylized image')
40 |
41 | FLAGS = tf.app.flags.FLAGS
42 |
43 |
44 | def get_image_filenames(dataset_dir):
45 | """helper fn that provides the full image filenames from the dataset_dir"""
46 | image_filenames = []
47 | for filename in os.listdir(dataset_dir):
48 | file_path = os.path.join(dataset_dir, filename)
49 | image_filenames.append(file_path)
50 | return image_filenames
51 |
52 |
53 | def image_reader(filename):
54 | """help fn that provides numpy image coding utilities"""
55 | img = scipy.misc.imread(filename).astype(np.float)
56 | if len(img.shape) == 2:
57 | img = np.dstack((img, img, img))
58 | elif img.shape[2] == 4:
59 | img = img[:, :, :3]
60 | return img
61 |
62 |
63 | def imsave(filename, img):
64 | img = np.clip(img, 0, 255).astype(np.uint8)
65 | Image.fromarray(img).save(filename, quality=95)
66 |
67 |
68 | def main(_):
69 | if not FLAGS.content_dataset_dir:
70 | raise ValueError('You must supply the content dataset directory '
71 | 'with --content_dataset_dir')
72 | if not FLAGS.style_dataset_dir:
73 | raise ValueError('You must supply the style dataset directory '
74 | 'with --style_dataset_dir')
75 |
76 | if not FLAGS.checkpoint_dir:
77 | raise ValueError('You must supply the checkpoints directory with '
78 | '--checkpoint_dir')
79 |
80 | if tf.gfile.IsDirectory(FLAGS.checkpoint_dir):
81 | checkpoint_dir = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
82 | else:
83 | checkpoint_dir = FLAGS.checkpoint_dir
84 |
85 | if not tf.gfile.Exists(FLAGS.eval_dir):
86 | tf.gfile.MakeDirs(FLAGS.eval_dir)
87 |
88 | tf.logging.set_verbosity(tf.logging.INFO)
89 | with tf.Graph().as_default():
90 | # define the model
91 | style_model, options = models_factory.get_model(FLAGS.model_config_path)
92 |
93 | # predict the stylized image
94 | inp_content_image = tf.placeholder(tf.float32, shape=(None, None, 3))
95 | inp_style_image = tf.placeholder(tf.float32, shape=(None, None, 3))
96 |
97 | # preprocess the content and style images
98 | content_image = preprocessing.mean_image_subtraction(inp_content_image)
99 | content_image = tf.expand_dims(content_image, axis=0)
100 | # style resizing and cropping
101 | style_image = preprocessing.preprocessing_image(
102 | inp_style_image,
103 | 448,
104 | 448,
105 | style_model.style_size)
106 | style_image = tf.expand_dims(style_image, axis=0)
107 |
108 | # style transfer
109 | stylized_image = style_model.transfer_styles(
110 | content_image,
111 | style_image,
112 | inter_weight=FLAGS.inter_weight)
113 | stylized_image = tf.squeeze(stylized_image, axis=0)
114 |
115 | # gather the test image filenames and style image filenames
116 | style_image_filenames = get_image_filenames(FLAGS.style_dataset_dir)
117 | content_image_filenames = get_image_filenames(FLAGS.content_dataset_dir)
118 |
119 | # starting inference of the images
120 | init_fn = slim.assign_from_checkpoint_fn(
121 | checkpoint_dir, slim.get_model_variables(), ignore_missing_vars=True)
122 | with tf.Session() as sess:
123 | # initialize the graph
124 | init_fn(sess)
125 |
126 | nn = 0.0
127 | total_time = 0.0
128 | # style transfer for each image based on one style image
129 | for i in range(len(style_image_filenames)):
130 | # gather the storage folder for the style transfer
131 | style_label = style_image_filenames[i].split('/')[-1]
132 | style_label = style_label.split('.')[0]
133 | style_dir = os.path.join(FLAGS.eval_dir, style_label)
134 |
135 | if not tf.gfile.Exists(style_dir):
136 | tf.gfile.MakeDirs(style_dir)
137 |
138 | # get the style image
139 | np_style_image = image_reader(style_image_filenames[i])
140 | print('Starting transferring the style of [%s]' % style_label)
141 |
142 | for j in range(len(content_image_filenames)):
143 | # gather the content image
144 | np_content_image = image_reader(content_image_filenames[j])
145 |
146 | start_time = time.time()
147 | np_stylized_image = sess.run(stylized_image,
148 | feed_dict={inp_content_image: np_content_image,
149 | inp_style_image: np_style_image})
150 | incre_time = time.time() - start_time
151 | nn += 1.0
152 | total_time += incre_time
153 | print("---%s seconds ---" % (total_time/nn))
154 |
155 | output_filename = os.path.join(
156 | style_dir, content_image_filenames[j].split('/')[-1])
157 | imsave(output_filename, np_stylized_image)
158 | print('Style [%s]: Finish transfer the image [%s]' % (
159 | style_label, content_image_filenames[j]))
160 |
161 |
162 | if __name__ == '__main__':
163 | tf.app.run()
164 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/__init__.py
--------------------------------------------------------------------------------
/models/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/__init__.pyc
--------------------------------------------------------------------------------
/models/autoencoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | from __future__ import division
4 |
5 | import tensorflow as tf
6 |
7 | from models import losses
8 | from models import preprocessing
9 | from models import vgg
10 | from models import vgg_decoder
11 |
12 | slim = tf.contrib.slim
13 |
14 | network_map = {
15 | 'vgg_16': vgg.vgg_16,
16 | 'vgg_19': vgg.vgg_19,
17 | }
18 |
19 |
20 | class AutoEncoder(object):
21 | def __init__(self, options):
22 | self.weight_decay = options.get('weight_decay')
23 |
24 | self.default_size = options.get('default_size')
25 | self.content_size = options.get('content_size')
26 |
27 | # network architecture
28 | self.network_name = options.get('network_name')
29 |
30 | # the loss layers for content and style similarity
31 | self.content_layers = options.get('content_layers')
32 |
33 | # the weights for the losses when trains the invertible network
34 | self.content_weight = options.get('content_weight')
35 | self.recons_weight = options.get('recons_weight')
36 | self.tv_weight = options.get('tv_weight')
37 |
38 | # gather the summaries and initialize the losses
39 | self.summaries = None
40 | self.total_loss = 0.0
41 | self.recons_loss = {}
42 | self.content_loss = {}
43 | self.tv_loss = {}
44 | self.train_op = None
45 |
46 | def auto_encoder(self, inputs, content_layer=2, reuse=True):
47 | # extract the content features
48 | image_features = losses.extract_image_features(inputs, self.network_name)
49 | content_features = losses.compute_content_features(image_features, self.content_layers)
50 |
51 | # used content feature
52 | selected_layer = self.content_layers[content_layer]
53 | content_feature = content_features[selected_layer]
54 | input_content_features = {selected_layer: content_feature}
55 |
56 | # reconstruct the images
57 | with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope(self.weight_decay)):
58 | outputs = vgg_decoder.vgg_decoder(
59 | content_feature,
60 | self.network_name,
61 | selected_layer,
62 | reuse=reuse,
63 | scope='decoder_%d' % content_layer)
64 | return outputs, input_content_features
65 |
66 | def build_train_graph(self, inputs):
67 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
68 | for i in range(len(self.content_layers)):
69 | # skip some networks
70 | if i < 3:
71 | continue
72 |
73 | selected_layer = self.content_layers[i]
74 |
75 | outputs, inputs_content_features = self.auto_encoder(
76 | inputs, content_layer=i, reuse=False)
77 | outputs = preprocessing.batch_mean_image_subtraction(outputs)
78 |
79 | ########################
80 | # construct the losses #
81 | ########################
82 | # 1) reconstruction loss
83 | recons_loss = tf.losses.mean_squared_error(
84 | inputs, outputs, scope='recons_loss/decoder_%d' % i)
85 | self.recons_loss[selected_layer] = recons_loss
86 | self.total_loss += self.recons_weight * recons_loss
87 | summaries.add(tf.summary.scalar(
88 | 'recons_loss/decoder_%d' % i, recons_loss))
89 |
90 | # 2) content loss
91 | outputs_image_features = losses.extract_image_features(
92 | outputs, self.network_name)
93 | outputs_content_features = losses.compute_content_features(
94 | outputs_image_features, [selected_layer])
95 | content_loss = losses.compute_content_loss(
96 | outputs_content_features, inputs_content_features, [selected_layer])
97 | self.content_loss[selected_layer] = content_loss
98 | self.total_loss += self.content_weight * content_loss
99 | summaries.add(tf.summary.scalar(
100 | 'content_loss/decoder_%d' % i, content_loss))
101 |
102 | # 3) total variation loss
103 | tv_loss = losses.compute_total_variation_loss_l1(outputs)
104 | self.tv_loss[selected_layer] = tv_loss
105 | self.total_loss += self.tv_weight * tv_loss
106 | summaries.add(tf.summary.scalar(
107 | 'tv_loss/decoder_%d' % i, tv_loss))
108 |
109 | image_tiles = tf.concat([inputs, outputs], axis=2)
110 | image_tiles = preprocessing.batch_mean_image_summation(image_tiles)
111 | image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0), tf.uint8)
112 | summaries.add(tf.summary.image(
113 | 'image_comparison/decoder_%d' % i, image_tiles, max_outputs=8))
114 |
115 | self.summaries = summaries
116 | return self.total_loss
117 |
118 | def get_training_operations(self, optimizer, global_step,
119 | variables_to_train=tf.trainable_variables()):
120 | # gather the variable summaries
121 | variables_summaries = []
122 | for var in variables_to_train:
123 | variables_summaries.append(tf.summary.histogram(var.op.name, var))
124 | variables_summaries = set(variables_summaries)
125 |
126 | # add the training operations
127 | train_ops = []
128 |
129 | grads_and_vars = optimizer.compute_gradients(
130 | self.total_loss, var_list=variables_to_train)
131 | train_op = optimizer.apply_gradients(
132 | grads_and_vars, global_step=global_step)
133 | train_ops.append(train_op)
134 |
135 | self.summaries |= variables_summaries
136 | self.train_op = tf.group(*train_ops)
137 | return self.train_op
138 |
--------------------------------------------------------------------------------
/models/avatar_net.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | from __future__ import division
4 |
5 | import tensorflow as tf
6 |
7 | from models import losses
8 | from models import network_ops
9 | from models import vgg
10 | from models import vgg_decoder
11 | from models import preprocessing
12 |
13 | slim = tf.contrib.slim
14 |
15 | network_map = {
16 | 'vgg_16': vgg.vgg_16,
17 | 'vgg_19': vgg.vgg_19,
18 | }
19 |
20 |
21 | class AvatarNet(object):
22 | def __init__(self, options):
23 | self.training_image_size = options.get('training_image_size')
24 | self.content_size = options.get('content_size')
25 | self.style_size = options.get('style_size')
26 |
27 | # network architecture
28 | self.network_name = options.get('network_name')
29 |
30 | # the loss layers for content and style similarity
31 | self.style_loss_layers = options.get('style_loss_layers')
32 |
33 | ##########################
34 | # style decorator option #
35 | ##########################
36 | # style coding method
37 | self.style_coding = options.get('style_coding')
38 |
39 | # style interpolation method
40 | self.style_interp = options.get('style_interp')
41 |
42 | # window size
43 | self.patch_size = options.get('patch_size')
44 |
45 | #######################
46 | # training quantities #
47 | #######################
48 | self.content_weight = options.get('content_weight')
49 | self.recons_weight = options.get('recons_weight')
50 | self.tv_weight = options.get('tv_weight')
51 | self.weight_decay = options.get('weight_decay')
52 |
53 | ##############################################
54 | # gather summaries and initialize the losses #
55 | ##############################################
56 | self.total_loss = 0.0
57 | self.recons_loss = None
58 | self.content_loss = None
59 | self.tv_loss = None
60 |
61 | ############################
62 | # summary and training ops #
63 | ############################
64 | self.train_op = None
65 | self.summaries = None
66 |
67 | def transfer_styles(self,
68 | inputs,
69 | styles,
70 | inter_weight=1.0,
71 | intra_weights=(1,)):
72 | """transfer the content image by style images
73 |
74 | Args:
75 | inputs: input images [batch_size, height, width, channel]
76 | styles: a list of input styles, in default the size is 1
77 | inter_weight: the blending weight between the content and style
78 | intra_weights: a list of blending weights among the styles,
79 | in default it is (1,)
80 |
81 | Returns:
82 | outputs: the stylized images [batch_size, height, width, channel]
83 | """
84 | if not isinstance(styles, (list, tuple)):
85 | styles = [styles]
86 |
87 | if not isinstance(intra_weights, (list, tuple)):
88 | intra_weights = [intra_weights]
89 |
90 | # 1) extract the style features
91 | styles_features = []
92 | for style in styles:
93 | style_image_features = losses.extract_image_features(
94 | style, self.network_name)
95 | style_features = losses.compute_content_features(
96 | style_image_features, self.style_loss_layers)
97 | styles_features.append(style_features)
98 |
99 | # 2) content features
100 | inputs_image_features = losses.extract_image_features(
101 | inputs, self.network_name)
102 | inputs_features = losses.compute_content_features(
103 | inputs_image_features, self.style_loss_layers)
104 |
105 | # 3) style decorator
106 | # the applied content feature from the content input
107 | selected_layer = self.style_loss_layers[-1]
108 | hidden_feature = inputs_features[selected_layer]
109 |
110 | # applying the style decorator
111 | blended_feature = 0.0
112 | n = 0
113 | for style_features in styles_features:
114 | swapped_feature = style_decorator(
115 | hidden_feature,
116 | style_features[selected_layer],
117 | style_coding=self.style_coding,
118 | style_interp=self.style_interp,
119 | ratio_interp=inter_weight,
120 | patch_size=self.patch_size)
121 | blended_feature += intra_weights[n] * swapped_feature
122 | n += 1
123 |
124 | # 4) decode the hidden feature to the output image
125 | with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope()):
126 | outputs = vgg_decoder.vgg_multiple_combined_decoder(
127 | blended_feature,
128 | styles_features,
129 | intra_weights,
130 | fusion_fn=network_ops.adaptive_instance_normalization,
131 | network_name=self.network_name,
132 | starting_layer=selected_layer)
133 | return outputs
134 |
135 | def hierarchical_autoencoder(self, inputs, reuse=True):
136 | """hierarchical autoencoder for content reconstruction"""
137 | # extract the content features
138 | image_features = losses.extract_image_features(
139 | inputs, self.network_name)
140 | content_features = losses.compute_content_features(
141 | image_features, self.style_loss_layers)
142 |
143 | # the applied content feature for the decode network
144 | selected_layer = self.style_loss_layers[-1]
145 | hidden_feature = content_features[selected_layer]
146 |
147 | # decode the hidden feature to the output image
148 | with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope(self.weight_decay)):
149 | outputs = vgg_decoder.vgg_combined_decoder(
150 | hidden_feature,
151 | content_features,
152 | fusion_fn=network_ops.adaptive_instance_normalization,
153 | network_name=self.network_name,
154 | starting_layer=selected_layer,
155 | reuse=reuse)
156 | return outputs
157 |
158 | def build_train_graph(self, inputs):
159 | """build the training graph for the training of the hierarchical autoencoder"""
160 | outputs = self.hierarchical_autoencoder(inputs, reuse=False)
161 | outputs = preprocessing.batch_mean_image_subtraction(outputs)
162 |
163 | # summaries
164 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
165 |
166 | ########################
167 | # construct the losses #
168 | ########################
169 | # 1) reconstruction loss
170 | if self.recons_weight > 0.0:
171 | recons_loss = tf.losses.mean_squared_error(
172 | inputs, outputs, weights=self.recons_weight, scope='recons_loss')
173 | self.recons_loss = recons_loss
174 | self.total_loss += recons_loss
175 | summaries.add(tf.summary.scalar('losses/recons_loss', recons_loss))
176 |
177 | # 2) content loss
178 | if self.content_weight > 0.0:
179 | outputs_image_features = losses.extract_image_features(
180 | outputs, self.network_name)
181 | outputs_content_features = losses.compute_content_features(
182 | outputs_image_features, self.style_loss_layers)
183 |
184 | inputs_image_features = losses.extract_image_features(
185 | inputs, self.network_name)
186 | inputs_content_features = losses.compute_content_features(
187 | inputs_image_features, self.style_loss_layers)
188 |
189 | content_loss = losses.compute_content_loss(
190 | outputs_content_features, inputs_content_features,
191 | content_loss_layers=self.style_loss_layers, weights=self.content_weight)
192 | self.content_loss = content_loss
193 | self.total_loss += content_loss
194 | summaries.add(tf.summary.scalar('losses/content_loss', content_loss))
195 |
196 | # 3) total variation loss
197 | if self.tv_weight > 0.0:
198 | tv_loss = losses.compute_total_variation_loss_l1(outputs, self.tv_weight)
199 | self.tv_loss = tv_loss
200 | self.total_loss += tv_loss
201 | summaries.add(tf.summary.scalar('losses/tv_loss', tv_loss))
202 |
203 | image_tiles = tf.concat([inputs, outputs], axis=2)
204 | image_tiles = preprocessing.batch_mean_image_summation(image_tiles)
205 | image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0), tf.uint8)
206 | summaries.add(tf.summary.image('image_comparison', image_tiles, max_outputs=8))
207 |
208 | self.summaries = summaries
209 | return self.total_loss
210 |
211 | def get_training_operations(self,
212 | optimizer,
213 | global_step,
214 | variables_to_train=tf.trainable_variables()):
215 | # gather the variable summaries
216 | variables_summaries = []
217 | for var in variables_to_train:
218 | variables_summaries.append(tf.summary.histogram(var.op.name, var))
219 | variables_summaries = set(variables_summaries)
220 |
221 | # add the training operations
222 | train_ops = []
223 | grads_and_vars = optimizer.compute_gradients(
224 | self.total_loss, var_list=variables_to_train)
225 | train_op = optimizer.apply_gradients(
226 | grads_and_vars=grads_and_vars,
227 | global_step=global_step)
228 | train_ops.append(train_op)
229 |
230 | self.summaries |= variables_summaries
231 | self.train_op = tf.group(*train_ops)
232 | return self.train_op
233 |
234 |
235 | def style_decorator(content_features,
236 | style_features,
237 | style_coding='ZCA',
238 | style_interp='normalized',
239 | ratio_interp=1.0,
240 | patch_size=3):
241 | """style decorator for high-level feature interaction
242 |
243 | Args:
244 | content_features: a tensor of size [batch_size, height, width, channel]
245 | style_features: a tensor of size [batch_size, height, width, channel]
246 | style_coding: projection and reconstruction method for style coding
247 | style_interp: interpolation option
248 | ratio_interp: interpolation ratio
249 | patch_size: a 0D tensor or int about the size of the patch
250 | """
251 | # feature projection
252 | projected_content_features, _, _ = \
253 | project_features(content_features, projection_module=style_coding)
254 | projected_style_features, style_kernels, mean_style_features = \
255 | project_features(style_features, projection_module=style_coding)
256 |
257 | # feature rearrangement
258 | rearranged_features = nearest_patch_swapping(
259 | projected_content_features, projected_style_features, patch_size=patch_size)
260 | if style_interp == 'normalized':
261 | rearranged_features = ratio_interp * rearranged_features + \
262 | (1 - ratio_interp) * projected_content_features
263 |
264 | # feature reconstruction
265 | reconstructed_features = reconstruct_features(
266 | rearranged_features,
267 | style_kernels,
268 | mean_style_features,
269 | reconstruction_module=style_coding)
270 |
271 | if style_interp == 'biased':
272 | reconstructed_features = ratio_interp * reconstructed_features + \
273 | (1 - ratio_interp) * content_features
274 |
275 | return reconstructed_features
276 |
277 |
278 | def project_features(features, projection_module='ZCA'):
279 | if projection_module == 'ZCA':
280 | return zca_normalization(features)
281 | elif projection_module == 'AdaIN':
282 | return adain_normalization(features)
283 | else:
284 | return features, None, None
285 |
286 |
287 | def reconstruct_features(projected_features,
288 | feature_kernels,
289 | mean_features,
290 | reconstruction_module='ZCA'):
291 | if reconstruction_module == 'ZCA':
292 | return zca_colorization(projected_features, feature_kernels, mean_features)
293 | elif reconstruction_module == 'AdaIN':
294 | return adain_colorization(projected_features, feature_kernels, mean_features)
295 | else:
296 | return projected_features
297 |
298 |
299 | def nearest_patch_swapping(content_features, style_features, patch_size=3):
300 | # channels for both the content and style, must be the same
301 | c_shape = tf.shape(content_features)
302 | s_shape = tf.shape(style_features)
303 | channel_assertion = tf.Assert(
304 | tf.equal(c_shape[3], s_shape[3]), ['number of channels must be the same'])
305 |
306 | with tf.control_dependencies([channel_assertion]):
307 | # spatial shapes for style and content features
308 | c_height, c_width, c_channel = c_shape[1], c_shape[2], c_shape[3]
309 |
310 | # convert the style features into convolutional kernels
311 | style_kernels = tf.extract_image_patches(
312 | style_features, ksizes=[1, patch_size, patch_size, 1],
313 | strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='SAME')
314 | style_kernels = tf.squeeze(style_kernels, axis=0)
315 | style_kernels = tf.transpose(style_kernels, perm=[2, 0, 1])
316 |
317 | # gather the conv and deconv kernels
318 | v_height, v_width = style_kernels.get_shape().as_list()[1:3]
319 | deconv_kernels = tf.reshape(
320 | style_kernels, shape=(patch_size, patch_size, c_channel, v_height*v_width))
321 |
322 | kernels_norm = tf.norm(style_kernels, axis=0, keep_dims=True)
323 | kernels_norm = tf.reshape(kernels_norm, shape=(1, 1, 1, v_height*v_width))
324 |
325 | # calculate the normalization factor
326 | mask = tf.ones((c_height, c_width), tf.float32)
327 | fullmask = tf.zeros((c_height+patch_size-1, c_width+patch_size-1), tf.float32)
328 | for x in range(patch_size):
329 | for y in range(patch_size):
330 | paddings = [[x, patch_size-x-1], [y, patch_size-y-1]]
331 | padded_mask = tf.pad(mask, paddings=paddings, mode="CONSTANT")
332 | fullmask += padded_mask
333 | pad_width = int((patch_size-1)/2)
334 | deconv_norm = tf.slice(fullmask, [pad_width, pad_width], [c_height, c_width])
335 | deconv_norm = tf.reshape(deconv_norm, shape=(1, c_height, c_width, 1))
336 |
337 | ########################
338 | # starting convolution #
339 | ########################
340 | # padding operation
341 | pad_total = patch_size - 1
342 | pad_beg = pad_total // 2
343 | pad_end = pad_total - pad_beg
344 | paddings = [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]
345 |
346 | # convolutional operations
347 | net = tf.pad(content_features, paddings=paddings, mode="REFLECT")
348 | net = tf.nn.conv2d(
349 | net,
350 | tf.div(deconv_kernels, kernels_norm+1e-7),
351 | strides=[1, 1, 1, 1],
352 | padding='VALID')
353 | # find the maximum locations
354 | best_match_ids = tf.argmax(net, axis=3)
355 | best_match_ids = tf.cast(
356 | tf.one_hot(best_match_ids, depth=v_height*v_width), dtype=tf.float32)
357 |
358 | # find the patches and warping the output
359 | unnormalized_output = tf.nn.conv2d_transpose(
360 | value=best_match_ids,
361 | filter=deconv_kernels,
362 | output_shape=(c_shape[0], c_height+pad_total, c_width+pad_total, c_channel),
363 | strides=[1, 1, 1, 1],
364 | padding='VALID')
365 | unnormalized_output = tf.slice(unnormalized_output, [0, pad_beg, pad_beg, 0], c_shape)
366 | output = tf.div(unnormalized_output, deconv_norm)
367 | output = tf.reshape(output, shape=c_shape)
368 |
369 | # output the swapped feature maps
370 | return output
371 |
372 |
373 | def zca_normalization(features):
374 | shape = tf.shape(features)
375 |
376 | # reshape the features to orderless feature vectors
377 | mean_features = tf.reduce_mean(features, axis=[1, 2], keep_dims=True)
378 | unbiased_features = tf.reshape(features - mean_features, shape=(shape[0], -1, shape[3]))
379 |
380 | # get the covariance matrix
381 | gram = tf.matmul(unbiased_features, unbiased_features, transpose_a=True)
382 | gram /= tf.reduce_prod(tf.cast(shape[1:3], tf.float32))
383 |
384 | # converting the feature spaces
385 | s, u, v = tf.svd(gram, compute_uv=True)
386 | s = tf.expand_dims(s, axis=1) # let it be active in the last dimension
387 |
388 | # get the effective singular values
389 | valid_index = tf.cast(s > 0.00001, dtype=tf.float32)
390 | s_effective = tf.maximum(s, 0.00001)
391 | sqrt_s_effective = tf.sqrt(s_effective) * valid_index
392 | sqrt_inv_s_effective = tf.sqrt(1.0/s_effective) * valid_index
393 |
394 | # colorization functions
395 | colorization_kernel = tf.matmul(tf.multiply(u, sqrt_s_effective), v, transpose_b=True)
396 |
397 | # normalized features
398 | normalized_features = tf.matmul(unbiased_features, u)
399 | normalized_features = tf.multiply(normalized_features, sqrt_inv_s_effective)
400 | normalized_features = tf.matmul(normalized_features, v, transpose_b=True)
401 | normalized_features = tf.reshape(normalized_features, shape=shape)
402 |
403 | return normalized_features, colorization_kernel, mean_features
404 |
405 |
406 | def zca_colorization(normalized_features, colorization_kernel, mean_features):
407 | # broadcasting the tensors for matrix multiplication
408 | shape = tf.shape(normalized_features)
409 | normalized_features = tf.reshape(
410 | normalized_features, shape=(shape[0], -1, shape[3]))
411 | colorized_features = tf.matmul(normalized_features, colorization_kernel)
412 | colorized_features = tf.reshape(colorized_features, shape=shape) + mean_features
413 | return colorized_features
414 |
415 |
416 | def adain_normalization(features):
417 | epsilon = 1e-7
418 | mean_features, colorization_kernels = tf.nn.moments(features, [1, 2], keep_dims=True)
419 | normalized_features = tf.div(
420 | tf.subtract(features, mean_features), tf.sqrt(tf.add(colorization_kernels, epsilon)))
421 | return normalized_features, colorization_kernels, mean_features
422 |
423 |
424 | def adain_colorization(normalized_features, colorization_kernels, mean_features):
425 | return tf.sqrt(colorization_kernels) * normalized_features + mean_features
426 |
--------------------------------------------------------------------------------
/models/avatar_net.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/avatar_net.pyc
--------------------------------------------------------------------------------
/models/losses.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 |
7 | from models import vgg
8 |
9 | slim = tf.contrib.slim
10 |
11 | network_map = {
12 | 'vgg_16': vgg.vgg_16,
13 | 'vgg_19': vgg.vgg_19,
14 | }
15 |
16 |
17 | def compute_gram_matrix(feature):
18 | """compute the gram matrix for a layer of feature
19 |
20 | the gram matrix is normalized with respect to the samples and
21 | the dimensions of the input features
22 |
23 | """
24 | shape = tf.shape(feature)
25 | feature_size = tf.reduce_prod(shape[1:])
26 | vectorized_feature = tf.reshape(
27 | feature, [shape[0], -1, shape[3]])
28 | gram_matrix = tf.matmul(
29 | vectorized_feature, vectorized_feature, transpose_a=True)
30 | gram_matrix /= tf.to_float(feature_size)
31 | return gram_matrix
32 |
33 |
34 | def compute_sufficient_statistics(feature):
35 | """compute the gram matrix for a layer of feature"""
36 | mean_feature, var_feature = tf.nn.moments(feature, [1, 2], keep_dims=True)
37 | std_feature = tf.sqrt(var_feature)
38 | sufficient_statistics = tf.concat([mean_feature, std_feature], axis=3)
39 | return sufficient_statistics
40 |
41 |
42 | def compute_content_features(features, content_loss_layers):
43 | """compute the content features from the end_point dict"""
44 | content_features = {}
45 | instance_label = features.keys()[0]
46 | instance_label = instance_label[:-14] # TODO: ugly code, need fix
47 | for layer in content_loss_layers:
48 | content_features[layer] = features[instance_label + '/' + layer]
49 | return content_features
50 |
51 |
52 | def compute_style_features(features, style_loss_layers):
53 | """compute the style features from the end_point dict"""
54 | style_features = {}
55 | instance_label = features.keys()[0]
56 | instance_label = instance_label[:-14] # TODO: ugly code, need fix
57 | for layer in style_loss_layers:
58 | style_features[layer] = compute_gram_matrix(
59 | features[instance_label + '/' + layer])
60 | return style_features
61 |
62 |
63 | def compute_approximate_style_features(features, style_loss_layers):
64 | style_features = {}
65 | instance_label = features.keys()[0].split('/')[:-2]
66 | for layer in style_loss_layers:
67 | style_features[layer] = compute_sufficient_statistics(
68 | features[instance_label + '/' + layer])
69 | return style_features
70 |
71 |
72 | def extract_image_features(inputs, network_name, reuse=True):
73 | """compute the dict of layer-wise image features from a given list of networks
74 |
75 | Args:
76 | inputs: the inputs image should be normalized between [-127.5, 127.5]
77 | network_name: the network name for the perceptual loss
78 | reuse: whether to reuse the parameters
79 |
80 | Returns:
81 | end_points: a dict for the image features of the inputs
82 | """
83 | with slim.arg_scope(vgg.vgg_arg_scope()):
84 | _, end_points = network_map[network_name](
85 | inputs, spatial_squeeze=False, is_training=False, reuse=reuse)
86 | return end_points
87 |
88 |
89 | def compute_content_and_style_features(inputs,
90 | network_name,
91 | content_loss_layers,
92 | style_loss_layers):
93 | """compute the content and style features from normalized image
94 |
95 | Args:
96 | inputs: input tensor of size [batch, height, width, channel]
97 | network_name: a string of the network name
98 | content_loss_layers: a dict about the layers for the content loss
99 | style_loss_layers: a dict about the layers for the style loss
100 |
101 | Returns:
102 | a dict of the features of the inputs
103 | """
104 | end_points = extract_image_features(inputs, network_name)
105 |
106 | content_features = compute_content_features(end_points, content_loss_layers)
107 | style_features = compute_style_features(end_points, style_loss_layers)
108 |
109 | return content_features, style_features
110 |
111 |
112 | def compute_content_loss(content_features, target_features,
113 | content_loss_layers, weights=1, scope=None):
114 | """compute the content loss
115 |
116 | Args:
117 | content_features: a dict of the features of the input image
118 | target_features: a dict of the features of the output image
119 | content_loss_layers: a dict about the layers for the content loss
120 | weights: the weights for this loss
121 | scope: optional scope
122 |
123 | Returns:
124 | the content loss
125 | """
126 | with tf.variable_scope(scope, 'content_loss', [content_features, target_features]):
127 | content_loss = 0
128 | for layer in content_loss_layers:
129 | content_feature = content_features[layer]
130 | target_feature = target_features[layer]
131 | content_loss += tf.losses.mean_squared_error(
132 | target_feature, content_feature, weights=weights, scope=layer)
133 | return content_loss
134 |
135 |
136 | def compute_style_loss(style_features, target_features,
137 | style_loss_layers, weights=1, scope=None):
138 | """compute the style loss
139 |
140 | Args:
141 | style_features: a dict of the Gram matrices of the style image
142 | target_features: a dict of the Gram matrices of the target image
143 | style_loss_layers: a dict of layers of features for the style loss
144 | weights: the weights for this loss
145 | scope: optional scope
146 |
147 | Returns:
148 | the style loss
149 | """
150 | with tf.variable_scope(scope, 'style_loss', [style_features, target_features]):
151 | style_loss = 0
152 | for layer in style_loss_layers:
153 | style_feature = style_features[layer]
154 | target_feature = target_features[layer]
155 | style_loss += tf.losses.mean_squared_error(
156 | style_feature, target_feature, weights=weights, scope=layer)
157 | return style_loss
158 |
159 |
160 | def compute_approximate_style_loss(style_features, target_features,
161 | style_loss_layers, scope=None):
162 | """compute the approximate style loss
163 |
164 | Args:
165 | style_features: a dict of the sufficient statistics of the
166 | feature maps of the style image
167 | target_features: a dict of the sufficient statistics of the
168 | feature maps of the target image
169 | style_loss_layers: a dict of layers of features for the style loss
170 | scope: optional scope
171 |
172 | Returns:
173 | the style loss
174 | """
175 | with tf.variable_scope(scope, 'approximated_style_loss', [style_features, target_features]):
176 | style_loss = 0
177 | for layer in style_loss_layers:
178 | style_feature = style_features[layer]
179 | target_feature = target_features[layer]
180 | # we only normalize with respect to the number of channel
181 | style_loss_per_layer = tf.reduce_sum(tf.square(style_feature-target_feature), axis=[1, 2, 3])
182 | style_loss += tf.reduce_mean(style_loss_per_layer)
183 | return style_loss
184 |
185 |
186 | def compute_total_variation_loss_l2(inputs, weights=1, scope=None):
187 | """compute the total variation loss"""
188 | inputs_shape = tf.shape(inputs)
189 | height = inputs_shape[1]
190 | width = inputs_shape[2]
191 |
192 | with tf.variable_scope(scope, 'total_variation_loss', [inputs]):
193 | loss_y = tf.losses.mean_squared_error(
194 | tf.slice(inputs, [0, 0, 0, 0], [-1, height-1, -1, -1]),
195 | tf.slice(inputs, [0, 1, 0, 0], [-1, -1, -1, -1]),
196 | weights=weights,
197 | scope='loss_y')
198 | loss_x = tf.losses.mean_squared_error(
199 | tf.slice(inputs, [0, 0, 0, 0], [-1, -1, width-1, -1]),
200 | tf.slice(inputs, [0, 0, 1, 0], [-1, -1, -1, -1]),
201 | weights=weights,
202 | scope='loss_x')
203 | loss = loss_y + loss_x
204 | return loss
205 |
206 |
207 | def compute_total_variation_loss_l1(inputs, weights=1, scope=None):
208 | """compute the total variation loss L1 norm"""
209 | inputs_shape = tf.shape(inputs)
210 | height = inputs_shape[1]
211 | width = inputs_shape[2]
212 |
213 | with tf.variable_scope(scope, 'total_variation_loss', [inputs]):
214 | loss_y = tf.losses.absolute_difference(
215 | tf.slice(inputs, [0, 0, 0, 0], [-1, height-1, -1, -1]),
216 | tf.slice(inputs, [0, 1, 0, 0], [-1, -1, -1, -1]),
217 | weights=weights,
218 | scope='loss_y')
219 | loss_x = tf.losses.absolute_difference(
220 | tf.slice(inputs, [0, 0, 0, 0], [-1, -1, width-1, -1]),
221 | tf.slice(inputs, [0, 0, 1, 0], [-1, -1, -1, -1]),
222 | weights=weights,
223 | scope='loss_x')
224 | loss = loss_y + loss_x
225 | return loss
226 |
--------------------------------------------------------------------------------
/models/losses.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/losses.pyc
--------------------------------------------------------------------------------
/models/models_factory.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 | import yaml
7 |
8 | from models import avatar_net
9 |
10 | slim = tf.contrib.slim
11 |
12 | models_map = {
13 | 'AvatarNet': avatar_net.AvatarNet,
14 | }
15 |
16 |
17 | def get_model(filename):
18 | if not tf.gfile.Exists(filename):
19 | raise ValueError('The config file [%s] does not exist.' % filename)
20 |
21 | with open(filename, 'rb') as f:
22 | options = yaml.load(f)
23 | model_name = options.get('model_name')
24 | print('Finish loading the model [%s] configuration' % model_name)
25 | if model_name not in models_map:
26 | raise ValueError('Name of model [%s] unknown' % model_name)
27 | model = models_map[model_name](options)
28 | return model, options
29 |
--------------------------------------------------------------------------------
/models/models_factory.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/models_factory.pyc
--------------------------------------------------------------------------------
/models/network_ops.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | from __future__ import division
4 |
5 | import tensorflow as tf
6 |
7 | slim = tf.contrib.slim
8 |
9 |
10 | # functions for neural network layers
11 | @slim.add_arg_scope
12 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
13 | """strided 2-D convolution with 'REFLECT' padding.
14 |
15 | Args:
16 | inputs: A 4-D tensor of size [batch, height, width, channel]
17 | num_outputs: An integer, the number of output filters
18 | kernel_size: An int with the kernel_size of the filters
19 | stride: An integer, the output stride
20 | rate: An integer, rate for atrous convolution
21 | scope: Optional scope
22 |
23 | Returns:
24 | output: A 4-D tensor of size [batch, height_out, width_out, channel] with
25 | the convolution output.
26 | """
27 | if kernel_size == 1:
28 | return slim.conv2d(inputs, num_outputs, kernel_size=1, stride=stride,
29 | rate=rate, padding='SAME', scope=scope)
30 | else:
31 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
32 | pad_total = kernel_size_effective - 1
33 | pad_beg = pad_total // 2
34 | pad_end = pad_total - pad_beg
35 | paddings = [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]
36 | inputs = tf.pad(inputs, paddings=paddings, mode="REFLECT")
37 | outputs = slim.conv2d(inputs, num_outputs, kernel_size, stride=stride,
38 | rate=rate, padding='VALID', scope=scope)
39 | return outputs
40 |
41 |
42 | @slim.add_arg_scope
43 | def conv2d_resize(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
44 | """deconvolution alternatively with the conv2d_transpose, where we
45 | first resize the inputs, and then convolve the results, see
46 | http://distill.pub/2016/deconv-checkerboard/
47 |
48 | Args:
49 | inputs: A 4-D tensor of size [batch, height, width, channel]
50 | num_outputs: An integer, the number of output filters
51 | kernel_size: An int with the kernel_size of the filters
52 | stride: An integer, the output stride
53 | rate: An integer, rate for atrous convolution
54 | scope: Optional scope
55 |
56 | Returns:
57 | output: A 4-D tensor of size [batch, height_out, width_out, channel] with
58 | the convolution output.
59 | """
60 | if stride == 1:
61 | return conv2d_same(inputs, num_outputs, kernel_size,
62 | stride=1, rate=rate, scope=scope)
63 | else:
64 | stride_larger_than_one = tf.greater(stride, 1)
65 | height = tf.shape(inputs)[1]
66 | width = tf.shape(inputs)[2]
67 | new_height, new_width = tf.cond(
68 | stride_larger_than_one,
69 | lambda: (height*stride, width*stride),
70 | lambda: (height, width))
71 | inputs_resize = tf.image.resize_nearest_neighbor(inputs,
72 | [new_height, new_width])
73 | outputs = conv2d_same(inputs_resize, num_outputs, kernel_size,
74 | stride=1, rate=rate, scope=scope)
75 | return outputs
76 |
77 |
78 | @slim.add_arg_scope
79 | def lrelu(inputs, leak=0.2, scope=None):
80 | """customized leaky ReLU activation function
81 | https://github.com/tensorflow/tensorflow/issues/4079
82 | """
83 | with tf.variable_scope(scope, 'lrelu'):
84 | f1 = 0.5 * (1 + leak)
85 | f2 = 0.5 * (1 - leak)
86 | return f1 * inputs + f2 * tf.abs(inputs)
87 |
88 |
89 | @slim.add_arg_scope
90 | def instance_norm(inputs, epsilon=1e-10):
91 | inst_mean, inst_var = tf.nn.moments(inputs, [1, 2], keep_dims=True)
92 | normalized_inputs = tf.div(
93 | tf.subtract(inputs, inst_mean), tf.sqrt(tf.add(inst_var, epsilon)))
94 | return normalized_inputs
95 |
96 |
97 | @slim.add_arg_scope
98 | def residual_unit_v0(inputs, depth, output_collections=None, scope=None):
99 | """Residual block version 0, the input and output has the same depth
100 |
101 | Args:
102 | inputs: a tensor of size [batch, height, width, channel]
103 | depth: the depth of the resnet unit output
104 | output_collections: collection to add the resnet unit output
105 | scope: optional variable_scope
106 |
107 | Returns:
108 | The resnet unit's output
109 | """
110 | with tf.variable_scope(scope, 'res_unit_v0', [inputs]) as sc:
111 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
112 | if depth == depth_in:
113 | shortcut = inputs
114 | else:
115 | shortcut = slim.conv2d(inputs, depth, [1, 1], scope='shortcut')
116 |
117 | residual = conv2d_same(inputs, depth, 3, stride=1, scope='conv1')
118 | with slim.arg_scope([slim.conv2d], activation_fn=None):
119 | residual = conv2d_same(residual, depth, 3, stride=1, scope='conv2')
120 |
121 | output = tf.nn.relu(shortcut + residual)
122 |
123 | return slim.utils.collect_named_outputs(
124 | output_collections, sc.original_name_scope, output)
125 |
126 |
127 | @slim.add_arg_scope
128 | def residual_block_downsample(inputs, depth, stride,
129 | normalizer_fn=slim.layer_norm,
130 | activation_fn=tf.nn.relu,
131 | outputs_collections=None, scope=None):
132 | """Residual block version 2 for downsampling, with preactivation
133 |
134 | Args:
135 | inputs: a tensor of size [batch, height, width, channel]
136 | depth: the depth of the resnet unit output
137 | stride: the stride of the residual block
138 | normalizer_fn: normalizer function for the residual block
139 | activation_fn: activation function for the residual block
140 | outputs_collections: collection to add the resnet unit output
141 | scope: optional variable_scope
142 |
143 | Returns:
144 | The resnet unit's output
145 | """
146 | with tf.variable_scope(scope, 'res_block_downsample', [inputs]) as sc:
147 | with slim.arg_scope([slim.conv2d],
148 | normalizer_fn=normalizer_fn,
149 | activation_fn=activation_fn):
150 | # preactivate the inputs
151 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
152 | preact = normalizer_fn(inputs, activation_fn=activation_fn, scope='preact')
153 | if depth == depth_in:
154 | shortcut = subsample(inputs, stride, scope='shortcut')
155 | else:
156 | with slim.arg_scope([slim.conv2d],
157 | normalizer_fn=None, activation_fn=None):
158 | shortcut = conv2d_same(preact, depth, 1,
159 | stride=stride, scope='shortcut')
160 |
161 | depth_botteneck = int(depth / 4)
162 | residual = slim.conv2d(preact, depth_botteneck, [1, 1],
163 | stride=1, scope='conv1')
164 | residual = conv2d_same(residual, depth_botteneck, 3,
165 | stride=stride, scope='conv2')
166 | residual = slim.conv2d(residual, depth, [1, 1],
167 | stride=1, normalizer_fn=None,
168 | activation_fn=None, scope='conv3')
169 |
170 | output = shortcut + residual
171 |
172 | return slim.utils.collect_named_outputs(
173 | outputs_collections, sc.original_name_scope, output)
174 |
175 |
176 | @slim.add_arg_scope
177 | def residual_block_upsample(inputs, depth, stride,
178 | normalizer_fn=slim.layer_norm,
179 | activation_fn=tf.nn.relu,
180 | outputs_collections=None, scope=None):
181 | """Residual block version 2 for upsampling, with preactivation
182 |
183 | Args:
184 | inputs: a tensor of size [batch, height, width, channel]
185 | depth: the depth of the resnet unit output
186 | stride: the stride of the residual block
187 | normalizer_fn: the normalizer function used in this block
188 | activation_fn: the activation function used in this block
189 | outputs_collections: collection to add the resnet unit output
190 | scope: optional variable_scope
191 |
192 | Returns:
193 | The resnet unit's output
194 | """
195 | with tf.variable_scope(scope, 'res_block_upsample', [inputs]) as sc:
196 | with slim.arg_scope([slim.conv2d],
197 | normalizer_fn=normalizer_fn,
198 | activation_fn=activation_fn):
199 | # preactivate the inputs
200 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
201 | preact = normalizer_fn(inputs, activation_fn=activation_fn, scope='preact')
202 | if depth == depth_in:
203 | shortcut = upsample(inputs, stride, scope='shortcut')
204 | else:
205 | with slim.arg_scope([slim.conv2d],
206 | normalizer_fn=None, activation_fn=None):
207 | shortcut = conv2d_resize(preact, depth, 1, stride=stride, scope='shortcut')
208 |
209 | # calculate the residuals
210 | depth_botteneck = int(depth / 4)
211 | residual = slim.conv2d(preact, depth_botteneck, [1, 1],
212 | stride=1, scope='conv1')
213 | residual = conv2d_resize(residual, depth_botteneck, 3,
214 | stride=stride, scope='conv2')
215 | residual = slim.conv2d(residual, depth, [1, 1],
216 | stride=1, normalizer_fn=None,
217 | activation_fn=None, scope='conv3')
218 |
219 | output = shortcut + residual
220 |
221 | return slim.utils.collect_named_outputs(
222 | outputs_collections, sc.original_name_scope, output)
223 |
224 |
225 | def subsample(inputs, factor, scope=None):
226 | if factor == 1:
227 | return inputs
228 | else:
229 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope)
230 |
231 |
232 | def upsample(inputs, factor, scope=None):
233 | if factor == 1:
234 | return inputs
235 | else:
236 | factor_larger_than_one = tf.greater(factor, 1)
237 | height = tf.shape(inputs)[1]
238 | width = tf.shape(inputs)[2]
239 | new_height, new_width = tf.cond(
240 | factor_larger_than_one,
241 | lambda: (height*factor, width*factor),
242 | lambda: (height, width))
243 | resized_inputs = tf.image.resize_nearest_neighbor(
244 | inputs, [new_height, new_width], name=scope)
245 | return resized_inputs
246 |
247 |
248 | def adaptive_instance_normalization(content_feature, style_feature):
249 | """adaptively transform the content feature by inverse instance normalization
250 | based on the 2nd order statistics of the style feature
251 | """
252 | normalized_content_feature = instance_norm(content_feature)
253 | inst_mean, inst_var = tf.nn.moments(style_feature, [1, 2], keep_dims=True)
254 | return tf.sqrt(inst_var) * normalized_content_feature + inst_mean
255 |
256 |
257 | def whitening_colorization_transform(content_features, style_features):
258 | """transform the content feature based on the whitening and colorization transform"""
259 | content_shape = tf.shape(content_features)
260 | style_shape = tf.shape(style_features)
261 |
262 | # get the unbiased content and style features
263 | content_features = tf.reshape(
264 | content_features, shape=(content_shape[0], -1, content_shape[3]))
265 | style_features = tf.reshape(
266 | style_features, shape=(style_shape[0], -1, style_shape[3]))
267 |
268 | # get the covariance matrices
269 | content_gram = tf.matmul(content_features, content_features, transpose_a=True)
270 | content_gram /= tf.reduce_prod(tf.cast(content_shape[1:], tf.float32))
271 | style_gram = tf.matmul(style_features, style_features, transpose_a=True)
272 | style_gram /= tf.reduce_prod(tf.cast(style_shape[1:], tf.float32))
273 |
274 | #################################
275 | # converting the feature spaces #
276 | #################################
277 | s_c, u_c, v_c = tf.svd(content_gram, compute_uv=True)
278 | s_c = tf.expand_dims(s_c, axis=1)
279 | s_s, u_s, v_s = tf.svd(style_gram, compute_uv=True)
280 | s_s = tf.expand_dims(s_s, axis=1)
281 |
282 | # normalized features
283 | normalized_features = tf.matmul(content_features, u_c)
284 | normalized_features = tf.multiply(normalized_features, 1.0/(tf.sqrt(s_c+1e-5)))
285 | normalized_features = tf.matmul(normalized_features, v_c, transpose_b=True)
286 |
287 | # colorized features
288 | # broadcasting the tensors for matrix multiplication
289 | content_batch = tf.shape(u_c)[0]
290 | style_batch = tf.shape(u_s)[0]
291 | batch_multiplier = tf.cast(content_batch/style_batch, tf.int32)
292 | u_s = tf.tile(u_s, multiples=tf.stack([batch_multiplier, 1, 1]))
293 | v_s = tf.tile(v_s, multiples=tf.stack([batch_multiplier, 1, 1]))
294 | colorized_features = tf.matmul(normalized_features, u_s)
295 | colorized_features = tf.multiply(colorized_features, tf.sqrt(s_s+1e-5))
296 | colorized_features = tf.matmul(colorized_features, v_s, transpose_b=True)
297 |
298 | # reshape the colorized features
299 | colorized_features = tf.reshape(colorized_features, shape=content_shape)
300 | return colorized_features
301 |
--------------------------------------------------------------------------------
/models/network_ops.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/network_ops.pyc
--------------------------------------------------------------------------------
/models/preprocessing.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 |
7 | from tensorflow.python.ops import control_flow_ops
8 |
9 | slim = tf.contrib.slim
10 |
11 | _R_MEAN = 123.68
12 | _G_MEAN = 116.78
13 | _B_MEAN = 103.94
14 |
15 | _RESIZE_SIDE_MIN = 256
16 | _RESIZE_SIDE_MAX = 512
17 |
18 |
19 | def _crop(image, offset_height, offset_width, crop_height, crop_width):
20 | original_shape = tf.shape(image)
21 |
22 | rank_assertion = tf.Assert(
23 | tf.equal(tf.rank(image), 3),
24 | ['Rank of image must be equal to 3.'])
25 | cropped_shape = control_flow_ops.with_dependencies(
26 | [rank_assertion],
27 | tf.stack([crop_height, crop_width, original_shape[2]]))
28 |
29 | size_assertion = tf.Assert(
30 | tf.logical_and(
31 | tf.greater_equal(original_shape[0], crop_height),
32 | tf.greater_equal(original_shape[1], crop_width)),
33 | ['Crop size greater than the image size.'])
34 |
35 | offsets = tf.to_int32(tf.stack([offset_height, offset_width, 0]))
36 |
37 | # Use tf.slice instead of crop_to_bounding box as it accepts tensors to
38 | # define the crop size.
39 | image = control_flow_ops.with_dependencies(
40 | [size_assertion],
41 | tf.slice(image, offsets, cropped_shape))
42 | return tf.reshape(image, cropped_shape)
43 |
44 |
45 | def _random_crop(image_list, crop_height, crop_width):
46 | if not image_list:
47 | raise ValueError('Empty image_list.')
48 |
49 | # Compute the rank assertions.
50 | rank_assertions = []
51 | for i in range(len(image_list)):
52 | image_rank = tf.rank(image_list[i])
53 | rank_assert = tf.Assert(
54 | tf.equal(image_rank, 3),
55 | ['Wrong rank for tensor %s [expected] [actual]',
56 | image_list[i].name, 3, image_rank])
57 | rank_assertions.append(rank_assert)
58 |
59 | image_shape = control_flow_ops.with_dependencies(
60 | [rank_assertions[0]],
61 | tf.shape(image_list[0]))
62 | image_height = image_shape[0]
63 | image_width = image_shape[1]
64 | crop_size_assert = tf.Assert(
65 | tf.logical_and(
66 | tf.greater_equal(image_height, crop_height),
67 | tf.greater_equal(image_width, crop_width)),
68 | ['Crop size greater than the image size.'])
69 |
70 | asserts = [rank_assertions[0], crop_size_assert]
71 |
72 | for i in range(1, len(image_list)):
73 | image = image_list[i]
74 | asserts.append(rank_assertions[i])
75 | shape = control_flow_ops.with_dependencies([rank_assertions[i]],
76 | tf.shape(image))
77 | height = shape[0]
78 | width = shape[1]
79 |
80 | height_assert = tf.Assert(
81 | tf.equal(height, image_height),
82 | ['Wrong height for tensor %s [expected][actual]',
83 | image.name, height, image_height])
84 | width_assert = tf.Assert(
85 | tf.equal(width, image_width),
86 | ['Wrong width for tensor %s [expected][actual]',
87 | image.name, width, image_width])
88 | asserts.extend([height_assert, width_assert])
89 |
90 | # Create a random bounding box.
91 | #
92 | # Use tf.random_uniform and not numpy.random.rand as doing the former would
93 | # generate random numbers at graph eval time, unlike the latter which
94 | # generates random numbers at graph definition time.
95 | max_offset_height = control_flow_ops.with_dependencies(
96 | asserts, tf.reshape(image_height - crop_height + 1, []))
97 | max_offset_width = control_flow_ops.with_dependencies(
98 | asserts, tf.reshape(image_width - crop_width + 1, []))
99 | offset_height = tf.random_uniform(
100 | [], maxval=max_offset_height, dtype=tf.int32)
101 | offset_width = tf.random_uniform(
102 | [], maxval=max_offset_width, dtype=tf.int32)
103 |
104 | return [_crop(image, offset_height, offset_width,
105 | crop_height, crop_width) for image in image_list]
106 |
107 |
108 | def _central_crop(image_list, crop_height, crop_width):
109 | outputs = []
110 | for image in image_list:
111 | image_height = tf.shape(image)[0]
112 | image_width = tf.shape(image)[1]
113 |
114 | offset_height = (image_height - crop_height) / 2
115 | offset_width = (image_width - crop_width) / 2
116 |
117 | outputs.append(_crop(image, offset_height, offset_width,
118 | crop_height, crop_width))
119 | return outputs
120 |
121 |
122 | def _mean_image_subtraction(image, means=(_R_MEAN, _G_MEAN, _B_MEAN)):
123 | if image.get_shape().ndims != 3:
124 | raise ValueError('Input must be of size [height, width, C>0]')
125 | num_channels = image.get_shape().as_list()[-1]
126 | if len(means) != num_channels:
127 | raise ValueError('len(means) must match the number of channels')
128 |
129 | channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image)
130 | for i in range(num_channels):
131 | channels[i] -= means[i]
132 | return tf.concat(axis=2, values=channels)
133 |
134 |
135 | def _smallest_size_at_least(height, width, smallest_side):
136 | smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32)
137 |
138 | height = tf.to_float(height)
139 | width = tf.to_float(width)
140 | smallest_side = tf.to_float(smallest_side)
141 |
142 | scale = tf.cond(tf.greater(height, width),
143 | lambda: smallest_side / width,
144 | lambda: smallest_side / height)
145 | new_height = tf.to_int32(height * scale)
146 | new_width = tf.to_int32(width * scale)
147 | return new_height, new_width
148 |
149 |
150 | def _aspect_preserving_resize(image, smallest_side):
151 | smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32)
152 |
153 | shape = tf.shape(image)
154 | height = shape[0]
155 | width = shape[1]
156 | new_height, new_width = _smallest_size_at_least(height, width, smallest_side)
157 | image = tf.expand_dims(image, 0)
158 | resized_image = tf.image.resize_bilinear(image, [new_height, new_width],
159 | align_corners=False)
160 | resized_image = tf.squeeze(resized_image)
161 | resized_image.set_shape([None, None, 3])
162 | return resized_image
163 |
164 |
165 | def preprocessing_for_train(image, output_height, output_width, resize_side):
166 | image = _aspect_preserving_resize(image, resize_side)
167 | image = _random_crop([image], output_height, output_width)[0]
168 | image.set_shape([output_height, output_width, 3])
169 | image = tf.to_float(image)
170 | return _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN])
171 |
172 |
173 | def preprocessing_for_eval(image, output_height, output_width, resize_side):
174 | image = _aspect_preserving_resize(image, resize_side)
175 | image = _central_crop([image], output_height, output_width)[0]
176 | image.set_shape([output_height, output_width, 3])
177 | image = tf.to_float(image)
178 | return _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN])
179 |
180 |
181 | def preprocessing_image(image, output_height, output_width,
182 | resize_side=_RESIZE_SIDE_MIN, is_training=False):
183 | if is_training:
184 | return preprocessing_for_train(image, output_height, output_width, resize_side)
185 | else:
186 | return preprocessing_for_eval(image, output_height, output_width, resize_side)
187 |
188 |
189 | #########################
190 | # personal modification #
191 | #########################
192 | def mean_image_subtraction(images, means=(_R_MEAN, _G_MEAN, _B_MEAN)):
193 | """works for one single image with dynamic shapes"""
194 | num_channels = 3
195 | channels = tf.split(images, num_channels, axis=2)
196 | for i in range(num_channels):
197 | channels[i] -= means[i]
198 | return tf.concat(channels, axis=2)
199 |
200 |
201 | def mean_image_summation(image, means=(_R_MEAN, _G_MEAN, _B_MEAN)):
202 | """works for one single image with dynamic shapes"""
203 | num_channels = 3
204 | channels = tf.split(image, num_channels, axis=2)
205 | for i in range(num_channels):
206 | channels[i] += means[i]
207 | return tf.concat(channels, axis=2)
208 |
209 |
210 | def batch_mean_image_subtraction(images, means=(_R_MEAN, _G_MEAN, _B_MEAN)):
211 | if images.get_shape().ndims != 4:
212 | raise ValueError('Input must be of size [batch, height, width, C>0')
213 | num_channels = images.get_shape().as_list()[-1]
214 | if len(means) != num_channels:
215 | raise ValueError('len(means) must match the number of channels')
216 | channels = tf.split(images, num_channels, axis=3)
217 | for i in range(num_channels):
218 | channels[i] -= means[i]
219 | return tf.concat(channels, axis=3)
220 |
221 |
222 | def batch_mean_image_summation(images, means=(_R_MEAN, _G_MEAN, _B_MEAN)):
223 | if images.get_shape().ndims != 4:
224 | raise ValueError('Input must be of size [batch, height, width, C>0')
225 | num_channels = images.get_shape().as_list()[-1]
226 | if len(means) != num_channels:
227 | raise ValueError('len(means) must match the number of channels')
228 | channels = tf.split(images, num_channels, axis=3)
229 | for i in range(num_channels):
230 | channels[i] += means[i]
231 | return tf.concat(channels, axis=3)
232 |
233 |
234 | def image_normalization(images, means=(_R_MEAN, _G_MEAN, _B_MEAN), scale=127.5):
235 | """rescale the images so that their magnitude ranging from [-1, 1]"""
236 | if images.get_shape().ndims == 4:
237 | return tf.div(batch_mean_image_subtraction(images, means), scale)
238 | elif images.get_shape().ndims == 3:
239 | return tf.div(mean_image_subtraction(images, means), scale)
240 | else:
241 | raise ValueError('Input must be of dimensions 3 or 4')
242 |
243 |
244 | def aspect_preserving_resize(image, smallest_side):
245 | return _aspect_preserving_resize(image, smallest_side)
246 |
--------------------------------------------------------------------------------
/models/preprocessing.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/preprocessing.pyc
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains models definitions for versions of the Oxford VGG network.
16 |
17 | These models definitions were introduced in the following technical report:
18 |
19 | Very Deep Convolutional Networks For Large-Scale Image Recognition
20 | Karen Simonyan and Andrew Zisserman
21 | arXiv technical report, 2015
22 | PDF: http://arxiv.org/pdf/1409.1556.pdf
23 | ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf
24 | CC-BY-4.0
25 |
26 | More information can be obtained from the VGG website:
27 | www.robots.ox.ac.uk/~vgg/research/very_deep/
28 |
29 | Usage:
30 | with slim.arg_scope(vgg.vgg_arg_scope()):
31 | outputs, end_points = vgg.vgg_a(inputs)
32 |
33 | with slim.arg_scope(vgg.vgg_arg_scope()):
34 | outputs, end_points = vgg.vgg_16(inputs)
35 |
36 | @@vgg_a
37 | @@vgg_16
38 | @@vgg_19
39 | """
40 | from __future__ import absolute_import
41 | from __future__ import division
42 | from __future__ import print_function
43 |
44 | import tensorflow as tf
45 |
46 | slim = tf.contrib.slim
47 |
48 |
49 | def vgg_arg_scope(weight_decay=0.0005):
50 | """Defines the VGG arg scope.
51 |
52 | Args:
53 | weight_decay: The l2 regularization coefficient.
54 |
55 | Returns:
56 | An arg_scope.
57 | """
58 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
59 | activation_fn=tf.nn.relu,
60 | weights_regularizer=slim.l2_regularizer(weight_decay),
61 | biases_initializer=tf.zeros_initializer()):
62 | with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc:
63 | return arg_sc
64 |
65 |
66 | def vgg_a(inputs,
67 | num_classes=1000,
68 | is_training=True,
69 | dropout_keep_prob=0.5,
70 | spatial_squeeze=True,
71 | scope='vgg_a'):
72 | """Oxford Net VGG 11-Layers version A Example.
73 |
74 | Note: All the fully_connected layers have been transformed to conv2d layers.
75 | To use in classification mode, resize input to 224x224.
76 |
77 | Args:
78 | inputs: a tensor of size [batch_size, height, width, channels].
79 | num_classes: number of predicted classes.
80 | is_training: whether or not the models is being trained.
81 | dropout_keep_prob: the probability that activations are kept in the dropout
82 | layers during training.
83 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
84 | outputs. Useful to remove unnecessary dimensions for classification.
85 | scope: Optional scope for the variables.
86 |
87 | Returns:
88 | the last op containing the log predictions and end_points dict.
89 | """
90 | with tf.variable_scope(scope, 'vgg_a', [inputs]) as sc:
91 | end_points_collection = sc.name + '_end_points'
92 | # Collect outputs for conv2d, fully_connected and max_pool2d.
93 | with slim.arg_scope([slim.conv2d, slim.max_pool2d],
94 | outputs_collections=end_points_collection):
95 | net = slim.repeat(inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1')
96 | net = slim.max_pool2d(net, [2, 2], scope='pool1')
97 | net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv2')
98 | net = slim.max_pool2d(net, [2, 2], scope='pool2')
99 | net = slim.repeat(net, 2, slim.conv2d, 256, [3, 3], scope='conv3')
100 | net = slim.max_pool2d(net, [2, 2], scope='pool3')
101 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv4')
102 | net = slim.max_pool2d(net, [2, 2], scope='pool4')
103 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5')
104 | net = slim.max_pool2d(net, [2, 2], scope='pool5')
105 | # Use conv2d instead of fully_connected layers.
106 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
107 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
108 | scope='dropout6')
109 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
111 | scope='dropout7')
112 | net = slim.conv2d(net, num_classes, [1, 1],
113 | activation_fn=None,
114 | normalizer_fn=None,
115 | scope='fc8')
116 | # Convert end_points_collection into a end_point dict.
117 | end_points = slim.utils.convert_collection_to_dict(end_points_collection)
118 | if spatial_squeeze:
119 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
120 | end_points[sc.name + '/fc8'] = net
121 | return net, end_points
122 |
123 |
124 | vgg_a.default_image_size = 224
125 |
126 |
127 | def vgg_16(inputs,
128 | num_classes=1000,
129 | is_training=True,
130 | dropout_keep_prob=0.5,
131 | spatial_squeeze=True,
132 | reuse=True,
133 | scope='vgg_16'):
134 | """Oxford Net VGG 16-Layers version D Example.
135 |
136 | Note: All the fully_connected layers have been transformed to conv2d layers.
137 | To use in classification mode, resize input to 224x224.
138 |
139 | Args:
140 | inputs: a tensor of size [batch_size, height, width, channels].
141 | num_classes: number of predicted classes.
142 | is_training: whether or not the models is being trained.
143 | dropout_keep_prob: the probability that activations are kept in the dropout
144 | layers during training.
145 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
146 | outputs. Useful to remove unnecessary dimensions for classification.
147 | reuse: whether to reuse the network parameters
148 | scope: Optional scope for the variables.
149 |
150 | Returns:
151 | the last op containing the log predictions and end_points dict.
152 | """
153 | with tf.variable_scope(scope, 'vgg_16', [inputs], reuse=reuse) as sc:
154 | end_points_collection = sc.original_name_scope + '_end_points'
155 | # Collect outputs for conv2d, fully_connected and max_pool2d.
156 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
157 | outputs_collections=end_points_collection):
158 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
159 | # net = slim.max_pool2d(net, [2, 2], scope='pool1')
160 | net = slim.avg_pool2d(net, [2, 2], scope='pool1')
161 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
162 | # net = slim.max_pool2d(net, [2, 2], scope='pool2')
163 | net = slim.avg_pool2d(net, [2, 2], scope='pool2')
164 | net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
165 | # net = slim.max_pool2d(net, [2, 2], scope='pool3')
166 | net = slim.avg_pool2d(net, [2, 2], scope='pool3')
167 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
168 | # net = slim.max_pool2d(net, [2, 2], scope='pool4')
169 | net = slim.avg_pool2d(net, [2, 2], scope='pool4')
170 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
171 | # net = slim.max_pool2d(net, [2, 2], scope='pool5')
172 | net = slim.avg_pool2d(net, [2, 2], scope='pool5')
173 | # Use conv2d instead of fully_connected layers.
174 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
175 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
176 | scope='dropout6')
177 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
178 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
179 | scope='dropout7')
180 | net = slim.conv2d(net, num_classes, [1, 1],
181 | activation_fn=None,
182 | normalizer_fn=None,
183 | scope='fc8')
184 | # Convert end_points_collection into a end_point dict.
185 | end_points = slim.utils.convert_collection_to_dict(end_points_collection)
186 | if spatial_squeeze:
187 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
188 | end_points[sc.name + '/fc8'] = net
189 | return net, end_points
190 |
191 |
192 | vgg_16.default_image_size = 224
193 |
194 |
195 | def vgg_19(inputs,
196 | num_classes=1000,
197 | is_training=True,
198 | dropout_keep_prob=0.5,
199 | spatial_squeeze=True,
200 | reuse=True,
201 | scope='vgg_19'):
202 | """Oxford Net VGG 19-Layers version E Example.
203 |
204 | Note: All the fully_connected layers have been transformed to conv2d layers.
205 | To use in classification mode, resize input to 224x224.
206 |
207 | Args:
208 | inputs: a tensor of size [batch_size, height, width, channels].
209 | num_classes: number of predicted classes.
210 | is_training: whether or not the models is being trained.
211 | dropout_keep_prob: the probability that activations are kept in the dropout
212 | layers during training.
213 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the
214 | outputs. Useful to remove unnecessary dimensions for classification.
215 | reuse: whether to reuse the network parameters
216 | scope: Optional scope for the variables.
217 |
218 | Returns:
219 | the last op containing the log predictions and end_points dict.
220 | """
221 | with tf.variable_scope(scope, 'vgg_19', [inputs], reuse=reuse) as sc:
222 | end_points_collection = sc.original_name_scope + '_end_points'
223 | # Collect outputs for conv2d, fully_connected and max_pool2d.
224 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
225 | outputs_collections=end_points_collection):
226 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
227 | # net = slim.max_pool2d(net, [2, 2], scope='pool1')
228 | net = slim.avg_pool2d(net, [2, 2], scope='pool1')
229 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
230 | # net = slim.max_pool2d(net, [2, 2], scope='pool2')
231 | net = slim.avg_pool2d(net, [2, 2], scope='pool2')
232 | net = slim.repeat(net, 4, slim.conv2d, 256, [3, 3], scope='conv3')
233 | # net = slim.max_pool2d(net, [2, 2], scope='pool3')
234 | net = slim.avg_pool2d(net, [2, 2], scope='pool3')
235 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv4')
236 | # net = slim.max_pool2d(net, [2, 2], scope='pool4')
237 | net = slim.avg_pool2d(net, [2, 2], scope='pool4')
238 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv5')
239 | # net = slim.max_pool2d(net, [2, 2], scope='pool5')
240 | net = slim.avg_pool2d(net, [2, 2], scope='pool5')
241 | # Use conv2d instead of fully_connected layers.
242 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
243 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
244 | scope='dropout6')
245 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
246 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
247 | scope='dropout7')
248 | net = slim.conv2d(net, num_classes, [1, 1],
249 | activation_fn=None,
250 | normalizer_fn=None,
251 | scope='fc8')
252 | # Convert end_points_collection into a end_point dict.
253 | end_points = slim.utils.convert_collection_to_dict(end_points_collection)
254 | if spatial_squeeze:
255 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
256 | end_points[sc.name + '/fc8'] = net
257 | return net, end_points
258 |
259 |
260 | vgg_19.default_image_size = 224
261 |
262 |
263 | # Alias
264 | vgg_d = vgg_16
265 | vgg_e = vgg_19
266 |
--------------------------------------------------------------------------------
/models/vgg.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/vgg.pyc
--------------------------------------------------------------------------------
/models/vgg_decoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | from __future__ import division
4 |
5 | import tensorflow as tf
6 |
7 | from models import network_ops
8 |
9 | slim = tf.contrib.slim
10 |
11 | vgg_19_decoder_architecture = [
12 | ('conv5/conv5_4', ('c', 512, 3)),
13 | ('conv5/conv5_3', ('c', 512, 3)),
14 | ('conv5/conv5_2', ('c', 512, 3)),
15 | ('conv5/conv5_1', ('c', 512, 3)),
16 | ('conv4/conv4_4', ('uc', 512, 3)),
17 | ('conv4/conv4_3', ('c', 512, 3)),
18 | ('conv4/conv4_2', ('c', 512, 3)),
19 | ('conv4/conv4_1', ('c', 256, 3)),
20 | ('conv3/conv3_4', ('uc', 256, 3)),
21 | ('conv3/conv3_3', ('c', 256, 3)),
22 | ('conv3/conv3_2', ('c', 256, 3)),
23 | ('conv3/conv3_1', ('c', 128, 3)),
24 | ('conv2/conv2_2', ('uc', 128, 3)),
25 | ('conv2/conv2_1', ('c', 64, 3)),
26 | ('conv1/conv1_2', ('uc', 64, 3)),
27 | ('conv1/conv1_1', ('c', 64, 3)),
28 | ]
29 |
30 | vgg_16_decoder_architecture = [
31 | ('conv5/conv5_3', ('c', 512, 3)),
32 | ('conv5/conv5_2', ('c', 512, 3)),
33 | ('conv5/conv5_1', ('c', 512, 3)),
34 | ('conv4/conv4_3', ('uc', 512, 3)),
35 | ('conv4/conv4_2', ('c', 512, 3)),
36 | ('conv4/conv4_1', ('c', 256, 3)),
37 | ('conv3/conv3_3', ('uc', 256, 3)),
38 | ('conv3/conv3_2', ('c', 256, 3)),
39 | ('conv3/conv3_1', ('c', 128, 3)),
40 | ('conv2/conv2_2', ('uc', 128, 3)),
41 | ('conv2/conv2_1', ('c', 64, 3)),
42 | ('conv1/conv1_2', ('uc', 64, 3)),
43 | ('conv1/conv1_1', ('c', 64, 3)),
44 | ]
45 |
46 | network_map = {
47 | 'vgg_19': vgg_19_decoder_architecture,
48 | 'vgg_16': vgg_16_decoder_architecture,
49 | }
50 |
51 |
52 | def vgg_decoder_arg_scope(weight_decay=0.0005):
53 | with slim.arg_scope(
54 | [slim.conv2d],
55 | padding='SAME',
56 | activation_fn=tf.nn.relu,
57 | normalizer_fn=None,
58 | weights_initializer=slim.xavier_initializer(uniform=False),
59 | weights_regularizer=slim.l2_regularizer(weight_decay)) as arg_sc:
60 | return arg_sc
61 |
62 |
63 | def vgg_decoder(inputs,
64 | network_name='vgg_16',
65 | starting_layer='conv1/conv1_1',
66 | reuse=False,
67 | scope=None):
68 | """construct the decoder network for the vgg models
69 |
70 | Args:
71 | inputs: input features [batch_size, height, width, channel]
72 | network_name: the type of the network, default is vgg_16
73 | starting_layer: the starting reflectance layer, default is 'conv1/conv1_1'
74 | reuse: (optional) whether to reuse the network
75 | scope: (optional) the scope of the network
76 |
77 | Returns:
78 | outputs: the decoded feature maps
79 | """
80 | with tf.variable_scope(scope, 'image_decoder', reuse=reuse):
81 | # gather the output with identity mapping
82 | net = tf.identity(inputs)
83 |
84 | # starting inferring the network
85 | is_active = False
86 | for layer, layer_struct in network_map[network_name]:
87 | if layer == starting_layer:
88 | is_active = True
89 | if is_active:
90 | conv_type, num_outputs, kernel_size = layer_struct
91 | if conv_type == 'c':
92 | net = network_ops.conv2d_same(net, num_outputs, kernel_size, 1, scope=layer)
93 | elif conv_type == 'uc':
94 | net = network_ops.conv2d_resize(net, num_outputs, kernel_size, 2, scope=layer)
95 | with slim.arg_scope([slim.conv2d], normalizer_fn=None, activation_fn=tf.tanh):
96 | outputs = network_ops.conv2d_same(net, 3, 7, 1, scope='output')
97 | return outputs * 150.0 + 127.5
98 |
99 |
100 | def vgg_combined_decoder(inputs,
101 | additional_features,
102 | fusion_fn=None,
103 | network_name='vgg_16',
104 | starting_layer='conv1/conv1_1',
105 | reuse=False,
106 | scope=None):
107 | """construct the decoder network with additional feature combination
108 |
109 | Args:
110 | inputs: input features [batch_size, height, width, channel]
111 | additional_features: a dict contains the additional features
112 | fusion_fn: the fusion function to combine features
113 | network_name: the type of the network, default is vgg_16
114 | starting_layer: the starting reflectance layer, default is 'conv1/conv1_1'
115 | reuse: (optional) whether to reuse the network
116 | scope: (optional) the scope of the network
117 |
118 | Returns:
119 | outputs: the decoded feature maps
120 | """
121 | with tf.variable_scope(scope, 'combined_decoder', reuse=reuse):
122 | # gather the output with identity mapping
123 | net = tf.identity(inputs)
124 |
125 | # starting inferring the network
126 | is_active = False
127 | for layer, layer_struct in network_map[network_name]:
128 | if layer == starting_layer:
129 | is_active = True
130 | if is_active:
131 | conv_type, num_outputs, kernel_size = layer_struct
132 |
133 | # combine the feature
134 | add_feature = additional_features.get(layer)
135 | if add_feature is not None and layer != starting_layer:
136 | net = fusion_fn(net, add_feature)
137 |
138 | if conv_type == 'c':
139 | net = network_ops.conv2d_same(net, num_outputs, kernel_size, 1, scope=layer)
140 | elif conv_type == 'uc':
141 | net = network_ops.conv2d_resize(net, num_outputs, kernel_size, 2, scope=layer)
142 | with slim.arg_scope([slim.conv2d], normalizer_fn=None, activation_fn=None):
143 | outputs = network_ops.conv2d_same(net, 3, 7, 1, scope='output')
144 | return outputs + 127.5
145 |
146 |
147 | def vgg_multiple_combined_decoder(inputs,
148 | additional_features,
149 | blending_weights,
150 | fusion_fn=None,
151 | network_name='vgg_16',
152 | starting_layer='conv1/conv1_1',
153 | reuse=False,
154 | scope=None):
155 | """construct the decoder network with additional feature combination
156 |
157 | Args:
158 | inputs: input features [batch_size, height, width, channel]
159 | additional_features: a dict contains the additional features
160 | blending_weights: the list of weights used for feature blending
161 | fusion_fn: the fusion function to combine features
162 | network_name: the type of the network, default is vgg_16
163 | starting_layer: the starting reflectance layer, default is 'conv1/conv1_1'
164 | reuse: (optional) whether to reuse the network
165 | scope: (optional) the scope of the network
166 |
167 | Returns:
168 | outputs: the decoded feature maps
169 | """
170 | with tf.variable_scope(scope, 'combined_decoder', reuse=reuse):
171 | # gather the output with identity mapping
172 | net = tf.identity(inputs)
173 |
174 | # starting inferring the network
175 | is_active = False
176 | for layer, layer_struct in network_map[network_name]:
177 | if layer == starting_layer:
178 | is_active = True
179 | if is_active:
180 | conv_type, num_outputs, kernel_size = layer_struct
181 |
182 | # combine the feature
183 | add_feature = additional_features[0].get(layer)
184 | if add_feature is not None and layer != starting_layer:
185 | # fuse multiple styles
186 | n = 0
187 | layer_output = 0.0
188 | for additional_feature in additional_features:
189 | additional_layer_feature = additional_feature.get(layer)
190 | fused_layer_feature = fusion_fn(net, additional_layer_feature)
191 | layer_output += blending_weights[n] * fused_layer_feature
192 | n += 1
193 | net = layer_output
194 |
195 | if conv_type == 'c':
196 | net = network_ops.conv2d_same(net, num_outputs, kernel_size, 1, scope=layer)
197 | elif conv_type == 'uc':
198 | net = network_ops.conv2d_resize(net, num_outputs, kernel_size, 2, scope=layer)
199 | with slim.arg_scope([slim.conv2d], normalizer_fn=None, activation_fn=None):
200 | outputs = network_ops.conv2d_same(net, 3, 7, 1, scope='output')
201 | return outputs + 127.5
202 |
--------------------------------------------------------------------------------
/models/vgg_decoder.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/vgg_decoder.pyc
--------------------------------------------------------------------------------
/scripts/evaluate_style_transfer.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CUDA_ID=$1
4 | # content image folders:
5 | # exemplar content images: ./data/contents/images/
6 | # exemplar content videos: ./data/contents/sequences/
7 | CONTENT_DATASET_DIR=$2
8 | # style image folders: ./data/styles/
9 | STYLE_DATASET_DIR=$3
10 | # output image folders: ./results/sequences/
11 | EVAL_DATASET_DIR=$4
12 |
13 | # network configuration
14 | CONFIG_DIR=./configs/AvatarNet_config.yml
15 |
16 | # the network path for the trained auto-encoding network (need to change accordingly)
17 | MODEL_DIR=/DATA/AvatarNet
18 |
19 | CUDA_VISIBLE_DEVICES=${CUDA_ID} \
20 | python evaluate_style_transfer.py \
21 | --checkpoint_dir=${MODEL_DIR} \
22 | --model_config_path=${CONFIG_DIR} \
23 | --content_dataset_dir=${CONTENT_DATASET_DIR} \
24 | --style_dataset_dir=${STYLE_DATASET_DIR} \
25 | --eval_dir=${EVAL_DATASET_DIR} \
26 | --inter_weight=0.8
--------------------------------------------------------------------------------
/scripts/train_image_reconstruction.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CUDA_ID=$1
4 | # MSCOCO tfexample dataset path
5 | DATASET_DIR=$2
6 | # model path
7 | MODEL_DIR=$3
8 |
9 | # network configuration
10 | CONFIG_DIR=./configs/AvatarNet_config.yml
11 |
12 | CUDA_VISIBLE_DEVICES=${CUDA_ID} \
13 | python train_image_reconstruction.py \
14 | --train_dir=${MODEL_DIR} \
15 | --model_config=${CONFIG_DIR} \
16 | --dataset_dir=${DATASET_DIR} \
17 | --dataset_name=MSCOCO \
18 | --dataset_split_name=train \
19 | --batch_size=8 \
20 | --max_number_of_step=120000 \
21 | --optimizer=adam \
22 | --learning_rate_decay_type=fixed \
23 | --learning_rate=0.0001
--------------------------------------------------------------------------------
/train_image_reconstruction.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import tensorflow as tf
6 |
7 | from tensorflow.python.ops import control_flow_ops
8 | from models.preprocessing import preprocessing_image
9 | from models import models_factory
10 | from datasets import dataset_utils
11 |
12 | slim = tf.contrib.slim
13 |
14 |
15 | tf.app.flags.DEFINE_integer(
16 | 'num_readers', 4,
17 | 'The number of parallel readers that read data from the dataset.')
18 | tf.app.flags.DEFINE_integer(
19 | 'num_preprocessing_threads', 1,
20 | 'The number of threads used to create the batches.')
21 |
22 | # ====================== #
23 | # Training specification #
24 | # ====================== #
25 | tf.app.flags.DEFINE_string(
26 | 'train_dir', '/tmp/tfmodel',
27 | 'Directory where checkpoints and event logs are written to.')
28 | tf.app.flags.DEFINE_integer(
29 | 'log_every_n_steps', 100,
30 | 'The frequency with which logs are printed, in seconds.')
31 | tf.app.flags.DEFINE_integer(
32 | 'save_interval_secs', 600,
33 | 'The frequency with which the models is saved, in seconds.')
34 | tf.app.flags.DEFINE_integer(
35 | 'save_summaries_secs', 120,
36 | 'The frequency with which summaries are saved, in seconds.')
37 | tf.app.flags.DEFINE_integer(
38 | 'batch_size', 32, 'The number of samples in each batch.')
39 | tf.app.flags.DEFINE_integer(
40 | 'max_number_of_steps', None, 'The maximum number of training steps.')
41 |
42 | # ============= #
43 | # Dataset Flags #
44 | # ============= #
45 | tf.app.flags.DEFINE_string(
46 | 'dataset_dir', None,
47 | 'The directory where the dataset files are stored.')
48 | tf.app.flags.DEFINE_string(
49 | 'dataset_name', None,
50 | 'The name of the dataset to load.')
51 | tf.app.flags.DEFINE_string(
52 | 'dataset_split_name', 'train',
53 | 'The name of the train/test split.')
54 |
55 | #######################
56 | # Model specification #
57 | #######################
58 | tf.app.flags.DEFINE_string(
59 | 'model_config', None,
60 | 'Directory where the configuration of the models is stored.')
61 |
62 | ######################
63 | # Optimization Flags #
64 | ######################
65 | tf.app.flags.DEFINE_string(
66 | 'optimizer', 'rmsprop',
67 | 'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
68 | '"ftrl", "momentum", "sgd" or "rmsprop".')
69 | tf.app.flags.DEFINE_float(
70 | 'adadelta_rho', 0.95, 'The decay rate for adadelta.')
71 | tf.app.flags.DEFINE_float(
72 | 'adagrad_initial_accumulator_value', 0.1,
73 | 'Starting value for the AdaGrad accumulators.')
74 | tf.app.flags.DEFINE_float(
75 | 'adam_beta1', 0.9,
76 | 'The exponential decay rate for the 1st moment estimates.')
77 | tf.app.flags.DEFINE_float(
78 | 'adam_beta2', 0.999,
79 | 'The exponential decay rate for the 2nd moment estimates.')
80 | tf.app.flags.DEFINE_float(
81 | 'opt_epsilon', 1.0, 'Epsilon term for the optimizer.')
82 | tf.app.flags.DEFINE_float(
83 | 'ftrl_learning_rate_power', -0.5, 'The learning rate power.')
84 | tf.app.flags.DEFINE_float(
85 | 'ftrl_initial_accumulator_value', 0.1,
86 | 'Starting value for the FTRL accumulators.')
87 | tf.app.flags.DEFINE_float(
88 | 'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')
89 | tf.app.flags.DEFINE_float(
90 | 'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')
91 | tf.app.flags.DEFINE_float(
92 | 'momentum', 0.9,
93 | 'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
94 | tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')
95 | tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
96 |
97 | #######################
98 | # Learning Rate Flags #
99 | #######################
100 | tf.app.flags.DEFINE_string(
101 | 'learning_rate_decay_type', 'exponential',
102 | 'Specififies how the learning rate is decayed. One of "fixed",'
103 | '"exponential", or "polynomial".')
104 | tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
105 | tf.app.flags.DEFINE_float(
106 | 'end_learning_rate', 0.0001,
107 | 'The minimal end learning rate used by a polynomial decay learning rate.')
108 | tf.app.flags.DEFINE_float(
109 | 'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')
110 | tf.app.flags.DEFINE_float(
111 | 'num_epochs_per_decay', 2.0,
112 | 'Number of epochs after which learning rate decays.')
113 | tf.app.flags.DEFINE_float(
114 | 'moving_average_decay', None,
115 | 'If left as None, the moving averages are not used.')
116 |
117 | # ============================ #
118 | # Fine-Tuning Flags
119 | # ============================ #
120 | tf.app.flags.DEFINE_string(
121 | 'checkpoint_path', None,
122 | 'The path to a checkpoint from which to fine-tune.')
123 | tf.app.flags.DEFINE_string(
124 | 'checkpoint_exclude_scopes', None,
125 | 'Comma-separated list of scopes of variables to exclude when restoring '
126 | 'from a checkpoint.')
127 | tf.app.flags.DEFINE_string(
128 | 'trainable_scopes', None,
129 | 'Comma-separated list of scopes to filter the set of variables to train.'
130 | 'By default, None would train all the variables.')
131 | tf.app.flags.DEFINE_boolean(
132 | 'ignore_missing_vars', False,
133 | 'When restoring a checkpoint would ignore missing variables.')
134 |
135 | FLAGS = tf.app.flags.FLAGS
136 |
137 |
138 | def _configure_learning_rate(num_samples_per_epoch, global_step):
139 | """Configures the learning rate.
140 |
141 | Args:
142 | num_samples_per_epoch: The number of samples in each epoch of training
143 | global_step: The global_step tensor.
144 |
145 | Returns:
146 | A `Tensor` representing the learning rate
147 |
148 | Raises:
149 | ValueError
150 | """
151 | decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
152 | FLAGS.num_epochs_per_decay)
153 | if FLAGS.learning_rate_decay_type == 'exponential':
154 | return tf.train.exponential_decay(
155 | FLAGS.learning_rate,
156 | global_step,
157 | decay_steps,
158 | FLAGS.learning_rate_decay_factor,
159 | staircase=True,
160 | name='exponential_decay_learning_rate')
161 | elif FLAGS.learning_rate_decay_type == 'fixed':
162 | return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
163 | elif FLAGS.learning_rate_decay_type == 'polynomial':
164 | return tf.train.polynomial_decay(
165 | FLAGS.learning_rate,
166 | global_step,
167 | decay_steps,
168 | FLAGS.end_learning_rate,
169 | power=1.0,
170 | cycle=False,
171 | name='polynomial_decay_learning_rate')
172 | else:
173 | raise ValueError('learning_rate_decay_type [%s] was not recognized',
174 | FLAGS.learning_rate_decay_type)
175 |
176 |
177 | def _configure_optimizer(learning_rate):
178 | """Configures the optimizer used for training.
179 |
180 | Args:
181 | learning_rate: A scalar or 'Tensor' learning rate
182 |
183 | Returns:
184 | An instance of an optimizer
185 |
186 | Raises:
187 | ValueError: if FLAGS.optimizer is not recognized
188 | """
189 | if FLAGS.optimizer == 'adadelta':
190 | optimizer = tf.train.AdadeltaOptimizer(
191 | learning_rate, rho=FLAGS.adadelta_rho, epsilon=FLAGS.opt_epsilon)
192 | elif FLAGS.optimizer == 'adagrad':
193 | optimizer = tf.train.AdagradOptimizer(
194 | learning_rate,
195 | initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
196 | elif FLAGS.optimizer == 'adam':
197 | optimizer = tf.train.AdamOptimizer(
198 | learning_rate,
199 | beta1=FLAGS.adam_beta1,
200 | beta2=FLAGS.adam_beta2,
201 | epsilon=FLAGS.opt_epsilon)
202 | elif FLAGS.optimizer == 'ftr1':
203 | optimizer = tf.train.FtrlOptimizer(
204 | learning_rate,
205 | learning_rate_power=FLAGS.ftrl_learning_rate_power,
206 | initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
207 | l1_regularization_strength=FLAGS.ftrl_l1,
208 | l2_regularization_strength=FLAGS.ftrl_l2)
209 | elif FLAGS.optimizer == 'momentum':
210 | optimizer = tf.train.MomentumOptimizer(
211 | learning_rate,
212 | momentum=FLAGS.momentum,
213 | name='Momentum')
214 | elif FLAGS.optimizer == 'rmsprop':
215 | optimizer = tf.train.RMSPropOptimizer(
216 | learning_rate,
217 | decay=FLAGS.rmsprop_decay,
218 | momentum=FLAGS.rmsprop_momentum,
219 | epsilon=FLAGS.opt_epsilon)
220 | elif FLAGS.optimizer == 'sgd':
221 | optimizer = tf.train.GradientDescentOptimizer(learning_rate)
222 | else:
223 | raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
224 | return optimizer
225 |
226 |
227 | def _get_variables_to_train(options):
228 | """Returns a list of variables to train.
229 |
230 | Args:
231 | A list of variables to train by the optimizer.
232 | """
233 | if options.get('trainable_scopes') is None:
234 | return tf.trainable_variables()
235 | else:
236 | scopes = [scope.strip() for scope in options.get('trainable_scopes').split(',')]
237 |
238 | variables_to_train = []
239 | for scope in scopes:
240 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
241 | variables_to_train.extend(variables)
242 | return variables_to_train
243 |
244 |
245 | def _get_init_fn(options):
246 | """Returns a function to warm-start the training.
247 |
248 | Note that the init_fn is only run when initializing the models during the
249 | very first global step.
250 |
251 | Returns:
252 | An init function
253 | """
254 | if options.get('checkpoint_path') is None:
255 | return None
256 |
257 | # Warn the user if a checkpoint exists in the train_dir. Then we'll be
258 | # ignoring the checkpoint anyway.
259 | if tf.train.latest_checkpoint(FLAGS.train_dir):
260 | tf.logging.info(
261 | 'Ignoring --checkpoint_path because a checkpoint already exists '
262 | 'in %s' % FLAGS.train_dir)
263 | return None
264 |
265 | exclusions = []
266 | if options.get('checkpoint_exclude_scopes'):
267 | # remove space and comma
268 | exclusions = [scope.strip()
269 | for scope in options.get('checkpoint_exclude_scopes').split(',')]
270 |
271 | variables_to_restore = []
272 | for var in slim.get_model_variables():
273 | excluded = False
274 | for exclusion in exclusions:
275 | if var.op.name.startswith(exclusion):
276 | excluded = True
277 | break
278 | if not excluded:
279 | variables_to_restore.append(var)
280 |
281 | if tf.gfile.IsDirectory(options.get('checkpoint_path')):
282 | checkpoint_path = tf.train.latest_checkpoint(options.get('checkpoint_path'))
283 | else:
284 | checkpoint_path = options.get('checkpoint_path')
285 |
286 | tf.logging.info('Fine-tuning from %s' % checkpoint_path)
287 |
288 | return slim.assign_from_checkpoint_fn(
289 | checkpoint_path,
290 | variables_to_restore,
291 | ignore_missing_vars=options.get('ignore_missing_vars'))
292 |
293 |
294 | def main(_):
295 | if not FLAGS.dataset_dir:
296 | raise ValueError('You must supply the dataset directory with'
297 | ' --dataset_dir')
298 |
299 | tf.logging.set_verbosity(tf.logging.INFO)
300 | with tf.Graph().as_default():
301 | global_step = slim.create_global_step() # create the global step
302 |
303 | ######################
304 | # select the dataset #
305 | ######################
306 | dataset = dataset_utils.get_split(
307 | FLAGS.dataset_name,
308 | FLAGS.dataset_split_name,
309 | FLAGS.dataset_dir)
310 |
311 | ######################
312 | # create the network #
313 | ######################
314 | # parse the options from a yaml file
315 | model, options = models_factory.get_model(FLAGS.model_config)
316 |
317 | ####################################################
318 | # create a dataset provider that loads the dataset #
319 | ####################################################
320 | # dataset provider
321 | provider = slim.dataset_data_provider.DatasetDataProvider(
322 | dataset,
323 | num_readers=FLAGS.num_readers,
324 | common_queue_capacity=20*FLAGS.batch_size,
325 | common_queue_min=10*FLAGS.batch_size)
326 | [image] = provider.get(['image'])
327 | image_clip = preprocessing_image(
328 | image,
329 | model.training_image_size,
330 | model.training_image_size,
331 | model.content_size,
332 | is_training=True)
333 | image_clip_batch = tf.train.batch(
334 | [image_clip],
335 | batch_size=FLAGS.batch_size,
336 | num_threads=FLAGS.num_preprocessing_threads,
337 | capacity=5*FLAGS.batch_size)
338 |
339 | # feque queue the inputs
340 | batch_queue = slim.prefetch_queue.prefetch_queue([image_clip_batch])
341 |
342 | ###########################################
343 | # build the models based on the given data #
344 | ###########################################
345 | images = batch_queue.dequeue()
346 | total_loss = model.build_train_graph(images)
347 |
348 | ####################################################
349 | # gather the operations for training and summaries #
350 | ####################################################
351 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
352 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
353 |
354 | # configurate the moving averages
355 | if FLAGS.moving_average_decay:
356 | moving_average_variables = slim.get_model_variables()
357 | variable_averages = tf.train.ExponentialMovingAverage(
358 | FLAGS.moving_average_decay, global_step)
359 | else:
360 | moving_average_variables, variable_averages = None, None
361 |
362 | # gather the optimizer operations
363 | learning_rate = _configure_learning_rate(
364 | dataset.num_samples, global_step)
365 | optimizer = _configure_optimizer(learning_rate)
366 | summaries.add(tf.summary.scalar('learning_rate', learning_rate))
367 |
368 | if FLAGS.moving_average_decay:
369 | update_ops.append(variable_averages.apply(moving_average_variables))
370 |
371 | # training operations
372 | train_op = model.get_training_operations(
373 | optimizer, global_step, _get_variables_to_train(options))
374 | update_ops.append(train_op)
375 |
376 | # gather the training summaries
377 | summaries |= set(model.summaries)
378 |
379 | # gather the update operation
380 | update_op = tf.group(*update_ops)
381 | watched_loss = control_flow_ops.with_dependencies(
382 | [update_op], total_loss, name='train_op')
383 |
384 | # merge the summaries
385 | summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES))
386 | summary_op = tf.summary.merge(list(summaries), name='summary_op')
387 |
388 | ##############################
389 | # start the training process #
390 | ##############################
391 | slim.learning.train(
392 | watched_loss,
393 | logdir=FLAGS.train_dir,
394 | init_fn=_get_init_fn(options),
395 | summary_op=summary_op,
396 | number_of_steps=FLAGS.max_number_of_steps,
397 | log_every_n_steps=FLAGS.log_every_n_steps,
398 | save_summaries_secs=FLAGS.save_summaries_secs,
399 | save_interval_secs=FLAGS.save_interval_secs)
400 |
401 |
402 | if __name__ == '__main__':
403 | tf.app.run()
404 |
--------------------------------------------------------------------------------