├── .gitignore ├── README.md ├── configs └── AvatarNet_config.yml ├── data ├── contents │ ├── images │ │ ├── avril.jpg │ │ ├── brad_pitt.jpg │ │ ├── cornell.jpg │ │ ├── flowers.jpg │ │ ├── modern.jpg │ │ └── woman_side_portrait.jpg │ └── sequences │ │ ├── frame_0001.png │ │ ├── frame_0002.png │ │ ├── frame_0003.png │ │ ├── frame_0004.png │ │ ├── frame_0005.png │ │ ├── frame_0006.png │ │ ├── frame_0007.png │ │ ├── frame_0008.png │ │ ├── frame_0009.png │ │ ├── frame_0010.png │ │ ├── frame_0011.png │ │ ├── frame_0012.png │ │ ├── frame_0013.png │ │ ├── frame_0014.png │ │ ├── frame_0015.png │ │ ├── frame_0016.png │ │ ├── frame_0017.png │ │ ├── frame_0018.png │ │ ├── frame_0019.png │ │ ├── frame_0020.png │ │ ├── frame_0021.png │ │ ├── frame_0022.png │ │ ├── frame_0023.png │ │ ├── frame_0024.png │ │ ├── frame_0025.png │ │ ├── frame_0026.png │ │ ├── frame_0027.png │ │ ├── frame_0028.png │ │ ├── frame_0029.png │ │ ├── frame_0030.png │ │ ├── frame_0031.png │ │ ├── frame_0032.png │ │ ├── frame_0033.png │ │ ├── frame_0034.png │ │ ├── frame_0035.png │ │ ├── frame_0036.png │ │ ├── frame_0037.png │ │ ├── frame_0038.png │ │ ├── frame_0039.png │ │ ├── frame_0040.png │ │ ├── frame_0041.png │ │ ├── frame_0042.png │ │ ├── frame_0043.png │ │ ├── frame_0044.png │ │ ├── frame_0045.png │ │ ├── frame_0046.png │ │ ├── frame_0047.png │ │ ├── frame_0048.png │ │ ├── frame_0049.png │ │ └── frame_0050.png └── styles │ ├── brushstrokers.jpg │ ├── candy.jpg │ ├── la_muse.jpg │ ├── plum_flower.jpg │ └── woman_in_peasant_dress_cropped.jpg ├── datasets ├── __init__.py ├── __init__.pyc ├── convert_mscoco_to_tfexamples.py ├── dataset_utils.py └── dataset_utils.pyc ├── docs ├── _config.yml ├── figures │ ├── closed_ups.png │ ├── image_results.png │ ├── network_architecture_with_comparison.png │ ├── result_comparison.png │ ├── snapshot.png │ ├── style_decorator.png │ ├── style_interpolation.png │ ├── teaser.png │ └── trade_off.png └── index.md ├── evaluate_style_transfer.py ├── models ├── __init__.py ├── __init__.pyc ├── autoencoder.py ├── avatar_net.py ├── avatar_net.pyc ├── losses.py ├── losses.pyc ├── models_factory.py ├── models_factory.pyc ├── network_ops.py ├── network_ops.pyc ├── preprocessing.py ├── preprocessing.pyc ├── vgg.py ├── vgg.pyc ├── vgg_decoder.py └── vgg_decoder.pyc ├── scripts ├── evaluate_style_transfer.sh └── train_image_reconstruction.sh └── train_image_reconstruction.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/* 2 | *.py[cod] 3 | ./results/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Avatar-Net: Multi-scale Zero-shot Style Transfer by Feature Decoration 2 | 3 | This repository contains the code (in [TensorFlow](https://www.tensorflow.org/)) for the paper: 4 | 5 | [__Avatar-Net: Multi-scale Zero-shot Style Transfer by Feature Decoration__](https://arxiv.org/abs/1805.03857) 6 |
7 | [Lu Sheng](http://www.ee.cuhk.edu.hk/~lsheng/), [Ziyi Lin](mailto:linziyi@sensetime.com), [Jing Shao](http://www.ee.cuhk.edu.hk/~jshao/), [Xiaogang Wang](http://www.ee.cuhk.edu.hk/~xgwang/) 8 |
9 | CVPR 2018 10 | 11 | ## Overview 12 | 13 | In this repository, we propose an efficient and effective Avatar-Net that enables visually plausible multi-scale transfer for _arbitrary_ style in real-time. The key ingredient is a __style decorator__ that makes up the content features by semantically aligned style features, which does not only holistically match their feature distributions but also preserve detailed style patterns in the decorated features. By embedding this module into an image reconstruction network that fuses multi-scale style abstractions, the Avatar-Net renders multi-scale stylization for any style image in one feed-forward pass. 14 | 15 | [teaser]: ./docs/figures/teaser.png 16 | ![teaser] 17 | 18 | ## Examples 19 | [image_results]: ./docs/figures/image_results.png 20 | ![image_results] 21 | 22 | ## Comparison with Prior Arts 23 | 24 |

25 | 26 | - The result by Avatar-Net receives concrete multi-scale style patterns (e.g. color distribution, brush strokes and circular patterns in _candy_ image). 27 | - [WCT](https://arxiv.org/abs/1705.08086) distorts the brush strokes and circular patterns. [AdaIN](https://arxiv.org/abs/1703.06868) cannot even keep the color distribution, while [Style-Swap](https://arxiv.org/abs/1612.04337) fails in this example. 28 | 29 | #### Execution Efficiency 30 | |Method| Gatys et. al. | AdaIN | WCT | Style-Swap | __Avatar-Net__ | 31 | | :---: | :---: | :---: | :---: | :---: | :---: | 32 | | __256x256 (sec)__ | 12.18 | 0.053 | 0.62 | 0.064 | __0.071__ | 33 | | __512x512 (sec)__ | 43.25 | 0.11 | 0.93 | 0.23 | __0.28__ | 34 | 35 | - Avatar-Net has a comparable executive time as AdaIN and GPU-accelerated Style-Swap, and is much faster than WCT and the optimization-based style transfer by [Gatys _et. al._](https://arxiv.org/abs/1508.06576). 36 | - The reference methods and the proposed Avatar-Net are implemented on a same TensorFlow platform with a same VGG network as the backbone. 37 | 38 | ## Dependencies 39 | - [TensorFlow](https://www.tensorflow.org/) (version >= 1.0, but just tested on TensorFlow 1.0). 40 | - Heavily depend on [TF-Slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim) and its [model repository](https://github.com/tensorflow/models/tree/master/research/slim). 41 | 42 | ## Download 43 | - The trained model of Avatar-Net can be downloaded through the [Google Drive](https://drive.google.com/open?id=1_7x93xwZMhCL-kLrz4B2iZ01Y8Q7SlTX). 44 | - The training of our style transfer network requires pretrained [VGG](https://arxiv.org/abs/1409.1556) networks, and they can be obtained from the [TF-Slim model repository](ttps://github.com/tensorflow/models/tree/master/research/slim). The encoding layers of Avatar-Net are also borrowed from pretrained VGG models. 45 | - [MSCOCO](http://cocodataset.org/#home) dataset is applied for the training of the proposed image reconstruction network. 46 | 47 | ## Usage 48 | 49 | ### Basic Usage 50 | 51 | Simply use the bash file `./scripts/evaluate_style_transfer.sh` to apply Avatar-Net to all content images in `CONTENT_DIR` from any style image in `STYLE_DIR`. For example, 52 | 53 | bash ./scripts/evaluate_style_transfer.sh gpu_id CONTENT_DIR STYLE_DIR EVAL_DIR 54 | 55 | - `gpu_id`: the mounted GPU ID for the TensorFlow session. 56 | - `CONTENT_DIR`: the directory of the content images. It can be `./data/contents/images` for multiple exemplar content images, or `./data/contents/sequences` for an exemplar content video. 57 | - `STYLE_DIR`: the directory of the style images. It can be `./data/styles` for multiple exemplar style images. 58 | - `EVAL_DIR`: the output directory. It contains multiple subdirectories named after the names of the style images. 59 | 60 | More detailed evaluation options can be found in `evaluate_style_transfer.py`, such as 61 | 62 | python evaluate_style_transfer.py 63 | 64 | ### Configuration 65 | 66 | The detailed configuration of Avatar-Net is listed in `configs/AvatarNet.yml`, including the training specifications and network hyper-parameters. The style decorator has three options: 67 | 68 | - `patch_size`: the patch size for the normalized cross-correlation, in default is `5`. 69 | - `style_coding`: the projection and reconstruction method, either `ZCA` or `AdaIN`. 70 | - `style_interp`: interpolation option between the transferred features and the content features, either `normalized` or `biased`. 71 | 72 | The style transfer is actually performed in `AvatarNet.transfer_styles(self, inputs, styles, inter_weight, intra_weights)`, in which 73 | 74 | - `inputs`: the content images. 75 | - `styles`: a list of style images (`len(styles)` > 2 for multiple style interpolation). 76 | - `inter_weight`: the weight balancing the style and content images. 77 | - `intra_weights`: a list of weights balancing the effects from different styles. 78 | 79 | Users may modify the evaluation script for multiple style interpolation or content-style trade-off. 80 | 81 | ### Training 82 | 83 | 1. Download [MSCOCO](http://cocodataset.org/#home) datasets and transfer the raw images into `tfexamples`, according to the python script `./datasets/convert_mscoco_to_tfexamples.py`. 84 | 2. Use `bash ./scripts/train_image_reconstruction.sh gpu_id DATASET_DIR MODEL_DIR` to start training with default hyper-parameters. `gpu_id` is the mounted GPU for the applied Tensorflow session. Replace `DATASET_DIR` with the path to MSCOCO training images and `MODEL_DIR` to Avatar-Net model directory. 85 | 86 | ## Citation 87 | 88 | If you find this code useful for your research, please cite the paper: 89 | 90 | Lu Sheng, Ziyi Lin, Jing Shao and Xiaogang Wang, "Avatar-Net: Multi-scale Zero-shot Style Transfer by Feature Decoration", in IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2018. [[Arxiv](https://arxiv.org/abs/1805.03857)] 91 | 92 | ``` 93 | @inproceedings{sheng2018avatar, 94 | Title = {Avatar-Net: Multi-scale Zero-shot Style Transfer by Feature Decoration}, 95 | author = {Sheng, Lu and Lin, Ziyi and Shao, Jing and Wang, Xiaogang}, 96 | Booktitle = {Computer Vision and Pattern Recognition (CVPR), 2018 IEEE Conference on}, 97 | pages={1--9}, 98 | year={2018} 99 | } 100 | ``` 101 | 102 | ## Acknowledgement 103 | 104 | This project is inspired by many style-agnostic style transfer methods, including [AdaIN](https://arxiv.org/abs/1703.06868), [WCT](https://arxiv.org/abs/1705.08086) and [Style-Swap](https://arxiv.org/abs/1612.04337), both from their papers and codes. 105 | 106 | ## Contact 107 | 108 | If you have any questions or suggestions about this paper, feel free to contact me ([lsheng@ee.cuhk.edu.hk](mailto:lsheng@ee.cuhk.edu.hk)) 109 | -------------------------------------------------------------------------------- /configs/AvatarNet_config.yml: -------------------------------------------------------------------------------- 1 | # name of the applied model 2 | model_name: 'AvatarNet' 3 | 4 | # the input sizes 5 | content_size: 512 6 | style_size: 512 7 | 8 | # perceptual loss configurations 9 | network_name: 'vgg_19' 10 | checkpoint_path: '/DATA/lsheng/model_zoo/VGG/vgg_19.ckpt' 11 | checkpoint_exclude_scopes: 'vgg_19/fc' 12 | ignore_missing_vars: True 13 | 14 | # style loss layers 15 | style_loss_layers: 16 | - 'conv1/conv1_1' 17 | - 'conv2/conv2_1' 18 | - 'conv3/conv3_1' 19 | - 'conv4/conv4_1' 20 | 21 | ################################# 22 | # style decorator specification # 23 | ################################# 24 | # patch size for style decorator 25 | patch_size: 5 26 | 27 | # style encoding method 28 | style_coding: 'ZCA' # 'AdaIN' 29 | 30 | # style interpolation 31 | style_interp: 'normalized' 32 | 33 | #################### 34 | # training routine # 35 | #################### 36 | training_image_size: 256 37 | weight_decay: 0.0005 38 | trainable_scopes: 'combined_decoder' 39 | 40 | # loss weights 41 | content_weight: 1.0 42 | recons_weight: 10.0 43 | tv_weight: 10.0 44 | -------------------------------------------------------------------------------- /data/contents/images/avril.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/images/avril.jpg -------------------------------------------------------------------------------- /data/contents/images/brad_pitt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/images/brad_pitt.jpg -------------------------------------------------------------------------------- /data/contents/images/cornell.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/images/cornell.jpg -------------------------------------------------------------------------------- /data/contents/images/flowers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/images/flowers.jpg -------------------------------------------------------------------------------- /data/contents/images/modern.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/images/modern.jpg -------------------------------------------------------------------------------- /data/contents/images/woman_side_portrait.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/images/woman_side_portrait.jpg -------------------------------------------------------------------------------- /data/contents/sequences/frame_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0001.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0002.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0003.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0004.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0005.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0006.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0007.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0008.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0009.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0010.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0011.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0012.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0013.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0013.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0014.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0015.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0016.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0017.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0018.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0019.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0020.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0021.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0022.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0022.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0023.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0023.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0024.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0025.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0026.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0026.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0027.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0027.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0028.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0028.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0029.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0029.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0030.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0030.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0031.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0031.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0032.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0032.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0033.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0033.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0034.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0034.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0035.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0035.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0036.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0036.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0037.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0037.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0038.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0038.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0039.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0039.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0040.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0040.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0041.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0041.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0042.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0042.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0043.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0043.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0044.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0044.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0045.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0045.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0046.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0046.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0047.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0047.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0048.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0049.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0049.png -------------------------------------------------------------------------------- /data/contents/sequences/frame_0050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/contents/sequences/frame_0050.png -------------------------------------------------------------------------------- /data/styles/brushstrokers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/styles/brushstrokers.jpg -------------------------------------------------------------------------------- /data/styles/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/styles/candy.jpg -------------------------------------------------------------------------------- /data/styles/la_muse.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/styles/la_muse.jpg -------------------------------------------------------------------------------- /data/styles/plum_flower.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/styles/plum_flower.jpg -------------------------------------------------------------------------------- /data/styles/woman_in_peasant_dress_cropped.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/data/styles/woman_in_peasant_dress_cropped.jpg -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/datasets/__init__.pyc -------------------------------------------------------------------------------- /datasets/convert_mscoco_to_tfexamples.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import math 6 | import os 7 | import random 8 | import sys 9 | 10 | from datasets import dataset_utils 11 | 12 | import tensorflow as tf 13 | 14 | _NUM_SHARDS = 5 15 | _RANDOM_SEED = 0 16 | 17 | FLAGS = tf.app.flags.FLAGS 18 | 19 | tf.app.flags.DEFINE_string( 20 | 'output_dataset_dir', None, 21 | 'The directory where the outputs TFRecords and temporary files are saved') 22 | 23 | tf.app.flags.DEFINE_string( 24 | 'input_dataset_dir', None, 25 | 'The directory where the input files are saved.') 26 | 27 | 28 | def _get_filenames(dataset_dir): 29 | split_dirs = ['train2014', 'val2014', 'test2014'] 30 | 31 | # get the full path to each image 32 | train_dir = os.path.join(dataset_dir, split_dirs[0]) 33 | validation_dir = os.path.join(dataset_dir, split_dirs[1]) 34 | test_dir = os.path.join(dataset_dir, split_dirs[2]) 35 | 36 | train_image_filenames = [] 37 | for filename in os.listdir(train_dir): 38 | file_path = os.path.join(train_dir, filename) 39 | train_image_filenames.append(file_path) 40 | 41 | validation_image_filenames = [] 42 | for filename in os.listdir(validation_dir): 43 | file_path = os.path.join(validation_dir, filename) 44 | validation_image_filenames.append(file_path) 45 | 46 | test_image_filenames = [] 47 | for filename in os.listdir(test_dir): 48 | file_path = os.path.join(test_dir, filename) 49 | test_image_filenames.append(file_path) 50 | 51 | print('Statistics in MSCOCO dataset...') 52 | print('There are %d images in train dataset' % len(train_image_filenames)) 53 | print('There are %d images in validation dataset' % len(validation_image_filenames)) 54 | print('There are %d images in test dataset' % len(test_image_filenames)) 55 | 56 | return train_image_filenames, validation_image_filenames, test_image_filenames 57 | 58 | 59 | def _get_dataset_filename(dataset_dir, split_name, shard_id): 60 | output_filename = 'MSCOCO_%s_%05d-of-%05d.tfrecord' % ( 61 | split_name, shard_id, _NUM_SHARDS) 62 | return os.path.join(dataset_dir, output_filename) 63 | 64 | 65 | def _convert_dataset(split_name, image_filenames, dataset_dir): 66 | assert split_name in ['train', 'validation', 'test'] 67 | 68 | num_per_shard = int(math.ceil(len(image_filenames) / float(_NUM_SHARDS))) 69 | 70 | with tf.Graph().as_default(): 71 | image_reader = dataset_utils.ImageReader() 72 | 73 | with tf.Session('') as sess: 74 | for shard_id in range(_NUM_SHARDS): 75 | output_filename = _get_dataset_filename( 76 | dataset_dir, split_name, shard_id) 77 | with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: 78 | start_ndx = shard_id * num_per_shard 79 | end_ndx = min((shard_id+1) * num_per_shard, len(image_filenames)) 80 | for i in range(start_ndx, end_ndx): 81 | sys.stdout.write('\r>> Converting image %d/%d shard %d' % ( 82 | i + 1, len(image_filenames), shard_id)) 83 | sys.stdout.flush() 84 | # read the image 85 | img_filename = image_filenames[i] 86 | img_data = tf.gfile.FastGFile(img_filename, 'r').read() 87 | img_shape = image_reader.read_image_dims(sess, img_data) 88 | example = dataset_utils.image_to_tfexample( 89 | img_data, img_filename[-3:], img_shape, img_filename) 90 | tfrecord_writer.write(example.SerializeToString()) 91 | sys.stdout.write('\n') 92 | sys.stdout.flush() 93 | 94 | 95 | def _dataset_exists(dataset_dir): 96 | for split_name in ['train', 'validation', 'test']: 97 | for shard_id in range(_NUM_SHARDS): 98 | output_filename = _get_dataset_filename( 99 | dataset_dir, split_name, shard_id) 100 | if not tf.gfile.Exists(output_filename): 101 | return False 102 | return True 103 | 104 | 105 | def run(input_dataset_dir, output_dataset_dir): 106 | if not tf.gfile.Exists(output_dataset_dir): 107 | tf.gfile.MakeDirs(output_dataset_dir) 108 | 109 | if _dataset_exists(output_dataset_dir): 110 | print('Dataset files already exist. Exiting without re-creating them.') 111 | return 112 | 113 | train_image_filenames, validation_image_filenames, test_image_filenames = \ 114 | _get_filenames(input_dataset_dir) 115 | 116 | # randomize the datasets 117 | random.seed(_RANDOM_SEED) 118 | random.shuffle(train_image_filenames) 119 | random.shuffle(validation_image_filenames) 120 | random.shuffle(test_image_filenames) 121 | 122 | num_train = len(train_image_filenames) 123 | num_validation = len(validation_image_filenames) 124 | num_test = len(test_image_filenames) 125 | num_samples = num_train + num_validation + num_test 126 | 127 | # store the dataset meta data 128 | dataset_meta_data = { 129 | 'dataset_name': 'MSCOCO', 130 | 'source_dataset_dir': input_dataset_dir, 131 | 'num_of_samples': num_samples, 132 | 'num_of_train': num_train, 133 | 'num_of_validation': num_validation, 134 | 'num_of_test': num_test, 135 | 'train_image_filenames': train_image_filenames, 136 | 'validation_image_filenames': validation_image_filenames, 137 | 'test_image_filenames': test_image_filenames} 138 | dataset_utils.write_dataset_meta_data(output_dataset_dir, dataset_meta_data) 139 | 140 | _convert_dataset('train', train_image_filenames, output_dataset_dir) 141 | _convert_dataset('validation', validation_image_filenames, output_dataset_dir) 142 | _convert_dataset('test', test_image_filenames, output_dataset_dir) 143 | 144 | 145 | def main(_): 146 | if not FLAGS.input_dataset_dir: 147 | raise ValueError('You must supply the dataset directory with --dataset_dir') 148 | run(FLAGS.input_dataset_dir, FLAGS.output_dataset_dir) 149 | 150 | 151 | if __name__ == '__main__': 152 | tf.app.run() 153 | -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import yaml 7 | 8 | import tensorflow as tf 9 | 10 | slim = tf.contrib.slim 11 | 12 | _META_DATA_FILENAME = 'dataset_meta_data.txt' 13 | 14 | _FILE_PATTERN = '%s_%s_*.tfrecord' 15 | 16 | _ITEMS_TO_DESCRIPTIONS = { 17 | 'image': 'A color image of varying size.', 18 | 'shape': 'The shape of the image.' 19 | } 20 | 21 | 22 | def int64_feature(values): 23 | if not isinstance(values, (tuple, list)): 24 | values = [values] 25 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 26 | 27 | 28 | def bytes_feature(values): 29 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) 30 | 31 | 32 | def image_to_tfexample(image_data, image_format, image_shape, image_filename): 33 | return tf.train.Example(features=tf.train.Features(feature={ 34 | 'image/encoded': bytes_feature(image_data), 35 | 'image/format': bytes_feature(image_format), 36 | 'image/shape': int64_feature(image_shape), 37 | 'image/filename': bytes_feature(image_filename), 38 | })) 39 | 40 | 41 | def write_dataset_meta_data(dataset_dir, dataset_meta_data, 42 | filename=_META_DATA_FILENAME): 43 | meta_filename = os.path.join(dataset_dir, filename) 44 | with open(meta_filename, 'wb') as f: 45 | yaml.dump(dataset_meta_data, f) 46 | print('Finish writing the dataset meta data.') 47 | 48 | 49 | def has_dataset_meta_data_file(dataset_dir, filename=_META_DATA_FILENAME): 50 | return tf.gfile.Exists(os.path.join(dataset_dir, filename)) 51 | 52 | 53 | def read_dataset_meta_data(dataset_dir, filename=_META_DATA_FILENAME): 54 | meta_filename = os.path.join(dataset_dir, filename) 55 | with open(meta_filename, 'rb') as f: 56 | dataset_meta_data = yaml.load(f) 57 | print('Finish loading the dataset meta data of [%s].' % 58 | dataset_meta_data.get('dataset_name')) 59 | return dataset_meta_data 60 | 61 | 62 | def get_split(dataset_name, 63 | split_name, 64 | dataset_dir, 65 | file_pattern=None, 66 | reader=None): 67 | if split_name not in ['train', 'validation']: 68 | raise ValueError('split name %s was not recognized.' % split_name) 69 | 70 | if not file_pattern: 71 | file_pattern = _FILE_PATTERN 72 | file_pattern = os.path.join(dataset_dir, file_pattern % ( 73 | dataset_name, split_name)) 74 | 75 | # read the dataset meta data 76 | if has_dataset_meta_data_file(dataset_dir): 77 | dataset_meta_data = read_dataset_meta_data(dataset_dir) 78 | num_samples = dataset_meta_data.get('num_of_' + split_name) 79 | else: 80 | raise ValueError('No dataset_meta_data file available in %s' % dataset_dir) 81 | 82 | if reader is None: 83 | reader = tf.TFRecordReader 84 | 85 | keys_to_features = { 86 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 87 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 88 | 'image/shape': tf.FixedLenFeature((3,), tf.int64, default_value=(224, 224, 3)), 89 | 'image/filename': tf.FixedLenFeature([], tf.string, default_value=''), 90 | } 91 | 92 | items_to_handlers = { 93 | 'image': slim.tfexample_decoder.Image( 94 | 'image/encoded', 'image/format'), 95 | 'shape': slim.tfexample_decoder.Tensor('image/shape'), 96 | 'filename': slim.tfexample_decoder.Tensor('image/filename') 97 | } 98 | 99 | decoder = slim.tfexample_decoder.TFExampleDecoder( 100 | keys_to_features, items_to_handlers) 101 | 102 | return slim.dataset.Dataset( 103 | data_sources=file_pattern, 104 | reader=reader, 105 | decoder=decoder, 106 | num_samples=num_samples, 107 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS) 108 | 109 | 110 | class ImageReader(object): 111 | """helper class that provides tensorflow image coding utilities.""" 112 | def __init__(self): 113 | self._decode_data = tf.placeholder(dtype=tf.string) 114 | self._decode_image = tf.image.decode_image(self._decode_data, channels=0) 115 | 116 | def read_image_dims(self, sess, image_data): 117 | image = self.decode_image(sess, image_data) 118 | return image.shape 119 | 120 | def decode_image(self, sess, image_data): 121 | image = sess.run(self._decode_image, 122 | feed_dict={self._decode_data: image_data}) 123 | assert len(image.shape) == 3 124 | assert image.shape[2] == 3 125 | return image 126 | 127 | 128 | class ImageCoder(object): 129 | """helper class that provides Tensorflow Image coding utilities, 130 | also works for corrupted data with incorrected extension 131 | """ 132 | def __init__(self): 133 | self._decode_data = tf.placeholder(dtype=tf.string) 134 | self._decode_image = tf.image.decode_image(self._decode_data, channels=0) 135 | self._encode_jpeg = tf.image.encode_jpeg(self._decode_image, format='rgb', quality=100) 136 | 137 | def decode_image(self, sess, image_data): 138 | # verify the image from the image_data 139 | status = False 140 | try: 141 | # decode image and verify the data 142 | image = sess.run(self._decode_image, 143 | feed_dict={self._decode_data: image_data}) 144 | image_shape = image.shape 145 | assert len(image_shape) == 3 146 | assert image_shape[2] == 3 147 | # encode as RGB JPEG image string and return 148 | image_string = sess.run(self._encode_jpeg, feed_dict={self._decode_data: image_data}) 149 | status = True 150 | except BaseException: 151 | image_shape, image_string = None, None 152 | return status, image_string, image_shape 153 | -------------------------------------------------------------------------------- /datasets/dataset_utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/datasets/dataset_utils.pyc -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman 2 | title: Avatar-Net 3 | description: Multi-scale Zero-shot Style Transfer by Feature Decoration 4 | show_downloads: true 5 | -------------------------------------------------------------------------------- /docs/figures/closed_ups.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/closed_ups.png -------------------------------------------------------------------------------- /docs/figures/image_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/image_results.png -------------------------------------------------------------------------------- /docs/figures/network_architecture_with_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/network_architecture_with_comparison.png -------------------------------------------------------------------------------- /docs/figures/result_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/result_comparison.png -------------------------------------------------------------------------------- /docs/figures/snapshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/snapshot.png -------------------------------------------------------------------------------- /docs/figures/style_decorator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/style_decorator.png -------------------------------------------------------------------------------- /docs/figures/style_interpolation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/style_interpolation.png -------------------------------------------------------------------------------- /docs/figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/teaser.png -------------------------------------------------------------------------------- /docs/figures/trade_off.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/docs/figures/trade_off.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | [teaser]: ./figures/teaser.png 2 | ![teaser] 3 |

4 | Exemplar stylized results by the proposed Avatar-Net, which faithfully transfers the lena image by arbitrary style. 5 |

6 | 7 | ## Overview 8 | 9 | Zero-shot artistic style transfer is an important image synthesis problem aiming at transferring arbitrary style into content images. However, the trade-off between the generalization and efficiency in existing methods impedes a high quality zero-shot style transfer in real-time. In this repository, we resolve this dilemma and propose an efficient yet effective Avatar-Net that enables visually plausible multi-scale transfer for arbitrary style. 10 | 11 | The key ingredient of our method is a __style decorator__ that makes up the content features by semantically aligned style features from an arbitrary style image, which does not only holistically match their feature distributions but also preserve detailed style patterns in the decorated features. 12 | 13 | [style_decorator]: ./figures/style_decorator.png 14 | ![style_decorator] 15 |

16 | Comparison of feature distribution transformation by different feature transfer modules. (a) Adaptive Instance Normalization, (b) Whitening and Coloring Transform, (c) Style-Swap, and (d) the proposed style decorator. 17 |

18 | 19 | By embedding this module into a reconstruction network that fuses multi-scale style abstractions, the Avatar-Net renders multi-scale stylization for any style image in one feed-forward pass. 20 | 21 | [network]: ./figures/network_architecture_with_comparison.png 22 | ![network] 23 |

24 | (a) Stylization comparison by autoencoder and style-augmented hourglass network. (b) The network architecture of the proposed method. 25 |

26 | 27 | ## Results 28 | 29 | [image_results]: ./figures/image_results.png 30 | ![image_results] 31 |

Exemplar stylized results by the proposed Avatar-Net.

32 | 33 | We demonstrate the state-of-the-art effectiveness and efficiency of the proposed method in generating high-quality stylized images, with a series of successful applications including multiple style integration, video stylization and etc. 34 | 35 | #### Comparison with Prior Arts 36 | 37 |

38 | 39 | - The result by Avatar-Net receives concrete multi-scale style patterns (e.g. color distribution, brush strokes and circular patterns in the style image). 40 | - WCT distorts the brush strokes and circular patterns. AdaIN cannot even keep the color distribution, while style-swap fails in this example. 41 | 42 | #### Execution Efficiency 43 | 44 |
45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 |
MethodGatys et. al.AdaINWCTStyle-SwapAvatar-Net
256x256 (sec)12.180.0530.620.0640.071
512x512 (sec)43.250.110.930.230.28
73 |
74 | 75 | - Avatar-Net has a comparable executive time as AdaIN and GPU-accelerated Style-Swap, and is much faster than WCT and the optimization-based style transfer by Gatys _et. al._. 76 | - The reference methods and the proposed Avatar-Net are implemented on a same TensorFlow platform with a same VGG network as the backbone. 77 | 78 | ### Applications 79 | #### Multi-style Interpolation 80 | [style_interpolation]: ./figures/style_interpolation.png 81 | ![style_interpolation] 82 | 83 | #### Content and Style Trade-off 84 | [trade_off]: ./figures/trade_off.png 85 | ![trade_off] 86 | 87 | #### Video Stylization ([the Youtube link](https://youtu.be/amaeqbw6TeA)) 88 | 89 |
90 | 91 |
92 | 93 | ## Code 94 | 95 | Please refer to the [GitHub repository](https://github.com/LucasSheng/avatar-net) for more details. 96 | 97 | ## Publication 98 | 99 | Lu Sheng, Ziyi Lin, Jing Shao and Xiaogang Wang, "Avatar-Net: Multi-scale Zero-shot Style Transfer by Feature Decoration", in IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2018. [[Arxiv](https://arxiv.org/abs/1805.03857)] 100 | 101 | ``` 102 | @inproceedings{sheng2018avatar, 103 | Title = {Avatar-Net: Multi-scale Zero-shot Style Transfer by Feature Decoration}, 104 | author = {Sheng, Lu and Lin, Ziyi and Shao, Jing and Wang, Xiaogang}, 105 | Booktitle = {Computer Vision and Pattern Recognition (CVPR), 2018 IEEE Conference on}, 106 | pages={1--9}, 107 | year={2018} 108 | } 109 | ``` 110 | -------------------------------------------------------------------------------- /evaluate_style_transfer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import time 7 | import scipy.misc 8 | import numpy as np 9 | import tensorflow as tf 10 | 11 | from models import models_factory 12 | from models import preprocessing 13 | 14 | from PIL import Image 15 | 16 | slim = tf.contrib.slim 17 | 18 | 19 | tf.app.flags.DEFINE_string( 20 | 'checkpoint_dir', 'tmp/tfmodel', 21 | 'The directory where the model was written to or an absolute path to a ' 22 | 'checkpoint file.') 23 | tf.app.flags.DEFINE_string( 24 | 'eval_dir', 'tmp/tfmodel', 25 | 'Directory where the results are saved to.') 26 | tf.app.flags.DEFINE_string( 27 | 'content_dataset_dir', None, 28 | 'The content directory where the test images are stored.') 29 | tf.app.flags.DEFINE_string( 30 | 'style_dataset_dir', None, 31 | 'The style directory where the style images are stored.') 32 | 33 | # choose the model configuration file 34 | tf.app.flags.DEFINE_string( 35 | 'model_config_path', None, 36 | 'The path of the model configuration file.') 37 | tf.app.flags.DEFINE_float( 38 | 'inter_weight', 1.0, 39 | 'The blending weight of the style patterns in the stylized image') 40 | 41 | FLAGS = tf.app.flags.FLAGS 42 | 43 | 44 | def get_image_filenames(dataset_dir): 45 | """helper fn that provides the full image filenames from the dataset_dir""" 46 | image_filenames = [] 47 | for filename in os.listdir(dataset_dir): 48 | file_path = os.path.join(dataset_dir, filename) 49 | image_filenames.append(file_path) 50 | return image_filenames 51 | 52 | 53 | def image_reader(filename): 54 | """help fn that provides numpy image coding utilities""" 55 | img = scipy.misc.imread(filename).astype(np.float) 56 | if len(img.shape) == 2: 57 | img = np.dstack((img, img, img)) 58 | elif img.shape[2] == 4: 59 | img = img[:, :, :3] 60 | return img 61 | 62 | 63 | def imsave(filename, img): 64 | img = np.clip(img, 0, 255).astype(np.uint8) 65 | Image.fromarray(img).save(filename, quality=95) 66 | 67 | 68 | def main(_): 69 | if not FLAGS.content_dataset_dir: 70 | raise ValueError('You must supply the content dataset directory ' 71 | 'with --content_dataset_dir') 72 | if not FLAGS.style_dataset_dir: 73 | raise ValueError('You must supply the style dataset directory ' 74 | 'with --style_dataset_dir') 75 | 76 | if not FLAGS.checkpoint_dir: 77 | raise ValueError('You must supply the checkpoints directory with ' 78 | '--checkpoint_dir') 79 | 80 | if tf.gfile.IsDirectory(FLAGS.checkpoint_dir): 81 | checkpoint_dir = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) 82 | else: 83 | checkpoint_dir = FLAGS.checkpoint_dir 84 | 85 | if not tf.gfile.Exists(FLAGS.eval_dir): 86 | tf.gfile.MakeDirs(FLAGS.eval_dir) 87 | 88 | tf.logging.set_verbosity(tf.logging.INFO) 89 | with tf.Graph().as_default(): 90 | # define the model 91 | style_model, options = models_factory.get_model(FLAGS.model_config_path) 92 | 93 | # predict the stylized image 94 | inp_content_image = tf.placeholder(tf.float32, shape=(None, None, 3)) 95 | inp_style_image = tf.placeholder(tf.float32, shape=(None, None, 3)) 96 | 97 | # preprocess the content and style images 98 | content_image = preprocessing.mean_image_subtraction(inp_content_image) 99 | content_image = tf.expand_dims(content_image, axis=0) 100 | # style resizing and cropping 101 | style_image = preprocessing.preprocessing_image( 102 | inp_style_image, 103 | 448, 104 | 448, 105 | style_model.style_size) 106 | style_image = tf.expand_dims(style_image, axis=0) 107 | 108 | # style transfer 109 | stylized_image = style_model.transfer_styles( 110 | content_image, 111 | style_image, 112 | inter_weight=FLAGS.inter_weight) 113 | stylized_image = tf.squeeze(stylized_image, axis=0) 114 | 115 | # gather the test image filenames and style image filenames 116 | style_image_filenames = get_image_filenames(FLAGS.style_dataset_dir) 117 | content_image_filenames = get_image_filenames(FLAGS.content_dataset_dir) 118 | 119 | # starting inference of the images 120 | init_fn = slim.assign_from_checkpoint_fn( 121 | checkpoint_dir, slim.get_model_variables(), ignore_missing_vars=True) 122 | with tf.Session() as sess: 123 | # initialize the graph 124 | init_fn(sess) 125 | 126 | nn = 0.0 127 | total_time = 0.0 128 | # style transfer for each image based on one style image 129 | for i in range(len(style_image_filenames)): 130 | # gather the storage folder for the style transfer 131 | style_label = style_image_filenames[i].split('/')[-1] 132 | style_label = style_label.split('.')[0] 133 | style_dir = os.path.join(FLAGS.eval_dir, style_label) 134 | 135 | if not tf.gfile.Exists(style_dir): 136 | tf.gfile.MakeDirs(style_dir) 137 | 138 | # get the style image 139 | np_style_image = image_reader(style_image_filenames[i]) 140 | print('Starting transferring the style of [%s]' % style_label) 141 | 142 | for j in range(len(content_image_filenames)): 143 | # gather the content image 144 | np_content_image = image_reader(content_image_filenames[j]) 145 | 146 | start_time = time.time() 147 | np_stylized_image = sess.run(stylized_image, 148 | feed_dict={inp_content_image: np_content_image, 149 | inp_style_image: np_style_image}) 150 | incre_time = time.time() - start_time 151 | nn += 1.0 152 | total_time += incre_time 153 | print("---%s seconds ---" % (total_time/nn)) 154 | 155 | output_filename = os.path.join( 156 | style_dir, content_image_filenames[j].split('/')[-1]) 157 | imsave(output_filename, np_stylized_image) 158 | print('Style [%s]: Finish transfer the image [%s]' % ( 159 | style_label, content_image_filenames[j])) 160 | 161 | 162 | if __name__ == '__main__': 163 | tf.app.run() 164 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/__init__.py -------------------------------------------------------------------------------- /models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/__init__.pyc -------------------------------------------------------------------------------- /models/autoencoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import tensorflow as tf 6 | 7 | from models import losses 8 | from models import preprocessing 9 | from models import vgg 10 | from models import vgg_decoder 11 | 12 | slim = tf.contrib.slim 13 | 14 | network_map = { 15 | 'vgg_16': vgg.vgg_16, 16 | 'vgg_19': vgg.vgg_19, 17 | } 18 | 19 | 20 | class AutoEncoder(object): 21 | def __init__(self, options): 22 | self.weight_decay = options.get('weight_decay') 23 | 24 | self.default_size = options.get('default_size') 25 | self.content_size = options.get('content_size') 26 | 27 | # network architecture 28 | self.network_name = options.get('network_name') 29 | 30 | # the loss layers for content and style similarity 31 | self.content_layers = options.get('content_layers') 32 | 33 | # the weights for the losses when trains the invertible network 34 | self.content_weight = options.get('content_weight') 35 | self.recons_weight = options.get('recons_weight') 36 | self.tv_weight = options.get('tv_weight') 37 | 38 | # gather the summaries and initialize the losses 39 | self.summaries = None 40 | self.total_loss = 0.0 41 | self.recons_loss = {} 42 | self.content_loss = {} 43 | self.tv_loss = {} 44 | self.train_op = None 45 | 46 | def auto_encoder(self, inputs, content_layer=2, reuse=True): 47 | # extract the content features 48 | image_features = losses.extract_image_features(inputs, self.network_name) 49 | content_features = losses.compute_content_features(image_features, self.content_layers) 50 | 51 | # used content feature 52 | selected_layer = self.content_layers[content_layer] 53 | content_feature = content_features[selected_layer] 54 | input_content_features = {selected_layer: content_feature} 55 | 56 | # reconstruct the images 57 | with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope(self.weight_decay)): 58 | outputs = vgg_decoder.vgg_decoder( 59 | content_feature, 60 | self.network_name, 61 | selected_layer, 62 | reuse=reuse, 63 | scope='decoder_%d' % content_layer) 64 | return outputs, input_content_features 65 | 66 | def build_train_graph(self, inputs): 67 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) 68 | for i in range(len(self.content_layers)): 69 | # skip some networks 70 | if i < 3: 71 | continue 72 | 73 | selected_layer = self.content_layers[i] 74 | 75 | outputs, inputs_content_features = self.auto_encoder( 76 | inputs, content_layer=i, reuse=False) 77 | outputs = preprocessing.batch_mean_image_subtraction(outputs) 78 | 79 | ######################## 80 | # construct the losses # 81 | ######################## 82 | # 1) reconstruction loss 83 | recons_loss = tf.losses.mean_squared_error( 84 | inputs, outputs, scope='recons_loss/decoder_%d' % i) 85 | self.recons_loss[selected_layer] = recons_loss 86 | self.total_loss += self.recons_weight * recons_loss 87 | summaries.add(tf.summary.scalar( 88 | 'recons_loss/decoder_%d' % i, recons_loss)) 89 | 90 | # 2) content loss 91 | outputs_image_features = losses.extract_image_features( 92 | outputs, self.network_name) 93 | outputs_content_features = losses.compute_content_features( 94 | outputs_image_features, [selected_layer]) 95 | content_loss = losses.compute_content_loss( 96 | outputs_content_features, inputs_content_features, [selected_layer]) 97 | self.content_loss[selected_layer] = content_loss 98 | self.total_loss += self.content_weight * content_loss 99 | summaries.add(tf.summary.scalar( 100 | 'content_loss/decoder_%d' % i, content_loss)) 101 | 102 | # 3) total variation loss 103 | tv_loss = losses.compute_total_variation_loss_l1(outputs) 104 | self.tv_loss[selected_layer] = tv_loss 105 | self.total_loss += self.tv_weight * tv_loss 106 | summaries.add(tf.summary.scalar( 107 | 'tv_loss/decoder_%d' % i, tv_loss)) 108 | 109 | image_tiles = tf.concat([inputs, outputs], axis=2) 110 | image_tiles = preprocessing.batch_mean_image_summation(image_tiles) 111 | image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0), tf.uint8) 112 | summaries.add(tf.summary.image( 113 | 'image_comparison/decoder_%d' % i, image_tiles, max_outputs=8)) 114 | 115 | self.summaries = summaries 116 | return self.total_loss 117 | 118 | def get_training_operations(self, optimizer, global_step, 119 | variables_to_train=tf.trainable_variables()): 120 | # gather the variable summaries 121 | variables_summaries = [] 122 | for var in variables_to_train: 123 | variables_summaries.append(tf.summary.histogram(var.op.name, var)) 124 | variables_summaries = set(variables_summaries) 125 | 126 | # add the training operations 127 | train_ops = [] 128 | 129 | grads_and_vars = optimizer.compute_gradients( 130 | self.total_loss, var_list=variables_to_train) 131 | train_op = optimizer.apply_gradients( 132 | grads_and_vars, global_step=global_step) 133 | train_ops.append(train_op) 134 | 135 | self.summaries |= variables_summaries 136 | self.train_op = tf.group(*train_ops) 137 | return self.train_op 138 | -------------------------------------------------------------------------------- /models/avatar_net.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import tensorflow as tf 6 | 7 | from models import losses 8 | from models import network_ops 9 | from models import vgg 10 | from models import vgg_decoder 11 | from models import preprocessing 12 | 13 | slim = tf.contrib.slim 14 | 15 | network_map = { 16 | 'vgg_16': vgg.vgg_16, 17 | 'vgg_19': vgg.vgg_19, 18 | } 19 | 20 | 21 | class AvatarNet(object): 22 | def __init__(self, options): 23 | self.training_image_size = options.get('training_image_size') 24 | self.content_size = options.get('content_size') 25 | self.style_size = options.get('style_size') 26 | 27 | # network architecture 28 | self.network_name = options.get('network_name') 29 | 30 | # the loss layers for content and style similarity 31 | self.style_loss_layers = options.get('style_loss_layers') 32 | 33 | ########################## 34 | # style decorator option # 35 | ########################## 36 | # style coding method 37 | self.style_coding = options.get('style_coding') 38 | 39 | # style interpolation method 40 | self.style_interp = options.get('style_interp') 41 | 42 | # window size 43 | self.patch_size = options.get('patch_size') 44 | 45 | ####################### 46 | # training quantities # 47 | ####################### 48 | self.content_weight = options.get('content_weight') 49 | self.recons_weight = options.get('recons_weight') 50 | self.tv_weight = options.get('tv_weight') 51 | self.weight_decay = options.get('weight_decay') 52 | 53 | ############################################## 54 | # gather summaries and initialize the losses # 55 | ############################################## 56 | self.total_loss = 0.0 57 | self.recons_loss = None 58 | self.content_loss = None 59 | self.tv_loss = None 60 | 61 | ############################ 62 | # summary and training ops # 63 | ############################ 64 | self.train_op = None 65 | self.summaries = None 66 | 67 | def transfer_styles(self, 68 | inputs, 69 | styles, 70 | inter_weight=1.0, 71 | intra_weights=(1,)): 72 | """transfer the content image by style images 73 | 74 | Args: 75 | inputs: input images [batch_size, height, width, channel] 76 | styles: a list of input styles, in default the size is 1 77 | inter_weight: the blending weight between the content and style 78 | intra_weights: a list of blending weights among the styles, 79 | in default it is (1,) 80 | 81 | Returns: 82 | outputs: the stylized images [batch_size, height, width, channel] 83 | """ 84 | if not isinstance(styles, (list, tuple)): 85 | styles = [styles] 86 | 87 | if not isinstance(intra_weights, (list, tuple)): 88 | intra_weights = [intra_weights] 89 | 90 | # 1) extract the style features 91 | styles_features = [] 92 | for style in styles: 93 | style_image_features = losses.extract_image_features( 94 | style, self.network_name) 95 | style_features = losses.compute_content_features( 96 | style_image_features, self.style_loss_layers) 97 | styles_features.append(style_features) 98 | 99 | # 2) content features 100 | inputs_image_features = losses.extract_image_features( 101 | inputs, self.network_name) 102 | inputs_features = losses.compute_content_features( 103 | inputs_image_features, self.style_loss_layers) 104 | 105 | # 3) style decorator 106 | # the applied content feature from the content input 107 | selected_layer = self.style_loss_layers[-1] 108 | hidden_feature = inputs_features[selected_layer] 109 | 110 | # applying the style decorator 111 | blended_feature = 0.0 112 | n = 0 113 | for style_features in styles_features: 114 | swapped_feature = style_decorator( 115 | hidden_feature, 116 | style_features[selected_layer], 117 | style_coding=self.style_coding, 118 | style_interp=self.style_interp, 119 | ratio_interp=inter_weight, 120 | patch_size=self.patch_size) 121 | blended_feature += intra_weights[n] * swapped_feature 122 | n += 1 123 | 124 | # 4) decode the hidden feature to the output image 125 | with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope()): 126 | outputs = vgg_decoder.vgg_multiple_combined_decoder( 127 | blended_feature, 128 | styles_features, 129 | intra_weights, 130 | fusion_fn=network_ops.adaptive_instance_normalization, 131 | network_name=self.network_name, 132 | starting_layer=selected_layer) 133 | return outputs 134 | 135 | def hierarchical_autoencoder(self, inputs, reuse=True): 136 | """hierarchical autoencoder for content reconstruction""" 137 | # extract the content features 138 | image_features = losses.extract_image_features( 139 | inputs, self.network_name) 140 | content_features = losses.compute_content_features( 141 | image_features, self.style_loss_layers) 142 | 143 | # the applied content feature for the decode network 144 | selected_layer = self.style_loss_layers[-1] 145 | hidden_feature = content_features[selected_layer] 146 | 147 | # decode the hidden feature to the output image 148 | with slim.arg_scope(vgg_decoder.vgg_decoder_arg_scope(self.weight_decay)): 149 | outputs = vgg_decoder.vgg_combined_decoder( 150 | hidden_feature, 151 | content_features, 152 | fusion_fn=network_ops.adaptive_instance_normalization, 153 | network_name=self.network_name, 154 | starting_layer=selected_layer, 155 | reuse=reuse) 156 | return outputs 157 | 158 | def build_train_graph(self, inputs): 159 | """build the training graph for the training of the hierarchical autoencoder""" 160 | outputs = self.hierarchical_autoencoder(inputs, reuse=False) 161 | outputs = preprocessing.batch_mean_image_subtraction(outputs) 162 | 163 | # summaries 164 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) 165 | 166 | ######################## 167 | # construct the losses # 168 | ######################## 169 | # 1) reconstruction loss 170 | if self.recons_weight > 0.0: 171 | recons_loss = tf.losses.mean_squared_error( 172 | inputs, outputs, weights=self.recons_weight, scope='recons_loss') 173 | self.recons_loss = recons_loss 174 | self.total_loss += recons_loss 175 | summaries.add(tf.summary.scalar('losses/recons_loss', recons_loss)) 176 | 177 | # 2) content loss 178 | if self.content_weight > 0.0: 179 | outputs_image_features = losses.extract_image_features( 180 | outputs, self.network_name) 181 | outputs_content_features = losses.compute_content_features( 182 | outputs_image_features, self.style_loss_layers) 183 | 184 | inputs_image_features = losses.extract_image_features( 185 | inputs, self.network_name) 186 | inputs_content_features = losses.compute_content_features( 187 | inputs_image_features, self.style_loss_layers) 188 | 189 | content_loss = losses.compute_content_loss( 190 | outputs_content_features, inputs_content_features, 191 | content_loss_layers=self.style_loss_layers, weights=self.content_weight) 192 | self.content_loss = content_loss 193 | self.total_loss += content_loss 194 | summaries.add(tf.summary.scalar('losses/content_loss', content_loss)) 195 | 196 | # 3) total variation loss 197 | if self.tv_weight > 0.0: 198 | tv_loss = losses.compute_total_variation_loss_l1(outputs, self.tv_weight) 199 | self.tv_loss = tv_loss 200 | self.total_loss += tv_loss 201 | summaries.add(tf.summary.scalar('losses/tv_loss', tv_loss)) 202 | 203 | image_tiles = tf.concat([inputs, outputs], axis=2) 204 | image_tiles = preprocessing.batch_mean_image_summation(image_tiles) 205 | image_tiles = tf.cast(tf.clip_by_value(image_tiles, 0.0, 255.0), tf.uint8) 206 | summaries.add(tf.summary.image('image_comparison', image_tiles, max_outputs=8)) 207 | 208 | self.summaries = summaries 209 | return self.total_loss 210 | 211 | def get_training_operations(self, 212 | optimizer, 213 | global_step, 214 | variables_to_train=tf.trainable_variables()): 215 | # gather the variable summaries 216 | variables_summaries = [] 217 | for var in variables_to_train: 218 | variables_summaries.append(tf.summary.histogram(var.op.name, var)) 219 | variables_summaries = set(variables_summaries) 220 | 221 | # add the training operations 222 | train_ops = [] 223 | grads_and_vars = optimizer.compute_gradients( 224 | self.total_loss, var_list=variables_to_train) 225 | train_op = optimizer.apply_gradients( 226 | grads_and_vars=grads_and_vars, 227 | global_step=global_step) 228 | train_ops.append(train_op) 229 | 230 | self.summaries |= variables_summaries 231 | self.train_op = tf.group(*train_ops) 232 | return self.train_op 233 | 234 | 235 | def style_decorator(content_features, 236 | style_features, 237 | style_coding='ZCA', 238 | style_interp='normalized', 239 | ratio_interp=1.0, 240 | patch_size=3): 241 | """style decorator for high-level feature interaction 242 | 243 | Args: 244 | content_features: a tensor of size [batch_size, height, width, channel] 245 | style_features: a tensor of size [batch_size, height, width, channel] 246 | style_coding: projection and reconstruction method for style coding 247 | style_interp: interpolation option 248 | ratio_interp: interpolation ratio 249 | patch_size: a 0D tensor or int about the size of the patch 250 | """ 251 | # feature projection 252 | projected_content_features, _, _ = \ 253 | project_features(content_features, projection_module=style_coding) 254 | projected_style_features, style_kernels, mean_style_features = \ 255 | project_features(style_features, projection_module=style_coding) 256 | 257 | # feature rearrangement 258 | rearranged_features = nearest_patch_swapping( 259 | projected_content_features, projected_style_features, patch_size=patch_size) 260 | if style_interp == 'normalized': 261 | rearranged_features = ratio_interp * rearranged_features + \ 262 | (1 - ratio_interp) * projected_content_features 263 | 264 | # feature reconstruction 265 | reconstructed_features = reconstruct_features( 266 | rearranged_features, 267 | style_kernels, 268 | mean_style_features, 269 | reconstruction_module=style_coding) 270 | 271 | if style_interp == 'biased': 272 | reconstructed_features = ratio_interp * reconstructed_features + \ 273 | (1 - ratio_interp) * content_features 274 | 275 | return reconstructed_features 276 | 277 | 278 | def project_features(features, projection_module='ZCA'): 279 | if projection_module == 'ZCA': 280 | return zca_normalization(features) 281 | elif projection_module == 'AdaIN': 282 | return adain_normalization(features) 283 | else: 284 | return features, None, None 285 | 286 | 287 | def reconstruct_features(projected_features, 288 | feature_kernels, 289 | mean_features, 290 | reconstruction_module='ZCA'): 291 | if reconstruction_module == 'ZCA': 292 | return zca_colorization(projected_features, feature_kernels, mean_features) 293 | elif reconstruction_module == 'AdaIN': 294 | return adain_colorization(projected_features, feature_kernels, mean_features) 295 | else: 296 | return projected_features 297 | 298 | 299 | def nearest_patch_swapping(content_features, style_features, patch_size=3): 300 | # channels for both the content and style, must be the same 301 | c_shape = tf.shape(content_features) 302 | s_shape = tf.shape(style_features) 303 | channel_assertion = tf.Assert( 304 | tf.equal(c_shape[3], s_shape[3]), ['number of channels must be the same']) 305 | 306 | with tf.control_dependencies([channel_assertion]): 307 | # spatial shapes for style and content features 308 | c_height, c_width, c_channel = c_shape[1], c_shape[2], c_shape[3] 309 | 310 | # convert the style features into convolutional kernels 311 | style_kernels = tf.extract_image_patches( 312 | style_features, ksizes=[1, patch_size, patch_size, 1], 313 | strides=[1, 1, 1, 1], rates=[1, 1, 1, 1], padding='SAME') 314 | style_kernels = tf.squeeze(style_kernels, axis=0) 315 | style_kernels = tf.transpose(style_kernels, perm=[2, 0, 1]) 316 | 317 | # gather the conv and deconv kernels 318 | v_height, v_width = style_kernels.get_shape().as_list()[1:3] 319 | deconv_kernels = tf.reshape( 320 | style_kernels, shape=(patch_size, patch_size, c_channel, v_height*v_width)) 321 | 322 | kernels_norm = tf.norm(style_kernels, axis=0, keep_dims=True) 323 | kernels_norm = tf.reshape(kernels_norm, shape=(1, 1, 1, v_height*v_width)) 324 | 325 | # calculate the normalization factor 326 | mask = tf.ones((c_height, c_width), tf.float32) 327 | fullmask = tf.zeros((c_height+patch_size-1, c_width+patch_size-1), tf.float32) 328 | for x in range(patch_size): 329 | for y in range(patch_size): 330 | paddings = [[x, patch_size-x-1], [y, patch_size-y-1]] 331 | padded_mask = tf.pad(mask, paddings=paddings, mode="CONSTANT") 332 | fullmask += padded_mask 333 | pad_width = int((patch_size-1)/2) 334 | deconv_norm = tf.slice(fullmask, [pad_width, pad_width], [c_height, c_width]) 335 | deconv_norm = tf.reshape(deconv_norm, shape=(1, c_height, c_width, 1)) 336 | 337 | ######################## 338 | # starting convolution # 339 | ######################## 340 | # padding operation 341 | pad_total = patch_size - 1 342 | pad_beg = pad_total // 2 343 | pad_end = pad_total - pad_beg 344 | paddings = [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]] 345 | 346 | # convolutional operations 347 | net = tf.pad(content_features, paddings=paddings, mode="REFLECT") 348 | net = tf.nn.conv2d( 349 | net, 350 | tf.div(deconv_kernels, kernels_norm+1e-7), 351 | strides=[1, 1, 1, 1], 352 | padding='VALID') 353 | # find the maximum locations 354 | best_match_ids = tf.argmax(net, axis=3) 355 | best_match_ids = tf.cast( 356 | tf.one_hot(best_match_ids, depth=v_height*v_width), dtype=tf.float32) 357 | 358 | # find the patches and warping the output 359 | unnormalized_output = tf.nn.conv2d_transpose( 360 | value=best_match_ids, 361 | filter=deconv_kernels, 362 | output_shape=(c_shape[0], c_height+pad_total, c_width+pad_total, c_channel), 363 | strides=[1, 1, 1, 1], 364 | padding='VALID') 365 | unnormalized_output = tf.slice(unnormalized_output, [0, pad_beg, pad_beg, 0], c_shape) 366 | output = tf.div(unnormalized_output, deconv_norm) 367 | output = tf.reshape(output, shape=c_shape) 368 | 369 | # output the swapped feature maps 370 | return output 371 | 372 | 373 | def zca_normalization(features): 374 | shape = tf.shape(features) 375 | 376 | # reshape the features to orderless feature vectors 377 | mean_features = tf.reduce_mean(features, axis=[1, 2], keep_dims=True) 378 | unbiased_features = tf.reshape(features - mean_features, shape=(shape[0], -1, shape[3])) 379 | 380 | # get the covariance matrix 381 | gram = tf.matmul(unbiased_features, unbiased_features, transpose_a=True) 382 | gram /= tf.reduce_prod(tf.cast(shape[1:3], tf.float32)) 383 | 384 | # converting the feature spaces 385 | s, u, v = tf.svd(gram, compute_uv=True) 386 | s = tf.expand_dims(s, axis=1) # let it be active in the last dimension 387 | 388 | # get the effective singular values 389 | valid_index = tf.cast(s > 0.00001, dtype=tf.float32) 390 | s_effective = tf.maximum(s, 0.00001) 391 | sqrt_s_effective = tf.sqrt(s_effective) * valid_index 392 | sqrt_inv_s_effective = tf.sqrt(1.0/s_effective) * valid_index 393 | 394 | # colorization functions 395 | colorization_kernel = tf.matmul(tf.multiply(u, sqrt_s_effective), v, transpose_b=True) 396 | 397 | # normalized features 398 | normalized_features = tf.matmul(unbiased_features, u) 399 | normalized_features = tf.multiply(normalized_features, sqrt_inv_s_effective) 400 | normalized_features = tf.matmul(normalized_features, v, transpose_b=True) 401 | normalized_features = tf.reshape(normalized_features, shape=shape) 402 | 403 | return normalized_features, colorization_kernel, mean_features 404 | 405 | 406 | def zca_colorization(normalized_features, colorization_kernel, mean_features): 407 | # broadcasting the tensors for matrix multiplication 408 | shape = tf.shape(normalized_features) 409 | normalized_features = tf.reshape( 410 | normalized_features, shape=(shape[0], -1, shape[3])) 411 | colorized_features = tf.matmul(normalized_features, colorization_kernel) 412 | colorized_features = tf.reshape(colorized_features, shape=shape) + mean_features 413 | return colorized_features 414 | 415 | 416 | def adain_normalization(features): 417 | epsilon = 1e-7 418 | mean_features, colorization_kernels = tf.nn.moments(features, [1, 2], keep_dims=True) 419 | normalized_features = tf.div( 420 | tf.subtract(features, mean_features), tf.sqrt(tf.add(colorization_kernels, epsilon))) 421 | return normalized_features, colorization_kernels, mean_features 422 | 423 | 424 | def adain_colorization(normalized_features, colorization_kernels, mean_features): 425 | return tf.sqrt(colorization_kernels) * normalized_features + mean_features 426 | -------------------------------------------------------------------------------- /models/avatar_net.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/avatar_net.pyc -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | from models import vgg 8 | 9 | slim = tf.contrib.slim 10 | 11 | network_map = { 12 | 'vgg_16': vgg.vgg_16, 13 | 'vgg_19': vgg.vgg_19, 14 | } 15 | 16 | 17 | def compute_gram_matrix(feature): 18 | """compute the gram matrix for a layer of feature 19 | 20 | the gram matrix is normalized with respect to the samples and 21 | the dimensions of the input features 22 | 23 | """ 24 | shape = tf.shape(feature) 25 | feature_size = tf.reduce_prod(shape[1:]) 26 | vectorized_feature = tf.reshape( 27 | feature, [shape[0], -1, shape[3]]) 28 | gram_matrix = tf.matmul( 29 | vectorized_feature, vectorized_feature, transpose_a=True) 30 | gram_matrix /= tf.to_float(feature_size) 31 | return gram_matrix 32 | 33 | 34 | def compute_sufficient_statistics(feature): 35 | """compute the gram matrix for a layer of feature""" 36 | mean_feature, var_feature = tf.nn.moments(feature, [1, 2], keep_dims=True) 37 | std_feature = tf.sqrt(var_feature) 38 | sufficient_statistics = tf.concat([mean_feature, std_feature], axis=3) 39 | return sufficient_statistics 40 | 41 | 42 | def compute_content_features(features, content_loss_layers): 43 | """compute the content features from the end_point dict""" 44 | content_features = {} 45 | instance_label = features.keys()[0] 46 | instance_label = instance_label[:-14] # TODO: ugly code, need fix 47 | for layer in content_loss_layers: 48 | content_features[layer] = features[instance_label + '/' + layer] 49 | return content_features 50 | 51 | 52 | def compute_style_features(features, style_loss_layers): 53 | """compute the style features from the end_point dict""" 54 | style_features = {} 55 | instance_label = features.keys()[0] 56 | instance_label = instance_label[:-14] # TODO: ugly code, need fix 57 | for layer in style_loss_layers: 58 | style_features[layer] = compute_gram_matrix( 59 | features[instance_label + '/' + layer]) 60 | return style_features 61 | 62 | 63 | def compute_approximate_style_features(features, style_loss_layers): 64 | style_features = {} 65 | instance_label = features.keys()[0].split('/')[:-2] 66 | for layer in style_loss_layers: 67 | style_features[layer] = compute_sufficient_statistics( 68 | features[instance_label + '/' + layer]) 69 | return style_features 70 | 71 | 72 | def extract_image_features(inputs, network_name, reuse=True): 73 | """compute the dict of layer-wise image features from a given list of networks 74 | 75 | Args: 76 | inputs: the inputs image should be normalized between [-127.5, 127.5] 77 | network_name: the network name for the perceptual loss 78 | reuse: whether to reuse the parameters 79 | 80 | Returns: 81 | end_points: a dict for the image features of the inputs 82 | """ 83 | with slim.arg_scope(vgg.vgg_arg_scope()): 84 | _, end_points = network_map[network_name]( 85 | inputs, spatial_squeeze=False, is_training=False, reuse=reuse) 86 | return end_points 87 | 88 | 89 | def compute_content_and_style_features(inputs, 90 | network_name, 91 | content_loss_layers, 92 | style_loss_layers): 93 | """compute the content and style features from normalized image 94 | 95 | Args: 96 | inputs: input tensor of size [batch, height, width, channel] 97 | network_name: a string of the network name 98 | content_loss_layers: a dict about the layers for the content loss 99 | style_loss_layers: a dict about the layers for the style loss 100 | 101 | Returns: 102 | a dict of the features of the inputs 103 | """ 104 | end_points = extract_image_features(inputs, network_name) 105 | 106 | content_features = compute_content_features(end_points, content_loss_layers) 107 | style_features = compute_style_features(end_points, style_loss_layers) 108 | 109 | return content_features, style_features 110 | 111 | 112 | def compute_content_loss(content_features, target_features, 113 | content_loss_layers, weights=1, scope=None): 114 | """compute the content loss 115 | 116 | Args: 117 | content_features: a dict of the features of the input image 118 | target_features: a dict of the features of the output image 119 | content_loss_layers: a dict about the layers for the content loss 120 | weights: the weights for this loss 121 | scope: optional scope 122 | 123 | Returns: 124 | the content loss 125 | """ 126 | with tf.variable_scope(scope, 'content_loss', [content_features, target_features]): 127 | content_loss = 0 128 | for layer in content_loss_layers: 129 | content_feature = content_features[layer] 130 | target_feature = target_features[layer] 131 | content_loss += tf.losses.mean_squared_error( 132 | target_feature, content_feature, weights=weights, scope=layer) 133 | return content_loss 134 | 135 | 136 | def compute_style_loss(style_features, target_features, 137 | style_loss_layers, weights=1, scope=None): 138 | """compute the style loss 139 | 140 | Args: 141 | style_features: a dict of the Gram matrices of the style image 142 | target_features: a dict of the Gram matrices of the target image 143 | style_loss_layers: a dict of layers of features for the style loss 144 | weights: the weights for this loss 145 | scope: optional scope 146 | 147 | Returns: 148 | the style loss 149 | """ 150 | with tf.variable_scope(scope, 'style_loss', [style_features, target_features]): 151 | style_loss = 0 152 | for layer in style_loss_layers: 153 | style_feature = style_features[layer] 154 | target_feature = target_features[layer] 155 | style_loss += tf.losses.mean_squared_error( 156 | style_feature, target_feature, weights=weights, scope=layer) 157 | return style_loss 158 | 159 | 160 | def compute_approximate_style_loss(style_features, target_features, 161 | style_loss_layers, scope=None): 162 | """compute the approximate style loss 163 | 164 | Args: 165 | style_features: a dict of the sufficient statistics of the 166 | feature maps of the style image 167 | target_features: a dict of the sufficient statistics of the 168 | feature maps of the target image 169 | style_loss_layers: a dict of layers of features for the style loss 170 | scope: optional scope 171 | 172 | Returns: 173 | the style loss 174 | """ 175 | with tf.variable_scope(scope, 'approximated_style_loss', [style_features, target_features]): 176 | style_loss = 0 177 | for layer in style_loss_layers: 178 | style_feature = style_features[layer] 179 | target_feature = target_features[layer] 180 | # we only normalize with respect to the number of channel 181 | style_loss_per_layer = tf.reduce_sum(tf.square(style_feature-target_feature), axis=[1, 2, 3]) 182 | style_loss += tf.reduce_mean(style_loss_per_layer) 183 | return style_loss 184 | 185 | 186 | def compute_total_variation_loss_l2(inputs, weights=1, scope=None): 187 | """compute the total variation loss""" 188 | inputs_shape = tf.shape(inputs) 189 | height = inputs_shape[1] 190 | width = inputs_shape[2] 191 | 192 | with tf.variable_scope(scope, 'total_variation_loss', [inputs]): 193 | loss_y = tf.losses.mean_squared_error( 194 | tf.slice(inputs, [0, 0, 0, 0], [-1, height-1, -1, -1]), 195 | tf.slice(inputs, [0, 1, 0, 0], [-1, -1, -1, -1]), 196 | weights=weights, 197 | scope='loss_y') 198 | loss_x = tf.losses.mean_squared_error( 199 | tf.slice(inputs, [0, 0, 0, 0], [-1, -1, width-1, -1]), 200 | tf.slice(inputs, [0, 0, 1, 0], [-1, -1, -1, -1]), 201 | weights=weights, 202 | scope='loss_x') 203 | loss = loss_y + loss_x 204 | return loss 205 | 206 | 207 | def compute_total_variation_loss_l1(inputs, weights=1, scope=None): 208 | """compute the total variation loss L1 norm""" 209 | inputs_shape = tf.shape(inputs) 210 | height = inputs_shape[1] 211 | width = inputs_shape[2] 212 | 213 | with tf.variable_scope(scope, 'total_variation_loss', [inputs]): 214 | loss_y = tf.losses.absolute_difference( 215 | tf.slice(inputs, [0, 0, 0, 0], [-1, height-1, -1, -1]), 216 | tf.slice(inputs, [0, 1, 0, 0], [-1, -1, -1, -1]), 217 | weights=weights, 218 | scope='loss_y') 219 | loss_x = tf.losses.absolute_difference( 220 | tf.slice(inputs, [0, 0, 0, 0], [-1, -1, width-1, -1]), 221 | tf.slice(inputs, [0, 0, 1, 0], [-1, -1, -1, -1]), 222 | weights=weights, 223 | scope='loss_x') 224 | loss = loss_y + loss_x 225 | return loss 226 | -------------------------------------------------------------------------------- /models/losses.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/losses.pyc -------------------------------------------------------------------------------- /models/models_factory.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import yaml 7 | 8 | from models import avatar_net 9 | 10 | slim = tf.contrib.slim 11 | 12 | models_map = { 13 | 'AvatarNet': avatar_net.AvatarNet, 14 | } 15 | 16 | 17 | def get_model(filename): 18 | if not tf.gfile.Exists(filename): 19 | raise ValueError('The config file [%s] does not exist.' % filename) 20 | 21 | with open(filename, 'rb') as f: 22 | options = yaml.load(f) 23 | model_name = options.get('model_name') 24 | print('Finish loading the model [%s] configuration' % model_name) 25 | if model_name not in models_map: 26 | raise ValueError('Name of model [%s] unknown' % model_name) 27 | model = models_map[model_name](options) 28 | return model, options 29 | -------------------------------------------------------------------------------- /models/models_factory.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/models_factory.pyc -------------------------------------------------------------------------------- /models/network_ops.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import tensorflow as tf 6 | 7 | slim = tf.contrib.slim 8 | 9 | 10 | # functions for neural network layers 11 | @slim.add_arg_scope 12 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 13 | """strided 2-D convolution with 'REFLECT' padding. 14 | 15 | Args: 16 | inputs: A 4-D tensor of size [batch, height, width, channel] 17 | num_outputs: An integer, the number of output filters 18 | kernel_size: An int with the kernel_size of the filters 19 | stride: An integer, the output stride 20 | rate: An integer, rate for atrous convolution 21 | scope: Optional scope 22 | 23 | Returns: 24 | output: A 4-D tensor of size [batch, height_out, width_out, channel] with 25 | the convolution output. 26 | """ 27 | if kernel_size == 1: 28 | return slim.conv2d(inputs, num_outputs, kernel_size=1, stride=stride, 29 | rate=rate, padding='SAME', scope=scope) 30 | else: 31 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 32 | pad_total = kernel_size_effective - 1 33 | pad_beg = pad_total // 2 34 | pad_end = pad_total - pad_beg 35 | paddings = [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]] 36 | inputs = tf.pad(inputs, paddings=paddings, mode="REFLECT") 37 | outputs = slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 38 | rate=rate, padding='VALID', scope=scope) 39 | return outputs 40 | 41 | 42 | @slim.add_arg_scope 43 | def conv2d_resize(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 44 | """deconvolution alternatively with the conv2d_transpose, where we 45 | first resize the inputs, and then convolve the results, see 46 | http://distill.pub/2016/deconv-checkerboard/ 47 | 48 | Args: 49 | inputs: A 4-D tensor of size [batch, height, width, channel] 50 | num_outputs: An integer, the number of output filters 51 | kernel_size: An int with the kernel_size of the filters 52 | stride: An integer, the output stride 53 | rate: An integer, rate for atrous convolution 54 | scope: Optional scope 55 | 56 | Returns: 57 | output: A 4-D tensor of size [batch, height_out, width_out, channel] with 58 | the convolution output. 59 | """ 60 | if stride == 1: 61 | return conv2d_same(inputs, num_outputs, kernel_size, 62 | stride=1, rate=rate, scope=scope) 63 | else: 64 | stride_larger_than_one = tf.greater(stride, 1) 65 | height = tf.shape(inputs)[1] 66 | width = tf.shape(inputs)[2] 67 | new_height, new_width = tf.cond( 68 | stride_larger_than_one, 69 | lambda: (height*stride, width*stride), 70 | lambda: (height, width)) 71 | inputs_resize = tf.image.resize_nearest_neighbor(inputs, 72 | [new_height, new_width]) 73 | outputs = conv2d_same(inputs_resize, num_outputs, kernel_size, 74 | stride=1, rate=rate, scope=scope) 75 | return outputs 76 | 77 | 78 | @slim.add_arg_scope 79 | def lrelu(inputs, leak=0.2, scope=None): 80 | """customized leaky ReLU activation function 81 | https://github.com/tensorflow/tensorflow/issues/4079 82 | """ 83 | with tf.variable_scope(scope, 'lrelu'): 84 | f1 = 0.5 * (1 + leak) 85 | f2 = 0.5 * (1 - leak) 86 | return f1 * inputs + f2 * tf.abs(inputs) 87 | 88 | 89 | @slim.add_arg_scope 90 | def instance_norm(inputs, epsilon=1e-10): 91 | inst_mean, inst_var = tf.nn.moments(inputs, [1, 2], keep_dims=True) 92 | normalized_inputs = tf.div( 93 | tf.subtract(inputs, inst_mean), tf.sqrt(tf.add(inst_var, epsilon))) 94 | return normalized_inputs 95 | 96 | 97 | @slim.add_arg_scope 98 | def residual_unit_v0(inputs, depth, output_collections=None, scope=None): 99 | """Residual block version 0, the input and output has the same depth 100 | 101 | Args: 102 | inputs: a tensor of size [batch, height, width, channel] 103 | depth: the depth of the resnet unit output 104 | output_collections: collection to add the resnet unit output 105 | scope: optional variable_scope 106 | 107 | Returns: 108 | The resnet unit's output 109 | """ 110 | with tf.variable_scope(scope, 'res_unit_v0', [inputs]) as sc: 111 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 112 | if depth == depth_in: 113 | shortcut = inputs 114 | else: 115 | shortcut = slim.conv2d(inputs, depth, [1, 1], scope='shortcut') 116 | 117 | residual = conv2d_same(inputs, depth, 3, stride=1, scope='conv1') 118 | with slim.arg_scope([slim.conv2d], activation_fn=None): 119 | residual = conv2d_same(residual, depth, 3, stride=1, scope='conv2') 120 | 121 | output = tf.nn.relu(shortcut + residual) 122 | 123 | return slim.utils.collect_named_outputs( 124 | output_collections, sc.original_name_scope, output) 125 | 126 | 127 | @slim.add_arg_scope 128 | def residual_block_downsample(inputs, depth, stride, 129 | normalizer_fn=slim.layer_norm, 130 | activation_fn=tf.nn.relu, 131 | outputs_collections=None, scope=None): 132 | """Residual block version 2 for downsampling, with preactivation 133 | 134 | Args: 135 | inputs: a tensor of size [batch, height, width, channel] 136 | depth: the depth of the resnet unit output 137 | stride: the stride of the residual block 138 | normalizer_fn: normalizer function for the residual block 139 | activation_fn: activation function for the residual block 140 | outputs_collections: collection to add the resnet unit output 141 | scope: optional variable_scope 142 | 143 | Returns: 144 | The resnet unit's output 145 | """ 146 | with tf.variable_scope(scope, 'res_block_downsample', [inputs]) as sc: 147 | with slim.arg_scope([slim.conv2d], 148 | normalizer_fn=normalizer_fn, 149 | activation_fn=activation_fn): 150 | # preactivate the inputs 151 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 152 | preact = normalizer_fn(inputs, activation_fn=activation_fn, scope='preact') 153 | if depth == depth_in: 154 | shortcut = subsample(inputs, stride, scope='shortcut') 155 | else: 156 | with slim.arg_scope([slim.conv2d], 157 | normalizer_fn=None, activation_fn=None): 158 | shortcut = conv2d_same(preact, depth, 1, 159 | stride=stride, scope='shortcut') 160 | 161 | depth_botteneck = int(depth / 4) 162 | residual = slim.conv2d(preact, depth_botteneck, [1, 1], 163 | stride=1, scope='conv1') 164 | residual = conv2d_same(residual, depth_botteneck, 3, 165 | stride=stride, scope='conv2') 166 | residual = slim.conv2d(residual, depth, [1, 1], 167 | stride=1, normalizer_fn=None, 168 | activation_fn=None, scope='conv3') 169 | 170 | output = shortcut + residual 171 | 172 | return slim.utils.collect_named_outputs( 173 | outputs_collections, sc.original_name_scope, output) 174 | 175 | 176 | @slim.add_arg_scope 177 | def residual_block_upsample(inputs, depth, stride, 178 | normalizer_fn=slim.layer_norm, 179 | activation_fn=tf.nn.relu, 180 | outputs_collections=None, scope=None): 181 | """Residual block version 2 for upsampling, with preactivation 182 | 183 | Args: 184 | inputs: a tensor of size [batch, height, width, channel] 185 | depth: the depth of the resnet unit output 186 | stride: the stride of the residual block 187 | normalizer_fn: the normalizer function used in this block 188 | activation_fn: the activation function used in this block 189 | outputs_collections: collection to add the resnet unit output 190 | scope: optional variable_scope 191 | 192 | Returns: 193 | The resnet unit's output 194 | """ 195 | with tf.variable_scope(scope, 'res_block_upsample', [inputs]) as sc: 196 | with slim.arg_scope([slim.conv2d], 197 | normalizer_fn=normalizer_fn, 198 | activation_fn=activation_fn): 199 | # preactivate the inputs 200 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 201 | preact = normalizer_fn(inputs, activation_fn=activation_fn, scope='preact') 202 | if depth == depth_in: 203 | shortcut = upsample(inputs, stride, scope='shortcut') 204 | else: 205 | with slim.arg_scope([slim.conv2d], 206 | normalizer_fn=None, activation_fn=None): 207 | shortcut = conv2d_resize(preact, depth, 1, stride=stride, scope='shortcut') 208 | 209 | # calculate the residuals 210 | depth_botteneck = int(depth / 4) 211 | residual = slim.conv2d(preact, depth_botteneck, [1, 1], 212 | stride=1, scope='conv1') 213 | residual = conv2d_resize(residual, depth_botteneck, 3, 214 | stride=stride, scope='conv2') 215 | residual = slim.conv2d(residual, depth, [1, 1], 216 | stride=1, normalizer_fn=None, 217 | activation_fn=None, scope='conv3') 218 | 219 | output = shortcut + residual 220 | 221 | return slim.utils.collect_named_outputs( 222 | outputs_collections, sc.original_name_scope, output) 223 | 224 | 225 | def subsample(inputs, factor, scope=None): 226 | if factor == 1: 227 | return inputs 228 | else: 229 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 230 | 231 | 232 | def upsample(inputs, factor, scope=None): 233 | if factor == 1: 234 | return inputs 235 | else: 236 | factor_larger_than_one = tf.greater(factor, 1) 237 | height = tf.shape(inputs)[1] 238 | width = tf.shape(inputs)[2] 239 | new_height, new_width = tf.cond( 240 | factor_larger_than_one, 241 | lambda: (height*factor, width*factor), 242 | lambda: (height, width)) 243 | resized_inputs = tf.image.resize_nearest_neighbor( 244 | inputs, [new_height, new_width], name=scope) 245 | return resized_inputs 246 | 247 | 248 | def adaptive_instance_normalization(content_feature, style_feature): 249 | """adaptively transform the content feature by inverse instance normalization 250 | based on the 2nd order statistics of the style feature 251 | """ 252 | normalized_content_feature = instance_norm(content_feature) 253 | inst_mean, inst_var = tf.nn.moments(style_feature, [1, 2], keep_dims=True) 254 | return tf.sqrt(inst_var) * normalized_content_feature + inst_mean 255 | 256 | 257 | def whitening_colorization_transform(content_features, style_features): 258 | """transform the content feature based on the whitening and colorization transform""" 259 | content_shape = tf.shape(content_features) 260 | style_shape = tf.shape(style_features) 261 | 262 | # get the unbiased content and style features 263 | content_features = tf.reshape( 264 | content_features, shape=(content_shape[0], -1, content_shape[3])) 265 | style_features = tf.reshape( 266 | style_features, shape=(style_shape[0], -1, style_shape[3])) 267 | 268 | # get the covariance matrices 269 | content_gram = tf.matmul(content_features, content_features, transpose_a=True) 270 | content_gram /= tf.reduce_prod(tf.cast(content_shape[1:], tf.float32)) 271 | style_gram = tf.matmul(style_features, style_features, transpose_a=True) 272 | style_gram /= tf.reduce_prod(tf.cast(style_shape[1:], tf.float32)) 273 | 274 | ################################# 275 | # converting the feature spaces # 276 | ################################# 277 | s_c, u_c, v_c = tf.svd(content_gram, compute_uv=True) 278 | s_c = tf.expand_dims(s_c, axis=1) 279 | s_s, u_s, v_s = tf.svd(style_gram, compute_uv=True) 280 | s_s = tf.expand_dims(s_s, axis=1) 281 | 282 | # normalized features 283 | normalized_features = tf.matmul(content_features, u_c) 284 | normalized_features = tf.multiply(normalized_features, 1.0/(tf.sqrt(s_c+1e-5))) 285 | normalized_features = tf.matmul(normalized_features, v_c, transpose_b=True) 286 | 287 | # colorized features 288 | # broadcasting the tensors for matrix multiplication 289 | content_batch = tf.shape(u_c)[0] 290 | style_batch = tf.shape(u_s)[0] 291 | batch_multiplier = tf.cast(content_batch/style_batch, tf.int32) 292 | u_s = tf.tile(u_s, multiples=tf.stack([batch_multiplier, 1, 1])) 293 | v_s = tf.tile(v_s, multiples=tf.stack([batch_multiplier, 1, 1])) 294 | colorized_features = tf.matmul(normalized_features, u_s) 295 | colorized_features = tf.multiply(colorized_features, tf.sqrt(s_s+1e-5)) 296 | colorized_features = tf.matmul(colorized_features, v_s, transpose_b=True) 297 | 298 | # reshape the colorized features 299 | colorized_features = tf.reshape(colorized_features, shape=content_shape) 300 | return colorized_features 301 | -------------------------------------------------------------------------------- /models/network_ops.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/network_ops.pyc -------------------------------------------------------------------------------- /models/preprocessing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | from tensorflow.python.ops import control_flow_ops 8 | 9 | slim = tf.contrib.slim 10 | 11 | _R_MEAN = 123.68 12 | _G_MEAN = 116.78 13 | _B_MEAN = 103.94 14 | 15 | _RESIZE_SIDE_MIN = 256 16 | _RESIZE_SIDE_MAX = 512 17 | 18 | 19 | def _crop(image, offset_height, offset_width, crop_height, crop_width): 20 | original_shape = tf.shape(image) 21 | 22 | rank_assertion = tf.Assert( 23 | tf.equal(tf.rank(image), 3), 24 | ['Rank of image must be equal to 3.']) 25 | cropped_shape = control_flow_ops.with_dependencies( 26 | [rank_assertion], 27 | tf.stack([crop_height, crop_width, original_shape[2]])) 28 | 29 | size_assertion = tf.Assert( 30 | tf.logical_and( 31 | tf.greater_equal(original_shape[0], crop_height), 32 | tf.greater_equal(original_shape[1], crop_width)), 33 | ['Crop size greater than the image size.']) 34 | 35 | offsets = tf.to_int32(tf.stack([offset_height, offset_width, 0])) 36 | 37 | # Use tf.slice instead of crop_to_bounding box as it accepts tensors to 38 | # define the crop size. 39 | image = control_flow_ops.with_dependencies( 40 | [size_assertion], 41 | tf.slice(image, offsets, cropped_shape)) 42 | return tf.reshape(image, cropped_shape) 43 | 44 | 45 | def _random_crop(image_list, crop_height, crop_width): 46 | if not image_list: 47 | raise ValueError('Empty image_list.') 48 | 49 | # Compute the rank assertions. 50 | rank_assertions = [] 51 | for i in range(len(image_list)): 52 | image_rank = tf.rank(image_list[i]) 53 | rank_assert = tf.Assert( 54 | tf.equal(image_rank, 3), 55 | ['Wrong rank for tensor %s [expected] [actual]', 56 | image_list[i].name, 3, image_rank]) 57 | rank_assertions.append(rank_assert) 58 | 59 | image_shape = control_flow_ops.with_dependencies( 60 | [rank_assertions[0]], 61 | tf.shape(image_list[0])) 62 | image_height = image_shape[0] 63 | image_width = image_shape[1] 64 | crop_size_assert = tf.Assert( 65 | tf.logical_and( 66 | tf.greater_equal(image_height, crop_height), 67 | tf.greater_equal(image_width, crop_width)), 68 | ['Crop size greater than the image size.']) 69 | 70 | asserts = [rank_assertions[0], crop_size_assert] 71 | 72 | for i in range(1, len(image_list)): 73 | image = image_list[i] 74 | asserts.append(rank_assertions[i]) 75 | shape = control_flow_ops.with_dependencies([rank_assertions[i]], 76 | tf.shape(image)) 77 | height = shape[0] 78 | width = shape[1] 79 | 80 | height_assert = tf.Assert( 81 | tf.equal(height, image_height), 82 | ['Wrong height for tensor %s [expected][actual]', 83 | image.name, height, image_height]) 84 | width_assert = tf.Assert( 85 | tf.equal(width, image_width), 86 | ['Wrong width for tensor %s [expected][actual]', 87 | image.name, width, image_width]) 88 | asserts.extend([height_assert, width_assert]) 89 | 90 | # Create a random bounding box. 91 | # 92 | # Use tf.random_uniform and not numpy.random.rand as doing the former would 93 | # generate random numbers at graph eval time, unlike the latter which 94 | # generates random numbers at graph definition time. 95 | max_offset_height = control_flow_ops.with_dependencies( 96 | asserts, tf.reshape(image_height - crop_height + 1, [])) 97 | max_offset_width = control_flow_ops.with_dependencies( 98 | asserts, tf.reshape(image_width - crop_width + 1, [])) 99 | offset_height = tf.random_uniform( 100 | [], maxval=max_offset_height, dtype=tf.int32) 101 | offset_width = tf.random_uniform( 102 | [], maxval=max_offset_width, dtype=tf.int32) 103 | 104 | return [_crop(image, offset_height, offset_width, 105 | crop_height, crop_width) for image in image_list] 106 | 107 | 108 | def _central_crop(image_list, crop_height, crop_width): 109 | outputs = [] 110 | for image in image_list: 111 | image_height = tf.shape(image)[0] 112 | image_width = tf.shape(image)[1] 113 | 114 | offset_height = (image_height - crop_height) / 2 115 | offset_width = (image_width - crop_width) / 2 116 | 117 | outputs.append(_crop(image, offset_height, offset_width, 118 | crop_height, crop_width)) 119 | return outputs 120 | 121 | 122 | def _mean_image_subtraction(image, means=(_R_MEAN, _G_MEAN, _B_MEAN)): 123 | if image.get_shape().ndims != 3: 124 | raise ValueError('Input must be of size [height, width, C>0]') 125 | num_channels = image.get_shape().as_list()[-1] 126 | if len(means) != num_channels: 127 | raise ValueError('len(means) must match the number of channels') 128 | 129 | channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image) 130 | for i in range(num_channels): 131 | channels[i] -= means[i] 132 | return tf.concat(axis=2, values=channels) 133 | 134 | 135 | def _smallest_size_at_least(height, width, smallest_side): 136 | smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32) 137 | 138 | height = tf.to_float(height) 139 | width = tf.to_float(width) 140 | smallest_side = tf.to_float(smallest_side) 141 | 142 | scale = tf.cond(tf.greater(height, width), 143 | lambda: smallest_side / width, 144 | lambda: smallest_side / height) 145 | new_height = tf.to_int32(height * scale) 146 | new_width = tf.to_int32(width * scale) 147 | return new_height, new_width 148 | 149 | 150 | def _aspect_preserving_resize(image, smallest_side): 151 | smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32) 152 | 153 | shape = tf.shape(image) 154 | height = shape[0] 155 | width = shape[1] 156 | new_height, new_width = _smallest_size_at_least(height, width, smallest_side) 157 | image = tf.expand_dims(image, 0) 158 | resized_image = tf.image.resize_bilinear(image, [new_height, new_width], 159 | align_corners=False) 160 | resized_image = tf.squeeze(resized_image) 161 | resized_image.set_shape([None, None, 3]) 162 | return resized_image 163 | 164 | 165 | def preprocessing_for_train(image, output_height, output_width, resize_side): 166 | image = _aspect_preserving_resize(image, resize_side) 167 | image = _random_crop([image], output_height, output_width)[0] 168 | image.set_shape([output_height, output_width, 3]) 169 | image = tf.to_float(image) 170 | return _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 171 | 172 | 173 | def preprocessing_for_eval(image, output_height, output_width, resize_side): 174 | image = _aspect_preserving_resize(image, resize_side) 175 | image = _central_crop([image], output_height, output_width)[0] 176 | image.set_shape([output_height, output_width, 3]) 177 | image = tf.to_float(image) 178 | return _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN]) 179 | 180 | 181 | def preprocessing_image(image, output_height, output_width, 182 | resize_side=_RESIZE_SIDE_MIN, is_training=False): 183 | if is_training: 184 | return preprocessing_for_train(image, output_height, output_width, resize_side) 185 | else: 186 | return preprocessing_for_eval(image, output_height, output_width, resize_side) 187 | 188 | 189 | ######################### 190 | # personal modification # 191 | ######################### 192 | def mean_image_subtraction(images, means=(_R_MEAN, _G_MEAN, _B_MEAN)): 193 | """works for one single image with dynamic shapes""" 194 | num_channels = 3 195 | channels = tf.split(images, num_channels, axis=2) 196 | for i in range(num_channels): 197 | channels[i] -= means[i] 198 | return tf.concat(channels, axis=2) 199 | 200 | 201 | def mean_image_summation(image, means=(_R_MEAN, _G_MEAN, _B_MEAN)): 202 | """works for one single image with dynamic shapes""" 203 | num_channels = 3 204 | channels = tf.split(image, num_channels, axis=2) 205 | for i in range(num_channels): 206 | channels[i] += means[i] 207 | return tf.concat(channels, axis=2) 208 | 209 | 210 | def batch_mean_image_subtraction(images, means=(_R_MEAN, _G_MEAN, _B_MEAN)): 211 | if images.get_shape().ndims != 4: 212 | raise ValueError('Input must be of size [batch, height, width, C>0') 213 | num_channels = images.get_shape().as_list()[-1] 214 | if len(means) != num_channels: 215 | raise ValueError('len(means) must match the number of channels') 216 | channels = tf.split(images, num_channels, axis=3) 217 | for i in range(num_channels): 218 | channels[i] -= means[i] 219 | return tf.concat(channels, axis=3) 220 | 221 | 222 | def batch_mean_image_summation(images, means=(_R_MEAN, _G_MEAN, _B_MEAN)): 223 | if images.get_shape().ndims != 4: 224 | raise ValueError('Input must be of size [batch, height, width, C>0') 225 | num_channels = images.get_shape().as_list()[-1] 226 | if len(means) != num_channels: 227 | raise ValueError('len(means) must match the number of channels') 228 | channels = tf.split(images, num_channels, axis=3) 229 | for i in range(num_channels): 230 | channels[i] += means[i] 231 | return tf.concat(channels, axis=3) 232 | 233 | 234 | def image_normalization(images, means=(_R_MEAN, _G_MEAN, _B_MEAN), scale=127.5): 235 | """rescale the images so that their magnitude ranging from [-1, 1]""" 236 | if images.get_shape().ndims == 4: 237 | return tf.div(batch_mean_image_subtraction(images, means), scale) 238 | elif images.get_shape().ndims == 3: 239 | return tf.div(mean_image_subtraction(images, means), scale) 240 | else: 241 | raise ValueError('Input must be of dimensions 3 or 4') 242 | 243 | 244 | def aspect_preserving_resize(image, smallest_side): 245 | return _aspect_preserving_resize(image, smallest_side) 246 | -------------------------------------------------------------------------------- /models/preprocessing.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/preprocessing.pyc -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains models definitions for versions of the Oxford VGG network. 16 | 17 | These models definitions were introduced in the following technical report: 18 | 19 | Very Deep Convolutional Networks For Large-Scale Image Recognition 20 | Karen Simonyan and Andrew Zisserman 21 | arXiv technical report, 2015 22 | PDF: http://arxiv.org/pdf/1409.1556.pdf 23 | ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf 24 | CC-BY-4.0 25 | 26 | More information can be obtained from the VGG website: 27 | www.robots.ox.ac.uk/~vgg/research/very_deep/ 28 | 29 | Usage: 30 | with slim.arg_scope(vgg.vgg_arg_scope()): 31 | outputs, end_points = vgg.vgg_a(inputs) 32 | 33 | with slim.arg_scope(vgg.vgg_arg_scope()): 34 | outputs, end_points = vgg.vgg_16(inputs) 35 | 36 | @@vgg_a 37 | @@vgg_16 38 | @@vgg_19 39 | """ 40 | from __future__ import absolute_import 41 | from __future__ import division 42 | from __future__ import print_function 43 | 44 | import tensorflow as tf 45 | 46 | slim = tf.contrib.slim 47 | 48 | 49 | def vgg_arg_scope(weight_decay=0.0005): 50 | """Defines the VGG arg scope. 51 | 52 | Args: 53 | weight_decay: The l2 regularization coefficient. 54 | 55 | Returns: 56 | An arg_scope. 57 | """ 58 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 59 | activation_fn=tf.nn.relu, 60 | weights_regularizer=slim.l2_regularizer(weight_decay), 61 | biases_initializer=tf.zeros_initializer()): 62 | with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc: 63 | return arg_sc 64 | 65 | 66 | def vgg_a(inputs, 67 | num_classes=1000, 68 | is_training=True, 69 | dropout_keep_prob=0.5, 70 | spatial_squeeze=True, 71 | scope='vgg_a'): 72 | """Oxford Net VGG 11-Layers version A Example. 73 | 74 | Note: All the fully_connected layers have been transformed to conv2d layers. 75 | To use in classification mode, resize input to 224x224. 76 | 77 | Args: 78 | inputs: a tensor of size [batch_size, height, width, channels]. 79 | num_classes: number of predicted classes. 80 | is_training: whether or not the models is being trained. 81 | dropout_keep_prob: the probability that activations are kept in the dropout 82 | layers during training. 83 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 84 | outputs. Useful to remove unnecessary dimensions for classification. 85 | scope: Optional scope for the variables. 86 | 87 | Returns: 88 | the last op containing the log predictions and end_points dict. 89 | """ 90 | with tf.variable_scope(scope, 'vgg_a', [inputs]) as sc: 91 | end_points_collection = sc.name + '_end_points' 92 | # Collect outputs for conv2d, fully_connected and max_pool2d. 93 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], 94 | outputs_collections=end_points_collection): 95 | net = slim.repeat(inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1') 96 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 97 | net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv2') 98 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 99 | net = slim.repeat(net, 2, slim.conv2d, 256, [3, 3], scope='conv3') 100 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 101 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv4') 102 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 103 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5') 104 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 105 | # Use conv2d instead of fully_connected layers. 106 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') 107 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 108 | scope='dropout6') 109 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 111 | scope='dropout7') 112 | net = slim.conv2d(net, num_classes, [1, 1], 113 | activation_fn=None, 114 | normalizer_fn=None, 115 | scope='fc8') 116 | # Convert end_points_collection into a end_point dict. 117 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 118 | if spatial_squeeze: 119 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 120 | end_points[sc.name + '/fc8'] = net 121 | return net, end_points 122 | 123 | 124 | vgg_a.default_image_size = 224 125 | 126 | 127 | def vgg_16(inputs, 128 | num_classes=1000, 129 | is_training=True, 130 | dropout_keep_prob=0.5, 131 | spatial_squeeze=True, 132 | reuse=True, 133 | scope='vgg_16'): 134 | """Oxford Net VGG 16-Layers version D Example. 135 | 136 | Note: All the fully_connected layers have been transformed to conv2d layers. 137 | To use in classification mode, resize input to 224x224. 138 | 139 | Args: 140 | inputs: a tensor of size [batch_size, height, width, channels]. 141 | num_classes: number of predicted classes. 142 | is_training: whether or not the models is being trained. 143 | dropout_keep_prob: the probability that activations are kept in the dropout 144 | layers during training. 145 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 146 | outputs. Useful to remove unnecessary dimensions for classification. 147 | reuse: whether to reuse the network parameters 148 | scope: Optional scope for the variables. 149 | 150 | Returns: 151 | the last op containing the log predictions and end_points dict. 152 | """ 153 | with tf.variable_scope(scope, 'vgg_16', [inputs], reuse=reuse) as sc: 154 | end_points_collection = sc.original_name_scope + '_end_points' 155 | # Collect outputs for conv2d, fully_connected and max_pool2d. 156 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 157 | outputs_collections=end_points_collection): 158 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 159 | # net = slim.max_pool2d(net, [2, 2], scope='pool1') 160 | net = slim.avg_pool2d(net, [2, 2], scope='pool1') 161 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 162 | # net = slim.max_pool2d(net, [2, 2], scope='pool2') 163 | net = slim.avg_pool2d(net, [2, 2], scope='pool2') 164 | net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') 165 | # net = slim.max_pool2d(net, [2, 2], scope='pool3') 166 | net = slim.avg_pool2d(net, [2, 2], scope='pool3') 167 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') 168 | # net = slim.max_pool2d(net, [2, 2], scope='pool4') 169 | net = slim.avg_pool2d(net, [2, 2], scope='pool4') 170 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5') 171 | # net = slim.max_pool2d(net, [2, 2], scope='pool5') 172 | net = slim.avg_pool2d(net, [2, 2], scope='pool5') 173 | # Use conv2d instead of fully_connected layers. 174 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') 175 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 176 | scope='dropout6') 177 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 178 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 179 | scope='dropout7') 180 | net = slim.conv2d(net, num_classes, [1, 1], 181 | activation_fn=None, 182 | normalizer_fn=None, 183 | scope='fc8') 184 | # Convert end_points_collection into a end_point dict. 185 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 186 | if spatial_squeeze: 187 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 188 | end_points[sc.name + '/fc8'] = net 189 | return net, end_points 190 | 191 | 192 | vgg_16.default_image_size = 224 193 | 194 | 195 | def vgg_19(inputs, 196 | num_classes=1000, 197 | is_training=True, 198 | dropout_keep_prob=0.5, 199 | spatial_squeeze=True, 200 | reuse=True, 201 | scope='vgg_19'): 202 | """Oxford Net VGG 19-Layers version E Example. 203 | 204 | Note: All the fully_connected layers have been transformed to conv2d layers. 205 | To use in classification mode, resize input to 224x224. 206 | 207 | Args: 208 | inputs: a tensor of size [batch_size, height, width, channels]. 209 | num_classes: number of predicted classes. 210 | is_training: whether or not the models is being trained. 211 | dropout_keep_prob: the probability that activations are kept in the dropout 212 | layers during training. 213 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 214 | outputs. Useful to remove unnecessary dimensions for classification. 215 | reuse: whether to reuse the network parameters 216 | scope: Optional scope for the variables. 217 | 218 | Returns: 219 | the last op containing the log predictions and end_points dict. 220 | """ 221 | with tf.variable_scope(scope, 'vgg_19', [inputs], reuse=reuse) as sc: 222 | end_points_collection = sc.original_name_scope + '_end_points' 223 | # Collect outputs for conv2d, fully_connected and max_pool2d. 224 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 225 | outputs_collections=end_points_collection): 226 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 227 | # net = slim.max_pool2d(net, [2, 2], scope='pool1') 228 | net = slim.avg_pool2d(net, [2, 2], scope='pool1') 229 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 230 | # net = slim.max_pool2d(net, [2, 2], scope='pool2') 231 | net = slim.avg_pool2d(net, [2, 2], scope='pool2') 232 | net = slim.repeat(net, 4, slim.conv2d, 256, [3, 3], scope='conv3') 233 | # net = slim.max_pool2d(net, [2, 2], scope='pool3') 234 | net = slim.avg_pool2d(net, [2, 2], scope='pool3') 235 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv4') 236 | # net = slim.max_pool2d(net, [2, 2], scope='pool4') 237 | net = slim.avg_pool2d(net, [2, 2], scope='pool4') 238 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv5') 239 | # net = slim.max_pool2d(net, [2, 2], scope='pool5') 240 | net = slim.avg_pool2d(net, [2, 2], scope='pool5') 241 | # Use conv2d instead of fully_connected layers. 242 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') 243 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 244 | scope='dropout6') 245 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 246 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 247 | scope='dropout7') 248 | net = slim.conv2d(net, num_classes, [1, 1], 249 | activation_fn=None, 250 | normalizer_fn=None, 251 | scope='fc8') 252 | # Convert end_points_collection into a end_point dict. 253 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 254 | if spatial_squeeze: 255 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 256 | end_points[sc.name + '/fc8'] = net 257 | return net, end_points 258 | 259 | 260 | vgg_19.default_image_size = 224 261 | 262 | 263 | # Alias 264 | vgg_d = vgg_16 265 | vgg_e = vgg_19 266 | -------------------------------------------------------------------------------- /models/vgg.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/vgg.pyc -------------------------------------------------------------------------------- /models/vgg_decoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import tensorflow as tf 6 | 7 | from models import network_ops 8 | 9 | slim = tf.contrib.slim 10 | 11 | vgg_19_decoder_architecture = [ 12 | ('conv5/conv5_4', ('c', 512, 3)), 13 | ('conv5/conv5_3', ('c', 512, 3)), 14 | ('conv5/conv5_2', ('c', 512, 3)), 15 | ('conv5/conv5_1', ('c', 512, 3)), 16 | ('conv4/conv4_4', ('uc', 512, 3)), 17 | ('conv4/conv4_3', ('c', 512, 3)), 18 | ('conv4/conv4_2', ('c', 512, 3)), 19 | ('conv4/conv4_1', ('c', 256, 3)), 20 | ('conv3/conv3_4', ('uc', 256, 3)), 21 | ('conv3/conv3_3', ('c', 256, 3)), 22 | ('conv3/conv3_2', ('c', 256, 3)), 23 | ('conv3/conv3_1', ('c', 128, 3)), 24 | ('conv2/conv2_2', ('uc', 128, 3)), 25 | ('conv2/conv2_1', ('c', 64, 3)), 26 | ('conv1/conv1_2', ('uc', 64, 3)), 27 | ('conv1/conv1_1', ('c', 64, 3)), 28 | ] 29 | 30 | vgg_16_decoder_architecture = [ 31 | ('conv5/conv5_3', ('c', 512, 3)), 32 | ('conv5/conv5_2', ('c', 512, 3)), 33 | ('conv5/conv5_1', ('c', 512, 3)), 34 | ('conv4/conv4_3', ('uc', 512, 3)), 35 | ('conv4/conv4_2', ('c', 512, 3)), 36 | ('conv4/conv4_1', ('c', 256, 3)), 37 | ('conv3/conv3_3', ('uc', 256, 3)), 38 | ('conv3/conv3_2', ('c', 256, 3)), 39 | ('conv3/conv3_1', ('c', 128, 3)), 40 | ('conv2/conv2_2', ('uc', 128, 3)), 41 | ('conv2/conv2_1', ('c', 64, 3)), 42 | ('conv1/conv1_2', ('uc', 64, 3)), 43 | ('conv1/conv1_1', ('c', 64, 3)), 44 | ] 45 | 46 | network_map = { 47 | 'vgg_19': vgg_19_decoder_architecture, 48 | 'vgg_16': vgg_16_decoder_architecture, 49 | } 50 | 51 | 52 | def vgg_decoder_arg_scope(weight_decay=0.0005): 53 | with slim.arg_scope( 54 | [slim.conv2d], 55 | padding='SAME', 56 | activation_fn=tf.nn.relu, 57 | normalizer_fn=None, 58 | weights_initializer=slim.xavier_initializer(uniform=False), 59 | weights_regularizer=slim.l2_regularizer(weight_decay)) as arg_sc: 60 | return arg_sc 61 | 62 | 63 | def vgg_decoder(inputs, 64 | network_name='vgg_16', 65 | starting_layer='conv1/conv1_1', 66 | reuse=False, 67 | scope=None): 68 | """construct the decoder network for the vgg models 69 | 70 | Args: 71 | inputs: input features [batch_size, height, width, channel] 72 | network_name: the type of the network, default is vgg_16 73 | starting_layer: the starting reflectance layer, default is 'conv1/conv1_1' 74 | reuse: (optional) whether to reuse the network 75 | scope: (optional) the scope of the network 76 | 77 | Returns: 78 | outputs: the decoded feature maps 79 | """ 80 | with tf.variable_scope(scope, 'image_decoder', reuse=reuse): 81 | # gather the output with identity mapping 82 | net = tf.identity(inputs) 83 | 84 | # starting inferring the network 85 | is_active = False 86 | for layer, layer_struct in network_map[network_name]: 87 | if layer == starting_layer: 88 | is_active = True 89 | if is_active: 90 | conv_type, num_outputs, kernel_size = layer_struct 91 | if conv_type == 'c': 92 | net = network_ops.conv2d_same(net, num_outputs, kernel_size, 1, scope=layer) 93 | elif conv_type == 'uc': 94 | net = network_ops.conv2d_resize(net, num_outputs, kernel_size, 2, scope=layer) 95 | with slim.arg_scope([slim.conv2d], normalizer_fn=None, activation_fn=tf.tanh): 96 | outputs = network_ops.conv2d_same(net, 3, 7, 1, scope='output') 97 | return outputs * 150.0 + 127.5 98 | 99 | 100 | def vgg_combined_decoder(inputs, 101 | additional_features, 102 | fusion_fn=None, 103 | network_name='vgg_16', 104 | starting_layer='conv1/conv1_1', 105 | reuse=False, 106 | scope=None): 107 | """construct the decoder network with additional feature combination 108 | 109 | Args: 110 | inputs: input features [batch_size, height, width, channel] 111 | additional_features: a dict contains the additional features 112 | fusion_fn: the fusion function to combine features 113 | network_name: the type of the network, default is vgg_16 114 | starting_layer: the starting reflectance layer, default is 'conv1/conv1_1' 115 | reuse: (optional) whether to reuse the network 116 | scope: (optional) the scope of the network 117 | 118 | Returns: 119 | outputs: the decoded feature maps 120 | """ 121 | with tf.variable_scope(scope, 'combined_decoder', reuse=reuse): 122 | # gather the output with identity mapping 123 | net = tf.identity(inputs) 124 | 125 | # starting inferring the network 126 | is_active = False 127 | for layer, layer_struct in network_map[network_name]: 128 | if layer == starting_layer: 129 | is_active = True 130 | if is_active: 131 | conv_type, num_outputs, kernel_size = layer_struct 132 | 133 | # combine the feature 134 | add_feature = additional_features.get(layer) 135 | if add_feature is not None and layer != starting_layer: 136 | net = fusion_fn(net, add_feature) 137 | 138 | if conv_type == 'c': 139 | net = network_ops.conv2d_same(net, num_outputs, kernel_size, 1, scope=layer) 140 | elif conv_type == 'uc': 141 | net = network_ops.conv2d_resize(net, num_outputs, kernel_size, 2, scope=layer) 142 | with slim.arg_scope([slim.conv2d], normalizer_fn=None, activation_fn=None): 143 | outputs = network_ops.conv2d_same(net, 3, 7, 1, scope='output') 144 | return outputs + 127.5 145 | 146 | 147 | def vgg_multiple_combined_decoder(inputs, 148 | additional_features, 149 | blending_weights, 150 | fusion_fn=None, 151 | network_name='vgg_16', 152 | starting_layer='conv1/conv1_1', 153 | reuse=False, 154 | scope=None): 155 | """construct the decoder network with additional feature combination 156 | 157 | Args: 158 | inputs: input features [batch_size, height, width, channel] 159 | additional_features: a dict contains the additional features 160 | blending_weights: the list of weights used for feature blending 161 | fusion_fn: the fusion function to combine features 162 | network_name: the type of the network, default is vgg_16 163 | starting_layer: the starting reflectance layer, default is 'conv1/conv1_1' 164 | reuse: (optional) whether to reuse the network 165 | scope: (optional) the scope of the network 166 | 167 | Returns: 168 | outputs: the decoded feature maps 169 | """ 170 | with tf.variable_scope(scope, 'combined_decoder', reuse=reuse): 171 | # gather the output with identity mapping 172 | net = tf.identity(inputs) 173 | 174 | # starting inferring the network 175 | is_active = False 176 | for layer, layer_struct in network_map[network_name]: 177 | if layer == starting_layer: 178 | is_active = True 179 | if is_active: 180 | conv_type, num_outputs, kernel_size = layer_struct 181 | 182 | # combine the feature 183 | add_feature = additional_features[0].get(layer) 184 | if add_feature is not None and layer != starting_layer: 185 | # fuse multiple styles 186 | n = 0 187 | layer_output = 0.0 188 | for additional_feature in additional_features: 189 | additional_layer_feature = additional_feature.get(layer) 190 | fused_layer_feature = fusion_fn(net, additional_layer_feature) 191 | layer_output += blending_weights[n] * fused_layer_feature 192 | n += 1 193 | net = layer_output 194 | 195 | if conv_type == 'c': 196 | net = network_ops.conv2d_same(net, num_outputs, kernel_size, 1, scope=layer) 197 | elif conv_type == 'uc': 198 | net = network_ops.conv2d_resize(net, num_outputs, kernel_size, 2, scope=layer) 199 | with slim.arg_scope([slim.conv2d], normalizer_fn=None, activation_fn=None): 200 | outputs = network_ops.conv2d_same(net, 3, 7, 1, scope='output') 201 | return outputs + 127.5 202 | -------------------------------------------------------------------------------- /models/vgg_decoder.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LucasSheng/avatar-net/8ee86d758efab378d570134366fe61adbc2d1030/models/vgg_decoder.pyc -------------------------------------------------------------------------------- /scripts/evaluate_style_transfer.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_ID=$1 4 | # content image folders: 5 | # exemplar content images: ./data/contents/images/ 6 | # exemplar content videos: ./data/contents/sequences/ 7 | CONTENT_DATASET_DIR=$2 8 | # style image folders: ./data/styles/ 9 | STYLE_DATASET_DIR=$3 10 | # output image folders: ./results/sequences/ 11 | EVAL_DATASET_DIR=$4 12 | 13 | # network configuration 14 | CONFIG_DIR=./configs/AvatarNet_config.yml 15 | 16 | # the network path for the trained auto-encoding network (need to change accordingly) 17 | MODEL_DIR=/DATA/AvatarNet 18 | 19 | CUDA_VISIBLE_DEVICES=${CUDA_ID} \ 20 | python evaluate_style_transfer.py \ 21 | --checkpoint_dir=${MODEL_DIR} \ 22 | --model_config_path=${CONFIG_DIR} \ 23 | --content_dataset_dir=${CONTENT_DATASET_DIR} \ 24 | --style_dataset_dir=${STYLE_DATASET_DIR} \ 25 | --eval_dir=${EVAL_DATASET_DIR} \ 26 | --inter_weight=0.8 -------------------------------------------------------------------------------- /scripts/train_image_reconstruction.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_ID=$1 4 | # MSCOCO tfexample dataset path 5 | DATASET_DIR=$2 6 | # model path 7 | MODEL_DIR=$3 8 | 9 | # network configuration 10 | CONFIG_DIR=./configs/AvatarNet_config.yml 11 | 12 | CUDA_VISIBLE_DEVICES=${CUDA_ID} \ 13 | python train_image_reconstruction.py \ 14 | --train_dir=${MODEL_DIR} \ 15 | --model_config=${CONFIG_DIR} \ 16 | --dataset_dir=${DATASET_DIR} \ 17 | --dataset_name=MSCOCO \ 18 | --dataset_split_name=train \ 19 | --batch_size=8 \ 20 | --max_number_of_step=120000 \ 21 | --optimizer=adam \ 22 | --learning_rate_decay_type=fixed \ 23 | --learning_rate=0.0001 -------------------------------------------------------------------------------- /train_image_reconstruction.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | from tensorflow.python.ops import control_flow_ops 8 | from models.preprocessing import preprocessing_image 9 | from models import models_factory 10 | from datasets import dataset_utils 11 | 12 | slim = tf.contrib.slim 13 | 14 | 15 | tf.app.flags.DEFINE_integer( 16 | 'num_readers', 4, 17 | 'The number of parallel readers that read data from the dataset.') 18 | tf.app.flags.DEFINE_integer( 19 | 'num_preprocessing_threads', 1, 20 | 'The number of threads used to create the batches.') 21 | 22 | # ====================== # 23 | # Training specification # 24 | # ====================== # 25 | tf.app.flags.DEFINE_string( 26 | 'train_dir', '/tmp/tfmodel', 27 | 'Directory where checkpoints and event logs are written to.') 28 | tf.app.flags.DEFINE_integer( 29 | 'log_every_n_steps', 100, 30 | 'The frequency with which logs are printed, in seconds.') 31 | tf.app.flags.DEFINE_integer( 32 | 'save_interval_secs', 600, 33 | 'The frequency with which the models is saved, in seconds.') 34 | tf.app.flags.DEFINE_integer( 35 | 'save_summaries_secs', 120, 36 | 'The frequency with which summaries are saved, in seconds.') 37 | tf.app.flags.DEFINE_integer( 38 | 'batch_size', 32, 'The number of samples in each batch.') 39 | tf.app.flags.DEFINE_integer( 40 | 'max_number_of_steps', None, 'The maximum number of training steps.') 41 | 42 | # ============= # 43 | # Dataset Flags # 44 | # ============= # 45 | tf.app.flags.DEFINE_string( 46 | 'dataset_dir', None, 47 | 'The directory where the dataset files are stored.') 48 | tf.app.flags.DEFINE_string( 49 | 'dataset_name', None, 50 | 'The name of the dataset to load.') 51 | tf.app.flags.DEFINE_string( 52 | 'dataset_split_name', 'train', 53 | 'The name of the train/test split.') 54 | 55 | ####################### 56 | # Model specification # 57 | ####################### 58 | tf.app.flags.DEFINE_string( 59 | 'model_config', None, 60 | 'Directory where the configuration of the models is stored.') 61 | 62 | ###################### 63 | # Optimization Flags # 64 | ###################### 65 | tf.app.flags.DEFINE_string( 66 | 'optimizer', 'rmsprop', 67 | 'The name of the optimizer, one of "adadelta", "adagrad", "adam",' 68 | '"ftrl", "momentum", "sgd" or "rmsprop".') 69 | tf.app.flags.DEFINE_float( 70 | 'adadelta_rho', 0.95, 'The decay rate for adadelta.') 71 | tf.app.flags.DEFINE_float( 72 | 'adagrad_initial_accumulator_value', 0.1, 73 | 'Starting value for the AdaGrad accumulators.') 74 | tf.app.flags.DEFINE_float( 75 | 'adam_beta1', 0.9, 76 | 'The exponential decay rate for the 1st moment estimates.') 77 | tf.app.flags.DEFINE_float( 78 | 'adam_beta2', 0.999, 79 | 'The exponential decay rate for the 2nd moment estimates.') 80 | tf.app.flags.DEFINE_float( 81 | 'opt_epsilon', 1.0, 'Epsilon term for the optimizer.') 82 | tf.app.flags.DEFINE_float( 83 | 'ftrl_learning_rate_power', -0.5, 'The learning rate power.') 84 | tf.app.flags.DEFINE_float( 85 | 'ftrl_initial_accumulator_value', 0.1, 86 | 'Starting value for the FTRL accumulators.') 87 | tf.app.flags.DEFINE_float( 88 | 'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.') 89 | tf.app.flags.DEFINE_float( 90 | 'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.') 91 | tf.app.flags.DEFINE_float( 92 | 'momentum', 0.9, 93 | 'The momentum for the MomentumOptimizer and RMSPropOptimizer.') 94 | tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.') 95 | tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.') 96 | 97 | ####################### 98 | # Learning Rate Flags # 99 | ####################### 100 | tf.app.flags.DEFINE_string( 101 | 'learning_rate_decay_type', 'exponential', 102 | 'Specififies how the learning rate is decayed. One of "fixed",' 103 | '"exponential", or "polynomial".') 104 | tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') 105 | tf.app.flags.DEFINE_float( 106 | 'end_learning_rate', 0.0001, 107 | 'The minimal end learning rate used by a polynomial decay learning rate.') 108 | tf.app.flags.DEFINE_float( 109 | 'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.') 110 | tf.app.flags.DEFINE_float( 111 | 'num_epochs_per_decay', 2.0, 112 | 'Number of epochs after which learning rate decays.') 113 | tf.app.flags.DEFINE_float( 114 | 'moving_average_decay', None, 115 | 'If left as None, the moving averages are not used.') 116 | 117 | # ============================ # 118 | # Fine-Tuning Flags 119 | # ============================ # 120 | tf.app.flags.DEFINE_string( 121 | 'checkpoint_path', None, 122 | 'The path to a checkpoint from which to fine-tune.') 123 | tf.app.flags.DEFINE_string( 124 | 'checkpoint_exclude_scopes', None, 125 | 'Comma-separated list of scopes of variables to exclude when restoring ' 126 | 'from a checkpoint.') 127 | tf.app.flags.DEFINE_string( 128 | 'trainable_scopes', None, 129 | 'Comma-separated list of scopes to filter the set of variables to train.' 130 | 'By default, None would train all the variables.') 131 | tf.app.flags.DEFINE_boolean( 132 | 'ignore_missing_vars', False, 133 | 'When restoring a checkpoint would ignore missing variables.') 134 | 135 | FLAGS = tf.app.flags.FLAGS 136 | 137 | 138 | def _configure_learning_rate(num_samples_per_epoch, global_step): 139 | """Configures the learning rate. 140 | 141 | Args: 142 | num_samples_per_epoch: The number of samples in each epoch of training 143 | global_step: The global_step tensor. 144 | 145 | Returns: 146 | A `Tensor` representing the learning rate 147 | 148 | Raises: 149 | ValueError 150 | """ 151 | decay_steps = int(num_samples_per_epoch / FLAGS.batch_size * 152 | FLAGS.num_epochs_per_decay) 153 | if FLAGS.learning_rate_decay_type == 'exponential': 154 | return tf.train.exponential_decay( 155 | FLAGS.learning_rate, 156 | global_step, 157 | decay_steps, 158 | FLAGS.learning_rate_decay_factor, 159 | staircase=True, 160 | name='exponential_decay_learning_rate') 161 | elif FLAGS.learning_rate_decay_type == 'fixed': 162 | return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate') 163 | elif FLAGS.learning_rate_decay_type == 'polynomial': 164 | return tf.train.polynomial_decay( 165 | FLAGS.learning_rate, 166 | global_step, 167 | decay_steps, 168 | FLAGS.end_learning_rate, 169 | power=1.0, 170 | cycle=False, 171 | name='polynomial_decay_learning_rate') 172 | else: 173 | raise ValueError('learning_rate_decay_type [%s] was not recognized', 174 | FLAGS.learning_rate_decay_type) 175 | 176 | 177 | def _configure_optimizer(learning_rate): 178 | """Configures the optimizer used for training. 179 | 180 | Args: 181 | learning_rate: A scalar or 'Tensor' learning rate 182 | 183 | Returns: 184 | An instance of an optimizer 185 | 186 | Raises: 187 | ValueError: if FLAGS.optimizer is not recognized 188 | """ 189 | if FLAGS.optimizer == 'adadelta': 190 | optimizer = tf.train.AdadeltaOptimizer( 191 | learning_rate, rho=FLAGS.adadelta_rho, epsilon=FLAGS.opt_epsilon) 192 | elif FLAGS.optimizer == 'adagrad': 193 | optimizer = tf.train.AdagradOptimizer( 194 | learning_rate, 195 | initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value) 196 | elif FLAGS.optimizer == 'adam': 197 | optimizer = tf.train.AdamOptimizer( 198 | learning_rate, 199 | beta1=FLAGS.adam_beta1, 200 | beta2=FLAGS.adam_beta2, 201 | epsilon=FLAGS.opt_epsilon) 202 | elif FLAGS.optimizer == 'ftr1': 203 | optimizer = tf.train.FtrlOptimizer( 204 | learning_rate, 205 | learning_rate_power=FLAGS.ftrl_learning_rate_power, 206 | initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value, 207 | l1_regularization_strength=FLAGS.ftrl_l1, 208 | l2_regularization_strength=FLAGS.ftrl_l2) 209 | elif FLAGS.optimizer == 'momentum': 210 | optimizer = tf.train.MomentumOptimizer( 211 | learning_rate, 212 | momentum=FLAGS.momentum, 213 | name='Momentum') 214 | elif FLAGS.optimizer == 'rmsprop': 215 | optimizer = tf.train.RMSPropOptimizer( 216 | learning_rate, 217 | decay=FLAGS.rmsprop_decay, 218 | momentum=FLAGS.rmsprop_momentum, 219 | epsilon=FLAGS.opt_epsilon) 220 | elif FLAGS.optimizer == 'sgd': 221 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 222 | else: 223 | raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer) 224 | return optimizer 225 | 226 | 227 | def _get_variables_to_train(options): 228 | """Returns a list of variables to train. 229 | 230 | Args: 231 | A list of variables to train by the optimizer. 232 | """ 233 | if options.get('trainable_scopes') is None: 234 | return tf.trainable_variables() 235 | else: 236 | scopes = [scope.strip() for scope in options.get('trainable_scopes').split(',')] 237 | 238 | variables_to_train = [] 239 | for scope in scopes: 240 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) 241 | variables_to_train.extend(variables) 242 | return variables_to_train 243 | 244 | 245 | def _get_init_fn(options): 246 | """Returns a function to warm-start the training. 247 | 248 | Note that the init_fn is only run when initializing the models during the 249 | very first global step. 250 | 251 | Returns: 252 | An init function 253 | """ 254 | if options.get('checkpoint_path') is None: 255 | return None 256 | 257 | # Warn the user if a checkpoint exists in the train_dir. Then we'll be 258 | # ignoring the checkpoint anyway. 259 | if tf.train.latest_checkpoint(FLAGS.train_dir): 260 | tf.logging.info( 261 | 'Ignoring --checkpoint_path because a checkpoint already exists ' 262 | 'in %s' % FLAGS.train_dir) 263 | return None 264 | 265 | exclusions = [] 266 | if options.get('checkpoint_exclude_scopes'): 267 | # remove space and comma 268 | exclusions = [scope.strip() 269 | for scope in options.get('checkpoint_exclude_scopes').split(',')] 270 | 271 | variables_to_restore = [] 272 | for var in slim.get_model_variables(): 273 | excluded = False 274 | for exclusion in exclusions: 275 | if var.op.name.startswith(exclusion): 276 | excluded = True 277 | break 278 | if not excluded: 279 | variables_to_restore.append(var) 280 | 281 | if tf.gfile.IsDirectory(options.get('checkpoint_path')): 282 | checkpoint_path = tf.train.latest_checkpoint(options.get('checkpoint_path')) 283 | else: 284 | checkpoint_path = options.get('checkpoint_path') 285 | 286 | tf.logging.info('Fine-tuning from %s' % checkpoint_path) 287 | 288 | return slim.assign_from_checkpoint_fn( 289 | checkpoint_path, 290 | variables_to_restore, 291 | ignore_missing_vars=options.get('ignore_missing_vars')) 292 | 293 | 294 | def main(_): 295 | if not FLAGS.dataset_dir: 296 | raise ValueError('You must supply the dataset directory with' 297 | ' --dataset_dir') 298 | 299 | tf.logging.set_verbosity(tf.logging.INFO) 300 | with tf.Graph().as_default(): 301 | global_step = slim.create_global_step() # create the global step 302 | 303 | ###################### 304 | # select the dataset # 305 | ###################### 306 | dataset = dataset_utils.get_split( 307 | FLAGS.dataset_name, 308 | FLAGS.dataset_split_name, 309 | FLAGS.dataset_dir) 310 | 311 | ###################### 312 | # create the network # 313 | ###################### 314 | # parse the options from a yaml file 315 | model, options = models_factory.get_model(FLAGS.model_config) 316 | 317 | #################################################### 318 | # create a dataset provider that loads the dataset # 319 | #################################################### 320 | # dataset provider 321 | provider = slim.dataset_data_provider.DatasetDataProvider( 322 | dataset, 323 | num_readers=FLAGS.num_readers, 324 | common_queue_capacity=20*FLAGS.batch_size, 325 | common_queue_min=10*FLAGS.batch_size) 326 | [image] = provider.get(['image']) 327 | image_clip = preprocessing_image( 328 | image, 329 | model.training_image_size, 330 | model.training_image_size, 331 | model.content_size, 332 | is_training=True) 333 | image_clip_batch = tf.train.batch( 334 | [image_clip], 335 | batch_size=FLAGS.batch_size, 336 | num_threads=FLAGS.num_preprocessing_threads, 337 | capacity=5*FLAGS.batch_size) 338 | 339 | # feque queue the inputs 340 | batch_queue = slim.prefetch_queue.prefetch_queue([image_clip_batch]) 341 | 342 | ########################################### 343 | # build the models based on the given data # 344 | ########################################### 345 | images = batch_queue.dequeue() 346 | total_loss = model.build_train_graph(images) 347 | 348 | #################################################### 349 | # gather the operations for training and summaries # 350 | #################################################### 351 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) 352 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 353 | 354 | # configurate the moving averages 355 | if FLAGS.moving_average_decay: 356 | moving_average_variables = slim.get_model_variables() 357 | variable_averages = tf.train.ExponentialMovingAverage( 358 | FLAGS.moving_average_decay, global_step) 359 | else: 360 | moving_average_variables, variable_averages = None, None 361 | 362 | # gather the optimizer operations 363 | learning_rate = _configure_learning_rate( 364 | dataset.num_samples, global_step) 365 | optimizer = _configure_optimizer(learning_rate) 366 | summaries.add(tf.summary.scalar('learning_rate', learning_rate)) 367 | 368 | if FLAGS.moving_average_decay: 369 | update_ops.append(variable_averages.apply(moving_average_variables)) 370 | 371 | # training operations 372 | train_op = model.get_training_operations( 373 | optimizer, global_step, _get_variables_to_train(options)) 374 | update_ops.append(train_op) 375 | 376 | # gather the training summaries 377 | summaries |= set(model.summaries) 378 | 379 | # gather the update operation 380 | update_op = tf.group(*update_ops) 381 | watched_loss = control_flow_ops.with_dependencies( 382 | [update_op], total_loss, name='train_op') 383 | 384 | # merge the summaries 385 | summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES)) 386 | summary_op = tf.summary.merge(list(summaries), name='summary_op') 387 | 388 | ############################## 389 | # start the training process # 390 | ############################## 391 | slim.learning.train( 392 | watched_loss, 393 | logdir=FLAGS.train_dir, 394 | init_fn=_get_init_fn(options), 395 | summary_op=summary_op, 396 | number_of_steps=FLAGS.max_number_of_steps, 397 | log_every_n_steps=FLAGS.log_every_n_steps, 398 | save_summaries_secs=FLAGS.save_summaries_secs, 399 | save_interval_secs=FLAGS.save_interval_secs) 400 | 401 | 402 | if __name__ == '__main__': 403 | tf.app.run() 404 | --------------------------------------------------------------------------------