├── LICENSE ├── README.md ├── config.json ├── imgs ├── int-dark-side.png ├── int-eye-color.png ├── int-male-female.png ├── new-blob-samples-2.png ├── new-blob-samples.png ├── new-bottom-celeba.png ├── new-bottom-seven-nine.png ├── new-left-celeba.png ├── new-random-celeba.png ├── new-small-missing-celeba.png ├── new-top-celeba.png ├── probinpainting.gif └── summary-gif.gif ├── main.py ├── main_generate.py ├── pixconcnn ├── __init__.py ├── generate.py ├── layers.py ├── models │ ├── __init__.py │ ├── cnn.py │ ├── gated_pixelcnn.py │ └── pixel_constrained.py └── training.py ├── requirements.txt ├── trained_models ├── celeba │ ├── config.json │ └── model.pt └── mnist │ ├── config.json │ └── model.pt └── utils ├── __init__.py ├── dataloaders.py ├── init_models.py ├── loading.py ├── masks.py └── plots.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019, Emilien Dupont & Suhas Suresha 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Probabilistic Semantic Inpainting with Pixel Constrained CNNs 2 | 3 | Pytorch implementation of [Probabilistic Semantic Inpainting with Pixel Constrained CNNs](https://arxiv.org/abs/1810.03728) (2018). 4 | 5 | This repo contains an implementation of Pixel Constrained CNN, a framework for performing probabilistic inpainting of images with arbitrary occlusions. It also includes all code to reproduce the experiments in the paper as well as the weights of the trained models. 6 | 7 | For a TensorFlow implementation, see this [repo](https://github.com/Schlumberger/pixel-constrained-cnn-tf). 8 | 9 | ## Examples 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | ## Usage 24 | 25 | ### Training 26 | 27 | The attributes of the model can be set in the `config.json` file. To train the model, run 28 | 29 | ``` 30 | python main.py config.json 31 | ``` 32 | 33 | This will also save the trained model and log various information as training progresses. Examples of `config.json` files are available in the `trained_models` directory. 34 | 35 | ### Inpainting 36 | 37 | To generate images with a trained model use `main_generate.py`. As an example, the following command generates 64 completions for images 73 and 84 in the MNIST dataset by conditioning on the top 14 rows. The model used to generate the completions is the trained MNIST model included in this repo and the results are saved to the `mnist_inpaintings` folder. 38 | 39 | ``` 40 | python main_generate.py -n mnist_inpaintings -m trained_models/mnist -t generation -i 73 84 -to 14 -ns 64 41 | ``` 42 | 43 | For a full list of options, run `python main_generate.py --help`. Note that if you do not have the MNIST dataset on your machine it will be automatically downloaded when running the above command. The CelebA dataset will have to be manually downloaded (see the Data sources section). If you already have the datasets downloaded, you can change the paths in `utils/dataloaders.py` to point to the correct folders on your machine. 44 | 45 | ## Trained models 46 | 47 | The trained models referenced in the paper are included in the `trained_models` folder. You can use the `main_generate.py` script to generate image completions (and other plots) with these models. 48 | 49 | ## Data sources 50 | 51 | The MNIST dataset can be automatically downloaded using `torchvision`. The CelebA dataset can be found [here](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). 52 | 53 | ## Citing 54 | 55 | If you find this work useful in your research, please cite using: 56 | 57 | ``` 58 | @article{dupont2018probabilistic, 59 | title={Probabilistic Semantic Inpainting with Pixel Constrained CNNs}, 60 | author={Dupont, Emilien and Suresha, Suhas}, 61 | journal={arXiv preprint arXiv:1810.03728}, 62 | year={2018} 63 | } 64 | ``` 65 | 66 | ## More examples 67 | 68 | 69 | 70 | 71 | 72 | 73 | ## License 74 | 75 | [Apache License 2.0](LICENSE) 76 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "mnist_model", 3 | "dataset": "mnist", 4 | "resize": 28, 5 | "crop": 28, 6 | "grayscale": true, 7 | "batch_size": 64, 8 | "constrained": true, 9 | "num_colors": 2, 10 | "filter_size": 5, 11 | "depth": 16, 12 | "num_filters_cond": 32, 13 | "num_filters_prior": 32, 14 | "lr": 4e-4, 15 | "epochs": 50, 16 | "mask_descriptor": ["random_rect", [10, 10]], 17 | "num_conds": 4, 18 | "num_samples": 64, 19 | "weight_cond_logits_loss": 1.0, 20 | "weight_prior_logits_loss": 0.0 21 | } 22 | -------------------------------------------------------------------------------- /imgs/int-dark-side.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/int-dark-side.png -------------------------------------------------------------------------------- /imgs/int-eye-color.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/int-eye-color.png -------------------------------------------------------------------------------- /imgs/int-male-female.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/int-male-female.png -------------------------------------------------------------------------------- /imgs/new-blob-samples-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-blob-samples-2.png -------------------------------------------------------------------------------- /imgs/new-blob-samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-blob-samples.png -------------------------------------------------------------------------------- /imgs/new-bottom-celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-bottom-celeba.png -------------------------------------------------------------------------------- /imgs/new-bottom-seven-nine.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-bottom-seven-nine.png -------------------------------------------------------------------------------- /imgs/new-left-celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-left-celeba.png -------------------------------------------------------------------------------- /imgs/new-random-celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-random-celeba.png -------------------------------------------------------------------------------- /imgs/new-small-missing-celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-small-missing-celeba.png -------------------------------------------------------------------------------- /imgs/new-top-celeba.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-top-celeba.png -------------------------------------------------------------------------------- /imgs/probinpainting.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/probinpainting.gif -------------------------------------------------------------------------------- /imgs/summary-gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/summary-gif.gif -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import json 3 | import os 4 | import sys 5 | import time 6 | import torch 7 | from pixconcnn.training import Trainer, PixelConstrainedTrainer 8 | from torchvision.utils import save_image 9 | from utils.dataloaders import mnist, celeba 10 | from utils.init_models import initialize_model 11 | from utils.masks import batch_random_mask, get_repeated_conditional_pixels, MaskGenerator 12 | 13 | 14 | # Set device 15 | cuda = torch.cuda.is_available() 16 | device = torch.device("cuda" if cuda else "cpu") 17 | 18 | # Get config file from command line arguments 19 | if len(sys.argv) != 2: 20 | raise(RuntimeError("Wrong arguments, use python main.py ")) 21 | config_path = sys.argv[1] 22 | 23 | # Open config file 24 | with open(config_path) as config_file: 25 | config = json.load(config_file) 26 | 27 | name = config['name'] 28 | constrained = config['constrained'] 29 | batch_size = config['batch_size'] 30 | lr = config['lr'] 31 | num_colors = config['num_colors'] 32 | epochs = config['epochs'] 33 | dataset = config['dataset'] 34 | resize = config['resize'] # Only relevant for celeba 35 | crop = config['crop'] # Only relevant for celeba 36 | grayscale = config["grayscale"] # Only relevant for celeba 37 | num_conds = config['num_conds'] # Only relevant if constrained 38 | num_samples = config['num_samples'] # Only relevant if constrained 39 | filter_size = config['filter_size'] 40 | depth = config['depth'] 41 | num_filters_cond = config['num_filters_cond'] 42 | num_filters_prior = config['num_filters_prior'] 43 | mask_descriptor = config['mask_descriptor'] 44 | weight_cond_logits_loss = config['weight_cond_logits_loss'] 45 | weight_prior_logits_loss = config['weight_prior_logits_loss'] 46 | 47 | # Create a folder to store experiment results 48 | timestamp = time.strftime("%Y-%m-%d_%H-%M") 49 | directory = "{}_{}".format(timestamp, name) 50 | if not os.path.exists(directory): 51 | os.makedirs(directory) 52 | 53 | # Save config file in experiment directory 54 | with open(directory + '/config.json', 'w') as config_file: 55 | json.dump(config, config_file) 56 | 57 | # Get data 58 | if dataset == 'mnist': 59 | data_loader, _ = mnist(batch_size, num_colors=num_colors, size=resize) 60 | img_size = (1, resize, resize) 61 | elif dataset == 'celeba': 62 | data_loader = celeba(batch_size, num_colors=num_colors, size=resize, 63 | crop=crop, grayscale=grayscale) 64 | if grayscale: 65 | img_size = (1, resize, resize) 66 | else: 67 | img_size = (3, resize, resize) 68 | 69 | # Initialize model weights and architecture 70 | model = initialize_model(img_size, 71 | num_colors, 72 | depth, 73 | filter_size, 74 | constrained, 75 | num_filters_prior, 76 | num_filters_cond) 77 | model.to(device) 78 | print(model) 79 | 80 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 81 | 82 | if constrained: 83 | mask_generator = MaskGenerator(img_size, mask_descriptor) 84 | trainer = PixelConstrainedTrainer(model, optimizer, device, mask_generator, 85 | weight_cond_logits_loss=weight_cond_logits_loss, 86 | weight_prior_logits_loss=weight_prior_logits_loss) 87 | # Train model 88 | progress_imgs = trainer.train(data_loader, epochs, directory=directory) 89 | 90 | # Get a random batch of images 91 | for batch, _ in data_loader: 92 | break 93 | 94 | for i in range(num_conds): 95 | mask = mask_generator.get_masks(batch_size) 96 | print('Generating {}/{} conditionings'.format(i + 1, num_conds)) 97 | cond_pixels = get_repeated_conditional_pixels(batch[i:i+1], mask[i:i+1], 98 | num_colors, num_samples) 99 | # Save mask as tensor 100 | torch.save(mask[i:i+1], directory + '/mask{}.pt'.format(i)) 101 | # Save image that gave rise to the conditioning as tensor 102 | torch.save(batch[i:i+1], directory + '/source{}.pt'.format(i)) 103 | # Save conditional pixels as tensor and image 104 | torch.save(cond_pixels[0:1], directory + '/cond_pixels{}.pt'.format(i)) 105 | save_image(cond_pixels[0:1], directory + '/cond_pixels{}.png'.format(i)) 106 | 107 | cond_pixels = cond_pixels.to(device) 108 | samples = model.sample(cond_pixels) 109 | # Save samples and mean sample as tensor and image 110 | torch.save(samples, directory + '/samples_cond{}.pt'.format(i)) 111 | save_image(samples.float() / (num_colors - 1.), 112 | directory + '/samples_cond{}.png'.format(i)) 113 | save_image(samples.float().mean(dim=0) / (num_colors - 1.), 114 | directory + '/mean_cond{}.png'.format(i)) 115 | # Save conditional logits if image is binary 116 | if num_colors == 2: 117 | # Save conditional logits 118 | logits, _, cond_logits = model(batch[i:i+1].float().to(device), cond_pixels[0:1]) 119 | # Second dimension corresponds to different pixel values, so select probs of it being 1 120 | save_image(cond_logits[:, 1], directory + '/prob_of_one_cond{}.png'.format(i)) 121 | # Second dimension corresponds to different pixel values, so select probs of it being 1 122 | save_image(logits[:, 1], directory + '/prob_of_one_logits{}.png'.format(i)) 123 | else: 124 | trainer = Trainer(model, optimizer, device) 125 | progress_imgs = trainer.train(data_loader, epochs, directory=directory) 126 | 127 | # Save losses and plots of them 128 | with open(directory + '/losses.json', 'w') as losses_file: 129 | json.dump(trainer.losses, losses_file) 130 | 131 | # Save model 132 | torch.save(trainer.model.state_dict(), directory + '/model.pt') 133 | 134 | # Save gif of progress 135 | imageio.mimsave(directory + '/training.gif', progress_imgs, fps=24) 136 | -------------------------------------------------------------------------------- /main_generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import time 5 | import torch 6 | import torch.nn.functional as F 7 | from PIL import Image 8 | from pixconcnn.generate import generate_images 9 | from utils.loading import load_model 10 | from utils.masks import get_conditional_pixels 11 | from utils.plots import uncertainty_plot, probs_and_conditional_plot 12 | from torchvision.utils import save_image 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | # Parse command line arguments 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('-m', '--model-folder', dest='model_folder', default=None, 19 | help='Path to trained pixel constrained model folder', 20 | required=True) 21 | parser.add_argument('-n', '--name', dest='name', default=None, 22 | help='Name of generation experiment', required=True) 23 | parser.add_argument('-t', '--gen_type', dest='gen_type', default=None, 24 | choices=['generation','logits','uncertainty'], 25 | help='Type of generation', required=True) 26 | parser.add_argument('-i', '--imgs', dest='imgs_idx', default=None, 27 | type=int, nargs='+', 28 | help='List of indices of images to perform generation with.') 29 | parser.add_argument('-ns', '--num-samples', dest='num_samples', default=None, 30 | type=int, help='Number of samples to generate for each image-mask combo.', 31 | required=True) 32 | parser.add_argument('-te', '--temp', dest='temp', default=1., 33 | type=float, help='Sampling temperature.') 34 | parser.add_argument('-v', '--model-version', dest='model_version', default=None, 35 | type=int, help='Version of model if not using the latest one.') 36 | 37 | parser.add_argument('-nr', '--num-row', dest='num_per_row', default=8, 38 | type=int, help='Number of images per row in grids.') 39 | parser.add_argument('-ni', '--num-iterations', dest='num_iters', default=None, 40 | type=int, help='Only relevant for logits. Number of iterations to plot intermediate logits for.') 41 | parser.add_argument('-r', '--random', dest='random_attribute', default=None, 42 | type=int, help='Number of random pixels to keep unmasked.') 43 | parser.add_argument('-b', '--bottom', dest='bottom_attribute', default=None, 44 | type=int, help='Number of bottom pixels to keep unmasked.') 45 | parser.add_argument('-to', '--top', dest='top_attribute', default=None, 46 | type=int, help='Number of top pixels to keep unmasked.') 47 | parser.add_argument('-c', '--center', dest='center_attribute', default=None, 48 | type=int, help='Number of central pixels to keep unmasked.') 49 | parser.add_argument('-e', '--edge', dest='edge_attribute', default=None, 50 | type=int, help='Number of edge pixels to keep unmasked.') 51 | parser.add_argument('-l', '--left', dest='left_attribute', default=None, 52 | type=int, help='Number of left pixels to keep unmasked.') 53 | parser.add_argument('-ri', '--right', dest='right_attribute', default=None, 54 | type=int, help='Number of right pixels to keep unmasked.') 55 | parser.add_argument('-rb', '--random-blob', dest='blob_attribute', default=None, 56 | type=int, nargs='+', help='First int should be maximum number of blobs, second lower bound on num_iters and third upper bound on num_iters.') 57 | parser.add_argument('-mf', '--mask-folder', dest='folder_attribute', default=None, 58 | help='Mask folder if using a cached mask.') 59 | 60 | # Unpack args 61 | args = parser.parse_args() 62 | 63 | # Create a folder to store generation results 64 | timestamp = time.strftime("%Y-%m-%d_%H-%M") 65 | directory = "gen_{}_{}".format(timestamp, args.name) 66 | if not os.path.exists(directory): 67 | os.makedirs(directory) 68 | 69 | # Save args 70 | with open(directory + '/args.json', 'w') as args_file: 71 | json.dump(vars(args), args_file) 72 | 73 | # Load model 74 | model, data_loader, _ = load_model(args.model_folder, model_version=args.model_version) 75 | 76 | # Convert input arguments to mask_descriptors 77 | mask_descriptors = [] 78 | if args.random_attribute is not None: 79 | mask_descriptors.append(('random', args.random_attribute)) 80 | if args.bottom_attribute is not None: 81 | mask_descriptors.append(('bottom', args.bottom_attribute)) 82 | if args.top_attribute is not None: 83 | mask_descriptors.append(('top', args.top_attribute)) 84 | if args.center_attribute is not None: 85 | mask_descriptors.append(('center', args.center_attribute)) 86 | if args.edge_attribute is not None: 87 | mask_descriptors.append(('edge', args.edge_attribute)) 88 | if args.left_attribute is not None: 89 | mask_descriptors.append(('left', args.left_attribute)) 90 | if args.right_attribute is not None: 91 | mask_descriptors.append(('right', args.right_attribute)) 92 | if args.blob_attribute is not None: 93 | max_num_blobs, lower_iter, upper_iter = args.blob_attribute 94 | mask_descriptors.append(('random_blob', (max_num_blobs, (lower_iter, upper_iter), 0.5))) 95 | if args.folder_attribute is not None: 96 | mask_descriptors.append(('random_blob_cache', (args.folder_attribute, 1))) 97 | 98 | imgs_idx = args.imgs_idx 99 | num_img = len(imgs_idx) 100 | total_num_imgs = args.num_samples * num_img * len(mask_descriptors) 101 | print("\nGenerating {} samples for {} different images combined with {} masks for a total of {} images".format(args.num_samples, num_img, len(mask_descriptors), total_num_imgs)) 102 | print("\nThe masks are {}\n".format(mask_descriptors)) 103 | 104 | # Create a batch from the images in imgs_idx 105 | batch = torch.stack([data_loader.dataset[img_idx][0] for img_idx in imgs_idx], dim=0) 106 | 107 | if args.gen_type == 'generation' or args.gen_type == 'uncertainty': 108 | # Generate images with model 109 | outputs = generate_images(model, batch, mask_descriptors, 110 | num_samples=args.num_samples, temp=args.temp, 111 | verbose=True) 112 | 113 | # Save images in folder 114 | for i in range(num_img): 115 | for j in range(len(mask_descriptors)): 116 | output = outputs[i][j] 117 | # Save every output as an image and a pytorch tensor 118 | torch.save(output["orig_img"].cpu(), directory + "/source_{}_{}.pt".format(i, j)) 119 | save_image(output["orig_img"].float() / (model.prior_net.num_colors - 1.), directory + "/source_{}_{}.png".format(i, j)) 120 | torch.save(output["cond_pixels"][0:1].cpu(), directory + "/cond_pixels_{}_{}.pt".format(i, j)) 121 | save_image(output["cond_pixels"][0:1, :3], directory + "/cond_pixels_{}_{}.png".format(i, j)) 122 | torch.save(output["mask"].cpu(), directory + "/mask_{}_{}.pt".format(i, j)) 123 | save_image(output["mask"], directory + "/mask_{}_{}.png".format(i, j)) 124 | save_image(output["samples"].float().mean(dim=0) / (model.prior_net.num_colors - 1.), directory + '/mean_samples_{}_{}.png'.format(i, j)) 125 | torch.save(output["samples"].cpu(), directory + "/samples_{}_{}.pt".format(i, j)) 126 | torch.save(output["log_probs"].cpu(), directory + "/log_probs_{}_{}.pt".format(i, j)) 127 | if args.gen_type == 'generation': 128 | save_image(output["samples"].float() / (model.prior_net.num_colors - 1.), directory + "/samples_{}_{}.png".format(i, j), nrow=args.num_per_row, pad_value=1) 129 | elif args.gen_type == 'uncertainty': 130 | sorted_samples, log_likelihoods = uncertainty_plot(output["samples"], output["log_probs"]) 131 | save_image(sorted_samples.float() / (model.prior_net.num_colors - 1.), directory + "/sorted_samples_{}_{}.png".format(i, j), nrow=args.num_per_row, pad_value=1) 132 | save_image(log_likelihoods, directory + "/log_likelihoods_{}_{}.png".format(i, j), pad_value=1, nrow=args.num_per_row) 133 | elif args.gen_type == 'logits': # Note this only works for binary images 134 | if model.prior_net.num_colors != 2: 135 | raise(RuntimeError("Logits generation only works for models with 2 colors. Current model has {} colors.".format(model.prior_net.num_colors))) 136 | # Generate images with model 137 | outputs = generate_images(model, batch, mask_descriptors, 138 | num_samples=args.num_samples, verbose=True, 139 | temp=args.temp) 140 | # Extract info 141 | img_size = model.prior_net.img_size 142 | num_pixels = img_size[1] * img_size[2] 143 | pix_per_iters = num_pixels // args.num_iters # Number of pixels to unmask per iteration 144 | # Save images in folder 145 | for i in range(num_img): 146 | for j in range(len(mask_descriptors)): 147 | output = outputs[i][j] 148 | mask = output["mask"] 149 | mask = mask.expand(args.num_samples, *mask.size()[1:]) 150 | samples = output["samples"] 151 | cond_pixels = get_conditional_pixels(samples, mask.float(), 2) 152 | cond_pixels = cond_pixels.to(device) 153 | mask = mask.to(device) 154 | samples = samples.to(device) 155 | logit_plots = {} 156 | for k in range(args.num_iters): 157 | logit_plots[k] = {} 158 | mask_step = mask.clone() 159 | if k != 0: # Do not modify mask for first iteration 160 | # Unmask num_pix_to_unmask pixels in raster order 161 | num_pix_to_unmask = k * pix_per_iters 162 | num_rows = num_pix_to_unmask // img_size[2] 163 | num_cols = num_pix_to_unmask - num_rows * img_size[2] 164 | if num_rows != 0: 165 | mask_step[:, :, :num_rows, :] = 1 166 | if num_cols != 0: 167 | mask_step[:, :, num_rows, :num_cols] = 1 168 | # Calculate logits with updated mask 169 | logits, prior_logits, cond_logits = model(samples.float() * mask_step.float(), cond_pixels) 170 | probs = F.softmax(logits.detach(), dim=1) 171 | # Create a plot for each sample 172 | for l in range(args.num_samples): 173 | logit_plot = probs_and_conditional_plot(samples[l].cpu(), 174 | probs[l, 1, 0].cpu(), 175 | mask_step[0, 0].cpu()) 176 | logit_plots[k][l] = logit_plot 177 | 178 | # Plot iterations as a grid for each sample 179 | for l in range(args.num_samples): 180 | logit_grid = [] 181 | for k in range(args.num_iters): 182 | logit_grid.append(logit_plots[k][l]) 183 | stacked_images = torch.stack(logit_grid, dim=0) 184 | save_image(stacked_images, directory + "/logit_plot_{}_{}_{}.png".format(i, j, l), pad_value=1) 185 | torch.save(stacked_images.cpu(), directory + "/logit_plot_{}_{}_{}.pt".format(i, j, l)) 186 | -------------------------------------------------------------------------------- /pixconcnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/pixconcnn/__init__.py -------------------------------------------------------------------------------- /pixconcnn/generate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import device as torch_device 3 | from torch import zeros as torch_zeros 4 | from torch.cuda import is_available as cuda_is_available 5 | from torchvision.utils import make_grid 6 | from utils.masks import MaskGenerator, get_repeated_conditional_pixels 7 | 8 | 9 | def generate_images(model, batch, mask_descriptors, num_samples=64, temp=1., 10 | verbose=False): 11 | """Generates image completions based on the images in batch masked by the 12 | masks in mask_descriptors. This will generate 13 | batch.size(0) * len(mask_descriptors) * num_samples completions, i.e. 14 | num_samples completions for every image and mask combination. 15 | 16 | Parameters 17 | ---------- 18 | model : pixconcnn.models.pixel_constrained.PixelConstrained instance 19 | 20 | batch : torch.Tensor 21 | 22 | mask_descriptors : list of mask_descriptor 23 | See utils.masks.MaskGenerator for allowed mask_descriptors. 24 | 25 | num_samples : int 26 | Number of samples to generate for a given image-mask combination. 27 | 28 | temp : float 29 | Temperature for sampling. 30 | 31 | verbose : bool 32 | If True prints progress information while generating images 33 | """ 34 | device = torch_device("cuda" if cuda_is_available() else "cpu") 35 | model.to(device) 36 | outputs = [] 37 | for i in range(batch.size(0)): 38 | outputs_per_img = [] 39 | for j in range(len(mask_descriptors)): 40 | if verbose: 41 | print("Generating samples for image {} using mask {}".format(i, mask_descriptors[j])) 42 | # Get image and mask combination 43 | img = batch[i:i+1] 44 | mask_generator = MaskGenerator(model.prior_net.img_size, mask_descriptors[j]) 45 | mask = mask_generator.get_masks(1) 46 | # Create conditional pixels which will be used to sample completions 47 | cond_pixels = get_repeated_conditional_pixels(img, mask, model.prior_net.num_colors, num_samples) 48 | cond_pixels = cond_pixels.to(device) 49 | samples, log_probs = model.sample(cond_pixels, return_likelihood=True, temp=temp) 50 | outputs_per_img.append({ 51 | "orig_img": img, 52 | "cond_pixels": cond_pixels, 53 | "mask": mask, 54 | "samples": samples, 55 | "log_probs": log_probs 56 | }) 57 | outputs.append(outputs_per_img) 58 | return outputs 59 | -------------------------------------------------------------------------------- /pixconcnn/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ResidualBlock(nn.Module): 6 | """ 7 | Residual block (note that number of in_channels and out_channels must be 8 | the same). 9 | 10 | Parameters 11 | ---------- 12 | in_channels : int 13 | 14 | out_channels : int 15 | 16 | kernel_size : int or tuple of ints 17 | 18 | stride : int 19 | 20 | padding : int 21 | """ 22 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding): 23 | super(ResidualBlock, self).__init__() 24 | 25 | self.convs = nn.Sequential( 26 | nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, 27 | padding=0), 28 | nn.BatchNorm2d(in_channels), 29 | nn.ReLU(True), 30 | nn.Conv2d(in_channels, out_channels, 31 | kernel_size=kernel_size, stride=stride, 32 | padding=padding), 33 | nn.BatchNorm2d(out_channels), 34 | nn.ReLU(True), 35 | nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, 36 | padding=0), 37 | nn.BatchNorm2d(out_channels), 38 | nn.ReLU(True) 39 | ) 40 | 41 | def forward(self, x): 42 | # In and out channels should be the same 43 | return x + self.convs(x) 44 | 45 | 46 | class MaskedConv2d(nn.Conv2d): 47 | """ 48 | Implements various 2d masked convolutions. 49 | 50 | Parameters 51 | ---------- 52 | mask_type : string 53 | Defines the type of mask to use. One of 'A', 'A_Red', 'A_Green', 54 | 'A_Blue', 'B', 'B_Red', 'B_Green', 'B_Blue', 'H', 'H_Red', 'H_Green', 55 | 'H_Blue', 'HR', 'HR_Red', 'HR_Green', 'HR_Blue', 'V', 'VR'. 56 | """ 57 | def __init__(self, mask_type, *args, **kwargs): 58 | super(MaskedConv2d, self).__init__(*args, **kwargs) 59 | self.mask_type = mask_type 60 | 61 | # Initialize mask 62 | mask = torch.zeros(*self.weight.size()) 63 | _, kernel_c, kernel_h, kernel_w = self.weight.size() 64 | # If using a color mask, the number of channels must be divisible by 3 65 | if mask_type.endswith('Red') or mask_type.endswith('Green') or mask_type.endswith('Blue'): 66 | assert kernel_c % 3 == 0 67 | # If using a horizontal mask, the kernel height must be 1 68 | if mask_type.startswith('H'): 69 | assert kernel_h == 1 # Kernel should have shape (1, kernel_w) 70 | 71 | if mask_type == 'A': 72 | # For 3 by 3 kernel, this would be: 73 | # 1 1 1 74 | # 1 0 0 75 | # 0 0 0 76 | mask[:, :, kernel_h // 2, :kernel_w // 2] = 1. 77 | mask[:, :, :kernel_h // 2, :] = 1. 78 | elif mask_type == 'A_Red': 79 | # Mask type A for red channels. Same as regular mask A 80 | if kernel_h == 1 and kernel_w == 1: 81 | pass # Mask is all zeros for 1x1 convolution 82 | else: 83 | mask[:, :, kernel_h // 2, :kernel_w // 2] = 1. 84 | mask[:, :, :kernel_h // 2, :] = 1. 85 | elif mask_type == 'A_Green': 86 | # Mask type A for green channels. Same as regular mask A, except 87 | # central pixel of first third of channels is 1 88 | if kernel_h == 1 and kernel_w == 1: 89 | mask[:, :kernel_c // 3, 0, 0] = 1. 90 | else: 91 | mask[:, :, kernel_h // 2, :kernel_w // 2] = 1. 92 | mask[:, :, :kernel_h // 2, :] = 1. 93 | mask[:, :kernel_c // 3, kernel_h // 2, kernel_w // 2] = 1. 94 | elif mask_type == 'A_Blue': 95 | # Mask type A for blue channels. Same as regular mask A, except 96 | # central pixel of first two thirds of channels is 1 97 | if kernel_h == 1 and kernel_w == 1: 98 | mask[:, :2 * kernel_c // 3, 0, 0] = 1. 99 | else: 100 | mask[:, :, kernel_h // 2, :kernel_w // 2] = 1. 101 | mask[:, :, :kernel_h // 2, :] = 1. 102 | mask[:, :2 * kernel_c // 3, kernel_h // 2, kernel_w // 2] = 1. 103 | elif mask_type == 'B': 104 | # For 3 by 3 kernel, this would be: 105 | # 1 1 1 106 | # 1 1 0 107 | # 0 0 0 108 | mask[:, :, kernel_h // 2, :kernel_w // 2 + 1] = 1. 109 | mask[:, :, :kernel_h // 2, :] = 1. 110 | elif mask_type == 'B_Red': 111 | # Mask type B for red channels. Same as regular mask B, except last 112 | # two thirds of channels of central pixels are 0. Alternatively, 113 | # same as Mask A but with first third of channels of central pixels 114 | # are 1 115 | if kernel_h == 1 and kernel_w == 1: 116 | mask[:, :kernel_c // 3, 0, 0] = 1. 117 | else: 118 | mask[:, :, kernel_h // 2, :kernel_w // 2] = 1. 119 | mask[:, :, :kernel_h // 2, :] = 1. 120 | mask[:, :kernel_c // 3, kernel_h // 2, kernel_w // 2] = 1. 121 | elif mask_type == 'B_Green': 122 | # Mask type B for green channels. Same as regular mask B, except 123 | # last third of channels of central pixels are 0 124 | if kernel_h == 1 and kernel_w == 1: 125 | mask[:, :2 * kernel_c // 3, 0, 0] = 1. 126 | else: 127 | mask[:, :, kernel_h // 2, :kernel_w // 2] = 1. 128 | mask[:, :, :kernel_h // 2, :] = 1. 129 | mask[:, :2 * kernel_c // 3, kernel_h // 2, kernel_w // 2] = 1. 130 | elif mask_type == 'B_Blue': 131 | # Mask type B for blue channels. Same as regular mask B 132 | if kernel_h == 1 and kernel_w == 1: 133 | mask[:, :, 0, 0] = 1. 134 | else: 135 | mask[:, :, kernel_h // 2, :kernel_w // 2] = 1. 136 | mask[:, :, :kernel_h // 2, :] = 1. 137 | mask[:, :, kernel_h // 2, kernel_w // 2] = 1. 138 | elif mask_type == 'H': 139 | # For 3 by 3 kernel, this would be: 140 | # 1 1 0 141 | # Mask for horizontal stack in regular gated conv 142 | mask[:, :, 0, :kernel_w // 2 + 1] = 1. 143 | elif mask_type == 'H_Red': 144 | mask[:, :, 0, :kernel_w // 2] = 1. 145 | mask[:, :kernel_c // 3, 0, kernel_w // 2] = 1. 146 | elif mask_type == 'H_Green': 147 | mask[:, :, 0, :kernel_w // 2] = 1. 148 | mask[:, :2 * kernel_c // 3, 0, kernel_w // 2] = 1. 149 | elif mask_type == 'H_Blue': 150 | mask[:, :, 0, :kernel_w // 2] = 1. 151 | mask[:, :, 0, kernel_w // 2] = 1. 152 | elif mask_type == 'HR': 153 | # For 3 by 3 kernel, this would be: 154 | # 1 0 0 155 | # Mask for horizontal stack in restricted gated conv 156 | mask[:, :, 0, :kernel_w // 2] = 1. 157 | elif mask_type == 'HR_Red': 158 | mask[:, :, 0, :kernel_w // 2] = 1. 159 | elif mask_type == 'HR_Green': 160 | mask[:, :, 0, :kernel_w // 2] = 1. 161 | mask[:, :kernel_c // 3, 0, kernel_w // 2] = 1. 162 | elif mask_type == 'HR_Blue': 163 | mask[:, :, 0, :kernel_w // 2] = 1. 164 | mask[:, :2 * kernel_c // 3, 0, kernel_w // 2] = 1. 165 | elif mask_type == 'V': 166 | # For 3 by 3 kernel, this would be: 167 | # 1 1 1 168 | # 1 1 1 169 | # 0 0 0 170 | mask[:, :, :kernel_h // 2 + 1, :] = 1. 171 | elif mask_type == 'VR': 172 | # For 3 by 3 kernel, this would be: 173 | # 1 1 1 174 | # 0 0 0 175 | # 0 0 0 176 | mask[:, :, :kernel_h // 2, :] = 1. 177 | 178 | # Register buffer adds a key to the state dict of the model. This will 179 | # track the attribute without registering it as a learnable parameter. 180 | # We require this since mask will be used in the forward pass. 181 | self.register_buffer('mask', mask) 182 | 183 | def forward(self, x): 184 | self.weight.data *= self.mask 185 | return super(MaskedConv2d, self).forward(x) 186 | 187 | 188 | class MaskedConvRGB(nn.Module): 189 | """ 190 | Masked convolution with RGB channel splitting. 191 | 192 | Parameters 193 | ---------- 194 | mask_type : string 195 | One of 'A', 'B', 'V' or 'H'. 196 | 197 | in_channels : int 198 | Must be divisible by 3 199 | 200 | out_channels : int 201 | Must be divisible by 3 202 | 203 | kernel_size : int or tuple of ints 204 | 205 | stride : int 206 | 207 | padding : int 208 | 209 | bias : bool 210 | If True adds a bias term to the convolution. 211 | """ 212 | def __init__(self, mask_type, in_channels, out_channels, kernel_size, 213 | stride, padding, bias): 214 | super(MaskedConvRGB, self).__init__() 215 | 216 | self.conv_R = MaskedConv2d(mask_type + '_Red', in_channels, 217 | out_channels // 3, kernel_size, 218 | stride=stride, padding=padding, bias=bias) 219 | self.conv_G = MaskedConv2d(mask_type + '_Green', in_channels, 220 | out_channels // 3, kernel_size, 221 | stride=stride, padding=padding, bias=bias) 222 | self.conv_B = MaskedConv2d(mask_type + '_Blue', in_channels, 223 | out_channels // 3, kernel_size, 224 | stride=stride, padding=padding, bias=bias) 225 | 226 | def forward(self, x): 227 | out_red = self.conv_R(x) 228 | out_green = self.conv_G(x) 229 | out_blue = self.conv_B(x) 230 | return torch.cat([out_red, out_green, out_blue], dim=1) 231 | 232 | 233 | class GatedConvBlock(nn.Module): 234 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, 235 | restricted=False): 236 | """Gated PixelCNN convolutional block. Note the number of input and 237 | output channels must be the same, unless restricted is True. 238 | 239 | Parameters 240 | ---------- 241 | in_channels : int 242 | 243 | out_channels : int 244 | 245 | kernel_size : int 246 | Note this MUST be int and not tuple like in regular convs. 247 | 248 | stride : int 249 | 250 | padding : int 251 | 252 | restricted : bool 253 | If True, uses restricted masks, otherwise uses regular masks. 254 | """ 255 | super(GatedConvBlock, self).__init__() 256 | 257 | assert type(kernel_size) is int 258 | 259 | self.restricted = restricted 260 | 261 | if restricted: 262 | vertical_mask = 'VR' 263 | horizontal_mask = 'HR' 264 | else: 265 | vertical_mask = 'V' 266 | horizontal_mask = 'H' 267 | 268 | self.vertical_conv = MaskedConv2d(vertical_mask, in_channels, 269 | 2 * out_channels, kernel_size, 270 | stride=stride, padding=padding) 271 | 272 | self.horizontal_conv = MaskedConv2d(horizontal_mask, in_channels, 273 | 2 * out_channels, (1, kernel_size), 274 | stride=stride, padding=(0, padding)) 275 | 276 | self.vertical_to_horizontal = nn.Conv2d(2 * out_channels, 277 | 2 * out_channels, 1) 278 | 279 | self.horizontal_conv_2 = nn.Conv2d(out_channels, out_channels, 1) 280 | 281 | def forward(self, v_input, h_input): 282 | # Vertical stack 283 | v_conv = self.vertical_conv(v_input) 284 | v_out = gated_activation(v_conv) 285 | # Vertical to horizontal 286 | v_to_h = self.vertical_to_horizontal(v_conv) 287 | # Horizontal stack 288 | h_conv = self.horizontal_conv(h_input) 289 | h_conv_activation = gated_activation(h_conv + v_to_h) 290 | h_conv2 = self.horizontal_conv_2(h_conv_activation) 291 | if self.restricted: 292 | h_out = h_conv2 293 | else: 294 | h_out = h_conv2 + h_input 295 | return v_out, h_out 296 | 297 | 298 | class GatedConvBlockRGB(nn.Module): 299 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, 300 | restricted=False): 301 | """Gated PixelCNN convolutional block for RGB images. Note the number of 302 | input and output channels must be the same, unless restricted is True. 303 | 304 | Parameters 305 | ---------- 306 | in_channels : int 307 | 308 | out_channels : int 309 | 310 | kernel_size : int 311 | Note this MUST be int and not tuple like in regular convs. 312 | 313 | stride : int 314 | 315 | padding : int 316 | 317 | restricted : bool 318 | If True, uses restricted masks, otherwise uses regular masks. 319 | """ 320 | super(GatedConvBlockRGB, self).__init__() 321 | 322 | assert type(kernel_size) is int 323 | 324 | self.restricted = restricted 325 | self.out_channels = out_channels 326 | 327 | if restricted: 328 | vertical_mask = 'VR' 329 | horizontal_mask = 'HR' 330 | else: 331 | vertical_mask = 'V' 332 | horizontal_mask = 'H' 333 | 334 | self.vertical_conv = MaskedConv2d(vertical_mask, in_channels, 335 | 2 * out_channels, kernel_size, 336 | stride=stride, padding=padding) 337 | 338 | self.horizontal_conv = MaskedConvRGB(horizontal_mask, in_channels, 339 | 2 * out_channels, (1, kernel_size), 340 | stride=stride, 341 | padding=(0, padding), bias=True) 342 | 343 | self.vertical_to_horizontal = nn.Conv2d(2 * out_channels, 344 | 2 * out_channels, 1) 345 | 346 | self.horizontal_conv_2 = MaskedConvRGB('B', out_channels, out_channels, 347 | (1, 1), stride=1, padding=0, 348 | bias=True) 349 | 350 | def forward(self, v_input, h_input): 351 | # Vertical stack 352 | v_conv = self.vertical_conv(v_input) 353 | v_out = gated_activation(v_conv) 354 | # Vertical to horizontal 355 | v_to_h = self.vertical_to_horizontal(v_conv) 356 | # Horizontal stack 357 | h_conv = self.horizontal_conv(h_input) + v_to_h 358 | # Gated activation must be applied for the R, G and B part of the 359 | # convolutional volume separately to avoid information from different 360 | # channels leaking 361 | channels_third = 2 * self.out_channels // 3 362 | h_conv_activation_R = gated_activation(h_conv[:, :channels_third]) 363 | h_conv_activation_G = gated_activation(h_conv[:, channels_third:2 * channels_third]) 364 | h_conv_activation_B = gated_activation(h_conv[:, 2 * channels_third:]) 365 | h_conv_activation = torch.cat([h_conv_activation_R, 366 | h_conv_activation_G, 367 | h_conv_activation_B], 368 | dim=1) 369 | # 1 by 1 convolution on horizontal stack 370 | h_conv2 = self.horizontal_conv_2(h_conv_activation) 371 | if self.restricted: 372 | h_out = h_conv2 373 | else: 374 | h_out = h_conv2 + h_input 375 | return v_out, h_out 376 | 377 | 378 | def gated_activation(input_vol): 379 | """Applies a gated activation to the convolutional volume. Note that this 380 | activation divides the number of channels by 2. 381 | 382 | Parameters 383 | ---------- 384 | input_vol : torch.Tensor 385 | Input convolutional volume. Shape (batch_size, channels, height, width) 386 | Note that number of channels must be even. 387 | 388 | Returns 389 | ------- 390 | output_vol of shape (batch_size, channels // 2, height, width) 391 | """ 392 | # Extract number of channels from input volume 393 | channels = input_vol.size(1) 394 | # Get activations for first and second half of volume 395 | tanh_activation = torch.tanh(input_vol[:, channels // 2:]) 396 | sigmoid_activation = torch.sigmoid(input_vol[:, :channels // 2]) 397 | return tanh_activation * sigmoid_activation 398 | -------------------------------------------------------------------------------- /pixconcnn/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/pixconcnn/models/__init__.py -------------------------------------------------------------------------------- /pixconcnn/models/cnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from pixconcnn.layers import ResidualBlock 3 | 4 | 5 | class ResNet(nn.Module): 6 | """ResNet (with regular unmasked convolutions) mapping an conditional pixel 7 | inputs to logits. 8 | 9 | Parameters 10 | ---------- 11 | img_size : tuple of ints 12 | Shape of input image. Note that since mask is appended to the masked 13 | image, the img_size given should be (num_channels + 1, height, width). 14 | Note, the extra channel used to store the mask. 15 | 16 | num_colors : int 17 | Number of colors to quantize output into. Typically 256, but can be 18 | lower for e.g. binary images. 19 | 20 | num_filters : int 21 | Number of filters for each convolution layer in model. 22 | 23 | depth : int 24 | Number of layers of model. Must be at least 2 to have an input and 25 | output layer. 26 | 27 | filter_size : int 28 | Size of convolutional filters. 29 | """ 30 | def __init__(self, img_size=(2, 32, 32), num_colors=256, num_filters=32, 31 | depth=17, filter_size=5): 32 | super(ResNet, self).__init__() 33 | 34 | self.depth = depth 35 | self.filter_size = filter_size 36 | self.padding = (filter_size - 1) // 2 37 | self.img_size = img_size 38 | self.num_channels = img_size[0] - 1 # Only output logits for color channels, not mask channel 39 | self.num_colors = num_colors 40 | self.num_filters = num_filters 41 | 42 | layers = [nn.Conv2d(self.num_channels + 1, self.num_filters, 43 | self.filter_size, stride=1, padding=self.padding)] 44 | 45 | for _ in range(self.depth - 2): 46 | layers.append( 47 | ResidualBlock(self.num_filters, self.num_filters, 48 | self.filter_size, stride=1, padding=self.padding) 49 | ) 50 | 51 | # Final layer to output logits 52 | layers.append( 53 | nn.Conv2d(self.num_filters, self.num_colors * self.num_channels, 1) 54 | ) 55 | 56 | self.img_to_pixel_logits = nn.Sequential(*layers) 57 | 58 | def forward(self, x): 59 | _, height, width = self.img_size 60 | # Shape (batch, output_channels, height, width) 61 | logits = self.img_to_pixel_logits(x) 62 | # Shape (batch, num_colors, channels, height, width) 63 | return logits.view(-1, self.num_colors, self.num_channels, height, width) 64 | -------------------------------------------------------------------------------- /pixconcnn/models/gated_pixelcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from abc import ABCMeta, abstractmethod 5 | from pixconcnn.layers import GatedConvBlock, GatedConvBlockRGB, MaskedConvRGB 6 | 7 | 8 | class PixelCNNBaseClass(nn.Module): 9 | """Abstract class defining PixelCNN sampling which is the same for both 10 | single channel and RGB models. 11 | """ 12 | __metaclass__ = ABCMeta 13 | @abstractmethod 14 | def forward(self): 15 | """Forward method is implemented in child class (GatedPixelCNN or 16 | GatedPixelCNNRGB).""" 17 | pass 18 | 19 | def sample(self, device, num_samples=16, temp=1., return_likelihood=False): 20 | """Generates samples from a GatedPixelCNN or GatedPixelCNNRGB. 21 | 22 | Parameters 23 | ---------- 24 | device : torch.device instance 25 | 26 | num_samples : int 27 | Number of samples to generate 28 | 29 | temp : float 30 | Temperature of softmax distribution. Temperatures larger than 1 31 | make the distribution more uniform while temperatures lower than 32 | 1 make the distribution more peaky. 33 | 34 | return_likelihood : bool 35 | If True returns the log likelihood of the samples according to the 36 | model. 37 | """ 38 | # Set model to evaluation mode 39 | self.eval() 40 | 41 | samples = torch.zeros(num_samples, *self.img_size) 42 | samples = samples.to(device) 43 | channels, height, width = self.img_size 44 | 45 | # Sample pixel intensities from a batch of probability distributions 46 | # for each pixel in each channel 47 | with torch.no_grad(): 48 | for i in range(height): 49 | for j in range(width): 50 | for k in range(channels): 51 | logits = self.forward(samples) 52 | probs = F.softmax(logits / temp, dim=1) 53 | # Note that probs has shape 54 | # (batch, num_colors, channels, height, width) 55 | pixel_val = torch.multinomial(probs[:, :, k, i, j], 1) 56 | # The pixel intensities will be given by 0, 1, 2, ..., so 57 | # normalize these to be in 0 - 1 range as this is what the 58 | # model expects. Note that pixel_val has shape (batch, 1) 59 | # so remove last dimension 60 | samples[:, k, i, j] = pixel_val[:, 0].float() / (self.num_colors - 1) 61 | 62 | # Reset model to train mode 63 | self.train() 64 | 65 | # Unnormalize pixels 66 | samples = (samples * (self.num_colors - 1)).long() 67 | 68 | if return_likelihood: 69 | return samples.cpu(), self.log_likelihood(device, samples).cpu() 70 | else: 71 | return samples.cpu() 72 | 73 | def log_likelihood(self, device, samples): 74 | """Calculates log likelihood of samples under model. 75 | 76 | Parameters 77 | ---------- 78 | device : torch.device instance 79 | 80 | samples : torch.Tensor 81 | Batch of images. Shape (batch_size, num_channels, width, height). 82 | Values should be integers in [0, self.prior_net.num_colors - 1]. 83 | """ 84 | # Set model to evaluation mode 85 | self.eval() 86 | 87 | num_samples, num_channels, height, width = samples.size() 88 | log_probs = torch.zeros(num_samples) 89 | log_probs = log_probs.to(device) 90 | 91 | # Normalize samples before passing through model 92 | norm_samples = samples.float() / (self.num_colors - 1) 93 | # Calculate pixel probs according to the model 94 | logits = self.forward(norm_samples) 95 | # Note that probs has shape 96 | # (batch, num_colors, channels, height, width) 97 | probs = F.softmax(logits, dim=1) 98 | 99 | # Calculate probability of each pixel 100 | for i in range(height): 101 | for j in range(width): 102 | for k in range(num_channels): 103 | # Get the batch of true values at pixel (k, i, j) 104 | true_vals = samples[:, k, i, j] 105 | # Get probability assigned by model to true pixel 106 | probs_pixel = probs[:, true_vals, k, i, j][:, 0] 107 | # Add log probs (1e-9 to avoid log(0)) 108 | log_probs += torch.log(probs_pixel + 1e-9) 109 | 110 | # Reset model to train mode 111 | self.train() 112 | 113 | return log_probs 114 | 115 | 116 | 117 | class GatedPixelCNN(PixelCNNBaseClass): 118 | """Gated PixelCNN model for single channel images. 119 | 120 | Parameters 121 | ---------- 122 | img_size : tuple of ints 123 | Shape of input image. E.g. (1, 32, 32) 124 | 125 | num_colors : int 126 | Number of colors to quantize output into. Typically 256, but can be 127 | lower for e.g. binary images. 128 | 129 | num_filters : int 130 | Number of filters for each convolution layer in model. 131 | 132 | depth : int 133 | Number of layers of model. Must be at least 2 to have an input and 134 | output layer. 135 | 136 | filter_size : int 137 | Size of convolutional filters. 138 | """ 139 | def __init__(self, img_size=(1, 32, 32), num_colors=256, num_filters=64, 140 | depth=17, filter_size=5): 141 | super(GatedPixelCNN, self).__init__() 142 | 143 | self.depth = depth 144 | self.filter_size = filter_size 145 | self.padding = (filter_size - 1) // 2 146 | self.img_size = img_size 147 | self.num_channels = img_size[0] 148 | self.num_colors = num_colors 149 | self.num_filters = num_filters 150 | 151 | # First layer is restricted (to avoid self-dependency on input pixel) 152 | self.input_to_stacks = GatedConvBlock(self.num_channels, 153 | self.num_filters, 154 | self.filter_size, 155 | stride=1, 156 | padding=self.padding, 157 | restricted=True) 158 | 159 | # Subsequent layers are regular gated blocks for vertical and horizontal 160 | # stack 161 | gated_stacks = [] 162 | # -2 since we are not counting first and last layer 163 | for _ in range(self.depth - 2): 164 | gated_stacks.append( 165 | GatedConvBlock(self.num_filters, self.num_filters, 166 | self.filter_size, stride=1, padding=self.padding) 167 | ) 168 | self.gated_stacks = nn.ModuleList(gated_stacks) 169 | 170 | # Final layer to output logits 171 | self.stacks_to_pixel_logits = nn.Conv2d(self.num_filters, self.num_colors * self.num_channels, 1) 172 | 173 | def forward(self, x): 174 | # Restricted gated layer 175 | vertical, horizontal = self.input_to_stacks(x, x) 176 | # Iterate over gated layers 177 | for gated_block in self.gated_stacks: 178 | vertical, horizontal = gated_block(vertical, horizontal) 179 | # Output logits from the horizontal stack (i.e. the stack which receives 180 | # information from both horizontal and vertical stack) 181 | logits = self.stacks_to_pixel_logits(horizontal) 182 | 183 | # Reshape logits 184 | _, height, width = self.img_size 185 | # Shape (batch, output_channels, height, width) -> 186 | # (batch, num_colors, channels, height, width) 187 | return logits.view(-1, self.num_colors, self.num_channels, height, width) 188 | 189 | 190 | class GatedPixelCNNRGB(PixelCNNBaseClass): 191 | """Gated PixelCNN model for RGB images. 192 | 193 | Parameters 194 | ---------- 195 | img_size : tuple of ints 196 | Shape of input image. E.g. (3, 32, 32) 197 | 198 | num_colors : int 199 | Number of colors to quantize output into. Typically 256, but can be 200 | lower for e.g. binary images. 201 | 202 | num_filters : int 203 | Number of filters for each convolution layer in model. 204 | 205 | depth : int 206 | Number of layers of model. Must be at least 2 to have an input and 207 | output layer. 208 | 209 | filter_size : int 210 | Size of convolutional filters. 211 | """ 212 | def __init__(self, img_size=(3, 32, 32), num_colors=256, num_filters=64, 213 | depth=17, filter_size=5): 214 | super(GatedPixelCNNRGB, self).__init__() 215 | 216 | self.depth = depth 217 | self.filter_size = filter_size 218 | self.padding = (filter_size - 1) / 2 219 | self.img_size = img_size 220 | self.num_channels = img_size[0] 221 | self.num_colors = num_colors 222 | self.num_filters = num_filters 223 | 224 | # First layer is restricted (to avoid self-dependency on input pixel) 225 | self.input_to_stacks = GatedConvBlockRGB(self.num_channels, 226 | self.num_filters, 227 | self.filter_size, 228 | stride=1, 229 | padding=self.padding, 230 | restricted=True) 231 | 232 | # Subsequent layers are regular gated blocks for vertical and horizontal 233 | # stack 234 | gated_stacks = [] 235 | # -2 since we are not counting first and last layer 236 | for _ in range(self.depth - 2): 237 | gated_stacks.append( 238 | GatedConvBlockRGB(self.num_filters, self.num_filters, 239 | self.filter_size, stride=1, 240 | padding=self.padding) 241 | ) 242 | self.gated_stacks = nn.ModuleList(gated_stacks) 243 | 244 | # Final layer to output logits 245 | self.stacks_to_pixel_logits = nn.Sequential( 246 | MaskedConvRGB('B', self.num_filters, 1023, (1, 1), stride=1, padding=0, bias=True), 247 | nn.ReLU(True), 248 | MaskedConvRGB('B', 1023, self.num_colors * self.num_channels, (1, 1), stride=1, padding=0, bias=True) 249 | ) 250 | 251 | def forward(self, x): 252 | # Restricted gated layer 253 | vertical, horizontal = self.input_to_stacks(x, x) 254 | # Several gated layers 255 | for gated_block in self.gated_stacks: 256 | vertical, horizontal = gated_block(vertical, horizontal) 257 | # Output logits from the horizontal stack (i.e. the stack which receives 258 | # information from both horizontal and vertical stack) 259 | logits = self.stacks_to_pixel_logits(horizontal) 260 | 261 | # Reshape logits maintaining order between R, G and B channels 262 | _, height, width = self.img_size 263 | logits_red = logits[:, :self.num_colors] 264 | logits_green = logits[:, self.num_colors:2 * self.num_colors] 265 | logits_blue = logits[:, 2 * self.num_colors:] 266 | logits_red = logits_red.view(-1, self.num_colors, 1, height, width) 267 | logits_green = logits_green.view(-1, self.num_colors, 1, height, width) 268 | logits_blue = logits_blue.view(-1, self.num_colors, 1, height, width) 269 | # Shape (batch, num_colors, channels, height, width) 270 | return torch.cat([logits_red, logits_green, logits_blue], dim=2) 271 | -------------------------------------------------------------------------------- /pixconcnn/models/pixel_constrained.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PixelConstrained(nn.Module): 7 | """Pixel Constrained CNN model. 8 | 9 | Parameters 10 | ---------- 11 | prior_net : pixconcnn.models.gated_pixelcnn.GatedPixelCNN(RGB) instance 12 | Model defining the prior network. 13 | 14 | cond_net : pixconcnn.models.cnn.ResNet instance 15 | Model defining the conditioning network. 16 | """ 17 | def __init__(self, prior_net, cond_net): 18 | 19 | super(PixelConstrained, self).__init__() 20 | 21 | self.prior_net = prior_net 22 | self.cond_net = cond_net 23 | 24 | def forward(self, x, x_cond): 25 | """ 26 | x : torch.Tensor 27 | Image to predict logits for. 28 | 29 | x_cond : torch.Tensor 30 | Image containing pixels to be conditioned on. The pixels which are 31 | being conditioned on should retain their usual value, while all 32 | other pixels should be set to 0. In addition the mask should be 33 | appended to known pixels such that the shape of x_cond is 34 | (batch_size, channels + 1, height, width). For more details, see 35 | utils.masks.get_conditional_pixels function. 36 | """ 37 | prior_logits = self.prior_net(x) 38 | cond_logits = self.cond_net(x_cond) 39 | logits = prior_logits + cond_logits 40 | return logits, prior_logits, cond_logits 41 | 42 | def sample(self, x_cond, temp=1., return_likelihood=False): 43 | """Generate conditional samples from the model. The number of samples 44 | generated will be equal to the batch size of x_cond. 45 | 46 | Parameters 47 | ---------- 48 | x_cond : torch.Tensor 49 | Tensor containing the conditioning pixels. Should have shape 50 | (num_samples, channels + 1, height, width). 51 | 52 | temp : float 53 | Temperature of softmax distribution. Temperatures larger than 1 54 | make the distribution more uniforman while temperatures lower than 55 | 1 make the distribution more peaky. 56 | 57 | return_likelihood : bool 58 | If True returns the log likelihood of the samples according to the 59 | model. 60 | """ 61 | # Set model to evaluation mode 62 | self.eval() 63 | 64 | # "Channel dimension" of x_cond has size channels + 1, so decrease this 65 | num_samples, channels_plus_mask, height, width = x_cond.size() 66 | channels = channels_plus_mask - 1 67 | 68 | # Samples to be generated 69 | samples = torch.zeros((num_samples, channels, height, width)) 70 | # Move samples to same device as conditional tensor 71 | samples = samples.to(x_cond.device) 72 | num_colors = self.prior_net.num_colors 73 | 74 | # Sample pixel intensities from a batch of probability distributions 75 | # for each pixel in each channel 76 | with torch.no_grad(): 77 | for i in range(height): 78 | for j in range(width): 79 | # The unmasked pixels are the ones where the mask has 80 | # nonzero value 81 | unmasked = x_cond[:, -1, i, j] > 0 82 | # If (i, j)th pixel is known for all images in batch (i.e. 83 | # if all values in unmasked are True), do not perform 84 | # forward pass of the model 85 | sample_pixel = True 86 | if unmasked.long().sum().item() == num_samples: 87 | sample_pixel = False 88 | for k in range(channels): 89 | if sample_pixel: 90 | logits, _, _ = self.forward(samples, x_cond) 91 | probs = F.softmax(logits / temp, dim=1) 92 | # Note that probs has shape 93 | # (batch, num_colors, channels, height, width) 94 | pixel_val = torch.multinomial(probs[:, :, k, i, j], 1) 95 | # The pixel intensities will be given by 0, 1, 2, ..., so 96 | # normalize these to be in 0 - 1 range as this is what the 97 | # model expects 98 | samples[:, k, i, j] = pixel_val[:, 0].float() / (num_colors - 1) 99 | # Set all the unmasked pixels to the value they are 100 | # conditioned on 101 | samples[:, k, i, j][unmasked] = x_cond[:, k, i, j][unmasked] 102 | 103 | # Reset model to train mode 104 | self.train() 105 | 106 | # Unnormalize pixels 107 | samples = (samples * (num_colors - 1)).long() 108 | 109 | if return_likelihood: 110 | return samples.cpu(), self.log_likelihood(samples, x_cond).cpu() 111 | else: 112 | return samples.cpu() 113 | 114 | def sample_unconditional(self, device, num_samples=16): 115 | """Samples from prior model without conditioning.""" 116 | return self.prior_net.sample(device, num_samples) 117 | 118 | def log_likelihood(self, samples, x_cond): 119 | """Calculates log likelihood of samples under model. 120 | 121 | Parameters 122 | ---------- 123 | samples : torch.Tensor 124 | Batch of images. Shape (batch_size, num_channels, width, height). 125 | Values should be integers in [0, self.prior_net.num_colors - 1]. 126 | 127 | x_cond : torch.Tensor 128 | Batch of conditional pixels. 129 | Shape (batch_size, num_channels + 1, width, height). 130 | """ 131 | # Set model to evaluation mode 132 | self.eval() 133 | 134 | num_samples, num_channels, height, width = samples.size() 135 | log_probs = torch.zeros(num_samples) 136 | log_probs = log_probs.to(x_cond.device) 137 | 138 | # Normalize samples before passing through model 139 | norm_samples = samples.float() / (self.prior_net.num_colors - 1) 140 | # Calculate pixel probs according to the model 141 | logits, _, _ = self.forward(norm_samples, x_cond) 142 | # Note that probs has shape 143 | # (batch, num_colors, channels, height, width) 144 | probs = F.softmax(logits, dim=1) 145 | 146 | # Calculate probability of each pixel 147 | for i in range(height): 148 | for j in range(width): 149 | # The unmasked pixels are the ones where the mask is nonzero 150 | unmasked = x_cond[:, -1, i, j] > 0 151 | for k in range(num_channels): 152 | # Get the batch of true values at pixel (k, i, j) 153 | true_vals = samples[:, k, i, j] 154 | # Get probability assigned by model to true pixel 155 | probs_pixel = probs[:, true_vals, k, i, j][:, 0] 156 | # Conditional pixels are known, so set probs of these to 1 157 | probs_pixel[unmasked] = 1. 158 | # Add log probs (1e-9 to avoid log(0)) 159 | log_probs += torch.log(probs_pixel + 1e-9) 160 | 161 | # Reset model to train mode 162 | self.train() 163 | 164 | return log_probs 165 | -------------------------------------------------------------------------------- /pixconcnn/training.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils.masks import get_conditional_pixels 6 | from torchvision.utils import make_grid 7 | 8 | 9 | class Trainer(): 10 | """Class used to train PixelCNN models without conditioning. 11 | 12 | Parameters 13 | ---------- 14 | model : pixconcnn.models.gated_pixelcnn.GatedPixelCNN(RGB) instance 15 | 16 | optimizer : one of optimizers in torch.optim 17 | 18 | device : torch.device instance 19 | 20 | record_loss_every : int 21 | Frequency (in iterations) with which to record loss. 22 | 23 | save_model_every : int 24 | Frequency (in epochs) with which to save model. 25 | """ 26 | def __init__(self, model, optimizer, device, record_loss_every=10, 27 | save_model_every=5): 28 | self.device = device 29 | self.losses = {'total': []} 30 | self.mean_epoch_losses = [] 31 | self.model = model 32 | self.optimizer = optimizer 33 | self.record_loss_every = record_loss_every 34 | self.save_model_every = save_model_every 35 | self.steps = 0 36 | 37 | def train(self, data_loader, epochs, directory='.'): 38 | """Trains model on the data given in data_loader. 39 | 40 | Parameters 41 | ---------- 42 | data_loader : torch.utils.data.DataLoader instance 43 | 44 | epochs : int 45 | Number of epochs to train model for. 46 | 47 | directory : string 48 | Directory in which to store training progress, including trained 49 | models and samples generated at every epoch. 50 | 51 | Returns 52 | ------- 53 | List of numpy arrays of generated images after each epoch. 54 | """ 55 | # List of generated images after each epoch to track progress of model 56 | progress_imgs = [] 57 | 58 | for epoch in range(epochs): 59 | print("\nEpoch {}/{}".format(epoch + 1, epochs)) 60 | epoch_loss = self._train_epoch(data_loader) 61 | mean_epoch_loss = epoch_loss / len(data_loader) 62 | print("Epoch loss: {}".format(mean_epoch_loss)) 63 | self.mean_epoch_losses.append(mean_epoch_loss) 64 | 65 | # Create a grid of model samples (limit number of samples by scaling 66 | # by number of pixels in output image; this is needed because of 67 | # GPU memory limitations) 68 | if self.model.img_size[-1] > 32: 69 | scale_to_32 = self.model.img_size[-1] / 32 70 | num_images = 64 / (scale_to_32 * scale_to_32) 71 | else: 72 | num_images = 64 73 | # Generate samples from model 74 | samples = self.model.sample(self.device, num_images) 75 | img_grid = make_grid(samples).cpu() 76 | # Convert to numpy with channels in imageio order 77 | img_grid = img_grid.float().numpy().transpose(1, 2, 0) / (self.model.num_colors - 1.) 78 | progress_imgs.append(img_grid) 79 | # Save generated image 80 | imageio.imsave(directory + '/training{}.png'.format(epoch), progress_imgs[-1]) 81 | # Save model 82 | if epoch % self.save_model_every == 0: 83 | torch.save(self.model.state_dict(), 84 | directory + '/model{}.pt'.format(epoch)) 85 | 86 | return progress_imgs 87 | 88 | def _train_epoch(self, data_loader): 89 | epoch_loss = 0 90 | for i, (batch, _) in enumerate(data_loader): 91 | batch_loss = self._train_iteration(batch) 92 | epoch_loss += batch_loss 93 | if i % 50 == 0: 94 | print("Iteration {}/{}, Loss: {}".format(i + 1, 95 | len(data_loader), 96 | batch_loss)) 97 | return epoch_loss 98 | 99 | def _train_iteration(self, batch): 100 | self.optimizer.zero_grad() 101 | 102 | batch = batch.to(self.device) 103 | 104 | # Normalize batch, i.e. put it in 0 - 1 range before passing it through 105 | # the model 106 | norm_batch = batch.float() / (self.model.num_colors - 1) 107 | logits = self.model(norm_batch) 108 | 109 | loss = self._loss(logits, batch) 110 | loss.backward() 111 | self.optimizer.step() 112 | 113 | self.steps += 1 114 | 115 | return loss.item() 116 | 117 | def _loss(self, logits, batch): 118 | loss = F.cross_entropy(logits, batch) 119 | 120 | if self.steps % self.record_loss_every == 0: 121 | self.losses['total'].append(loss.item()) 122 | 123 | return loss 124 | 125 | 126 | class PixelConstrainedTrainer(): 127 | """Class used to train Pixel Constrained CNN models. 128 | 129 | Parameters 130 | ---------- 131 | model : pixconcnn.models.pixel_constrained.PixelConstrained instance 132 | 133 | optimizer : one of optimizers in torch.optim 134 | 135 | device : torch.device instance 136 | 137 | mask_generator : pixconcnn.utils.masks.MaskGenerator instance 138 | Defines the masks used during training. 139 | 140 | weight_cond_logits_loss : float 141 | Weight on conditional logits in the loss (called alpha in the paper) 142 | 143 | weight_cond_logits_loss : float 144 | Weight on prio logits in the loss. 145 | 146 | record_loss_every : int 147 | Frequency (in iterations) with which to record loss. 148 | 149 | save_model_every : int 150 | Frequency (in epochs) with which to save model. 151 | """ 152 | def __init__(self, model, optimizer, device, mask_generator, 153 | weight_cond_logits_loss=0., weight_prior_logits_loss=0., 154 | record_loss_every=10, save_model_every=5): 155 | self.device = device 156 | self.losses = {'cond_logits': [], 'prior_logits': [], 'logits': [], 'total': []} # Keep track of losses 157 | self.mask_generator = mask_generator 158 | self.mean_epoch_losses = [] 159 | self.model = model 160 | self.optimizer = optimizer 161 | self.record_loss_every = record_loss_every 162 | self.save_model_every = save_model_every 163 | self.steps = 0 164 | self.weight_cond_logits_loss = weight_cond_logits_loss 165 | self.weight_prior_logits_loss = weight_prior_logits_loss 166 | 167 | def train(self, data_loader, epochs, directory='.'): 168 | """ 169 | Parameters 170 | ---------- 171 | data_loader : torch.utils.data.DataLoader instance 172 | 173 | epochs : int 174 | Number of epochs to train model for. 175 | 176 | directory : string 177 | Directory in which to store training progress, including trained 178 | models and samples generated at every epoch. 179 | 180 | Returns 181 | ------- 182 | List of numpy arrays of generated images after each epoch. 183 | """ 184 | # List of generated images after each epoch to track progress of model 185 | progress_imgs = [] 186 | 187 | # Use a fixed batch of images to test conditioning throughout training 188 | for batch, _ in data_loader: 189 | break 190 | test_mask = self.mask_generator.get_masks(batch.size(0)) 191 | cond_pixels = get_conditional_pixels(batch, test_mask, 192 | self.model.prior_net.num_colors) 193 | # Number of images generated in a batch is limited by GPU memory. 194 | # For 32 by 32 this should be 64, for 64 by 64 this should be 16 and 195 | # for 128 by 128 this should be 4 etc. 196 | if self.model.prior_net.img_size[-1] > 32: 197 | scale_to_32 = self.model.prior_net.img_size[-1] / 32 198 | num_images = 64 / (scale_to_32 * scale_to_32) 199 | else: 200 | num_images = 64 201 | cond_pixels = cond_pixels[:num_images] 202 | 203 | cond_pixels = cond_pixels.to(self.device) 204 | 205 | for epoch in range(epochs): 206 | print("\nEpoch {}/{}".format(epoch + 1, epochs)) 207 | epoch_loss = self._train_epoch(data_loader) 208 | mean_epoch_loss = epoch_loss / len(data_loader) 209 | print("Epoch loss: {}".format(mean_epoch_loss)) 210 | self.mean_epoch_losses.append(mean_epoch_loss) 211 | 212 | # Create a grid of model samples 213 | samples = self.model.sample(cond_pixels) 214 | img_grid = make_grid(samples, nrow=8).cpu() 215 | # Convert to numpy with channels in imageio order 216 | img_grid = img_grid.float().numpy().transpose(1, 2, 0) / (self.model.prior_net.num_colors - 1.) 217 | progress_imgs.append(img_grid) 218 | 219 | # Save generated image 220 | imageio.imsave(directory + '/training{}.png'.format(epoch), progress_imgs[-1]) 221 | 222 | # Save model 223 | if epoch % self.save_model_every == 0: 224 | torch.save(self.model.state_dict(), 225 | directory + '/model{}.pt'.format(epoch)) 226 | 227 | return progress_imgs 228 | 229 | def _train_epoch(self, data_loader): 230 | epoch_loss = 0 231 | for i, (batch, _) in enumerate(data_loader): 232 | mask = self.mask_generator.get_masks(batch.size(0)) 233 | batch_loss = self._train_iteration(batch, mask) 234 | epoch_loss += batch_loss 235 | if i % 50 == 0: 236 | print("Iteration {}/{}, Loss: {}".format(i + 1, len(data_loader), 237 | batch_loss)) 238 | return epoch_loss 239 | 240 | def _train_iteration(self, batch, mask): 241 | self.optimizer.zero_grad() 242 | 243 | # Note that towards the end of the dataset the batch size may be smaller 244 | # if batch_size doesn't divide the number of examples. In that case, 245 | # slice mask so it has same shape as batch. 246 | cond_pixels = get_conditional_pixels(batch, mask[:batch.size(0)], self.model.prior_net.num_colors) 247 | 248 | batch = batch.to(self.device) 249 | cond_pixels = cond_pixels.to(self.device) 250 | 251 | # Normalize batch, i.e. put it in 0 - 1 range before passing it through 252 | # the model 253 | norm_batch = batch.float() / (self.model.prior_net.num_colors - 1) 254 | logits, prior_logits, cond_logits = self.model(norm_batch, cond_pixels) 255 | 256 | loss = self._loss(logits, prior_logits, cond_logits, batch) 257 | loss.backward() 258 | self.optimizer.step() 259 | 260 | self.steps += 1 261 | 262 | return loss.item() 263 | 264 | def _loss(self, logits, prior_logits, cond_logits, batch): 265 | logits_loss = F.cross_entropy(logits, batch) 266 | prior_logits_loss = F.cross_entropy(prior_logits, batch) 267 | cond_logits_loss = F.cross_entropy(cond_logits, batch) 268 | total_loss = logits_loss + \ 269 | self.weight_cond_logits_loss * cond_logits_loss + \ 270 | self.weight_prior_logits_loss * prior_logits_loss 271 | 272 | # Record losses 273 | if self.steps % self.record_loss_every == 0: 274 | self.losses['total'].append(total_loss.item()) 275 | self.losses['cond_logits'].append(cond_logits_loss.item()) 276 | self.losses['prior_logits'].append(prior_logits_loss.item()) 277 | self.losses['logits'].append(logits_loss.item()) 278 | 279 | return total_loss 280 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | imageio==2.4.1 2 | matplotlib==2.2.3 3 | numpy==1.11.2 4 | Pillow==6.2.0 5 | torch==0.4.1 6 | torchvision==0.2.1 7 | -------------------------------------------------------------------------------- /trained_models/celeba/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "0", 3 | "dataset": "celeba", 4 | "resize": 32, 5 | "crop": 89, 6 | "grayscale": false, 7 | "batch_size": 64, 8 | "constrained": true, 9 | "num_colors": 32, 10 | "filter_size": 5, 11 | "depth": 18, 12 | "num_filters_cond": 66, 13 | "num_filters_prior": 66, 14 | "lr": 4e-4, 15 | "epochs": 80, 16 | "mask_descriptor": ["random_blob", [4, [2, 7], 0.5]], 17 | "num_conds": 4, 18 | "num_samples": 64, 19 | "weight_cond_logits_loss": 1.0, 20 | "weight_prior_logits_loss": 0.0 21 | } 22 | -------------------------------------------------------------------------------- /trained_models/celeba/model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/trained_models/celeba/model.pt -------------------------------------------------------------------------------- /trained_models/mnist/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "0", 3 | "dataset": "mnist", 4 | "resize": 28, 5 | "crop": 28, 6 | "grayscale": true, 7 | "batch_size": 64, 8 | "constrained": true, 9 | "num_colors": 2, 10 | "filter_size": 5, 11 | "depth": 16, 12 | "num_filters_cond": 32, 13 | "num_filters_prior": 32, 14 | "lr": 4e-4, 15 | "epochs": 50, 16 | "mask_descriptor": ["random_blob", [4, [2, 7], 0.5]], 17 | "num_conds": 4, 18 | "num_samples": 64, 19 | "weight_cond_logits_loss": 1.0, 20 | "weight_prior_logits_loss": 0.0 21 | } 22 | -------------------------------------------------------------------------------- /trained_models/mnist/model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/trained_models/mnist/model.pt -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/utils/__init__.py -------------------------------------------------------------------------------- /utils/dataloaders.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from PIL import Image 3 | from torch.utils.data import DataLoader, Dataset 4 | from torchvision import datasets, transforms 5 | 6 | 7 | def mnist(batch_size=128, num_colors=256, size=28, 8 | path_to_data='../mnist_data'): 9 | """MNIST dataloader with (28, 28) images. 10 | 11 | Parameters 12 | ---------- 13 | batch_size : int 14 | 15 | num_colors : int 16 | Number of colors to quantize images into. Typically 256, but can be 17 | lower for e.g. binary images. 18 | 19 | size : int 20 | Size (height and width) of each image. Default is 28 for no resizing. 21 | 22 | path_to_data : string 23 | Path to MNIST data files. 24 | """ 25 | quantize = get_quantize_func(num_colors) 26 | 27 | all_transforms = transforms.Compose([ 28 | transforms.Resize(size), 29 | transforms.ToTensor(), 30 | transforms.Lambda(lambda x: quantize(x)) 31 | ]) 32 | 33 | train_data = datasets.MNIST(path_to_data, train=True, download=True, 34 | transform=all_transforms) 35 | test_data = datasets.MNIST(path_to_data, train=False, 36 | transform=all_transforms) 37 | 38 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True) 39 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True) 40 | 41 | return train_loader, test_loader 42 | 43 | 44 | def celeba(batch_size=128, num_colors=256, size=178, crop=178, grayscale=False, 45 | shuffle=True, path_to_data='../celeba_data'): 46 | """CelebA dataloader with square images. Note original CelebA images have 47 | shape (218, 178), this dataloader center crops these images to be (178, 178) 48 | by default. 49 | 50 | Parameters 51 | ---------- 52 | batch_size : int 53 | 54 | num_colors : int 55 | Number of colors to quantize images into. Typically 256, but can be 56 | lower for e.g. binary images. 57 | 58 | size : int 59 | Size (height and width) of each image. 60 | 61 | crop : int 62 | Size of center crop. This crop happens *before* the resizing. 63 | 64 | grayscale : bool 65 | If True converts images to grayscale. 66 | 67 | shuffle : bool 68 | If True shuffles images. 69 | 70 | path_to_data : string 71 | Path to CelebA image files. 72 | """ 73 | quantize = get_quantize_func(num_colors) 74 | 75 | if grayscale: 76 | transform = transforms.Compose([ 77 | transforms.CenterCrop(crop), 78 | transforms.Resize(size), 79 | transforms.Grayscale(), 80 | transforms.ToTensor(), 81 | transforms.Lambda(lambda x: quantize(x)) 82 | ]) 83 | else: 84 | transform = transforms.Compose([ 85 | transforms.CenterCrop(crop), 86 | transforms.Resize(size), 87 | transforms.ToTensor(), 88 | transforms.Lambda(lambda x: quantize(x)) 89 | ]) 90 | celeba_data = CelebADataset(path_to_data, 91 | transform=transform) 92 | celeba_loader = DataLoader(celeba_data, batch_size=batch_size, 93 | shuffle=shuffle) 94 | return celeba_loader 95 | 96 | 97 | class CelebADataset(Dataset): 98 | """CelebA dataset. 99 | 100 | Parameters 101 | ---------- 102 | path_to_data : string 103 | Path to CelebA images. 104 | 105 | subsample : int 106 | Only load every |subsample| number of images. 107 | 108 | transform : None or one of torchvision.transforms instances 109 | """ 110 | def __init__(self, path_to_data, subsample=1, transform=None): 111 | self.img_paths = glob.glob(path_to_data + '/*')[::subsample] 112 | self.transform = transform 113 | 114 | def __len__(self): 115 | return len(self.img_paths) 116 | 117 | def __getitem__(self, idx): 118 | sample_path = self.img_paths[idx] 119 | sample = Image.open(sample_path) 120 | 121 | if self.transform: 122 | sample = self.transform(sample) 123 | # Since there are no labels, return 0 for the "label" 124 | return sample, 0 125 | 126 | 127 | def get_quantize_func(num_colors): 128 | """Returns a quantization function which can be used to set the number of 129 | colors in an image. 130 | 131 | Parameters 132 | ---------- 133 | num_colors : int 134 | Number of bins to quantize image into. Should be between 2 and 256. 135 | """ 136 | def quantize_func(batch): 137 | """Takes as input a float tensor with values in the 0 - 1 range and 138 | outputs a long tensor with integer values corresponding to each 139 | quantization bin. 140 | 141 | Parameters 142 | ---------- 143 | batch : torch.Tensor 144 | Values in 0 - 1 range. 145 | """ 146 | if num_colors == 2: 147 | return (batch > 0.5).long() 148 | else: 149 | return (batch * (num_colors - 1)).long() 150 | 151 | return quantize_func 152 | -------------------------------------------------------------------------------- /utils/init_models.py: -------------------------------------------------------------------------------- 1 | from pixconcnn.models.cnn import ResNet 2 | from pixconcnn.models.gated_pixelcnn import GatedPixelCNN, GatedPixelCNNRGB 3 | from pixconcnn.models.pixel_constrained import PixelConstrained 4 | 5 | 6 | def initialize_model(img_size, num_colors, depth, filter_size, constrained, 7 | num_filters_prior, num_filters_cond=1): 8 | """Helper function that initializes an appropriate model based on the 9 | input arguments. 10 | 11 | Parameters 12 | ---------- 13 | img_size : tuple of ints 14 | Specifies size of image as (channels, height, width), e.g. (3, 32, 32). 15 | If img_size[0] == 1, returned model will be for grayscale images. If 16 | img_size[0] == 3, returned model will be for RGB images. 17 | 18 | num_colors : int 19 | Number of colors to quantize output into. Typically 256, but can be 20 | lower for e.g. binary images. 21 | 22 | depth : int 23 | Number of layers in model. 24 | 25 | filter_size : int 26 | Size of (square) convolutional filters of the model. 27 | 28 | constrained : bool 29 | If True returns a PixelConstrained model, otherwise returns a 30 | GatedPixelCNN or GatedPixelCNNRGB model. 31 | 32 | num_filters_prior : int 33 | Number of convolutional filters in each layer of prior network. 34 | 35 | num_filter_cond : int (optional) 36 | Required if using a PixelConstrained model. Number of of convolutional 37 | filters in each layer of conditioning network. 38 | """ 39 | 40 | if img_size[0] == 1: 41 | prior_net = GatedPixelCNN(img_size=img_size, 42 | num_colors=num_colors, 43 | num_filters=num_filters_prior, 44 | depth=depth, 45 | filter_size=filter_size) 46 | else: 47 | prior_net = GatedPixelCNNRGB(img_size=img_size, 48 | num_colors=num_colors, 49 | num_filters=num_filters_prior, 50 | depth=depth, 51 | filter_size=filter_size) 52 | 53 | if constrained: 54 | # Add extra color channel for mask in conditioning network 55 | cond_net = ResNet(img_size=(img_size[0] + 1,) + img_size[1:], 56 | num_colors=num_colors, 57 | num_filters=num_filters_cond, 58 | depth=depth, 59 | filter_size=filter_size) 60 | # Define a pixel constrained model based on prior and cond net 61 | return PixelConstrained(prior_net, cond_net) 62 | else: 63 | return prior_net 64 | -------------------------------------------------------------------------------- /utils/loading.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from utils.dataloaders import mnist, celeba 4 | from utils.init_models import initialize_model 5 | 6 | def load_model(directory, model_version=None): 7 | """ 8 | Returns model, data_loader and mask_descriptor of trained model. 9 | 10 | Parameters 11 | ---------- 12 | directory : string 13 | Directory where experiment was saved. For example './experiment_1'. 14 | 15 | model_version : int or None 16 | If None loads final model, otherwise loads model version determined by 17 | int. 18 | """ 19 | path_to_config = directory + '/config.json' 20 | if model_version is None: 21 | path_to_model = directory + '/model.pt' 22 | else: 23 | path_to_model = directory + '/model{}.pt'.format(model_version) 24 | 25 | # Open config file 26 | with open(path_to_config) as config_file: 27 | config = json.load(config_file) 28 | 29 | # Load dataset info 30 | dataset = config["dataset"] 31 | resize = config["resize"] 32 | crop = config["crop"] 33 | batch_size = config["batch_size"] 34 | num_colors = config["num_colors"] 35 | if "grayscale" in config: 36 | grayscale = config["grayscale"] 37 | else: 38 | grayscale = False 39 | 40 | # Get data 41 | if dataset == 'mnist': 42 | # Extract the test dataset (second argument) 43 | _, data_loader = mnist(batch_size, num_colors, resize) 44 | img_size = (1, resize, resize) 45 | elif dataset == 'celeba': 46 | data_loader = celeba(batch_size, num_colors, resize, crop, grayscale) 47 | if grayscale: 48 | img_size = (1, resize, resize) 49 | else: 50 | img_size = (3, resize, resize) 51 | 52 | # Load model info 53 | constrained = config["constrained"] 54 | depth = config["depth"] 55 | num_filters_cond = config["num_filters_cond"] 56 | num_filters_prior = config["num_filters_prior"] 57 | filter_size = config["filter_size"] 58 | 59 | model = initialize_model(img_size, 60 | num_colors, 61 | depth, 62 | filter_size, 63 | constrained, 64 | num_filters_prior, 65 | num_filters_cond) 66 | 67 | model.load_state_dict(torch.load(path_to_model, map_location=lambda storage, loc: storage)) 68 | 69 | return model, data_loader, config["mask_descriptor"] 70 | -------------------------------------------------------------------------------- /utils/masks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import DataLoader, Dataset 4 | from torchvision import datasets, transforms 5 | 6 | 7 | class MaskGenerator(): 8 | """Class used to generate masks. Can be used to create masks during 9 | training or to build various masks for generation. 10 | 11 | Parameters 12 | ---------- 13 | img_size : tuple of ints 14 | E.g. (1, 28, 28) or (3, 64, 64) 15 | 16 | mask_descriptor : tuple of string and other 17 | Mask descriptors will be of the form (mask_type, mask_attribute). 18 | Allowed descriptors are: 19 | 1. ('random', None or int or tuple of ints): Generates random masks, 20 | where the position of visible pixels is selected uniformly 21 | at random over the image. If mask_attribute is None then the 22 | number of visible pixels is sampled uniformly between 1 and the 23 | total number of pixels in the image, otherwise it is fixed to 24 | the int given in mask_attribute. If mask_attribute is a tuple 25 | of ints, the number of visible pixels is sampled uniformly 26 | between the first int (lower bound) and the second int (upper 27 | bound). 28 | 2. ('bottom', int): Generates masks where only the bottom pixels are 29 | visible. The int determines the number of rows of the image to 30 | keep visible at the bottom. 31 | 3. ('top', int): Generates masks where only the top pixels are 32 | visible. The int determines the number of rows of the image to 33 | keep visible at the top. 34 | 4. ('center', int): Generates masks where only the central pixels 35 | are visible. The int determines the size in pixels of the sides 36 | of the square of visible pixels of the image. 37 | 5. ('edge', int): Generates masks where only the edge pixels of the 38 | image are visible. The int determines the thickness of the edges 39 | in pixels. 40 | 6. ('left', int): Generates masks where only the left pixels of the 41 | image are visible. The int determines the number of columns 42 | in pixels which are visible. 43 | 7. ('right', int): Generates masks where only the right pixels of 44 | the image are visible. The int determines the number of columns 45 | in pixels which are visible. 46 | 8. ('random_rect', (int, int)): Generates random rectangular masks 47 | where the maximum height and width of the rectangles are 48 | determined by the two ints. 49 | 9. ('random_blob', (int, (int, int), float)): Generates random 50 | blobs, where the number of blobs is determined by the first int, 51 | the range of iterations (see function definition) is determined 52 | by the tuple of ints and the threshold for making pixels visible 53 | is determined by the float. 54 | 10. ('random_blob_cache', (str, int)): Loads pregenerated random masks 55 | from a folder given by the string, using a batch_size given by 56 | the int. 57 | """ 58 | def __init__(self, img_size, mask_descriptor): 59 | self.img_size = img_size 60 | self.num_pixels = img_size[1] * img_size[2] 61 | self.mask_type, self.mask_attribute = mask_descriptor 62 | 63 | if self.mask_type == 'random_blob_cache': 64 | dset = datasets.ImageFolder(self.mask_attribute[0], 65 | transform=transforms.Compose([transforms.Grayscale(), 66 | transforms.ToTensor()])) 67 | self.data_loader = DataLoader(dset, batch_size=self.mask_attribute[1], shuffle=True) 68 | 69 | def get_masks(self, batch_size): 70 | """Returns a tensor of shape (batch_size, 1, img_size[1], img_size[2]) 71 | containing masks which were generated according to mask_type and 72 | mask_attribute. 73 | 74 | Parameters 75 | ---------- 76 | batch_size : int 77 | """ 78 | if self.mask_type == 'random': 79 | if self.mask_attribute is None: 80 | num_visibles = np.random.randint(1, self.num_pixels, size=batch_size) 81 | return batch_random_mask(self.img_size, num_visibles, batch_size) 82 | elif type(self.mask_attribute) == int: 83 | return batch_random_mask(self.img_size, self.mask_attribute, batch_size) 84 | else: 85 | lower_bound, upper_bound = self.mask_attribute 86 | num_visibles = np.random.randint(lower_bound, upper_bound, size=batch_size) 87 | return batch_random_mask(self.img_size, num_visibles, batch_size) 88 | elif self.mask_type == 'bottom': 89 | return batch_bottom_mask(self.img_size, self.mask_attribute, batch_size) 90 | elif self.mask_type == 'top': 91 | return batch_top_mask(self.img_size, self.mask_attribute, batch_size) 92 | elif self.mask_type == 'center': 93 | return batch_center_mask(self.img_size, self.mask_attribute, batch_size) 94 | elif self.mask_type == 'edge': 95 | return batch_edge_mask(self.img_size, self.mask_attribute, batch_size) 96 | elif self.mask_type == 'left': 97 | return batch_left_mask(self.img_size, self.mask_attribute, batch_size) 98 | elif self.mask_type == 'right': 99 | return batch_right_mask(self.img_size, self.mask_attribute, batch_size) 100 | elif self.mask_type == 'random_rect': 101 | return batch_random_rect_mask(self.img_size, self.mask_attribute[0], 102 | self.mask_attribute[1], batch_size) 103 | elif self.mask_type == 'random_blob': 104 | return batch_multi_random_blobs(self.img_size, 105 | self.mask_attribute[0], 106 | self.mask_attribute[1], 107 | self.mask_attribute[2], batch_size) 108 | elif self.mask_type == 'random_blob_cache': 109 | # Hacky way to get a single batch of data 110 | for mask_batch in self.data_loader: 111 | break 112 | # Zero index because Image folder returns (img, label) tuple 113 | return mask_batch[0] 114 | 115 | 116 | def single_random_mask(img_size, num_visible): 117 | """Returns random mask where 0 corresponds to a hidden value and 1 to a 118 | visible value. Shape of mask is same as img_size. 119 | 120 | Parameters 121 | ---------- 122 | img_size : tuple of ints 123 | E.g. (1, 32, 32) for grayscale or (3, 64, 64) for RGB. 124 | 125 | num_visible : int 126 | Number of visible values. 127 | """ 128 | _, height, width = img_size 129 | # Sample integers without replacement between 0 and the total number of 130 | # pixels. The measurements array will then contain a pixel indices 131 | # corresponding to locations where pixels will be visible. 132 | measurements = np.random.choice(range(height * width), size=num_visible, replace=False) 133 | # Create empty mask 134 | mask = torch.zeros(1, width, height) 135 | # Update mask with measurements 136 | for m in measurements: 137 | row = int(m / width) 138 | col = m % width 139 | mask[0, row, col] = 1 140 | return mask 141 | 142 | 143 | def batch_random_mask(img_size, num_visibles, batch_size, repeat=False): 144 | """Returns a batch of random masks. 145 | 146 | Parameters 147 | ---------- 148 | img_size : see single_random_mask 149 | 150 | num_visibles : int or list of ints 151 | If int will keep the number of visible pixels in the masks fixed, if 152 | list will change the number of visible pixels depending on the values 153 | in the list. List should have length equal to batch_size. 154 | 155 | batch_size : int 156 | Number of masks to create. 157 | 158 | repeat : bool 159 | If True returns a batch of the same mask repeated batch_size times. 160 | """ 161 | # Mask should have same shape as image, but only 1 channel 162 | mask_batch = torch.zeros(batch_size, 1, *img_size[1:]) 163 | if repeat: 164 | if not type(num_visibles) == int: 165 | raise RuntimeError("num_visibles must be an int if used with repeat=True. {} was provided instead.".format(type(num_visibles))) 166 | single_mask = single_random_mask(img_size, num_visibles) 167 | for i in range(batch_size): 168 | mask_batch[i] = single_mask 169 | else: 170 | if type(num_visibles) == int: 171 | for i in range(batch_size): 172 | mask_batch[i] = single_random_mask(img_size, num_visibles) 173 | else: 174 | for i in range(batch_size): 175 | mask_batch[i] = single_random_mask(img_size, num_visibles[i]) 176 | return mask_batch 177 | 178 | 179 | def batch_bottom_mask(img_size, num_rows, batch_size): 180 | """Masks all the output except the |num_rows| lowest rows (in the height 181 | dimension). 182 | 183 | Parameters 184 | ---------- 185 | img_size : see single_random_mask 186 | 187 | num_rows : int 188 | Number of rows from bottom which will be visible. 189 | 190 | batch_size : int 191 | Number of masks to create. 192 | """ 193 | mask = torch.zeros(batch_size, 1, *img_size[1:]) 194 | mask[:, :, -num_rows:, :] = 1. 195 | return mask 196 | 197 | 198 | def batch_top_mask(img_size, num_rows, batch_size): 199 | """Masks all the output except the |num_rows| highest rows (in the height 200 | dimension). 201 | 202 | Parameters 203 | ---------- 204 | img_size : see single_random_mask 205 | 206 | num_rows : int 207 | Number of rows from top which will be visible. 208 | 209 | batch_size : int 210 | Number of masks to create. 211 | """ 212 | mask = torch.zeros(batch_size, 1, *img_size[1:]) 213 | mask[:, :, :num_rows, :] = 1. 214 | return mask 215 | 216 | 217 | def batch_center_mask(img_size, num_pixels, batch_size): 218 | """Masks all the output except the num_pixels by num_pixels central square 219 | of the image. 220 | 221 | Parameters 222 | ---------- 223 | img_size : see single_random_mask 224 | 225 | num_pixels : int 226 | Should be even. If not even, num_pixels will be replaced with 227 | num_pixels - 1. 228 | 229 | batch_size : int 230 | Number of masks to create. 231 | """ 232 | mask = torch.zeros(batch_size, 1, *img_size[1:]) 233 | _, height, width = img_size 234 | lower_height = int(height / 2 - num_pixels / 2) 235 | upper_height = int(height / 2 + num_pixels / 2) 236 | lower_width = int(width / 2 - num_pixels / 2) 237 | upper_width = int(width / 2 + num_pixels / 2) 238 | mask[:, :, lower_height:upper_height, lower_width:upper_width] = 1. 239 | return mask 240 | 241 | 242 | def batch_edge_mask(img_size, num_pixels, batch_size): 243 | """Masks all the output except the num_pixels thick edge of the image. 244 | 245 | Parameters 246 | ---------- 247 | img_size : see single_random_mask 248 | 249 | num_pixels : int 250 | Should be smaller than min(height / 2, width / 2). 251 | 252 | batch_size : int 253 | Number of masks to create. 254 | """ 255 | mask = torch.zeros(batch_size, 1, *img_size[1:]) 256 | mask[:, :, :num_pixels, :] = 1. 257 | mask[:, :, -num_pixels:, :] = 1. 258 | mask[:, :, :, :num_pixels] = 1. 259 | mask[:, :, :, -num_pixels:] = 1. 260 | return mask 261 | 262 | 263 | def batch_left_mask(img_size, num_cols, batch_size): 264 | """Masks all the pixels except the left side of the image. 265 | 266 | Parameters 267 | ---------- 268 | img_size : see single_random_mask 269 | 270 | num_cols : int 271 | Number of columns of the left side of the image to remain visible. 272 | 273 | batch_size : int 274 | Number of masks to create. 275 | """ 276 | mask = torch.zeros(batch_size, 1, *img_size[1:]) 277 | mask[:, :, :, :num_cols] = 1. 278 | return mask 279 | 280 | 281 | def batch_right_mask(img_size, num_cols, batch_size): 282 | """Masks all the pixels except the right side of the image. 283 | 284 | Parameters 285 | ---------- 286 | img_size : see single_random_mask 287 | 288 | num_cols : int 289 | Number of columns of the right side of the image to remain visible. 290 | 291 | batch_size : int 292 | Number of masks to create. 293 | """ 294 | mask = torch.zeros(batch_size, 1, *img_size[1:]) 295 | mask[:, :, :, -num_cols:] = 1. 296 | return mask 297 | 298 | 299 | def random_rect_mask(img_size, max_height, max_width): 300 | """Returns a mask with a random rectangle of visible pixels. 301 | 302 | Parameters 303 | ---------- 304 | img_size : see single_random_mask 305 | 306 | max_height : int 307 | Maximum height of randomly sampled rectangle. 308 | 309 | max_width : int 310 | Maximum width of randomly sampled rectangle. 311 | """ 312 | mask = torch.zeros(1, *img_size[1:]) 313 | _, img_width, img_height = img_size 314 | # Sample top left corner of unmasked rectangle 315 | top_left = np.random.randint(0, img_height - 1), np.random.randint(0, img_width - 1) 316 | # Sample height of rectangle 317 | # This is a number between 1 and the max_height parameter. If the top left corner 318 | # is too close to the bottom of the image, make sure the rectangle doesn't exceed 319 | # this 320 | rect_height = np.random.randint(1, min(max_height, img_height - top_left[0])) 321 | # Sample width of rectangle 322 | rect_width = np.random.randint(1, min(max_width, img_width - top_left[1])) 323 | # Set visible pixels 324 | bottom_right = top_left[0] + rect_height, top_left[1] + rect_width 325 | mask[0, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] = 1. 326 | return mask 327 | 328 | 329 | def batch_random_rect_mask(img_size, max_height, max_width, batch_size): 330 | """Returns a batch of masks with random rectangles of visible pixels. 331 | 332 | Parameters 333 | ---------- 334 | img_size : see single_random_mask 335 | 336 | max_height : int 337 | Maximum height of randomly sampled rectangle. 338 | 339 | max_width : int 340 | Maximum width of randomly sampled rectangle. 341 | 342 | batch_size : int 343 | Number of masks to create. 344 | """ 345 | mask = torch.zeros(batch_size, 1, *img_size[1:]) 346 | for i in range(batch_size): 347 | mask[i] = random_rect_mask(img_size, max_height, max_width) 348 | return mask 349 | 350 | 351 | def random_blob(img_size, num_iter, threshold, fixed_init=None): 352 | """Generates masks with random connected blobs. 353 | 354 | Parameters 355 | ---------- 356 | img_size : see single_random_mask 357 | 358 | num_iter : int 359 | Number of iterations to expand random blob for. 360 | 361 | threshold : float 362 | Number between 0 and 1. Probability of keeping a pixel hidden. 363 | 364 | fixed_init : tuple of ints or None 365 | If fixed_init is None, central position of blob will be sampled 366 | randomly, otherwise expansion will start from fixed_init. E.g. 367 | fixed_init = (6, 12) will start the expansion from pixel in row 6, 368 | column 12. 369 | """ 370 | _, img_height, img_width = img_size 371 | # Defines the shifts around the central pixel which may be unmasked 372 | neighbors = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)] 373 | if fixed_init is None: 374 | # Sample random initial position 375 | init_pos = np.random.randint(0, img_height - 1), np.random.randint(0, img_width - 1) 376 | else: 377 | init_pos = (fixed_init[0], fixed_init[1]) 378 | # Initialize mask and make init_pos visible 379 | mask = torch.zeros(1, 1, *img_size[1:]) 380 | mask[0, 0, init_pos[0], init_pos[1]] = 1. 381 | # Initialize the list of seed positions 382 | seed_positions = [init_pos] 383 | # Randomly expand blob 384 | for i in range(num_iter): 385 | next_seed_positions = [] 386 | for seed_pos in seed_positions: 387 | # Sample probability that neighboring pixel will be visible 388 | prob_visible = np.random.rand(len(neighbors)) 389 | for j, neighbor in enumerate(neighbors): 390 | if prob_visible[j] > threshold: 391 | current_h, current_w = seed_pos 392 | shift_h, shift_w = neighbor 393 | # Ensure new height stays within image boundaries 394 | new_h = max(min(current_h + shift_h, img_height - 1), 0) 395 | # Ensure new width stays within image boundaries 396 | new_w = max(min(current_w + shift_w, img_width - 1), 0) 397 | # Update mask 398 | mask[0, 0, new_h, new_w] = 1. 399 | # Add new position to list of seeds 400 | next_seed_positions.append((new_h, new_w)) 401 | seed_positions = next_seed_positions 402 | return mask 403 | 404 | 405 | def multi_random_blobs(img_size, max_num_blobs, iter_range, threshold): 406 | """Generates masks with multiple random connected blobs. 407 | 408 | Parameters 409 | ---------- 410 | max_num_blobs : int 411 | Maximum number of blobs. Number of blobs will be sampled between 1 and 412 | max_num_blobs 413 | 414 | iter_range : (int, int) 415 | Lower and upper bound on number of iterations to be used for each blob. 416 | This will be sampled for each blob. 417 | 418 | threshold : float 419 | Number between 0 and 1. Probability of keeping a pixel hidden. 420 | """ 421 | mask = torch.zeros(1, 1, *img_size[1:]) 422 | # Sample number of blobs 423 | num_blobs = np.random.randint(1, max_num_blobs + 1) 424 | for _ in range(num_blobs): 425 | num_iter = np.random.randint(iter_range[0], iter_range[1]) 426 | mask += random_blob(img_size, num_iter, threshold) 427 | mask[mask > 0] = 1. 428 | return mask 429 | 430 | 431 | def batch_multi_random_blobs(img_size, max_num_blobs, iter_range, threshold, 432 | batch_size): 433 | """Generates batch of masks with multiple random connected blobs.""" 434 | mask = torch.zeros(batch_size, 1, *img_size[1:]) 435 | for i in range(batch_size): 436 | mask[i] = multi_random_blobs(img_size, max_num_blobs, iter_range, threshold) 437 | return mask 438 | 439 | 440 | def get_conditional_pixels(batch, mask, num_colors): 441 | """Returns conditional pixels obtained from masking the data in batch with 442 | mask and appending the mask. E.g. if the input has size (N, C, H, W) 443 | then the output will have size (N, C + 1, H, W) i.e. the mask is appended 444 | as an extra color channel. 445 | 446 | Parameters 447 | ---------- 448 | batch : torch.Tensor 449 | Batch of data as returned by a DataLoader, i.e. unnormalized. 450 | Shape (num_examples, num_channels, width, height) 451 | 452 | mask : torch.Tensor 453 | Mask as returned by MaskGenerator.get_masks. 454 | Shape (num_examples, 1, width, height) 455 | 456 | num_colors : int 457 | Number of colors image is quantized to. 458 | """ 459 | batch_size, channels, width, height = batch.size() 460 | # Add extra channel to keep mask 461 | cond_pixels = torch.zeros((batch_size, channels + 1, height, width)) 462 | # Mask batch to only show visible pixels 463 | cond_pixels[:, :channels, :, :] = mask * batch.float() 464 | # Add mask scaled by number of colors in last channel dimension 465 | cond_pixels[:, -1:, :, :] = mask * (num_colors - 1) 466 | # Normalize conditional pixels to be in 0 - 1 range 467 | return cond_pixels / (num_colors - 1) 468 | 469 | 470 | def get_repeated_conditional_pixels(batch, mask, num_colors, num_reps): 471 | """Returns repeated conditional pixels. 472 | 473 | Parameters 474 | ---------- 475 | batch : torch.Tensor 476 | Shape (1, num_channels, width, height) 477 | 478 | mask : torch.Tensor 479 | Shape (1, num_channels, width, height) 480 | 481 | num_colors : int 482 | Number of colors image is quantized to. 483 | 484 | num_reps : int 485 | Number of times the conditional pixels will be repeated 486 | """ 487 | assert batch.size(0) == 1 488 | assert mask.size(0) == 1 489 | cond_pixels = get_conditional_pixels(batch, mask, num_colors) 490 | return cond_pixels.expand(num_reps, *cond_pixels.size()[1:]) 491 | -------------------------------------------------------------------------------- /utils/plots.py: -------------------------------------------------------------------------------- 1 | from matplotlib.pyplot import get_cmap 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def probs_and_conditional_plot(img, probs, mask, cmap='plasma'): 7 | """Creates a plot of pixel probabilities with the conditional pixels from 8 | the original image overlayed. Note this function only works for binary 9 | images. 10 | 11 | Parameters 12 | ---------- 13 | img : torch.Tensor 14 | Shape (1, H, W) 15 | 16 | probs : torch.Tensor 17 | Shape (H, W). Should be the probability of a pixel being 1. 18 | 19 | mask : torch.Tensor 20 | Shape (H, W) 21 | 22 | cmap : string 23 | Colormap to use for probs (as defined in matplotlib.plt, 24 | e.g. 'jet', 'viridis', ...) 25 | """ 26 | # Define function to convert array to colormap rgb image 27 | # The colorscale has a min value of 0 and a max value of 1 by default 28 | # (i.e. it does not rescale depending on the value of the probs) 29 | convert_to_cmap = get_cmap(cmap) 30 | # Create output image from colormap of probs 31 | rgba_probs = convert_to_cmap(probs.numpy()) 32 | output_img = np.delete(rgba_probs, 3, 2) # Convert to RGB 33 | # Overlay unmasked parts of original image over probs 34 | np_mask = mask.numpy().astype(bool) # Convert mask to boolean numpy array 35 | np_img = img.numpy()[0] # Convert img to grayscale numpy img 36 | output_img[:, :, 0][np_mask] = np_img[np_mask] 37 | output_img[:, :, 1][np_mask] = np_img[np_mask] 38 | output_img[:, :, 2][np_mask] = np_img[np_mask] 39 | # Convert numpy image to torch tensor 40 | return torch.Tensor(output_img.transpose(2, 0, 1)) 41 | 42 | 43 | def uncertainty_plot(samples, log_probs, cmap='plasma'): 44 | """Sorts samples by their log likelihoods and creates an image representing 45 | the log likelihood of each sample as a box with color and size proportional 46 | to the log likelihood. 47 | 48 | Parameters 49 | ---------- 50 | samples : torch.Tensor 51 | Shape (N, C, H, W) 52 | 53 | log_probs : torch.Tensor 54 | Shape (N,) 55 | 56 | cmap : string 57 | Colormap to use for likelihoods (as defined in matplotlib.plt, 58 | e.g. 'jet', 'viridis', ...) 59 | """ 60 | # Sorted by negative log likelihood 61 | sorted_nll, sorted_indices = torch.sort(-log_probs) 62 | sorted_samples = samples[sorted_indices] 63 | # Normalize log likelihoods to be in 0 - 1 range 64 | min_ll, max_ll = (-sorted_nll).min(), (-sorted_nll).max() 65 | normalized_likelihoods = ((-sorted_nll) - min_ll) / (max_ll - min_ll) 66 | 67 | # For each sample draw an image with a box proportional in size and 68 | # color to the log likelihood value 69 | num_samples, _, height, width = samples.size() 70 | # Initialize white background images on which to draw boxes 71 | ll_images = torch.ones(num_samples, 3, height, width) 72 | # Specify box sizes 73 | lower_width = width // 2 - width // 5 74 | upper_width = width // 2 + width // 5 75 | max_box_height = height 76 | min_box_height = 1 77 | # Generate colors for the boxes 78 | convert_to_cmap = get_cmap(cmap) 79 | # Remove alpha channel from colormap 80 | colors = convert_to_cmap(normalized_likelihoods.numpy())[:, :-1] 81 | 82 | # Fill out images with boxes 83 | for i in range(num_samples): 84 | norm_ll = normalized_likelihoods[i].item() 85 | box_height = int(min_box_height + (max_box_height - min_box_height) * norm_ll) 86 | box_color = colors[i] 87 | for j in range(3): 88 | ll_images[i, j, height - box_height:height, lower_width:upper_width] = box_color[j] 89 | 90 | return sorted_samples, ll_images 91 | --------------------------------------------------------------------------------