├── .gitignore ├── LICENSE ├── README.md ├── config ├── colorization_mirflickr25k.json ├── inpainting_celebahq.json ├── inpainting_places2.json └── uncropping_places2.json ├── core ├── base_dataset.py ├── base_model.py ├── base_network.py ├── logger.py ├── praser.py └── util.py ├── data ├── __init__.py ├── dataset.py └── util │ ├── auto_augment.py │ └── mask.py ├── eval.py ├── misc ├── Palette Image-to-Image Diffusion Models.pdf └── image │ ├── Mask_Places365_test_00143399.jpg │ ├── Mask_Places365_test_00144085.jpg │ ├── Mask_Places365_test_00209019.jpg │ ├── Mask_Places365_test_00263905.jpg │ ├── Out_Places365_test_00143399.jpg │ ├── Out_Places365_test_00144085.jpg │ ├── Out_Places365_test_00209019.jpg │ ├── Out_Places365_test_00263905.jpg │ ├── Process_02323.jpg │ ├── Process_26190.jpg │ ├── Process_Places365_test_00042384.jpg │ └── Process_Places365_test_00309553.jpg ├── models ├── __init__.py ├── guided_diffusion_modules │ ├── nn.py │ └── unet.py ├── loss.py ├── metric.py ├── model.py ├── network.py └── sr3_modules │ └── unet.py ├── preprocess └── mirflickr25k_preprocess.py ├── requirements.txt ├── run.py └── slurm └── inpainting_places2.slurm /.gitignore: -------------------------------------------------------------------------------- 1 | # myself 2 | experiments/* 3 | datasets/* 4 | !experiments/clean.sh 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Liangwei Jiang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Palette: Image-to-Image Diffusion Models 2 | 3 | [Paper](https://arxiv.org/pdf/2111.05826.pdf ) | [Project](https://iterative-refinement.github.io/palette/ ) 4 | 5 | ## Brief 6 | 7 | This is an unofficial implementation of **Palette: Image-to-Image Diffusion Models** by **Pytorch**, and it is mainly inherited from its super-resolution version [Image-Super-Resolution-via-Iterative-Refinement](https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement). The code template is from my another seed project: [distributed-pytorch-template](https://github.com/Janspiry/distributed-pytorch-template). 8 | 9 | There are some implementation details with paper descriptions: 10 | 11 | - We adapted the U-Net architecture used in `Guided-Diffusion`, which give a substantial boost to sample quality. 12 | - We used the attention mechanism in low-resolution features (16×16) like vanilla `DDPM`. 13 | - We encode the $\gamma$ rather than $t$ in `Palette` and embed it with affine transformation. 14 | - We fix the variance $Σ_\theta(x_t, t)$ to a constant during the inference as described in `Palette`. 15 | 16 | ## Status 17 | 18 | ### Code 19 | - [x] Diffusion Model Pipeline 20 | - [x] Train/Test Process 21 | - [x] Save/Load Training State 22 | - [x] Logger/Tensorboard 23 | - [x] Multiple GPU Training (DDP) 24 | - [x] EMA 25 | - [x] Metrics (now for FID, IS) 26 | - [x] Dataset (now for inpainting, uncropping, colorization) 27 | - [x] Google colab script 🌟(now for inpainting) 28 | 29 | ### Task 30 | 31 | I try to finish following tasks in order: 32 | - [x] Inpainting on [CelebaHQ](https://drive.google.com/drive/folders/1CjZAajyf-jIknskoTQ4CGvVkAigkhNWA?usp=sharing)🚀 ([Google Colab](https://colab.research.google.com/drive/1wfcd6QKkN2AqZDGFKZLyGKAoI5xcXUgO#scrollTo=8VFpuekybeQK)) 33 | - [x] Inpainting on [Places2 with 128×128 centering mask](https://drive.google.com/drive/folders/1fLyFtrStfEtyrqwI0N_Xb_3idsf0gz0M?usp=sharing)🚀 34 | 35 | The follow-up experiment is uncertain, due to lack of time and GPU resources: 36 | 37 | - [ ] Uncropping on Places2 38 | - [ ] Colorization on ImageNet val set 39 | 40 | ## Results 41 | 42 | The DDPM model requires significant computational resources, and we have only built a few example models to validate the ideas in this paper. 43 | 44 | ### Visuals 45 | 46 | #### Celeba-HQ 47 | 48 | Results with 200 epochs and 930K iterations, and the first 100 samples in [centering mask](https://drive.google.com/drive/folders/10zyHZtYV5vCht2MGNCF8WzpZJT2ae2RS?usp=sharing) and [irregular mask](https://drive.google.com/drive/folders/1vmSI-R9J2yQZY1cVkSSZlTYil2DprzvY?usp=sharing). 49 | 50 | | ![Process_02323](misc//image//Process_02323.jpg) | ![Process_02323](misc//image//Process_26190.jpg) | 51 | | ------------------------------------------------ | ---- | 52 | 53 | #### Places2 with 128×128 centering mask 54 | 55 | Results with 16 epochs and 660K iterations, and the several **picked** samples in [centering mask](https://drive.google.com/drive/folders/1XusKO0_M6GUfPG-FOlID0Xcp0SiexKNe?usp=sharing). 56 | 57 | | ![Mask_Places365_test_00209019.jpg](misc//image//Mask_Places365_test_00209019.jpg) | ![Mask_Places365_test_00143399.jpg](misc//image//Mask_Places365_test_00143399.jpg) | ![Mask_Places365_test_00263905.jpg](misc//image//Mask_Places365_test_00263905.jpg) | ![Mask_Places365_test_00144085.jpg](misc//image//Mask_Places365_test_00144085.jpg) | 58 | | ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | ---- | 59 | | ![Out_Places365_test_00209019](misc//image//Out_Places365_test_00209019.jpg) | ![Out_Places365_test_00143399.jpg](misc//image//Out_Places365_test_00143399.jpg) | ![Out_Places365_test_00263905.jpg](misc//image//Out_Places365_test_00263905.jpg) | ![Out_Places365_test_00144085.jpg](misc//image//Out_Places365_test_00144085.jpg) | 60 | 61 | #### Uncropping on Places2 62 | 63 | Results with 8 epochs and 330K iterations, and the several **picked** samples in [uncropping](https://drive.google.com/drive/folders/1tC3B8ayaadhXAJrOCTrw15R8t84REPWJ?usp=sharing). 64 | | ![Process_Places365_test_00309553](misc//image//Process_Places365_test_00309553.jpg) | ![Process_Places365_test_00042384](misc//image//Process_Places365_test_00042384.jpg) | 65 | | ------------------------------------------------ | ---- | 66 | 67 | 68 | ### Metrics 69 | 70 | | Tasks | Dataset | EMA | FID(-) | IS(+) | 71 | | -------------------- | ----------- | -------- | ---- | -------------------- | 72 | | Inpainting with centering mask | Celeba-HQ | False | 5.7873 | 3.0705 | 73 | | Inpainting with irregular mask | Celeba-HQ | False | 5.4026 | 3.1221 | 74 | 75 | ## Usage 76 | ### Environment 77 | ```python 78 | pip install -r requirements.txt 79 | ``` 80 | 81 | ### Pre-trained Model 82 | 83 | | Dataset | Task | Iterations | GPUs×Days×Bs | URL | 84 | | --------- | ---------- | ---------- | ------------ | ------------------------------------------------------------ | 85 | | Celeba-HQ | Inpainting | 930K | 2×5×3 | [Google Drive](https://drive.google.com/drive/folders/13YZ2UAmGJ-b7DICr-FDAPM7gctreJEoH?usp=sharing) | 86 | | Places2 | Inpainting | 660K | 4×8×10 | [Google Drive](https://drive.google.com/drive/folders/1Vz_HC0LcpV6yMLOd-SXyoaqJHtxyPBxZ?usp=sharing) | 87 | 88 | **Bs** indicates sample size per gpu. 89 | 90 | 91 | 92 | ### Data Prepare 93 | 94 | We get most of them from Kaggle, which may be slightly different from official version, and you also can download them from official website. 95 | - [CelebA-HQ resized (256x256) Kaggle](https://www.kaggle.com/datasets/badasstechie/celebahq-resized-256x256) 96 | - [Places2 Official](http://places2.csail.mit.edu/download.html) | [Places2 Kaggle](https://www.kaggle.com/datasets/nickj26/places2-mit-dataset?resource=download) 97 | - [ImageNet Official](https://www.image-net.org/download.php) 98 | 99 | We use the default division of these datasets for training and evaluation. The file lists we use can be found in [Celeba-HQ](https://drive.google.com/drive/folders/1-ym2Mi2jVKdWmWYKJ_L2TWXjUQh8z7H-?usp=sharing), [Places2](https://drive.google.com/drive/folders/11Qj2MtRfiD7LbKEveYwOLaiX62lm_2ww?usp=sharing). 100 | 101 | After you prepared own data, you need to modify the corresponding configure file to point to your data. Take the following as an example: 102 | 103 | ```yaml 104 | "which_dataset": { // import designated dataset using arguments 105 | "name": ["data.dataset", "InpaintDataset"], // import Dataset() class 106 | "args":{ // arguments to initialize dataset 107 | "data_root": "your data path", 108 | "data_len": -1, 109 | "mask_mode": "hybrid" 110 | } 111 | }, 112 | ``` 113 | 114 | More choices about **dataloader** and **validation split** also can be found in `datasets` part of configure file. 115 | 116 | ### Training/Resume Training 117 | 1. Download the checkpoints from given links. 118 | 1. Set `resume_state` of configure file to the directory of previous checkpoint. Take the following as an example, this directory contains training states and saved model: 119 | 120 | ```yaml 121 | "path": { //set every part file path 122 | "resume_state": "experiments/inpainting_celebahq_220426_150122/checkpoint/100" 123 | }, 124 | ``` 125 | 2. Set your network label in `load_everything` function of `model.py`, default is **Network**. Follow the tutorial settings, the optimizers and models will be loaded from 100.state and 100_Network.pth respectively. 126 | 127 | ```python 128 | netG_label = self.netG.__class__.__name__ 129 | self.load_network(network=self.netG, network_label=netG_label, strict=False) 130 | ``` 131 | 132 | 3. Run the script: 133 | 134 | ```python 135 | python run.py -p train -c config/inpainting_celebahq.json 136 | ``` 137 | 138 | We test the U-Net backbone used in `SR3` and `Guided Diffusion`, and `Guided Diffusion` one have a more robust performance in our current experiments. More choices about **backbone**, **loss** and **metric** can be found in `which_networks` part of configure file. 139 | 140 | ### Test 141 | 142 | 1. Modify the configure file to point to your data following the steps in **Data Prepare** part. 143 | 2. Set your model path following the steps in **Resume Training** part. 144 | 3. Run the script: 145 | ```python 146 | python run.py -p test -c config/inpainting_celebahq.json 147 | ``` 148 | 149 | ### Evaluation 150 | 1. Create two folders saving ground truth images and sample images, and their file names need to correspond to each other. 151 | 152 | 2. Run the script: 153 | 154 | ```python 155 | python eval.py -s [ground image path] -d [sample image path] 156 | ``` 157 | 158 | 159 | 160 | ## Acknowledge 161 | Our work is based on the following theoretical works: 162 | - [Denoising Diffusion Probabilistic Models](https://arxiv.org/pdf/2006.11239.pdf) 163 | - [Palette: Image-to-Image Diffusion Models](https://arxiv.org/pdf/2111.05826.pdf) 164 | - [Diffusion Models Beat GANs on Image Synthesis](https://arxiv.org/abs/2105.05233) 165 | 166 | and we are benefiting a lot from the following projects: 167 | - [openai/guided-diffusion](https://github.com/openai/guided-diffusion) 168 | - [LouisRouss/Diffusion-Based-Model-for-Colorization](https://github.com/LouisRouss/Diffusion-Based-Model-for-Colorization) 169 | -------------------------------------------------------------------------------- /config/colorization_mirflickr25k.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "colorization_mirflickr25k", // experiments name 3 | "gpu_ids": [0], // gpu ids list, default is single 0 4 | "seed" : -1, // random seed, seed <0 represents randomization not used 5 | "finetune_norm": false, // find the parameters to optimize 6 | 7 | "path": { //set every part file path 8 | "base_dir": "experiments", // base path for all log except resume_state 9 | "code": "code", // code backup 10 | "tb_logger": "tb_logger", // path of tensorboard logger 11 | "results": "results", 12 | "checkpoint": "checkpoint", 13 | // "resume_state": "experiments/inpainting_places2_220413_143231/checkpoint/25" 14 | "resume_state": null // ex: 100, loading .state and .pth from given epoch and iteration 15 | }, 16 | 17 | "datasets": { // train or test 18 | "train": { 19 | "which_dataset": { // import designated dataset using arguments 20 | "name": ["data.dataset", "ColorizationDataset"], // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py]) 21 | "args":{ // arguments to initialize dataset 22 | "data_root": "datasets/mirflickr25k/images", 23 | "data_flist": "datasets/mirflickr25k/flist/train.flist", 24 | "data_len": -1 25 | } 26 | }, 27 | "dataloader":{ 28 | "validation_split": 2, // percent or number 29 | "args":{ // arguments to initialize train_dataloader 30 | "batch_size": 4, // batch size in each gpu 31 | "num_workers": 4, 32 | "shuffle": true, 33 | "pin_memory": true, 34 | "drop_last": true 35 | }, 36 | "val_args":{ // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader 37 | "batch_size": 1, // batch size in each gpu 38 | "num_workers": 4, 39 | "shuffle": false, 40 | "pin_memory": true, 41 | "drop_last": false 42 | } 43 | } 44 | }, 45 | "test": { 46 | "which_dataset": { 47 | "name": "ColorizationDataset", // import Dataset() class / function(not recommend) from default file 48 | "args":{ 49 | "data_root": "datasets/mirflickr25k/images", 50 | "data_flist": "datasets/mirflickr25k/flist/test.flist" 51 | } 52 | }, 53 | "dataloader":{ 54 | "args":{ 55 | "batch_size": 8, 56 | "num_workers": 4, 57 | "pin_memory": true 58 | } 59 | } 60 | } 61 | }, 62 | 63 | "model": { // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict 64 | "which_model": { // import designated model(trainer) using arguments 65 | "name": ["models.model", "Palette"], // import Model() class / function(not recommend) from models.model.py (default is [models.model.py]) 66 | "args": { 67 | "sample_num": 8, // process of each image 68 | "task": "colorization", 69 | "ema_scheduler": { 70 | "ema_start": 1, 71 | "ema_iter": 1, 72 | "ema_decay": 0.9999 73 | }, 74 | "optimizers": [ 75 | { "lr": 5e-5, "weight_decay": 0} 76 | ] 77 | } 78 | }, 79 | "which_networks": [ // import designated list of networks using arguments 80 | { 81 | "name": ["models.network", "Network"], // import Network() class / function(not recommend) from default file (default is [models/network.py]) 82 | "args": { // arguments to initialize network 83 | "init_type": "kaiming", // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming 84 | "module_name": "guided_diffusion", // sr3 | guided_diffusion 85 | "unet": { 86 | "in_channel": 6, 87 | "out_channel": 3, 88 | "inner_channel": 64, 89 | "channel_mults": [ 90 | 1, 91 | 2, 92 | 4, 93 | 8 94 | ], 95 | "attn_res": [ 96 | // 32, 97 | 16 98 | // 8 99 | ], 100 | "num_head_channels": 32, 101 | "res_blocks": 2, 102 | "dropout": 0.2, 103 | "image_size": 224 104 | }, 105 | "beta_schedule": { 106 | "train": { 107 | "schedule": "linear", 108 | "n_timestep": 2000, 109 | // "n_timestep": 5, // debug 110 | "linear_start": 1e-6, 111 | "linear_end": 0.01 112 | }, 113 | "test": { 114 | "schedule": "linear", 115 | "n_timestep": 1000, 116 | "linear_start": 1e-4, 117 | "linear_end": 0.09 118 | } 119 | } 120 | } 121 | } 122 | ], 123 | "which_losses": [ // import designated list of losses without arguments 124 | "mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}} 125 | ], 126 | "which_metrics": [ // import designated list of metrics without arguments 127 | "mae" // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}} 128 | ] 129 | }, 130 | 131 | "train": { // arguments for basic training 132 | "n_epoch": 1e8, // max epochs, not limited now 133 | "n_iter": 1e8, // max interations 134 | "val_epoch": 5, // valdation every specified number of epochs 135 | "save_checkpoint_epoch": 10, 136 | "log_iter": 1e4, // log every specified number of iterations 137 | "tensorboard" : true // tensorboardX enable 138 | }, 139 | 140 | "debug": { // arguments in debug mode, which will replace arguments in train 141 | "val_epoch": 1, 142 | "save_checkpoint_epoch": 1, 143 | "log_iter": 10, 144 | "debug_split": 50 // percent or number, change the size of dataloder to debug_split. 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /config/inpainting_celebahq.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "inpainting_celebahq", // experiments name 3 | "gpu_ids": [0], // gpu ids list, default is single 0 4 | "seed" : -1, // random seed, seed <0 represents randomization not used 5 | "finetune_norm": false, // find the parameters to optimize 6 | 7 | "path": { //set every part file path 8 | "base_dir": "experiments", // base path for all log except resume_state 9 | "code": "code", // code backup 10 | "tb_logger": "tb_logger", // path of tensorboard logger 11 | "results": "results", 12 | "checkpoint": "checkpoint", 13 | "resume_state": "experiments/train_inpainting_celebahq_220426_233652/checkpoint/190" 14 | // "resume_state": null // ex: 100, loading .state and .pth from given epoch and iteration 15 | }, 16 | 17 | "datasets": { // train or test 18 | "train": { 19 | "which_dataset": { // import designated dataset using arguments 20 | "name": ["data.dataset", "InpaintDataset"], // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py]) 21 | "args":{ // arguments to initialize dataset 22 | "data_root": "datasets/celebahq/flist/train.flist", 23 | "data_len": -1, 24 | "mask_config": { 25 | "mask_mode": "hybrid" 26 | } 27 | } 28 | }, 29 | "dataloader":{ 30 | "validation_split": 2, // percent or number 31 | "args":{ // arguments to initialize train_dataloader 32 | "batch_size": 3, // batch size in each gpu 33 | "num_workers": 4, 34 | "shuffle": true, 35 | "pin_memory": true, 36 | "drop_last": true 37 | }, 38 | "val_args":{ // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader 39 | "batch_size": 1, // batch size in each gpu 40 | "num_workers": 4, 41 | "shuffle": false, 42 | "pin_memory": true, 43 | "drop_last": false 44 | } 45 | } 46 | }, 47 | "test": { 48 | "which_dataset": { 49 | "name": "InpaintDataset", // import Dataset() class / function(not recommend) from default file 50 | "args":{ 51 | "data_root": "datasets/celebahq/flist/test.flist", 52 | "mask_config": { 53 | "mask_mode": "center" 54 | } 55 | } 56 | }, 57 | "dataloader":{ 58 | "args":{ 59 | "batch_size": 8, 60 | "num_workers": 4, 61 | "pin_memory": true 62 | } 63 | } 64 | } 65 | }, 66 | 67 | "model": { // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict 68 | "which_model": { // import designated model(trainer) using arguments 69 | "name": ["models.model", "Palette"], // import Model() class / function(not recommend) from models.model.py (default is [models.model.py]) 70 | "args": { 71 | "sample_num": 8, // process of each image 72 | "task": "inpainting", 73 | "ema_scheduler": { 74 | "ema_start": 1, 75 | "ema_iter": 1, 76 | "ema_decay": 0.9999 77 | }, 78 | "optimizers": [ 79 | { "lr": 5e-5, "weight_decay": 0} 80 | ] 81 | } 82 | }, 83 | "which_networks": [ // import designated list of networks using arguments 84 | { 85 | "name": ["models.network", "Network"], // import Network() class / function(not recommend) from default file (default is [models/network.py]) 86 | "args": { // arguments to initialize network 87 | "init_type": "kaiming", // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming 88 | "module_name": "guided_diffusion", // sr3 | guided_diffusion 89 | "unet": { 90 | "in_channel": 6, 91 | "out_channel": 3, 92 | "inner_channel": 64, 93 | "channel_mults": [ 94 | 1, 95 | 2, 96 | 4, 97 | 8 98 | ], 99 | "attn_res": [ 100 | // 32, 101 | 16 102 | // 8 103 | ], 104 | "num_head_channels": 32, 105 | "res_blocks": 2, 106 | "dropout": 0.2, 107 | "image_size": 256 108 | }, 109 | "beta_schedule": { 110 | "train": { 111 | "schedule": "linear", 112 | "n_timestep": 2000, 113 | // "n_timestep": 10, // debug 114 | "linear_start": 1e-6, 115 | "linear_end": 0.01 116 | }, 117 | "test": { 118 | "schedule": "linear", 119 | "n_timestep": 1000, 120 | "linear_start": 1e-4, 121 | "linear_end": 0.09 122 | } 123 | } 124 | } 125 | } 126 | ], 127 | "which_losses": [ // import designated list of losses without arguments 128 | "mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}} 129 | ], 130 | "which_metrics": [ // import designated list of metrics without arguments 131 | "mae" // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}} 132 | ] 133 | }, 134 | 135 | "train": { // arguments for basic training 136 | "n_epoch": 1e8, // max epochs, not limited now 137 | "n_iter": 1e8, // max interations 138 | "val_epoch": 5, // valdation every specified number of epochs 139 | "save_checkpoint_epoch": 10, 140 | "log_iter": 1e3, // log every specified number of iterations 141 | "tensorboard" : true // tensorboardX enable 142 | }, 143 | 144 | "debug": { // arguments in debug mode, which will replace arguments in train 145 | "val_epoch": 1, 146 | "save_checkpoint_epoch": 1, 147 | "log_iter": 2, 148 | "debug_split": 50 // percent or number, change the size of dataloder to debug_split. 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /config/inpainting_places2.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "inpainting_places2", // experiments name 3 | "gpu_ids": [0], // gpu ids list, default is single 0 4 | "seed" : -1, // random seed, seed <0 represents randomization not used 5 | "finetune_norm": false, // find the parameters to optimize 6 | 7 | "path": { //set every part file path 8 | "base_dir": "experiments", // base path for all log except resume_state 9 | "code": "code", // code backup 10 | "tb_logger": "tb_logger", // path of tensorboard logger 11 | "results": "results", 12 | "checkpoint": "checkpoint", 13 | "resume_state": "experiments/train_inpainting_places2_220429_160230/checkpoint/8" 14 | // "resume_state": null // ex: 100, loading .state and .pth from given epoch and iteration 15 | }, 16 | 17 | "datasets": { // train or test 18 | "train": { 19 | "which_dataset": { // import designated dataset using arguments 20 | "name": ["data.dataset", "InpaintDataset"], // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py]) 21 | "args":{ // arguments to initialize dataset 22 | "data_root": "datasets/place365/flist/train.flist", 23 | "data_len": -1, 24 | "mask_config": { 25 | "mask_mode": "hybrid" 26 | } 27 | } 28 | }, 29 | "dataloader":{ 30 | "validation_split": 2, // percent or number 31 | "args":{ // arguments to initialize train_dataloader 32 | "batch_size": 3, // batch size in each gpu 33 | "num_workers": 4, 34 | "shuffle": true, 35 | "pin_memory": true, 36 | "drop_last": true 37 | }, 38 | "val_args":{ // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader 39 | "batch_size": 1, // batch size in each gpu 40 | "num_workers": 4, 41 | "shuffle": false, 42 | "pin_memory": true, 43 | "drop_last": false 44 | } 45 | } 46 | }, 47 | "test": { 48 | "which_dataset": { 49 | "name": "InpaintDataset", // import Dataset() class / function(not recommend) from default file 50 | "args":{ 51 | "data_root": "datasets/place365/flist/test.flist", 52 | "mask_config": { 53 | "mask_mode": "center" 54 | } 55 | } 56 | }, 57 | "dataloader":{ 58 | "args":{ 59 | "batch_size": 8, 60 | "num_workers": 4, 61 | "pin_memory": true 62 | } 63 | } 64 | } 65 | }, 66 | 67 | "model": { // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict 68 | "which_model": { // import designated model(trainer) using arguments 69 | "name": ["models.model", "Palette"], // import Model() class / function(not recommend) from models.model.py (default is [models.model.py]) 70 | "args": { 71 | "sample_num": 8, // process of each image 72 | "task": "inpainting", 73 | "ema_scheduler": { 74 | "ema_start": 1, 75 | "ema_iter": 1, 76 | "ema_decay": 0.9999 77 | }, 78 | "optimizers": [ 79 | { "lr": 5e-5, "weight_decay": 0} 80 | ] 81 | } 82 | }, 83 | "which_networks": [ // import designated list of networks using arguments 84 | { 85 | "name": ["models.network", "Network"], // import Network() class / function(not recommend) from default file (default is [models/network.py]) 86 | "args": { // arguments to initialize network 87 | "init_type": "kaiming", // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming 88 | "module_name": "guided_diffusion", // sr3 | guided_diffusion 89 | "unet": { 90 | "in_channel": 6, 91 | "out_channel": 3, 92 | "inner_channel": 64, 93 | "channel_mults": [ 94 | 1, 95 | 2, 96 | 4, 97 | 8 98 | ], 99 | "attn_res": [ 100 | // 32, 101 | 16 102 | // 8 103 | ], 104 | "num_head_channels": 32, 105 | "res_blocks": 2, 106 | "dropout": 0.2, 107 | "image_size": 256 108 | }, 109 | "beta_schedule": { 110 | "train": { 111 | "schedule": "linear", 112 | "n_timestep": 2000, 113 | // "n_timestep": 5, // debug 114 | "linear_start": 1e-6, 115 | "linear_end": 0.01 116 | }, 117 | "test": { 118 | "schedule": "linear", 119 | "n_timestep": 1000, 120 | "linear_start": 1e-4, 121 | "linear_end": 0.09 122 | } 123 | } 124 | } 125 | } 126 | ], 127 | "which_losses": [ // import designated list of losses without arguments 128 | "mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}} 129 | ], 130 | "which_metrics": [ // import designated list of metrics without arguments 131 | "mae" // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}} 132 | ] 133 | }, 134 | 135 | "train": { // arguments for basic training 136 | "n_epoch": 1e8, // max epochs, not limited now 137 | "n_iter": 1e8, // max interations 138 | "val_epoch": 1, // valdation every specified number of epochs 139 | "save_checkpoint_epoch": 1, 140 | "log_iter": 1e4, // log every specified number of iterations 141 | "tensorboard" : true // tensorboardX enable 142 | }, 143 | 144 | "debug": { // arguments in debug mode, which will replace arguments in train 145 | "val_epoch": 1, 146 | "save_checkpoint_epoch": 1, 147 | "log_iter": 10, 148 | "debug_split": 50 // percent or number, change the size of dataloder to debug_split. 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /config/uncropping_places2.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "uncropping_places2", // experiments name 3 | "gpu_ids": [0], // gpu ids list, default is single 0 4 | "seed" : -1, // random seed, seed <0 represents randomization not used 5 | "finetune_norm": false, // find the parameters to optimize 6 | 7 | "path": { //set every part file path 8 | "base_dir": "experiments", // base path for all log except resume_state 9 | "code": "code", // code backup 10 | "tb_logger": "tb_logger", // path of tensorboard logger 11 | "results": "results", 12 | "checkpoint": "checkpoint", 13 | // "resume_state": "experiments/inpainting_places2_220413_143231/checkpoint/25" 14 | "resume_state": null // ex: 100, loading .state and .pth from given epoch and iteration 15 | }, 16 | 17 | "datasets": { // train or test 18 | "train": { 19 | "which_dataset": { // import designated dataset using arguments 20 | "name": ["data.dataset", "UncroppingDataset"], // import Dataset() class / function(not recommend) from data.dataset.py (default is [data.dataset.py]) 21 | "args":{ // arguments to initialize dataset 22 | "data_root": "datasets/place365/flist/train.flist", 23 | "data_len": -1, 24 | "mask_config": { 25 | "mask_mode": "hybrid" // onedirection | fourdirection | hybrid | manual 26 | } 27 | } 28 | }, 29 | "dataloader":{ 30 | "validation_split": 2, // percent or number 31 | "args":{ // arguments to initialize train_dataloader 32 | "batch_size": 3, // batch size in each gpu 33 | "num_workers": 4, 34 | "shuffle": true, 35 | "pin_memory": true, 36 | "drop_last": true 37 | }, 38 | "val_args":{ // arguments to initialize valid_dataloader, will overwrite the parameters in train_dataloader 39 | "batch_size": 1, // batch size in each gpu 40 | "num_workers": 4, 41 | "shuffle": false, 42 | "pin_memory": true, 43 | "drop_last": false 44 | } 45 | } 46 | }, 47 | "test": { 48 | "which_dataset": { 49 | "name": "UncroppingDataset", // import Dataset() class / function(not recommend) from default file 50 | "args":{ 51 | "data_root": "datasets/place365/flist/test.flist", 52 | "mask_config": { 53 | "mask_mode": "onedirection", 54 | "shape": [] 55 | } 56 | } 57 | }, 58 | "dataloader":{ 59 | "args":{ 60 | "batch_size": 8, 61 | "num_workers": 4, 62 | "pin_memory": true 63 | } 64 | } 65 | } 66 | }, 67 | 68 | "model": { // networks/metrics/losses/optimizers/lr_schedulers is a list and model is a dict 69 | "which_model": { // import designated model(trainer) using arguments 70 | "name": ["models.model", "Palette"], // import Model() class / function(not recommend) from models.model.py (default is [models.model.py]) 71 | "args": { 72 | "sample_num": 8, // process of each image 73 | "task": "uncropping", 74 | "ema_scheduler": { 75 | "ema_start": 1, 76 | "ema_iter": 1, 77 | "ema_decay": 0.9999 78 | }, 79 | "optimizers": [ 80 | { "lr": 5e-5, "weight_decay": 0} 81 | ] 82 | } 83 | }, 84 | "which_networks": [ // import designated list of networks using arguments 85 | { 86 | "name": ["models.network", "Network"], // import Network() class / function(not recommend) from default file (default is [models/network.py]) 87 | "args": { // arguments to initialize network 88 | "init_type": "kaiming", // method can be [normal | xavier| xavier_uniform | kaiming | orthogonal], default is kaiming 89 | "module_name": "guided_diffusion", // sr3 | guided_diffusion 90 | "unet": { 91 | "in_channel": 6, 92 | "out_channel": 3, 93 | "inner_channel": 64, 94 | "channel_mults": [ 95 | 1, 96 | 2, 97 | 4, 98 | 8 99 | ], 100 | "attn_res": [ 101 | // 32, 102 | 16 103 | // 8 104 | ], 105 | "num_head_channels": 32, 106 | "res_blocks": 2, 107 | "dropout": 0.2, 108 | "image_size": 256 109 | }, 110 | "beta_schedule": { 111 | "train": { 112 | "schedule": "linear", 113 | "n_timestep": 2000, 114 | // "n_timestep": 5, // debug 115 | "linear_start": 1e-6, 116 | "linear_end": 0.01 117 | }, 118 | "test": { 119 | "schedule": "linear", 120 | "n_timestep": 1000, 121 | "linear_start": 1e-4, 122 | "linear_end": 0.09 123 | } 124 | } 125 | } 126 | } 127 | ], 128 | "which_optimizers": [ // len(networks) == len(optimizers) == len(lr_schedulers), it will be deleted after initialization if not used. 129 | { "name": "Adam", "args":{ "lr": 5e-5, "weight_decay": 0}} 130 | ], 131 | "which_lr_schedulers": [ // {} represents None, it will be deleted after initialization. 132 | {} 133 | // { "name": "LinearLR", "args": { "start_factor": 0.2, "total_iters": 1e3 }} 134 | ], 135 | "which_losses": [ // import designated list of losses without arguments 136 | "mse_loss" // import mse_loss() function/class from default file (default is [models/losses.py]), equivalent to { "name": "mse_loss", "args":{}} 137 | ], 138 | "which_metrics": [ // import designated list of metrics without arguments 139 | "mae" // import mae() function/class from default file (default is [models/metrics.py]), equivalent to { "name": "mae", "args":{}} 140 | ] 141 | }, 142 | 143 | "train": { // arguments for basic training 144 | "n_epoch": 1e8, // max epochs, not limited now 145 | "n_iter": 1e8, // max interations 146 | "val_epoch": 1, // valdation every specified number of epochs 147 | "save_checkpoint_epoch": 1, 148 | "log_iter": 1e4, // log every specified number of iterations 149 | "tensorboard" : true // tensorboardX enable 150 | }, 151 | 152 | "debug": { // arguments in debug mode, which will replace arguments in train 153 | "val_epoch": 1, 154 | "save_checkpoint_epoch": 1, 155 | "log_iter": 10, 156 | "debug_split": 50 // percent or number, change the size of dataloder to debug_split. 157 | } 158 | } 159 | -------------------------------------------------------------------------------- /core/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from torchvision import transforms 3 | from PIL import Image 4 | import os 5 | import numpy as np 6 | 7 | IMG_EXTENSIONS = [ 8 | '.jpg', '.JPG', '.jpeg', '.JPEG', 9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 10 | ] 11 | 12 | def is_image_file(filename): 13 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 14 | 15 | def make_dataset(dir): 16 | if os.path.isfile(dir): 17 | images = [i for i in np.genfromtxt(dir, dtype=np.str, encoding='utf-8')] 18 | else: 19 | images = [] 20 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 21 | for root, _, fnames in sorted(os.walk(dir)): 22 | for fname in sorted(fnames): 23 | if is_image_file(fname): 24 | path = os.path.join(root, fname) 25 | images.append(path) 26 | 27 | return images 28 | 29 | def pil_loader(path): 30 | return Image.open(path).convert('RGB') 31 | 32 | class BaseDataset(data.Dataset): 33 | def __init__(self, data_root, image_size=[256, 256], loader=pil_loader): 34 | self.imgs = make_dataset(data_root) 35 | self.tfs = transforms.Compose([ 36 | transforms.Resize((image_size[0], image_size[1])), 37 | transforms.ToTensor(), 38 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) 39 | ]) 40 | self.loader = loader 41 | 42 | def __getitem__(self, index): 43 | path = self.imgs[index] 44 | img = self.tfs(self.loader(path)) 45 | return img 46 | 47 | def __len__(self): 48 | return len(self.imgs) 49 | -------------------------------------------------------------------------------- /core/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import abstractmethod 3 | from functools import partial 4 | import collections 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | import core.util as Util 11 | CustomResult = collections.namedtuple('CustomResult', 'name result') 12 | 13 | class BaseModel(): 14 | def __init__(self, opt, phase_loader, val_loader, metrics, logger, writer): 15 | """ init model with basic input, which are from __init__(**kwargs) function in inherited class """ 16 | self.opt = opt 17 | self.phase = opt['phase'] 18 | self.set_device = partial(Util.set_device, rank=opt['global_rank']) 19 | 20 | ''' optimizers and schedulers ''' 21 | self.schedulers = [] 22 | self.optimizers = [] 23 | 24 | ''' process record ''' 25 | self.batch_size = self.opt['datasets'][self.phase]['dataloader']['args']['batch_size'] 26 | self.epoch = 0 27 | self.iter = 0 28 | 29 | self.phase_loader = phase_loader 30 | self.val_loader = val_loader 31 | self.metrics = metrics 32 | 33 | ''' logger to log file, which only work on GPU 0. writer to tensorboard and result file ''' 34 | self.logger = logger 35 | self.writer = writer 36 | self.results_dict = CustomResult([],[]) # {"name":[], "result":[]} 37 | 38 | def train(self): 39 | while self.epoch <= self.opt['train']['n_epoch'] and self.iter <= self.opt['train']['n_iter']: 40 | self.epoch += 1 41 | if self.opt['distributed']: 42 | ''' sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch ''' 43 | self.phase_loader.sampler.set_epoch(self.epoch) 44 | 45 | train_log = self.train_step() 46 | 47 | ''' save logged informations into log dict ''' 48 | train_log.update({'epoch': self.epoch, 'iters': self.iter}) 49 | 50 | ''' print logged informations to the screen and tensorboard ''' 51 | for key, value in train_log.items(): 52 | self.logger.info('{:5s}: {}\t'.format(str(key), value)) 53 | 54 | if self.epoch % self.opt['train']['save_checkpoint_epoch'] == 0: 55 | self.logger.info('Saving the self at the end of epoch {:.0f}'.format(self.epoch)) 56 | self.save_everything() 57 | 58 | if self.epoch % self.opt['train']['val_epoch'] == 0: 59 | self.logger.info("\n\n\n------------------------------Validation Start------------------------------") 60 | if self.val_loader is None: 61 | self.logger.warning('Validation stop where dataloader is None, Skip it.') 62 | else: 63 | val_log = self.val_step() 64 | for key, value in val_log.items(): 65 | self.logger.info('{:5s}: {}\t'.format(str(key), value)) 66 | self.logger.info("\n------------------------------Validation End------------------------------\n\n") 67 | self.logger.info('Number of Epochs has reached the limit, End.') 68 | 69 | def test(self): 70 | pass 71 | 72 | @abstractmethod 73 | def train_step(self): 74 | raise NotImplementedError('You must specify how to train your networks.') 75 | 76 | @abstractmethod 77 | def val_step(self): 78 | raise NotImplementedError('You must specify how to do validation on your networks.') 79 | 80 | def test_step(self): 81 | pass 82 | 83 | def print_network(self, network): 84 | """ print network structure, only work on GPU 0 """ 85 | if self.opt['global_rank'] !=0: 86 | return 87 | if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel): 88 | network = network.module 89 | 90 | s, n = str(network), sum(map(lambda x: x.numel(), network.parameters())) 91 | net_struc_str = '{}'.format(network.__class__.__name__) 92 | self.logger.info('Network structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 93 | self.logger.info(s) 94 | 95 | def save_network(self, network, network_label): 96 | """ save network structure, only work on GPU 0 """ 97 | if self.opt['global_rank'] !=0: 98 | return 99 | save_filename = '{}_{}.pth'.format(self.epoch, network_label) 100 | save_path = os.path.join(self.opt['path']['checkpoint'], save_filename) 101 | if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel): 102 | network = network.module 103 | state_dict = network.state_dict() 104 | for key, param in state_dict.items(): 105 | state_dict[key] = param.cpu() 106 | torch.save(state_dict, save_path) 107 | 108 | def load_network(self, network, network_label, strict=True): 109 | if self.opt['path']['resume_state'] is None: 110 | return 111 | self.logger.info('Beign loading pretrained model [{:s}] ...'.format(network_label)) 112 | 113 | model_path = "{}_{}.pth".format(self. opt['path']['resume_state'], network_label) 114 | 115 | if not os.path.exists(model_path): 116 | self.logger.warning('Pretrained model in [{:s}] is not existed, Skip it'.format(model_path)) 117 | return 118 | 119 | self.logger.info('Loading pretrained model from [{:s}] ...'.format(model_path)) 120 | if isinstance(network, nn.DataParallel) or isinstance(network, nn.parallel.DistributedDataParallel): 121 | network = network.module 122 | network.load_state_dict(torch.load(model_path, map_location = lambda storage, loc: Util.set_device(storage)), strict=strict) 123 | 124 | def save_training_state(self): 125 | """ saves training state during training, only work on GPU 0 """ 126 | if self.opt['global_rank'] !=0: 127 | return 128 | assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), 'optimizers and schedulers must be a list.' 129 | state = {'epoch': self.epoch, 'iter': self.iter, 'schedulers': [], 'optimizers': []} 130 | for s in self.schedulers: 131 | state['schedulers'].append(s.state_dict()) 132 | for o in self.optimizers: 133 | state['optimizers'].append(o.state_dict()) 134 | save_filename = '{}.state'.format(self.epoch) 135 | save_path = os.path.join(self.opt['path']['checkpoint'], save_filename) 136 | torch.save(state, save_path) 137 | 138 | def resume_training(self): 139 | """ resume the optimizers and schedulers for training, only work when phase is test or resume training enable """ 140 | if self.phase!='train' or self. opt['path']['resume_state'] is None: 141 | return 142 | self.logger.info('Beign loading training states'.format()) 143 | assert isinstance(self.optimizers, list) and isinstance(self.schedulers, list), 'optimizers and schedulers must be a list.' 144 | 145 | state_path = "{}.state".format(self. opt['path']['resume_state']) 146 | 147 | if not os.path.exists(state_path): 148 | self.logger.warning('Training state in [{:s}] is not existed, Skip it'.format(state_path)) 149 | return 150 | 151 | self.logger.info('Loading training state for [{:s}] ...'.format(state_path)) 152 | resume_state = torch.load(state_path, map_location = lambda storage, loc: self.set_device(storage)) 153 | 154 | resume_optimizers = resume_state['optimizers'] 155 | resume_schedulers = resume_state['schedulers'] 156 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers {} != {}'.format(len(resume_optimizers), len(self.optimizers)) 157 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers {} != {}'.format(len(resume_schedulers), len(self.schedulers)) 158 | for i, o in enumerate(resume_optimizers): 159 | self.optimizers[i].load_state_dict(o) 160 | for i, s in enumerate(resume_schedulers): 161 | self.schedulers[i].load_state_dict(s) 162 | 163 | self.epoch = resume_state['epoch'] 164 | self.iter = resume_state['iter'] 165 | 166 | def load_everything(self): 167 | pass 168 | 169 | @abstractmethod 170 | def save_everything(self): 171 | raise NotImplementedError('You must specify how to save your networks, optimizers and schedulers.') 172 | -------------------------------------------------------------------------------- /core/base_network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | class BaseNetwork(nn.Module): 3 | def __init__(self, init_type='kaiming', gain=0.02): 4 | super(BaseNetwork, self).__init__() 5 | self.init_type = init_type 6 | self.gain = gain 7 | 8 | def init_weights(self): 9 | """ 10 | initialize network's weights 11 | init_type: normal | xavier | kaiming | orthogonal 12 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 13 | """ 14 | 15 | def init_func(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('InstanceNorm2d') != -1: 18 | if hasattr(m, 'weight') and m.weight is not None: 19 | nn.init.constant_(m.weight.data, 1.0) 20 | if hasattr(m, 'bias') and m.bias is not None: 21 | nn.init.constant_(m.bias.data, 0.0) 22 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 23 | if self.init_type == 'normal': 24 | nn.init.normal_(m.weight.data, 0.0, self.gain) 25 | elif self.init_type == 'xavier': 26 | nn.init.xavier_normal_(m.weight.data, gain=self.gain) 27 | elif self.init_type == 'xavier_uniform': 28 | nn.init.xavier_uniform_(m.weight.data, gain=1.0) 29 | elif self.init_type == 'kaiming': 30 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 31 | elif self.init_type == 'orthogonal': 32 | nn.init.orthogonal_(m.weight.data, gain=self.gain) 33 | elif self.init_type == 'none': # uses pytorch's default init method 34 | m.reset_parameters() 35 | else: 36 | raise NotImplementedError('initialization method [%s] is not implemented' % self.init_type) 37 | if hasattr(m, 'bias') and m.bias is not None: 38 | nn.init.constant_(m.bias.data, 0.0) 39 | 40 | self.apply(init_func) 41 | # propagate to children 42 | for m in self.children(): 43 | if hasattr(m, 'init_weights'): 44 | m.init_weights(self.init_type, self.gain) 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /core/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import importlib 4 | from datetime import datetime 5 | import logging 6 | import pandas as pd 7 | 8 | import core.util as Util 9 | 10 | class InfoLogger(): 11 | """ 12 | use logging to record log, only work on GPU 0 by judging global_rank 13 | """ 14 | def __init__(self, opt): 15 | self.opt = opt 16 | self.rank = opt['global_rank'] 17 | self.phase = opt['phase'] 18 | 19 | self.setup_logger(None, opt['path']['experiments_root'], opt['phase'], level=logging.INFO, screen=False) 20 | self.logger = logging.getLogger(opt['phase']) 21 | self.infologger_ftns = {'info', 'warning', 'debug'} 22 | 23 | def __getattr__(self, name): 24 | if self.rank != 0: # info only print on GPU 0. 25 | def wrapper(info, *args, **kwargs): 26 | pass 27 | return wrapper 28 | if name in self.infologger_ftns: 29 | print_info = getattr(self.logger, name, None) 30 | def wrapper(info, *args, **kwargs): 31 | print_info(info, *args, **kwargs) 32 | return wrapper 33 | 34 | @staticmethod 35 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False): 36 | """ set up logger """ 37 | l = logging.getLogger(logger_name) 38 | formatter = logging.Formatter( 39 | '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') 40 | log_file = os.path.join(root, '{}.log'.format(phase)) 41 | fh = logging.FileHandler(log_file, mode='a+') 42 | fh.setFormatter(formatter) 43 | l.setLevel(level) 44 | l.addHandler(fh) 45 | if screen: 46 | sh = logging.StreamHandler() 47 | sh.setFormatter(formatter) 48 | l.addHandler(sh) 49 | 50 | class VisualWriter(): 51 | """ 52 | use tensorboard to record visuals, support 'add_scalar', 'add_scalars', 'add_image', 'add_images', etc. funtion. 53 | Also integrated with save results function. 54 | """ 55 | def __init__(self, opt, logger): 56 | log_dir = opt['path']['tb_logger'] 57 | self.result_dir = opt['path']['results'] 58 | enabled = opt['train']['tensorboard'] 59 | self.rank = opt['global_rank'] 60 | 61 | self.writer = None 62 | self.selected_module = "" 63 | 64 | if enabled and self.rank==0: 65 | log_dir = str(log_dir) 66 | 67 | # Retrieve vizualization writer. 68 | succeeded = False 69 | for module in ["tensorboardX", "torch.utils.tensorboard"]: 70 | try: 71 | self.writer = importlib.import_module(module).SummaryWriter(log_dir) 72 | succeeded = True 73 | break 74 | except ImportError: 75 | succeeded = False 76 | self.selected_module = module 77 | 78 | if not succeeded: 79 | message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ 80 | "this machine. Please install TensorboardX with 'pip install tensorboardx', upgrade PyTorch to " \ 81 | "version >= 1.1 to use 'torch.utils.tensorboard' or turn off the option in the 'config.json' file." 82 | logger.warning(message) 83 | 84 | self.epoch = 0 85 | self.iter = 0 86 | self.phase = '' 87 | 88 | self.tb_writer_ftns = { 89 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 90 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 91 | } 92 | self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} 93 | self.custom_ftns = {'close'} 94 | self.timer = datetime.now() 95 | 96 | def set_iter(self, epoch, iter, phase='train'): 97 | self.phase = phase 98 | self.epoch = epoch 99 | self.iter = iter 100 | 101 | def save_images(self, results): 102 | result_path = os.path.join(self.result_dir, self.phase) 103 | os.makedirs(result_path, exist_ok=True) 104 | result_path = os.path.join(result_path, str(self.epoch)) 105 | os.makedirs(result_path, exist_ok=True) 106 | 107 | ''' get names and corresponding images from results[OrderedDict] ''' 108 | try: 109 | names = results['name'] 110 | outputs = Util.postprocess(results['result']) 111 | for i in range(len(names)): 112 | Image.fromarray(outputs[i]).save(os.path.join(result_path, names[i])) 113 | except: 114 | raise NotImplementedError('You must specify the context of name and result in save_current_results functions of model.') 115 | 116 | def close(self): 117 | self.writer.close() 118 | print('Close the Tensorboard SummaryWriter.') 119 | 120 | 121 | def __getattr__(self, name): 122 | """ 123 | If visualization is configured to use: 124 | return add_data() methods of tensorboard with additional information (step, tag) added. 125 | Otherwise: 126 | return a blank function handle that does nothing 127 | """ 128 | if name in self.tb_writer_ftns: 129 | add_data = getattr(self.writer, name, None) 130 | def wrapper(tag, data, *args, **kwargs): 131 | if add_data is not None: 132 | # add phase(train/valid) tag 133 | if name not in self.tag_mode_exceptions: 134 | tag = '{}/{}'.format(self.phase, tag) 135 | add_data(tag, data, self.iter, *args, **kwargs) 136 | return wrapper 137 | else: 138 | # default action for returning methods defined in this class, set_step() for instance. 139 | try: 140 | attr = object.__getattr__(name) 141 | except AttributeError: 142 | raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) 143 | return attr 144 | 145 | 146 | class LogTracker: 147 | """ 148 | record training numerical indicators. 149 | """ 150 | def __init__(self, *keys, phase='train'): 151 | self.phase = phase 152 | self._data = pd.DataFrame(index=keys, columns=['total', 'counts', 'average']) 153 | self.reset() 154 | 155 | def reset(self): 156 | for col in self._data.columns: 157 | self._data[col].values[:] = 0 158 | 159 | def update(self, key, value, n=1): 160 | self._data.total[key] += value * n 161 | self._data.counts[key] += n 162 | self._data.average[key] = self._data.total[key] / self._data.counts[key] 163 | 164 | def avg(self, key): 165 | return self._data.average[key] 166 | 167 | def result(self): 168 | return {'{}/{}'.format(self.phase, k):v for k, v in dict(self._data.average).items()} 169 | -------------------------------------------------------------------------------- /core/praser.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import json 4 | from pathlib import Path 5 | from datetime import datetime 6 | from functools import partial 7 | import importlib 8 | from types import FunctionType 9 | import shutil 10 | def init_obj(opt, logger, *args, default_file_name='default file', given_module=None, init_type='Network', **modify_kwargs): 11 | """ 12 | finds a function handle with the name given as 'name' in config, 13 | and returns the instance initialized with corresponding args. 14 | """ 15 | if opt is None or len(opt)<1: 16 | logger.info('Option is None when initialize {}'.format(init_type)) 17 | return None 18 | 19 | ''' default format is dict with name key ''' 20 | if isinstance(opt, str): 21 | opt = {'name': opt} 22 | logger.warning('Config is a str, converts to a dict {}'.format(opt)) 23 | 24 | name = opt['name'] 25 | ''' name can be list, indicates the file and class name of function ''' 26 | if isinstance(name, list): 27 | file_name, class_name = name[0], name[1] 28 | else: 29 | file_name, class_name = default_file_name, name 30 | try: 31 | if given_module is not None: 32 | module = given_module 33 | else: 34 | module = importlib.import_module(file_name) 35 | 36 | attr = getattr(module, class_name) 37 | kwargs = opt.get('args', {}) 38 | kwargs.update(modify_kwargs) 39 | ''' import class or function with args ''' 40 | if isinstance(attr, type): 41 | ret = attr(*args, **kwargs) 42 | ret.__name__ = ret.__class__.__name__ 43 | elif isinstance(attr, FunctionType): 44 | ret = partial(attr, *args, **kwargs) 45 | ret.__name__ = attr.__name__ 46 | # ret = attr 47 | logger.info('{} [{:s}() form {:s}] is created.'.format(init_type, class_name, file_name)) 48 | except: 49 | raise NotImplementedError('{} [{:s}() form {:s}] not recognized.'.format(init_type, class_name, file_name)) 50 | return ret 51 | 52 | 53 | def mkdirs(paths): 54 | if isinstance(paths, str): 55 | os.makedirs(paths, exist_ok=True) 56 | else: 57 | for path in paths: 58 | os.makedirs(path, exist_ok=True) 59 | 60 | def get_timestamp(): 61 | return datetime.now().strftime('%y%m%d_%H%M%S') 62 | 63 | 64 | def write_json(content, fname): 65 | fname = Path(fname) 66 | with fname.open('wt') as handle: 67 | json.dump(content, handle, indent=4, sort_keys=False) 68 | 69 | class NoneDict(dict): 70 | def __missing__(self, key): 71 | return None 72 | 73 | def dict_to_nonedict(opt): 74 | """ convert to NoneDict, which return None for missing key. """ 75 | if isinstance(opt, dict): 76 | new_opt = dict() 77 | for key, sub_opt in opt.items(): 78 | new_opt[key] = dict_to_nonedict(sub_opt) 79 | return NoneDict(**new_opt) 80 | elif isinstance(opt, list): 81 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 82 | else: 83 | return opt 84 | 85 | def dict2str(opt, indent_l=1): 86 | """ dict to string for logger """ 87 | msg = '' 88 | for k, v in opt.items(): 89 | if isinstance(v, dict): 90 | msg += ' ' * (indent_l * 2) + k + ':[\n' 91 | msg += dict2str(v, indent_l + 1) 92 | msg += ' ' * (indent_l * 2) + ']\n' 93 | else: 94 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 95 | return msg 96 | 97 | def parse(args): 98 | json_str = '' 99 | with open(args.config, 'r') as f: 100 | for line in f: 101 | line = line.split('//')[0] + '\n' 102 | json_str += line 103 | opt = json.loads(json_str, object_pairs_hook=OrderedDict) 104 | 105 | ''' replace the config context using args ''' 106 | opt['phase'] = args.phase 107 | if args.gpu_ids is not None: 108 | opt['gpu_ids'] = [int(id) for id in args.gpu_ids.split(',')] 109 | if args.batch is not None: 110 | opt['datasets'][opt['phase']]['dataloader']['args']['batch_size'] = args.batch 111 | 112 | ''' set cuda environment ''' 113 | if len(opt['gpu_ids']) > 1: 114 | opt['distributed'] = True 115 | else: 116 | opt['distributed'] = False 117 | 118 | ''' update name ''' 119 | if args.debug: 120 | opt['name'] = 'debug_{}'.format(opt['name']) 121 | elif opt['finetune_norm']: 122 | opt['name'] = 'finetune_{}'.format(opt['name']) 123 | else: 124 | opt['name'] = '{}_{}'.format(opt['phase'], opt['name']) 125 | 126 | ''' set log directory ''' 127 | experiments_root = os.path.join(opt['path']['base_dir'], '{}_{}'.format(opt['name'], get_timestamp())) 128 | mkdirs(experiments_root) 129 | 130 | ''' save json ''' 131 | write_json(opt, '{}/config.json'.format(experiments_root)) 132 | 133 | ''' change folder relative hierarchy ''' 134 | opt['path']['experiments_root'] = experiments_root 135 | for key, path in opt['path'].items(): 136 | if 'resume' not in key and 'base' not in key and 'root' not in key: 137 | opt['path'][key] = os.path.join(experiments_root, path) 138 | mkdirs(opt['path'][key]) 139 | 140 | ''' debug mode ''' 141 | if 'debug' in opt['name']: 142 | opt['train'].update(opt['debug']) 143 | 144 | ''' code backup ''' 145 | for name in os.listdir('.'): 146 | if name in ['config', 'models', 'core', 'slurm', 'data']: 147 | shutil.copytree(name, os.path.join(opt['path']['code'], name), ignore=shutil.ignore_patterns("*.pyc", "__pycache__")) 148 | if '.py' in name or '.sh' in name: 149 | shutil.copy(name, opt['path']['code']) 150 | return dict_to_nonedict(opt) 151 | 152 | 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /core/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import math 4 | import torch 5 | from torch.nn.parallel import DistributedDataParallel as DDP 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)): 10 | ''' 11 | Converts a torch Tensor into an image Numpy array 12 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 13 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 14 | ''' 15 | tensor = tensor.clamp_(*min_max) # clamp 16 | n_dim = tensor.dim() 17 | if n_dim == 4: 18 | n_img = len(tensor) 19 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 20 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 21 | elif n_dim == 3: 22 | img_np = tensor.numpy() 23 | img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 24 | elif n_dim == 2: 25 | img_np = tensor.numpy() 26 | else: 27 | raise TypeError('Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 28 | if out_type == np.uint8: 29 | img_np = ((img_np+1) * 127.5).round() 30 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 31 | return img_np.astype(out_type).squeeze() 32 | 33 | def postprocess(images): 34 | return [tensor2img(image) for image in images] 35 | 36 | 37 | def set_seed(seed, gl_seed=0): 38 | """ set random seed, gl_seed used in worker_init_fn function """ 39 | if seed >=0 and gl_seed>=0: 40 | seed += gl_seed 41 | torch.manual_seed(seed) 42 | torch.cuda.manual_seed_all(seed) 43 | np.random.seed(seed) 44 | random.seed(seed) 45 | 46 | ''' change the deterministic and benchmark maybe cause uncertain convolution behavior. 47 | speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html ''' 48 | if seed >=0 and gl_seed>=0: # slower, more reproducible 49 | torch.backends.cudnn.deterministic = True 50 | torch.backends.cudnn.benchmark = False 51 | else: # faster, less reproducible 52 | torch.backends.cudnn.deterministic = False 53 | torch.backends.cudnn.benchmark = True 54 | 55 | def set_gpu(args, distributed=False, rank=0): 56 | """ set parameter to gpu or ddp """ 57 | if args is None: 58 | return None 59 | if distributed and isinstance(args, torch.nn.Module): 60 | return DDP(args.cuda(), device_ids=[rank], output_device=rank, broadcast_buffers=True, find_unused_parameters=True) 61 | else: 62 | return args.cuda() 63 | 64 | def set_device(args, distributed=False, rank=0): 65 | """ set parameter to gpu or cpu """ 66 | if torch.cuda.is_available(): 67 | if isinstance(args, list): 68 | return (set_gpu(item, distributed, rank) for item in args) 69 | elif isinstance(args, dict): 70 | return {key:set_gpu(args[key], distributed, rank) for key in args} 71 | else: 72 | args = set_gpu(args, distributed, rank) 73 | return args 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import numpy as np 3 | 4 | from torch.utils.data.distributed import DistributedSampler 5 | from torch import Generator, randperm 6 | from torch.utils.data import DataLoader, Subset 7 | 8 | import core.util as Util 9 | from core.praser import init_obj 10 | 11 | 12 | def define_dataloader(logger, opt): 13 | """ create train/test dataloader and validation dataloader, validation dataloader is None when phase is test or not GPU 0 """ 14 | '''create dataset and set random seed''' 15 | dataloader_args = opt['datasets'][opt['phase']]['dataloader']['args'] 16 | worker_init_fn = partial(Util.set_seed, gl_seed=opt['seed']) 17 | 18 | phase_dataset, val_dataset = define_dataset(logger, opt) 19 | 20 | '''create datasampler''' 21 | data_sampler = None 22 | if opt['distributed']: 23 | data_sampler = DistributedSampler(phase_dataset, shuffle=dataloader_args.get('shuffle', False), num_replicas=opt['world_size'], rank=opt['global_rank']) 24 | dataloader_args.update({'shuffle':False}) # sampler option is mutually exclusive with shuffle 25 | 26 | ''' create dataloader and validation dataloader ''' 27 | dataloader = DataLoader(phase_dataset, sampler=data_sampler, worker_init_fn=worker_init_fn, **dataloader_args) 28 | ''' val_dataloader don't use DistributedSampler to run only GPU 0! ''' 29 | if opt['global_rank']==0 and val_dataset is not None: 30 | dataloader_args.update(opt['datasets'][opt['phase']]['dataloader'].get('val_args',{})) 31 | val_dataloader = DataLoader(val_dataset, worker_init_fn=worker_init_fn, **dataloader_args) 32 | else: 33 | val_dataloader = None 34 | return dataloader, val_dataloader 35 | 36 | 37 | def define_dataset(logger, opt): 38 | ''' loading Dataset() class from given file's name ''' 39 | dataset_opt = opt['datasets'][opt['phase']]['which_dataset'] 40 | phase_dataset = init_obj(dataset_opt, logger, default_file_name='data.dataset', init_type='Dataset') 41 | val_dataset = None 42 | 43 | valid_len = 0 44 | data_len = len(phase_dataset) 45 | if 'debug' in opt['name']: 46 | debug_split = opt['debug'].get('debug_split', 1.0) 47 | if isinstance(debug_split, int): 48 | data_len = debug_split 49 | else: 50 | data_len *= debug_split 51 | 52 | dataloder_opt = opt['datasets'][opt['phase']]['dataloader'] 53 | valid_split = dataloder_opt.get('validation_split', 0) 54 | 55 | ''' divide validation dataset, valid_split==0 when phase is test or validation_split is 0. ''' 56 | if valid_split > 0.0 or 'debug' in opt['name']: 57 | if isinstance(valid_split, int): 58 | assert valid_split < data_len, "Validation set size is configured to be larger than entire dataset." 59 | valid_len = valid_split 60 | else: 61 | valid_len = int(data_len * valid_split) 62 | data_len -= valid_len 63 | phase_dataset, val_dataset = subset_split(dataset=phase_dataset, lengths=[data_len, valid_len], generator=Generator().manual_seed(opt['seed'])) 64 | 65 | logger.info('Dataset for {} have {} samples.'.format(opt['phase'], data_len)) 66 | if opt['phase'] == 'train': 67 | logger.info('Dataset for {} have {} samples.'.format('val', valid_len)) 68 | return phase_dataset, val_dataset 69 | 70 | def subset_split(dataset, lengths, generator): 71 | """ 72 | split a dataset into non-overlapping new datasets of given lengths. main code is from random_split function in pytorch 73 | """ 74 | indices = randperm(sum(lengths), generator=generator).tolist() 75 | Subsets = [] 76 | for offset, length in zip(np.add.accumulate(lengths), lengths): 77 | if length == 0: 78 | Subsets.append(None) 79 | else: 80 | Subsets.append(Subset(dataset, indices[offset - length : offset])) 81 | return Subsets 82 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from torchvision import transforms 3 | from PIL import Image 4 | import os 5 | import torch 6 | import numpy as np 7 | 8 | from .util.mask import (bbox2mask, brush_stroke_mask, get_irregular_mask, random_bbox, random_cropping_bbox) 9 | 10 | IMG_EXTENSIONS = [ 11 | '.jpg', '.JPG', '.jpeg', '.JPEG', 12 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 13 | ] 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def make_dataset(dir): 19 | if os.path.isfile(dir): 20 | images = [i for i in np.genfromtxt(dir, dtype=np.str, encoding='utf-8')] 21 | else: 22 | images = [] 23 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 24 | for root, _, fnames in sorted(os.walk(dir)): 25 | for fname in sorted(fnames): 26 | if is_image_file(fname): 27 | path = os.path.join(root, fname) 28 | images.append(path) 29 | 30 | return images 31 | 32 | def pil_loader(path): 33 | return Image.open(path).convert('RGB') 34 | 35 | class InpaintDataset(data.Dataset): 36 | def __init__(self, data_root, mask_config={}, data_len=-1, image_size=[256, 256], loader=pil_loader): 37 | imgs = make_dataset(data_root) 38 | if data_len > 0: 39 | self.imgs = imgs[:int(data_len)] 40 | else: 41 | self.imgs = imgs 42 | self.tfs = transforms.Compose([ 43 | transforms.Resize((image_size[0], image_size[1])), 44 | transforms.ToTensor(), 45 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5]) 46 | ]) 47 | self.loader = loader 48 | self.mask_config = mask_config 49 | self.mask_mode = self.mask_config['mask_mode'] 50 | self.image_size = image_size 51 | 52 | def __getitem__(self, index): 53 | ret = {} 54 | path = self.imgs[index] 55 | img = self.tfs(self.loader(path)) 56 | mask = self.get_mask() 57 | cond_image = img*(1. - mask) + mask*torch.randn_like(img) 58 | mask_img = img*(1. - mask) + mask 59 | 60 | ret['gt_image'] = img 61 | ret['cond_image'] = cond_image 62 | ret['mask_image'] = mask_img 63 | ret['mask'] = mask 64 | ret['path'] = path.rsplit("/")[-1].rsplit("\\")[-1] 65 | return ret 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | 70 | def get_mask(self): 71 | if self.mask_mode == 'bbox': 72 | mask = bbox2mask(self.image_size, random_bbox()) 73 | elif self.mask_mode == 'center': 74 | h, w = self.image_size 75 | mask = bbox2mask(self.image_size, (h//4, w//4, h//2, w//2)) 76 | elif self.mask_mode == 'irregular': 77 | mask = get_irregular_mask(self.image_size) 78 | elif self.mask_mode == 'free_form': 79 | mask = brush_stroke_mask(self.image_size) 80 | elif self.mask_mode == 'hybrid': 81 | regular_mask = bbox2mask(self.image_size, random_bbox()) 82 | irregular_mask = brush_stroke_mask(self.image_size, ) 83 | mask = regular_mask | irregular_mask 84 | elif self.mask_mode == 'file': 85 | pass 86 | else: 87 | raise NotImplementedError( 88 | f'Mask mode {self.mask_mode} has not been implemented.') 89 | return torch.from_numpy(mask).permute(2,0,1) 90 | 91 | 92 | class UncroppingDataset(data.Dataset): 93 | def __init__(self, data_root, mask_config={}, data_len=-1, image_size=[256, 256], loader=pil_loader): 94 | imgs = make_dataset(data_root) 95 | if data_len > 0: 96 | self.imgs = imgs[:int(data_len)] 97 | else: 98 | self.imgs = imgs 99 | self.tfs = transforms.Compose([ 100 | transforms.Resize((image_size[0], image_size[1])), 101 | transforms.ToTensor(), 102 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5]) 103 | ]) 104 | self.loader = loader 105 | self.mask_config = mask_config 106 | self.mask_mode = self.mask_config['mask_mode'] 107 | self.image_size = image_size 108 | 109 | def __getitem__(self, index): 110 | ret = {} 111 | path = self.imgs[index] 112 | img = self.tfs(self.loader(path)) 113 | mask = self.get_mask() 114 | cond_image = img*(1. - mask) + mask*torch.randn_like(img) 115 | mask_img = img*(1. - mask) + mask 116 | 117 | ret['gt_image'] = img 118 | ret['cond_image'] = cond_image 119 | ret['mask_image'] = mask_img 120 | ret['mask'] = mask 121 | ret['path'] = path.rsplit("/")[-1].rsplit("\\")[-1] 122 | return ret 123 | 124 | def __len__(self): 125 | return len(self.imgs) 126 | 127 | def get_mask(self): 128 | if self.mask_mode == 'manual': 129 | mask = bbox2mask(self.image_size, self.mask_config['shape']) 130 | elif self.mask_mode == 'fourdirection' or self.mask_mode == 'onedirection': 131 | mask = bbox2mask(self.image_size, random_cropping_bbox(mask_mode=self.mask_mode)) 132 | elif self.mask_mode == 'hybrid': 133 | if np.random.randint(0,2)<1: 134 | mask = bbox2mask(self.image_size, random_cropping_bbox(mask_mode='onedirection')) 135 | else: 136 | mask = bbox2mask(self.image_size, random_cropping_bbox(mask_mode='fourdirection')) 137 | elif self.mask_mode == 'file': 138 | pass 139 | else: 140 | raise NotImplementedError( 141 | f'Mask mode {self.mask_mode} has not been implemented.') 142 | return torch.from_numpy(mask).permute(2,0,1) 143 | 144 | 145 | class ColorizationDataset(data.Dataset): 146 | def __init__(self, data_root, data_flist, data_len=-1, image_size=[224, 224], loader=pil_loader): 147 | self.data_root = data_root 148 | flist = make_dataset(data_flist) 149 | if data_len > 0: 150 | self.flist = flist[:int(data_len)] 151 | else: 152 | self.flist = flist 153 | self.tfs = transforms.Compose([ 154 | transforms.Resize((image_size[0], image_size[1])), 155 | transforms.ToTensor(), 156 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5]) 157 | ]) 158 | self.loader = loader 159 | self.image_size = image_size 160 | 161 | def __getitem__(self, index): 162 | ret = {} 163 | file_name = str(self.flist[index]).zfill(5) + '.png' 164 | 165 | img = self.tfs(self.loader('{}/{}/{}'.format(self.data_root, 'color', file_name))) 166 | cond_image = self.tfs(self.loader('{}/{}/{}'.format(self.data_root, 'gray', file_name))) 167 | 168 | ret['gt_image'] = img 169 | ret['cond_image'] = cond_image 170 | ret['path'] = file_name 171 | return ret 172 | 173 | def __len__(self): 174 | return len(self.flist) 175 | 176 | 177 | -------------------------------------------------------------------------------- /data/util/auto_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from scipy import ndimage 4 | from PIL import Image, ImageEnhance, ImageOps 5 | 6 | 7 | class AutoAugment(object): 8 | def __init__(self): 9 | self.policies = [ 10 | ['Invert', 0.1, 7, 'Contrast', 0.2, 6], 11 | ['Rotate', 0.7, 2, 'TranslateX', 0.3, 9], 12 | ['Sharpness', 0.8, 1, 'Sharpness', 0.9, 3], 13 | ['ShearY', 0.5, 8, 'TranslateY', 0.7, 9], 14 | ['AutoContrast', 0.5, 8, 'Equalize', 0.9, 2], 15 | ['ShearY', 0.2, 7, 'Posterize', 0.3, 7], 16 | ['Color', 0.4, 3, 'Brightness', 0.6, 7], 17 | ['Sharpness', 0.3, 9, 'Brightness', 0.7, 9], 18 | ['Equalize', 0.6, 5, 'Equalize', 0.5, 1], 19 | ['Contrast', 0.6, 7, 'Sharpness', 0.6, 5], 20 | ['Color', 0.7, 7, 'TranslateX', 0.5, 8], 21 | ['Equalize', 0.3, 7, 'AutoContrast', 0.4, 8], 22 | ['TranslateY', 0.4, 3, 'Sharpness', 0.2, 6], 23 | ['Brightness', 0.9, 6, 'Color', 0.2, 8], 24 | ['Solarize', 0.5, 2, 'Invert', 0, 0.3], 25 | ['Equalize', 0.2, 0, 'AutoContrast', 0.6, 0], 26 | ['Equalize', 0.2, 8, 'Equalize', 0.6, 4], 27 | ['Color', 0.9, 9, 'Equalize', 0.6, 6], 28 | ['AutoContrast', 0.8, 4, 'Solarize', 0.2, 8], 29 | ['Brightness', 0.1, 3, 'Color', 0.7, 0], 30 | ['Solarize', 0.4, 5, 'AutoContrast', 0.9, 3], 31 | ['TranslateY', 0.9, 9, 'TranslateY', 0.7, 9], 32 | ['AutoContrast', 0.9, 2, 'Solarize', 0.8, 3], 33 | ['Equalize', 0.8, 8, 'Invert', 0.1, 3], 34 | ['TranslateY', 0.7, 9, 'AutoContrast', 0.9, 1], 35 | ] 36 | 37 | def __call__(self, img): 38 | img = apply_policy(img, self.policies[random.randrange(len(self.policies))]) 39 | return img 40 | 41 | 42 | class ImageNetAutoAugment(object): 43 | def __init__(self): 44 | self.policies = [ 45 | ['Posterize', 0.4, 8, 'Rotate', 0.6, 9], 46 | ['Solarize', 0.6, 5, 'AutoContrast', 0.6, 5], 47 | ['Equalize', 0.8, 8, 'Equalize', 0.6, 3], 48 | ['Posterize', 0.6, 7, 'Posterize', 0.6, 6], 49 | ['Equalize', 0.4, 7, 'Solarize', 0.2, 4], 50 | ['Equalize', 0.4, 4, 'Rotate', 0.8, 8], 51 | ['Solarize', 0.6, 3, 'Equalize', 0.6, 7], 52 | ['Posterize', 0.8, 5, 'Equalize', 1.0, 2], 53 | ['Rotate', 0.2, 3, 'Solarize', 0.6, 8], 54 | ['Equalize', 0.6, 8, 'Posterize', 0.4, 6], 55 | ['Rotate', 0.8, 8, 'Color', 0.4, 0], 56 | ['Rotate', 0.4, 9, 'Equalize', 0.6, 2], 57 | ['Equalize', 0.0, 0.7, 'Equalize', 0.8, 8], 58 | ['Invert', 0.6, 4, 'Equalize', 1.0, 8], 59 | ['Color', 0.6, 4, 'Contrast', 1.0, 8], 60 | ['Rotate', 0.8, 8, 'Color', 1.0, 2], 61 | ['Color', 0.8, 8, 'Solarize', 0.8, 7], 62 | ['Sharpness', 0.4, 7, 'Invert', 0.6, 8], 63 | ['ShearX', 0.6, 5, 'Equalize', 1.0, 9], 64 | ['Color', 0.4, 0, 'Equalize', 0.6, 3], 65 | ['Equalize', 0.4, 7, 'Solarize', 0.2, 4], 66 | ['Solarize', 0.6, 5, 'AutoContrast', 0.6, 5], 67 | ['Invert', 0.6, 4, 'Equalize', 1.0, 8], 68 | ['Color', 0.6, 4, 'Contrast', 1.0, 8], 69 | ['Equalize', 0.8, 8, 'Equalize', 0.6, 3] 70 | ] 71 | 72 | def __call__(self, img): 73 | img = apply_policy(img, self.policies[random.randrange(len(self.policies))]) 74 | return img 75 | 76 | 77 | operations = { 78 | 'ShearX': lambda img, magnitude: shear_x(img, magnitude), 79 | 'ShearY': lambda img, magnitude: shear_y(img, magnitude), 80 | 'TranslateX': lambda img, magnitude: translate_x(img, magnitude), 81 | 'TranslateY': lambda img, magnitude: translate_y(img, magnitude), 82 | 'Rotate': lambda img, magnitude: rotate(img, magnitude), 83 | 'AutoContrast': lambda img, magnitude: auto_contrast(img, magnitude), 84 | 'Invert': lambda img, magnitude: invert(img, magnitude), 85 | 'Equalize': lambda img, magnitude: equalize(img, magnitude), 86 | 'Solarize': lambda img, magnitude: solarize(img, magnitude), 87 | 'Posterize': lambda img, magnitude: posterize(img, magnitude), 88 | 'Contrast': lambda img, magnitude: contrast(img, magnitude), 89 | 'Color': lambda img, magnitude: color(img, magnitude), 90 | 'Brightness': lambda img, magnitude: brightness(img, magnitude), 91 | 'Sharpness': lambda img, magnitude: sharpness(img, magnitude), 92 | 'Cutout': lambda img, magnitude: cutout(img, magnitude), 93 | } 94 | 95 | 96 | def apply_policy(img, policy): 97 | if random.random() < policy[1]: 98 | img = operations[policy[0]](img, policy[2]) 99 | if random.random() < policy[4]: 100 | img = operations[policy[3]](img, policy[5]) 101 | 102 | return img 103 | 104 | 105 | def transform_matrix_offset_center(matrix, x, y): 106 | o_x = float(x) / 2 + 0.5 107 | o_y = float(y) / 2 + 0.5 108 | offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]]) 109 | reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]]) 110 | transform_matrix = offset_matrix @ matrix @ reset_matrix 111 | return transform_matrix 112 | 113 | 114 | def shear_x(img, magnitude): 115 | img = np.array(img) 116 | magnitudes = np.linspace(-0.3, 0.3, 11) 117 | 118 | transform_matrix = np.array([[1, random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]), 0], 119 | [0, 1, 0], 120 | [0, 0, 1]]) 121 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 122 | affine_matrix = transform_matrix[:2, :2] 123 | offset = transform_matrix[:2, 2] 124 | img = np.stack([ndimage.interpolation.affine_transform( 125 | img[:, :, c], 126 | affine_matrix, 127 | offset) for c in range(img.shape[2])], axis=2) 128 | img = Image.fromarray(img) 129 | return img 130 | 131 | 132 | def shear_y(img, magnitude): 133 | img = np.array(img) 134 | magnitudes = np.linspace(-0.3, 0.3, 11) 135 | 136 | transform_matrix = np.array([[1, 0, 0], 137 | [random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]), 1, 0], 138 | [0, 0, 1]]) 139 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 140 | affine_matrix = transform_matrix[:2, :2] 141 | offset = transform_matrix[:2, 2] 142 | img = np.stack([ndimage.interpolation.affine_transform( 143 | img[:, :, c], 144 | affine_matrix, 145 | offset) for c in range(img.shape[2])], axis=2) 146 | img = Image.fromarray(img) 147 | return img 148 | 149 | 150 | def translate_x(img, magnitude): 151 | img = np.array(img) 152 | magnitudes = np.linspace(-150/331, 150/331, 11) 153 | 154 | transform_matrix = np.array([[1, 0, 0], 155 | [0, 1, img.shape[1]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])], 156 | [0, 0, 1]]) 157 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 158 | affine_matrix = transform_matrix[:2, :2] 159 | offset = transform_matrix[:2, 2] 160 | img = np.stack([ndimage.interpolation.affine_transform( 161 | img[:, :, c], 162 | affine_matrix, 163 | offset) for c in range(img.shape[2])], axis=2) 164 | img = Image.fromarray(img) 165 | return img 166 | 167 | 168 | def translate_y(img, magnitude): 169 | img = np.array(img) 170 | magnitudes = np.linspace(-150/331, 150/331, 11) 171 | 172 | transform_matrix = np.array([[1, 0, img.shape[0]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])], 173 | [0, 1, 0], 174 | [0, 0, 1]]) 175 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 176 | affine_matrix = transform_matrix[:2, :2] 177 | offset = transform_matrix[:2, 2] 178 | img = np.stack([ndimage.interpolation.affine_transform( 179 | img[:, :, c], 180 | affine_matrix, 181 | offset) for c in range(img.shape[2])], axis=2) 182 | img = Image.fromarray(img) 183 | return img 184 | 185 | 186 | def rotate(img, magnitude): 187 | img = np.array(img) 188 | magnitudes = np.linspace(-30, 30, 11) 189 | theta = np.deg2rad(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 190 | transform_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], 191 | [np.sin(theta), np.cos(theta), 0], 192 | [0, 0, 1]]) 193 | transform_matrix = transform_matrix_offset_center(transform_matrix, img.shape[0], img.shape[1]) 194 | affine_matrix = transform_matrix[:2, :2] 195 | offset = transform_matrix[:2, 2] 196 | img = np.stack([ndimage.interpolation.affine_transform( 197 | img[:, :, c], 198 | affine_matrix, 199 | offset) for c in range(img.shape[2])], axis=2) 200 | img = Image.fromarray(img) 201 | return img 202 | 203 | 204 | def auto_contrast(img, magnitude): 205 | img = ImageOps.autocontrast(img) 206 | return img 207 | 208 | 209 | def invert(img, magnitude): 210 | img = ImageOps.invert(img) 211 | return img 212 | 213 | 214 | def equalize(img, magnitude): 215 | img = ImageOps.equalize(img) 216 | return img 217 | 218 | 219 | def solarize(img, magnitude): 220 | magnitudes = np.linspace(0, 256, 11) 221 | img = ImageOps.solarize(img, random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 222 | return img 223 | 224 | 225 | def posterize(img, magnitude): 226 | magnitudes = np.linspace(4, 8, 11) 227 | img = ImageOps.posterize(img, int(round(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])))) 228 | return img 229 | 230 | 231 | def contrast(img, magnitude): 232 | magnitudes = np.linspace(0.1, 1.9, 11) 233 | img = ImageEnhance.Contrast(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 234 | return img 235 | 236 | 237 | def color(img, magnitude): 238 | magnitudes = np.linspace(0.1, 1.9, 11) 239 | img = ImageEnhance.Color(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 240 | return img 241 | 242 | 243 | def brightness(img, magnitude): 244 | magnitudes = np.linspace(0.1, 1.9, 11) 245 | img = ImageEnhance.Brightness(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 246 | return img 247 | 248 | 249 | def sharpness(img, magnitude): 250 | magnitudes = np.linspace(0.1, 1.9, 11) 251 | img = ImageEnhance.Sharpness(img).enhance(random.uniform(magnitudes[magnitude], magnitudes[magnitude+1])) 252 | return img 253 | 254 | 255 | def cutout(org_img, magnitude=None): 256 | 257 | magnitudes = np.linspace(0, 60/331, 11) 258 | 259 | img = np.copy(org_img) 260 | mask_val = img.mean() 261 | 262 | if magnitude is None: 263 | mask_size = 16 264 | else: 265 | mask_size = int(round(img.shape[0]*random.uniform(magnitudes[magnitude], magnitudes[magnitude+1]))) 266 | top = np.random.randint(0 - mask_size//2, img.shape[0] - mask_size) 267 | left = np.random.randint(0 - mask_size//2, img.shape[1] - mask_size) 268 | bottom = top + mask_size 269 | right = left + mask_size 270 | 271 | if top < 0: 272 | top = 0 273 | if left < 0: 274 | left = 0 275 | 276 | img[top:bottom, left:right, :].fill(mask_val) 277 | 278 | img = Image.fromarray(img) 279 | 280 | return img 281 | 282 | 283 | class Cutout(object): 284 | 285 | def __init__(self, length=16): 286 | self.length = length 287 | 288 | def __call__(self, img): 289 | img = np.array(img) 290 | 291 | mask_val = img.mean() 292 | 293 | top = np.random.randint(0 - self.length//2, img.shape[0] - self.length) 294 | left = np.random.randint(0 - self.length//2, img.shape[1] - self.length) 295 | bottom = top + self.length 296 | right = left + self.length 297 | 298 | top = 0 if top < 0 else top 299 | left = 0 if left < 0 else top 300 | 301 | img[top:bottom, left:right, :] = mask_val 302 | 303 | img = Image.fromarray(img) 304 | 305 | return img -------------------------------------------------------------------------------- /data/util/mask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | 4 | import cv2 5 | import numpy as np 6 | from PIL import Image, ImageDraw 7 | 8 | 9 | def random_cropping_bbox(img_shape=(256,256), mask_mode='onedirection'): 10 | h, w = img_shape 11 | if mask_mode == 'onedirection': 12 | _type = np.random.randint(0, 4) 13 | if _type == 0: 14 | top, left, height, width = 0, 0, h, w//2 15 | elif _type == 1: 16 | top, left, height, width = 0, 0, h//2, w 17 | elif _type == 2: 18 | top, left, height, width = h//2, 0, h//2, w 19 | elif _type == 3: 20 | top, left, height, width = 0, w//2, h, w//2 21 | else: 22 | target_area = (h*w)//2 23 | width = np.random.randint(target_area//h, w) 24 | height = target_area//width 25 | if h==height: 26 | top = 0 27 | else: 28 | top = np.random.randint(0, h-height) 29 | if w==width: 30 | left = 0 31 | else: 32 | left = np.random.randint(0, w-width) 33 | return (top, left, height, width) 34 | 35 | def random_bbox(img_shape=(256,256), max_bbox_shape=(128, 128), max_bbox_delta=40, min_margin=20): 36 | """Generate a random bbox for the mask on a given image. 37 | 38 | In our implementation, the max value cannot be obtained since we use 39 | `np.random.randint`. And this may be different with other standard scripts 40 | in the community. 41 | 42 | Args: 43 | img_shape (tuple[int]): The size of a image, in the form of (h, w). 44 | max_bbox_shape (int | tuple[int]): Maximum shape of the mask box, 45 | in the form of (h, w). If it is an integer, the mask box will be 46 | square. 47 | max_bbox_delta (int | tuple[int]): Maximum delta of the mask box, 48 | in the form of (delta_h, delta_w). If it is an integer, delta_h 49 | and delta_w will be the same. Mask shape will be randomly sampled 50 | from the range of `max_bbox_shape - max_bbox_delta` and 51 | `max_bbox_shape`. Default: (40, 40). 52 | min_margin (int | tuple[int]): The minimum margin size from the 53 | edges of mask box to the image boarder, in the form of 54 | (margin_h, margin_w). If it is an integer, margin_h and margin_w 55 | will be the same. Default: (20, 20). 56 | 57 | Returns: 58 | tuple[int]: The generated box, (top, left, h, w). 59 | """ 60 | if not isinstance(max_bbox_shape, tuple): 61 | max_bbox_shape = (max_bbox_shape, max_bbox_shape) 62 | if not isinstance(max_bbox_delta, tuple): 63 | max_bbox_delta = (max_bbox_delta, max_bbox_delta) 64 | if not isinstance(min_margin, tuple): 65 | min_margin = (min_margin, min_margin) 66 | 67 | img_h, img_w = img_shape[:2] 68 | max_mask_h, max_mask_w = max_bbox_shape 69 | max_delta_h, max_delta_w = max_bbox_delta 70 | margin_h, margin_w = min_margin 71 | 72 | if max_mask_h > img_h or max_mask_w > img_w: 73 | raise ValueError(f'mask shape {max_bbox_shape} should be smaller than ' 74 | f'image shape {img_shape}') 75 | if (max_delta_h // 2 * 2 >= max_mask_h 76 | or max_delta_w // 2 * 2 >= max_mask_w): 77 | raise ValueError(f'mask delta {max_bbox_delta} should be smaller than' 78 | f'mask shape {max_bbox_shape}') 79 | if img_h - max_mask_h < 2 * margin_h or img_w - max_mask_w < 2 * margin_w: 80 | raise ValueError(f'Margin {min_margin} cannot be satisfied for img' 81 | f'shape {img_shape} and mask shape {max_bbox_shape}') 82 | 83 | # get the max value of (top, left) 84 | max_top = img_h - margin_h - max_mask_h 85 | max_left = img_w - margin_w - max_mask_w 86 | # randomly select a (top, left) 87 | top = np.random.randint(margin_h, max_top) 88 | left = np.random.randint(margin_w, max_left) 89 | # randomly shrink the shape of mask box according to `max_bbox_delta` 90 | # the center of box is fixed 91 | delta_top = np.random.randint(0, max_delta_h // 2 + 1) 92 | delta_left = np.random.randint(0, max_delta_w // 2 + 1) 93 | top = top + delta_top 94 | left = left + delta_left 95 | h = max_mask_h - delta_top 96 | w = max_mask_w - delta_left 97 | return (top, left, h, w) 98 | 99 | 100 | def bbox2mask(img_shape, bbox, dtype='uint8'): 101 | """Generate mask in ndarray from bbox. 102 | 103 | The returned mask has the shape of (h, w, 1). '1' indicates the 104 | hole and '0' indicates the valid regions. 105 | 106 | We prefer to use `uint8` as the data type of masks, which may be different 107 | from other codes in the community. 108 | 109 | Args: 110 | img_shape (tuple[int]): The size of the image. 111 | bbox (tuple[int]): Configuration tuple, (top, left, height, width) 112 | dtype (str): Indicate the data type of returned masks. Default: 'uint8' 113 | 114 | Return: 115 | numpy.ndarray: Mask in the shape of (h, w, 1). 116 | """ 117 | 118 | height, width = img_shape[:2] 119 | 120 | mask = np.zeros((height, width, 1), dtype=dtype) 121 | mask[bbox[0]:bbox[0] + bbox[2], bbox[1]:bbox[1] + bbox[3], :] = 1 122 | 123 | return mask 124 | 125 | 126 | def brush_stroke_mask(img_shape, 127 | num_vertices=(4, 12), 128 | mean_angle=2 * math.pi / 5, 129 | angle_range=2 * math.pi / 15, 130 | brush_width=(12, 40), 131 | max_loops=4, 132 | dtype='uint8'): 133 | """Generate free-form mask. 134 | 135 | The method of generating free-form mask is in the following paper: 136 | Free-Form Image Inpainting with Gated Convolution. 137 | 138 | When you set the config of this type of mask. You may note the usage of 139 | `np.random.randint` and the range of `np.random.randint` is [left, right). 140 | 141 | We prefer to use `uint8` as the data type of masks, which may be different 142 | from other codes in the community. 143 | 144 | TODO: Rewrite the implementation of this function. 145 | 146 | Args: 147 | img_shape (tuple[int]): Size of the image. 148 | num_vertices (int | tuple[int]): Min and max number of vertices. If 149 | only give an integer, we will fix the number of vertices. 150 | Default: (4, 12). 151 | mean_angle (float): Mean value of the angle in each vertex. The angle 152 | is measured in radians. Default: 2 * math.pi / 5. 153 | angle_range (float): Range of the random angle. 154 | Default: 2 * math.pi / 15. 155 | brush_width (int | tuple[int]): (min_width, max_width). If only give 156 | an integer, we will fix the width of brush. Default: (12, 40). 157 | max_loops (int): The max number of for loops of drawing strokes. 158 | dtype (str): Indicate the data type of returned masks. 159 | Default: 'uint8'. 160 | 161 | Returns: 162 | numpy.ndarray: Mask in the shape of (h, w, 1). 163 | """ 164 | 165 | img_h, img_w = img_shape[:2] 166 | if isinstance(num_vertices, int): 167 | min_num_vertices, max_num_vertices = num_vertices, num_vertices + 1 168 | elif isinstance(num_vertices, tuple): 169 | min_num_vertices, max_num_vertices = num_vertices 170 | else: 171 | raise TypeError('The type of num_vertices should be int' 172 | f'or tuple[int], but got type: {num_vertices}') 173 | 174 | if isinstance(brush_width, tuple): 175 | min_width, max_width = brush_width 176 | elif isinstance(brush_width, int): 177 | min_width, max_width = brush_width, brush_width + 1 178 | else: 179 | raise TypeError('The type of brush_width should be int' 180 | f'or tuple[int], but got type: {brush_width}') 181 | 182 | average_radius = math.sqrt(img_h * img_h + img_w * img_w) / 8 183 | mask = Image.new('L', (img_w, img_h), 0) 184 | 185 | loop_num = np.random.randint(1, max_loops) 186 | num_vertex_list = np.random.randint( 187 | min_num_vertices, max_num_vertices, size=loop_num) 188 | angle_min_list = np.random.uniform(0, angle_range, size=loop_num) 189 | angle_max_list = np.random.uniform(0, angle_range, size=loop_num) 190 | 191 | for loop_n in range(loop_num): 192 | num_vertex = num_vertex_list[loop_n] 193 | angle_min = mean_angle - angle_min_list[loop_n] 194 | angle_max = mean_angle + angle_max_list[loop_n] 195 | angles = [] 196 | vertex = [] 197 | 198 | # set random angle on each vertex 199 | angles = np.random.uniform(angle_min, angle_max, size=num_vertex) 200 | reverse_mask = (np.arange(num_vertex, dtype=np.float32) % 2) == 0 201 | angles[reverse_mask] = 2 * math.pi - angles[reverse_mask] 202 | 203 | h, w = mask.size 204 | 205 | # set random vertices 206 | vertex.append((np.random.randint(0, w), np.random.randint(0, h))) 207 | r_list = np.random.normal( 208 | loc=average_radius, scale=average_radius // 2, size=num_vertex) 209 | for i in range(num_vertex): 210 | r = np.clip(r_list[i], 0, 2 * average_radius) 211 | new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w) 212 | new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h) 213 | vertex.append((int(new_x), int(new_y))) 214 | # draw brush strokes according to the vertex and angle list 215 | draw = ImageDraw.Draw(mask) 216 | width = np.random.randint(min_width, max_width) 217 | draw.line(vertex, fill=1, width=width) 218 | for v in vertex: 219 | draw.ellipse((v[0] - width // 2, v[1] - width // 2, 220 | v[0] + width // 2, v[1] + width // 2), 221 | fill=1) 222 | # randomly flip the mask 223 | if np.random.normal() > 0: 224 | mask.transpose(Image.FLIP_LEFT_RIGHT) 225 | if np.random.normal() > 0: 226 | mask.transpose(Image.FLIP_TOP_BOTTOM) 227 | mask = np.array(mask).astype(dtype=getattr(np, dtype)) 228 | mask = mask[:, :, None] 229 | return mask 230 | 231 | 232 | def random_irregular_mask(img_shape, 233 | num_vertices=(4, 8), 234 | max_angle=4, 235 | length_range=(10, 100), 236 | brush_width=(10, 40), 237 | dtype='uint8'): 238 | """Generate random irregular masks. 239 | 240 | This is a modified version of free-form mask implemented in 241 | 'brush_stroke_mask'. 242 | 243 | We prefer to use `uint8` as the data type of masks, which may be different 244 | from other codes in the community. 245 | 246 | TODO: Rewrite the implementation of this function. 247 | 248 | Args: 249 | img_shape (tuple[int]): Size of the image. 250 | num_vertices (int | tuple[int]): Min and max number of vertices. If 251 | only give an integer, we will fix the number of vertices. 252 | Default: (4, 8). 253 | max_angle (float): Max value of angle at each vertex. Default 4.0. 254 | length_range (int | tuple[int]): (min_length, max_length). If only give 255 | an integer, we will fix the length of brush. Default: (10, 100). 256 | brush_width (int | tuple[int]): (min_width, max_width). If only give 257 | an integer, we will fix the width of brush. Default: (10, 40). 258 | dtype (str): Indicate the data type of returned masks. Default: 'uint8' 259 | 260 | Returns: 261 | numpy.ndarray: Mask in the shape of (h, w, 1). 262 | """ 263 | 264 | h, w = img_shape[:2] 265 | 266 | mask = np.zeros((h, w), dtype=dtype) 267 | if isinstance(length_range, int): 268 | min_length, max_length = length_range, length_range + 1 269 | elif isinstance(length_range, tuple): 270 | min_length, max_length = length_range 271 | else: 272 | raise TypeError('The type of length_range should be int' 273 | f'or tuple[int], but got type: {length_range}') 274 | if isinstance(num_vertices, int): 275 | min_num_vertices, max_num_vertices = num_vertices, num_vertices + 1 276 | elif isinstance(num_vertices, tuple): 277 | min_num_vertices, max_num_vertices = num_vertices 278 | else: 279 | raise TypeError('The type of num_vertices should be int' 280 | f'or tuple[int], but got type: {num_vertices}') 281 | 282 | if isinstance(brush_width, int): 283 | min_brush_width, max_brush_width = brush_width, brush_width + 1 284 | elif isinstance(brush_width, tuple): 285 | min_brush_width, max_brush_width = brush_width 286 | else: 287 | raise TypeError('The type of brush_width should be int' 288 | f'or tuple[int], but got type: {brush_width}') 289 | 290 | num_v = np.random.randint(min_num_vertices, max_num_vertices) 291 | 292 | for i in range(num_v): 293 | start_x = np.random.randint(w) 294 | start_y = np.random.randint(h) 295 | # from the start point, randomly setlect n \in [1, 6] directions. 296 | direction_num = np.random.randint(1, 6) 297 | angle_list = np.random.randint(0, max_angle, size=direction_num) 298 | length_list = np.random.randint( 299 | min_length, max_length, size=direction_num) 300 | brush_width_list = np.random.randint( 301 | min_brush_width, max_brush_width, size=direction_num) 302 | for direct_n in range(direction_num): 303 | angle = 0.01 + angle_list[direct_n] 304 | if i % 2 == 0: 305 | angle = 2 * math.pi - angle 306 | length = length_list[direct_n] 307 | brush_w = brush_width_list[direct_n] 308 | # compute end point according to the random angle 309 | end_x = (start_x + length * np.sin(angle)).astype(np.int32) 310 | end_y = (start_y + length * np.cos(angle)).astype(np.int32) 311 | 312 | cv2.line(mask, (start_y, start_x), (end_y, end_x), 1, brush_w) 313 | start_x, start_y = end_x, end_y 314 | mask = np.expand_dims(mask, axis=2) 315 | 316 | return mask 317 | 318 | 319 | def get_irregular_mask(img_shape, area_ratio_range=(0.15, 0.5), **kwargs): 320 | """Get irregular mask with the constraints in mask ratio 321 | 322 | Args: 323 | img_shape (tuple[int]): Size of the image. 324 | area_ratio_range (tuple(float)): Contain the minimum and maximum area 325 | ratio. Default: (0.15, 0.5). 326 | 327 | Returns: 328 | numpy.ndarray: Mask in the shape of (h, w, 1). 329 | """ 330 | 331 | mask = random_irregular_mask(img_shape, **kwargs) 332 | min_ratio, max_ratio = area_ratio_range 333 | 334 | while not min_ratio < (np.sum(mask) / 335 | (img_shape[0] * img_shape[1])) < max_ratio: 336 | mask = random_irregular_mask(img_shape, **kwargs) 337 | 338 | return mask 339 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from cleanfid import fid 3 | from core.base_dataset import BaseDataset 4 | from models.metric import inception_score 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('-s', '--src', type=str, help='Ground truth images directory') 9 | parser.add_argument('-d', '--dst', type=str, help='Generate images directory') 10 | 11 | ''' parser configs ''' 12 | args = parser.parse_args() 13 | 14 | fid_score = fid.compute_fid(args.src, args.dst) 15 | is_mean, is_std = inception_score(BaseDataset(args.dst), cuda=True, batch_size=8, resize=True, splits=10) 16 | 17 | print('FID: {}'.format(fid_score)) 18 | print('IS:{} {}'.format(is_mean, is_std)) -------------------------------------------------------------------------------- /misc/Palette Image-to-Image Diffusion Models.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/136b29f58d0af6e5db9f3655d2891f5a855fcdaa/misc/Palette Image-to-Image Diffusion Models.pdf -------------------------------------------------------------------------------- /misc/image/Mask_Places365_test_00143399.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/136b29f58d0af6e5db9f3655d2891f5a855fcdaa/misc/image/Mask_Places365_test_00143399.jpg -------------------------------------------------------------------------------- /misc/image/Mask_Places365_test_00144085.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/136b29f58d0af6e5db9f3655d2891f5a855fcdaa/misc/image/Mask_Places365_test_00144085.jpg -------------------------------------------------------------------------------- /misc/image/Mask_Places365_test_00209019.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/136b29f58d0af6e5db9f3655d2891f5a855fcdaa/misc/image/Mask_Places365_test_00209019.jpg -------------------------------------------------------------------------------- /misc/image/Mask_Places365_test_00263905.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/136b29f58d0af6e5db9f3655d2891f5a855fcdaa/misc/image/Mask_Places365_test_00263905.jpg -------------------------------------------------------------------------------- /misc/image/Out_Places365_test_00143399.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/136b29f58d0af6e5db9f3655d2891f5a855fcdaa/misc/image/Out_Places365_test_00143399.jpg -------------------------------------------------------------------------------- /misc/image/Out_Places365_test_00144085.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/136b29f58d0af6e5db9f3655d2891f5a855fcdaa/misc/image/Out_Places365_test_00144085.jpg -------------------------------------------------------------------------------- /misc/image/Out_Places365_test_00209019.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/136b29f58d0af6e5db9f3655d2891f5a855fcdaa/misc/image/Out_Places365_test_00209019.jpg -------------------------------------------------------------------------------- /misc/image/Out_Places365_test_00263905.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/136b29f58d0af6e5db9f3655d2891f5a855fcdaa/misc/image/Out_Places365_test_00263905.jpg -------------------------------------------------------------------------------- /misc/image/Process_02323.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/136b29f58d0af6e5db9f3655d2891f5a855fcdaa/misc/image/Process_02323.jpg -------------------------------------------------------------------------------- /misc/image/Process_26190.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/136b29f58d0af6e5db9f3655d2891f5a855fcdaa/misc/image/Process_26190.jpg -------------------------------------------------------------------------------- /misc/image/Process_Places365_test_00042384.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/136b29f58d0af6e5db9f3655d2891f5a855fcdaa/misc/image/Process_Places365_test_00042384.jpg -------------------------------------------------------------------------------- /misc/image/Process_Places365_test_00309553.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Janspiry/Palette-Image-to-Image-Diffusion-Models/136b29f58d0af6e5db9f3655d2891f5a855fcdaa/misc/image/Process_Places365_test_00309553.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from core.praser import init_obj 2 | 3 | def create_model(**cfg_model): 4 | """ create_model """ 5 | opt = cfg_model['opt'] 6 | logger = cfg_model['logger'] 7 | 8 | model_opt = opt['model']['which_model'] 9 | model_opt['args'].update(cfg_model) 10 | model = init_obj(model_opt, logger, default_file_name='models.model', init_type='Model') 11 | 12 | return model 13 | 14 | def define_network(logger, opt, network_opt): 15 | """ define network with weights initialization """ 16 | net = init_obj(network_opt, logger, default_file_name='models.network', init_type='Network') 17 | 18 | if opt['phase'] == 'train': 19 | logger.info('Network [{}] weights initialize using [{:s}] method.'.format(net.__class__.__name__, network_opt['args'].get('init_type', 'default'))) 20 | net.init_weights() 21 | return net 22 | 23 | 24 | def define_loss(logger, loss_opt): 25 | return init_obj(loss_opt, logger, default_file_name='models.loss', init_type='Loss') 26 | 27 | def define_metric(logger, metric_opt): 28 | return init_obj(metric_opt, logger, default_file_name='models.metric', init_type='Metric') 29 | 30 | -------------------------------------------------------------------------------- /models/guided_diffusion_modules/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class GroupNorm32(nn.GroupNorm): 12 | def forward(self, x): 13 | return super().forward(x.float()).type(x.dtype) 14 | 15 | 16 | def zero_module(module): 17 | """ 18 | Zero out the parameters of a module and return it. 19 | """ 20 | for p in module.parameters(): 21 | p.detach().zero_() 22 | return module 23 | 24 | 25 | def scale_module(module, scale): 26 | """ 27 | Scale the parameters of a module and return it. 28 | """ 29 | for p in module.parameters(): 30 | p.detach().mul_(scale) 31 | return module 32 | 33 | 34 | def mean_flat(tensor): 35 | """ 36 | Take the mean over all non-batch dimensions. 37 | """ 38 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 39 | 40 | 41 | def normalization(channels): 42 | """ 43 | Make a standard normalization layer. 44 | 45 | :param channels: number of input channels. 46 | :return: an nn.Module for normalization. 47 | """ 48 | return GroupNorm32(32, channels) 49 | 50 | 51 | 52 | def checkpoint(func, inputs, params, flag): 53 | """ 54 | Evaluate a function without caching intermediate activations, allowing for 55 | reduced memory at the expense of extra compute in the backward pass. 56 | 57 | :param func: the function to evaluate. 58 | :param inputs: the argument sequence to pass to `func`. 59 | :param params: a sequence of parameters `func` depends on but does not 60 | explicitly take as arguments. 61 | :param flag: if False, disable gradient checkpointing. 62 | """ 63 | if flag: 64 | args = tuple(inputs) + tuple(params) 65 | return CheckpointFunction.apply(func, len(inputs), *args) 66 | else: 67 | return func(*inputs) 68 | 69 | 70 | class CheckpointFunction(torch.autograd.Function): 71 | @staticmethod 72 | def forward(ctx, run_function, length, *args): 73 | ctx.run_function = run_function 74 | ctx.input_tensors = list(args[:length]) 75 | ctx.input_params = list(args[length:]) 76 | with torch.no_grad(): 77 | output_tensors = ctx.run_function(*ctx.input_tensors) 78 | return output_tensors 79 | 80 | @staticmethod 81 | def backward(ctx, *output_grads): 82 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 83 | with torch.enable_grad(): 84 | # Fixes a bug where the first op in run_function modifies the 85 | # Tensor storage in place, which is not allowed for detach()'d 86 | # Tensors. 87 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 88 | output_tensors = ctx.run_function(*shallow_copies) 89 | input_grads = torch.autograd.grad( 90 | output_tensors, 91 | ctx.input_tensors + ctx.input_params, 92 | output_grads, 93 | allow_unused=True, 94 | ) 95 | del ctx.input_tensors 96 | del ctx.input_params 97 | del output_tensors 98 | return (None, None) + input_grads 99 | 100 | 101 | def count_flops_attn(model, _x, y): 102 | """ 103 | A counter for the `thop` package to count the operations in an 104 | attention operation. 105 | Meant to be used like: 106 | macs, params = thop.profile( 107 | model, 108 | inputs=(inputs, timestamps), 109 | custom_ops={QKVAttention: QKVAttention.count_flops}, 110 | ) 111 | """ 112 | b, c, *spatial = y[0].shape 113 | num_spatial = int(np.prod(spatial)) 114 | # We perform two matmuls with the same number of ops. 115 | # The first computes the weight matrix, the second computes 116 | # the combination of the value vectors. 117 | matmul_ops = 2 * b * (num_spatial ** 2) * c 118 | model.total_ops += torch.DoubleTensor([matmul_ops]) 119 | 120 | 121 | def gamma_embedding(gammas, dim, max_period=10000): 122 | """ 123 | Create sinusoidal timestep embeddings. 124 | :param gammas: a 1-D Tensor of N indices, one per batch element. 125 | These may be fractional. 126 | :param dim: the dimension of the output. 127 | :param max_period: controls the minimum frequency of the embeddings. 128 | :return: an [N x dim] Tensor of positional embeddings. 129 | """ 130 | half = dim // 2 131 | freqs = torch.exp( 132 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 133 | ).to(device=gammas.device) 134 | args = gammas[:, None].float() * freqs[None] 135 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 136 | if dim % 2: 137 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 138 | return embedding -------------------------------------------------------------------------------- /models/guided_diffusion_modules/unet.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .nn import ( 9 | checkpoint, 10 | zero_module, 11 | normalization, 12 | count_flops_attn, 13 | gamma_embedding 14 | ) 15 | 16 | class SiLU(nn.Module): 17 | def forward(self, x): 18 | return x * torch.sigmoid(x) 19 | 20 | class EmbedBlock(nn.Module): 21 | """ 22 | Any module where forward() takes embeddings as a second argument. 23 | """ 24 | 25 | @abstractmethod 26 | def forward(self, x, emb): 27 | """ 28 | Apply the module to `x` given `emb` embeddings. 29 | """ 30 | 31 | class EmbedSequential(nn.Sequential, EmbedBlock): 32 | """ 33 | A sequential module that passes embeddings to the children that 34 | support it as an extra input. 35 | """ 36 | 37 | def forward(self, x, emb): 38 | for layer in self: 39 | if isinstance(layer, EmbedBlock): 40 | x = layer(x, emb) 41 | else: 42 | x = layer(x) 43 | return x 44 | 45 | class Upsample(nn.Module): 46 | """ 47 | An upsampling layer with an optional convolution. 48 | :param channels: channels in the inputs and outputs. 49 | :param use_conv: a bool determining if a convolution is applied. 50 | 51 | """ 52 | 53 | def __init__(self, channels, use_conv, out_channel=None): 54 | super().__init__() 55 | self.channels = channels 56 | self.out_channel = out_channel or channels 57 | self.use_conv = use_conv 58 | if use_conv: 59 | self.conv = nn.Conv2d(self.channels, self.out_channel, 3, padding=1) 60 | 61 | def forward(self, x): 62 | assert x.shape[1] == self.channels 63 | x = F.interpolate(x, scale_factor=2, mode="nearest") 64 | if self.use_conv: 65 | x = self.conv(x) 66 | return x 67 | 68 | class Downsample(nn.Module): 69 | """ 70 | A downsampling layer with an optional convolution. 71 | :param channels: channels in the inputs and outputs. 72 | :param use_conv: a bool determining if a convolution is applied. 73 | """ 74 | 75 | def __init__(self, channels, use_conv, out_channel=None): 76 | super().__init__() 77 | self.channels = channels 78 | self.out_channel = out_channel or channels 79 | self.use_conv = use_conv 80 | stride = 2 81 | if use_conv: 82 | self.op = nn.Conv2d( 83 | self.channels, self.out_channel, 3, stride=stride, padding=1 84 | ) 85 | else: 86 | assert self.channels == self.out_channel 87 | self.op = nn.AvgPool2d(kernel_size=stride, stride=stride) 88 | 89 | def forward(self, x): 90 | assert x.shape[1] == self.channels 91 | return self.op(x) 92 | 93 | 94 | class ResBlock(EmbedBlock): 95 | """ 96 | A residual block that can optionally change the number of channels. 97 | :param channels: the number of input channels. 98 | :param emb_channels: the number of embedding channels. 99 | :param dropout: the rate of dropout. 100 | :param out_channel: if specified, the number of out channels. 101 | :param use_conv: if True and out_channel is specified, use a spatial 102 | convolution instead of a smaller 1x1 convolution to change the 103 | channels in the skip connection. 104 | :param use_checkpoint: if True, use gradient checkpointing on this module. 105 | :param up: if True, use this block for upsampling. 106 | :param down: if True, use this block for downsampling. 107 | """ 108 | 109 | def __init__( 110 | self, 111 | channels, 112 | emb_channels, 113 | dropout, 114 | out_channel=None, 115 | use_conv=False, 116 | use_scale_shift_norm=False, 117 | use_checkpoint=False, 118 | up=False, 119 | down=False, 120 | ): 121 | super().__init__() 122 | self.channels = channels 123 | self.emb_channels = emb_channels 124 | self.dropout = dropout 125 | self.out_channel = out_channel or channels 126 | self.use_conv = use_conv 127 | self.use_checkpoint = use_checkpoint 128 | self.use_scale_shift_norm = use_scale_shift_norm 129 | 130 | self.in_layers = nn.Sequential( 131 | normalization(channels), 132 | SiLU(), 133 | nn.Conv2d(channels, self.out_channel, 3, padding=1), 134 | ) 135 | 136 | self.updown = up or down 137 | 138 | if up: 139 | self.h_upd = Upsample(channels, False) 140 | self.x_upd = Upsample(channels, False) 141 | elif down: 142 | self.h_upd = Downsample(channels, False) 143 | self.x_upd = Downsample(channels, False) 144 | else: 145 | self.h_upd = self.x_upd = nn.Identity() 146 | 147 | self.emb_layers = nn.Sequential( 148 | SiLU(), 149 | nn.Linear( 150 | emb_channels, 151 | 2 * self.out_channel if use_scale_shift_norm else self.out_channel, 152 | ), 153 | ) 154 | self.out_layers = nn.Sequential( 155 | normalization(self.out_channel), 156 | SiLU(), 157 | nn.Dropout(p=dropout), 158 | zero_module( 159 | nn.Conv2d(self.out_channel, self.out_channel, 3, padding=1) 160 | ), 161 | ) 162 | 163 | if self.out_channel == channels: 164 | self.skip_connection = nn.Identity() 165 | elif use_conv: 166 | self.skip_connection = nn.Conv2d( 167 | channels, self.out_channel, 3, padding=1 168 | ) 169 | else: 170 | self.skip_connection = nn.Conv2d(channels, self.out_channel, 1) 171 | 172 | def forward(self, x, emb): 173 | """ 174 | Apply the block to a Tensor, conditioned on a embedding. 175 | :param x: an [N x C x ...] Tensor of features. 176 | :param emb: an [N x emb_channels] Tensor of embeddings. 177 | :return: an [N x C x ...] Tensor of outputs. 178 | """ 179 | return checkpoint( 180 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 181 | ) 182 | 183 | def _forward(self, x, emb): 184 | if self.updown: 185 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 186 | h = in_rest(x) 187 | h = self.h_upd(h) 188 | x = self.x_upd(x) 189 | h = in_conv(h) 190 | else: 191 | h = self.in_layers(x) 192 | emb_out = self.emb_layers(emb).type(h.dtype) 193 | while len(emb_out.shape) < len(h.shape): 194 | emb_out = emb_out[..., None] 195 | if self.use_scale_shift_norm: 196 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 197 | scale, shift = torch.chunk(emb_out, 2, dim=1) 198 | h = out_norm(h) * (1 + scale) + shift 199 | h = out_rest(h) 200 | else: 201 | h = h + emb_out 202 | h = self.out_layers(h) 203 | return self.skip_connection(x) + h 204 | 205 | class AttentionBlock(nn.Module): 206 | """ 207 | An attention block that allows spatial positions to attend to each other. 208 | Originally ported from here, but adapted to the N-d case. 209 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 210 | """ 211 | 212 | def __init__( 213 | self, 214 | channels, 215 | num_heads=1, 216 | num_head_channels=-1, 217 | use_checkpoint=False, 218 | use_new_attention_order=False, 219 | ): 220 | super().__init__() 221 | self.channels = channels 222 | if num_head_channels == -1: 223 | self.num_heads = num_heads 224 | else: 225 | assert ( 226 | channels % num_head_channels == 0 227 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 228 | self.num_heads = channels // num_head_channels 229 | self.use_checkpoint = use_checkpoint 230 | self.norm = normalization(channels) 231 | self.qkv = nn.Conv1d(channels, channels * 3, 1) 232 | if use_new_attention_order: 233 | # split qkv before split heads 234 | self.attention = QKVAttention(self.num_heads) 235 | else: 236 | # split heads before split qkv 237 | self.attention = QKVAttentionLegacy(self.num_heads) 238 | 239 | self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) 240 | 241 | def forward(self, x): 242 | return checkpoint(self._forward, (x,), self.parameters(), True) 243 | 244 | def _forward(self, x): 245 | b, c, *spatial = x.shape 246 | x = x.reshape(b, c, -1) 247 | qkv = self.qkv(self.norm(x)) 248 | h = self.attention(qkv) 249 | h = self.proj_out(h) 250 | return (x + h).reshape(b, c, *spatial) 251 | 252 | 253 | class QKVAttentionLegacy(nn.Module): 254 | """ 255 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 256 | """ 257 | 258 | def __init__(self, n_heads): 259 | super().__init__() 260 | self.n_heads = n_heads 261 | 262 | def forward(self, qkv): 263 | """ 264 | Apply QKV attention. 265 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 266 | :return: an [N x (H * C) x T] tensor after attention. 267 | """ 268 | bs, width, length = qkv.shape 269 | assert width % (3 * self.n_heads) == 0 270 | ch = width // (3 * self.n_heads) 271 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) 272 | scale = 1 / math.sqrt(math.sqrt(ch)) 273 | weight = torch.einsum( 274 | "bct,bcs->bts", q * scale, k * scale 275 | ) # More stable with f16 than dividing afterwards 276 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 277 | a = torch.einsum("bts,bcs->bct", weight, v) 278 | return a.reshape(bs, -1, length) 279 | 280 | @staticmethod 281 | def count_flops(model, _x, y): 282 | return count_flops_attn(model, _x, y) 283 | 284 | 285 | class QKVAttention(nn.Module): 286 | """ 287 | A module which performs QKV attention and splits in a different order. 288 | """ 289 | 290 | def __init__(self, n_heads): 291 | super().__init__() 292 | self.n_heads = n_heads 293 | 294 | def forward(self, qkv): 295 | """ 296 | Apply QKV attention. 297 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 298 | :return: an [N x (H * C) x T] tensor after attention. 299 | """ 300 | bs, width, length = qkv.shape 301 | assert width % (3 * self.n_heads) == 0 302 | ch = width // (3 * self.n_heads) 303 | q, k, v = qkv.chunk(3, dim=1) 304 | scale = 1 / math.sqrt(math.sqrt(ch)) 305 | weight = torch.einsum( 306 | "bct,bcs->bts", 307 | (q * scale).view(bs * self.n_heads, ch, length), 308 | (k * scale).view(bs * self.n_heads, ch, length), 309 | ) # More stable with f16 than dividing afterwards 310 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 311 | a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) 312 | return a.reshape(bs, -1, length) 313 | 314 | @staticmethod 315 | def count_flops(model, _x, y): 316 | return count_flops_attn(model, _x, y) 317 | 318 | class UNet(nn.Module): 319 | """ 320 | The full UNet model with attention and embedding. 321 | :param in_channel: channels in the input Tensor, for image colorization : Y_channels + X_channels . 322 | :param inner_channel: base channel count for the model. 323 | :param out_channel: channels in the output Tensor. 324 | :param res_blocks: number of residual blocks per downsample. 325 | :param attn_res: a collection of downsample rates at which 326 | attention will take place. May be a set, list, or tuple. 327 | For example, if this contains 4, then at 4x downsampling, attention 328 | will be used. 329 | :param dropout: the dropout probability. 330 | :param channel_mults: channel multiplier for each level of the UNet. 331 | :param conv_resample: if True, use learned convolutions for upsampling and 332 | downsampling. 333 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 334 | :param num_heads: the number of attention heads in each attention layer. 335 | :param num_heads_channels: if specified, ignore num_heads and instead use 336 | a fixed channel width per attention head. 337 | :param num_heads_upsample: works with num_heads to set a different number 338 | of heads for upsampling. Deprecated. 339 | :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. 340 | :param resblock_updown: use residual blocks for up/downsampling. 341 | :param use_new_attention_order: use a different attention pattern for potentially 342 | increased efficiency. 343 | """ 344 | 345 | def __init__( 346 | self, 347 | image_size, 348 | in_channel, 349 | inner_channel, 350 | out_channel, 351 | res_blocks, 352 | attn_res, 353 | dropout=0, 354 | channel_mults=(1, 2, 4, 8), 355 | conv_resample=True, 356 | use_checkpoint=False, 357 | use_fp16=False, 358 | num_heads=1, 359 | num_head_channels=-1, 360 | num_heads_upsample=-1, 361 | use_scale_shift_norm=True, 362 | resblock_updown=True, 363 | use_new_attention_order=False, 364 | ): 365 | 366 | super().__init__() 367 | 368 | if num_heads_upsample == -1: 369 | num_heads_upsample = num_heads 370 | 371 | self.image_size = image_size 372 | self.in_channel = in_channel 373 | self.inner_channel = inner_channel 374 | self.out_channel = out_channel 375 | self.res_blocks = res_blocks 376 | self.attn_res = attn_res 377 | self.dropout = dropout 378 | self.channel_mults = channel_mults 379 | self.conv_resample = conv_resample 380 | self.use_checkpoint = use_checkpoint 381 | self.dtype = torch.float16 if use_fp16 else torch.float32 382 | self.num_heads = num_heads 383 | self.num_head_channels = num_head_channels 384 | self.num_heads_upsample = num_heads_upsample 385 | 386 | cond_embed_dim = inner_channel * 4 387 | self.cond_embed = nn.Sequential( 388 | nn.Linear(inner_channel, cond_embed_dim), 389 | SiLU(), 390 | nn.Linear(cond_embed_dim, cond_embed_dim), 391 | ) 392 | 393 | ch = input_ch = int(channel_mults[0] * inner_channel) 394 | self.input_blocks = nn.ModuleList( 395 | [EmbedSequential(nn.Conv2d(in_channel, ch, 3, padding=1))] 396 | ) 397 | self._feature_size = ch 398 | input_block_chans = [ch] 399 | ds = 1 400 | for level, mult in enumerate(channel_mults): 401 | for _ in range(res_blocks): 402 | layers = [ 403 | ResBlock( 404 | ch, 405 | cond_embed_dim, 406 | dropout, 407 | out_channel=int(mult * inner_channel), 408 | use_checkpoint=use_checkpoint, 409 | use_scale_shift_norm=use_scale_shift_norm, 410 | ) 411 | ] 412 | ch = int(mult * inner_channel) 413 | if ds in attn_res: 414 | layers.append( 415 | AttentionBlock( 416 | ch, 417 | use_checkpoint=use_checkpoint, 418 | num_heads=num_heads, 419 | num_head_channels=num_head_channels, 420 | use_new_attention_order=use_new_attention_order, 421 | ) 422 | ) 423 | self.input_blocks.append(EmbedSequential(*layers)) 424 | self._feature_size += ch 425 | input_block_chans.append(ch) 426 | if level != len(channel_mults) - 1: 427 | out_ch = ch 428 | self.input_blocks.append( 429 | EmbedSequential( 430 | ResBlock( 431 | ch, 432 | cond_embed_dim, 433 | dropout, 434 | out_channel=out_ch, 435 | use_checkpoint=use_checkpoint, 436 | use_scale_shift_norm=use_scale_shift_norm, 437 | down=True, 438 | ) 439 | if resblock_updown 440 | else Downsample( 441 | ch, conv_resample, out_channel=out_ch 442 | ) 443 | ) 444 | ) 445 | ch = out_ch 446 | input_block_chans.append(ch) 447 | ds *= 2 448 | self._feature_size += ch 449 | 450 | self.middle_block = EmbedSequential( 451 | ResBlock( 452 | ch, 453 | cond_embed_dim, 454 | dropout, 455 | use_checkpoint=use_checkpoint, 456 | use_scale_shift_norm=use_scale_shift_norm, 457 | ), 458 | AttentionBlock( 459 | ch, 460 | use_checkpoint=use_checkpoint, 461 | num_heads=num_heads, 462 | num_head_channels=num_head_channels, 463 | use_new_attention_order=use_new_attention_order, 464 | ), 465 | ResBlock( 466 | ch, 467 | cond_embed_dim, 468 | dropout, 469 | use_checkpoint=use_checkpoint, 470 | use_scale_shift_norm=use_scale_shift_norm, 471 | ), 472 | ) 473 | self._feature_size += ch 474 | 475 | self.output_blocks = nn.ModuleList([]) 476 | for level, mult in list(enumerate(channel_mults))[::-1]: 477 | for i in range(res_blocks + 1): 478 | ich = input_block_chans.pop() 479 | layers = [ 480 | ResBlock( 481 | ch + ich, 482 | cond_embed_dim, 483 | dropout, 484 | out_channel=int(inner_channel * mult), 485 | use_checkpoint=use_checkpoint, 486 | use_scale_shift_norm=use_scale_shift_norm, 487 | ) 488 | ] 489 | ch = int(inner_channel * mult) 490 | if ds in attn_res: 491 | layers.append( 492 | AttentionBlock( 493 | ch, 494 | use_checkpoint=use_checkpoint, 495 | num_heads=num_heads_upsample, 496 | num_head_channels=num_head_channels, 497 | use_new_attention_order=use_new_attention_order, 498 | ) 499 | ) 500 | if level and i == res_blocks: 501 | out_ch = ch 502 | layers.append( 503 | ResBlock( 504 | ch, 505 | cond_embed_dim, 506 | dropout, 507 | out_channel=out_ch, 508 | use_checkpoint=use_checkpoint, 509 | use_scale_shift_norm=use_scale_shift_norm, 510 | up=True, 511 | ) 512 | if resblock_updown 513 | else Upsample(ch, conv_resample, out_channel=out_ch) 514 | ) 515 | ds //= 2 516 | self.output_blocks.append(EmbedSequential(*layers)) 517 | self._feature_size += ch 518 | 519 | self.out = nn.Sequential( 520 | normalization(ch), 521 | SiLU(), 522 | zero_module(nn.Conv2d(input_ch, out_channel, 3, padding=1)), 523 | ) 524 | 525 | def forward(self, x, gammas): 526 | """ 527 | Apply the model to an input batch. 528 | :param x: an [N x 2 x ...] Tensor of inputs (B&W) 529 | :param gammas: a 1-D batch of gammas. 530 | :return: an [N x C x ...] Tensor of outputs. 531 | """ 532 | hs = [] 533 | gammas = gammas.view(-1, ) 534 | emb = self.cond_embed(gamma_embedding(gammas, self.inner_channel)) 535 | 536 | h = x.type(torch.float32) 537 | for module in self.input_blocks: 538 | h = module(h, emb) 539 | hs.append(h) 540 | h = self.middle_block(h, emb) 541 | for module in self.output_blocks: 542 | h = torch.cat([h, hs.pop()], dim=1) 543 | h = module(h, emb) 544 | h = h.type(x.dtype) 545 | return self.out(h) 546 | 547 | if __name__ == '__main__': 548 | b, c, h, w = 3, 6, 64, 64 549 | timsteps = 100 550 | model = UNet( 551 | image_size=h, 552 | in_channel=c, 553 | inner_channel=64, 554 | out_channel=3, 555 | res_blocks=2, 556 | attn_res=[8] 557 | ) 558 | x = torch.randn((b, c, h, w)) 559 | emb = torch.ones((b, )) 560 | out = model(x, emb) -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | # class mse_loss(nn.Module): 7 | # def __init__(self) -> None: 8 | # super().__init__() 9 | # self.loss_fn = nn.MSELoss() 10 | # def forward(self, output, target): 11 | # return self.loss_fn(output, target) 12 | 13 | 14 | def mse_loss(output, target): 15 | return F.mse_loss(output, target) 16 | 17 | 18 | class FocalLoss(nn.Module): 19 | def __init__(self, gamma=2, alpha=None, size_average=True): 20 | super(FocalLoss, self).__init__() 21 | self.gamma = gamma 22 | self.alpha = alpha 23 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) 24 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 25 | self.size_average = size_average 26 | 27 | def forward(self, input, target): 28 | if input.dim()>2: 29 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 30 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 31 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 32 | target = target.view(-1,1) 33 | 34 | logpt = F.log_softmax(input) 35 | logpt = logpt.gather(1,target) 36 | logpt = logpt.view(-1) 37 | pt = Variable(logpt.data.exp()) 38 | 39 | if self.alpha is not None: 40 | if self.alpha.type()!=input.data.type(): 41 | self.alpha = self.alpha.type_as(input.data) 42 | at = self.alpha.gather(0,target.data.view(-1)) 43 | logpt = logpt * Variable(at) 44 | 45 | loss = -1 * (1-pt)**self.gamma * logpt 46 | if self.size_average: return loss.mean() 47 | else: return loss.sum() 48 | 49 | -------------------------------------------------------------------------------- /models/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from torch.nn import functional as F 5 | import torch.utils.data 6 | 7 | from torchvision.models.inception import inception_v3 8 | 9 | import numpy as np 10 | from scipy.stats import entropy 11 | 12 | def mae(input, target): 13 | with torch.no_grad(): 14 | loss = nn.L1Loss() 15 | output = loss(input, target) 16 | return output 17 | 18 | 19 | def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=1): 20 | """Computes the inception score of the generated images imgs 21 | 22 | imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1] 23 | cuda -- whether or not to run on GPU 24 | batch_size -- batch size for feeding into Inception v3 25 | splits -- number of splits 26 | """ 27 | N = len(imgs) 28 | 29 | assert batch_size > 0 30 | assert N > batch_size 31 | 32 | # Set up dtype 33 | if cuda: 34 | dtype = torch.cuda.FloatTensor 35 | else: 36 | if torch.cuda.is_available(): 37 | print("WARNING: You have a CUDA device, so you should probably set cuda=True") 38 | dtype = torch.FloatTensor 39 | 40 | # Set up dataloader 41 | dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) 42 | 43 | # Load inception model 44 | inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype) 45 | inception_model.eval() 46 | up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype) 47 | def get_pred(x): 48 | if resize: 49 | x = up(x) 50 | x = inception_model(x) 51 | return F.softmax(x).data.cpu().numpy() 52 | 53 | # Get predictions 54 | preds = np.zeros((N, 1000)) 55 | 56 | for i, batch in enumerate(dataloader, 0): 57 | batch = batch.type(dtype) 58 | batchv = Variable(batch) 59 | batch_size_i = batch.size()[0] 60 | 61 | preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv) 62 | 63 | # Now compute the mean kl-div 64 | split_scores = [] 65 | 66 | for k in range(splits): 67 | part = preds[k * (N // splits): (k+1) * (N // splits), :] 68 | py = np.mean(part, axis=0) 69 | scores = [] 70 | for i in range(part.shape[0]): 71 | pyx = part[i, :] 72 | scores.append(entropy(pyx, py)) 73 | split_scores.append(np.exp(np.mean(scores))) 74 | 75 | return np.mean(split_scores), np.std(split_scores) -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | from core.base_model import BaseModel 4 | from core.logger import LogTracker 5 | import copy 6 | class EMA(): 7 | def __init__(self, beta=0.9999): 8 | super().__init__() 9 | self.beta = beta 10 | def update_model_average(self, ma_model, current_model): 11 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 12 | old_weight, up_weight = ma_params.data, current_params.data 13 | ma_params.data = self.update_average(old_weight, up_weight) 14 | def update_average(self, old, new): 15 | if old is None: 16 | return new 17 | return old * self.beta + (1 - self.beta) * new 18 | 19 | class Palette(BaseModel): 20 | def __init__(self, networks, losses, sample_num, task, optimizers, ema_scheduler=None, **kwargs): 21 | ''' must to init BaseModel with kwargs ''' 22 | super(Palette, self).__init__(**kwargs) 23 | 24 | ''' networks, dataloder, optimizers, losses, etc. ''' 25 | self.loss_fn = losses[0] 26 | self.netG = networks[0] 27 | if ema_scheduler is not None: 28 | self.ema_scheduler = ema_scheduler 29 | self.netG_EMA = copy.deepcopy(self.netG) 30 | self.EMA = EMA(beta=self.ema_scheduler['ema_decay']) 31 | else: 32 | self.ema_scheduler = None 33 | 34 | ''' networks can be a list, and must convert by self.set_device function if using multiple GPU. ''' 35 | self.netG = self.set_device(self.netG, distributed=self.opt['distributed']) 36 | if self.ema_scheduler is not None: 37 | self.netG_EMA = self.set_device(self.netG_EMA, distributed=self.opt['distributed']) 38 | self.load_networks() 39 | 40 | self.optG = torch.optim.Adam(list(filter(lambda p: p.requires_grad, self.netG.parameters())), **optimizers[0]) 41 | self.optimizers.append(self.optG) 42 | self.resume_training() 43 | 44 | if self.opt['distributed']: 45 | self.netG.module.set_loss(self.loss_fn) 46 | self.netG.module.set_new_noise_schedule(phase=self.phase) 47 | else: 48 | self.netG.set_loss(self.loss_fn) 49 | self.netG.set_new_noise_schedule(phase=self.phase) 50 | 51 | ''' can rewrite in inherited class for more informations logging ''' 52 | self.train_metrics = LogTracker(*[m.__name__ for m in losses], phase='train') 53 | self.val_metrics = LogTracker(*[m.__name__ for m in self.metrics], phase='val') 54 | self.test_metrics = LogTracker(*[m.__name__ for m in self.metrics], phase='test') 55 | 56 | self.sample_num = sample_num 57 | self.task = task 58 | 59 | def set_input(self, data): 60 | ''' must use set_device in tensor ''' 61 | self.cond_image = self.set_device(data.get('cond_image')) 62 | self.gt_image = self.set_device(data.get('gt_image')) 63 | self.mask = self.set_device(data.get('mask')) 64 | self.mask_image = data.get('mask_image') 65 | self.path = data['path'] 66 | self.batch_size = len(data['path']) 67 | 68 | def get_current_visuals(self, phase='train'): 69 | dict = { 70 | 'gt_image': (self.gt_image.detach()[:].float().cpu()+1)/2, 71 | 'cond_image': (self.cond_image.detach()[:].float().cpu()+1)/2, 72 | } 73 | if self.task in ['inpainting','uncropping']: 74 | dict.update({ 75 | 'mask': self.mask.detach()[:].float().cpu(), 76 | 'mask_image': (self.mask_image+1)/2, 77 | }) 78 | if phase != 'train': 79 | dict.update({ 80 | 'output': (self.output.detach()[:].float().cpu()+1)/2 81 | }) 82 | return dict 83 | 84 | def save_current_results(self): 85 | ret_path = [] 86 | ret_result = [] 87 | for idx in range(self.batch_size): 88 | ret_path.append('GT_{}'.format(self.path[idx])) 89 | ret_result.append(self.gt_image[idx].detach().float().cpu()) 90 | 91 | ret_path.append('Process_{}'.format(self.path[idx])) 92 | ret_result.append(self.visuals[idx::self.batch_size].detach().float().cpu()) 93 | 94 | ret_path.append('Out_{}'.format(self.path[idx])) 95 | ret_result.append(self.visuals[idx-self.batch_size].detach().float().cpu()) 96 | 97 | if self.task in ['inpainting','uncropping']: 98 | ret_path.extend(['Mask_{}'.format(name) for name in self.path]) 99 | ret_result.extend(self.mask_image) 100 | 101 | self.results_dict = self.results_dict._replace(name=ret_path, result=ret_result) 102 | return self.results_dict._asdict() 103 | 104 | def train_step(self): 105 | self.netG.train() 106 | self.train_metrics.reset() 107 | for train_data in tqdm.tqdm(self.phase_loader): 108 | self.set_input(train_data) 109 | self.optG.zero_grad() 110 | loss = self.netG(self.gt_image, self.cond_image, mask=self.mask) 111 | loss.backward() 112 | self.optG.step() 113 | 114 | self.iter += self.batch_size 115 | self.writer.set_iter(self.epoch, self.iter, phase='train') 116 | self.train_metrics.update(self.loss_fn.__name__, loss.item()) 117 | if self.iter % self.opt['train']['log_iter'] == 0: 118 | for key, value in self.train_metrics.result().items(): 119 | self.logger.info('{:5s}: {}\t'.format(str(key), value)) 120 | self.writer.add_scalar(key, value) 121 | for key, value in self.get_current_visuals().items(): 122 | self.writer.add_images(key, value) 123 | if self.ema_scheduler is not None: 124 | if self.iter > self.ema_scheduler['ema_start'] and self.iter % self.ema_scheduler['ema_iter'] == 0: 125 | self.EMA.update_model_average(self.netG_EMA, self.netG) 126 | 127 | for scheduler in self.schedulers: 128 | scheduler.step() 129 | return self.train_metrics.result() 130 | 131 | def val_step(self): 132 | self.netG.eval() 133 | self.val_metrics.reset() 134 | with torch.no_grad(): 135 | for val_data in tqdm.tqdm(self.val_loader): 136 | self.set_input(val_data) 137 | if self.opt['distributed']: 138 | if self.task in ['inpainting','uncropping']: 139 | self.output, self.visuals = self.netG.module.restoration(self.cond_image, y_t=self.cond_image, 140 | y_0=self.gt_image, mask=self.mask, sample_num=self.sample_num) 141 | else: 142 | self.output, self.visuals = self.netG.module.restoration(self.cond_image, sample_num=self.sample_num) 143 | else: 144 | if self.task in ['inpainting','uncropping']: 145 | self.output, self.visuals = self.netG.restoration(self.cond_image, y_t=self.cond_image, 146 | y_0=self.gt_image, mask=self.mask, sample_num=self.sample_num) 147 | else: 148 | self.output, self.visuals = self.netG.restoration(self.cond_image, sample_num=self.sample_num) 149 | 150 | self.iter += self.batch_size 151 | self.writer.set_iter(self.epoch, self.iter, phase='val') 152 | 153 | for met in self.metrics: 154 | key = met.__name__ 155 | value = met(self.gt_image, self.output) 156 | self.val_metrics.update(key, value) 157 | self.writer.add_scalar(key, value) 158 | for key, value in self.get_current_visuals(phase='val').items(): 159 | self.writer.add_images(key, value) 160 | self.writer.save_images(self.save_current_results()) 161 | 162 | return self.val_metrics.result() 163 | 164 | def test(self): 165 | self.netG.eval() 166 | self.test_metrics.reset() 167 | with torch.no_grad(): 168 | for phase_data in tqdm.tqdm(self.phase_loader): 169 | self.set_input(phase_data) 170 | if self.opt['distributed']: 171 | if self.task in ['inpainting','uncropping']: 172 | self.output, self.visuals = self.netG.module.restoration(self.cond_image, y_t=self.cond_image, 173 | y_0=self.gt_image, mask=self.mask, sample_num=self.sample_num) 174 | else: 175 | self.output, self.visuals = self.netG.module.restoration(self.cond_image, sample_num=self.sample_num) 176 | else: 177 | if self.task in ['inpainting','uncropping']: 178 | self.output, self.visuals = self.netG.restoration(self.cond_image, y_t=self.cond_image, 179 | y_0=self.gt_image, mask=self.mask, sample_num=self.sample_num) 180 | else: 181 | self.output, self.visuals = self.netG.restoration(self.cond_image, sample_num=self.sample_num) 182 | 183 | self.iter += self.batch_size 184 | self.writer.set_iter(self.epoch, self.iter, phase='test') 185 | for met in self.metrics: 186 | key = met.__name__ 187 | value = met(self.gt_image, self.output) 188 | self.test_metrics.update(key, value) 189 | self.writer.add_scalar(key, value) 190 | for key, value in self.get_current_visuals(phase='test').items(): 191 | self.writer.add_images(key, value) 192 | self.writer.save_images(self.save_current_results()) 193 | 194 | test_log = self.test_metrics.result() 195 | ''' save logged informations into log dict ''' 196 | test_log.update({'epoch': self.epoch, 'iters': self.iter}) 197 | 198 | ''' print logged informations to the screen and tensorboard ''' 199 | for key, value in test_log.items(): 200 | self.logger.info('{:5s}: {}\t'.format(str(key), value)) 201 | 202 | def load_networks(self): 203 | """ save pretrained model and training state, which only do on GPU 0. """ 204 | if self.opt['distributed']: 205 | netG_label = self.netG.module.__class__.__name__ 206 | else: 207 | netG_label = self.netG.__class__.__name__ 208 | self.load_network(network=self.netG, network_label=netG_label, strict=False) 209 | if self.ema_scheduler is not None: 210 | self.load_network(network=self.netG_EMA, network_label=netG_label+'_ema', strict=False) 211 | 212 | def save_everything(self): 213 | """ load pretrained model and training state. """ 214 | if self.opt['distributed']: 215 | netG_label = self.netG.module.__class__.__name__ 216 | else: 217 | netG_label = self.netG.__class__.__name__ 218 | self.save_network(network=self.netG, network_label=netG_label) 219 | if self.ema_scheduler is not None: 220 | self.save_network(network=self.netG_EMA, network_label=netG_label+'_ema') 221 | self.save_training_state() 222 | -------------------------------------------------------------------------------- /models/network.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from inspect import isfunction 4 | from functools import partial 5 | import numpy as np 6 | from tqdm import tqdm 7 | from core.base_network import BaseNetwork 8 | class Network(BaseNetwork): 9 | def __init__(self, unet, beta_schedule, module_name='sr3', **kwargs): 10 | super(Network, self).__init__(**kwargs) 11 | if module_name == 'sr3': 12 | from .sr3_modules.unet import UNet 13 | elif module_name == 'guided_diffusion': 14 | from .guided_diffusion_modules.unet import UNet 15 | 16 | self.denoise_fn = UNet(**unet) 17 | self.beta_schedule = beta_schedule 18 | 19 | def set_loss(self, loss_fn): 20 | self.loss_fn = loss_fn 21 | 22 | def set_new_noise_schedule(self, device=torch.device('cuda'), phase='train'): 23 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 24 | betas = make_beta_schedule(**self.beta_schedule[phase]) 25 | betas = betas.detach().cpu().numpy() if isinstance( 26 | betas, torch.Tensor) else betas 27 | alphas = 1. - betas 28 | 29 | timesteps, = betas.shape 30 | self.num_timesteps = int(timesteps) 31 | 32 | gammas = np.cumprod(alphas, axis=0) 33 | gammas_prev = np.append(1., gammas[:-1]) 34 | 35 | # calculations for diffusion q(x_t | x_{t-1}) and others 36 | self.register_buffer('gammas', to_torch(gammas)) 37 | self.register_buffer('sqrt_recip_gammas', to_torch(np.sqrt(1. / gammas))) 38 | self.register_buffer('sqrt_recipm1_gammas', to_torch(np.sqrt(1. / gammas - 1))) 39 | 40 | # calculations for posterior q(x_{t-1} | x_t, x_0) 41 | posterior_variance = betas * (1. - gammas_prev) / (1. - gammas) 42 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 43 | self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) 44 | self.register_buffer('posterior_mean_coef1', to_torch(betas * np.sqrt(gammas_prev) / (1. - gammas))) 45 | self.register_buffer('posterior_mean_coef2', to_torch((1. - gammas_prev) * np.sqrt(alphas) / (1. - gammas))) 46 | 47 | def predict_start_from_noise(self, y_t, t, noise): 48 | return ( 49 | extract(self.sqrt_recip_gammas, t, y_t.shape) * y_t - 50 | extract(self.sqrt_recipm1_gammas, t, y_t.shape) * noise 51 | ) 52 | 53 | def q_posterior(self, y_0_hat, y_t, t): 54 | posterior_mean = ( 55 | extract(self.posterior_mean_coef1, t, y_t.shape) * y_0_hat + 56 | extract(self.posterior_mean_coef2, t, y_t.shape) * y_t 57 | ) 58 | posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, y_t.shape) 59 | return posterior_mean, posterior_log_variance_clipped 60 | 61 | def p_mean_variance(self, y_t, t, clip_denoised: bool, y_cond=None): 62 | noise_level = extract(self.gammas, t, x_shape=(1, 1)).to(y_t.device) 63 | y_0_hat = self.predict_start_from_noise( 64 | y_t, t=t, noise=self.denoise_fn(torch.cat([y_cond, y_t], dim=1), noise_level)) 65 | 66 | if clip_denoised: 67 | y_0_hat.clamp_(-1., 1.) 68 | 69 | model_mean, posterior_log_variance = self.q_posterior( 70 | y_0_hat=y_0_hat, y_t=y_t, t=t) 71 | return model_mean, posterior_log_variance 72 | 73 | def q_sample(self, y_0, sample_gammas, noise=None): 74 | noise = default(noise, lambda: torch.randn_like(y_0)) 75 | return ( 76 | sample_gammas.sqrt() * y_0 + 77 | (1 - sample_gammas).sqrt() * noise 78 | ) 79 | 80 | @torch.no_grad() 81 | def p_sample(self, y_t, t, clip_denoised=True, y_cond=None): 82 | model_mean, model_log_variance = self.p_mean_variance( 83 | y_t=y_t, t=t, clip_denoised=clip_denoised, y_cond=y_cond) 84 | noise = torch.randn_like(y_t) if any(t>0) else torch.zeros_like(y_t) 85 | return model_mean + noise * (0.5 * model_log_variance).exp() 86 | 87 | @torch.no_grad() 88 | def restoration(self, y_cond, y_t=None, y_0=None, mask=None, sample_num=8): 89 | b, *_ = y_cond.shape 90 | 91 | assert self.num_timesteps > sample_num, 'num_timesteps must greater than sample_num' 92 | sample_inter = (self.num_timesteps//sample_num) 93 | 94 | y_t = default(y_t, lambda: torch.randn_like(y_cond)) 95 | ret_arr = y_t 96 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 97 | t = torch.full((b,), i, device=y_cond.device, dtype=torch.long) 98 | y_t = self.p_sample(y_t, t, y_cond=y_cond) 99 | if mask is not None: 100 | y_t = y_0*(1.-mask) + mask*y_t 101 | if i % sample_inter == 0: 102 | ret_arr = torch.cat([ret_arr, y_t], dim=0) 103 | return y_t, ret_arr 104 | 105 | def forward(self, y_0, y_cond=None, mask=None, noise=None): 106 | # sampling from p(gammas) 107 | b, *_ = y_0.shape 108 | t = torch.randint(1, self.num_timesteps, (b,), device=y_0.device).long() 109 | gamma_t1 = extract(self.gammas, t-1, x_shape=(1, 1)) 110 | sqrt_gamma_t2 = extract(self.gammas, t, x_shape=(1, 1)) 111 | sample_gammas = (sqrt_gamma_t2-gamma_t1) * torch.rand((b, 1), device=y_0.device) + gamma_t1 112 | sample_gammas = sample_gammas.view(b, -1) 113 | 114 | noise = default(noise, lambda: torch.randn_like(y_0)) 115 | y_noisy = self.q_sample( 116 | y_0=y_0, sample_gammas=sample_gammas.view(-1, 1, 1, 1), noise=noise) 117 | 118 | if mask is not None: 119 | noise_hat = self.denoise_fn(torch.cat([y_cond, y_noisy*mask+(1.-mask)*y_0], dim=1), sample_gammas) 120 | loss = self.loss_fn(mask*noise, mask*noise_hat) 121 | else: 122 | noise_hat = self.denoise_fn(torch.cat([y_cond, y_noisy], dim=1), sample_gammas) 123 | loss = self.loss_fn(noise, noise_hat) 124 | return loss 125 | 126 | 127 | # gaussian diffusion trainer class 128 | def exists(x): 129 | return x is not None 130 | 131 | def default(val, d): 132 | if exists(val): 133 | return val 134 | return d() if isfunction(d) else d 135 | 136 | def extract(a, t, x_shape=(1,1,1,1)): 137 | b, *_ = t.shape 138 | out = a.gather(-1, t) 139 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 140 | 141 | # beta_schedule function 142 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac): 143 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 144 | warmup_time = int(n_timestep * warmup_frac) 145 | betas[:warmup_time] = np.linspace( 146 | linear_start, linear_end, warmup_time, dtype=np.float64) 147 | return betas 148 | 149 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-6, linear_end=1e-2, cosine_s=8e-3): 150 | if schedule == 'quad': 151 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, 152 | n_timestep, dtype=np.float64) ** 2 153 | elif schedule == 'linear': 154 | betas = np.linspace(linear_start, linear_end, 155 | n_timestep, dtype=np.float64) 156 | elif schedule == 'warmup10': 157 | betas = _warmup_beta(linear_start, linear_end, 158 | n_timestep, 0.1) 159 | elif schedule == 'warmup50': 160 | betas = _warmup_beta(linear_start, linear_end, 161 | n_timestep, 0.5) 162 | elif schedule == 'const': 163 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 164 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 165 | betas = 1. / np.linspace(n_timestep, 166 | 1, n_timestep, dtype=np.float64) 167 | elif schedule == "cosine": 168 | timesteps = ( 169 | torch.arange(n_timestep + 1, dtype=torch.float64) / 170 | n_timestep + cosine_s 171 | ) 172 | alphas = timesteps / (1 + cosine_s) * math.pi / 2 173 | alphas = torch.cos(alphas).pow(2) 174 | alphas = alphas / alphas[0] 175 | betas = 1 - alphas[1:] / alphas[:-1] 176 | betas = betas.clamp(max=0.999) 177 | else: 178 | raise NotImplementedError(schedule) 179 | return betas 180 | 181 | 182 | -------------------------------------------------------------------------------- /models/sr3_modules/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from inspect import isfunction 5 | 6 | class UNet(nn.Module): 7 | def __init__( 8 | self, 9 | in_channel=6, 10 | out_channel=3, 11 | inner_channel=32, 12 | norm_groups=32, 13 | channel_mults=(1, 2, 4, 8, 8), 14 | attn_res=(8), 15 | res_blocks=3, 16 | dropout=0, 17 | with_noise_level_emb=True, 18 | image_size=128 19 | ): 20 | super().__init__() 21 | 22 | if with_noise_level_emb: 23 | noise_level_channel = inner_channel 24 | self.noise_level_mlp = nn.Sequential( 25 | PositionalEncoding(inner_channel), 26 | nn.Linear(inner_channel, inner_channel * 4), 27 | Swish(), 28 | nn.Linear(inner_channel * 4, inner_channel) 29 | ) 30 | else: 31 | noise_level_channel = None 32 | self.noise_level_mlp = None 33 | 34 | num_mults = len(channel_mults) 35 | pre_channel = inner_channel 36 | feat_channels = [pre_channel] 37 | now_res = image_size 38 | downs = [nn.Conv2d(in_channel, inner_channel, 39 | kernel_size=3, padding=1)] 40 | for ind in range(num_mults): 41 | is_last = (ind == num_mults - 1) 42 | use_attn = (now_res in attn_res) 43 | channel_mult = inner_channel * channel_mults[ind] 44 | for _ in range(0, res_blocks): 45 | downs.append(ResnetBlocWithAttn( 46 | pre_channel, channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn)) 47 | feat_channels.append(channel_mult) 48 | pre_channel = channel_mult 49 | if not is_last: 50 | downs.append(Downsample(pre_channel)) 51 | feat_channels.append(pre_channel) 52 | now_res = now_res//2 53 | self.downs = nn.ModuleList(downs) 54 | 55 | self.mid = nn.ModuleList([ 56 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 57 | dropout=dropout, with_attn=True), 58 | ResnetBlocWithAttn(pre_channel, pre_channel, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 59 | dropout=dropout, with_attn=False) 60 | ]) 61 | 62 | ups = [] 63 | for ind in reversed(range(num_mults)): 64 | is_last = (ind < 1) 65 | use_attn = (now_res in attn_res) 66 | channel_mult = inner_channel * channel_mults[ind] 67 | for _ in range(0, res_blocks+1): 68 | ups.append(ResnetBlocWithAttn( 69 | pre_channel+feat_channels.pop(), channel_mult, noise_level_emb_dim=noise_level_channel, norm_groups=norm_groups, 70 | dropout=dropout, with_attn=use_attn)) 71 | pre_channel = channel_mult 72 | if not is_last: 73 | ups.append(Upsample(pre_channel)) 74 | now_res = now_res*2 75 | 76 | self.ups = nn.ModuleList(ups) 77 | 78 | self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups) 79 | 80 | def forward(self, x, time): 81 | t = self.noise_level_mlp(time) if exists( 82 | self.noise_level_mlp) else None 83 | 84 | feats = [] 85 | for layer in self.downs: 86 | if isinstance(layer, ResnetBlocWithAttn): 87 | x = layer(x, t) 88 | else: 89 | x = layer(x) 90 | feats.append(x) 91 | 92 | for layer in self.mid: 93 | if isinstance(layer, ResnetBlocWithAttn): 94 | x = layer(x, t) 95 | else: 96 | x = layer(x) 97 | 98 | for layer in self.ups: 99 | if isinstance(layer, ResnetBlocWithAttn): 100 | x = layer(torch.cat((x, feats.pop()), dim=1), t) 101 | else: 102 | x = layer(x) 103 | 104 | return self.final_conv(x) 105 | 106 | 107 | # PositionalEncoding Source: https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py 108 | class PositionalEncoding(nn.Module): 109 | def __init__(self, dim): 110 | super().__init__() 111 | self.dim = dim 112 | 113 | def forward(self, noise_level): 114 | count = self.dim // 2 115 | step = torch.arange(count, dtype=noise_level.dtype, device=noise_level.device) / count 116 | encoding = noise_level.unsqueeze(1) * torch.exp(-math.log(1e4) * step.unsqueeze(0)) 117 | encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1) 118 | return encoding 119 | 120 | 121 | class FeatureWiseAffine(nn.Module): 122 | def __init__(self, in_channels, out_channels, use_affine_level=False): 123 | super(FeatureWiseAffine, self).__init__() 124 | self.use_affine_level = use_affine_level 125 | self.noise_func = nn.Sequential( 126 | nn.Linear(in_channels, out_channels*(1+self.use_affine_level)) 127 | ) 128 | 129 | def forward(self, x, noise_embed): 130 | batch = x.shape[0] 131 | if self.use_affine_level: 132 | gamma, beta = self.noise_func(noise_embed).view(batch, -1, 1, 1).chunk(2, dim=1) 133 | x = (1 + gamma) * x + beta 134 | else: 135 | x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1) 136 | return x 137 | 138 | 139 | class Swish(nn.Module): 140 | def forward(self, x): 141 | return x * torch.sigmoid(x) 142 | 143 | 144 | class Upsample(nn.Module): 145 | def __init__(self, dim): 146 | super().__init__() 147 | self.up = nn.Upsample(scale_factor=2, mode="nearest") 148 | self.conv = nn.Conv2d(dim, dim, 3, padding=1) 149 | 150 | def forward(self, x): 151 | return self.conv(self.up(x)) 152 | 153 | 154 | class Downsample(nn.Module): 155 | def __init__(self, dim): 156 | super().__init__() 157 | self.conv = nn.Conv2d(dim, dim, 3, 2, 1) 158 | 159 | def forward(self, x): 160 | return self.conv(x) 161 | 162 | 163 | # building block modules 164 | 165 | 166 | class Block(nn.Module): 167 | def __init__(self, dim, dim_out, groups=32, dropout=0): 168 | super().__init__() 169 | self.block = nn.Sequential( 170 | nn.GroupNorm(groups, dim), 171 | Swish(), 172 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(), 173 | nn.Conv2d(dim, dim_out, 3, padding=1) 174 | ) 175 | 176 | def forward(self, x): 177 | return self.block(x) 178 | 179 | 180 | class ResnetBlock(nn.Module): 181 | def __init__(self, dim, dim_out, noise_level_emb_dim=None, dropout=0, use_affine_level=False, norm_groups=32): 182 | super().__init__() 183 | self.noise_func = FeatureWiseAffine(noise_level_emb_dim, dim_out, use_affine_level) 184 | 185 | self.block1 = Block(dim, dim_out, groups=norm_groups) 186 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) 187 | self.res_conv = nn.Conv2d( 188 | dim, dim_out, 1) if dim != dim_out else nn.Identity() 189 | 190 | def forward(self, x, time_emb): 191 | b, c, h, w = x.shape 192 | h = self.block1(x) 193 | h = self.noise_func(h, time_emb) 194 | h = self.block2(h) 195 | return h + self.res_conv(x) 196 | 197 | 198 | class SelfAttention(nn.Module): 199 | def __init__(self, in_channel, n_head=1, norm_groups=32): 200 | super().__init__() 201 | 202 | self.n_head = n_head 203 | 204 | self.norm = nn.GroupNorm(norm_groups, in_channel) 205 | self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False) 206 | self.out = nn.Conv2d(in_channel, in_channel, 1) 207 | 208 | def forward(self, input): 209 | batch, channel, height, width = input.shape 210 | n_head = self.n_head 211 | head_dim = channel // n_head 212 | 213 | norm = self.norm(input) 214 | qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width) 215 | query, key, value = qkv.chunk(3, dim=2) # bhdyx 216 | 217 | attn = torch.einsum("bnchw, bncyx -> bnhwyx", query, key).contiguous() / math.sqrt(channel) 218 | attn = attn.view(batch, n_head, height, width, -1) 219 | attn = torch.softmax(attn, -1) 220 | attn = attn.view(batch, n_head, height, width, height, width) 221 | 222 | out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous() 223 | out = self.out(out.view(batch, channel, height, width)) 224 | 225 | return out + input 226 | 227 | 228 | class ResnetBlocWithAttn(nn.Module): 229 | def __init__(self, dim, dim_out, *, noise_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False): 230 | super().__init__() 231 | self.with_attn = with_attn 232 | self.res_block = ResnetBlock( 233 | dim, dim_out, noise_level_emb_dim, norm_groups=norm_groups, dropout=dropout) 234 | if with_attn: 235 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups) 236 | 237 | def forward(self, x, time_emb): 238 | x = self.res_block(x, time_emb) 239 | if(self.with_attn): 240 | x = self.attn(x) 241 | return x 242 | 243 | 244 | def exists(x): 245 | return x is not None 246 | 247 | 248 | def default(val, d): 249 | if exists(val): 250 | return val 251 | return d() if isfunction(d) else d 252 | -------------------------------------------------------------------------------- /preprocess/mirflickr25k_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from sklearn.model_selection import train_test_split 4 | import cv2 5 | 6 | def convert_abl(ab, l): 7 | """ convert AB and L to RGB """ 8 | l = np.expand_dims(l, axis=3) 9 | lab = np.concatenate([l, ab], axis=3) 10 | if len(lab.shape)==4: 11 | image_color, image_l = [], [] 12 | for _color, _l in zip(lab, l): 13 | out = cv2.cvtColor(_color.astype('uint8'), cv2.COLOR_LAB2RGB) 14 | out = cv2.cvtColor(out, cv2.COLOR_RGB2BGR) 15 | image_color.append(out) 16 | image_l.append(cv2.cvtColor(_l.astype('uint8'), cv2.COLOR_GRAY2RGB)) 17 | image_color = np.array(image_color) 18 | image_l = np.array(image_l) 19 | else: 20 | image_color = cv2.cvtColor(lab.astype('uint8'), cv2.COLOR_LAB2RGB) 21 | image_l = cv2.cvtColor(l.astype('uint8'), cv2.COLOR_GRAY2RGB) 22 | return image_color, image_l 23 | 24 | def load_data(home): 25 | ab1 = np.load(os.path.join(home,"ab/ab", "ab1.npy")) 26 | ab2 = np.load(os.path.join(home, "ab/ab", "ab2.npy")) 27 | ab3 = np.load(os.path.join(home,"ab/ab", "ab3.npy")) 28 | ab = np.concatenate([ab1, ab2, ab3], axis=0) 29 | l = np.load(os.path.join(home,"l/gray_scale.npy")) 30 | return ab, l 31 | 32 | if __name__ == '__main__': 33 | home = './' # path saved .npy 34 | flist_save_path = './flist' 35 | image_save_path = './images' # images save path 36 | 37 | all_color, all_l = load_data(home) 38 | image_color, image_l = convert_abl(all_color, all_l) 39 | 40 | color_save_path, gray_save_path = '{}/color'.format(image_save_path), '{}/gray'.format(image_save_path) 41 | os.makedirs(color_save_path, exist_ok=True) 42 | os.makedirs(gray_save_path, exist_ok=True) 43 | for i in range(image_color.shape[0]): 44 | cv2.imwrite('{}/{}.png'.format(color_save_path, str(i).zfill(5)), image_color[i]) 45 | for i in range(image_l.shape[0]): 46 | cv2.imwrite('{}/{}.png'.format(gray_save_path, str(i).zfill(5)), image_l[i]) 47 | 48 | os.makedirs(flist_save_path, exist_ok=True) 49 | arr = np.random.permutation(25000) 50 | with open('{}/train.flist'.format(flist_save_path), 'w') as f: 51 | for item in arr[:24000]: 52 | print(str(item).zfill(5), file=f) 53 | with open('{}/test.flist'.format(flist_save_path), 'w') as f: 54 | for item in arr[24000:]: 55 | print(str(item).zfill(5), file=f) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.6 2 | torchvision 3 | numpy 4 | pandas 5 | tqdm 6 | tensorboardX>=1.14 7 | scipy 8 | opencv-python 9 | clean-fid 10 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | import torch 5 | import torch.multiprocessing as mp 6 | 7 | from core.logger import VisualWriter, InfoLogger 8 | import core.praser as Praser 9 | import core.util as Util 10 | from data import define_dataloader 11 | from models import create_model, define_network, define_loss, define_metric 12 | 13 | def main_worker(gpu, ngpus_per_node, opt): 14 | """ threads running on each GPU """ 15 | if 'local_rank' not in opt: 16 | opt['local_rank'] = opt['global_rank'] = gpu 17 | if opt['distributed']: 18 | torch.cuda.set_device(int(opt['local_rank'])) 19 | print('using GPU {} for training'.format(int(opt['local_rank']))) 20 | torch.distributed.init_process_group(backend = 'nccl', 21 | init_method = opt['init_method'], 22 | world_size = opt['world_size'], 23 | rank = opt['global_rank'], 24 | group_name='mtorch' 25 | ) 26 | '''set seed and and cuDNN environment ''' 27 | torch.backends.cudnn.enabled = True 28 | warnings.warn('You have chosen to use cudnn for accleration. torch.backends.cudnn.enabled=True') 29 | Util.set_seed(opt['seed']) 30 | 31 | ''' set logger ''' 32 | phase_logger = InfoLogger(opt) 33 | phase_writer = VisualWriter(opt, phase_logger) 34 | phase_logger.info('Create the log file in directory {}.\n'.format(opt['path']['experiments_root'])) 35 | 36 | '''set networks and dataset''' 37 | phase_loader, val_loader = define_dataloader(phase_logger, opt) # val_loader is None if phase is test. 38 | networks = [define_network(phase_logger, opt, item_opt) for item_opt in opt['model']['which_networks']] 39 | 40 | ''' set metrics, loss, optimizer and schedulers ''' 41 | metrics = [define_metric(phase_logger, item_opt) for item_opt in opt['model']['which_metrics']] 42 | losses = [define_loss(phase_logger, item_opt) for item_opt in opt['model']['which_losses']] 43 | 44 | model = create_model( 45 | opt = opt, 46 | networks = networks, 47 | phase_loader = phase_loader, 48 | val_loader = val_loader, 49 | losses = losses, 50 | metrics = metrics, 51 | logger = phase_logger, 52 | writer = phase_writer 53 | ) 54 | 55 | phase_logger.info('Begin model {}.'.format(opt['phase'])) 56 | try: 57 | if opt['phase'] == 'train': 58 | model.train() 59 | else: 60 | model.test() 61 | finally: 62 | phase_writer.close() 63 | 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser() 67 | parser.add_argument('-c', '--config', type=str, default='config/colorization_mirflickr25k.json', help='JSON file for configuration') 68 | parser.add_argument('-p', '--phase', type=str, choices=['train','test'], help='Run train or test', default='train') 69 | parser.add_argument('-b', '--batch', type=int, default=None, help='Batch size in every gpu') 70 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None) 71 | parser.add_argument('-d', '--debug', action='store_true') 72 | parser.add_argument('-P', '--port', default='21012', type=str) 73 | 74 | ''' parser configs ''' 75 | args = parser.parse_args() 76 | opt = Praser.parse(args) 77 | 78 | ''' cuda devices ''' 79 | gpu_str = ','.join(str(x) for x in opt['gpu_ids']) 80 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_str 81 | print('export CUDA_VISIBLE_DEVICES={}'.format(gpu_str)) 82 | 83 | ''' use DistributedDataParallel(DDP) and multiprocessing for multi-gpu training''' 84 | # [Todo]: multi GPU on multi machine 85 | if opt['distributed']: 86 | ngpus_per_node = len(opt['gpu_ids']) # or torch.cuda.device_count() 87 | opt['world_size'] = ngpus_per_node 88 | opt['init_method'] = 'tcp://127.0.0.1:'+ args.port 89 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, opt)) 90 | else: 91 | opt['world_size'] = 1 92 | main_worker(0, 1, opt) -------------------------------------------------------------------------------- /slurm/inpainting_places2.slurm: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o experiments/slurm.log 3 | #SBATCH -J base 4 | #SBATCH -p dell 5 | #SBATCH --gres=gpu:4 6 | #SBATCH -c 16 7 | python run.py -c config/inpainting_places2.json -gpu 0,1,2,3 -b 8 8 | --------------------------------------------------------------------------------