├── .github └── FUNDING.yml ├── .gitignore ├── LICENCE ├── README.md ├── data ├── content-images │ ├── figures.jpg │ ├── lion.jpg │ └── taj_mahal.jpg ├── examples │ ├── candy_model │ │ ├── figures_width_500_model_candy_resize_230.jpg │ │ └── taj_mahal_width_800_model_candy_resize_230.jpg │ ├── edtaonisl_model │ │ ├── figures_width_500_model_edtaonisl_9e5_33k_resized_230.jpg │ │ └── taj_mahal_width_500_model_edtaonisl_9e5_33k_resized_230.jpg │ ├── figures_width_550_model_mosaic_4e5_e2.jpg │ ├── lion_width_360_model_mosaic_4e5_e2.jpg │ ├── mosaic_model │ │ ├── figures_width_500_model_mosaic_4e5_e2_resized_230.jpg │ │ └── taj_mahal_width_500_model_mosaic_4e5_e2_resized_230.jpg │ ├── readme_pics │ │ ├── loss_curves.PNG │ │ ├── monitor_img1.jpg │ │ ├── monitor_img2.jpg │ │ ├── spike.png │ │ └── statistics.PNG │ ├── starry_night_model │ │ ├── figures_width_500_model_starry_v3_resize_230.jpg │ │ └── taj_mahal_width_500_model_starry_v3_resize_230.jpg │ └── style_images │ │ ├── candy_resize_230.jpg │ │ ├── edtaonisl_crop_resized_230.jpg │ │ ├── mosaic_crop_resized_230.jpg │ │ └── vg_starry_night_resized_230.jpg └── style-images │ ├── candy.jpg │ ├── edtaonisl.jpg │ ├── mosaic.jpg │ └── vg_starry_night.jpg ├── environment.yml ├── models └── definitions │ ├── __init__.py │ ├── perceptual_loss_net.py │ └── transformer_net.py ├── stylization_script.py ├── training_script.py └── utils ├── __init__.py ├── resource_downloader.py └── utils.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | patreon: theaiepiphany 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Aleksa Gordić 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Fast Neural Style Transfer (feed-forward method) :zap::computer: + :art: = :heart: 2 | This repo contains a concise PyTorch implementation of the original feed-forward NST paper (:link: [Johnson et al.](https://arxiv.org/pdf/1603.08155.pdf)). 3 | 4 | Checkout my implementation of the original NST (optimization method) paper ([Gatys et al.](https://github.com/gordicaleksa/pytorch-neural-style-transfer)). 5 | 6 | It's an accompanying repo for [this video series on YouTube](https://www.youtube.com/watch?v=S78LQebx6jo&list=PLBoQnSflObcmbfshq9oNs41vODgXG-608). 7 | 8 |

9 | NST Intro 11 |

12 | 13 | ### Why yet another Fast NST (feed-forward method) repo? 14 | It's the **cleanest and most concise** NST repo that I know of + it's written in **PyTorch!** :heart: 15 | 16 | My idea :bulb: is to make the code so simple and well commented, that you can use it as a **first step on your NST learning journey** before any other blog, course, book or research paper. :books: 17 | 18 | I've included automatic, pretrained models and MS COCO dataset, download script - so you can either **instantaneously run it** and get the results (:art: stylized images) using pretrained models **or start training/experimenting with your own models**. :rocket: 19 | 20 | ## Examples 21 | 22 | Here are some examples with the [4 pretrained models](https://www.dropbox.com/s/fb39gscd1b42px1/pretrained_models.zip?dl=0) (automatic download enabled - look at [usage section](#usage)): 23 | 24 |

25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 |

41 | 42 | *Note:* keep in mind that I still need to improve these models, 3 of these (last 3 rows) only saw 33k images from MS COCO. 43 | 44 | ## Setup 45 | 46 | 1. Open Anaconda Prompt and navigate into project directory `cd path_to_repo` 47 | 2. Run `conda env create` from project directory (this will create a brand new conda environment). 48 | 3. Run `activate pytorch-nst-fast` (if you want to run scripts from your console otherwise set the interpreter in your IDE) 49 | 50 | That's it! It should work out-of-the-box executing environment.yml file which deals with dependencies. 51 | 52 | ----- 53 | 54 | PyTorch package will pull some version of CUDA with it, but it is highly recommended that you install system-wide CUDA beforehand, mostly because of GPU drivers. I also recommend using Miniconda installer as a way to get conda on your system. 55 | 56 | Follow through points 1 and 2 of [this setup](https://github.com/Petlja/PSIML/blob/master/docs/MachineSetup.md) and use the most up-to-date versions of Miniconda and CUDA/cuDNN (I recommend CUDA 10.1 or 10.2 as those are compatible with PyTorch 1.5, which is used in this repo, and newest compatible cuDNN). 57 | 58 | ## Usage 59 | 60 | Go through this section to run the project but if you are still having problems take a look at [this (stylization)](https://www.youtube.com/watch?v=lOR-LncQlk8&list=PLBoQnSflObcmbfshq9oNs41vODgXG-608&index=5) and [this (training)](https://www.youtube.com/watch?v=EuXd-aO77A0&list=PLBoQnSflObcmbfshq9oNs41vODgXG-608&index=6) accompanying YouTube videos. 61 | 62 | ### Stylization 63 | 64 | 1. Download pretrained models, run: `python utils/resource_downloader.py` 65 | 2. Run `python stylization_script.py` (it's got default content image and model set) 66 | 67 | That's it! If you want more flexibility (and I guess you do) there's a couple more nuggets of info. 68 | 69 | More expressive command is:
70 | `python stylization_script.py --content_input --img_width --model_name ` 71 | 72 | If you pass a directory into `--content_input` it will perform batch stylization.
73 | You can control the batch size (in case you have VRAM problems) with `batch_size` param. 74 | 75 | ----- 76 | 77 | You just need to specify the **names**, the repo automatically finds content images and models in default directories: 78 | 1. content images default dir: `/data/content-images/` 79 | 2. model binaries default dir: `/models/binaries/` 80 | 81 | So all **you** got to do is place images and models there and you can use them. Output will be dumped to `/data/output-images/`. 82 | 83 | After you run resource_downloader.py script binaries dir will be pre-populated with 4 pretrained models. 84 | 85 | Go ahead, play with it and make some art! 86 | 87 |

88 | 89 | 90 |

91 | 92 | ### Training your own models 93 | 94 | 1. Download MS COCO dataset, run `python utils/resource_downloader.py -r mscoco_dataset` (it's a 12.5 GB file) 95 | 2. Run `python training_script.py --style_img_name ` 96 | 97 | Now that will probably actually work! 98 | 99 | It will periodically dump checkpoint models to `/models/checkpoints/` and the final model to `/models/binaries/` by default. 100 | 101 | I strongly recommend playing with these 2 params: 102 | 1. **style_weight** - I always kept it in the [1e5, 9e5] range, you may have to tweak it for your specific style image a little bit 103 | 2. **subset_size** - Usually 32k images do the job (that's 8k batches) - you'll need to monitor **tensorboard** to figure out if your curves are saturating at that point or not. If they are still going down set the number higher 104 | 105 | That brings us to the next section! 106 | 107 | ### Tensorboard Visualizations 108 | 109 | To **start tensorboard** just run: `tensorboard --logdir=runs --samples_per_plugin images=50` from your conda console. 110 | 111 | `samples_per_plugin images=` sets the number of images you'll be able to see when moving the image slider. 112 | 113 | There's basically **2 things you want to monitor** during your training (not counting console output <- redundant if you use tensor board) 114 | 115 | #### Monitor your loss/statistics curves 116 | 117 | You want to keep `content-loss` and `style-loss` going down or at least one of them (style loss usually saturates first). 118 | 119 | I usually set tv weight to 0 so that's why you see 0 on the `tv-loss` curve. You should use it only if you see that your images are having smoothness problem ([check this out](https://github.com/gordicaleksa/pytorch-neural-style-transfer#impact-of-total-variation-tv-loss) for visualization of what exactly tv weight does). 120 | 121 |

122 | 123 |

124 | 125 | Statistics curves let me understand how the stylized image coming out of the transformer net behaves. 126 | 127 |

128 | 129 |

130 | 131 | If max or min intensities start diverging or mean/median start going to far away from 0 that's a good indicator that your (probably) style weight is not good. You can keep the content weight constant and just tweak the style weight. 132 | 133 | #### Monitor your intermediate stylized images 134 | 135 | This one helps immensely so as to help you manually early-stop your training if you don't like the stylized output you see. 136 | 137 |

138 | 139 | 140 |

141 | 142 | In the beggining stylized images look kinda rubish like the one one the left. As the training progresses you'll get more meaningful images popping out (the one on the right). 143 | 144 | ## Debugging 145 | Q: My style/content loss curves just spiked in the middle of training?
146 | A: 2 options: a) rerun the training (optimizer got into a bad state) b) if that doesn't work lower your style weight 147 | 148 |

149 | 150 |

151 | 152 | Q: How can I see the exact parameters that you used to train your models?
153 | A: Just run the model in the `stylization_script.py`, training metadata will be printed out to the output console. 154 | 155 | ## Further experimentation (advanced, for researchers) 156 | 157 | There's a couple of things you could experiment with (assuming fixed net architectures), here are some ideas: 158 | 1. Try and set MSE to `sum` reduction for the style loss. I used that method [here](https://github.com/gordicaleksa/pytorch-neural-style-transfer/blob/master/neural_style_transfer.py) and it gave nice results. You'll have to play with style-weight afterwards to get it running. This will effectively give bigger weight to deeper style representations because Gram matrices coming out from deeper layers are bigger. Meaning you'll give advantage to high-level style features (broad spatial characteristics of the style) over low level style features (smaller neighborhood characteristics like fine brush-strokes). 159 | 2. Original paper used `tanh activation` at the output - figure out how you can get it to work using that, you may have to add some scaling. There is this magic constant 150 that Johnson originally used to scale tanh. I created [this issue](https://github.com/jcjohnson/fast-neural-style/issues/186) as it is un-clear of how it came to be and whether it was just experimentally figured out. 160 | 3. PyTorch VGG16 pretrained model was trained on the 0..1 range ImageNet normalized images. Try and work with 0..255 range ImageNet mean-only-normalized images - that will also work! It worked [here](https://github.com/gordicaleksa/pytorch-neural-style-transfer/blob/master/utils/utils.py) and if you try and feed such an image into VGG16 (as a classifier) it will give you correct predictions! 161 | 4. [This repo](https://github.com/pytorch/examples/tree/master/fast_neural_style) used 0..255 images (no normalization) as an input to transformer net - play with that. You'll have to normalize transformer net output before feeding that to VGG16. 162 | 163 | Some of these may further improve the visual quality that you get from these models! If you find something interesting I'd like to hear from you! 164 | 165 | ## Acknowledgements 166 | 167 | I found these repos useful: (while developing this one) 168 | * [fast_neural_style](https://github.com/pytorch/examples/tree/master/fast_neural_style) (PyTorch, feed-forward method) 169 | * [pytorch-neural-style-transfer](https://github.com/gordicaleksa/pytorch-neural-style-transfer) (PyTorch, optimization method) 170 | * [original J.Johnson's repo](https://github.com/jcjohnson/fast-neural-style) (Torch, written in Lua) 171 | 172 | I found some of the content/style images I was using here: 173 | * [style/artistic images](https://www.rawpixel.com/board/537381/vincent-van-gogh-free-original-public-domain-paintings?sort=curated&mode=shop&page=1) 174 | * [awesome figures pic](https://www.pexels.com/photo/action-android-device-electronics-595804/) 175 | 176 | Other images are now already classics in the NST world. 177 | 178 | ## Citation 179 | 180 | If you find this code useful for your research, please cite the following: 181 | 182 | ``` 183 | @misc{Gordić2020nst-fast, 184 | author = {Gordić, Aleksa}, 185 | title = {pytorch-nst-feedforward}, 186 | year = {2020}, 187 | publisher = {GitHub}, 188 | journal = {GitHub repository}, 189 | howpublished = {\url{https://github.com/gordicaleksa/pytorch-nst-feedforward}}, 190 | } 191 | ``` 192 | 193 | ## Connect with me 194 | 195 | If you'd love to have some more AI-related content in your life :nerd_face:, consider: 196 | * Subscribing to my YouTube channel [The AI Epiphany](https://www.youtube.com/c/TheAiEpiphany) :bell: 197 | * Follow me on [LinkedIn](https://www.linkedin.com/in/aleksagordic/) and [Twitter](https://twitter.com/gordic_aleksa) :bulb: 198 | * Follow me on [Medium](https://gordicaleksa.medium.com/) :books: :heart: 199 | 200 | ## Licence 201 | 202 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://github.com/gordicaleksa/pytorch-nst-feedforward/blob/master/LICENCE) -------------------------------------------------------------------------------- /data/content-images/figures.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/content-images/figures.jpg -------------------------------------------------------------------------------- /data/content-images/lion.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/content-images/lion.jpg -------------------------------------------------------------------------------- /data/content-images/taj_mahal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/content-images/taj_mahal.jpg -------------------------------------------------------------------------------- /data/examples/candy_model/figures_width_500_model_candy_resize_230.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/candy_model/figures_width_500_model_candy_resize_230.jpg -------------------------------------------------------------------------------- /data/examples/candy_model/taj_mahal_width_800_model_candy_resize_230.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/candy_model/taj_mahal_width_800_model_candy_resize_230.jpg -------------------------------------------------------------------------------- /data/examples/edtaonisl_model/figures_width_500_model_edtaonisl_9e5_33k_resized_230.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/edtaonisl_model/figures_width_500_model_edtaonisl_9e5_33k_resized_230.jpg -------------------------------------------------------------------------------- /data/examples/edtaonisl_model/taj_mahal_width_500_model_edtaonisl_9e5_33k_resized_230.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/edtaonisl_model/taj_mahal_width_500_model_edtaonisl_9e5_33k_resized_230.jpg -------------------------------------------------------------------------------- /data/examples/figures_width_550_model_mosaic_4e5_e2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/figures_width_550_model_mosaic_4e5_e2.jpg -------------------------------------------------------------------------------- /data/examples/lion_width_360_model_mosaic_4e5_e2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/lion_width_360_model_mosaic_4e5_e2.jpg -------------------------------------------------------------------------------- /data/examples/mosaic_model/figures_width_500_model_mosaic_4e5_e2_resized_230.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/mosaic_model/figures_width_500_model_mosaic_4e5_e2_resized_230.jpg -------------------------------------------------------------------------------- /data/examples/mosaic_model/taj_mahal_width_500_model_mosaic_4e5_e2_resized_230.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/mosaic_model/taj_mahal_width_500_model_mosaic_4e5_e2_resized_230.jpg -------------------------------------------------------------------------------- /data/examples/readme_pics/loss_curves.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/readme_pics/loss_curves.PNG -------------------------------------------------------------------------------- /data/examples/readme_pics/monitor_img1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/readme_pics/monitor_img1.jpg -------------------------------------------------------------------------------- /data/examples/readme_pics/monitor_img2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/readme_pics/monitor_img2.jpg -------------------------------------------------------------------------------- /data/examples/readme_pics/spike.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/readme_pics/spike.png -------------------------------------------------------------------------------- /data/examples/readme_pics/statistics.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/readme_pics/statistics.PNG -------------------------------------------------------------------------------- /data/examples/starry_night_model/figures_width_500_model_starry_v3_resize_230.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/starry_night_model/figures_width_500_model_starry_v3_resize_230.jpg -------------------------------------------------------------------------------- /data/examples/starry_night_model/taj_mahal_width_500_model_starry_v3_resize_230.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/starry_night_model/taj_mahal_width_500_model_starry_v3_resize_230.jpg -------------------------------------------------------------------------------- /data/examples/style_images/candy_resize_230.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/style_images/candy_resize_230.jpg -------------------------------------------------------------------------------- /data/examples/style_images/edtaonisl_crop_resized_230.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/style_images/edtaonisl_crop_resized_230.jpg -------------------------------------------------------------------------------- /data/examples/style_images/mosaic_crop_resized_230.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/style_images/mosaic_crop_resized_230.jpg -------------------------------------------------------------------------------- /data/examples/style_images/vg_starry_night_resized_230.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/style_images/vg_starry_night_resized_230.jpg -------------------------------------------------------------------------------- /data/style-images/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/style-images/candy.jpg -------------------------------------------------------------------------------- /data/style-images/edtaonisl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/style-images/edtaonisl.jpg -------------------------------------------------------------------------------- /data/style-images/mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/style-images/mosaic.jpg -------------------------------------------------------------------------------- /data/style-images/vg_starry_night.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/style-images/vg_starry_night.jpg -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pytorch-nst-fast 2 | channels: 3 | - defaults 4 | - pytorch 5 | dependencies: 6 | - python==3.8.3 7 | - pip==20.0.2 8 | - matplotlib==3.1.3 9 | - pytorch==1.5.0 10 | - torchvision==0.6.0 11 | - pip: 12 | - numpy==1.18.4 13 | - opencv-python==4.2.0.32 14 | - GitPython==3.1.2 15 | - tensorboard==2.2.2 16 | 17 | -------------------------------------------------------------------------------- /models/definitions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/models/definitions/__init__.py -------------------------------------------------------------------------------- /models/definitions/perceptual_loss_net.py: -------------------------------------------------------------------------------- 1 | """ 2 | VGG16 deep learning model is used as the perceptual loss (network). 3 | More detail about the VGG architecture (if you want to understand magic/hardcoded numbers) can be found here: 4 | 5 | https://github.com/pytorch/vision/blob/3c254fb7af5f8af252c24e89949c54a3461ff0be/torchvision/models/vgg.py 6 | """ 7 | 8 | from collections import namedtuple 9 | import torch 10 | from torchvision import models 11 | 12 | 13 | class Vgg16(torch.nn.Module): 14 | """Only those layers are exposed which have already proven to work nicely.""" 15 | def __init__(self, requires_grad=False, show_progress=False): 16 | super().__init__() 17 | # Keeping eval() mode only for consistency - it only affects BatchNorm and Dropout both of which we won't use 18 | vgg16 = models.vgg16(pretrained=True, progress=show_progress).eval() 19 | vgg_pretrained_features = vgg16.features 20 | self.layer_names = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'] 21 | 22 | self.slice1 = torch.nn.Sequential() 23 | self.slice2 = torch.nn.Sequential() 24 | self.slice3 = torch.nn.Sequential() 25 | self.slice4 = torch.nn.Sequential() 26 | for x in range(4): 27 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 28 | for x in range(4, 9): 29 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 30 | for x in range(9, 16): 31 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 32 | for x in range(16, 23): 33 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 34 | 35 | # Set these to False so that PyTorch won't be including them in it's autograd engine - eating up precious memory 36 | if not requires_grad: 37 | for param in self.parameters(): 38 | param.requires_grad = False 39 | 40 | def forward(self, x): 41 | x = self.slice1(x) 42 | relu1_2 = x 43 | x = self.slice2(x) 44 | relu2_2 = x 45 | x = self.slice3(x) 46 | relu3_3 = x 47 | x = self.slice4(x) 48 | relu4_3 = x 49 | vgg_outputs = namedtuple("VggOutputs", self.layer_names) 50 | out = vgg_outputs(relu1_2, relu2_2, relu3_3, relu4_3) 51 | return out 52 | 53 | 54 | # Set the perceptual loss network to be VGG16 55 | PerceptualLossNet = Vgg16 56 | -------------------------------------------------------------------------------- /models/definitions/transformer_net.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modifications to the original J.Johnson's architecture: 3 | 1. Instance normalization is used instead of batch normalization * 4 | 2. Instead of learned up-sampling use nearest-neighbor up-sampling followed by convolution ** 5 | 3. No scaled tanh at the output of the network *** 6 | 7 | * Ulyanov showed that this gives better results, checkout the paper here: https://arxiv.org/pdf/1607.08022.pdf 8 | ** Distill pub blog showed this to have better results: http://distill.pub/2016/deconv-checkerboard/ 9 | *** I tried using it even opened an issue on the original Johnson's repo (written in Lua) - no improvements 10 | 11 | Note: checkout the details about original Johnson's architecture here: 12 | https://cs.stanford.edu/people/jcjohns/papers/fast-style/fast-style-supp.pdf 13 | """ 14 | 15 | import torch 16 | 17 | 18 | class TransformerNet(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | # Non-linearity 22 | self.relu = torch.nn.ReLU() 23 | 24 | # Down-sampling convolution layers 25 | num_of_channels = [3, 32, 64, 128] 26 | kernel_sizes = [9, 3, 3] 27 | stride_sizes = [1, 2, 2] 28 | self.conv1 = ConvLayer(num_of_channels[0], num_of_channels[1], kernel_size=kernel_sizes[0], stride=stride_sizes[0]) 29 | self.in1 = torch.nn.InstanceNorm2d(num_of_channels[1], affine=True) 30 | self.conv2 = ConvLayer(num_of_channels[1], num_of_channels[2], kernel_size=kernel_sizes[1], stride=stride_sizes[1]) 31 | self.in2 = torch.nn.InstanceNorm2d(num_of_channels[2], affine=True) 32 | self.conv3 = ConvLayer(num_of_channels[2], num_of_channels[3], kernel_size=kernel_sizes[2], stride=stride_sizes[2]) 33 | self.in3 = torch.nn.InstanceNorm2d(num_of_channels[3], affine=True) 34 | 35 | # Residual layers 36 | res_block_num_of_filters = 128 37 | self.res1 = ResidualBlock(res_block_num_of_filters) 38 | self.res2 = ResidualBlock(res_block_num_of_filters) 39 | self.res3 = ResidualBlock(res_block_num_of_filters) 40 | self.res4 = ResidualBlock(res_block_num_of_filters) 41 | self.res5 = ResidualBlock(res_block_num_of_filters) 42 | 43 | # Up-sampling convolution layers 44 | num_of_channels.reverse() 45 | kernel_sizes.reverse() 46 | stride_sizes.reverse() 47 | self.up1 = UpsampleConvLayer(num_of_channels[0], num_of_channels[1], kernel_size=kernel_sizes[0], stride=stride_sizes[0]) 48 | self.in4 = torch.nn.InstanceNorm2d(num_of_channels[1], affine=True) 49 | self.up2 = UpsampleConvLayer(num_of_channels[1], num_of_channels[2], kernel_size=kernel_sizes[1], stride=stride_sizes[1]) 50 | self.in5 = torch.nn.InstanceNorm2d(num_of_channels[2], affine=True) 51 | self.up3 = ConvLayer(num_of_channels[2], num_of_channels[3], kernel_size=kernel_sizes[2], stride=stride_sizes[2]) 52 | 53 | def forward(self, x): 54 | y = self.relu(self.in1(self.conv1(x))) 55 | y = self.relu(self.in2(self.conv2(y))) 56 | y = self.relu(self.in3(self.conv3(y))) 57 | y = self.res1(y) 58 | y = self.res2(y) 59 | y = self.res3(y) 60 | y = self.res4(y) 61 | y = self.res5(y) 62 | y = self.relu(self.in4(self.up1(y))) 63 | y = self.relu(self.in5(self.up2(y))) 64 | # No tanh activation here as originally proposed by J.Johnson, I didn't get any improvements by using it, 65 | # if you get better results using it feel free to make a PR 66 | return self.up3(y) 67 | 68 | 69 | class ConvLayer(torch.nn.Module): 70 | """ 71 | A small wrapper around nn.Conv2d, so as to make the code cleaner and allow for experimentation with padding 72 | """ 73 | def __init__(self, in_channels, out_channels, kernel_size, stride): 74 | super().__init__() 75 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=kernel_size//2, padding_mode='reflect') 76 | 77 | def forward(self, x): 78 | return self.conv2d(x) 79 | 80 | 81 | class ResidualBlock(torch.nn.Module): 82 | """ 83 | Originally introduced in (Microsoft Research Asia, He et al.): https://arxiv.org/abs/1512.03385 84 | Modified architecture according to suggestions in this blog: http://torch.ch/blog/2016/02/04/resnets.html 85 | 86 | The only difference from the original is: There is no ReLU layer after the addition of identity and residual 87 | """ 88 | 89 | def __init__(self, channels): 90 | super(ResidualBlock, self).__init__() 91 | kernel_size = 3 92 | stride_size = 1 93 | self.conv1 = ConvLayer(channels, channels, kernel_size=kernel_size, stride=stride_size) 94 | self.in1 = torch.nn.InstanceNorm2d(channels, affine=True) 95 | self.conv2 = ConvLayer(channels, channels, kernel_size=kernel_size, stride=stride_size) 96 | self.in2 = torch.nn.InstanceNorm2d(channels, affine=True) 97 | self.relu = torch.nn.ReLU() 98 | 99 | def forward(self, x): 100 | residual = x 101 | out = self.relu(self.in1(self.conv1(x))) 102 | out = self.in2(self.conv2(out)) 103 | return out + residual # modification: no ReLu after the addition 104 | 105 | 106 | class UpsampleConvLayer(torch.nn.Module): 107 | """ 108 | Nearest-neighbor up-sampling followed by a convolution 109 | Appears to give better results than learned up-sampling aka transposed conv (avoids the checkerboard artifact) 110 | 111 | Initially proposed on distill pub: http://distill.pub/2016/deconv-checkerboard/ 112 | """ 113 | 114 | def __init__(self, in_channels, out_channels, kernel_size, stride): 115 | super().__init__() 116 | self.upsampling_factor = stride 117 | self.conv2d = ConvLayer(in_channels, out_channels, kernel_size, stride=1) 118 | 119 | def forward(self, x): 120 | if self.upsampling_factor > 1: 121 | x = torch.nn.functional.interpolate(x, scale_factor=self.upsampling_factor, mode='nearest') 122 | return self.conv2d(x) 123 | 124 | -------------------------------------------------------------------------------- /stylization_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | 8 | 9 | import utils.utils as utils 10 | from models.definitions.transformer_net import TransformerNet 11 | 12 | 13 | def stylize_static_image(inference_config): 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | # Prepare the model - load the weights and put the model into evaluation mode 17 | stylization_model = TransformerNet().to(device) 18 | training_state = torch.load(os.path.join(inference_config["model_binaries_path"], inference_config["model_name"])) 19 | state_dict = training_state["state_dict"] 20 | stylization_model.load_state_dict(state_dict, strict=True) 21 | stylization_model.eval() 22 | 23 | if inference_config['verbose']: 24 | utils.print_model_metadata(training_state) 25 | 26 | with torch.no_grad(): 27 | if os.path.isdir(inference_config['content_input']): # do a batch stylization (every image in the directory) 28 | img_dataset = utils.SimpleDataset(inference_config['content_input'], inference_config['img_width']) 29 | img_loader = DataLoader(img_dataset, batch_size=inference_config['batch_size']) 30 | 31 | try: 32 | processed_imgs_cnt = 0 33 | for batch_id, img_batch in enumerate(img_loader): 34 | processed_imgs_cnt += len(img_batch) 35 | if inference_config['verbose']: 36 | print(f'Processing batch {batch_id + 1} ({processed_imgs_cnt}/{len(img_dataset)} processed images).') 37 | 38 | img_batch = img_batch.to(device) 39 | stylized_imgs = stylization_model(img_batch).to('cpu').numpy() 40 | for stylized_img in stylized_imgs: 41 | utils.save_and_maybe_display_image(inference_config, stylized_img, should_display=False) 42 | except Exception as e: 43 | print(e) 44 | print(f'Consider making the batch_size (current = {inference_config["batch_size"]} images) or img_width (current = {inference_config["img_width"]} px) smaller') 45 | exit(1) 46 | 47 | else: # do stylization for a single image 48 | content_img_path = os.path.join(inference_config['content_images_path'], inference_config['content_input']) 49 | content_image = utils.prepare_img(content_img_path, inference_config['img_width'], device) 50 | stylized_img = stylization_model(content_image).to('cpu').numpy()[0] 51 | utils.save_and_maybe_display_image(inference_config, stylized_img, should_display=inference_config['should_not_display']) 52 | 53 | 54 | if __name__ == "__main__": 55 | # 56 | # Fixed args - don't change these unless you have a good reason 57 | # 58 | content_images_path = os.path.join(os.path.dirname(__file__), 'data', 'content-images') 59 | output_images_path = os.path.join(os.path.dirname(__file__), 'data', 'output-images') 60 | model_binaries_path = os.path.join(os.path.dirname(__file__), 'models', 'binaries') 61 | 62 | assert utils.dir_contains_only_models(model_binaries_path), f'Model directory should contain only model binaries.' 63 | os.makedirs(output_images_path, exist_ok=True) 64 | 65 | # 66 | # Modifiable args - feel free to play with these 67 | # 68 | parser = argparse.ArgumentParser() 69 | # Put image name or directory containing images (if you'd like to do a batch stylization on all those images) 70 | parser.add_argument("--content_input", type=str, help="Content image(s) to stylize", default='taj_mahal.jpg') 71 | parser.add_argument("--batch_size", type=int, help="Batch size used only if you set content_input to a directory", default=5) 72 | parser.add_argument("--img_width", type=int, help="Resize content image to this width", default=500) 73 | parser.add_argument("--model_name", type=str, help="Model binary to use for stylization", default='mosaic_4e5_e2.pth') 74 | 75 | # Less frequently used arguments 76 | parser.add_argument("--should_not_display", action='store_false', help="Should display the stylized result") 77 | parser.add_argument("--verbose", action='store_true', help="Print model metadata (how the model was trained) and where the resulting stylized image was saved") 78 | parser.add_argument("--redirected_output", type=str, help="Overwrite default output dir. Useful when this project is used as a submodule", default=None) 79 | args = parser.parse_args() 80 | 81 | # if redirected output is not set when doing batch stylization set to default image output location 82 | if os.path.isdir(args.content_input) and args.redirected_output is None: 83 | args.redirected_output = output_images_path 84 | 85 | # Wrapping inference configuration into a dictionary 86 | inference_config = dict() 87 | for arg in vars(args): 88 | inference_config[arg] = getattr(args, arg) 89 | inference_config['content_images_path'] = content_images_path 90 | inference_config['output_images_path'] = output_images_path 91 | inference_config['model_binaries_path'] = model_binaries_path 92 | 93 | stylize_static_image(inference_config) 94 | -------------------------------------------------------------------------------- /training_script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | 5 | import torch 6 | from torch.optim import Adam 7 | from torch.utils.tensorboard import SummaryWriter 8 | import numpy as np 9 | 10 | from models.definitions.perceptual_loss_net import PerceptualLossNet 11 | from models.definitions.transformer_net import TransformerNet 12 | import utils.utils as utils 13 | 14 | 15 | def train(training_config): 16 | writer = SummaryWriter() # (tensorboard) writer will output to ./runs/ directory by default 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | # prepare data loader 20 | train_loader = utils.get_training_data_loader(training_config) 21 | 22 | # prepare neural networks 23 | transformer_net = TransformerNet().train().to(device) 24 | perceptual_loss_net = PerceptualLossNet(requires_grad=False).to(device) 25 | 26 | optimizer = Adam(transformer_net.parameters()) 27 | 28 | # Calculate style image's Gram matrices (style representation) 29 | # Built over feature maps as produced by the perceptual net - VGG16 30 | style_img_path = os.path.join(training_config['style_images_path'], training_config['style_img_name']) 31 | style_img = utils.prepare_img(style_img_path, target_shape=None, device=device, batch_size=training_config['batch_size']) 32 | style_img_set_of_feature_maps = perceptual_loss_net(style_img) 33 | target_style_representation = [utils.gram_matrix(x) for x in style_img_set_of_feature_maps] 34 | 35 | utils.print_header(training_config) 36 | # Tracking loss metrics, NST is ill-posed we can only track loss and visual appearance of the stylized images 37 | acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.] 38 | ts = time.time() 39 | for epoch in range(training_config['num_of_epochs']): 40 | for batch_id, (content_batch, _) in enumerate(train_loader): 41 | # step1: Feed content batch through transformer net 42 | content_batch = content_batch.to(device) 43 | stylized_batch = transformer_net(content_batch) 44 | 45 | # step2: Feed content and stylized batch through perceptual net (VGG16) 46 | content_batch_set_of_feature_maps = perceptual_loss_net(content_batch) 47 | stylized_batch_set_of_feature_maps = perceptual_loss_net(stylized_batch) 48 | 49 | # step3: Calculate content representations and content loss 50 | target_content_representation = content_batch_set_of_feature_maps.relu2_2 51 | current_content_representation = stylized_batch_set_of_feature_maps.relu2_2 52 | content_loss = training_config['content_weight'] * torch.nn.MSELoss(reduction='mean')(target_content_representation, current_content_representation) 53 | 54 | # step4: Calculate style representation and style loss 55 | style_loss = 0.0 56 | current_style_representation = [utils.gram_matrix(x) for x in stylized_batch_set_of_feature_maps] 57 | for gram_gt, gram_hat in zip(target_style_representation, current_style_representation): 58 | style_loss += torch.nn.MSELoss(reduction='mean')(gram_gt, gram_hat) 59 | style_loss /= len(target_style_representation) 60 | style_loss *= training_config['style_weight'] 61 | 62 | # step5: Calculate total variation loss - enforces image smoothness 63 | tv_loss = training_config['tv_weight'] * utils.total_variation(stylized_batch) 64 | 65 | # step6: Combine losses and do a backprop 66 | total_loss = content_loss + style_loss + tv_loss 67 | total_loss.backward() 68 | optimizer.step() 69 | 70 | optimizer.zero_grad() # clear gradients for the next round 71 | 72 | # 73 | # Logging and checkpoint creation 74 | # 75 | acc_content_loss += content_loss.item() 76 | acc_style_loss += style_loss.item() 77 | acc_tv_loss += tv_loss.item() 78 | 79 | if training_config['enable_tensorboard']: 80 | # log scalars 81 | writer.add_scalar('Loss/content-loss', content_loss.item(), len(train_loader) * epoch + batch_id + 1) 82 | writer.add_scalar('Loss/style-loss', style_loss.item(), len(train_loader) * epoch + batch_id + 1) 83 | writer.add_scalar('Loss/tv-loss', tv_loss.item(), len(train_loader) * epoch + batch_id + 1) 84 | writer.add_scalars('Statistics/min-max-mean-median', {'min': torch.min(stylized_batch), 'max': torch.max(stylized_batch), 'mean': torch.mean(stylized_batch), 'median': torch.median(stylized_batch)}, len(train_loader) * epoch + batch_id + 1) 85 | # log stylized image 86 | if batch_id % training_config['image_log_freq'] == 0: 87 | stylized = utils.post_process_image(stylized_batch[0].detach().to('cpu').numpy()) 88 | stylized = np.moveaxis(stylized, 2, 0) # writer expects channel first image 89 | writer.add_image('stylized_img', stylized, len(train_loader) * epoch + batch_id + 1) 90 | 91 | if training_config['console_log_freq'] is not None and batch_id % training_config['console_log_freq'] == 0: 92 | print(f'time elapsed={(time.time()-ts)/60:.2f}[min]|epoch={epoch + 1}|batch=[{batch_id + 1}/{len(train_loader)}]|c-loss={acc_content_loss / training_config["console_log_freq"]}|s-loss={acc_style_loss / training_config["console_log_freq"]}|tv-loss={acc_tv_loss / training_config["console_log_freq"]}|total loss={(acc_content_loss + acc_style_loss + acc_tv_loss) / training_config["console_log_freq"]}') 93 | acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.] 94 | 95 | if training_config['checkpoint_freq'] is not None and (batch_id + 1) % training_config['checkpoint_freq'] == 0: 96 | training_state = utils.get_training_metadata(training_config) 97 | training_state["state_dict"] = transformer_net.state_dict() 98 | training_state["optimizer_state"] = optimizer.state_dict() 99 | ckpt_model_name = f"ckpt_style_{training_config['style_img_name'].split('.')[0]}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}_epoch_{epoch}_batch_{batch_id}.pth" 100 | torch.save(training_state, os.path.join(training_config['checkpoints_path'], ckpt_model_name)) 101 | 102 | # 103 | # Save model with additional metadata - like which commit was used to train the model, style/content weights, etc. 104 | # 105 | training_state = utils.get_training_metadata(training_config) 106 | training_state["state_dict"] = transformer_net.state_dict() 107 | training_state["optimizer_state"] = optimizer.state_dict() 108 | model_name = f"style_{training_config['style_img_name'].split('.')[0]}_datapoints_{training_state['num_of_datapoints']}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}.pth" 109 | torch.save(training_state, os.path.join(training_config['model_binaries_path'], model_name)) 110 | 111 | 112 | if __name__ == "__main__": 113 | # 114 | # Fixed args - don't change these unless you have a good reason 115 | # 116 | dataset_path = os.path.join(os.path.dirname(__file__), 'data', 'mscoco') 117 | style_images_path = os.path.join(os.path.dirname(__file__), 'data', 'style-images') 118 | model_binaries_path = os.path.join(os.path.dirname(__file__), 'models', 'binaries') 119 | checkpoints_root_path = os.path.join(os.path.dirname(__file__), 'models', 'checkpoints') 120 | image_size = 256 # training images from MS COCO are resized to image_size x image_size 121 | batch_size = 4 122 | 123 | assert os.path.exists(dataset_path), f'MS COCO missing. Download the dataset using resource_downloader.py script.' 124 | os.makedirs(model_binaries_path, exist_ok=True) 125 | 126 | # 127 | # Modifiable args - feel free to play with these (only a small subset is exposed by design to avoid cluttering) 128 | # 129 | parser = argparse.ArgumentParser() 130 | # training related 131 | parser.add_argument("--style_img_name", type=str, help="style image name that will be used for training", default='edtaonisl.jpg') 132 | parser.add_argument("--content_weight", type=float, help="weight factor for content loss", default=1e0) # you don't need to change this one just play with style loss 133 | parser.add_argument("--style_weight", type=float, help="weight factor for style loss", default=4e5) 134 | parser.add_argument("--tv_weight", type=float, help="weight factor for total variation loss", default=0) 135 | parser.add_argument("--num_of_epochs", type=int, help="number of training epochs ", default=2) 136 | parser.add_argument("--subset_size", type=int, help="number of MS COCO images (NOT BATCHES) to use, default is all (~83k)(specified by None)", default=None) 137 | # logging/debugging/checkpoint related (helps a lot with experimentation) 138 | parser.add_argument("--enable_tensorboard", type=bool, help="enable tensorboard logging (scalars + images)", default=True) 139 | parser.add_argument("--image_log_freq", type=int, help="tensorboard image logging (batch) frequency - enable_tensorboard must be True to use", default=100) 140 | parser.add_argument("--console_log_freq", type=int, help="logging to output console (batch) frequency", default=500) 141 | parser.add_argument("--checkpoint_freq", type=int, help="checkpoint model saving (batch) frequency", default=2000) 142 | args = parser.parse_args() 143 | 144 | checkpoints_path = os.path.join(checkpoints_root_path, args.style_img_name.split('.')[0]) 145 | if args.checkpoint_freq is not None: 146 | os.makedirs(checkpoints_path, exist_ok=True) 147 | 148 | # Wrapping training configuration into a dictionary 149 | training_config = dict() 150 | for arg in vars(args): 151 | training_config[arg] = getattr(args, arg) 152 | training_config['dataset_path'] = dataset_path 153 | training_config['style_images_path'] = style_images_path 154 | training_config['model_binaries_path'] = model_binaries_path 155 | training_config['checkpoints_path'] = checkpoints_path 156 | training_config['image_size'] = image_size 157 | training_config['batch_size'] = batch_size 158 | 159 | # Original J.Johnson's training with improved transformer net architecture 160 | train(training_config) 161 | 162 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | -------------------------------------------------------------------------------- /utils/resource_downloader.py: -------------------------------------------------------------------------------- 1 | import zipfile 2 | from torch.hub import download_url_to_file 3 | import argparse 4 | import os 5 | 6 | # If the link is broken you can download the MS COCO 2014 dataset manually from http://cocodataset.org/#download 7 | MS_COCO_2014_TRAIN_DATASET_PATH = r'http://images.cocodataset.org/zips/train2014.zip' # ~13 GB after unzipping 8 | 9 | PRETRAINED_MODELS_PATH = r'https://www.dropbox.com/s/fb39gscd1b42px1/pretrained_models.zip?dl=1' 10 | 11 | DOWNLOAD_DICT = { 12 | 'pretrained_models': PRETRAINED_MODELS_PATH, 13 | 'mscoco_dataset': MS_COCO_2014_TRAIN_DATASET_PATH, 14 | } 15 | download_choices = list(DOWNLOAD_DICT.keys()) 16 | 17 | 18 | if __name__ == '__main__': 19 | # 20 | # Choose whether you want to download pretrained models or MSCOCO 2014 dataset 21 | # 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--resource", "-r", type=str, choices=download_choices, 24 | help="specify whether you want to download ms coco dataset or pretrained models", 25 | default=download_choices[0]) 26 | args = parser.parse_args() 27 | 28 | # step1: download the resource to local filesystem 29 | remote_resource_path = DOWNLOAD_DICT[args.resource] 30 | print(f'Downloading from {remote_resource_path}') 31 | resource_tmp_path = args.resource + '.zip' 32 | download_url_to_file(remote_resource_path, resource_tmp_path) 33 | 34 | # step2: unzip the resource 35 | print(f'Started unzipping...') 36 | with zipfile.ZipFile(resource_tmp_path) as zf: 37 | local_resource_path = os.path.join(os.path.dirname(__file__), os.pardir) 38 | if args.resource == 'pretrained_models': 39 | local_resource_path = os.path.join(local_resource_path, 'models', 'binaries') 40 | else: 41 | local_resource_path = os.path.join(local_resource_path, 'data', 'mscoco') 42 | os.makedirs(local_resource_path, exist_ok=True) 43 | zf.extractall(path=local_resource_path) 44 | print(f'Unzipping to: {local_resource_path} finished.') 45 | 46 | # step3: remove the temporary resource file 47 | os.remove(resource_tmp_path) 48 | print(f'Removing tmp file {resource_tmp_path}.') 49 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | 5 | import cv2 as cv 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from torchvision import transforms 9 | from torchvision import datasets 10 | from torch.utils.data import Dataset, DataLoader, Sampler 11 | import torch 12 | import git 13 | 14 | 15 | IMAGENET_MEAN_1 = np.array([0.485, 0.456, 0.406]) 16 | IMAGENET_STD_1 = np.array([0.229, 0.224, 0.225]) 17 | IMAGENET_MEAN_255 = np.array([123.675, 116.28, 103.53]) 18 | # Usually when normalizing 0..255 images only mean-normalization is performed -> that's why standard dev is all 1s here 19 | IMAGENET_STD_NEUTRAL = np.array([1, 1, 1]) 20 | 21 | 22 | class SimpleDataset(Dataset): 23 | def __init__(self, img_dir, target_width): 24 | self.img_dir = img_dir 25 | self.img_paths = [os.path.join(img_dir, img_name) for img_name in os.listdir(img_dir)] 26 | 27 | h, w = load_image(self.img_paths[0]).shape[:2] 28 | img_height = int(h * (target_width / w)) 29 | self.target_width = target_width 30 | self.target_height = img_height 31 | 32 | self.transform = transforms.Compose([ 33 | transforms.ToTensor(), 34 | transforms.Normalize(mean=IMAGENET_MEAN_1, std=IMAGENET_STD_1) 35 | ]) 36 | 37 | def __len__(self): 38 | return len(self.img_paths) 39 | 40 | def __getitem__(self, idx): 41 | img = load_image(self.img_paths[idx], target_shape=(self.target_height, self.target_width)) 42 | tensor = self.transform(img) 43 | return tensor 44 | 45 | 46 | def load_image(img_path, target_shape=None): 47 | if not os.path.exists(img_path): 48 | raise Exception(f'Path does not exist: {img_path}') 49 | img = cv.imread(img_path)[:, :, ::-1] # [:, :, ::-1] converts BGR (opencv format...) into RGB 50 | 51 | if target_shape is not None: # resize section 52 | if isinstance(target_shape, int) and target_shape != -1: # scalar -> implicitly setting the width 53 | current_height, current_width = img.shape[:2] 54 | new_width = target_shape 55 | new_height = int(current_height * (new_width / current_width)) 56 | img = cv.resize(img, (new_width, new_height), interpolation=cv.INTER_CUBIC) 57 | else: # set both dimensions to target shape 58 | img = cv.resize(img, (target_shape[1], target_shape[0]), interpolation=cv.INTER_CUBIC) 59 | 60 | # this need to go after resizing - otherwise cv.resize will push values outside of [0,1] range 61 | img = img.astype(np.float32) # convert from uint8 to float32 62 | img /= 255.0 # get to [0, 1] range 63 | return img 64 | 65 | 66 | def prepare_img(img_path, target_shape, device, batch_size=1, should_normalize=True, is_255_range=False): 67 | img = load_image(img_path, target_shape=target_shape) 68 | 69 | transform_list = [transforms.ToTensor()] 70 | if is_255_range: 71 | transform_list.append(transforms.Lambda(lambda x: x.mul(255))) 72 | if should_normalize: 73 | transform_list.append(transforms.Normalize(mean=IMAGENET_MEAN_255, std=IMAGENET_STD_NEUTRAL) if is_255_range else transforms.Normalize(mean=IMAGENET_MEAN_1, std=IMAGENET_STD_1)) 74 | transform = transforms.Compose(transform_list) 75 | 76 | img = transform(img).to(device) 77 | img = img.repeat(batch_size, 1, 1, 1) 78 | 79 | return img 80 | 81 | 82 | def post_process_image(dump_img): 83 | assert isinstance(dump_img, np.ndarray), f'Expected numpy image got {type(dump_img)}' 84 | 85 | mean = IMAGENET_MEAN_1.reshape(-1, 1, 1) 86 | std = IMAGENET_STD_1.reshape(-1, 1, 1) 87 | dump_img = (dump_img * std) + mean # de-normalize 88 | dump_img = (np.clip(dump_img, 0., 1.) * 255).astype(np.uint8) 89 | dump_img = np.moveaxis(dump_img, 0, 2) 90 | return dump_img 91 | 92 | 93 | def get_next_available_name(input_dir): 94 | img_name_pattern = re.compile(r'[0-9]{6}\.jpg') 95 | candidates = [candidate for candidate in os.listdir(input_dir) if re.fullmatch(img_name_pattern, candidate)] 96 | 97 | if len(candidates) == 0: 98 | return '000000.jpg' 99 | else: 100 | latest_file = sorted(candidates)[-1] 101 | prefix_int = int(latest_file.split('.')[0]) 102 | return f'{str(prefix_int + 1).zfill(6)}.jpg' 103 | 104 | 105 | def save_and_maybe_display_image(inference_config, dump_img, should_display=False): 106 | assert isinstance(dump_img, np.ndarray), f'Expected numpy array got {type(dump_img)}.' 107 | 108 | dump_img = post_process_image(dump_img) 109 | if inference_config['img_width'] is None: 110 | inference_config['img_width'] = dump_img.shape[0] 111 | 112 | if inference_config['redirected_output'] is None: 113 | dump_dir = inference_config['output_images_path'] 114 | dump_img_name = os.path.basename(inference_config['content_input']).split('.')[0] + '_width_' + str(inference_config['img_width']) + '_model_' + inference_config['model_name'].split('.')[0] + '.jpg' 115 | else: # useful when this repo is used as a utility submodule in some other repo like pytorch-naive-video-nst 116 | dump_dir = inference_config['redirected_output'] 117 | os.makedirs(dump_dir, exist_ok=True) 118 | dump_img_name = get_next_available_name(inference_config['redirected_output']) 119 | 120 | cv.imwrite(os.path.join(dump_dir, dump_img_name), dump_img[:, :, ::-1]) # ::-1 because opencv expects BGR (and not RGB) format... 121 | 122 | # Don't print this information in batch stylization mode 123 | if inference_config['verbose'] and not os.path.isdir(inference_config['content_input']): 124 | print(f'Saved image to {dump_dir}.') 125 | 126 | if should_display: 127 | plt.imshow(dump_img) 128 | plt.show() 129 | 130 | 131 | class SequentialSubsetSampler(Sampler): 132 | r"""Samples elements sequentially, always in the same order from a subset defined by size. 133 | 134 | Arguments: 135 | data_source (Dataset): dataset to sample from 136 | subset_size: defines the subset from which to sample from 137 | """ 138 | 139 | def __init__(self, data_source, subset_size): 140 | assert isinstance(data_source, Dataset) or isinstance(data_source, datasets.ImageFolder) 141 | self.data_source = data_source 142 | 143 | if subset_size is None: # if None -> use the whole dataset 144 | subset_size = len(data_source) 145 | assert 0 < subset_size <= len(data_source), f'Subset size should be between (0, {len(data_source)}].' 146 | self.subset_size = subset_size 147 | 148 | def __iter__(self): 149 | return iter(range(self.subset_size)) 150 | 151 | def __len__(self): 152 | return self.subset_size 153 | 154 | 155 | def get_training_data_loader(training_config, should_normalize=True, is_255_range=False): 156 | """ 157 | There are multiple ways to make this feed-forward NST working, 158 | including using 0..255 range (without any normalization) images during transformer net training, 159 | keeping the options if somebody wants to play and get better results. 160 | """ 161 | transform_list = [transforms.Resize(training_config['image_size']), 162 | transforms.CenterCrop(training_config['image_size']), 163 | transforms.ToTensor()] 164 | if is_255_range: 165 | transform_list.append(transforms.Lambda(lambda x: x.mul(255))) 166 | if should_normalize: 167 | transform_list.append(transforms.Normalize(mean=IMAGENET_MEAN_255, std=IMAGENET_STD_NEUTRAL) if is_255_range else transforms.Normalize(mean=IMAGENET_MEAN_1, std=IMAGENET_STD_1)) 168 | transform = transforms.Compose(transform_list) 169 | 170 | train_dataset = datasets.ImageFolder(training_config['dataset_path'], transform) 171 | sampler = SequentialSubsetSampler(train_dataset, training_config['subset_size']) 172 | training_config['subset_size'] = len(sampler) # update in case it was None 173 | train_loader = DataLoader(train_dataset, batch_size=training_config['batch_size'], sampler=sampler, drop_last=True) 174 | print(f'Using {len(train_loader)*training_config["batch_size"]*training_config["num_of_epochs"]} datapoints ({len(train_loader)*training_config["num_of_epochs"]} batches) (MS COCO images) for transformer network training.') 175 | return train_loader 176 | 177 | 178 | def gram_matrix(x, should_normalize=True): 179 | (b, ch, h, w) = x.size() 180 | features = x.view(b, ch, w * h) 181 | features_t = features.transpose(1, 2) 182 | gram = features.bmm(features_t) 183 | if should_normalize: 184 | gram /= ch * h * w 185 | return gram 186 | 187 | 188 | # Not used atm, you'd want to use this if you choose to go with 0..255 images in the training loader 189 | def normalize_batch(batch): 190 | batch /= 255.0 191 | mean = batch.new_tensor(IMAGENET_MEAN_1).view(-1, 1, 1) 192 | std = batch.new_tensor(IMAGENET_STD_1).view(-1, 1, 1) 193 | return (batch - mean) / std 194 | 195 | 196 | def total_variation(img_batch): 197 | batch_size = img_batch.shape[0] 198 | return (torch.sum(torch.abs(img_batch[:, :, :, :-1] - img_batch[:, :, :, 1:])) + 199 | torch.sum(torch.abs(img_batch[:, :, :-1, :] - img_batch[:, :, 1:, :]))) / batch_size 200 | 201 | 202 | def print_header(training_config): 203 | print(f'Learning the style of {training_config["style_img_name"]} style image.') 204 | print('*' * 80) 205 | print(f'Hyperparams: content_weight={training_config["content_weight"]}, style_weight={training_config["style_weight"]} and tv_weight={training_config["tv_weight"]}') 206 | print('*' * 80) 207 | 208 | if training_config["console_log_freq"]: 209 | print(f'Logging to console every {training_config["console_log_freq"]} batches.') 210 | else: 211 | print(f'Console logging disabled. Change console_log_freq if you want to use it.') 212 | 213 | if training_config["checkpoint_freq"]: 214 | print(f'Saving checkpoint models every {training_config["checkpoint_freq"]} batches.') 215 | else: 216 | print(f'Checkpoint models saving disabled.') 217 | 218 | if training_config['enable_tensorboard']: 219 | print('Tensorboard enabled.') 220 | print('Run "tensorboard --logdir=runs --samples_per_plugin images=50" from your conda env') 221 | print('Open http://localhost:6006/ in your browser and you\'re ready to use tensorboard!') 222 | else: 223 | print('Tensorboard disabled.') 224 | print('*' * 80) 225 | 226 | 227 | def get_training_metadata(training_config): 228 | num_of_datapoints = training_config['subset_size'] * training_config['num_of_epochs'] 229 | training_metadata = { 230 | "commit_hash": git.Repo(search_parent_directories=True).head.object.hexsha, 231 | "content_weight": training_config['content_weight'], 232 | "style_weight": training_config['style_weight'], 233 | "tv_weight": training_config['tv_weight'], 234 | "num_of_datapoints": num_of_datapoints 235 | } 236 | return training_metadata 237 | 238 | 239 | def print_model_metadata(training_state): 240 | print('Model training metadata:') 241 | for key, value in training_state.items(): 242 | if key != 'state_dict' and key != 'optimizer_state': 243 | print(key, ':', value) 244 | 245 | 246 | def dir_contains_only_models(path): 247 | assert os.path.exists(path), f'Provided path: {path} does not exist.' 248 | assert os.path.isdir(path), f'Provided path: {path} is not a directory.' 249 | list_of_files = os.listdir(path) 250 | assert len(list_of_files) > 0, f'No models found, use training_script.py to train a model or download pretrained models via resource_downloader.py.' 251 | for f in list_of_files: 252 | if not (f.endswith('.pt') or f.endswith('.pth')): 253 | return False 254 | 255 | return True 256 | 257 | 258 | # Count how many trainable weights the model has <- just for having a feeling for how big the model is 259 | def count_parameters(model): 260 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 261 | --------------------------------------------------------------------------------