├── .gitignore ├── LICENSE ├── README.md ├── assets ├── PFFB.png ├── Pipeline.png └── network.png ├── data ├── data_loader.py ├── data_loader_sketch.py ├── prepare_data.py ├── prepare_data_sketch.py ├── thinplate │ ├── __init__.py │ ├── numpy.py │ ├── pytorch.py │ └── tests │ │ ├── __init__.py │ │ ├── test_tps_numpy.py │ │ └── test_tps_pytorch.py └── tps_transformation.py ├── discriminator.py ├── distributed.py ├── extractor ├── Open-Sans-Bold.ttf └── manga_panel_extractor.py ├── inference.py ├── models.py ├── requirement.txt ├── test_datasets ├── gray_test │ ├── 001_in.png │ ├── 001_ref_a.png │ ├── 001_ref_b.png │ ├── 002_in.jpeg │ ├── 002_in_ref_a.jpg │ ├── 002_in_ref_b.jpeg │ ├── 003_in.jpeg │ ├── 003_in_ref_a.jpg │ ├── 003_in_ref_b.jpg │ ├── 004_in.png │ ├── 004_ref_1.jpg │ ├── 004_ref_2.jpg │ ├── 005_in.png │ ├── 005_ref_1.jpeg │ ├── 005_ref_2.jpg │ ├── 005_ref_3.jpeg │ ├── 006_in.png │ ├── 006_ref.png │ └── out │ │ ├── 001_in_color_a.png │ │ ├── 001_in_color_b.png │ │ ├── 002_in_color_a.png │ │ ├── 002_in_color_b.png │ │ ├── 003_in_color_a.png │ │ ├── 003_in_color_b.png │ │ ├── 004_in_color.png │ │ ├── 005_in_color.png │ │ └── 006_in_color.png └── sketch_test │ ├── 001_in.jpg │ ├── 001_ref_a.jpg │ ├── 001_ref_b.jpg │ └── out │ ├── 001_in_color_a.png │ └── 001_in_color_b.png ├── train.py ├── train_all_gray.py ├── train_all_sketch.py ├── train_disc.py ├── utils.py └── vgg_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### Example user template template 2 | ### Example user template 3 | 4 | # IntelliJ project files 5 | .idea 6 | *.iml 7 | out 8 | gen 9 | 10 | # Debug file 11 | datacheck.py 12 | test_gray2color.py 13 | val.py 14 | 15 | experiments/ 16 | misc/ 17 | results/ 18 | test_datasets/* 19 | !/test_datasets/gray_test 20 | !/test_datasets/gray_test/out 21 | !/test_datasets/sketch_test 22 | !/test_datasets/sketch_test/out 23 | train_datasets/ 24 | training_logs/ 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reference-Image-Embed-Manga-Colorization 2 | 3 | An amazing manga colorization project 4 | 5 | You can colorize gray manga or character sketches using any reference image you want, this model will faithfully retain the color features and transfer them to your manga. This is useful when you wish the color of the character's hair or clothes to be consistent. 6 | 7 | If the project is helpful, please leave a ⭐ this repo. best luck, my friend 😊
8 | 9 | ## Overview 10 |

11 | 12 |

13 | 14 | It's basically a cGAN(Conditional Generative Adversarial Network) architecture. 15 | 16 | ### Generator 17 | 18 | Generator is divided into two parts. 19 | 20 | `Color Embedding Layer` consists of part of pretrained VGG19 net and an MLP(Multilayer Perceptron), which is used to extract `color embedding` from reference image(for training, its preprocessed Ground Truth Image). 21 | 22 | Another part is a U-net-like network. The encoder layer extracts `content embedding` from gray input image(only contains L-channel information), and the decoder layer reconstructs the image with `color embedding` through PFFB(Progressive Feature Formalization Block) and outputs the ab_channel information. 23 | 24 |

25 | 26 |

27 | 28 | The figure shows how PFFB works. 29 | 30 | It generates a filter by applying color embedding, and then convolving with content features. The figure is from this [paper](https://arxiv.org/abs/2106.08017) and check it for more details. 31 | 32 | ### Discriminator 33 | 34 | Discriminator is a PatchGAN, referring to [pix2pix](https://arxiv.org/abs/1611.07004v3). The difference is that there are two conditions used for input. One is the gray image waiting for colorization, and one is the reference image providing color information. 35 | 36 | ### Loss 37 | 38 | There are three losses in total, `L1 loss`, `perceptual loss` produced by pretrained vgg19, and `adversarial loss` produced by discriminator. The ratio is `1: 0.1: 0.01`. 39 | 40 | ### Pipeline 41 | 42 |

43 | 44 |

45 | 46 | - a. Segment panels from input manga image, `Manga-Panel-Extractor` is from [here](https://github.com/pvnieo/Manga-Panel-Extractor). 47 | - b. Select a reference image for each panel, and generator will colorize each panel. 48 | - c. Concatenate all colorized panels into original format. 49 | 50 | ## Results 51 | ### Gray model 52 | 53 | | Original | Reference | Colorization | 54 | |:----------:|:-----------:|:----------:| 55 | | | | | 56 | | | | | 57 | | | | | 58 | | | | | 59 | | | | | 60 | | | | | 61 | | || | 62 | | | | | 63 | | | | | 64 | 65 | ### sketch model 66 | 67 | | Original | Reference | Colorization | 68 | | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | 69 | | | | | 70 | | | | | 71 | 72 | 73 | 74 | ## Dependencies and Installation 75 | 76 | 1. Clone this GitHub repo. 77 | ``` 78 | git clone https://github.com/linSensiGit/Example_Based_Manga_Colorization---cGAN.git 79 | 80 | cd Example_Based_Manga_Colorization---cGAN 81 | ``` 82 | 83 | 2. Create Environment 84 | - Python >= 3.6 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux)) 85 | 86 | - [PyTorch >= 1.5.0](https://pytorch.org/) (Default GPU mode) 87 | 88 | ``` 89 | # My environment for reference 90 | - Python = 3.9.15 91 | - PyTorch = 1.13.0 92 | - Torchvision = 0.14.0 93 | - Cuda = 11.7 94 | - GPU = RTX 3060ti 95 | ``` 96 | 97 | 3. Install Dependencies 98 | 99 | ``` 100 | pip3 install -r requirement.txt 101 | ``` 102 | 103 | ## Get Started 104 | 105 | Once you've set up the environment, several things need to be done before colorization. 106 | 107 | ### Prepare pretrained models 108 | 109 | 1. Download generator. I have trained two generators, for [gray manga](https://drive.google.com/file/d/11RQGvBKySEtRcBdYD8O5ZLb54jB7SAgN/view?usp=drive_link) colorization and [sketch](https://drive.google.com/file/d/1I4XwOYIGAoQwMOicknZl0s6AWcwpARmR/view?usp=drive_link) colorization. Choose what you need. 110 | 111 | 2. Download [VGG model](https://drive.google.com/file/d/1S7t3mD-tznEUrMmq5bRsLZk4fkN24QSV/view?usp=drive_link) , it's part of generator. 112 | 113 | 3. Download discriminator, for training [gray manga](https://drive.google.com/file/d/1DHHE9um_xOm0brTpbHb_R7K7J4mn37FS/view?usp=drive_link) colorization and [sketch](https://drive.google.com/file/d/1WgIPYY4b4GcpHW9EWFrFoTxL9SlilQbN/view?usp=drive_link) colorization. (optional) 114 | 115 | 4. Put the pretrained model in the correct directory: 116 | 117 | ``` 118 | Colorful-Manga-GAN 119 | |- experiments 120 | |- Color2Manga_gray 121 | |- xxx000_gray.pt 122 | |- Color2Manga_sketch 123 | |- xxx000_sketch.pt 124 | |- Discriminator 125 | |- xxx000_d.pt 126 | |- VGG19 127 | |- vgg19-dcbb9e9d.pth 128 | ``` 129 | 130 | ### Quick test 131 | 132 | I have collected some test datasets which contain manga pages and corresponding reference images. You can check it in the path `./test_datasets`. When you use the file `inference.py` to test, you may need to edit the input file path or pretrained weights path in this file. 133 | 134 | ``` 135 | python inference.py 136 | 137 | # If you don't want to segment your manga 138 | python inference.py -ne 139 | ``` 140 | Initially, `Manga-Panel-Extractor` will segment the manga page into panels. 141 | 142 | Then follow the instructions in the console and you will get the colorized image. 143 | 144 | ## Train your Own Model 145 | ### Prepare Datasets 146 | 147 | There are three datasets I used to train the model. 148 | 149 | For gray model, [Anime Face Dataset](https://www.kaggle.com/datasets/scribbless/another-anime-face-dataset) and Tagged [Anime Illustrations Dataset](https://www.kaggle.com/datasets/mylesoneill/tagged-anime-illustrations) are used. And I only use `danbooru-images` folder in the second Dataset. 150 | 151 | For sketch model, [Anime Sketch Colorization Pair Dataset](https://www.kaggle.com/datasets/ktaebum/anime-sketch-colorization-pair) is used. 152 | 153 | All the datasets are from [Kaggle](https://www.kaggle.com/). 154 | 155 | Follow instructions are based on my dataset, but feel free to use your own dataset if you like. 156 | 157 | ### Preprocess training data 158 | 159 | ``` 160 | cd data 161 | python prepare_data.py 162 | ``` 163 | 164 | If you are using ` Anime Sketch Colorization Pair` dataset : 165 | 166 | ``` 167 | python prepare_data_sketch.py 168 | ``` 169 | 170 | Several arguments needed to be assigned : 171 | 172 | ``` 173 | usage: prepare_data.py [-h] [--out OUT] [--size SIZE] [--n_worker N_WORKER] 174 | [--resample RESAMPLE] 175 | path 176 | positional arguments: 177 | path the path of datasets 178 | optional arguments: 179 | -h, --help show this help message and exit 180 | --out OUT the path to save generated lmdb 181 | --size SIZE compressed image size (128, 256, 512, 1024) alternative 182 | --n_worker N_WORKER The number of threads, depends on your CPU 183 | --resample RESAMPLE 184 | ``` 185 | 186 | For instance, you can run the command like this: 187 | 188 | ``` 189 | python prepare_data.py --out ../train_datasets/Sketch_train_lmdb --n_worker 20 --size 256 E:/Dataset/animefaces256cleaner 190 | ``` 191 | 192 | ### Training 193 | 194 | There are four scripts in total for training 195 | 196 | `train.py` —— train only generator 197 | 198 | `train_disc` —— train only discriminator 199 | 200 | `train_all_gray.py`—— train both generator and discriminator, under the usual dataset 201 | 202 | `train_all_sketch.py`—— train both generator and discriminator, under sketch pair dataset specific 203 | 204 | 205 | 206 | All of these scripts share similar commands to drive: 207 | 208 | ``` 209 | usage: train_all_gray.py [-h] [--datasets DATASETS] [--iter ITER] 210 | [--batch BATCH] [--size SIZE] [--ckpt CKPT] 211 | [--ckpt_disc CKPT_DISC] [--lr LR] [--lr_disc LR_DISC] 212 | [--experiment_name EXPERIMENT_NAME] [--wandb] 213 | [--local_rank LOCAL_RANK] 214 | optional arguments: 215 | -h, --help show this help message and exit 216 | --datasets DATASETS the path of training dataset 217 | --iter ITER number of iteration in total 218 | --batch BATCH batch size 219 | --size SIZE size of image in dataset, usually 256 220 | --ckpt CKPT path of pretrained generator 221 | --ckpt_disc CKPT_DISC path of pretrained discriminator 222 | --lr LR learning rate of generator 223 | --lr_disc LR_DISC learning rate of discriminator 224 | --experiment_name EXPERIMENT_NAME used to save training_logs and trained model 225 | --wandb 226 | --local_rank LOCAL_RANK 227 | ``` 228 | 229 | There may be a slight difference, you could check the code for more details. 230 | 231 | 232 | 233 | For instance, you can run the command like this: 234 | 235 | ``` 236 | python train_all_gray.py --batch 8 --experiment_name Color2Manga_sketch --ckpt experiments/Color2Manga_sketch/078000.pt --datasets ./train_datasets/Sketch_train_lmdb --ckpt_disc experiments/Discriminator/078000_d.pt 237 | ``` 238 | 239 | ## Work in Progress 240 | - [ ] Add SR model instead of directly interpolate upscaling 241 | - [ ] Optimize the generator network(adding L-channel information to output which is essential for colorize sketch) 242 | - [ ] Better developed manga-panel-extractor(current segmentation is not precise enough) 243 | - [ ] Develop a front UI and add color hint so that users could adjust the color of a specific area 244 | 245 | ## 😁Contact 246 | 247 | If you have any questions, please feel free to contact me via `shizifeng0615@outlook.com` 248 | 249 | ## 🙌 Acknowledgement 250 | Based on https://github.com/zhaohengyuan1/Color2Embed 251 | 252 | Thx https://github.com/pvnieo/Manga-Panel-Extractor 253 | 254 | ## Reference 255 | 256 | [1] Zhao, Hengyuan et al. “Color2Embed: Fast Exemplar-Based Image Colorization using Color Embeddings.” (2021). 257 | 258 | [2] Isola, Phillip et al. “Image-to-Image Translation with Conditional Adversarial Networks.” *2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR)* (2016): 5967-5976. 259 | 260 | [3] Furusawa, Chie et al. “Comicolorization: semi-automatic manga colorization.” *SIGGRAPH Asia 2017 Technical Briefs* (2017): n. pag. 261 | 262 | [4] Satoshi Iizuka, Edgar Simo-Serra, and Hiroshi Ishikawa. "Let there be Color!: Joint End-to-end Learning of Global and Local Image Priors for Automatic Image Colorization with Simultaneous Classification". ACM Transaction on Graphics (Proc. of SIGGRAPH), 35(4):110, 2016. 263 | -------------------------------------------------------------------------------- /assets/PFFB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/assets/PFFB.png -------------------------------------------------------------------------------- /assets/Pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/assets/Pipeline.png -------------------------------------------------------------------------------- /assets/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/assets/network.png -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import numpy as np 4 | import lmdb 5 | from PIL import Image 6 | from skimage import color 7 | import torch 8 | from torch.utils.data import Dataset 9 | from data.tps_transformation import tps_transform 10 | 11 | def RGB2Lab(inputs): 12 | return color.rgb2lab(inputs) 13 | 14 | def Normalize(inputs): 15 | # output l [-50,50] ab[-128,128] 16 | l = inputs[:, :, 0:1] 17 | ab = inputs[:, :, 1:3] 18 | l = l - 50 19 | # ab = ab 20 | lab = np.concatenate((l, ab), 2) 21 | 22 | return lab.astype('float32') 23 | 24 | def selfnormalize(inputs): 25 | d = torch.max(inputs) - torch.min(inputs) 26 | out = (inputs) / d 27 | return out 28 | 29 | def to_gray(inputs): 30 | img_gray = np.clip((np.concatenate((inputs[:,:,:1], inputs[:,:,:1], inputs[:,:,:1]), 2)+50)/100*255, 0, 255).astype('uint8') 31 | 32 | return img_gray 33 | 34 | def numpy2tensor(inputs): 35 | out = torch.from_numpy(inputs.transpose(2,0,1)) 36 | return out 37 | 38 | class MultiResolutionDataset(Dataset): 39 | def __init__(self, path, transform, resolution=256): 40 | self.env = lmdb.open( 41 | path, 42 | max_readers=32, 43 | readonly=True, 44 | lock=False, 45 | readahead=False, 46 | meminit=False, 47 | ) 48 | 49 | if not self.env: 50 | raise IOError('Cannot open lmdb dataset', path) 51 | 52 | with self.env.begin(write=False) as txn: 53 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 54 | 55 | self.resolution = resolution 56 | self.transform = transform 57 | 58 | def __len__(self): 59 | return self.length 60 | 61 | def __getitem__(self, index): 62 | with self.env.begin(write=False) as txn: 63 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 64 | img_bytes = txn.get(key) 65 | 66 | buffer = BytesIO(img_bytes) 67 | img = Image.open(buffer) 68 | img_src = np.array(img) # [0,255] uint8 69 | 70 | # ima_a = img_src 71 | # ima_a = ima_a.astype('uint8') 72 | # ima_a = Image.fromarray(ima_a) 73 | # ima_a.show() 74 | 75 | ## add gaussian noise 76 | noise = np.random.uniform(-5, 5, np.shape(img_src)) 77 | img_ref = np.clip(np.array(img_src) + noise, 0, 255) 78 | 79 | 80 | img_ref = tps_transform(img_ref) # [0,255] uint8 81 | img_ref = np.clip(img_ref, 0, 255) 82 | img_ref = img_ref.astype('uint8') 83 | img_ref = Image.fromarray(img_ref) 84 | img_ref = np.array(self.transform(img_ref)) # [0,255] uint8 85 | 86 | img_lab = Normalize(RGB2Lab(img_src)) # l [-50,50] ab [-128, 128] 87 | 88 | img = img_src.astype('float32') # [0,255] float32 RGB 89 | img_ref = img_ref.astype('float32') # [0,255] float32 RGB 90 | 91 | img = numpy2tensor(img) 92 | img_ref = numpy2tensor(img_ref) # [B, 3, 256, 256] 93 | img_lab = numpy2tensor(img_lab) 94 | 95 | return img, img_ref, img_lab 96 | 97 | -------------------------------------------------------------------------------- /data/data_loader_sketch.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import numpy as np 4 | import lmdb 5 | from PIL import Image 6 | from skimage import color 7 | import torch 8 | from torch.utils.data import Dataset 9 | from data.tps_transformation import tps_transform 10 | 11 | def RGB2Lab(inputs): 12 | return color.rgb2lab(inputs) 13 | 14 | def Normalize(inputs): 15 | # output l [-50,50] ab[-128,128] 16 | l = inputs[:, :, 0:1] 17 | ab = inputs[:, :, 1:3] 18 | l = l - 50 19 | # ab = ab 20 | lab = np.concatenate((l, ab), 2) 21 | 22 | return lab.astype('float32') 23 | 24 | def selfnormalize(inputs): 25 | d = torch.max(inputs) - torch.min(inputs) 26 | out = (inputs) / d 27 | return out 28 | 29 | def to_gray(inputs): 30 | img_gray = np.clip((np.concatenate((inputs[:,:,:1], inputs[:,:,:1], inputs[:,:,:1]), 2)+50)/100*255, 0, 255).astype('uint8') 31 | 32 | return img_gray 33 | 34 | def numpy2tensor(inputs): 35 | out = torch.from_numpy(inputs.transpose(2,0,1)) 36 | return out 37 | 38 | class MultiResolutionDataset(Dataset): 39 | def __init__(self, path, transform, resolution=256): 40 | self.env = lmdb.open( 41 | path, 42 | max_readers=32, 43 | readonly=True, 44 | lock=False, 45 | readahead=False, 46 | meminit=False, 47 | ) 48 | 49 | if not self.env: 50 | raise IOError('Cannot open lmdb dataset', path) 51 | 52 | with self.env.begin(write=False) as txn: 53 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 54 | 55 | self.resolution = resolution 56 | self.transform = transform 57 | 58 | def __len__(self): 59 | return self.length 60 | 61 | def __getitem__(self, index): 62 | with self.env.begin(write=False) as txn: 63 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 64 | img_bytes = txn.get(key) 65 | 66 | buffer = BytesIO(img_bytes) 67 | img = Image.open(buffer) 68 | img_src = np.array(img) # [0,255] uint8 69 | 70 | # ima_a = img_src 71 | # ima_a = ima_a.astype('uint8') 72 | # ima_a = Image.fromarray(ima_a) 73 | # ima_a.show() 74 | 75 | # get the left color image 76 | img_ref = img_src[:, :256] 77 | ## add gaussian noise 78 | noise = np.random.uniform(-5, 5, np.shape(img_ref)) 79 | img_ref = np.clip(np.array(img_ref) + noise, 0, 255) 80 | 81 | 82 | img_ref = tps_transform(img_ref) # [0,255] uint8 83 | img_ref = np.clip(img_ref, 0, 255) 84 | img_ref = img_ref.astype('uint8') 85 | img_ref = Image.fromarray(img_ref) 86 | img_ref = np.array(self.transform(img_ref)) # [0,255] uint8 87 | 88 | img_lab = img_src[:, :256] 89 | img_lab = Normalize(RGB2Lab(img_lab)) # l [-50,50] ab [-128, 128] 90 | 91 | img_lab_sketch = img_src[:, 256:] 92 | img_lab_sketch = Normalize(RGB2Lab(img_lab_sketch)) # l [-50,50] ab [-128, 128] 93 | 94 | img = img_src[:, :256].astype('float32') # [0,255] float32 RGB 95 | img_ref = img_ref.astype('float32') # [0,255] float32 RGB 96 | 97 | # ima_a = img 98 | # ima_a = ima_a.astype('uint8') 99 | # ima_a = Image.fromarray(ima_a) 100 | # ima_a.show() 101 | # 102 | # ima_a = img_ref 103 | # ima_a = ima_a.astype('uint8') 104 | # ima_a = Image.fromarray(ima_a) 105 | # ima_a.show() 106 | # 107 | # ima_a = img_lab 108 | # ima_a = ima_a.astype('uint8') 109 | # ima_a = Image.fromarray(ima_a) 110 | # ima_a.show() 111 | 112 | 113 | img = numpy2tensor(img) 114 | img_ref = numpy2tensor(img_ref) # [B, 3, 256, 256] 115 | img_lab = numpy2tensor(img_lab) 116 | img_lab_sketch = numpy2tensor(img_lab_sketch) 117 | 118 | return img, img_ref, img_lab, img_lab_sketch 119 | 120 | -------------------------------------------------------------------------------- /data/prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from io import BytesIO 3 | import multiprocessing 4 | from functools import partial 5 | 6 | from PIL import Image 7 | import lmdb 8 | from tqdm import tqdm 9 | from torchvision import datasets 10 | from torchvision.transforms import functional as trans_fn 11 | 12 | 13 | def resize_and_convert(img, size, resample, quality=100): 14 | img = trans_fn.resize(img, size, resample) 15 | img = trans_fn.center_crop(img, size) 16 | buffer = BytesIO() 17 | img.save(buffer, format='jpeg', quality=quality) 18 | val = buffer.getvalue() 19 | 20 | return val 21 | 22 | 23 | def resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100): 24 | imgs = [] 25 | 26 | for size in sizes: 27 | imgs.append(resize_and_convert(img, size, resample, quality)) 28 | 29 | return imgs 30 | 31 | 32 | def resize_worker(img_file, sizes, resample): 33 | i, file = img_file 34 | img = Image.open(file) 35 | img = img.convert('RGB') 36 | out = resize_multiple(img, sizes=sizes, resample=resample) 37 | 38 | return i, out 39 | 40 | 41 | def prepare(env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS): 42 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample) 43 | 44 | files = sorted(dataset.imgs, key=lambda x: x[0]) 45 | # print(files) 46 | # eixt() 47 | files = [(i, file) for i, (file, label) in enumerate(files)] 48 | total = 0 49 | 50 | with multiprocessing.Pool(n_worker) as pool: 51 | for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)): 52 | for size, img in zip(sizes, imgs): 53 | key = f'{size}-{str(i).zfill(5)}'.encode('utf-8') 54 | 55 | with env.begin(write=True) as txn: 56 | txn.put(key, img) 57 | 58 | total += 1 59 | 60 | with env.begin(write=True) as txn: 61 | txn.put('length'.encode('utf-8'), str(total).encode('utf-8')) 62 | 63 | 64 | if __name__ == '__main__': 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--out', type=str) 67 | parser.add_argument('--size', type=str, default='128,256,512,1024') 68 | parser.add_argument('--n_worker', type=int, default=8) 69 | parser.add_argument('--resample', type=str, default='lanczos') 70 | parser.add_argument('path', type=str) 71 | 72 | args = parser.parse_args() 73 | 74 | resample_map = {'lanczos': Image.LANCZOS, 'bilinear': Image.BILINEAR} 75 | resample = resample_map[args.resample] 76 | 77 | sizes = [int(s.strip()) for s in args.size.split(',')] 78 | 79 | print(f'Make dataset of image sizes:', ', '.join(str(s) for s in sizes)) 80 | 81 | imgset = datasets.ImageFolder(args.path) 82 | 83 | with lmdb.open(args.out, map_size=6 * 1024 * 1024 * 1024, readahead=False) as env: 84 | prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample) 85 | -------------------------------------------------------------------------------- /data/prepare_data_sketch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from io import BytesIO 3 | import multiprocessing 4 | from functools import partial 5 | 6 | from PIL import Image 7 | import lmdb 8 | from tqdm import tqdm 9 | from torchvision import datasets 10 | from torchvision.transforms import functional as trans_fn 11 | 12 | 13 | def resize_and_convert(img, size, resample, quality=100): 14 | img = trans_fn.resize(img, size=[256, 512], interpolation=resample) 15 | img = trans_fn.center_crop(img, output_size=[256, 512]) 16 | buffer = BytesIO() 17 | img.save(buffer, format='jpeg', quality=quality) 18 | val = buffer.getvalue() 19 | 20 | return val 21 | 22 | 23 | def resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100): 24 | imgs = [] 25 | 26 | for size in sizes: 27 | imgs.append(resize_and_convert(img, size, resample, quality)) 28 | 29 | return imgs 30 | 31 | 32 | def resize_worker(img_file, sizes, resample): 33 | i, file = img_file 34 | img = Image.open(file) 35 | img = img.convert('RGB') 36 | out = resize_multiple(img, sizes=sizes, resample=resample) 37 | 38 | return i, out 39 | 40 | 41 | def prepare(env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS): 42 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample) 43 | 44 | files = sorted(dataset.imgs, key=lambda x: x[0]) 45 | # print(files) 46 | # eixt() 47 | files = [(i, file) for i, (file, label) in enumerate(files)] 48 | total = 0 49 | 50 | with multiprocessing.Pool(n_worker) as pool: 51 | for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)): 52 | for size, img in zip(sizes, imgs): 53 | key = f'{size}-{str(i).zfill(5)}'.encode('utf-8') 54 | 55 | with env.begin(write=True) as txn: 56 | txn.put(key, img) 57 | 58 | total += 1 59 | 60 | with env.begin(write=True) as txn: 61 | txn.put('length'.encode('utf-8'), str(total).encode('utf-8')) 62 | 63 | 64 | if __name__ == '__main__': 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--out', type=str) 67 | parser.add_argument('--size', type=str, default='128,256,512,1024') 68 | parser.add_argument('--n_worker', type=int, default=8) 69 | parser.add_argument('--resample', type=str, default='lanczos') 70 | parser.add_argument('path', type=str) 71 | 72 | args = parser.parse_args() 73 | 74 | resample_map = {'lanczos': Image.LANCZOS, 'bilinear': Image.BILINEAR} 75 | resample = resample_map[args.resample] 76 | 77 | sizes = [int(s.strip()) for s in args.size.split(',')] 78 | 79 | print(f'Make dataset of image sizes:', ', '.join(str(s) for s in sizes)) 80 | 81 | imgset = datasets.ImageFolder(args.path) 82 | 83 | with lmdb.open(args.out, map_size=6 * 1024 * 1024 * 1024, readahead=False) as env: 84 | prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample) 85 | -------------------------------------------------------------------------------- /data/thinplate/__init__.py: -------------------------------------------------------------------------------- 1 | from data.thinplate.numpy import * 2 | 3 | try: 4 | import torch 5 | import data.thinplate.pytorch as torch 6 | except ImportError: 7 | pass 8 | 9 | __version__ = '1.0.0' -------------------------------------------------------------------------------- /data/thinplate/numpy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ 5 | 6 | import numpy as np 7 | 8 | class TPS: 9 | @staticmethod 10 | def fit(c, lambd=0., reduced=False): 11 | n = c.shape[0] 12 | 13 | U = TPS.u(TPS.d(c, c)) 14 | K = U + np.eye(n, dtype=np.float32)*lambd 15 | 16 | P = np.ones((n, 3), dtype=np.float32) 17 | P[:, 1:] = c[:, :2] 18 | 19 | v = np.zeros(n+3, dtype=np.float32) 20 | v[:n] = c[:, -1] 21 | 22 | A = np.zeros((n+3, n+3), dtype=np.float32) 23 | A[:n, :n] = K 24 | A[:n, -3:] = P 25 | A[-3:, :n] = P.T 26 | 27 | theta = np.linalg.solve(A, v) # p has structure w,a 28 | return theta[1:] if reduced else theta 29 | 30 | @staticmethod 31 | def d(a, b): 32 | return np.sqrt(np.square(a[:, None, :2] - b[None, :, :2]).sum(-1)) 33 | 34 | @staticmethod 35 | def u(r): 36 | return r**2 * np.log(r + 1e-6) 37 | 38 | @staticmethod 39 | def z(x, c, theta): 40 | x = np.atleast_2d(x) 41 | U = TPS.u(TPS.d(x, c)) 42 | w, a = theta[:-3], theta[-3:] 43 | reduced = theta.shape[0] == c.shape[0] + 2 44 | if reduced: 45 | w = np.concatenate((-np.sum(w, keepdims=True), w)) 46 | b = np.dot(U, w) 47 | return a[0] + a[1]*x[:, 0] + a[2]*x[:, 1] + b 48 | 49 | def uniform_grid(shape): 50 | '''Uniform grid coordinates. 51 | 52 | Params 53 | ------ 54 | shape : tuple 55 | HxW defining the number of height and width dimension of the grid 56 | 57 | Returns 58 | ------- 59 | points: HxWx2 tensor 60 | Grid coordinates over [0,1] normalized image range. 61 | ''' 62 | 63 | H,W = shape[:2] 64 | c = np.empty((H, W, 2)) 65 | c[..., 0] = np.linspace(0, 1, W, dtype=np.float32) 66 | c[..., 1] = np.expand_dims(np.linspace(0, 1, H, dtype=np.float32), -1) 67 | 68 | return c 69 | 70 | def tps_theta_from_points(c_src, c_dst, reduced=False): 71 | delta = c_src - c_dst 72 | 73 | cx = np.column_stack((c_dst, delta[:, 0])) 74 | cy = np.column_stack((c_dst, delta[:, 1])) 75 | 76 | theta_dx = TPS.fit(cx, reduced=reduced) 77 | theta_dy = TPS.fit(cy, reduced=reduced) 78 | 79 | return np.stack((theta_dx, theta_dy), -1) 80 | 81 | 82 | def tps_grid(theta, c_dst, dshape): 83 | ugrid = uniform_grid(dshape) 84 | 85 | reduced = c_dst.shape[0] + 2 == theta.shape[0] 86 | 87 | dx = TPS.z(ugrid.reshape((-1, 2)), c_dst, theta[:, 0]).reshape(dshape[:2]) 88 | dy = TPS.z(ugrid.reshape((-1, 2)), c_dst, theta[:, 1]).reshape(dshape[:2]) 89 | dgrid = np.stack((dx, dy), -1) 90 | 91 | grid = dgrid + ugrid 92 | 93 | return grid # H'xW'x2 grid[i,j] in range [0..1] 94 | 95 | def tps_grid_to_remap(grid, sshape): 96 | '''Convert a dense grid to OpenCV's remap compatible maps. 97 | 98 | Params 99 | ------ 100 | grid : HxWx2 array 101 | Normalized flow field coordinates as computed by compute_densegrid. 102 | sshape : tuple 103 | Height and width of source image in pixels. 104 | 105 | 106 | Returns 107 | ------- 108 | mapx : HxW array 109 | mapy : HxW array 110 | ''' 111 | 112 | mx = (grid[:, :, 0] * sshape[1]).astype(np.float32) 113 | my = (grid[:, :, 1] * sshape[0]).astype(np.float32) 114 | 115 | return mx, my -------------------------------------------------------------------------------- /data/thinplate/pytorch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Christoph Heindl. 2 | # 3 | # Licensed under MIT License 4 | # ============================================================ 5 | 6 | import torch 7 | 8 | def tps(theta, ctrl, grid): 9 | '''Evaluate the thin-plate-spline (TPS) surface at xy locations arranged in a grid. 10 | The TPS surface is a minimum bend interpolation surface defined by a set of control points. 11 | The function value for a x,y location is given by 12 | 13 | TPS(x,y) := theta[-3] + theta[-2]*x + theta[-1]*y + \sum_t=0,T theta[t] U(x,y,ctrl[t]) 14 | 15 | This method computes the TPS value for multiple batches over multiple grid locations for 2 16 | surfaces in one go. 17 | 18 | Params 19 | ------ 20 | theta: Nx(T+3)x2 tensor, or Nx(T+2)x2 tensor 21 | Batch size N, T+3 or T+2 (reduced form) model parameters for T control points in dx and dy. 22 | ctrl: NxTx2 tensor or Tx2 tensor 23 | T control points in normalized image coordinates [0..1] 24 | grid: NxHxWx3 tensor 25 | Grid locations to evaluate with homogeneous 1 in first coordinate. 26 | 27 | Returns 28 | ------- 29 | z: NxHxWx2 tensor 30 | Function values at each grid location in dx and dy. 31 | ''' 32 | 33 | N, H, W, _ = grid.size() 34 | 35 | if ctrl.dim() == 2: 36 | ctrl = ctrl.expand(N, *ctrl.size()) 37 | 38 | T = ctrl.shape[1] 39 | 40 | diff = grid[...,1:].unsqueeze(-2) - ctrl.unsqueeze(1).unsqueeze(1) 41 | D = torch.sqrt((diff**2).sum(-1)) 42 | U = (D**2) * torch.log(D + 1e-6) 43 | 44 | w, a = theta[:, :-3, :], theta[:, -3:, :] 45 | 46 | reduced = T + 2 == theta.shape[1] 47 | if reduced: 48 | w = torch.cat((-w.sum(dim=1, keepdim=True), w), dim=1) 49 | 50 | # U is NxHxWxT 51 | b = torch.bmm(U.view(N, -1, T), w).view(N,H,W,2) 52 | # b is NxHxWx2 53 | z = torch.bmm(grid.view(N,-1,3), a).view(N,H,W,2) + b 54 | 55 | return z 56 | 57 | def tps_grid(theta, ctrl, size): 58 | '''Compute a thin-plate-spline grid from parameters for sampling. 59 | 60 | Params 61 | ------ 62 | theta: Nx(T+3)x2 tensor 63 | Batch size N, T+3 model parameters for T control points in dx and dy. 64 | ctrl: NxTx2 tensor, or Tx2 tensor 65 | T control points in normalized image coordinates [0..1] 66 | size: tuple 67 | Output grid size as NxCxHxW. C unused. This defines the output image 68 | size when sampling. 69 | 70 | Returns 71 | ------- 72 | grid : NxHxWx2 tensor 73 | Grid suitable for sampling in pytorch containing source image 74 | locations for each output pixel. 75 | ''' 76 | N, _, H, W = size 77 | 78 | grid = theta.new(N, H, W, 3) 79 | grid[:, :, :, 0] = 1. 80 | grid[:, :, :, 1] = torch.linspace(0, 1, W) 81 | grid[:, :, :, 2] = torch.linspace(0, 1, H).unsqueeze(-1) 82 | 83 | z = tps(theta, ctrl, grid) 84 | return (grid[...,1:] + z)*2-1 # [-1,1] range required by F.sample_grid 85 | 86 | def tps_sparse(theta, ctrl, xy): 87 | if xy.dim() == 2: 88 | xy = xy.expand(theta.shape[0], *xy.size()) 89 | 90 | N, M = xy.shape[:2] 91 | grid = xy.new(N, M, 3) 92 | grid[..., 0] = 1. 93 | grid[..., 1:] = xy 94 | 95 | z = tps(theta, ctrl, grid.view(N,M,1,3)) 96 | return xy + z.view(N, M, 2) 97 | 98 | def uniform_grid(shape): 99 | '''Uniformly places control points aranged in grid accross normalized image coordinates. 100 | 101 | Params 102 | ------ 103 | shape : tuple 104 | HxW defining the number of control points in height and width dimension 105 | 106 | Returns 107 | ------- 108 | points: HxWx2 tensor 109 | Control points over [0,1] normalized image range. 110 | ''' 111 | H,W = shape[:2] 112 | c = torch.zeros(H, W, 2) 113 | c[..., 0] = torch.linspace(0, 1, W) 114 | c[..., 1] = torch.linspace(0, 1, H).unsqueeze(-1) 115 | return c 116 | 117 | if __name__ == '__main__': 118 | c = torch.tensor([ 119 | [0., 0], 120 | [1., 0], 121 | [1., 1], 122 | [0, 1], 123 | ]).unsqueeze(0) 124 | theta = torch.zeros(1, 4+3, 2) 125 | size= (1,1,6,3) 126 | print(tps_grid(theta, c, size).shape) -------------------------------------------------------------------------------- /data/thinplate/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/data/thinplate/tests/__init__.py -------------------------------------------------------------------------------- /data/thinplate/tests/test_tps_numpy.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from numpy.testing import assert_allclose 4 | import thinplate as tps 5 | 6 | def test_numpy_fit(): 7 | c = np.array([ 8 | [0., 0, 0.0], 9 | [1., 0, 0.0], 10 | [1., 1, 0.0], 11 | [0, 1, 0.0], 12 | ]) 13 | 14 | theta = tps.TPS.fit(c) 15 | assert_allclose(theta, 0) 16 | assert_allclose(tps.TPS.z(c, c, theta), c[:, 2]) 17 | 18 | c = np.array([ 19 | [0., 0, 1.0], 20 | [1., 0, 1.0], 21 | [1., 1, 1.0], 22 | [0, 1, 1.0], 23 | ]) 24 | 25 | theta = tps.TPS.fit(c) 26 | assert_allclose(theta[:-3], 0) 27 | assert_allclose(theta[-3:], [1, 0, 0]) 28 | assert_allclose(tps.TPS.z(c, c, theta), c[:, 2], atol=1e-3) 29 | 30 | # reduced form 31 | theta = tps.TPS.fit(c, reduced=True) 32 | assert len(theta) == c.shape[0] + 2 33 | assert_allclose(theta[:-3], 0) 34 | assert_allclose(theta[-3:], [1, 0, 0]) 35 | assert_allclose(tps.TPS.z(c, c, theta), c[:, 2], atol=1e-3) 36 | 37 | c = np.array([ 38 | [0., 0, -.5], 39 | [1., 0, 0.5], 40 | [1., 1, 0.2], 41 | [0, 1, 0.8], 42 | ]) 43 | 44 | theta = tps.TPS.fit(c) 45 | assert_allclose(tps.TPS.z(c, c, theta), c[:, 2], atol=1e-3) 46 | 47 | def test_numpy_densegrid(): 48 | 49 | # enlarges a small rectangle to full view 50 | 51 | import cv2 52 | 53 | img = np.zeros((40, 40), dtype=np.uint8) 54 | img[10:21, 10:21] = 255 55 | 56 | c_dst = np.array([ 57 | [0., 0], 58 | [1., 0], 59 | [1, 1], 60 | [0, 1], 61 | ]) 62 | 63 | 64 | c_src = np.array([ 65 | [10., 10], 66 | [20., 10], 67 | [20, 20], 68 | [10, 20], 69 | ]) / 40. 70 | 71 | theta = tps.tps_theta_from_points(c_src, c_dst) 72 | theta_r = tps.tps_theta_from_points(c_src, c_dst, reduced=True) 73 | 74 | grid = tps.tps_grid(theta, c_dst, (20,20)) 75 | grid_r = tps.tps_grid(theta_r, c_dst, (20,20)) 76 | 77 | mapx, mapy = tps.tps_grid_to_remap(grid, img.shape) 78 | warped = cv2.remap(img, mapx, mapy, cv2.INTER_CUBIC) 79 | 80 | assert img.min() == 0. 81 | assert img.max() == 255. 82 | assert warped.shape == (20,20) 83 | assert warped.min() == 255. 84 | assert warped.max() == 255. 85 | assert np.linalg.norm(grid.reshape(-1,2) - grid_r.reshape(-1,2)) < 1e-3 86 | -------------------------------------------------------------------------------- /data/thinplate/tests/test_tps_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | import thinplate as tps 7 | 8 | from numpy.testing import assert_allclose 9 | 10 | def test_pytorch_grid(): 11 | 12 | c_dst = np.array([ 13 | [0., 0], 14 | [1., 0], 15 | [1, 1], 16 | [0, 1], 17 | ], dtype=np.float32) 18 | 19 | 20 | c_src = np.array([ 21 | [10., 10], 22 | [20., 10], 23 | [20, 20], 24 | [10, 20], 25 | ], dtype=np.float32) / 40. 26 | 27 | theta = tps.tps_theta_from_points(c_src, c_dst) 28 | theta_r = tps.tps_theta_from_points(c_src, c_dst, reduced=True) 29 | 30 | np_grid = tps.tps_grid(theta, c_dst, (20,20)) 31 | np_grid_r = tps.tps_grid(theta_r, c_dst, (20,20)) 32 | 33 | pth_theta = torch.tensor(theta).unsqueeze(0) 34 | pth_grid = tps.torch.tps_grid(pth_theta, torch.tensor(c_dst), (1, 1, 20, 20)).squeeze().numpy() 35 | pth_grid = (pth_grid + 1) / 2 # convert [-1,1] range to [0,1] 36 | 37 | pth_theta_r = torch.tensor(theta_r).unsqueeze(0) 38 | pth_grid_r = tps.torch.tps_grid(pth_theta_r, torch.tensor(c_dst), (1, 1, 20, 20)).squeeze().numpy() 39 | pth_grid_r = (pth_grid_r + 1) / 2 # convert [-1,1] range to [0,1] 40 | 41 | assert_allclose(np_grid, pth_grid) 42 | assert_allclose(np_grid_r, pth_grid_r) 43 | assert_allclose(np_grid_r, np_grid) -------------------------------------------------------------------------------- /data/tps_transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import data.thinplate as tps 3 | import cv2 4 | import random 5 | import math 6 | 7 | # Reference : https://github.com/cheind/py-thin-plate-spline 8 | 9 | def tps_transform(img, dshape=None): 10 | 11 | while True: 12 | point1 = round(random.uniform(0.3, 0.7), 2) 13 | point2 = round(random.uniform(0.3, 0.7), 2) 14 | range_1 = round(random.uniform(-0.25, 0.25), 2) 15 | range_2 = round(random.uniform(-0.25, 0.25), 2) 16 | if math.isclose(point1 + range_1, point2 + range_2): 17 | continue 18 | else: 19 | break 20 | 21 | c_src = np.array([ 22 | [0.0, 0.0], 23 | [1., 0], 24 | [1, 1], 25 | [0, 1], 26 | [point1, point1], 27 | [point2, point2], 28 | ]) 29 | 30 | c_dst = np.array([ 31 | [0., 0], 32 | [1., 0], 33 | [1, 1], 34 | [0, 1], 35 | [point1 + range_1, point1 + range_1], 36 | [point2 + range_2, point2 + range_2], 37 | ]) 38 | 39 | dshape = dshape or img.shape 40 | theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True) 41 | grid = tps.tps_grid(theta, c_dst, dshape) 42 | mapx, mapy = tps.tps_grid_to_remap(grid, img.shape) 43 | return cv2.remap(img, mapx, mapy, cv2.INTER_CUBIC) 44 | 45 | -------------------------------------------------------------------------------- /discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | 6 | class Discriminator(nn.Module): 7 | def __init__(self, in_channels=3): 8 | super(Discriminator, self).__init__() 9 | 10 | def discriminator_block(in_filters, out_filters, normalization=True): 11 | """Returns downsampling layers of each discriminator block""" 12 | layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)] 13 | if normalization: 14 | layers.append(nn.InstanceNorm2d(out_filters)) 15 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 16 | return layers 17 | 18 | self.model = nn.Sequential( 19 | *discriminator_block(in_channels * 3, 64, normalization=False), 20 | *discriminator_block(64, 128), 21 | *discriminator_block(128, 256), 22 | *discriminator_block(256, 512), 23 | nn.ZeroPad2d((1, 0, 1, 0)), 24 | nn.Conv2d(512, 1, 4, padding=1, bias=False) 25 | ) 26 | 27 | def forward(self, img_out, img_l, img_ref ): 28 | # Concatenate image and condition image by channels to produce input 29 | img_input = torch.cat((img_out, img_l, img_ref), 1) 30 | return self.model(img_input) 31 | 32 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def reduce_sum(tensor): 45 | if not dist.is_available(): 46 | return tensor 47 | 48 | if not dist.is_initialized(): 49 | return tensor 50 | 51 | tensor = tensor.clone() 52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 53 | 54 | return tensor 55 | 56 | 57 | def gather_grad(params): 58 | world_size = get_world_size() 59 | 60 | if world_size == 1: 61 | return 62 | 63 | for param in params: 64 | if param.grad is not None: 65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 66 | param.grad.data.div_(world_size) 67 | 68 | 69 | def all_gather(data): 70 | world_size = get_world_size() 71 | 72 | if world_size == 1: 73 | return [data] 74 | 75 | buffer = pickle.dumps(data) 76 | storage = torch.ByteStorage.from_buffer(buffer) 77 | tensor = torch.ByteTensor(storage).to('cuda') 78 | 79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 81 | dist.all_gather(size_list, local_size) 82 | size_list = [int(size.item()) for size in size_list] 83 | max_size = max(size_list) 84 | 85 | tensor_list = [] 86 | for _ in size_list: 87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 88 | 89 | if local_size != max_size: 90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 91 | tensor = torch.cat((tensor, padding), 0) 92 | 93 | dist.all_gather(tensor_list, tensor) 94 | 95 | data_list = [] 96 | 97 | for size, tensor in zip(size_list, tensor_list): 98 | buffer = tensor.cpu().numpy().tobytes()[:size] 99 | data_list.append(pickle.loads(buffer)) 100 | 101 | return data_list 102 | 103 | 104 | def reduce_loss_dict(loss_dict): 105 | world_size = get_world_size() 106 | 107 | if world_size < 2: 108 | return loss_dict 109 | 110 | with torch.no_grad(): 111 | keys = [] 112 | losses = [] 113 | 114 | for k in sorted(loss_dict.keys()): 115 | keys.append(k) 116 | losses.append(loss_dict[k]) 117 | 118 | losses = torch.stack(losses, 0) 119 | dist.reduce(losses, dst=0) 120 | 121 | if dist.get_rank() == 0: 122 | losses /= world_size 123 | 124 | reduced_losses = {k: v for k, v in zip(keys, losses)} 125 | 126 | return reduced_losses 127 | -------------------------------------------------------------------------------- /extractor/Open-Sans-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/extractor/Open-Sans-Bold.ttf -------------------------------------------------------------------------------- /extractor/manga_panel_extractor.py: -------------------------------------------------------------------------------- 1 | # stdlib 2 | import argparse 3 | from argparse import RawTextHelpFormatter 4 | import os 5 | from os.path import splitext, basename, exists, join 6 | from os import makedirs 7 | # 3p 8 | from tqdm import tqdm 9 | import numpy as np 10 | from skimage import measure 11 | from PIL import Image 12 | from PIL import ImageFont 13 | from PIL import ImageDraw 14 | import cv2 15 | # project 16 | from utils import get_files, load_image 17 | from skimage import io 18 | 19 | 20 | class PanelExtractor: 21 | def __init__(self, min_pct_panel=2, max_pct_panel=90, paper_th=0.35): 22 | assert min_pct_panel < max_pct_panel, "Minimum percentage must be smaller than maximum percentage" 23 | self.min_panel = min_pct_panel / 100 24 | self.max_panel = max_pct_panel / 100 25 | self.paper_th = paper_th 26 | 27 | def _generate_panel_blocks(self, img): 28 | img = img if len(img.shape) == 2 else img[:, :, 0] 29 | blur = cv2.GaussianBlur(img, (5, 5), 0) 30 | thresh = cv2.threshold(blur, 230, 255, cv2.THRESH_BINARY)[1] 31 | cv2.rectangle(thresh, (0, 0), tuple(img.shape[::-1]), (0, 0, 0), 10) 32 | num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(thresh, 4, cv2.CV_32S) 33 | ind = np.argsort(stats[:, 4], )[::-1][1] 34 | panel_block_mask = ((labels == ind) * 255).astype("uint8") 35 | # Image.fromarray(panel_block_mask).show() 36 | return panel_block_mask 37 | 38 | def generate_panels(self, img): 39 | block_mask = self._generate_panel_blocks(img) 40 | cv2.rectangle(block_mask, (0, 0), tuple(block_mask.shape[::-1]), (255, 255, 255), 10) 41 | # Image.fromarray(block_mask).show() 42 | 43 | # detect contours 44 | contours, hierarchy = cv2.findContours(block_mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 45 | panels = [] 46 | masks = [] 47 | panel_masks = [] 48 | # print(len(contours)) 49 | 50 | for i in range(len(contours)): 51 | area = cv2.contourArea(contours[i]) 52 | img_area = img.shape[0] * img.shape[1] 53 | 54 | # if the contour is very small or very big, it's likely wrongly detected 55 | if area < (self.min_panel * img_area) or area > (self.max_panel * img_area): 56 | continue 57 | 58 | x, y, w, h = cv2.boundingRect(contours[i]) 59 | masks.append(cv2.boundingRect(contours[i])) 60 | # create panel mask 61 | panel_mask = np.ones_like(block_mask, "int32") 62 | cv2.fillPoly(panel_mask, [contours[i].astype("int32")], color=(0, 0, 0)) 63 | # Image.fromarray(panel_mask).show() 64 | panel_mask = panel_mask[y:y + h, x:x + w].copy() 65 | # Image.fromarray(panel_mask).show() 66 | 67 | # apply panel mask 68 | panel = img[y:y + h, x:x + w].copy() 69 | # Image.fromarray(panel).show() 70 | panel[panel_mask == 1] = 255 71 | # Image.fromarray(panel).show() 72 | 73 | panels.append(panel) 74 | panel_masks.append(panel_mask) 75 | 76 | return panels, masks, panel_masks 77 | 78 | def extract(self, folder): 79 | print("Loading images ... ", end="") 80 | # image_list, _, _ = get_files(folder) 81 | image_list = [] 82 | image_list.append(folder) 83 | imgs = [load_image(x) for x in image_list] 84 | print("Done!") 85 | 86 | folder = os.path.dirname(folder) 87 | # create panels dir 88 | if not exists(join(folder, "panels")): 89 | makedirs(join(folder, "panels")) 90 | folder = join(folder, "panels") 91 | 92 | # remove images with paper texture, not well segmented 93 | paperless_imgs = [] 94 | for img in tqdm(imgs, desc="Removing images with paper texture"): 95 | hist, bins = np.histogram(img.copy().ravel(), 256, [0, 256]) 96 | if np.sum(hist[50:200]) / np.sum(hist) < self.paper_th: 97 | paperless_imgs.append(img) 98 | 99 | if not paperless_imgs: 100 | return imgs, [], [] 101 | for i, img in tqdm(enumerate(paperless_imgs), desc="extracting panels"): 102 | panels, masks, panel_masks = self.generate_panels(img) 103 | name, ext = splitext(basename(image_list[i])) 104 | for j, panel in enumerate(panels): 105 | cv2.imwrite(join(folder, f'{name}_{j}.{ext}'), panel) 106 | 107 | # show the order of colorized panels 108 | img = Image.fromarray(img) 109 | draw = ImageDraw.Draw(img) 110 | font = ImageFont.truetype('extractor/Open-Sans-Bold.ttf', 160) 111 | 112 | def flatten(l): 113 | for el in l: 114 | if isinstance(el, list): 115 | yield from flatten(el) 116 | else: 117 | yield el 118 | 119 | for i, bbox in enumerate(flatten(masks), start=1): 120 | w, h = draw.textsize(str(i), font=font) 121 | y = (bbox[1] + bbox[3] / 2 - h / 2) 122 | x = (bbox[0] + bbox[2] / 2 - w / 2) 123 | draw.text((x, y), str(i), (255, 215, 0), font=font) 124 | img.show() 125 | return panels, masks, panel_masks 126 | 127 | def concatPanels(self, img_file, fake_imgs, masks, panel_masks): 128 | img = io.imread(img_file) 129 | # out_imgs.append(f"D:\MyProject\Python\DL_learning\Manga-Panel-Extractor-master\out\in0_ref0.png") 130 | # out_imgs.append(f"D:\MyProject\Python\DL_learning\Manga-Panel-Extractor-master\out\in1_ref1.png") 131 | # out_imgs.append(f"D:\MyProject\Python\DL_learning\Manga-Panel-Extractor-master\out\in2_ref2.png") 132 | for i in range(len(fake_imgs)): 133 | x, y, w, h = masks[i] 134 | # fake_img = io.imread(fake_imgs[i]) 135 | # fake_img = np.array(fake_img) 136 | fake_img = fake_imgs[i] 137 | panel_mask = panel_masks[i] 138 | img[y:y + h, x:x + w][panel_mask == 0] = fake_img[panel_mask == 0] 139 | # Image.fromarray(img).show() 140 | out_folder = os.path.dirname(img_file) 141 | out_name = os.path.basename(img_file) 142 | out_name = os.path.splitext(out_name)[0] 143 | out_img_path = os.path.join(out_folder,'color',f'{out_name}_color.png') 144 | 145 | # show image 146 | Image.fromarray(img).show() 147 | # save image 148 | folder_path = os.path.join(out_folder, 'color') 149 | if not os.path.exists(folder_path): 150 | os.mkdir(folder_path) 151 | io.imsave(out_img_path, img) 152 | 153 | 154 | def main(args): 155 | panel_extractor = PanelExtractor(min_pct_panel=args.min_panel, max_pct_panel=args.max_panel) 156 | panels, masks, panel_masks = panel_extractor.extract(args.folder) 157 | panel_extractor.concatPanels(args.folder, [], masks, panel_masks) 158 | 159 | 160 | if __name__ == "__main__": 161 | parser = argparse.ArgumentParser( 162 | description="Implementation of a Manga Panel Extractor and dialogue bubble text eraser.", 163 | formatter_class=RawTextHelpFormatter 164 | ) 165 | parser.add_argument("-minp", "--min_panel", type=int, choices=range(1, 99), default=5, metavar="[1-99]", 166 | help="Percentage of minimum panel area in relation to total page area.") 167 | parser.add_argument("-maxp", "--max_panel", type=int, choices=range(1, 99), default=90, metavar="[1-99]", 168 | help="Percentage of minimum panel area in relation to total page area.") 169 | parser.add_argument("-f", '--folder', default='./images/002.png', type=str, 170 | help="""folder path to input manga pages. 171 | Panels will be saved to a directory named `panels` in this folder.""") 172 | 173 | args = parser.parse_args() 174 | main(args) 175 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from skimage import color, io 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from PIL import Image 9 | from models import ColorEncoder, ColorUNet 10 | from extractor.manga_panel_extractor import PanelExtractor 11 | import argparse 12 | 13 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 14 | 15 | def mkdirs(path): 16 | if not os.path.exists(path): 17 | os.makedirs(path) 18 | 19 | def Lab2RGB_out(img_lab): 20 | img_lab = img_lab.detach().cpu() 21 | img_l = img_lab[:,:1,:,:] 22 | img_ab = img_lab[:,1:,:,:] 23 | # print(torch.max(img_l), torch.min(img_l)) 24 | # print(torch.max(img_ab), torch.min(img_ab)) 25 | img_l = img_l + 50 26 | pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy() 27 | # grid_lab = utils.make_grid(pred_lab, nrow=1).numpy().astype("float64") 28 | # print(grid_lab.shape) 29 | out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1)* 255).astype("uint8") 30 | return out 31 | 32 | def RGB2Lab(inputs): 33 | return color.rgb2lab(inputs) 34 | 35 | def Normalize(inputs): 36 | l = inputs[:, :, 0:1] 37 | ab = inputs[:, :, 1:3] 38 | l = l - 50 39 | lab = np.concatenate((l, ab), 2) 40 | 41 | return lab.astype('float32') 42 | 43 | def numpy2tensor(inputs): 44 | out = torch.from_numpy(inputs.transpose(2,0,1)) 45 | return out 46 | 47 | def tensor2numpy(inputs): 48 | out = inputs[0,...].detach().cpu().numpy().transpose(1,2,0) 49 | return out 50 | 51 | def preprocessing(inputs): 52 | # input: rgb, [0, 255], uint8 53 | img_lab = Normalize(RGB2Lab(inputs)) 54 | img = np.array(inputs, 'float32') # [0, 255] 55 | img = numpy2tensor(img) 56 | img_lab = numpy2tensor(img_lab) 57 | return img.unsqueeze(0), img_lab.unsqueeze(0) 58 | 59 | if __name__ == "__main__": 60 | device = "cuda" 61 | 62 | # model_name = 'Color2Manga_sketch' 63 | ckpt_path = 'experiments/Color2Manga_gray/074000_gray.pt' 64 | test_dir_path = 'test_datasets/gray_test' 65 | no_extractor = False 66 | # imgs_num = len(os.listdir(test_dir_path)) // 2 67 | imgsize = 256 68 | 69 | parser = argparse.ArgumentParser() 70 | 71 | parser.add_argument("--path", type=str, default=None, help="path of input image") 72 | parser.add_argument("--size", type=int, default=None) 73 | parser.add_argument("--ckpt", type=str, default=None, help="path of model weight") 74 | parser.add_argument("-ne", "--no_extractor", action='store_true', 75 | help="Do not segment the manga panels.") 76 | 77 | args = parser.parse_args() 78 | 79 | if args.path: 80 | test_dir_path = args.path 81 | if args.size: 82 | imgsize = args.size 83 | if args.ckpt: 84 | ckpt_path = args.ckpt 85 | if args.no_extractor: 86 | no_extractor = args.no_extractor 87 | 88 | 89 | ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage) 90 | 91 | colorEncoder = ColorEncoder().to(device) 92 | colorEncoder.load_state_dict(ckpt["colorEncoder"]) 93 | colorEncoder.eval() 94 | 95 | colorUNet = ColorUNet().to(device) 96 | colorUNet.load_state_dict(ckpt["colorUNet"]) 97 | colorUNet.eval() 98 | 99 | imgs = [] 100 | imgs_lab = [] 101 | 102 | # for i in range(imgs_num): 103 | # idx = i 104 | # print('Image', idx, 'Input Image', 'in%d.JPEG'%idx, 'Ref Image', 'ref%d.JPEG'%idx) 105 | 106 | while 1: 107 | print(f'make sure both manga image and reference images are under this path{test_dir_path}') 108 | img_path = input("please input the name of image needed to be colorized(with file extension): ") 109 | img_path = os.path.join(test_dir_path, img_path) 110 | img_name = os.path.basename(img_path) 111 | img_name = os.path.splitext(img_name)[0] 112 | 113 | if no_extractor: 114 | ref_img_path = os.path.join(test_dir_path, input(f"{1}/{1} reference image:")) 115 | 116 | img1 = Image.open(img_path).convert("RGB") 117 | width, height = img1.size 118 | img2 = Image.open(ref_img_path).convert("RGB") 119 | 120 | img1, img1_lab = preprocessing(img1) 121 | img2, img2_lab = preprocessing(img2) 122 | 123 | img1 = img1.to(device) 124 | img1_lab = img1_lab.to(device) 125 | img2 = img2.to(device) 126 | img2_lab = img2_lab.to(device) 127 | 128 | # print('-------',torch.max(img1_lab[:,:1,:,:]), torch.min(img1_lab[:,1:,:,:])) 129 | 130 | with torch.no_grad(): 131 | img2_resize = F.interpolate(img2 / 255., size=(imgsize, imgsize), mode='bilinear', 132 | recompute_scale_factor=False, align_corners=False) 133 | img1_L_resize = F.interpolate(img1_lab[:, :1, :, :] / 50., size=(imgsize, imgsize), mode='bilinear', 134 | recompute_scale_factor=False, align_corners=False) 135 | 136 | color_vector = colorEncoder(img2_resize) 137 | 138 | fake_ab = colorUNet((img1_L_resize, color_vector)) 139 | fake_ab = F.interpolate(fake_ab * 110, size=(height, width), mode='bilinear', 140 | recompute_scale_factor=False, align_corners=False) 141 | 142 | fake_img = torch.cat((img1_lab[:, :1, :, :], fake_ab), 1) 143 | fake_img = Lab2RGB_out(fake_img) 144 | # io.imsave(out_img_path, fake_img) 145 | 146 | out_folder = os.path.dirname(img_path) 147 | out_name = os.path.basename(img_path) 148 | out_name = os.path.splitext(out_name)[0] 149 | out_img_path = os.path.join(out_folder, 'color', f'{out_name}_color.png') 150 | 151 | # show image 152 | Image.fromarray(fake_img).show() 153 | # save image 154 | folder_path = os.path.join(out_folder, 'color') 155 | if not os.path.exists(folder_path): 156 | os.mkdir(folder_path) 157 | io.imsave(out_img_path, fake_img) 158 | 159 | continue 160 | 161 | 162 | 163 | # extract panels from manga 164 | panel_extractor = PanelExtractor(min_pct_panel=5, max_pct_panel=90) 165 | panels, masks, panel_masks = panel_extractor.extract(img_path) 166 | panel_num = len(panels) 167 | 168 | ref_img_paths = [] 169 | # ref_img_path = os.path.join(test_dir_path, '%03d_ref.png' % idx) 170 | print("Please enter the name of the reference image in order according to the number prompts on the picture") 171 | for i in range(panel_num): 172 | ref_img_path = os.path.join(test_dir_path, input(f"{i+1}/{panel_num} reference image:")) 173 | ref_img_paths.append(ref_img_path) 174 | 175 | 176 | 177 | 178 | fake_imgs = [] 179 | for i in range(panel_num): 180 | img1 = Image.fromarray(panels[i]).convert("RGB") 181 | width, height = img1.size 182 | img2 = Image.open(ref_img_paths[i]).convert("RGB") 183 | 184 | # img1 = Image.open(img_path).convert("RGB") 185 | # width, height = img1.size 186 | # img2 = Image.open(ref_img_path).convert("RGB") 187 | 188 | img1, img1_lab = preprocessing(img1) 189 | img2, img2_lab = preprocessing(img2) 190 | 191 | img1 = img1.to(device) 192 | img1_lab = img1_lab.to(device) 193 | img2 = img2.to(device) 194 | img2_lab = img2_lab.to(device) 195 | 196 | # print('-------',torch.max(img1_lab[:,:1,:,:]), torch.min(img1_lab[:,1:,:,:])) 197 | 198 | with torch.no_grad(): 199 | img2_resize = F.interpolate(img2 / 255., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False) 200 | img1_L_resize = F.interpolate(img1_lab[:,:1,:,:] / 50., size=(imgsize, imgsize), mode='bilinear', recompute_scale_factor=False, align_corners=False) 201 | 202 | color_vector = colorEncoder(img2_resize) 203 | 204 | fake_ab = colorUNet((img1_L_resize, color_vector)) 205 | fake_ab = F.interpolate(fake_ab*110, size=(height, width), mode='bilinear', recompute_scale_factor=False, align_corners=False) 206 | 207 | fake_img = torch.cat((img1_lab[:,:1,:,:], fake_ab), 1) 208 | fake_img = Lab2RGB_out(fake_img) 209 | # io.imsave(f'test_datasets/gray_test/panels/{i}.png', fake_img) 210 | fake_imgs.append(fake_img) 211 | 212 | if panel_num == 1: 213 | out_folder = os.path.dirname(img_path) 214 | out_name = os.path.basename(img_path) 215 | out_name = os.path.splitext(out_name)[0] 216 | out_img_path = os.path.join(out_folder,'color',f'{out_name}_color.png') 217 | 218 | # show image 219 | Image.fromarray(fake_imgs[0]).show() 220 | # save image 221 | folder_path = os.path.join(out_folder, 'color') 222 | if not os.path.exists(folder_path): 223 | os.mkdir(folder_path) 224 | io.imsave(out_img_path, fake_imgs[0]) 225 | else: 226 | panel_extractor.concatPanels(img_path, fake_imgs, masks, panel_masks) 227 | 228 | print(f'colored image has been put to: {test_dir_path}color') 229 | 230 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from vgg_model import vgg19 8 | 9 | class DoubleConv(nn.Module): 10 | """(convolution => [BN] => ReLU) * 2""" 11 | 12 | def __init__(self, in_channels, out_channels, mid_channels=None): 13 | super().__init__() 14 | if not mid_channels: 15 | mid_channels = out_channels 16 | self.double_conv = nn.Sequential( 17 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 18 | nn.BatchNorm2d(mid_channels), 19 | nn.LeakyReLU(0.1, True), 20 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 21 | nn.BatchNorm2d(out_channels), 22 | nn.LeakyReLU(0.1, True) 23 | ) 24 | 25 | def forward(self, x): 26 | x = self.double_conv(x) 27 | return x 28 | 29 | class ResBlock(nn.Module): 30 | """(convolution => [BN] => ReLU) * 2""" 31 | 32 | def __init__(self, in_channels, out_channels): 33 | super().__init__() 34 | self.bottle_conv = nn.Conv2d(in_channels, out_channels, 1, 1, 0) 35 | self.double_conv = nn.Sequential( 36 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 37 | nn.BatchNorm2d(out_channels), 38 | nn.LeakyReLU(0.2, True), 39 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 40 | ) 41 | 42 | def forward(self, x): 43 | x = self.bottle_conv(x) 44 | x = self.double_conv(x) + x 45 | return x / math.sqrt(2) 46 | 47 | 48 | class Down(nn.Module): 49 | """Downscaling with stride conv then double conv""" 50 | 51 | def __init__(self, in_channels, out_channels): 52 | super().__init__() 53 | self.main = nn.Sequential( 54 | nn.Conv2d(in_channels, in_channels, 4, 2, 1), 55 | nn.LeakyReLU(0.1, True), 56 | # DoubleConv(in_channels, out_channels) 57 | ResBlock(in_channels, out_channels) 58 | ) 59 | 60 | 61 | def forward(self, x): 62 | 63 | x = self.main(x) 64 | 65 | return x 66 | 67 | class SDFT(nn.Module): 68 | 69 | def __init__(self, color_dim, channels, kernel_size = 3): 70 | super().__init__() 71 | 72 | # generate global conv weights 73 | fan_in = channels * kernel_size ** 2 74 | self.kernel_size = kernel_size 75 | self.padding = kernel_size // 2 76 | 77 | self.scale = 1 / math.sqrt(fan_in) 78 | self.modulation = nn.Conv2d(color_dim, channels, 1) 79 | self.weight = nn.Parameter( 80 | torch.randn(1, channels, channels, kernel_size, kernel_size) 81 | ) 82 | 83 | def forward(self, fea, color_style): 84 | # for global adjustation 85 | B, C, H, W = fea.size() 86 | # print(fea.shape, color_style.shape) 87 | style = self.modulation(color_style).view(B, 1, C, 1, 1) 88 | weight = self.scale * self.weight * style 89 | # demodulation 90 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 91 | weight = weight * demod.view(B, C, 1, 1, 1) 92 | 93 | weight = weight.view( 94 | B * C, C, self.kernel_size, self.kernel_size 95 | ) 96 | 97 | fea = fea.view(1, B * C, H, W) 98 | fea = F.conv2d(fea, weight, padding=self.padding, groups=B) 99 | fea = fea.view(B, C, H, W) 100 | 101 | return fea 102 | 103 | 104 | class UpBlock(nn.Module): 105 | 106 | 107 | def __init__(self, color_dim, in_channels, out_channels, kernel_size = 3, bilinear=True): 108 | super().__init__() 109 | 110 | # if bilinear, use the normal convolutions to reduce the number of channels 111 | if bilinear: 112 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 113 | 114 | else: 115 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 116 | 117 | self.conv_cat = nn.Sequential( 118 | nn.Conv2d(in_channels // 2 + in_channels // 8, out_channels, 1, 1, 0), 119 | nn.LeakyReLU(0.2, True), 120 | nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), 121 | nn.LeakyReLU(0.2, True) 122 | ) 123 | 124 | self.conv_s = nn.Conv2d(in_channels//2, out_channels, 1, 1, 0) 125 | 126 | # generate global conv weights 127 | self.SDFT = SDFT(color_dim, out_channels, kernel_size) 128 | 129 | 130 | def forward(self, x1, x2, color_style): 131 | # print(x1.shape, x2.shape, color_style.shape) 132 | x1 = self.up(x1) 133 | x1_s = self.conv_s(x1) 134 | 135 | x = torch.cat([x1, x2[:, ::4, :, :]], dim=1) 136 | x = self.conv_cat(x) 137 | x = self.SDFT(x, color_style) 138 | 139 | x = x + x1_s #ResBlock 140 | 141 | return x 142 | 143 | 144 | class ColorEncoder(nn.Module): 145 | def __init__(self, color_dim=512): 146 | super(ColorEncoder, self).__init__() 147 | 148 | # self.vgg = vgg19(pretrained_path=None) 149 | self.vgg = vgg19() 150 | 151 | self.feature2vector = nn.Sequential( 152 | nn.Conv2d(color_dim, color_dim, 4, 2, 2), # 8x8 153 | nn.LeakyReLU(0.2, True), 154 | nn.Conv2d(color_dim, color_dim, 3, 1, 1), 155 | nn.LeakyReLU(0.2, True), 156 | nn.Conv2d(color_dim, color_dim, 4, 2, 2), # 4x4 157 | nn.LeakyReLU(0.2, True), 158 | nn.Conv2d(color_dim, color_dim, 3, 1, 1), 159 | nn.LeakyReLU(0.2, True), 160 | nn.AdaptiveAvgPool2d((1, 1)), # 1x1 161 | nn.Conv2d(color_dim, color_dim//2, 1), # linear-1 162 | nn.LeakyReLU(0.2, True), 163 | nn.Conv2d(color_dim//2, color_dim//2, 1), # linear-2 164 | nn.LeakyReLU(0.2, True), 165 | nn.Conv2d(color_dim//2, color_dim, 1), # linear-3 166 | ) 167 | 168 | self.color_dim = color_dim 169 | 170 | def forward(self, x): 171 | # x #[0, 1] RGB 172 | vgg_fea = self.vgg(x, layer_name='relu5_2') # [B, 512, 16, 16] 173 | 174 | x_color = self.feature2vector(vgg_fea[-1]) # [B, 512, 1, 1] 175 | 176 | return x_color 177 | 178 | 179 | class ColorUNet(nn.Module): 180 | ### this model output is ab 181 | def __init__(self, n_channels=1, n_classes=3, bilinear=True): 182 | super(ColorUNet, self).__init__() 183 | self.n_channels = n_channels 184 | self.n_classes = n_classes 185 | self.bilinear = bilinear 186 | 187 | self.inc = DoubleConv(n_channels, 64) 188 | self.down1 = Down(64, 128) 189 | self.down2 = Down(128, 256) 190 | self.down3 = Down(256, 512) 191 | factor = 2 if bilinear else 1 192 | self.down4 = Down(512, 1024 // factor) 193 | 194 | self.up1 = UpBlock(512, 1024, 512 // factor, 3, bilinear) 195 | self.up2 = UpBlock(512, 512, 256 // factor, 3, bilinear) 196 | self.up3 = UpBlock(512, 256, 128 // factor, 5, bilinear) 197 | self.up4 = UpBlock(512, 128, 64, 5, bilinear) 198 | self.outc = nn.Sequential( 199 | nn.Conv2d(64, 64, 3, 1, 1), 200 | nn.LeakyReLU(0.2, True), 201 | nn.Conv2d(64, 2, 3, 1, 1), 202 | nn.Tanh() # [-1,1] 203 | ) 204 | 205 | def forward(self, x): 206 | # print(torch.max(x[0]), torch.min(x[0])) #[-1, 1] gray image L 207 | # print(torch.max(x[1]), torch.min(x[1])) # color vector 208 | 209 | x_color = x[1] # [B, 512, 1, 1] 210 | 211 | x1 = self.inc(x[0]) # [B, 64, 256, 256] 212 | x2 = self.down1(x1) # [B, 128, 128, 128] 213 | x3 = self.down2(x2) # [B, 256, 64, 64] 214 | x4 = self.down3(x3) # [B, 512, 32, 32] 215 | x5 = self.down4(x4) # [B, 512, 16, 16] 216 | 217 | x6 = self.up1(x5, x4, x_color) # [B, 256, 32, 32] 218 | x7 = self.up2(x6, x3, x_color) # [B, 128, 64, 64] 219 | x8 = self.up3(x7, x2, x_color) # [B, 64, 128, 128] 220 | x9 = self.up4(x8, x1, x_color) # [B, 64, 256, 256] 221 | x_ab = self.outc(x9) 222 | 223 | return x_ab 224 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | skimage 3 | cv2 4 | torch 5 | torchvision -------------------------------------------------------------------------------- /test_datasets/gray_test/001_in.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/001_in.png -------------------------------------------------------------------------------- /test_datasets/gray_test/001_ref_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/001_ref_a.png -------------------------------------------------------------------------------- /test_datasets/gray_test/001_ref_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/001_ref_b.png -------------------------------------------------------------------------------- /test_datasets/gray_test/002_in.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/002_in.jpeg -------------------------------------------------------------------------------- /test_datasets/gray_test/002_in_ref_a.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/002_in_ref_a.jpg -------------------------------------------------------------------------------- /test_datasets/gray_test/002_in_ref_b.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/002_in_ref_b.jpeg -------------------------------------------------------------------------------- /test_datasets/gray_test/003_in.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/003_in.jpeg -------------------------------------------------------------------------------- /test_datasets/gray_test/003_in_ref_a.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/003_in_ref_a.jpg -------------------------------------------------------------------------------- /test_datasets/gray_test/003_in_ref_b.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/003_in_ref_b.jpg -------------------------------------------------------------------------------- /test_datasets/gray_test/004_in.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/004_in.png -------------------------------------------------------------------------------- /test_datasets/gray_test/004_ref_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/004_ref_1.jpg -------------------------------------------------------------------------------- /test_datasets/gray_test/004_ref_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/004_ref_2.jpg -------------------------------------------------------------------------------- /test_datasets/gray_test/005_in.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/005_in.png -------------------------------------------------------------------------------- /test_datasets/gray_test/005_ref_1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/005_ref_1.jpeg -------------------------------------------------------------------------------- /test_datasets/gray_test/005_ref_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/005_ref_2.jpg -------------------------------------------------------------------------------- /test_datasets/gray_test/005_ref_3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/005_ref_3.jpeg -------------------------------------------------------------------------------- /test_datasets/gray_test/006_in.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/006_in.png -------------------------------------------------------------------------------- /test_datasets/gray_test/006_ref.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/006_ref.png -------------------------------------------------------------------------------- /test_datasets/gray_test/out/001_in_color_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/out/001_in_color_a.png -------------------------------------------------------------------------------- /test_datasets/gray_test/out/001_in_color_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/out/001_in_color_b.png -------------------------------------------------------------------------------- /test_datasets/gray_test/out/002_in_color_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/out/002_in_color_a.png -------------------------------------------------------------------------------- /test_datasets/gray_test/out/002_in_color_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/out/002_in_color_b.png -------------------------------------------------------------------------------- /test_datasets/gray_test/out/003_in_color_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/out/003_in_color_a.png -------------------------------------------------------------------------------- /test_datasets/gray_test/out/003_in_color_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/out/003_in_color_b.png -------------------------------------------------------------------------------- /test_datasets/gray_test/out/004_in_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/out/004_in_color.png -------------------------------------------------------------------------------- /test_datasets/gray_test/out/005_in_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/out/005_in_color.png -------------------------------------------------------------------------------- /test_datasets/gray_test/out/006_in_color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/gray_test/out/006_in_color.png -------------------------------------------------------------------------------- /test_datasets/sketch_test/001_in.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/sketch_test/001_in.jpg -------------------------------------------------------------------------------- /test_datasets/sketch_test/001_ref_a.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/sketch_test/001_ref_a.jpg -------------------------------------------------------------------------------- /test_datasets/sketch_test/001_ref_b.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/sketch_test/001_ref_b.jpg -------------------------------------------------------------------------------- /test_datasets/sketch_test/out/001_in_color_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/sketch_test/out/001_in_color_a.png -------------------------------------------------------------------------------- /test_datasets/sketch_test/out/001_in_color_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linshys/Example_Based_Manga_Colorization---cGAN/b7a6f09c25ff76d48983749698d0af9839fdc2e1/test_datasets/sketch_test/out/001_in_color_b.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import os 4 | 5 | import numpy as np 6 | from PIL import Image 7 | from skimage import color, io 8 | import torch 9 | from torch import nn, optim 10 | from torch.nn import functional as F 11 | from torch.utils import data 12 | from torchvision import transforms 13 | from tqdm import tqdm 14 | 15 | # from ColorEncoder import ColorEncoder 16 | from models import ColorEncoder, ColorUNet 17 | from vgg_model import vgg19 18 | from data.data_loader import MultiResolutionDataset 19 | 20 | from utils import tensor_lab2rgb 21 | 22 | from distributed import ( 23 | get_rank, 24 | synchronize, 25 | reduce_loss_dict, 26 | ) 27 | 28 | def mkdirss(dirpath): 29 | if not os.path.exists(dirpath): 30 | os.makedirs(dirpath) 31 | 32 | def data_sampler(dataset, shuffle, distributed): 33 | if distributed: 34 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 35 | 36 | if shuffle: 37 | return data.RandomSampler(dataset) 38 | 39 | else: 40 | return data.SequentialSampler(dataset) 41 | 42 | 43 | def requires_grad(model, flag=True): 44 | for p in model.parameters(): 45 | p.requires_grad = flag 46 | 47 | 48 | def sample_data(loader): 49 | while True: 50 | for batch in loader: 51 | yield batch 52 | 53 | def Lab2RGB_out(img_lab): 54 | img_lab = img_lab.detach().cpu() 55 | img_l = img_lab[:,:1,:,:] 56 | img_ab = img_lab[:,1:,:,:] 57 | # print(torch.max(img_l), torch.min(img_l)) 58 | # print(torch.max(img_ab), torch.min(img_ab)) 59 | img_l = img_l + 50 60 | pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy() 61 | # grid_lab = utils.make_grid(pred_lab, nrow=1).numpy().astype("float64") 62 | # print(grid_lab.shape) 63 | out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1)* 255).astype("uint8") 64 | return out 65 | 66 | def RGB2Lab(inputs): 67 | # input [0, 255] uint8 68 | # out l: [0, 100], ab: [-110, 110], float32 69 | return color.rgb2lab(inputs) 70 | 71 | def Normalize(inputs): 72 | l = inputs[:, :, 0:1] 73 | ab = inputs[:, :, 1:3] 74 | l = l - 50 75 | lab = np.concatenate((l, ab), 2) 76 | 77 | return lab.astype('float32') 78 | 79 | def numpy2tensor(inputs): 80 | out = torch.from_numpy(inputs.transpose(2,0,1)) 81 | return out 82 | 83 | def tensor2numpy(inputs): 84 | out = inputs[0,...].detach().cpu().numpy().transpose(1,2,0) 85 | return out 86 | 87 | def preprocessing(inputs): 88 | # input: rgb, [0, 255], uint8 89 | img_lab = Normalize(RGB2Lab(inputs)) 90 | img = np.array(inputs, 'float32') # [0, 255] 91 | img = numpy2tensor(img) 92 | img_lab = numpy2tensor(img_lab) 93 | return img.unsqueeze(0), img_lab.unsqueeze(0) 94 | 95 | def uncenter_l(inputs): 96 | l = inputs[:,:1,:,:] + 50 97 | ab = inputs[:,1:,:,:] 98 | return torch.cat((l, ab), 1) 99 | 100 | def train( 101 | args, 102 | loader, 103 | colorEncoder, 104 | colorUNet, 105 | vggnet, 106 | g_optim, 107 | device, 108 | ): 109 | loader = sample_data(loader) 110 | 111 | pbar = range(args.iter) 112 | 113 | if get_rank() == 0: 114 | pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) 115 | 116 | g_loss_val = 0 117 | loss_dict = {} 118 | recon_val_all = 0 119 | fea_val_all = 0 120 | 121 | if args.distributed: 122 | colorEncoder_module = colorEncoder.module 123 | colorUNet_module = colorUNet.module 124 | 125 | else: 126 | colorEncoder_module = colorEncoder 127 | colorUNet_module = colorUNet 128 | 129 | for idx in pbar: 130 | i = idx + args.start_iter+1 131 | 132 | if i > args.iter: 133 | print("Done!") 134 | 135 | break 136 | 137 | img, img_ref, img_lab = next(loader) 138 | 139 | # ima = img_ref.numpy() 140 | # ima = ima[0].astype('uint8') 141 | # ima = Image.fromarray(ima.transpose(1,2,0)) 142 | # ima.show() 143 | 144 | img = img.to(device) # GT [B, 3, 256, 256] 145 | img_lab = img_lab.to(device) # GT 146 | 147 | img_ref = img_ref.to(device) # tps_transformed image RGB [B, 3, 256, 256] 148 | 149 | img_l = img_lab[:,:1,:,:] / 50 # [-1, 1] target L 150 | img_ab = img_lab[:,1:,:,:] / 110 # [-1, 1] target ab 151 | # img_ref_ab = img_ref_lab[:,1:,:,:] / 110 # [-1, 1] ref ab 152 | 153 | colorEncoder.train() 154 | colorUNet.train() 155 | 156 | requires_grad(colorEncoder, True) 157 | requires_grad(colorUNet, True) 158 | 159 | ref_color_vector = colorEncoder(img_ref / 255.) 160 | 161 | fake_swap_ab = colorUNet((img_l, ref_color_vector)) # [-1, 1] 162 | 163 | ## recon l1 loss 164 | recon_loss = (F.smooth_l1_loss(fake_swap_ab, img_ab)) * 1 165 | 166 | ## feature loss 167 | real_img_rgb = img / 255. 168 | features_A = vggnet(real_img_rgb, layer_name='all') 169 | 170 | fake_swap_rgb = tensor_lab2rgb(torch.cat((img_l*50+50, fake_swap_ab*110), 1)) # [0, 1] 171 | features_B = vggnet(fake_swap_rgb, layer_name='all') 172 | # fea_loss = F.l1_loss(features_A[-1], features_B[-1]) * 0.1 173 | # fea_loss = 0 174 | 175 | fea_loss1 = F.l1_loss(features_A[0], features_B[0]) / 32 * 0.1 176 | fea_loss2 = F.l1_loss(features_A[1], features_B[1]) / 16 * 0.1 177 | fea_loss3 = F.l1_loss(features_A[2], features_B[2]) / 8 * 0.1 178 | fea_loss4 = F.l1_loss(features_A[3], features_B[3]) / 4 * 0.1 179 | fea_loss5 = F.l1_loss(features_A[4], features_B[4]) * 0.1 180 | 181 | fea_loss = fea_loss1 + fea_loss2 + fea_loss3 + fea_loss4 + fea_loss5 182 | 183 | loss_dict["recon"] = recon_loss 184 | 185 | loss_dict["fea"] = fea_loss 186 | 187 | g_optim.zero_grad() 188 | (recon_loss+fea_loss).backward() 189 | g_optim.step() 190 | 191 | loss_reduced = reduce_loss_dict(loss_dict) 192 | 193 | 194 | recon_val = loss_reduced["recon"].mean().item() 195 | recon_val_all += recon_val 196 | # recon_val = 0 197 | fea_val = loss_reduced["fea"].mean().item() 198 | fea_val_all += fea_val 199 | # fea_val = 0 200 | 201 | if get_rank() == 0: 202 | pbar.set_description( 203 | ( 204 | f"recon:{recon_val:.4f}; fea:{fea_val:.4f};" 205 | ) 206 | ) 207 | 208 | 209 | if i % 50 == 0: 210 | print(f"recon_all:{recon_val_all/50:.4f}; fea_all:{fea_val_all/50:.4f};") 211 | recon_val_all = 0 212 | fea_val_all = 0 213 | 214 | if i % 500 == 0: 215 | with torch.no_grad(): 216 | colorEncoder.eval() 217 | colorUNet.eval() 218 | 219 | imgsize = 256 220 | for inum in range(15): 221 | val_img_path = 'test_datasets/val_Manga/in%d.jpg' % (inum + 1) 222 | val_ref_path = 'test_datasets/val_Manga/ref%d.jpg' % (inum + 1) 223 | # val_img_path = 'test_datasets/val_daytime/day_sample/in%d.jpg'%(inum+1) 224 | # val_ref_path = 'test_datasets/val_daytime/night_sample/dark4.jpg' 225 | out_name = 'in%d_ref%d.png'%(inum+1, inum+1) 226 | val_img = Image.open(val_img_path).convert("RGB").resize((imgsize, imgsize)) 227 | val_img_ref = Image.open(val_ref_path).convert("RGB").resize((imgsize, imgsize)) 228 | val_img, val_img_lab = preprocessing(val_img) 229 | val_img_ref, val_img_ref_lab = preprocessing(val_img_ref) 230 | 231 | # val_img = val_img.to(device) 232 | val_img_lab = val_img_lab.to(device) 233 | val_img_ref = val_img_ref.to(device) 234 | # val_img_ref_lab = val_img_ref_lab.to(device) 235 | 236 | val_img_l = val_img_lab[:,:1,:,:] / 50. # [-1, 1] 237 | # val_img_ref_ab = val_img_ref_lab[:,1:,:,:] / 110. # [-1, 1] 238 | 239 | ref_color_vector = colorEncoder(val_img_ref / 255.) # [0, 1] 240 | fake_swap_ab = colorUNet((val_img_l, ref_color_vector)) 241 | 242 | fake_img = torch.cat((val_img_l*50, fake_swap_ab*110), 1) 243 | 244 | sample = np.concatenate((tensor2numpy(val_img), tensor2numpy(val_img_ref), Lab2RGB_out(fake_img)), 1) 245 | 246 | out_dir = 'training_logs/%s/%06d'%(args.experiment_name, i) 247 | mkdirss(out_dir) 248 | io.imsave('%s/%s'%(out_dir, out_name), sample.astype('uint8')) 249 | torch.cuda.empty_cache() 250 | if i % 2500 == 0: 251 | out_dir = "experiments/%s"%(args.experiment_name) 252 | mkdirss(out_dir) 253 | torch.save( 254 | { 255 | "colorEncoder": colorEncoder_module.state_dict(), 256 | "colorUNet": colorUNet_module.state_dict(), 257 | "g_optim": g_optim.state_dict(), 258 | "args": args, 259 | }, 260 | f"%s/{str(i).zfill(6)}.pt"%(out_dir), 261 | ) 262 | 263 | 264 | if __name__ == "__main__": 265 | device = "cuda" 266 | 267 | torch.backends.cudnn.benchmark = True 268 | 269 | parser = argparse.ArgumentParser() 270 | 271 | parser.add_argument("--datasets", type=str) 272 | parser.add_argument("--iter", type=int, default=100000) 273 | parser.add_argument("--batch", type=int, default=16) 274 | parser.add_argument("--size", type=int, default=256) 275 | parser.add_argument("--ckpt", type=str, default=None) 276 | parser.add_argument("--lr", type=float, default=0.0001) 277 | parser.add_argument("--experiment_name", type=str, default="default") 278 | parser.add_argument("--wandb", action="store_true") 279 | parser.add_argument("--local_rank", type=int, default=0) 280 | 281 | args = parser.parse_args() 282 | 283 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 284 | args.distributed = n_gpu > 1 285 | 286 | if args.distributed: 287 | torch.cuda.set_device(args.local_rank) 288 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 289 | synchronize() 290 | 291 | args.start_iter = 0 292 | 293 | vggnet = vgg19(pretrained_path = './experiments/VGG19/vgg19-dcbb9e9d.pth', require_grad = False) 294 | vggnet = vggnet.to(device) 295 | vggnet.eval() 296 | 297 | colorEncoder = ColorEncoder(color_dim=512).to(device) 298 | colorUNet = ColorUNet(bilinear=True).to(device) 299 | 300 | 301 | g_optim = optim.Adam( 302 | list(colorEncoder.parameters()) + list(colorUNet.parameters()), 303 | lr=args.lr, 304 | betas=(0.9, 0.99), 305 | ) 306 | 307 | if args.ckpt is not None: 308 | print("load model:", args.ckpt) 309 | 310 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 311 | 312 | try: 313 | ckpt_name = os.path.basename(args.ckpt) 314 | args.start_iter = int(os.path.splitext(ckpt_name)[0]) 315 | 316 | except ValueError: 317 | pass 318 | 319 | colorEncoder.load_state_dict(ckpt["colorEncoder"]) 320 | colorUNet.load_state_dict(ckpt["colorUNet"]) 321 | g_optim.load_state_dict(ckpt["g_optim"]) 322 | 323 | # print(args.distributed) 324 | 325 | if args.distributed: 326 | colorEncoder = nn.parallel.DistributedDataParallel( 327 | colorEncoder, 328 | device_ids=[args.local_rank], 329 | output_device=args.local_rank, 330 | broadcast_buffers=False, 331 | ) 332 | 333 | colorUNet = nn.parallel.DistributedDataParallel( 334 | colorUNet, 335 | device_ids=[args.local_rank], 336 | output_device=args.local_rank, 337 | broadcast_buffers=False, 338 | ) 339 | 340 | 341 | transform = transforms.Compose( 342 | [ 343 | transforms.RandomHorizontalFlip(), 344 | transforms.RandomVerticalFlip(), 345 | transforms.RandomRotation(degrees=(0, 360)) 346 | ] 347 | ) 348 | 349 | datasets = [] 350 | dataset = MultiResolutionDataset(args.datasets, transform, args.size) 351 | datasets.append(dataset) 352 | 353 | loader = data.DataLoader( 354 | data.ConcatDataset(datasets), 355 | batch_size=args.batch, 356 | sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), 357 | drop_last=True, 358 | ) 359 | 360 | train( 361 | args, 362 | loader, 363 | colorEncoder, 364 | colorUNet, 365 | vggnet, 366 | g_optim, 367 | device, 368 | ) 369 | 370 | -------------------------------------------------------------------------------- /train_all_gray.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import os 4 | import re 5 | 6 | import numpy as np 7 | from PIL import Image 8 | from skimage import color, io 9 | import torch 10 | from torch import nn, optim 11 | from torch.nn import functional as F 12 | from torch.utils import data 13 | from torchvision import transforms 14 | from tqdm import tqdm 15 | from torch.autograd import Variable 16 | 17 | # from ColorEncoder import ColorEncoder 18 | from models import ColorEncoder, ColorUNet 19 | from vgg_model import vgg19 20 | from discriminator import Discriminator 21 | from data.data_loader import MultiResolutionDataset 22 | 23 | from utils import tensor_lab2rgb 24 | 25 | from distributed import ( 26 | get_rank, 27 | synchronize, 28 | reduce_loss_dict, 29 | ) 30 | 31 | 32 | def mkdirss(dirpath): 33 | if not os.path.exists(dirpath): 34 | os.makedirs(dirpath) 35 | 36 | 37 | def data_sampler(dataset, shuffle, distributed): 38 | if distributed: 39 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 40 | 41 | if shuffle: 42 | return data.RandomSampler(dataset) 43 | 44 | else: 45 | return data.SequentialSampler(dataset) 46 | 47 | 48 | def requires_grad(model, flag=True): 49 | for p in model.parameters(): 50 | p.requires_grad = flag 51 | 52 | 53 | def sample_data(loader): 54 | while True: 55 | for batch in loader: 56 | yield batch 57 | 58 | 59 | def Lab2RGB_out(img_lab): 60 | img_lab = img_lab.detach().cpu() 61 | img_l = img_lab[:, :1, :, :] 62 | img_ab = img_lab[:, 1:, :, :] 63 | # print(torch.max(img_l), torch.min(img_l)) 64 | # print(torch.max(img_ab), torch.min(img_ab)) 65 | img_l = img_l + 50 66 | pred_lab = torch.cat((img_l, img_ab), 1)[0, ...].numpy() 67 | # grid_lab = utils.make_grid(pred_lab, nrow=1).numpy().astype("float64") 68 | # print(grid_lab.shape) 69 | out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8") 70 | return out 71 | 72 | 73 | def RGB2Lab(inputs): 74 | # input [0, 255] uint8 75 | # out l: [0, 100], ab: [-110, 110], float32 76 | return color.rgb2lab(inputs) 77 | 78 | 79 | def Normalize(inputs): 80 | l = inputs[:, :, 0:1] 81 | ab = inputs[:, :, 1:3] 82 | l = l - 50 83 | lab = np.concatenate((l, ab), 2) 84 | 85 | return lab.astype('float32') 86 | 87 | 88 | def numpy2tensor(inputs): 89 | out = torch.from_numpy(inputs.transpose(2, 0, 1)) 90 | return out 91 | 92 | 93 | def tensor2numpy(inputs): 94 | out = inputs[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 95 | return out 96 | 97 | 98 | def preprocessing(inputs): 99 | # input: rgb, [0, 255], uint8 100 | img_lab = Normalize(RGB2Lab(inputs)) 101 | img = np.array(inputs, 'float32') # [0, 255] 102 | img = numpy2tensor(img) 103 | img_lab = numpy2tensor(img_lab) 104 | return img.unsqueeze(0), img_lab.unsqueeze(0) 105 | 106 | 107 | def uncenter_l(inputs): 108 | l = inputs[:, :1, :, :] + 50 109 | ab = inputs[:, 1:, :, :] 110 | return torch.cat((l, ab), 1) 111 | 112 | 113 | def train( 114 | args, 115 | loader, 116 | colorEncoder, 117 | colorUNet, 118 | discriminator, 119 | vggnet, 120 | g_optim, 121 | d_optim, 122 | device, 123 | ): 124 | loader = sample_data(loader) 125 | 126 | pbar = range(args.iter) 127 | 128 | if get_rank() == 0: 129 | pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) 130 | 131 | g_loss_val = 0 132 | loss_dict = {} 133 | recon_val_all = 0 134 | fea_val_all = 0 135 | disc_val_all = 0 136 | disc_val_GAN_all = 0 137 | disc_val = 0 138 | count = 0 139 | criterion_GAN = torch.nn.MSELoss().to(device) 140 | 141 | # Calculate output of image discriminator (PatchGAN) 142 | patch = (1, args.size // 2 ** 4, args.size // 2 ** 4) 143 | Tensor = torch.cuda.FloatTensor if device == 'cuda' else torch.FloatTensor 144 | 145 | colorEncoder_module = colorEncoder 146 | colorUNet_module = colorUNet 147 | 148 | for idx in pbar: 149 | i = idx + args.start_iter+1 150 | 151 | if i > args.iter: 152 | print("Done!") 153 | 154 | break 155 | 156 | img, img_ref, img_lab = next(loader) 157 | 158 | # Adversarial ground truths 159 | valid = Variable(Tensor(np.ones((img.size(0), *patch))), requires_grad=False) 160 | fake = Variable(Tensor(np.zeros((img.size(0), *patch))), requires_grad=False) 161 | 162 | img = img.to(device) # GT [B, 3, 256, 256] 163 | img_lab = img_lab.to(device) # GT 164 | 165 | img_ref = img_ref.to(device) # tps_transformed image RGB [B, 3, 256, 256] 166 | 167 | img_l = img_lab[:, :1, :, :] / 50 # [-1, 1] target L 168 | img_ab = img_lab[:, 1:, :, :] / 110 # [-1, 1] target ab 169 | # img_ref_ab = img_ref_lab[:,1:,:,:] / 110 # [-1, 1] ref ab 170 | 171 | colorEncoder.train() 172 | colorUNet.train() 173 | discriminator.train() 174 | 175 | requires_grad(colorEncoder, True) 176 | requires_grad(colorUNet, True) 177 | requires_grad(discriminator, True) 178 | 179 | # ------------------ 180 | # Train Generators 181 | # ------------------ 182 | 183 | ref_color_vector = colorEncoder(img_ref / 255.) 184 | 185 | fake_swap_ab = colorUNet((img_l, ref_color_vector)) # [-1, 1] 186 | 187 | ## recon l1 loss 188 | recon_loss = (F.smooth_l1_loss(fake_swap_ab, img_ab)) 189 | 190 | ## feature loss 191 | real_img_rgb = img / 255. 192 | features_A = vggnet(real_img_rgb, layer_name='all') 193 | 194 | fake_swap_rgb = tensor_lab2rgb(torch.cat((img_l * 50 + 50, fake_swap_ab * 110), 1)) # [0, 1] 195 | features_B = vggnet(fake_swap_rgb, layer_name='all') 196 | # fea_loss = F.l1_loss(features_A[-1], features_B[-1]) * 0.1 197 | # fea_loss = 0 198 | 199 | fea_loss1 = F.l1_loss(features_A[0], features_B[0]) / 32 * 0.1 200 | fea_loss2 = F.l1_loss(features_A[1], features_B[1]) / 16 * 0.1 201 | fea_loss3 = F.l1_loss(features_A[2], features_B[2]) / 8 * 0.1 202 | fea_loss4 = F.l1_loss(features_A[3], features_B[3]) / 4 * 0.1 203 | fea_loss5 = F.l1_loss(features_A[4], features_B[4]) * 0.1 204 | 205 | fea_loss = fea_loss1 + fea_loss2 + fea_loss3 + fea_loss4 + fea_loss5 206 | 207 | ## discriminator loss 208 | real_img_rgb = img / 255. 209 | img_ref_rgb = img_ref / 255. 210 | zero_ab_image = torch.zeros_like(fake_swap_ab) 211 | input_img_rgb = tensor_lab2rgb(torch.cat((img_l * 50 + 50, zero_ab_image), 1)) # [0, 1] 212 | 213 | pred_fake = discriminator(fake_swap_rgb, input_img_rgb, img_ref_rgb) 214 | disc_loss_GAN = criterion_GAN(pred_fake, valid) 215 | disc_loss_GAN = disc_loss_GAN*0.01 216 | 217 | loss_dict["recon"] = recon_loss 218 | 219 | loss_dict["fea"] = fea_loss 220 | 221 | loss_dict["disc_loss_GAN"] = disc_loss_GAN 222 | 223 | g_optim.zero_grad() 224 | (recon_loss + fea_loss + disc_loss_GAN).backward() 225 | g_optim.step() 226 | 227 | # --------------------- 228 | # Train Discriminator 229 | # --------------------- 230 | # if the disc_loss_GAN<0.003, then start to train Discriminator 231 | # if i%3 == 0: 232 | 233 | # Real loss 234 | pred_real = discriminator(real_img_rgb, input_img_rgb, img_ref_rgb) 235 | loss_real = criterion_GAN(pred_real, valid) 236 | 237 | # Fake loss 238 | pred_fake = discriminator(fake_swap_rgb.detach(), input_img_rgb, img_ref_rgb) 239 | loss_fake = criterion_GAN(pred_fake, fake) 240 | 241 | # Total loss 242 | disc_loss = 0.5 * (loss_real + loss_fake) 243 | 244 | d_optim.zero_grad() 245 | disc_loss.backward() 246 | d_optim.step() 247 | 248 | # loss for discriminator itself 249 | disc_val = disc_loss.mean().item() 250 | disc_val_all += disc_val 251 | count +=1 252 | 253 | # -------------- 254 | # Log Progress 255 | # -------------- 256 | 257 | loss_reduced = reduce_loss_dict(loss_dict) 258 | 259 | recon_val = loss_reduced["recon"].mean().item() 260 | recon_val_all += recon_val 261 | # recon_val = 0 262 | fea_val = loss_reduced["fea"].mean().item() 263 | fea_val_all += fea_val 264 | # fea_val = 0 265 | 266 | # loss for generator 267 | disc_val_GAN = loss_reduced["disc_loss_GAN"].mean().item() 268 | disc_val_GAN_all += disc_val_GAN 269 | 270 | 271 | 272 | if get_rank() == 0: 273 | pbar.set_description( 274 | ( 275 | f"recon:{recon_val:.4f}; fea:{fea_val:.4f}; disc_GAN:{disc_val_GAN:.4f}; discriminator:{disc_val:.4f};" 276 | ) 277 | ) 278 | 279 | if i % 100 == 0: 280 | if disc_val_all!=0: 281 | disc_val_all = disc_val_all/count 282 | print(f"recon_all:{recon_val_all / 100:.4f}; fea_all:{fea_val_all / 100:.4f}; disc_GAN_all:{disc_val_GAN_all / 100:.4f};discriminator:{disc_val_all:.4f};") 283 | recon_val_all = 0 284 | fea_val_all = 0 285 | disc_val_GAN_all = 0 286 | disc_val_all = 0 287 | count = 0 288 | 289 | # this code is for model validation, you should prepare you own val dataset and edit code to use it 290 | # if i % 250 == 0: 291 | # with torch.no_grad(): 292 | # colorEncoder.eval() 293 | # colorUNet.eval() 294 | # 295 | # imgsize = 256 296 | # for inum in range(15): 297 | # val_img_path = 'test_datasets/val_Manga/in%d.jpg' % (inum + 1) 298 | # val_ref_path = 'test_datasets/val_Manga/ref%d.jpg' % (inum + 1) 299 | # # val_img_path = 'test_datasets/val_daytime/day_sample/in%d.jpg'%(inum+1) 300 | # # val_ref_path = 'test_datasets/val_daytime/night_sample/dark4.jpg' 301 | # out_name = 'in%d_ref%d.png' % (inum + 1, inum + 1) 302 | # val_img = Image.open(val_img_path).convert("RGB").resize((imgsize, imgsize)) 303 | # val_img_ref = Image.open(val_ref_path).convert("RGB").resize((imgsize, imgsize)) 304 | # val_img, val_img_lab = preprocessing(val_img) 305 | # val_img_ref, val_img_ref_lab = preprocessing(val_img_ref) 306 | # 307 | # # val_img = val_img.to(device) 308 | # val_img_lab = val_img_lab.to(device) 309 | # val_img_ref = val_img_ref.to(device) 310 | # # val_img_ref_lab = val_img_ref_lab.to(device) 311 | # 312 | # val_img_l = val_img_lab[:, :1, :, :] / 50. # [-1, 1] 313 | # # val_img_ref_ab = val_img_ref_lab[:,1:,:,:] / 110. # [-1, 1] 314 | # 315 | # ref_color_vector = colorEncoder(val_img_ref / 255.) # [0, 1] 316 | # fake_swap_ab = colorUNet((val_img_l, ref_color_vector)) 317 | # 318 | # fake_img = torch.cat((val_img_l * 50, fake_swap_ab * 110), 1) 319 | # 320 | # sample = np.concatenate( 321 | # (tensor2numpy(val_img), tensor2numpy(val_img_ref), Lab2RGB_out(fake_img)), 1) 322 | # 323 | # out_dir = 'training_logs/%s/%06d' % (args.experiment_name, i) 324 | # mkdirss(out_dir) 325 | # io.imsave('%s/%s' % (out_dir, out_name), sample.astype('uint8')) 326 | # torch.cuda.empty_cache() 327 | if i % 2000 == 0: 328 | out_dir_g = "experiments/%s" % (args.experiment_name) 329 | mkdirss(out_dir_g) 330 | torch.save( 331 | { 332 | "colorEncoder": colorEncoder_module.state_dict(), 333 | "colorUNet": colorUNet_module.state_dict(), 334 | "g_optim": g_optim.state_dict(), 335 | "args": args, 336 | }, 337 | f"%s/{str(i).zfill(6)}_gray.pt" % (out_dir_g), 338 | ) 339 | out_dir_d = "experiments/Discriminator" 340 | mkdirss(out_dir_d) 341 | torch.save( 342 | { 343 | "discriminator": discriminator.state_dict(), 344 | "d_optim": d_optim.state_dict(), 345 | "args": args, 346 | }, 347 | f"%s/{str(i).zfill(6)}_d.pt" % (out_dir_d), 348 | ) 349 | 350 | 351 | if __name__ == "__main__": 352 | device = "cuda" 353 | 354 | torch.backends.cudnn.benchmark = True 355 | 356 | parser = argparse.ArgumentParser() 357 | 358 | parser.add_argument("--datasets", type=str) 359 | parser.add_argument("--iter", type=int, default=100000) 360 | parser.add_argument("--batch", type=int, default=16) 361 | parser.add_argument("--size", type=int, default=256) 362 | parser.add_argument("--ckpt", type=str, default=None) 363 | parser.add_argument("--ckpt_disc", type=str, default=None) 364 | parser.add_argument("--lr", type=float, default=0.0001) 365 | parser.add_argument("--lr_disc", type=float, default=0.0002) 366 | parser.add_argument("--experiment_name", type=str, default="default") 367 | parser.add_argument("--wandb", action="store_true") 368 | parser.add_argument("--local_rank", type=int, default=0) 369 | 370 | args = parser.parse_args() 371 | 372 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 373 | args.distributed = n_gpu > 1 374 | 375 | args.start_iter = 0 376 | 377 | vggnet = vgg19(pretrained_path='./experiments/VGG19/vgg19-dcbb9e9d.pth', require_grad=False) 378 | vggnet = vggnet.to(device) 379 | vggnet.eval() 380 | 381 | colorEncoder = ColorEncoder(color_dim=512).to(device) 382 | colorUNet = ColorUNet(bilinear=True).to(device) 383 | discriminator = Discriminator(in_channels=3).to(device) 384 | 385 | g_optim = optim.Adam( 386 | list(colorEncoder.parameters()) + list(colorUNet.parameters()), 387 | lr=args.lr, 388 | betas=(0.9, 0.99), 389 | ) 390 | 391 | d_optim = optim.Adam( 392 | discriminator.parameters(), 393 | lr=args.lr_disc, 394 | betas=(0.5, 0.999), 395 | ) 396 | 397 | if args.ckpt is not None: 398 | print("load model:", args.ckpt) 399 | 400 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 401 | 402 | try: 403 | ckpt_name = os.path.basename(args.ckpt) 404 | match = re.search(r'\d+', ckpt_name) 405 | if match: 406 | args.start_iter = int(match.group(0)) 407 | else: 408 | args.start_iter = 0 409 | except ValueError: 410 | pass 411 | 412 | colorEncoder.load_state_dict(ckpt["colorEncoder"]) 413 | colorUNet.load_state_dict(ckpt["colorUNet"]) 414 | g_optim.load_state_dict(ckpt["g_optim"]) 415 | 416 | if args.ckpt_disc is not None: 417 | print("load discriminator model:", args.ckpt_disc) 418 | 419 | ckpt_disc = torch.load(args.ckpt_disc, map_location=lambda storage, loc: storage) 420 | discriminator.load_state_dict(ckpt_disc["discriminator"]) 421 | d_optim.load_state_dict(ckpt_disc["d_optim"]) 422 | # print(args.distributed) 423 | 424 | transform = transforms.Compose( 425 | [ 426 | transforms.RandomHorizontalFlip(), 427 | # transforms.RandomVerticalFlip(), 428 | transforms.RandomRotation(degrees=(-90, 90)) 429 | ] 430 | ) 431 | 432 | datasets = [] 433 | dataset = MultiResolutionDataset(args.datasets, transform, args.size) 434 | datasets.append(dataset) 435 | 436 | loader = data.DataLoader( 437 | data.ConcatDataset(datasets), 438 | batch_size=args.batch, 439 | sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), 440 | drop_last=True, 441 | ) 442 | 443 | train( 444 | args, 445 | loader, 446 | colorEncoder, 447 | colorUNet, 448 | discriminator, 449 | vggnet, 450 | g_optim, 451 | d_optim, 452 | device, 453 | ) 454 | -------------------------------------------------------------------------------- /train_all_sketch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import os 4 | import re 5 | 6 | import numpy as np 7 | from PIL import Image 8 | from skimage import color, io 9 | import torch 10 | from torch import nn, optim 11 | from torch.nn import functional as F 12 | from torch.utils import data 13 | from torchvision import transforms 14 | from tqdm import tqdm 15 | from torch.autograd import Variable 16 | 17 | # from ColorEncoder import ColorEncoder 18 | from models import ColorEncoder, ColorUNet 19 | from vgg_model import vgg19 20 | from discriminator import Discriminator 21 | # from data.data_loader import MultiResolutionDataset 22 | from data.data_loader_sketch import MultiResolutionDataset 23 | 24 | from utils import tensor_lab2rgb 25 | 26 | from distributed import ( 27 | get_rank, 28 | synchronize, 29 | reduce_loss_dict, 30 | ) 31 | 32 | 33 | def mkdirss(dirpath): 34 | if not os.path.exists(dirpath): 35 | os.makedirs(dirpath) 36 | 37 | 38 | def data_sampler(dataset, shuffle, distributed): 39 | if distributed: 40 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 41 | 42 | if shuffle: 43 | return data.RandomSampler(dataset) 44 | 45 | else: 46 | return data.SequentialSampler(dataset) 47 | 48 | 49 | def requires_grad(model, flag=True): 50 | for p in model.parameters(): 51 | p.requires_grad = flag 52 | 53 | 54 | def sample_data(loader): 55 | while True: 56 | for batch in loader: 57 | yield batch 58 | 59 | 60 | def Lab2RGB_out(img_lab): 61 | img_lab = img_lab.detach().cpu() 62 | img_l = img_lab[:, :1, :, :] 63 | img_ab = img_lab[:, 1:, :, :] 64 | # print(torch.max(img_l), torch.min(img_l)) 65 | # print(torch.max(img_ab), torch.min(img_ab)) 66 | img_l = img_l + 50 67 | pred_lab = torch.cat((img_l, img_ab), 1)[0, ...].numpy() 68 | # grid_lab = utils.make_grid(pred_lab, nrow=1).numpy().astype("float64") 69 | # print(grid_lab.shape) 70 | out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8") 71 | return out 72 | 73 | 74 | def RGB2Lab(inputs): 75 | # input [0, 255] uint8 76 | # out l: [0, 100], ab: [-110, 110], float32 77 | return color.rgb2lab(inputs) 78 | 79 | 80 | def Normalize(inputs): 81 | l = inputs[:, :, 0:1] 82 | ab = inputs[:, :, 1:3] 83 | l = l - 50 84 | lab = np.concatenate((l, ab), 2) 85 | 86 | return lab.astype('float32') 87 | 88 | 89 | def numpy2tensor(inputs): 90 | out = torch.from_numpy(inputs.transpose(2, 0, 1)) 91 | return out 92 | 93 | 94 | def tensor2numpy(inputs): 95 | out = inputs[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 96 | return out 97 | 98 | 99 | def preprocessing(inputs): 100 | # input: rgb, [0, 255], uint8 101 | img_lab = Normalize(RGB2Lab(inputs)) 102 | img = np.array(inputs, 'float32') # [0, 255] 103 | img = numpy2tensor(img) 104 | img_lab = numpy2tensor(img_lab) 105 | return img.unsqueeze(0), img_lab.unsqueeze(0) 106 | 107 | 108 | def uncenter_l(inputs): 109 | l = inputs[:, :1, :, :] + 50 110 | ab = inputs[:, 1:, :, :] 111 | return torch.cat((l, ab), 1) 112 | 113 | 114 | def train( 115 | args, 116 | loader, 117 | colorEncoder, 118 | colorUNet, 119 | discriminator, 120 | vggnet, 121 | g_optim, 122 | d_optim, 123 | device, 124 | ): 125 | loader = sample_data(loader) 126 | 127 | pbar = range(args.iter) 128 | 129 | if get_rank() == 0: 130 | pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) 131 | 132 | g_loss_val = 0 133 | loss_dict = {} 134 | recon_val_all = 0 135 | fea_val_all = 0 136 | disc_val_all = 0 137 | disc_val_GAN_all = 0 138 | disc_val = 0 139 | count = 0 140 | criterion_GAN = torch.nn.MSELoss().to(device) 141 | 142 | # Calculate output of image discriminator (PatchGAN) 143 | patch = (1, args.size // 2 ** 4, args.size // 2 ** 4) 144 | Tensor = torch.cuda.FloatTensor if device == 'cuda' else torch.FloatTensor 145 | 146 | colorEncoder_module = colorEncoder 147 | colorUNet_module = colorUNet 148 | 149 | for idx in pbar: 150 | i = idx + args.start_iter + 1 151 | 152 | if i > args.iter: 153 | print("Done!") 154 | 155 | break 156 | 157 | # img, img_ref, img_lab = next(loader) 158 | img, img_ref, img_lab, img_lab_sketch = next(loader) 159 | 160 | # Adversarial ground truths 161 | valid = Variable(Tensor(np.ones((img.size(0), *patch))), requires_grad=False) 162 | fake = Variable(Tensor(np.zeros((img.size(0), *patch))), requires_grad=False) 163 | # ima = img_ref.numpy() 164 | # ima = ima[0].astype('uint8') 165 | # ima = Image.fromarray(ima.transpose(1,2,0)) 166 | # ima.show() 167 | 168 | img = img.to(device) # GT [B, 3, 256, 256] 169 | img_lab = img_lab.to(device) # GT 170 | img_lab_sketch = img_lab_sketch.to(device) 171 | 172 | img_ref = img_ref.to(device) # tps_transformed image RGB [B, 3, 256, 256] 173 | 174 | img_l = img_lab_sketch[:, :1, :, :] / 50 # [-1, 1] target L 175 | img_ab = img_lab[:, 1:, :, :] / 110 # [-1, 1] target ab 176 | # img_ref_ab = img_ref_lab[:,1:,:,:] / 110 # [-1, 1] ref ab 177 | 178 | colorEncoder.train() 179 | colorUNet.train() 180 | discriminator.train() 181 | 182 | requires_grad(colorEncoder, True) 183 | requires_grad(colorUNet, True) 184 | requires_grad(discriminator, True) 185 | 186 | # ------------------ 187 | # Train Generators 188 | # ------------------ 189 | 190 | ref_color_vector = colorEncoder(img_ref / 255.) 191 | 192 | fake_swap_ab = colorUNet((img_l, ref_color_vector)) # [-1, 1] 193 | 194 | ## recon l1 loss 195 | recon_loss = (F.smooth_l1_loss(fake_swap_ab, img_ab)) 196 | 197 | ## feature loss 198 | real_img_rgb = img / 255. 199 | features_A = vggnet(real_img_rgb, layer_name='all') 200 | 201 | fake_swap_rgb = tensor_lab2rgb(torch.cat((img_l * 50 + 50, fake_swap_ab * 110), 1)) # [0, 1] 202 | features_B = vggnet(fake_swap_rgb, layer_name='all') 203 | # fea_loss = F.l1_loss(features_A[-1], features_B[-1]) * 0.1 204 | # fea_loss = 0 205 | 206 | fea_loss1 = F.l1_loss(features_A[0], features_B[0]) / 32 * 0.1 207 | fea_loss2 = F.l1_loss(features_A[1], features_B[1]) / 16 * 0.1 208 | fea_loss3 = F.l1_loss(features_A[2], features_B[2]) / 8 * 0.1 209 | fea_loss4 = F.l1_loss(features_A[3], features_B[3]) / 4 * 0.1 210 | fea_loss5 = F.l1_loss(features_A[4], features_B[4]) * 0.1 211 | 212 | fea_loss = fea_loss1 + fea_loss2 + fea_loss3 + fea_loss4 + fea_loss5 213 | 214 | ## discriminator loss 215 | real_img_rgb = img / 255. 216 | img_ref_rgb = img_ref / 255. 217 | zero_ab_image = torch.zeros_like(fake_swap_ab) 218 | input_img_rgb = tensor_lab2rgb(torch.cat((img_l * 50 + 50, zero_ab_image), 1)) # [0, 1] 219 | 220 | # ima = input_img_rgb.cpu() 221 | # ima = ima.numpy()*255 222 | # ima = ima[0].astype('uint8') 223 | # ima = Image.fromarray(ima.transpose(1,2,0)) 224 | # ima.show() 225 | 226 | pred_fake = discriminator(fake_swap_rgb, input_img_rgb, img_ref_rgb) 227 | disc_loss_GAN = criterion_GAN(pred_fake, valid) 228 | disc_loss_GAN = disc_loss_GAN * 0.01 229 | 230 | loss_dict["recon"] = recon_loss 231 | 232 | loss_dict["fea"] = fea_loss 233 | 234 | loss_dict["disc_loss_GAN"] = disc_loss_GAN 235 | 236 | g_optim.zero_grad() 237 | (recon_loss + fea_loss + disc_loss_GAN).backward() 238 | g_optim.step() 239 | 240 | # --------------------- 241 | # Train Discriminator 242 | # --------------------- 243 | # if the disc_loss_GAN<0.003, then start to train Discriminator 244 | if i % 35 == 0: 245 | # Real loss 246 | pred_real = discriminator(real_img_rgb, input_img_rgb, img_ref_rgb) 247 | loss_real = criterion_GAN(pred_real, valid) 248 | 249 | # Fake loss 250 | pred_fake = discriminator(fake_swap_rgb.detach(), input_img_rgb, img_ref_rgb) 251 | loss_fake = criterion_GAN(pred_fake, fake) 252 | 253 | # Total loss 254 | disc_loss = 0.5 * (loss_real + loss_fake) 255 | 256 | d_optim.zero_grad() 257 | disc_loss.backward() 258 | d_optim.step() 259 | 260 | # loss for discriminator itself 261 | disc_val = disc_loss.mean().item() 262 | disc_val_all += disc_val 263 | count += 1 264 | 265 | # -------------- 266 | # Log Progress 267 | # -------------- 268 | 269 | loss_reduced = reduce_loss_dict(loss_dict) 270 | 271 | recon_val = loss_reduced["recon"].mean().item() 272 | recon_val_all += recon_val 273 | # recon_val = 0 274 | fea_val = loss_reduced["fea"].mean().item() 275 | fea_val_all += fea_val 276 | # fea_val = 0 277 | 278 | # loss for generator 279 | disc_val_GAN = loss_reduced["disc_loss_GAN"].mean().item() 280 | disc_val_GAN_all += disc_val_GAN 281 | 282 | if get_rank() == 0: 283 | pbar.set_description( 284 | ( 285 | f"recon:{recon_val:.4f}; fea:{fea_val:.4f}; disc_GAN:{disc_val_GAN:.4f}; discriminator:{disc_val:.4f};" 286 | ) 287 | ) 288 | 289 | if i % 100 == 0: 290 | if disc_val_all != 0: 291 | disc_val_all = disc_val_all / count 292 | print( 293 | f"recon_all:{recon_val_all / 100:.4f}; fea_all:{fea_val_all / 100:.4f}; disc_GAN_all:{disc_val_GAN_all / 100:.4f};discriminator:{disc_val_all:.4f};") 294 | recon_val_all = 0 295 | fea_val_all = 0 296 | disc_val_GAN_all = 0 297 | disc_val_all = 0 298 | count = 0 299 | 300 | # this code is for model validation, you should prepare you own val dataset and edit code to use it 301 | # if i % 250 == 0: 302 | # with torch.no_grad(): 303 | # colorEncoder.eval() 304 | # colorUNet.eval() 305 | # 306 | # imgsize = 256 307 | # for inum in range(12): 308 | # val_img_path = 'test_datasets/val_Sketch/in%d.jpg' % (inum + 1) 309 | # val_ref_path = 'test_datasets/val_Sketch/ref%d.jpg' % (inum + 1) 310 | # # val_img_path = 'test_datasets/val_daytime/day_sample/in%d.jpg'%(inum+1) 311 | # # val_ref_path = 'test_datasets/val_daytime/night_sample/dark4.jpg' 312 | # out_name = 'in%d_ref%d.png' % (inum + 1, inum + 1) 313 | # val_img = Image.open(val_img_path).convert("RGB").resize((imgsize, imgsize)) 314 | # val_img_ref = Image.open(val_ref_path).convert("RGB").resize((imgsize, imgsize)) 315 | # val_img, val_img_lab = preprocessing(val_img) 316 | # val_img_ref, val_img_ref_lab = preprocessing(val_img_ref) 317 | # 318 | # # val_img = val_img.to(device) 319 | # val_img_lab = val_img_lab.to(device) 320 | # val_img_ref = val_img_ref.to(device) 321 | # # val_img_ref_lab = val_img_ref_lab.to(device) 322 | # 323 | # val_img_l = val_img_lab[:, :1, :, :] / 50. # [-1, 1] 324 | # # val_img_ref_ab = val_img_ref_lab[:,1:,:,:] / 110. # [-1, 1] 325 | # 326 | # ref_color_vector = colorEncoder(val_img_ref / 255.) # [0, 1] 327 | # fake_swap_ab = colorUNet((val_img_l, ref_color_vector)) 328 | # 329 | # fake_img = torch.cat((val_img_l * 50, fake_swap_ab * 110), 1) 330 | # 331 | # sample = np.concatenate( 332 | # (tensor2numpy(val_img), tensor2numpy(val_img_ref), Lab2RGB_out(fake_img)), 1) 333 | # 334 | # out_dir = 'training_logs/%s/%06d' % (args.experiment_name, i) 335 | # mkdirss(out_dir) 336 | # io.imsave('%s/%s' % (out_dir, out_name), sample.astype('uint8')) 337 | # torch.cuda.empty_cache() 338 | if i % 2000 == 0: 339 | out_dir_g = "experiments/%s" % (args.experiment_name) 340 | mkdirss(out_dir_g) 341 | torch.save( 342 | { 343 | "colorEncoder": colorEncoder_module.state_dict(), 344 | "colorUNet": colorUNet_module.state_dict(), 345 | "g_optim": g_optim.state_dict(), 346 | "args": args, 347 | }, 348 | f"%s/{str(i).zfill(6)}_sketch.pt" % (out_dir_g), 349 | ) 350 | out_dir_d = "experiments/Discriminator" 351 | mkdirss(out_dir_d) 352 | torch.save( 353 | { 354 | "discriminator": discriminator.state_dict(), 355 | "d_optim": d_optim.state_dict(), 356 | "args": args, 357 | }, 358 | f"%s/{str(i).zfill(6)}_d.pt" % (out_dir_d), 359 | ) 360 | 361 | 362 | if __name__ == "__main__": 363 | device = "cuda" 364 | 365 | torch.backends.cudnn.benchmark = True 366 | 367 | parser = argparse.ArgumentParser() 368 | 369 | parser.add_argument("--datasets", type=str) 370 | parser.add_argument("--iter", type=int, default=200000) 371 | parser.add_argument("--batch", type=int, default=16) 372 | parser.add_argument("--size", type=int, default=256) 373 | parser.add_argument("--ckpt", type=str, default=None) 374 | parser.add_argument("--ckpt_disc", type=str, default=None) 375 | parser.add_argument("--lr", type=float, default=0.0001) 376 | parser.add_argument("--lr_disc", type=float, default=0.0002) 377 | parser.add_argument("--experiment_name", type=str, default="default") 378 | parser.add_argument("--wandb", action="store_true") 379 | parser.add_argument("--local_rank", type=int, default=0) 380 | 381 | args = parser.parse_args() 382 | 383 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 384 | args.distributed = n_gpu > 1 385 | 386 | args.start_iter = 0 387 | 388 | vggnet = vgg19(pretrained_path='./experiments/VGG19/vgg19-dcbb9e9d.pth', require_grad=False) 389 | vggnet = vggnet.to(device) 390 | vggnet.eval() 391 | 392 | colorEncoder = ColorEncoder(color_dim=512).to(device) 393 | colorUNet = ColorUNet(bilinear=True).to(device) 394 | discriminator = Discriminator(in_channels=3).to(device) 395 | 396 | g_optim = optim.Adam( 397 | list(colorEncoder.parameters()) + list(colorUNet.parameters()), 398 | lr=args.lr, 399 | betas=(0.9, 0.99), 400 | ) 401 | 402 | d_optim = optim.Adam( 403 | discriminator.parameters(), 404 | lr=args.lr_disc, 405 | betas=(0.5, 0.999), 406 | ) 407 | 408 | if args.ckpt is not None: 409 | print("load model:", args.ckpt) 410 | 411 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 412 | 413 | try: 414 | ckpt_name = os.path.basename(args.ckpt) 415 | match = re.search(r'\d+', ckpt_name) 416 | if match: 417 | args.start_iter = int(match.group(0)) 418 | else: 419 | args.start_iter = 0 420 | except ValueError: 421 | pass 422 | 423 | colorEncoder.load_state_dict(ckpt["colorEncoder"]) 424 | colorUNet.load_state_dict(ckpt["colorUNet"]) 425 | g_optim.load_state_dict(ckpt["g_optim"]) 426 | 427 | if args.ckpt_disc is not None: 428 | print("load discriminator model:", args.ckpt_disc) 429 | 430 | ckpt_disc = torch.load(args.ckpt_disc, map_location=lambda storage, loc: storage) 431 | discriminator.load_state_dict(ckpt_disc["discriminator"]) 432 | d_optim.load_state_dict(ckpt_disc["d_optim"]) 433 | # print(args.distributed) 434 | 435 | transform = transforms.Compose( 436 | [ 437 | transforms.RandomHorizontalFlip(), 438 | # transforms.RandomVerticalFlip(), 439 | transforms.RandomRotation(degrees=(-90, 90)) 440 | ] 441 | ) 442 | 443 | datasets = [] 444 | dataset = MultiResolutionDataset(args.datasets, transform, args.size) 445 | datasets.append(dataset) 446 | 447 | loader = data.DataLoader( 448 | data.ConcatDataset(datasets), 449 | batch_size=args.batch, 450 | sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), 451 | drop_last=True, 452 | ) 453 | 454 | train( 455 | args, 456 | loader, 457 | colorEncoder, 458 | colorUNet, 459 | discriminator, 460 | vggnet, 461 | g_optim, 462 | d_optim, 463 | device, 464 | ) 465 | -------------------------------------------------------------------------------- /train_disc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import os 4 | 5 | import numpy as np 6 | from PIL import Image 7 | from skimage import color, io 8 | import torch 9 | from torch import nn, optim 10 | from torch.nn import functional as F 11 | from torch.utils import data 12 | from torchvision import transforms 13 | from tqdm import tqdm 14 | from torch.autograd import Variable 15 | 16 | # from ColorEncoder import ColorEncoder 17 | from models import ColorEncoder, ColorUNet 18 | from discriminator import Discriminator 19 | from data.data_loader import MultiResolutionDataset 20 | 21 | from utils import tensor_lab2rgb 22 | 23 | from distributed import ( 24 | get_rank, 25 | synchronize, 26 | reduce_loss_dict, 27 | ) 28 | 29 | 30 | def mkdirss(dirpath): 31 | if not os.path.exists(dirpath): 32 | os.makedirs(dirpath) 33 | 34 | 35 | def data_sampler(dataset, shuffle, distributed): 36 | if distributed: 37 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 38 | 39 | if shuffle: 40 | return data.RandomSampler(dataset) 41 | 42 | else: 43 | return data.SequentialSampler(dataset) 44 | 45 | 46 | def requires_grad(model, flag=True): 47 | for p in model.parameters(): 48 | p.requires_grad = flag 49 | 50 | 51 | def sample_data(loader): 52 | while True: 53 | for batch in loader: 54 | yield batch 55 | 56 | 57 | def Lab2RGB_out(img_lab): 58 | img_lab = img_lab.detach().cpu() 59 | img_l = img_lab[:, :1, :, :] 60 | img_ab = img_lab[:, 1:, :, :] 61 | # print(torch.max(img_l), torch.min(img_l)) 62 | # print(torch.max(img_ab), torch.min(img_ab)) 63 | img_l = img_l + 50 64 | pred_lab = torch.cat((img_l, img_ab), 1)[0, ...].numpy() 65 | # grid_lab = utils.make_grid(pred_lab, nrow=1).numpy().astype("float64") 66 | # print(grid_lab.shape) 67 | out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8") 68 | return out 69 | 70 | 71 | def RGB2Lab(inputs): 72 | # input [0, 255] uint8 73 | # out l: [0, 100], ab: [-110, 110], float32 74 | return color.rgb2lab(inputs) 75 | 76 | 77 | def Normalize(inputs): 78 | l = inputs[:, :, 0:1] 79 | ab = inputs[:, :, 1:3] 80 | l = l - 50 81 | lab = np.concatenate((l, ab), 2) 82 | 83 | return lab.astype('float32') 84 | 85 | 86 | def numpy2tensor(inputs): 87 | out = torch.from_numpy(inputs.transpose(2, 0, 1)) 88 | return out 89 | 90 | 91 | def tensor2numpy(inputs): 92 | out = inputs[0, ...].detach().cpu().numpy().transpose(1, 2, 0) 93 | return out 94 | 95 | 96 | def preprocessing(inputs): 97 | # input: rgb, [0, 255], uint8 98 | img_lab = Normalize(RGB2Lab(inputs)) 99 | img = np.array(inputs, 'float32') # [0, 255] 100 | img = numpy2tensor(img) 101 | img_lab = numpy2tensor(img_lab) 102 | return img.unsqueeze(0), img_lab.unsqueeze(0) 103 | 104 | 105 | def uncenter_l(inputs): 106 | l = inputs[:, :1, :, :] + 50 107 | ab = inputs[:, 1:, :, :] 108 | return torch.cat((l, ab), 1) 109 | 110 | 111 | def train( 112 | args, 113 | loader, 114 | colorEncoder, 115 | colorUNet, 116 | discriminator, 117 | d_optim, 118 | device, 119 | ): 120 | loader = sample_data(loader) 121 | 122 | pbar = range(args.iter) 123 | 124 | if get_rank() == 0: 125 | pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) 126 | 127 | disc_val_all = 0 128 | criterion_GAN = torch.nn.MSELoss().to(device) 129 | 130 | # Calculate output of image discriminator (PatchGAN) 131 | patch = (1, args.size // 2 ** 4, args.size // 2 ** 4) 132 | Tensor = torch.cuda.FloatTensor if device == 'cuda' else torch.FloatTensor 133 | 134 | for idx in pbar: 135 | i = idx + args.start_iter 136 | 137 | if i > args.iter: 138 | print("Done!") 139 | 140 | break 141 | 142 | img, img_ref, img_lab = next(loader) 143 | 144 | # Adversarial ground truths 145 | valid = Variable(Tensor(np.ones((img.size(0), *patch))), requires_grad=False) 146 | fake = Variable(Tensor(np.zeros((img.size(0), *patch))), requires_grad=False) 147 | # ima = img.numpy() 148 | # ima = ima[0].astype('uint8') 149 | # ima = Image.fromarray(ima.transpose(1,2,0)) 150 | # ima.show() 151 | 152 | img = img.to(device) # GT [B, 3, 256, 256] 153 | img_lab = img_lab.to(device) # GT 154 | 155 | img_ref = img_ref.to(device) # tps_transformed image RGB [B, 3, 256, 256] 156 | 157 | img_l = img_lab[:, :1, :, :] / 50 # [-1, 1] target L 158 | img_ab = img_lab[:, 1:, :, :] / 110 # [-1, 1] target ab 159 | # img_ref_ab = img_ref_lab[:,1:,:,:] / 110 # [-1, 1] ref ab 160 | 161 | colorEncoder.eval() 162 | colorUNet.eval() 163 | discriminator.train() 164 | 165 | requires_grad(colorEncoder, False) 166 | requires_grad(colorUNet, False) 167 | requires_grad(discriminator, True) 168 | 169 | with torch.no_grad(): 170 | ref_color_vector = colorEncoder(img_ref / 255.) 171 | fake_swap_ab = colorUNet((img_l, ref_color_vector)) # [-1, 1] 172 | 173 | fake_swap_rgb = tensor_lab2rgb(torch.cat((img_l * 50 + 50, fake_swap_ab * 110), 1)) # [0, 1] 174 | real_img_rgb = img / 255. 175 | img_ref_rgb = img_ref / 255. 176 | 177 | zero_ab_image = torch.zeros_like(fake_swap_ab) 178 | input_img_rgb = tensor_lab2rgb(torch.cat((img_l * 50 + 50, zero_ab_image), 1)) # [0, 1] 179 | 180 | # show the gray image 181 | 182 | # input_img_rgb_cpu = input_img_rgb.cpu() 183 | # ima = input_img_rgb_cpu.numpy() 184 | # ima = ima*255 185 | # ima = ima[0].astype('uint8') 186 | # ima = Image.fromarray(ima.transpose(1,2,0)) 187 | # ima.show() 188 | 189 | # Real loss 190 | pred_real = discriminator(real_img_rgb, input_img_rgb, img_ref_rgb) 191 | loss_real = criterion_GAN(pred_real, valid) 192 | 193 | # Fake loss 194 | pred_fake = discriminator(fake_swap_rgb.detach(), input_img_rgb, img_ref_rgb) 195 | loss_fake = criterion_GAN(pred_fake, fake) 196 | 197 | # Total loss 198 | disc_loss = 0.5 * (loss_real + loss_fake) 199 | 200 | d_optim.zero_grad() 201 | disc_loss.backward() 202 | d_optim.step() 203 | 204 | disc_val = disc_loss.mean().item() 205 | disc_val_all += disc_val 206 | 207 | if get_rank() == 0: 208 | pbar.set_description( 209 | ( 210 | f"discriminator:{disc_val:.4f};" 211 | ) 212 | ) 213 | 214 | if i % 100 == 0: 215 | print(f"discriminator:{disc_val_all / 100:.4f};") 216 | disc_val_all = 0 217 | if i % 1000 == 0: 218 | out_dir = "experiments/%s" % (args.experiment_name) 219 | mkdirss(out_dir) 220 | torch.save( 221 | { 222 | "discriminator": discriminator.state_dict(), 223 | "d_optim": d_optim.state_dict(), 224 | "args": args, 225 | }, 226 | f"%s/{str(i).zfill(6)}_ds.pt" % (out_dir), 227 | ) 228 | 229 | 230 | if __name__ == "__main__": 231 | device = "cuda" 232 | 233 | torch.backends.cudnn.benchmark = True 234 | 235 | parser = argparse.ArgumentParser() 236 | 237 | parser.add_argument("--datasets", type=str) 238 | parser.add_argument("--iter", type=int, default=100000) 239 | parser.add_argument("--batch", type=int, default=16) 240 | parser.add_argument("--size", type=int, default=256) 241 | parser.add_argument("--ckpt", type=str, default=None) 242 | parser.add_argument("--ckpt_disc", type=str, default=None) 243 | parser.add_argument("--lr", type=float, default=0.0002) 244 | parser.add_argument("--experiment_name", type=str, default="default") 245 | parser.add_argument("--wandb", action="store_true") 246 | parser.add_argument("--local_rank", type=int, default=0) 247 | 248 | args = parser.parse_args() 249 | 250 | n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 251 | args.distributed = n_gpu > 1 252 | 253 | if args.distributed: 254 | torch.cuda.set_device(args.local_rank) 255 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 256 | synchronize() 257 | 258 | args.start_iter = 0 259 | 260 | colorEncoder = ColorEncoder(color_dim=512).to(device) 261 | colorUNet = ColorUNet(bilinear=True).to(device) 262 | discriminator = Discriminator(in_channels=3).to(device) 263 | 264 | d_optim = optim.Adam( 265 | discriminator.parameters(), 266 | lr=args.lr, 267 | betas=(0.5, 0.999), 268 | ) 269 | 270 | if args.ckpt is not None: 271 | print("load model:", args.ckpt) 272 | 273 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 274 | 275 | colorEncoder.load_state_dict(ckpt["colorEncoder"]) 276 | colorUNet.load_state_dict(ckpt["colorUNet"]) 277 | 278 | if args.ckpt_disc is not None: 279 | print("load discriminator model:", args.ckpt_disc) 280 | 281 | ckpt_disc = torch.load(args.ckpt_disc, map_location=lambda storage, loc: storage) 282 | 283 | try: 284 | ckpt_name = os.path.basename(args.ckpt_disc) 285 | args.start_iter = int(os.path.splitext(ckpt_name)[0]) 286 | 287 | except ValueError: 288 | pass 289 | 290 | discriminator.load_state_dict(ckpt_disc["discriminator"]) 291 | d_optim.load_state_dict(ckpt_disc["d_optim"]) 292 | 293 | # print(args.distributed) 294 | 295 | if args.distributed: 296 | colorEncoder = nn.parallel.DistributedDataParallel( 297 | colorEncoder, 298 | device_ids=[args.local_rank], 299 | output_device=args.local_rank, 300 | broadcast_buffers=False, 301 | ) 302 | 303 | colorUNet = nn.parallel.DistributedDataParallel( 304 | colorUNet, 305 | device_ids=[args.local_rank], 306 | output_device=args.local_rank, 307 | broadcast_buffers=False, 308 | ) 309 | 310 | transform = transforms.Compose( 311 | [ 312 | transforms.RandomHorizontalFlip(), 313 | transforms.RandomVerticalFlip(), 314 | transforms.RandomRotation(degrees=(0, 360)) 315 | ] 316 | ) 317 | 318 | datasets = [] 319 | dataset = MultiResolutionDataset(args.datasets, transform, args.size) 320 | datasets.append(dataset) 321 | 322 | loader = data.DataLoader( 323 | data.ConcatDataset(datasets), 324 | batch_size=args.batch, 325 | sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), 326 | drop_last=True, 327 | ) 328 | 329 | train( 330 | args, 331 | loader, 332 | colorEncoder, 333 | colorUNet, 334 | discriminator, 335 | d_optim, 336 | device, 337 | ) 338 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | # stdlib 4 | import os 5 | # 3p 6 | from skimage import io 7 | import cv2 8 | xyz_from_rgb = np.array( 9 | [[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169], [0.019334, 0.119193, 0.950227]] 10 | ) 11 | rgb_from_xyz = np.array( 12 | [[3.24048134, -0.96925495, 0.05564664], [-1.53715152, 1.87599, -0.20404134], [-0.49853633, 0.04155593, 1.05731107]] 13 | ) 14 | 15 | 16 | def tensor_lab2rgb(input): 17 | """ 18 | n * 3* h *w 19 | """ 20 | input_trans = input.transpose(1, 2).transpose(2, 3) # n * h * w * 3 21 | L, a, b = input_trans[:, :, :, 0:1], input_trans[:, :, :, 1:2], input_trans[:, :, :, 2:] 22 | y = (L + 16.0) / 116.0 23 | x = (a / 500.0) + y 24 | z = y - (b / 200.0) 25 | 26 | neg_mask = z.data < 0 27 | z[neg_mask] = 0 28 | xyz = torch.cat((x, y, z), dim=3) 29 | 30 | mask = xyz.data > 0.2068966 31 | mask_xyz = xyz.clone() 32 | mask_xyz[mask] = torch.pow(xyz[mask], 3.0) 33 | mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.0) / 7.787 34 | mask_xyz[:, :, :, 0] = mask_xyz[:, :, :, 0] * 0.95047 35 | mask_xyz[:, :, :, 2] = mask_xyz[:, :, :, 2] * 1.08883 36 | 37 | rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view( 38 | input.size(0), input.size(2), input.size(3), 3 39 | ) 40 | rgb = rgb_trans.transpose(2, 3).transpose(1, 2) 41 | 42 | mask = rgb > 0.0031308 43 | mask_rgb = rgb.clone() 44 | mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055 45 | mask_rgb[~mask] = rgb[~mask] * 12.92 46 | 47 | neg_mask = mask_rgb.data < 0 48 | large_mask = mask_rgb.data > 1 49 | mask_rgb[neg_mask] = 0 50 | mask_rgb[large_mask] = 1 51 | return mask_rgb 52 | 53 | def get_files(img_dir): 54 | imgs, masks, xmls = list_files(img_dir) 55 | return imgs, masks, xmls 56 | 57 | 58 | def list_files(in_path): 59 | img_files = [] 60 | mask_files = [] 61 | gt_files = [] 62 | for (dirpath, dirnames, filenames) in os.walk(in_path): 63 | for file in filenames: 64 | filename, ext = os.path.splitext(file) 65 | ext = str.lower(ext) 66 | if ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or ext == '.png' or ext == '.pgm': 67 | img_files.append(os.path.join(dirpath, file)) 68 | elif ext == '.bmp': 69 | mask_files.append(os.path.join(dirpath, file)) 70 | elif ext == '.xml' or ext == '.gt' or ext == '.txt': 71 | gt_files.append(os.path.join(dirpath, file)) 72 | elif ext == '.zip': 73 | continue 74 | return img_files, mask_files, gt_files 75 | 76 | 77 | def load_image(img_file): 78 | img = io.imread(img_file) # RGB order 79 | if img.shape[0] == 2: 80 | img = img[0] 81 | if len(img.shape) == 2: 82 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 83 | if img.shape[2] == 4: 84 | img = img[:, :, :3] 85 | img = np.array(img) 86 | 87 | return img -------------------------------------------------------------------------------- /vgg_model.py: -------------------------------------------------------------------------------- 1 | from torchvision import models 2 | from collections import namedtuple 3 | import torch 4 | import torch.nn as nn 5 | 6 | def vgg_preprocess(tensor): 7 | # input is RGB tensor which ranges in [0,1] 8 | # output is RGB tensor which ranges 9 | mean_val = torch.Tensor([0.485, 0.456, 0.406]).type_as(tensor).view(-1, 1, 1) 10 | std_val = torch.Tensor([0.229, 0.224, 0.225]).type_as(tensor).view(-1, 1, 1) 11 | tensor_norm = (tensor - mean_val) / std_val 12 | return tensor_norm 13 | 14 | class vgg19(nn.Module): 15 | 16 | def __init__(self, pretrained_path = './experiments/VGG19/vgg19-dcbb9e9d.pth', require_grad = False): 17 | super(vgg19, self).__init__() 18 | self.vgg_model = models.vgg19() 19 | if pretrained_path != None: 20 | print('----load pretrained vgg19----') 21 | self.vgg_model.load_state_dict(torch.load(pretrained_path)) 22 | print('----load done!----') 23 | self.vgg_feature = self.vgg_model.features 24 | self.seq_list = [nn.Sequential(ele) for ele in self.vgg_feature] 25 | # self.vgg_layer = ['conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 26 | # 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 27 | # 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 28 | # 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 29 | # 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'] 30 | 31 | # self.vgg_layer = ['relu1_2', 'relu2_2', 'relu3_2', 'relu4_2', 'relu5_2'] 32 | 33 | if not require_grad: 34 | for parameter in self.parameters(): 35 | parameter.requires_grad = False 36 | 37 | def forward(self, x, layer_name='relu5_2'): 38 | ### x: RGB [0, 1], input should be [0, 1] 39 | x = vgg_preprocess(x) 40 | 41 | conv1_1 = self.seq_list[0](x) 42 | relu1_1 = self.seq_list[1](conv1_1) 43 | conv1_2 = self.seq_list[2](relu1_1) 44 | relu1_2 = self.seq_list[3](conv1_2) 45 | pool1 = self.seq_list[4](relu1_2) 46 | 47 | conv2_1 = self.seq_list[5](pool1) 48 | relu2_1 = self.seq_list[6](conv2_1) 49 | conv2_2 = self.seq_list[7](relu2_1) 50 | relu2_2 = self.seq_list[8](conv2_2) 51 | pool2 = self.seq_list[9](relu2_2) 52 | 53 | conv3_1 = self.seq_list[10](pool2) 54 | relu3_1 = self.seq_list[11](conv3_1) 55 | conv3_2 = self.seq_list[12](relu3_1) 56 | relu3_2 = self.seq_list[13](conv3_2) 57 | conv3_3 = self.seq_list[14](relu3_2) 58 | relu3_3 = self.seq_list[15](conv3_3) 59 | conv3_4 = self.seq_list[16](relu3_3) 60 | relu3_4 = self.seq_list[17](conv3_4) 61 | pool3 = self.seq_list[18](relu3_4) 62 | 63 | conv4_1 = self.seq_list[19](pool3) 64 | relu4_1 = self.seq_list[20](conv4_1) 65 | conv4_2 = self.seq_list[21](relu4_1) 66 | relu4_2 = self.seq_list[22](conv4_2) 67 | conv4_3 = self.seq_list[23](relu4_2) 68 | relu4_3 = self.seq_list[24](conv4_3) 69 | conv4_4 = self.seq_list[25](relu4_3) 70 | relu4_4 = self.seq_list[26](conv4_4) 71 | pool4 = self.seq_list[27](relu4_4) 72 | 73 | conv5_1 = self.seq_list[28](pool4) 74 | relu5_1 = self.seq_list[29](conv5_1) 75 | conv5_2 = self.seq_list[30](relu5_1) 76 | relu5_2 = self.seq_list[31](conv5_2) # [B, 512, 16, 16] 77 | conv5_3 = self.seq_list[32](relu5_2) 78 | relu5_3 = self.seq_list[33](conv5_3) 79 | conv5_4 = self.seq_list[34](relu5_3) 80 | relu5_4 = self.seq_list[35](conv5_4) 81 | pool5 = self.seq_list[36](relu5_4) # [B, 512, 8, 8] 82 | 83 | # vgg_output = namedtuple("vgg_output", self.vgg_layer) 84 | 85 | # vgg_list = [conv1_1, relu1_1, conv1_2, relu1_2, pool1, 86 | # conv2_1, relu2_1, conv2_2, relu2_2, pool2, 87 | # conv3_1, relu3_1, conv3_2, relu3_2, conv3_3, relu3_3, conv3_4, relu3_4, pool3, 88 | # conv4_1, relu4_1, conv4_2, relu4_2, conv4_3, relu4_3, conv4_4, relu4_4, pool4, 89 | # conv5_1, relu5_1, conv5_2, relu5_2, conv5_3, relu5_3, conv5_4, relu5_4, pool5] 90 | 91 | if layer_name == 'relu5_2': 92 | vgg_list = [relu5_2] 93 | elif layer_name == 'conv5_2': 94 | vgg_list = [conv5_2] 95 | elif layer_name == 'relu5_4': 96 | vgg_list = [relu5_4] 97 | elif layer_name == 'pool5': 98 | # print('pool5') 99 | vgg_list = [pool5] 100 | elif layer_name == 'all': 101 | vgg_list = [relu1_2, relu2_2, relu3_2, relu4_2, relu5_2] 102 | 103 | # out = vgg_output(*vgg_list) 104 | 105 | return vgg_list 106 | 107 | class vgg19_class_fea(nn.Module): 108 | 109 | def __init__(self, pretrained_path = './experiments/vgg19-dcbb9e9d.pth', require_grad = False): 110 | super(vgg19_class_fea, self).__init__() 111 | self.vgg_model = models.vgg19() 112 | print('----load pretrained vgg19----') 113 | self.vgg_model.load_state_dict(torch.load(pretrained_path)) 114 | print('----load done!----') 115 | self.vgg_feature = self.vgg_model.features 116 | self.avgpool = self.vgg_model.avgpool 117 | self.classifier = self.vgg_model.classifier 118 | 119 | self.seq_list = [nn.Sequential(ele) for ele in self.vgg_feature] # 37层 120 | if not require_grad: 121 | for parameter in self.parameters(): 122 | parameter.requires_grad = False 123 | 124 | def forward(self, x): 125 | ### x: RGB [0, 1], input should be [0, 1] 126 | x = vgg_preprocess(x) 127 | 128 | for i in range(len(self.seq_list)): 129 | x = self.seq_list[i](x) 130 | if i == 31: 131 | relu5_2 = x 132 | 133 | x = self.avgpool(x) 134 | x = torch.flatten(x, 1) 135 | x_class = self.classifier(x) 136 | return x_class, relu5_2 --------------------------------------------------------------------------------