├── .github
└── FUNDING.yml
├── .gitignore
├── LICENCE
├── README.md
├── data
├── content-images
│ ├── figures.jpg
│ ├── lion.jpg
│ └── taj_mahal.jpg
├── examples
│ ├── candy_model
│ │ ├── figures_width_500_model_candy_resize_230.jpg
│ │ └── taj_mahal_width_800_model_candy_resize_230.jpg
│ ├── edtaonisl_model
│ │ ├── figures_width_500_model_edtaonisl_9e5_33k_resized_230.jpg
│ │ └── taj_mahal_width_500_model_edtaonisl_9e5_33k_resized_230.jpg
│ ├── figures_width_550_model_mosaic_4e5_e2.jpg
│ ├── lion_width_360_model_mosaic_4e5_e2.jpg
│ ├── mosaic_model
│ │ ├── figures_width_500_model_mosaic_4e5_e2_resized_230.jpg
│ │ └── taj_mahal_width_500_model_mosaic_4e5_e2_resized_230.jpg
│ ├── readme_pics
│ │ ├── loss_curves.PNG
│ │ ├── monitor_img1.jpg
│ │ ├── monitor_img2.jpg
│ │ ├── spike.png
│ │ └── statistics.PNG
│ ├── starry_night_model
│ │ ├── figures_width_500_model_starry_v3_resize_230.jpg
│ │ └── taj_mahal_width_500_model_starry_v3_resize_230.jpg
│ └── style_images
│ │ ├── candy_resize_230.jpg
│ │ ├── edtaonisl_crop_resized_230.jpg
│ │ ├── mosaic_crop_resized_230.jpg
│ │ └── vg_starry_night_resized_230.jpg
└── style-images
│ ├── candy.jpg
│ ├── edtaonisl.jpg
│ ├── mosaic.jpg
│ └── vg_starry_night.jpg
├── environment.yml
├── models
└── definitions
│ ├── __init__.py
│ ├── perceptual_loss_net.py
│ └── transformer_net.py
├── stylization_script.py
├── training_script.py
└── utils
├── __init__.py
├── resource_downloader.py
└── utils.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | patreon: theaiepiphany
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | __pycache__
--------------------------------------------------------------------------------
/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Aleksa Gordić
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Fast Neural Style Transfer (feed-forward method) :zap::computer: + :art: = :heart:
2 | This repo contains a concise PyTorch implementation of the original feed-forward NST paper (:link: [Johnson et al.](https://arxiv.org/pdf/1603.08155.pdf)).
3 |
4 | Checkout my implementation of the original NST (optimization method) paper ([Gatys et al.](https://github.com/gordicaleksa/pytorch-neural-style-transfer)).
5 |
6 | It's an accompanying repo for [this video series on YouTube](https://www.youtube.com/watch?v=S78LQebx6jo&list=PLBoQnSflObcmbfshq9oNs41vODgXG-608).
7 |
8 |
9 |
11 |
12 |
13 | ### Why yet another Fast NST (feed-forward method) repo?
14 | It's the **cleanest and most concise** NST repo that I know of + it's written in **PyTorch!** :heart:
15 |
16 | My idea :bulb: is to make the code so simple and well commented, that you can use it as a **first step on your NST learning journey** before any other blog, course, book or research paper. :books:
17 |
18 | I've included automatic, pretrained models and MS COCO dataset, download script - so you can either **instantaneously run it** and get the results (:art: stylized images) using pretrained models **or start training/experimenting with your own models**. :rocket:
19 |
20 | ## Examples
21 |
22 | Here are some examples with the [4 pretrained models](https://www.dropbox.com/s/fb39gscd1b42px1/pretrained_models.zip?dl=0) (automatic download enabled - look at [usage section](#usage)):
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 | *Note:* keep in mind that I still need to improve these models, 3 of these (last 3 rows) only saw 33k images from MS COCO.
43 |
44 | ## Setup
45 |
46 | 1. Open Anaconda Prompt and navigate into project directory `cd path_to_repo`
47 | 2. Run `conda env create` from project directory (this will create a brand new conda environment).
48 | 3. Run `activate pytorch-nst-fast` (if you want to run scripts from your console otherwise set the interpreter in your IDE)
49 |
50 | That's it! It should work out-of-the-box executing environment.yml file which deals with dependencies.
51 |
52 | -----
53 |
54 | PyTorch package will pull some version of CUDA with it, but it is highly recommended that you install system-wide CUDA beforehand, mostly because of GPU drivers. I also recommend using Miniconda installer as a way to get conda on your system.
55 |
56 | Follow through points 1 and 2 of [this setup](https://github.com/Petlja/PSIML/blob/master/docs/MachineSetup.md) and use the most up-to-date versions of Miniconda and CUDA/cuDNN (I recommend CUDA 10.1 or 10.2 as those are compatible with PyTorch 1.5, which is used in this repo, and newest compatible cuDNN).
57 |
58 | ## Usage
59 |
60 | Go through this section to run the project but if you are still having problems take a look at [this (stylization)](https://www.youtube.com/watch?v=lOR-LncQlk8&list=PLBoQnSflObcmbfshq9oNs41vODgXG-608&index=5) and [this (training)](https://www.youtube.com/watch?v=EuXd-aO77A0&list=PLBoQnSflObcmbfshq9oNs41vODgXG-608&index=6) accompanying YouTube videos.
61 |
62 | ### Stylization
63 |
64 | 1. Download pretrained models, run: `python utils/resource_downloader.py`
65 | 2. Run `python stylization_script.py` (it's got default content image and model set)
66 |
67 | That's it! If you want more flexibility (and I guess you do) there's a couple more nuggets of info.
68 |
69 | More expressive command is:
70 | `python stylization_script.py --content_input --img_width --model_name `
71 |
72 | If you pass a directory into `--content_input` it will perform batch stylization.
73 | You can control the batch size (in case you have VRAM problems) with `batch_size` param.
74 |
75 | -----
76 |
77 | You just need to specify the **names**, the repo automatically finds content images and models in default directories:
78 | 1. content images default dir: `/data/content-images/`
79 | 2. model binaries default dir: `/models/binaries/`
80 |
81 | So all **you** got to do is place images and models there and you can use them. Output will be dumped to `/data/output-images/`.
82 |
83 | After you run resource_downloader.py script binaries dir will be pre-populated with 4 pretrained models.
84 |
85 | Go ahead, play with it and make some art!
86 |
87 |
88 |
89 |
90 |
91 |
92 | ### Training your own models
93 |
94 | 1. Download MS COCO dataset, run `python utils/resource_downloader.py -r mscoco_dataset` (it's a 12.5 GB file)
95 | 2. Run `python training_script.py --style_img_name `
96 |
97 | Now that will probably actually work!
98 |
99 | It will periodically dump checkpoint models to `/models/checkpoints/` and the final model to `/models/binaries/` by default.
100 |
101 | I strongly recommend playing with these 2 params:
102 | 1. **style_weight** - I always kept it in the [1e5, 9e5] range, you may have to tweak it for your specific style image a little bit
103 | 2. **subset_size** - Usually 32k images do the job (that's 8k batches) - you'll need to monitor **tensorboard** to figure out if your curves are saturating at that point or not. If they are still going down set the number higher
104 |
105 | That brings us to the next section!
106 |
107 | ### Tensorboard Visualizations
108 |
109 | To **start tensorboard** just run: `tensorboard --logdir=runs --samples_per_plugin images=50` from your conda console.
110 |
111 | `samples_per_plugin images=` sets the number of images you'll be able to see when moving the image slider.
112 |
113 | There's basically **2 things you want to monitor** during your training (not counting console output <- redundant if you use tensor board)
114 |
115 | #### Monitor your loss/statistics curves
116 |
117 | You want to keep `content-loss` and `style-loss` going down or at least one of them (style loss usually saturates first).
118 |
119 | I usually set tv weight to 0 so that's why you see 0 on the `tv-loss` curve. You should use it only if you see that your images are having smoothness problem ([check this out](https://github.com/gordicaleksa/pytorch-neural-style-transfer#impact-of-total-variation-tv-loss) for visualization of what exactly tv weight does).
120 |
121 |
122 |
123 |
124 |
125 | Statistics curves let me understand how the stylized image coming out of the transformer net behaves.
126 |
127 |
128 |
129 |
130 |
131 | If max or min intensities start diverging or mean/median start going to far away from 0 that's a good indicator that your (probably) style weight is not good. You can keep the content weight constant and just tweak the style weight.
132 |
133 | #### Monitor your intermediate stylized images
134 |
135 | This one helps immensely so as to help you manually early-stop your training if you don't like the stylized output you see.
136 |
137 |
138 |
139 |
140 |
141 |
142 | In the beggining stylized images look kinda rubish like the one one the left. As the training progresses you'll get more meaningful images popping out (the one on the right).
143 |
144 | ## Debugging
145 | Q: My style/content loss curves just spiked in the middle of training?
146 | A: 2 options: a) rerun the training (optimizer got into a bad state) b) if that doesn't work lower your style weight
147 |
148 |
149 |
150 |
151 |
152 | Q: How can I see the exact parameters that you used to train your models?
153 | A: Just run the model in the `stylization_script.py`, training metadata will be printed out to the output console.
154 |
155 | ## Further experimentation (advanced, for researchers)
156 |
157 | There's a couple of things you could experiment with (assuming fixed net architectures), here are some ideas:
158 | 1. Try and set MSE to `sum` reduction for the style loss. I used that method [here](https://github.com/gordicaleksa/pytorch-neural-style-transfer/blob/master/neural_style_transfer.py) and it gave nice results. You'll have to play with style-weight afterwards to get it running. This will effectively give bigger weight to deeper style representations because Gram matrices coming out from deeper layers are bigger. Meaning you'll give advantage to high-level style features (broad spatial characteristics of the style) over low level style features (smaller neighborhood characteristics like fine brush-strokes).
159 | 2. Original paper used `tanh activation` at the output - figure out how you can get it to work using that, you may have to add some scaling. There is this magic constant 150 that Johnson originally used to scale tanh. I created [this issue](https://github.com/jcjohnson/fast-neural-style/issues/186) as it is un-clear of how it came to be and whether it was just experimentally figured out.
160 | 3. PyTorch VGG16 pretrained model was trained on the 0..1 range ImageNet normalized images. Try and work with 0..255 range ImageNet mean-only-normalized images - that will also work! It worked [here](https://github.com/gordicaleksa/pytorch-neural-style-transfer/blob/master/utils/utils.py) and if you try and feed such an image into VGG16 (as a classifier) it will give you correct predictions!
161 | 4. [This repo](https://github.com/pytorch/examples/tree/master/fast_neural_style) used 0..255 images (no normalization) as an input to transformer net - play with that. You'll have to normalize transformer net output before feeding that to VGG16.
162 |
163 | Some of these may further improve the visual quality that you get from these models! If you find something interesting I'd like to hear from you!
164 |
165 | ## Acknowledgements
166 |
167 | I found these repos useful: (while developing this one)
168 | * [fast_neural_style](https://github.com/pytorch/examples/tree/master/fast_neural_style) (PyTorch, feed-forward method)
169 | * [pytorch-neural-style-transfer](https://github.com/gordicaleksa/pytorch-neural-style-transfer) (PyTorch, optimization method)
170 | * [original J.Johnson's repo](https://github.com/jcjohnson/fast-neural-style) (Torch, written in Lua)
171 |
172 | I found some of the content/style images I was using here:
173 | * [style/artistic images](https://www.rawpixel.com/board/537381/vincent-van-gogh-free-original-public-domain-paintings?sort=curated&mode=shop&page=1)
174 | * [awesome figures pic](https://www.pexels.com/photo/action-android-device-electronics-595804/)
175 |
176 | Other images are now already classics in the NST world.
177 |
178 | ## Citation
179 |
180 | If you find this code useful for your research, please cite the following:
181 |
182 | ```
183 | @misc{Gordić2020nst-fast,
184 | author = {Gordić, Aleksa},
185 | title = {pytorch-nst-feedforward},
186 | year = {2020},
187 | publisher = {GitHub},
188 | journal = {GitHub repository},
189 | howpublished = {\url{https://github.com/gordicaleksa/pytorch-nst-feedforward}},
190 | }
191 | ```
192 |
193 | ## Connect with me
194 |
195 | If you'd love to have some more AI-related content in your life :nerd_face:, consider:
196 | * Subscribing to my YouTube channel [The AI Epiphany](https://www.youtube.com/c/TheAiEpiphany) :bell:
197 | * Follow me on [LinkedIn](https://www.linkedin.com/in/aleksagordic/) and [Twitter](https://twitter.com/gordic_aleksa) :bulb:
198 | * Follow me on [Medium](https://gordicaleksa.medium.com/) :books: :heart:
199 |
200 | ## Licence
201 |
202 | [](https://github.com/gordicaleksa/pytorch-nst-feedforward/blob/master/LICENCE)
--------------------------------------------------------------------------------
/data/content-images/figures.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/content-images/figures.jpg
--------------------------------------------------------------------------------
/data/content-images/lion.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/content-images/lion.jpg
--------------------------------------------------------------------------------
/data/content-images/taj_mahal.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/content-images/taj_mahal.jpg
--------------------------------------------------------------------------------
/data/examples/candy_model/figures_width_500_model_candy_resize_230.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/candy_model/figures_width_500_model_candy_resize_230.jpg
--------------------------------------------------------------------------------
/data/examples/candy_model/taj_mahal_width_800_model_candy_resize_230.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/candy_model/taj_mahal_width_800_model_candy_resize_230.jpg
--------------------------------------------------------------------------------
/data/examples/edtaonisl_model/figures_width_500_model_edtaonisl_9e5_33k_resized_230.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/edtaonisl_model/figures_width_500_model_edtaonisl_9e5_33k_resized_230.jpg
--------------------------------------------------------------------------------
/data/examples/edtaonisl_model/taj_mahal_width_500_model_edtaonisl_9e5_33k_resized_230.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/edtaonisl_model/taj_mahal_width_500_model_edtaonisl_9e5_33k_resized_230.jpg
--------------------------------------------------------------------------------
/data/examples/figures_width_550_model_mosaic_4e5_e2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/figures_width_550_model_mosaic_4e5_e2.jpg
--------------------------------------------------------------------------------
/data/examples/lion_width_360_model_mosaic_4e5_e2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/lion_width_360_model_mosaic_4e5_e2.jpg
--------------------------------------------------------------------------------
/data/examples/mosaic_model/figures_width_500_model_mosaic_4e5_e2_resized_230.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/mosaic_model/figures_width_500_model_mosaic_4e5_e2_resized_230.jpg
--------------------------------------------------------------------------------
/data/examples/mosaic_model/taj_mahal_width_500_model_mosaic_4e5_e2_resized_230.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/mosaic_model/taj_mahal_width_500_model_mosaic_4e5_e2_resized_230.jpg
--------------------------------------------------------------------------------
/data/examples/readme_pics/loss_curves.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/readme_pics/loss_curves.PNG
--------------------------------------------------------------------------------
/data/examples/readme_pics/monitor_img1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/readme_pics/monitor_img1.jpg
--------------------------------------------------------------------------------
/data/examples/readme_pics/monitor_img2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/readme_pics/monitor_img2.jpg
--------------------------------------------------------------------------------
/data/examples/readme_pics/spike.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/readme_pics/spike.png
--------------------------------------------------------------------------------
/data/examples/readme_pics/statistics.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/readme_pics/statistics.PNG
--------------------------------------------------------------------------------
/data/examples/starry_night_model/figures_width_500_model_starry_v3_resize_230.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/starry_night_model/figures_width_500_model_starry_v3_resize_230.jpg
--------------------------------------------------------------------------------
/data/examples/starry_night_model/taj_mahal_width_500_model_starry_v3_resize_230.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/starry_night_model/taj_mahal_width_500_model_starry_v3_resize_230.jpg
--------------------------------------------------------------------------------
/data/examples/style_images/candy_resize_230.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/style_images/candy_resize_230.jpg
--------------------------------------------------------------------------------
/data/examples/style_images/edtaonisl_crop_resized_230.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/style_images/edtaonisl_crop_resized_230.jpg
--------------------------------------------------------------------------------
/data/examples/style_images/mosaic_crop_resized_230.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/style_images/mosaic_crop_resized_230.jpg
--------------------------------------------------------------------------------
/data/examples/style_images/vg_starry_night_resized_230.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/examples/style_images/vg_starry_night_resized_230.jpg
--------------------------------------------------------------------------------
/data/style-images/candy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/style-images/candy.jpg
--------------------------------------------------------------------------------
/data/style-images/edtaonisl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/style-images/edtaonisl.jpg
--------------------------------------------------------------------------------
/data/style-images/mosaic.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/style-images/mosaic.jpg
--------------------------------------------------------------------------------
/data/style-images/vg_starry_night.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/data/style-images/vg_starry_night.jpg
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: pytorch-nst-fast
2 | channels:
3 | - defaults
4 | - pytorch
5 | dependencies:
6 | - python==3.8.3
7 | - pip==20.0.2
8 | - matplotlib==3.1.3
9 | - pytorch==1.5.0
10 | - torchvision==0.6.0
11 | - pip:
12 | - numpy==1.18.4
13 | - opencv-python==4.2.0.32
14 | - GitPython==3.1.2
15 | - tensorboard==2.2.2
16 |
17 |
--------------------------------------------------------------------------------
/models/definitions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gordicaleksa/pytorch-neural-style-transfer-johnson/00c96e8e3f1b0b7fb4c14254fd0c6f1281a29598/models/definitions/__init__.py
--------------------------------------------------------------------------------
/models/definitions/perceptual_loss_net.py:
--------------------------------------------------------------------------------
1 | """
2 | VGG16 deep learning model is used as the perceptual loss (network).
3 | More detail about the VGG architecture (if you want to understand magic/hardcoded numbers) can be found here:
4 |
5 | https://github.com/pytorch/vision/blob/3c254fb7af5f8af252c24e89949c54a3461ff0be/torchvision/models/vgg.py
6 | """
7 |
8 | from collections import namedtuple
9 | import torch
10 | from torchvision import models
11 |
12 |
13 | class Vgg16(torch.nn.Module):
14 | """Only those layers are exposed which have already proven to work nicely."""
15 | def __init__(self, requires_grad=False, show_progress=False):
16 | super().__init__()
17 | # Keeping eval() mode only for consistency - it only affects BatchNorm and Dropout both of which we won't use
18 | vgg16 = models.vgg16(pretrained=True, progress=show_progress).eval()
19 | vgg_pretrained_features = vgg16.features
20 | self.layer_names = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3']
21 |
22 | self.slice1 = torch.nn.Sequential()
23 | self.slice2 = torch.nn.Sequential()
24 | self.slice3 = torch.nn.Sequential()
25 | self.slice4 = torch.nn.Sequential()
26 | for x in range(4):
27 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
28 | for x in range(4, 9):
29 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
30 | for x in range(9, 16):
31 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
32 | for x in range(16, 23):
33 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
34 |
35 | # Set these to False so that PyTorch won't be including them in it's autograd engine - eating up precious memory
36 | if not requires_grad:
37 | for param in self.parameters():
38 | param.requires_grad = False
39 |
40 | def forward(self, x):
41 | x = self.slice1(x)
42 | relu1_2 = x
43 | x = self.slice2(x)
44 | relu2_2 = x
45 | x = self.slice3(x)
46 | relu3_3 = x
47 | x = self.slice4(x)
48 | relu4_3 = x
49 | vgg_outputs = namedtuple("VggOutputs", self.layer_names)
50 | out = vgg_outputs(relu1_2, relu2_2, relu3_3, relu4_3)
51 | return out
52 |
53 |
54 | # Set the perceptual loss network to be VGG16
55 | PerceptualLossNet = Vgg16
56 |
--------------------------------------------------------------------------------
/models/definitions/transformer_net.py:
--------------------------------------------------------------------------------
1 | """
2 | Modifications to the original J.Johnson's architecture:
3 | 1. Instance normalization is used instead of batch normalization *
4 | 2. Instead of learned up-sampling use nearest-neighbor up-sampling followed by convolution **
5 | 3. No scaled tanh at the output of the network ***
6 |
7 | * Ulyanov showed that this gives better results, checkout the paper here: https://arxiv.org/pdf/1607.08022.pdf
8 | ** Distill pub blog showed this to have better results: http://distill.pub/2016/deconv-checkerboard/
9 | *** I tried using it even opened an issue on the original Johnson's repo (written in Lua) - no improvements
10 |
11 | Note: checkout the details about original Johnson's architecture here:
12 | https://cs.stanford.edu/people/jcjohns/papers/fast-style/fast-style-supp.pdf
13 | """
14 |
15 | import torch
16 |
17 |
18 | class TransformerNet(torch.nn.Module):
19 | def __init__(self):
20 | super().__init__()
21 | # Non-linearity
22 | self.relu = torch.nn.ReLU()
23 |
24 | # Down-sampling convolution layers
25 | num_of_channels = [3, 32, 64, 128]
26 | kernel_sizes = [9, 3, 3]
27 | stride_sizes = [1, 2, 2]
28 | self.conv1 = ConvLayer(num_of_channels[0], num_of_channels[1], kernel_size=kernel_sizes[0], stride=stride_sizes[0])
29 | self.in1 = torch.nn.InstanceNorm2d(num_of_channels[1], affine=True)
30 | self.conv2 = ConvLayer(num_of_channels[1], num_of_channels[2], kernel_size=kernel_sizes[1], stride=stride_sizes[1])
31 | self.in2 = torch.nn.InstanceNorm2d(num_of_channels[2], affine=True)
32 | self.conv3 = ConvLayer(num_of_channels[2], num_of_channels[3], kernel_size=kernel_sizes[2], stride=stride_sizes[2])
33 | self.in3 = torch.nn.InstanceNorm2d(num_of_channels[3], affine=True)
34 |
35 | # Residual layers
36 | res_block_num_of_filters = 128
37 | self.res1 = ResidualBlock(res_block_num_of_filters)
38 | self.res2 = ResidualBlock(res_block_num_of_filters)
39 | self.res3 = ResidualBlock(res_block_num_of_filters)
40 | self.res4 = ResidualBlock(res_block_num_of_filters)
41 | self.res5 = ResidualBlock(res_block_num_of_filters)
42 |
43 | # Up-sampling convolution layers
44 | num_of_channels.reverse()
45 | kernel_sizes.reverse()
46 | stride_sizes.reverse()
47 | self.up1 = UpsampleConvLayer(num_of_channels[0], num_of_channels[1], kernel_size=kernel_sizes[0], stride=stride_sizes[0])
48 | self.in4 = torch.nn.InstanceNorm2d(num_of_channels[1], affine=True)
49 | self.up2 = UpsampleConvLayer(num_of_channels[1], num_of_channels[2], kernel_size=kernel_sizes[1], stride=stride_sizes[1])
50 | self.in5 = torch.nn.InstanceNorm2d(num_of_channels[2], affine=True)
51 | self.up3 = ConvLayer(num_of_channels[2], num_of_channels[3], kernel_size=kernel_sizes[2], stride=stride_sizes[2])
52 |
53 | def forward(self, x):
54 | y = self.relu(self.in1(self.conv1(x)))
55 | y = self.relu(self.in2(self.conv2(y)))
56 | y = self.relu(self.in3(self.conv3(y)))
57 | y = self.res1(y)
58 | y = self.res2(y)
59 | y = self.res3(y)
60 | y = self.res4(y)
61 | y = self.res5(y)
62 | y = self.relu(self.in4(self.up1(y)))
63 | y = self.relu(self.in5(self.up2(y)))
64 | # No tanh activation here as originally proposed by J.Johnson, I didn't get any improvements by using it,
65 | # if you get better results using it feel free to make a PR
66 | return self.up3(y)
67 |
68 |
69 | class ConvLayer(torch.nn.Module):
70 | """
71 | A small wrapper around nn.Conv2d, so as to make the code cleaner and allow for experimentation with padding
72 | """
73 | def __init__(self, in_channels, out_channels, kernel_size, stride):
74 | super().__init__()
75 | self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=kernel_size//2, padding_mode='reflect')
76 |
77 | def forward(self, x):
78 | return self.conv2d(x)
79 |
80 |
81 | class ResidualBlock(torch.nn.Module):
82 | """
83 | Originally introduced in (Microsoft Research Asia, He et al.): https://arxiv.org/abs/1512.03385
84 | Modified architecture according to suggestions in this blog: http://torch.ch/blog/2016/02/04/resnets.html
85 |
86 | The only difference from the original is: There is no ReLU layer after the addition of identity and residual
87 | """
88 |
89 | def __init__(self, channels):
90 | super(ResidualBlock, self).__init__()
91 | kernel_size = 3
92 | stride_size = 1
93 | self.conv1 = ConvLayer(channels, channels, kernel_size=kernel_size, stride=stride_size)
94 | self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
95 | self.conv2 = ConvLayer(channels, channels, kernel_size=kernel_size, stride=stride_size)
96 | self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
97 | self.relu = torch.nn.ReLU()
98 |
99 | def forward(self, x):
100 | residual = x
101 | out = self.relu(self.in1(self.conv1(x)))
102 | out = self.in2(self.conv2(out))
103 | return out + residual # modification: no ReLu after the addition
104 |
105 |
106 | class UpsampleConvLayer(torch.nn.Module):
107 | """
108 | Nearest-neighbor up-sampling followed by a convolution
109 | Appears to give better results than learned up-sampling aka transposed conv (avoids the checkerboard artifact)
110 |
111 | Initially proposed on distill pub: http://distill.pub/2016/deconv-checkerboard/
112 | """
113 |
114 | def __init__(self, in_channels, out_channels, kernel_size, stride):
115 | super().__init__()
116 | self.upsampling_factor = stride
117 | self.conv2d = ConvLayer(in_channels, out_channels, kernel_size, stride=1)
118 |
119 | def forward(self, x):
120 | if self.upsampling_factor > 1:
121 | x = torch.nn.functional.interpolate(x, scale_factor=self.upsampling_factor, mode='nearest')
122 | return self.conv2d(x)
123 |
124 |
--------------------------------------------------------------------------------
/stylization_script.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 |
8 |
9 | import utils.utils as utils
10 | from models.definitions.transformer_net import TransformerNet
11 |
12 |
13 | def stylize_static_image(inference_config):
14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15 |
16 | # Prepare the model - load the weights and put the model into evaluation mode
17 | stylization_model = TransformerNet().to(device)
18 | training_state = torch.load(os.path.join(inference_config["model_binaries_path"], inference_config["model_name"]))
19 | state_dict = training_state["state_dict"]
20 | stylization_model.load_state_dict(state_dict, strict=True)
21 | stylization_model.eval()
22 |
23 | if inference_config['verbose']:
24 | utils.print_model_metadata(training_state)
25 |
26 | with torch.no_grad():
27 | if os.path.isdir(inference_config['content_input']): # do a batch stylization (every image in the directory)
28 | img_dataset = utils.SimpleDataset(inference_config['content_input'], inference_config['img_width'])
29 | img_loader = DataLoader(img_dataset, batch_size=inference_config['batch_size'])
30 |
31 | try:
32 | processed_imgs_cnt = 0
33 | for batch_id, img_batch in enumerate(img_loader):
34 | processed_imgs_cnt += len(img_batch)
35 | if inference_config['verbose']:
36 | print(f'Processing batch {batch_id + 1} ({processed_imgs_cnt}/{len(img_dataset)} processed images).')
37 |
38 | img_batch = img_batch.to(device)
39 | stylized_imgs = stylization_model(img_batch).to('cpu').numpy()
40 | for stylized_img in stylized_imgs:
41 | utils.save_and_maybe_display_image(inference_config, stylized_img, should_display=False)
42 | except Exception as e:
43 | print(e)
44 | print(f'Consider making the batch_size (current = {inference_config["batch_size"]} images) or img_width (current = {inference_config["img_width"]} px) smaller')
45 | exit(1)
46 |
47 | else: # do stylization for a single image
48 | content_img_path = os.path.join(inference_config['content_images_path'], inference_config['content_input'])
49 | content_image = utils.prepare_img(content_img_path, inference_config['img_width'], device)
50 | stylized_img = stylization_model(content_image).to('cpu').numpy()[0]
51 | utils.save_and_maybe_display_image(inference_config, stylized_img, should_display=inference_config['should_not_display'])
52 |
53 |
54 | if __name__ == "__main__":
55 | #
56 | # Fixed args - don't change these unless you have a good reason
57 | #
58 | content_images_path = os.path.join(os.path.dirname(__file__), 'data', 'content-images')
59 | output_images_path = os.path.join(os.path.dirname(__file__), 'data', 'output-images')
60 | model_binaries_path = os.path.join(os.path.dirname(__file__), 'models', 'binaries')
61 |
62 | assert utils.dir_contains_only_models(model_binaries_path), f'Model directory should contain only model binaries.'
63 | os.makedirs(output_images_path, exist_ok=True)
64 |
65 | #
66 | # Modifiable args - feel free to play with these
67 | #
68 | parser = argparse.ArgumentParser()
69 | # Put image name or directory containing images (if you'd like to do a batch stylization on all those images)
70 | parser.add_argument("--content_input", type=str, help="Content image(s) to stylize", default='taj_mahal.jpg')
71 | parser.add_argument("--batch_size", type=int, help="Batch size used only if you set content_input to a directory", default=5)
72 | parser.add_argument("--img_width", type=int, help="Resize content image to this width", default=500)
73 | parser.add_argument("--model_name", type=str, help="Model binary to use for stylization", default='mosaic_4e5_e2.pth')
74 |
75 | # Less frequently used arguments
76 | parser.add_argument("--should_not_display", action='store_false', help="Should display the stylized result")
77 | parser.add_argument("--verbose", action='store_true', help="Print model metadata (how the model was trained) and where the resulting stylized image was saved")
78 | parser.add_argument("--redirected_output", type=str, help="Overwrite default output dir. Useful when this project is used as a submodule", default=None)
79 | args = parser.parse_args()
80 |
81 | # if redirected output is not set when doing batch stylization set to default image output location
82 | if os.path.isdir(args.content_input) and args.redirected_output is None:
83 | args.redirected_output = output_images_path
84 |
85 | # Wrapping inference configuration into a dictionary
86 | inference_config = dict()
87 | for arg in vars(args):
88 | inference_config[arg] = getattr(args, arg)
89 | inference_config['content_images_path'] = content_images_path
90 | inference_config['output_images_path'] = output_images_path
91 | inference_config['model_binaries_path'] = model_binaries_path
92 |
93 | stylize_static_image(inference_config)
94 |
--------------------------------------------------------------------------------
/training_script.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import time
4 |
5 | import torch
6 | from torch.optim import Adam
7 | from torch.utils.tensorboard import SummaryWriter
8 | import numpy as np
9 |
10 | from models.definitions.perceptual_loss_net import PerceptualLossNet
11 | from models.definitions.transformer_net import TransformerNet
12 | import utils.utils as utils
13 |
14 |
15 | def train(training_config):
16 | writer = SummaryWriter() # (tensorboard) writer will output to ./runs/ directory by default
17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18 |
19 | # prepare data loader
20 | train_loader = utils.get_training_data_loader(training_config)
21 |
22 | # prepare neural networks
23 | transformer_net = TransformerNet().train().to(device)
24 | perceptual_loss_net = PerceptualLossNet(requires_grad=False).to(device)
25 |
26 | optimizer = Adam(transformer_net.parameters())
27 |
28 | # Calculate style image's Gram matrices (style representation)
29 | # Built over feature maps as produced by the perceptual net - VGG16
30 | style_img_path = os.path.join(training_config['style_images_path'], training_config['style_img_name'])
31 | style_img = utils.prepare_img(style_img_path, target_shape=None, device=device, batch_size=training_config['batch_size'])
32 | style_img_set_of_feature_maps = perceptual_loss_net(style_img)
33 | target_style_representation = [utils.gram_matrix(x) for x in style_img_set_of_feature_maps]
34 |
35 | utils.print_header(training_config)
36 | # Tracking loss metrics, NST is ill-posed we can only track loss and visual appearance of the stylized images
37 | acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.]
38 | ts = time.time()
39 | for epoch in range(training_config['num_of_epochs']):
40 | for batch_id, (content_batch, _) in enumerate(train_loader):
41 | # step1: Feed content batch through transformer net
42 | content_batch = content_batch.to(device)
43 | stylized_batch = transformer_net(content_batch)
44 |
45 | # step2: Feed content and stylized batch through perceptual net (VGG16)
46 | content_batch_set_of_feature_maps = perceptual_loss_net(content_batch)
47 | stylized_batch_set_of_feature_maps = perceptual_loss_net(stylized_batch)
48 |
49 | # step3: Calculate content representations and content loss
50 | target_content_representation = content_batch_set_of_feature_maps.relu2_2
51 | current_content_representation = stylized_batch_set_of_feature_maps.relu2_2
52 | content_loss = training_config['content_weight'] * torch.nn.MSELoss(reduction='mean')(target_content_representation, current_content_representation)
53 |
54 | # step4: Calculate style representation and style loss
55 | style_loss = 0.0
56 | current_style_representation = [utils.gram_matrix(x) for x in stylized_batch_set_of_feature_maps]
57 | for gram_gt, gram_hat in zip(target_style_representation, current_style_representation):
58 | style_loss += torch.nn.MSELoss(reduction='mean')(gram_gt, gram_hat)
59 | style_loss /= len(target_style_representation)
60 | style_loss *= training_config['style_weight']
61 |
62 | # step5: Calculate total variation loss - enforces image smoothness
63 | tv_loss = training_config['tv_weight'] * utils.total_variation(stylized_batch)
64 |
65 | # step6: Combine losses and do a backprop
66 | total_loss = content_loss + style_loss + tv_loss
67 | total_loss.backward()
68 | optimizer.step()
69 |
70 | optimizer.zero_grad() # clear gradients for the next round
71 |
72 | #
73 | # Logging and checkpoint creation
74 | #
75 | acc_content_loss += content_loss.item()
76 | acc_style_loss += style_loss.item()
77 | acc_tv_loss += tv_loss.item()
78 |
79 | if training_config['enable_tensorboard']:
80 | # log scalars
81 | writer.add_scalar('Loss/content-loss', content_loss.item(), len(train_loader) * epoch + batch_id + 1)
82 | writer.add_scalar('Loss/style-loss', style_loss.item(), len(train_loader) * epoch + batch_id + 1)
83 | writer.add_scalar('Loss/tv-loss', tv_loss.item(), len(train_loader) * epoch + batch_id + 1)
84 | writer.add_scalars('Statistics/min-max-mean-median', {'min': torch.min(stylized_batch), 'max': torch.max(stylized_batch), 'mean': torch.mean(stylized_batch), 'median': torch.median(stylized_batch)}, len(train_loader) * epoch + batch_id + 1)
85 | # log stylized image
86 | if batch_id % training_config['image_log_freq'] == 0:
87 | stylized = utils.post_process_image(stylized_batch[0].detach().to('cpu').numpy())
88 | stylized = np.moveaxis(stylized, 2, 0) # writer expects channel first image
89 | writer.add_image('stylized_img', stylized, len(train_loader) * epoch + batch_id + 1)
90 |
91 | if training_config['console_log_freq'] is not None and batch_id % training_config['console_log_freq'] == 0:
92 | print(f'time elapsed={(time.time()-ts)/60:.2f}[min]|epoch={epoch + 1}|batch=[{batch_id + 1}/{len(train_loader)}]|c-loss={acc_content_loss / training_config["console_log_freq"]}|s-loss={acc_style_loss / training_config["console_log_freq"]}|tv-loss={acc_tv_loss / training_config["console_log_freq"]}|total loss={(acc_content_loss + acc_style_loss + acc_tv_loss) / training_config["console_log_freq"]}')
93 | acc_content_loss, acc_style_loss, acc_tv_loss = [0., 0., 0.]
94 |
95 | if training_config['checkpoint_freq'] is not None and (batch_id + 1) % training_config['checkpoint_freq'] == 0:
96 | training_state = utils.get_training_metadata(training_config)
97 | training_state["state_dict"] = transformer_net.state_dict()
98 | training_state["optimizer_state"] = optimizer.state_dict()
99 | ckpt_model_name = f"ckpt_style_{training_config['style_img_name'].split('.')[0]}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}_epoch_{epoch}_batch_{batch_id}.pth"
100 | torch.save(training_state, os.path.join(training_config['checkpoints_path'], ckpt_model_name))
101 |
102 | #
103 | # Save model with additional metadata - like which commit was used to train the model, style/content weights, etc.
104 | #
105 | training_state = utils.get_training_metadata(training_config)
106 | training_state["state_dict"] = transformer_net.state_dict()
107 | training_state["optimizer_state"] = optimizer.state_dict()
108 | model_name = f"style_{training_config['style_img_name'].split('.')[0]}_datapoints_{training_state['num_of_datapoints']}_cw_{str(training_config['content_weight'])}_sw_{str(training_config['style_weight'])}_tw_{str(training_config['tv_weight'])}.pth"
109 | torch.save(training_state, os.path.join(training_config['model_binaries_path'], model_name))
110 |
111 |
112 | if __name__ == "__main__":
113 | #
114 | # Fixed args - don't change these unless you have a good reason
115 | #
116 | dataset_path = os.path.join(os.path.dirname(__file__), 'data', 'mscoco')
117 | style_images_path = os.path.join(os.path.dirname(__file__), 'data', 'style-images')
118 | model_binaries_path = os.path.join(os.path.dirname(__file__), 'models', 'binaries')
119 | checkpoints_root_path = os.path.join(os.path.dirname(__file__), 'models', 'checkpoints')
120 | image_size = 256 # training images from MS COCO are resized to image_size x image_size
121 | batch_size = 4
122 |
123 | assert os.path.exists(dataset_path), f'MS COCO missing. Download the dataset using resource_downloader.py script.'
124 | os.makedirs(model_binaries_path, exist_ok=True)
125 |
126 | #
127 | # Modifiable args - feel free to play with these (only a small subset is exposed by design to avoid cluttering)
128 | #
129 | parser = argparse.ArgumentParser()
130 | # training related
131 | parser.add_argument("--style_img_name", type=str, help="style image name that will be used for training", default='edtaonisl.jpg')
132 | parser.add_argument("--content_weight", type=float, help="weight factor for content loss", default=1e0) # you don't need to change this one just play with style loss
133 | parser.add_argument("--style_weight", type=float, help="weight factor for style loss", default=4e5)
134 | parser.add_argument("--tv_weight", type=float, help="weight factor for total variation loss", default=0)
135 | parser.add_argument("--num_of_epochs", type=int, help="number of training epochs ", default=2)
136 | parser.add_argument("--subset_size", type=int, help="number of MS COCO images (NOT BATCHES) to use, default is all (~83k)(specified by None)", default=None)
137 | # logging/debugging/checkpoint related (helps a lot with experimentation)
138 | parser.add_argument("--enable_tensorboard", type=bool, help="enable tensorboard logging (scalars + images)", default=True)
139 | parser.add_argument("--image_log_freq", type=int, help="tensorboard image logging (batch) frequency - enable_tensorboard must be True to use", default=100)
140 | parser.add_argument("--console_log_freq", type=int, help="logging to output console (batch) frequency", default=500)
141 | parser.add_argument("--checkpoint_freq", type=int, help="checkpoint model saving (batch) frequency", default=2000)
142 | args = parser.parse_args()
143 |
144 | checkpoints_path = os.path.join(checkpoints_root_path, args.style_img_name.split('.')[0])
145 | if args.checkpoint_freq is not None:
146 | os.makedirs(checkpoints_path, exist_ok=True)
147 |
148 | # Wrapping training configuration into a dictionary
149 | training_config = dict()
150 | for arg in vars(args):
151 | training_config[arg] = getattr(args, arg)
152 | training_config['dataset_path'] = dataset_path
153 | training_config['style_images_path'] = style_images_path
154 | training_config['model_binaries_path'] = model_binaries_path
155 | training_config['checkpoints_path'] = checkpoints_path
156 | training_config['image_size'] = image_size
157 | training_config['batch_size'] = batch_size
158 |
159 | # Original J.Johnson's training with improved transformer net architecture
160 | train(training_config)
161 |
162 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
--------------------------------------------------------------------------------
/utils/resource_downloader.py:
--------------------------------------------------------------------------------
1 | import zipfile
2 | from torch.hub import download_url_to_file
3 | import argparse
4 | import os
5 |
6 | # If the link is broken you can download the MS COCO 2014 dataset manually from http://cocodataset.org/#download
7 | MS_COCO_2014_TRAIN_DATASET_PATH = r'http://images.cocodataset.org/zips/train2014.zip' # ~13 GB after unzipping
8 |
9 | PRETRAINED_MODELS_PATH = r'https://www.dropbox.com/s/fb39gscd1b42px1/pretrained_models.zip?dl=1'
10 |
11 | DOWNLOAD_DICT = {
12 | 'pretrained_models': PRETRAINED_MODELS_PATH,
13 | 'mscoco_dataset': MS_COCO_2014_TRAIN_DATASET_PATH,
14 | }
15 | download_choices = list(DOWNLOAD_DICT.keys())
16 |
17 |
18 | if __name__ == '__main__':
19 | #
20 | # Choose whether you want to download pretrained models or MSCOCO 2014 dataset
21 | #
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument("--resource", "-r", type=str, choices=download_choices,
24 | help="specify whether you want to download ms coco dataset or pretrained models",
25 | default=download_choices[0])
26 | args = parser.parse_args()
27 |
28 | # step1: download the resource to local filesystem
29 | remote_resource_path = DOWNLOAD_DICT[args.resource]
30 | print(f'Downloading from {remote_resource_path}')
31 | resource_tmp_path = args.resource + '.zip'
32 | download_url_to_file(remote_resource_path, resource_tmp_path)
33 |
34 | # step2: unzip the resource
35 | print(f'Started unzipping...')
36 | with zipfile.ZipFile(resource_tmp_path) as zf:
37 | local_resource_path = os.path.join(os.path.dirname(__file__), os.pardir)
38 | if args.resource == 'pretrained_models':
39 | local_resource_path = os.path.join(local_resource_path, 'models', 'binaries')
40 | else:
41 | local_resource_path = os.path.join(local_resource_path, 'data', 'mscoco')
42 | os.makedirs(local_resource_path, exist_ok=True)
43 | zf.extractall(path=local_resource_path)
44 | print(f'Unzipping to: {local_resource_path} finished.')
45 |
46 | # step3: remove the temporary resource file
47 | os.remove(resource_tmp_path)
48 | print(f'Removing tmp file {resource_tmp_path}.')
49 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 |
4 |
5 | import cv2 as cv
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | from torchvision import transforms
9 | from torchvision import datasets
10 | from torch.utils.data import Dataset, DataLoader, Sampler
11 | import torch
12 | import git
13 |
14 |
15 | IMAGENET_MEAN_1 = np.array([0.485, 0.456, 0.406])
16 | IMAGENET_STD_1 = np.array([0.229, 0.224, 0.225])
17 | IMAGENET_MEAN_255 = np.array([123.675, 116.28, 103.53])
18 | # Usually when normalizing 0..255 images only mean-normalization is performed -> that's why standard dev is all 1s here
19 | IMAGENET_STD_NEUTRAL = np.array([1, 1, 1])
20 |
21 |
22 | class SimpleDataset(Dataset):
23 | def __init__(self, img_dir, target_width):
24 | self.img_dir = img_dir
25 | self.img_paths = [os.path.join(img_dir, img_name) for img_name in os.listdir(img_dir)]
26 |
27 | h, w = load_image(self.img_paths[0]).shape[:2]
28 | img_height = int(h * (target_width / w))
29 | self.target_width = target_width
30 | self.target_height = img_height
31 |
32 | self.transform = transforms.Compose([
33 | transforms.ToTensor(),
34 | transforms.Normalize(mean=IMAGENET_MEAN_1, std=IMAGENET_STD_1)
35 | ])
36 |
37 | def __len__(self):
38 | return len(self.img_paths)
39 |
40 | def __getitem__(self, idx):
41 | img = load_image(self.img_paths[idx], target_shape=(self.target_height, self.target_width))
42 | tensor = self.transform(img)
43 | return tensor
44 |
45 |
46 | def load_image(img_path, target_shape=None):
47 | if not os.path.exists(img_path):
48 | raise Exception(f'Path does not exist: {img_path}')
49 | img = cv.imread(img_path)[:, :, ::-1] # [:, :, ::-1] converts BGR (opencv format...) into RGB
50 |
51 | if target_shape is not None: # resize section
52 | if isinstance(target_shape, int) and target_shape != -1: # scalar -> implicitly setting the width
53 | current_height, current_width = img.shape[:2]
54 | new_width = target_shape
55 | new_height = int(current_height * (new_width / current_width))
56 | img = cv.resize(img, (new_width, new_height), interpolation=cv.INTER_CUBIC)
57 | else: # set both dimensions to target shape
58 | img = cv.resize(img, (target_shape[1], target_shape[0]), interpolation=cv.INTER_CUBIC)
59 |
60 | # this need to go after resizing - otherwise cv.resize will push values outside of [0,1] range
61 | img = img.astype(np.float32) # convert from uint8 to float32
62 | img /= 255.0 # get to [0, 1] range
63 | return img
64 |
65 |
66 | def prepare_img(img_path, target_shape, device, batch_size=1, should_normalize=True, is_255_range=False):
67 | img = load_image(img_path, target_shape=target_shape)
68 |
69 | transform_list = [transforms.ToTensor()]
70 | if is_255_range:
71 | transform_list.append(transforms.Lambda(lambda x: x.mul(255)))
72 | if should_normalize:
73 | transform_list.append(transforms.Normalize(mean=IMAGENET_MEAN_255, std=IMAGENET_STD_NEUTRAL) if is_255_range else transforms.Normalize(mean=IMAGENET_MEAN_1, std=IMAGENET_STD_1))
74 | transform = transforms.Compose(transform_list)
75 |
76 | img = transform(img).to(device)
77 | img = img.repeat(batch_size, 1, 1, 1)
78 |
79 | return img
80 |
81 |
82 | def post_process_image(dump_img):
83 | assert isinstance(dump_img, np.ndarray), f'Expected numpy image got {type(dump_img)}'
84 |
85 | mean = IMAGENET_MEAN_1.reshape(-1, 1, 1)
86 | std = IMAGENET_STD_1.reshape(-1, 1, 1)
87 | dump_img = (dump_img * std) + mean # de-normalize
88 | dump_img = (np.clip(dump_img, 0., 1.) * 255).astype(np.uint8)
89 | dump_img = np.moveaxis(dump_img, 0, 2)
90 | return dump_img
91 |
92 |
93 | def get_next_available_name(input_dir):
94 | img_name_pattern = re.compile(r'[0-9]{6}\.jpg')
95 | candidates = [candidate for candidate in os.listdir(input_dir) if re.fullmatch(img_name_pattern, candidate)]
96 |
97 | if len(candidates) == 0:
98 | return '000000.jpg'
99 | else:
100 | latest_file = sorted(candidates)[-1]
101 | prefix_int = int(latest_file.split('.')[0])
102 | return f'{str(prefix_int + 1).zfill(6)}.jpg'
103 |
104 |
105 | def save_and_maybe_display_image(inference_config, dump_img, should_display=False):
106 | assert isinstance(dump_img, np.ndarray), f'Expected numpy array got {type(dump_img)}.'
107 |
108 | dump_img = post_process_image(dump_img)
109 | if inference_config['img_width'] is None:
110 | inference_config['img_width'] = dump_img.shape[0]
111 |
112 | if inference_config['redirected_output'] is None:
113 | dump_dir = inference_config['output_images_path']
114 | dump_img_name = os.path.basename(inference_config['content_input']).split('.')[0] + '_width_' + str(inference_config['img_width']) + '_model_' + inference_config['model_name'].split('.')[0] + '.jpg'
115 | else: # useful when this repo is used as a utility submodule in some other repo like pytorch-naive-video-nst
116 | dump_dir = inference_config['redirected_output']
117 | os.makedirs(dump_dir, exist_ok=True)
118 | dump_img_name = get_next_available_name(inference_config['redirected_output'])
119 |
120 | cv.imwrite(os.path.join(dump_dir, dump_img_name), dump_img[:, :, ::-1]) # ::-1 because opencv expects BGR (and not RGB) format...
121 |
122 | # Don't print this information in batch stylization mode
123 | if inference_config['verbose'] and not os.path.isdir(inference_config['content_input']):
124 | print(f'Saved image to {dump_dir}.')
125 |
126 | if should_display:
127 | plt.imshow(dump_img)
128 | plt.show()
129 |
130 |
131 | class SequentialSubsetSampler(Sampler):
132 | r"""Samples elements sequentially, always in the same order from a subset defined by size.
133 |
134 | Arguments:
135 | data_source (Dataset): dataset to sample from
136 | subset_size: defines the subset from which to sample from
137 | """
138 |
139 | def __init__(self, data_source, subset_size):
140 | assert isinstance(data_source, Dataset) or isinstance(data_source, datasets.ImageFolder)
141 | self.data_source = data_source
142 |
143 | if subset_size is None: # if None -> use the whole dataset
144 | subset_size = len(data_source)
145 | assert 0 < subset_size <= len(data_source), f'Subset size should be between (0, {len(data_source)}].'
146 | self.subset_size = subset_size
147 |
148 | def __iter__(self):
149 | return iter(range(self.subset_size))
150 |
151 | def __len__(self):
152 | return self.subset_size
153 |
154 |
155 | def get_training_data_loader(training_config, should_normalize=True, is_255_range=False):
156 | """
157 | There are multiple ways to make this feed-forward NST working,
158 | including using 0..255 range (without any normalization) images during transformer net training,
159 | keeping the options if somebody wants to play and get better results.
160 | """
161 | transform_list = [transforms.Resize(training_config['image_size']),
162 | transforms.CenterCrop(training_config['image_size']),
163 | transforms.ToTensor()]
164 | if is_255_range:
165 | transform_list.append(transforms.Lambda(lambda x: x.mul(255)))
166 | if should_normalize:
167 | transform_list.append(transforms.Normalize(mean=IMAGENET_MEAN_255, std=IMAGENET_STD_NEUTRAL) if is_255_range else transforms.Normalize(mean=IMAGENET_MEAN_1, std=IMAGENET_STD_1))
168 | transform = transforms.Compose(transform_list)
169 |
170 | train_dataset = datasets.ImageFolder(training_config['dataset_path'], transform)
171 | sampler = SequentialSubsetSampler(train_dataset, training_config['subset_size'])
172 | training_config['subset_size'] = len(sampler) # update in case it was None
173 | train_loader = DataLoader(train_dataset, batch_size=training_config['batch_size'], sampler=sampler, drop_last=True)
174 | print(f'Using {len(train_loader)*training_config["batch_size"]*training_config["num_of_epochs"]} datapoints ({len(train_loader)*training_config["num_of_epochs"]} batches) (MS COCO images) for transformer network training.')
175 | return train_loader
176 |
177 |
178 | def gram_matrix(x, should_normalize=True):
179 | (b, ch, h, w) = x.size()
180 | features = x.view(b, ch, w * h)
181 | features_t = features.transpose(1, 2)
182 | gram = features.bmm(features_t)
183 | if should_normalize:
184 | gram /= ch * h * w
185 | return gram
186 |
187 |
188 | # Not used atm, you'd want to use this if you choose to go with 0..255 images in the training loader
189 | def normalize_batch(batch):
190 | batch /= 255.0
191 | mean = batch.new_tensor(IMAGENET_MEAN_1).view(-1, 1, 1)
192 | std = batch.new_tensor(IMAGENET_STD_1).view(-1, 1, 1)
193 | return (batch - mean) / std
194 |
195 |
196 | def total_variation(img_batch):
197 | batch_size = img_batch.shape[0]
198 | return (torch.sum(torch.abs(img_batch[:, :, :, :-1] - img_batch[:, :, :, 1:])) +
199 | torch.sum(torch.abs(img_batch[:, :, :-1, :] - img_batch[:, :, 1:, :]))) / batch_size
200 |
201 |
202 | def print_header(training_config):
203 | print(f'Learning the style of {training_config["style_img_name"]} style image.')
204 | print('*' * 80)
205 | print(f'Hyperparams: content_weight={training_config["content_weight"]}, style_weight={training_config["style_weight"]} and tv_weight={training_config["tv_weight"]}')
206 | print('*' * 80)
207 |
208 | if training_config["console_log_freq"]:
209 | print(f'Logging to console every {training_config["console_log_freq"]} batches.')
210 | else:
211 | print(f'Console logging disabled. Change console_log_freq if you want to use it.')
212 |
213 | if training_config["checkpoint_freq"]:
214 | print(f'Saving checkpoint models every {training_config["checkpoint_freq"]} batches.')
215 | else:
216 | print(f'Checkpoint models saving disabled.')
217 |
218 | if training_config['enable_tensorboard']:
219 | print('Tensorboard enabled.')
220 | print('Run "tensorboard --logdir=runs --samples_per_plugin images=50" from your conda env')
221 | print('Open http://localhost:6006/ in your browser and you\'re ready to use tensorboard!')
222 | else:
223 | print('Tensorboard disabled.')
224 | print('*' * 80)
225 |
226 |
227 | def get_training_metadata(training_config):
228 | num_of_datapoints = training_config['subset_size'] * training_config['num_of_epochs']
229 | training_metadata = {
230 | "commit_hash": git.Repo(search_parent_directories=True).head.object.hexsha,
231 | "content_weight": training_config['content_weight'],
232 | "style_weight": training_config['style_weight'],
233 | "tv_weight": training_config['tv_weight'],
234 | "num_of_datapoints": num_of_datapoints
235 | }
236 | return training_metadata
237 |
238 |
239 | def print_model_metadata(training_state):
240 | print('Model training metadata:')
241 | for key, value in training_state.items():
242 | if key != 'state_dict' and key != 'optimizer_state':
243 | print(key, ':', value)
244 |
245 |
246 | def dir_contains_only_models(path):
247 | assert os.path.exists(path), f'Provided path: {path} does not exist.'
248 | assert os.path.isdir(path), f'Provided path: {path} is not a directory.'
249 | list_of_files = os.listdir(path)
250 | assert len(list_of_files) > 0, f'No models found, use training_script.py to train a model or download pretrained models via resource_downloader.py.'
251 | for f in list_of_files:
252 | if not (f.endswith('.pt') or f.endswith('.pth')):
253 | return False
254 |
255 | return True
256 |
257 |
258 | # Count how many trainable weights the model has <- just for having a feeling for how big the model is
259 | def count_parameters(model):
260 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
261 |
--------------------------------------------------------------------------------