├── LICENSE
├── README.md
├── config.json
├── imgs
├── int-dark-side.png
├── int-eye-color.png
├── int-male-female.png
├── new-blob-samples-2.png
├── new-blob-samples.png
├── new-bottom-celeba.png
├── new-bottom-seven-nine.png
├── new-left-celeba.png
├── new-random-celeba.png
├── new-small-missing-celeba.png
├── new-top-celeba.png
├── probinpainting.gif
└── summary-gif.gif
├── main.py
├── main_generate.py
├── pixconcnn
├── __init__.py
├── generate.py
├── layers.py
├── models
│ ├── __init__.py
│ ├── cnn.py
│ ├── gated_pixelcnn.py
│ └── pixel_constrained.py
└── training.py
├── requirements.txt
├── trained_models
├── celeba
│ ├── config.json
│ └── model.pt
└── mnist
│ ├── config.json
│ └── model.pt
└── utils
├── __init__.py
├── dataloaders.py
├── init_models.py
├── loading.py
├── masks.py
└── plots.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2019, Emilien Dupont & Suhas Suresha
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Probabilistic Semantic Inpainting with Pixel Constrained CNNs
2 |
3 | Pytorch implementation of [Probabilistic Semantic Inpainting with Pixel Constrained CNNs](https://arxiv.org/abs/1810.03728) (2018).
4 |
5 | This repo contains an implementation of Pixel Constrained CNN, a framework for performing probabilistic inpainting of images with arbitrary occlusions. It also includes all code to reproduce the experiments in the paper as well as the weights of the trained models.
6 |
7 | For a TensorFlow implementation, see this [repo](https://github.com/Schlumberger/pixel-constrained-cnn-tf).
8 |
9 | ## Examples
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | ## Usage
24 |
25 | ### Training
26 |
27 | The attributes of the model can be set in the `config.json` file. To train the model, run
28 |
29 | ```
30 | python main.py config.json
31 | ```
32 |
33 | This will also save the trained model and log various information as training progresses. Examples of `config.json` files are available in the `trained_models` directory.
34 |
35 | ### Inpainting
36 |
37 | To generate images with a trained model use `main_generate.py`. As an example, the following command generates 64 completions for images 73 and 84 in the MNIST dataset by conditioning on the top 14 rows. The model used to generate the completions is the trained MNIST model included in this repo and the results are saved to the `mnist_inpaintings` folder.
38 |
39 | ```
40 | python main_generate.py -n mnist_inpaintings -m trained_models/mnist -t generation -i 73 84 -to 14 -ns 64
41 | ```
42 |
43 | For a full list of options, run `python main_generate.py --help`. Note that if you do not have the MNIST dataset on your machine it will be automatically downloaded when running the above command. The CelebA dataset will have to be manually downloaded (see the Data sources section). If you already have the datasets downloaded, you can change the paths in `utils/dataloaders.py` to point to the correct folders on your machine.
44 |
45 | ## Trained models
46 |
47 | The trained models referenced in the paper are included in the `trained_models` folder. You can use the `main_generate.py` script to generate image completions (and other plots) with these models.
48 |
49 | ## Data sources
50 |
51 | The MNIST dataset can be automatically downloaded using `torchvision`. The CelebA dataset can be found [here](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).
52 |
53 | ## Citing
54 |
55 | If you find this work useful in your research, please cite using:
56 |
57 | ```
58 | @article{dupont2018probabilistic,
59 | title={Probabilistic Semantic Inpainting with Pixel Constrained CNNs},
60 | author={Dupont, Emilien and Suresha, Suhas},
61 | journal={arXiv preprint arXiv:1810.03728},
62 | year={2018}
63 | }
64 | ```
65 |
66 | ## More examples
67 |
68 |
69 |
70 |
71 |
72 |
73 | ## License
74 |
75 | [Apache License 2.0](LICENSE)
76 |
--------------------------------------------------------------------------------
/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "mnist_model",
3 | "dataset": "mnist",
4 | "resize": 28,
5 | "crop": 28,
6 | "grayscale": true,
7 | "batch_size": 64,
8 | "constrained": true,
9 | "num_colors": 2,
10 | "filter_size": 5,
11 | "depth": 16,
12 | "num_filters_cond": 32,
13 | "num_filters_prior": 32,
14 | "lr": 4e-4,
15 | "epochs": 50,
16 | "mask_descriptor": ["random_rect", [10, 10]],
17 | "num_conds": 4,
18 | "num_samples": 64,
19 | "weight_cond_logits_loss": 1.0,
20 | "weight_prior_logits_loss": 0.0
21 | }
22 |
--------------------------------------------------------------------------------
/imgs/int-dark-side.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/int-dark-side.png
--------------------------------------------------------------------------------
/imgs/int-eye-color.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/int-eye-color.png
--------------------------------------------------------------------------------
/imgs/int-male-female.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/int-male-female.png
--------------------------------------------------------------------------------
/imgs/new-blob-samples-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-blob-samples-2.png
--------------------------------------------------------------------------------
/imgs/new-blob-samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-blob-samples.png
--------------------------------------------------------------------------------
/imgs/new-bottom-celeba.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-bottom-celeba.png
--------------------------------------------------------------------------------
/imgs/new-bottom-seven-nine.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-bottom-seven-nine.png
--------------------------------------------------------------------------------
/imgs/new-left-celeba.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-left-celeba.png
--------------------------------------------------------------------------------
/imgs/new-random-celeba.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-random-celeba.png
--------------------------------------------------------------------------------
/imgs/new-small-missing-celeba.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-small-missing-celeba.png
--------------------------------------------------------------------------------
/imgs/new-top-celeba.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/new-top-celeba.png
--------------------------------------------------------------------------------
/imgs/probinpainting.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/probinpainting.gif
--------------------------------------------------------------------------------
/imgs/summary-gif.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/imgs/summary-gif.gif
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import imageio
2 | import json
3 | import os
4 | import sys
5 | import time
6 | import torch
7 | from pixconcnn.training import Trainer, PixelConstrainedTrainer
8 | from torchvision.utils import save_image
9 | from utils.dataloaders import mnist, celeba
10 | from utils.init_models import initialize_model
11 | from utils.masks import batch_random_mask, get_repeated_conditional_pixels, MaskGenerator
12 |
13 |
14 | # Set device
15 | cuda = torch.cuda.is_available()
16 | device = torch.device("cuda" if cuda else "cpu")
17 |
18 | # Get config file from command line arguments
19 | if len(sys.argv) != 2:
20 | raise(RuntimeError("Wrong arguments, use python main.py "))
21 | config_path = sys.argv[1]
22 |
23 | # Open config file
24 | with open(config_path) as config_file:
25 | config = json.load(config_file)
26 |
27 | name = config['name']
28 | constrained = config['constrained']
29 | batch_size = config['batch_size']
30 | lr = config['lr']
31 | num_colors = config['num_colors']
32 | epochs = config['epochs']
33 | dataset = config['dataset']
34 | resize = config['resize'] # Only relevant for celeba
35 | crop = config['crop'] # Only relevant for celeba
36 | grayscale = config["grayscale"] # Only relevant for celeba
37 | num_conds = config['num_conds'] # Only relevant if constrained
38 | num_samples = config['num_samples'] # Only relevant if constrained
39 | filter_size = config['filter_size']
40 | depth = config['depth']
41 | num_filters_cond = config['num_filters_cond']
42 | num_filters_prior = config['num_filters_prior']
43 | mask_descriptor = config['mask_descriptor']
44 | weight_cond_logits_loss = config['weight_cond_logits_loss']
45 | weight_prior_logits_loss = config['weight_prior_logits_loss']
46 |
47 | # Create a folder to store experiment results
48 | timestamp = time.strftime("%Y-%m-%d_%H-%M")
49 | directory = "{}_{}".format(timestamp, name)
50 | if not os.path.exists(directory):
51 | os.makedirs(directory)
52 |
53 | # Save config file in experiment directory
54 | with open(directory + '/config.json', 'w') as config_file:
55 | json.dump(config, config_file)
56 |
57 | # Get data
58 | if dataset == 'mnist':
59 | data_loader, _ = mnist(batch_size, num_colors=num_colors, size=resize)
60 | img_size = (1, resize, resize)
61 | elif dataset == 'celeba':
62 | data_loader = celeba(batch_size, num_colors=num_colors, size=resize,
63 | crop=crop, grayscale=grayscale)
64 | if grayscale:
65 | img_size = (1, resize, resize)
66 | else:
67 | img_size = (3, resize, resize)
68 |
69 | # Initialize model weights and architecture
70 | model = initialize_model(img_size,
71 | num_colors,
72 | depth,
73 | filter_size,
74 | constrained,
75 | num_filters_prior,
76 | num_filters_cond)
77 | model.to(device)
78 | print(model)
79 |
80 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
81 |
82 | if constrained:
83 | mask_generator = MaskGenerator(img_size, mask_descriptor)
84 | trainer = PixelConstrainedTrainer(model, optimizer, device, mask_generator,
85 | weight_cond_logits_loss=weight_cond_logits_loss,
86 | weight_prior_logits_loss=weight_prior_logits_loss)
87 | # Train model
88 | progress_imgs = trainer.train(data_loader, epochs, directory=directory)
89 |
90 | # Get a random batch of images
91 | for batch, _ in data_loader:
92 | break
93 |
94 | for i in range(num_conds):
95 | mask = mask_generator.get_masks(batch_size)
96 | print('Generating {}/{} conditionings'.format(i + 1, num_conds))
97 | cond_pixels = get_repeated_conditional_pixels(batch[i:i+1], mask[i:i+1],
98 | num_colors, num_samples)
99 | # Save mask as tensor
100 | torch.save(mask[i:i+1], directory + '/mask{}.pt'.format(i))
101 | # Save image that gave rise to the conditioning as tensor
102 | torch.save(batch[i:i+1], directory + '/source{}.pt'.format(i))
103 | # Save conditional pixels as tensor and image
104 | torch.save(cond_pixels[0:1], directory + '/cond_pixels{}.pt'.format(i))
105 | save_image(cond_pixels[0:1], directory + '/cond_pixels{}.png'.format(i))
106 |
107 | cond_pixels = cond_pixels.to(device)
108 | samples = model.sample(cond_pixels)
109 | # Save samples and mean sample as tensor and image
110 | torch.save(samples, directory + '/samples_cond{}.pt'.format(i))
111 | save_image(samples.float() / (num_colors - 1.),
112 | directory + '/samples_cond{}.png'.format(i))
113 | save_image(samples.float().mean(dim=0) / (num_colors - 1.),
114 | directory + '/mean_cond{}.png'.format(i))
115 | # Save conditional logits if image is binary
116 | if num_colors == 2:
117 | # Save conditional logits
118 | logits, _, cond_logits = model(batch[i:i+1].float().to(device), cond_pixels[0:1])
119 | # Second dimension corresponds to different pixel values, so select probs of it being 1
120 | save_image(cond_logits[:, 1], directory + '/prob_of_one_cond{}.png'.format(i))
121 | # Second dimension corresponds to different pixel values, so select probs of it being 1
122 | save_image(logits[:, 1], directory + '/prob_of_one_logits{}.png'.format(i))
123 | else:
124 | trainer = Trainer(model, optimizer, device)
125 | progress_imgs = trainer.train(data_loader, epochs, directory=directory)
126 |
127 | # Save losses and plots of them
128 | with open(directory + '/losses.json', 'w') as losses_file:
129 | json.dump(trainer.losses, losses_file)
130 |
131 | # Save model
132 | torch.save(trainer.model.state_dict(), directory + '/model.pt')
133 |
134 | # Save gif of progress
135 | imageio.mimsave(directory + '/training.gif', progress_imgs, fps=24)
136 |
--------------------------------------------------------------------------------
/main_generate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import time
5 | import torch
6 | import torch.nn.functional as F
7 | from PIL import Image
8 | from pixconcnn.generate import generate_images
9 | from utils.loading import load_model
10 | from utils.masks import get_conditional_pixels
11 | from utils.plots import uncertainty_plot, probs_and_conditional_plot
12 | from torchvision.utils import save_image
13 |
14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15 |
16 | # Parse command line arguments
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument('-m', '--model-folder', dest='model_folder', default=None,
19 | help='Path to trained pixel constrained model folder',
20 | required=True)
21 | parser.add_argument('-n', '--name', dest='name', default=None,
22 | help='Name of generation experiment', required=True)
23 | parser.add_argument('-t', '--gen_type', dest='gen_type', default=None,
24 | choices=['generation','logits','uncertainty'],
25 | help='Type of generation', required=True)
26 | parser.add_argument('-i', '--imgs', dest='imgs_idx', default=None,
27 | type=int, nargs='+',
28 | help='List of indices of images to perform generation with.')
29 | parser.add_argument('-ns', '--num-samples', dest='num_samples', default=None,
30 | type=int, help='Number of samples to generate for each image-mask combo.',
31 | required=True)
32 | parser.add_argument('-te', '--temp', dest='temp', default=1.,
33 | type=float, help='Sampling temperature.')
34 | parser.add_argument('-v', '--model-version', dest='model_version', default=None,
35 | type=int, help='Version of model if not using the latest one.')
36 |
37 | parser.add_argument('-nr', '--num-row', dest='num_per_row', default=8,
38 | type=int, help='Number of images per row in grids.')
39 | parser.add_argument('-ni', '--num-iterations', dest='num_iters', default=None,
40 | type=int, help='Only relevant for logits. Number of iterations to plot intermediate logits for.')
41 | parser.add_argument('-r', '--random', dest='random_attribute', default=None,
42 | type=int, help='Number of random pixels to keep unmasked.')
43 | parser.add_argument('-b', '--bottom', dest='bottom_attribute', default=None,
44 | type=int, help='Number of bottom pixels to keep unmasked.')
45 | parser.add_argument('-to', '--top', dest='top_attribute', default=None,
46 | type=int, help='Number of top pixels to keep unmasked.')
47 | parser.add_argument('-c', '--center', dest='center_attribute', default=None,
48 | type=int, help='Number of central pixels to keep unmasked.')
49 | parser.add_argument('-e', '--edge', dest='edge_attribute', default=None,
50 | type=int, help='Number of edge pixels to keep unmasked.')
51 | parser.add_argument('-l', '--left', dest='left_attribute', default=None,
52 | type=int, help='Number of left pixels to keep unmasked.')
53 | parser.add_argument('-ri', '--right', dest='right_attribute', default=None,
54 | type=int, help='Number of right pixels to keep unmasked.')
55 | parser.add_argument('-rb', '--random-blob', dest='blob_attribute', default=None,
56 | type=int, nargs='+', help='First int should be maximum number of blobs, second lower bound on num_iters and third upper bound on num_iters.')
57 | parser.add_argument('-mf', '--mask-folder', dest='folder_attribute', default=None,
58 | help='Mask folder if using a cached mask.')
59 |
60 | # Unpack args
61 | args = parser.parse_args()
62 |
63 | # Create a folder to store generation results
64 | timestamp = time.strftime("%Y-%m-%d_%H-%M")
65 | directory = "gen_{}_{}".format(timestamp, args.name)
66 | if not os.path.exists(directory):
67 | os.makedirs(directory)
68 |
69 | # Save args
70 | with open(directory + '/args.json', 'w') as args_file:
71 | json.dump(vars(args), args_file)
72 |
73 | # Load model
74 | model, data_loader, _ = load_model(args.model_folder, model_version=args.model_version)
75 |
76 | # Convert input arguments to mask_descriptors
77 | mask_descriptors = []
78 | if args.random_attribute is not None:
79 | mask_descriptors.append(('random', args.random_attribute))
80 | if args.bottom_attribute is not None:
81 | mask_descriptors.append(('bottom', args.bottom_attribute))
82 | if args.top_attribute is not None:
83 | mask_descriptors.append(('top', args.top_attribute))
84 | if args.center_attribute is not None:
85 | mask_descriptors.append(('center', args.center_attribute))
86 | if args.edge_attribute is not None:
87 | mask_descriptors.append(('edge', args.edge_attribute))
88 | if args.left_attribute is not None:
89 | mask_descriptors.append(('left', args.left_attribute))
90 | if args.right_attribute is not None:
91 | mask_descriptors.append(('right', args.right_attribute))
92 | if args.blob_attribute is not None:
93 | max_num_blobs, lower_iter, upper_iter = args.blob_attribute
94 | mask_descriptors.append(('random_blob', (max_num_blobs, (lower_iter, upper_iter), 0.5)))
95 | if args.folder_attribute is not None:
96 | mask_descriptors.append(('random_blob_cache', (args.folder_attribute, 1)))
97 |
98 | imgs_idx = args.imgs_idx
99 | num_img = len(imgs_idx)
100 | total_num_imgs = args.num_samples * num_img * len(mask_descriptors)
101 | print("\nGenerating {} samples for {} different images combined with {} masks for a total of {} images".format(args.num_samples, num_img, len(mask_descriptors), total_num_imgs))
102 | print("\nThe masks are {}\n".format(mask_descriptors))
103 |
104 | # Create a batch from the images in imgs_idx
105 | batch = torch.stack([data_loader.dataset[img_idx][0] for img_idx in imgs_idx], dim=0)
106 |
107 | if args.gen_type == 'generation' or args.gen_type == 'uncertainty':
108 | # Generate images with model
109 | outputs = generate_images(model, batch, mask_descriptors,
110 | num_samples=args.num_samples, temp=args.temp,
111 | verbose=True)
112 |
113 | # Save images in folder
114 | for i in range(num_img):
115 | for j in range(len(mask_descriptors)):
116 | output = outputs[i][j]
117 | # Save every output as an image and a pytorch tensor
118 | torch.save(output["orig_img"].cpu(), directory + "/source_{}_{}.pt".format(i, j))
119 | save_image(output["orig_img"].float() / (model.prior_net.num_colors - 1.), directory + "/source_{}_{}.png".format(i, j))
120 | torch.save(output["cond_pixels"][0:1].cpu(), directory + "/cond_pixels_{}_{}.pt".format(i, j))
121 | save_image(output["cond_pixels"][0:1, :3], directory + "/cond_pixels_{}_{}.png".format(i, j))
122 | torch.save(output["mask"].cpu(), directory + "/mask_{}_{}.pt".format(i, j))
123 | save_image(output["mask"], directory + "/mask_{}_{}.png".format(i, j))
124 | save_image(output["samples"].float().mean(dim=0) / (model.prior_net.num_colors - 1.), directory + '/mean_samples_{}_{}.png'.format(i, j))
125 | torch.save(output["samples"].cpu(), directory + "/samples_{}_{}.pt".format(i, j))
126 | torch.save(output["log_probs"].cpu(), directory + "/log_probs_{}_{}.pt".format(i, j))
127 | if args.gen_type == 'generation':
128 | save_image(output["samples"].float() / (model.prior_net.num_colors - 1.), directory + "/samples_{}_{}.png".format(i, j), nrow=args.num_per_row, pad_value=1)
129 | elif args.gen_type == 'uncertainty':
130 | sorted_samples, log_likelihoods = uncertainty_plot(output["samples"], output["log_probs"])
131 | save_image(sorted_samples.float() / (model.prior_net.num_colors - 1.), directory + "/sorted_samples_{}_{}.png".format(i, j), nrow=args.num_per_row, pad_value=1)
132 | save_image(log_likelihoods, directory + "/log_likelihoods_{}_{}.png".format(i, j), pad_value=1, nrow=args.num_per_row)
133 | elif args.gen_type == 'logits': # Note this only works for binary images
134 | if model.prior_net.num_colors != 2:
135 | raise(RuntimeError("Logits generation only works for models with 2 colors. Current model has {} colors.".format(model.prior_net.num_colors)))
136 | # Generate images with model
137 | outputs = generate_images(model, batch, mask_descriptors,
138 | num_samples=args.num_samples, verbose=True,
139 | temp=args.temp)
140 | # Extract info
141 | img_size = model.prior_net.img_size
142 | num_pixels = img_size[1] * img_size[2]
143 | pix_per_iters = num_pixels // args.num_iters # Number of pixels to unmask per iteration
144 | # Save images in folder
145 | for i in range(num_img):
146 | for j in range(len(mask_descriptors)):
147 | output = outputs[i][j]
148 | mask = output["mask"]
149 | mask = mask.expand(args.num_samples, *mask.size()[1:])
150 | samples = output["samples"]
151 | cond_pixels = get_conditional_pixels(samples, mask.float(), 2)
152 | cond_pixels = cond_pixels.to(device)
153 | mask = mask.to(device)
154 | samples = samples.to(device)
155 | logit_plots = {}
156 | for k in range(args.num_iters):
157 | logit_plots[k] = {}
158 | mask_step = mask.clone()
159 | if k != 0: # Do not modify mask for first iteration
160 | # Unmask num_pix_to_unmask pixels in raster order
161 | num_pix_to_unmask = k * pix_per_iters
162 | num_rows = num_pix_to_unmask // img_size[2]
163 | num_cols = num_pix_to_unmask - num_rows * img_size[2]
164 | if num_rows != 0:
165 | mask_step[:, :, :num_rows, :] = 1
166 | if num_cols != 0:
167 | mask_step[:, :, num_rows, :num_cols] = 1
168 | # Calculate logits with updated mask
169 | logits, prior_logits, cond_logits = model(samples.float() * mask_step.float(), cond_pixels)
170 | probs = F.softmax(logits.detach(), dim=1)
171 | # Create a plot for each sample
172 | for l in range(args.num_samples):
173 | logit_plot = probs_and_conditional_plot(samples[l].cpu(),
174 | probs[l, 1, 0].cpu(),
175 | mask_step[0, 0].cpu())
176 | logit_plots[k][l] = logit_plot
177 |
178 | # Plot iterations as a grid for each sample
179 | for l in range(args.num_samples):
180 | logit_grid = []
181 | for k in range(args.num_iters):
182 | logit_grid.append(logit_plots[k][l])
183 | stacked_images = torch.stack(logit_grid, dim=0)
184 | save_image(stacked_images, directory + "/logit_plot_{}_{}_{}.png".format(i, j, l), pad_value=1)
185 | torch.save(stacked_images.cpu(), directory + "/logit_plot_{}_{}_{}.pt".format(i, j, l))
186 |
--------------------------------------------------------------------------------
/pixconcnn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/pixconcnn/__init__.py
--------------------------------------------------------------------------------
/pixconcnn/generate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch import device as torch_device
3 | from torch import zeros as torch_zeros
4 | from torch.cuda import is_available as cuda_is_available
5 | from torchvision.utils import make_grid
6 | from utils.masks import MaskGenerator, get_repeated_conditional_pixels
7 |
8 |
9 | def generate_images(model, batch, mask_descriptors, num_samples=64, temp=1.,
10 | verbose=False):
11 | """Generates image completions based on the images in batch masked by the
12 | masks in mask_descriptors. This will generate
13 | batch.size(0) * len(mask_descriptors) * num_samples completions, i.e.
14 | num_samples completions for every image and mask combination.
15 |
16 | Parameters
17 | ----------
18 | model : pixconcnn.models.pixel_constrained.PixelConstrained instance
19 |
20 | batch : torch.Tensor
21 |
22 | mask_descriptors : list of mask_descriptor
23 | See utils.masks.MaskGenerator for allowed mask_descriptors.
24 |
25 | num_samples : int
26 | Number of samples to generate for a given image-mask combination.
27 |
28 | temp : float
29 | Temperature for sampling.
30 |
31 | verbose : bool
32 | If True prints progress information while generating images
33 | """
34 | device = torch_device("cuda" if cuda_is_available() else "cpu")
35 | model.to(device)
36 | outputs = []
37 | for i in range(batch.size(0)):
38 | outputs_per_img = []
39 | for j in range(len(mask_descriptors)):
40 | if verbose:
41 | print("Generating samples for image {} using mask {}".format(i, mask_descriptors[j]))
42 | # Get image and mask combination
43 | img = batch[i:i+1]
44 | mask_generator = MaskGenerator(model.prior_net.img_size, mask_descriptors[j])
45 | mask = mask_generator.get_masks(1)
46 | # Create conditional pixels which will be used to sample completions
47 | cond_pixels = get_repeated_conditional_pixels(img, mask, model.prior_net.num_colors, num_samples)
48 | cond_pixels = cond_pixels.to(device)
49 | samples, log_probs = model.sample(cond_pixels, return_likelihood=True, temp=temp)
50 | outputs_per_img.append({
51 | "orig_img": img,
52 | "cond_pixels": cond_pixels,
53 | "mask": mask,
54 | "samples": samples,
55 | "log_probs": log_probs
56 | })
57 | outputs.append(outputs_per_img)
58 | return outputs
59 |
--------------------------------------------------------------------------------
/pixconcnn/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ResidualBlock(nn.Module):
6 | """
7 | Residual block (note that number of in_channels and out_channels must be
8 | the same).
9 |
10 | Parameters
11 | ----------
12 | in_channels : int
13 |
14 | out_channels : int
15 |
16 | kernel_size : int or tuple of ints
17 |
18 | stride : int
19 |
20 | padding : int
21 | """
22 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
23 | super(ResidualBlock, self).__init__()
24 |
25 | self.convs = nn.Sequential(
26 | nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1,
27 | padding=0),
28 | nn.BatchNorm2d(in_channels),
29 | nn.ReLU(True),
30 | nn.Conv2d(in_channels, out_channels,
31 | kernel_size=kernel_size, stride=stride,
32 | padding=padding),
33 | nn.BatchNorm2d(out_channels),
34 | nn.ReLU(True),
35 | nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1,
36 | padding=0),
37 | nn.BatchNorm2d(out_channels),
38 | nn.ReLU(True)
39 | )
40 |
41 | def forward(self, x):
42 | # In and out channels should be the same
43 | return x + self.convs(x)
44 |
45 |
46 | class MaskedConv2d(nn.Conv2d):
47 | """
48 | Implements various 2d masked convolutions.
49 |
50 | Parameters
51 | ----------
52 | mask_type : string
53 | Defines the type of mask to use. One of 'A', 'A_Red', 'A_Green',
54 | 'A_Blue', 'B', 'B_Red', 'B_Green', 'B_Blue', 'H', 'H_Red', 'H_Green',
55 | 'H_Blue', 'HR', 'HR_Red', 'HR_Green', 'HR_Blue', 'V', 'VR'.
56 | """
57 | def __init__(self, mask_type, *args, **kwargs):
58 | super(MaskedConv2d, self).__init__(*args, **kwargs)
59 | self.mask_type = mask_type
60 |
61 | # Initialize mask
62 | mask = torch.zeros(*self.weight.size())
63 | _, kernel_c, kernel_h, kernel_w = self.weight.size()
64 | # If using a color mask, the number of channels must be divisible by 3
65 | if mask_type.endswith('Red') or mask_type.endswith('Green') or mask_type.endswith('Blue'):
66 | assert kernel_c % 3 == 0
67 | # If using a horizontal mask, the kernel height must be 1
68 | if mask_type.startswith('H'):
69 | assert kernel_h == 1 # Kernel should have shape (1, kernel_w)
70 |
71 | if mask_type == 'A':
72 | # For 3 by 3 kernel, this would be:
73 | # 1 1 1
74 | # 1 0 0
75 | # 0 0 0
76 | mask[:, :, kernel_h // 2, :kernel_w // 2] = 1.
77 | mask[:, :, :kernel_h // 2, :] = 1.
78 | elif mask_type == 'A_Red':
79 | # Mask type A for red channels. Same as regular mask A
80 | if kernel_h == 1 and kernel_w == 1:
81 | pass # Mask is all zeros for 1x1 convolution
82 | else:
83 | mask[:, :, kernel_h // 2, :kernel_w // 2] = 1.
84 | mask[:, :, :kernel_h // 2, :] = 1.
85 | elif mask_type == 'A_Green':
86 | # Mask type A for green channels. Same as regular mask A, except
87 | # central pixel of first third of channels is 1
88 | if kernel_h == 1 and kernel_w == 1:
89 | mask[:, :kernel_c // 3, 0, 0] = 1.
90 | else:
91 | mask[:, :, kernel_h // 2, :kernel_w // 2] = 1.
92 | mask[:, :, :kernel_h // 2, :] = 1.
93 | mask[:, :kernel_c // 3, kernel_h // 2, kernel_w // 2] = 1.
94 | elif mask_type == 'A_Blue':
95 | # Mask type A for blue channels. Same as regular mask A, except
96 | # central pixel of first two thirds of channels is 1
97 | if kernel_h == 1 and kernel_w == 1:
98 | mask[:, :2 * kernel_c // 3, 0, 0] = 1.
99 | else:
100 | mask[:, :, kernel_h // 2, :kernel_w // 2] = 1.
101 | mask[:, :, :kernel_h // 2, :] = 1.
102 | mask[:, :2 * kernel_c // 3, kernel_h // 2, kernel_w // 2] = 1.
103 | elif mask_type == 'B':
104 | # For 3 by 3 kernel, this would be:
105 | # 1 1 1
106 | # 1 1 0
107 | # 0 0 0
108 | mask[:, :, kernel_h // 2, :kernel_w // 2 + 1] = 1.
109 | mask[:, :, :kernel_h // 2, :] = 1.
110 | elif mask_type == 'B_Red':
111 | # Mask type B for red channels. Same as regular mask B, except last
112 | # two thirds of channels of central pixels are 0. Alternatively,
113 | # same as Mask A but with first third of channels of central pixels
114 | # are 1
115 | if kernel_h == 1 and kernel_w == 1:
116 | mask[:, :kernel_c // 3, 0, 0] = 1.
117 | else:
118 | mask[:, :, kernel_h // 2, :kernel_w // 2] = 1.
119 | mask[:, :, :kernel_h // 2, :] = 1.
120 | mask[:, :kernel_c // 3, kernel_h // 2, kernel_w // 2] = 1.
121 | elif mask_type == 'B_Green':
122 | # Mask type B for green channels. Same as regular mask B, except
123 | # last third of channels of central pixels are 0
124 | if kernel_h == 1 and kernel_w == 1:
125 | mask[:, :2 * kernel_c // 3, 0, 0] = 1.
126 | else:
127 | mask[:, :, kernel_h // 2, :kernel_w // 2] = 1.
128 | mask[:, :, :kernel_h // 2, :] = 1.
129 | mask[:, :2 * kernel_c // 3, kernel_h // 2, kernel_w // 2] = 1.
130 | elif mask_type == 'B_Blue':
131 | # Mask type B for blue channels. Same as regular mask B
132 | if kernel_h == 1 and kernel_w == 1:
133 | mask[:, :, 0, 0] = 1.
134 | else:
135 | mask[:, :, kernel_h // 2, :kernel_w // 2] = 1.
136 | mask[:, :, :kernel_h // 2, :] = 1.
137 | mask[:, :, kernel_h // 2, kernel_w // 2] = 1.
138 | elif mask_type == 'H':
139 | # For 3 by 3 kernel, this would be:
140 | # 1 1 0
141 | # Mask for horizontal stack in regular gated conv
142 | mask[:, :, 0, :kernel_w // 2 + 1] = 1.
143 | elif mask_type == 'H_Red':
144 | mask[:, :, 0, :kernel_w // 2] = 1.
145 | mask[:, :kernel_c // 3, 0, kernel_w // 2] = 1.
146 | elif mask_type == 'H_Green':
147 | mask[:, :, 0, :kernel_w // 2] = 1.
148 | mask[:, :2 * kernel_c // 3, 0, kernel_w // 2] = 1.
149 | elif mask_type == 'H_Blue':
150 | mask[:, :, 0, :kernel_w // 2] = 1.
151 | mask[:, :, 0, kernel_w // 2] = 1.
152 | elif mask_type == 'HR':
153 | # For 3 by 3 kernel, this would be:
154 | # 1 0 0
155 | # Mask for horizontal stack in restricted gated conv
156 | mask[:, :, 0, :kernel_w // 2] = 1.
157 | elif mask_type == 'HR_Red':
158 | mask[:, :, 0, :kernel_w // 2] = 1.
159 | elif mask_type == 'HR_Green':
160 | mask[:, :, 0, :kernel_w // 2] = 1.
161 | mask[:, :kernel_c // 3, 0, kernel_w // 2] = 1.
162 | elif mask_type == 'HR_Blue':
163 | mask[:, :, 0, :kernel_w // 2] = 1.
164 | mask[:, :2 * kernel_c // 3, 0, kernel_w // 2] = 1.
165 | elif mask_type == 'V':
166 | # For 3 by 3 kernel, this would be:
167 | # 1 1 1
168 | # 1 1 1
169 | # 0 0 0
170 | mask[:, :, :kernel_h // 2 + 1, :] = 1.
171 | elif mask_type == 'VR':
172 | # For 3 by 3 kernel, this would be:
173 | # 1 1 1
174 | # 0 0 0
175 | # 0 0 0
176 | mask[:, :, :kernel_h // 2, :] = 1.
177 |
178 | # Register buffer adds a key to the state dict of the model. This will
179 | # track the attribute without registering it as a learnable parameter.
180 | # We require this since mask will be used in the forward pass.
181 | self.register_buffer('mask', mask)
182 |
183 | def forward(self, x):
184 | self.weight.data *= self.mask
185 | return super(MaskedConv2d, self).forward(x)
186 |
187 |
188 | class MaskedConvRGB(nn.Module):
189 | """
190 | Masked convolution with RGB channel splitting.
191 |
192 | Parameters
193 | ----------
194 | mask_type : string
195 | One of 'A', 'B', 'V' or 'H'.
196 |
197 | in_channels : int
198 | Must be divisible by 3
199 |
200 | out_channels : int
201 | Must be divisible by 3
202 |
203 | kernel_size : int or tuple of ints
204 |
205 | stride : int
206 |
207 | padding : int
208 |
209 | bias : bool
210 | If True adds a bias term to the convolution.
211 | """
212 | def __init__(self, mask_type, in_channels, out_channels, kernel_size,
213 | stride, padding, bias):
214 | super(MaskedConvRGB, self).__init__()
215 |
216 | self.conv_R = MaskedConv2d(mask_type + '_Red', in_channels,
217 | out_channels // 3, kernel_size,
218 | stride=stride, padding=padding, bias=bias)
219 | self.conv_G = MaskedConv2d(mask_type + '_Green', in_channels,
220 | out_channels // 3, kernel_size,
221 | stride=stride, padding=padding, bias=bias)
222 | self.conv_B = MaskedConv2d(mask_type + '_Blue', in_channels,
223 | out_channels // 3, kernel_size,
224 | stride=stride, padding=padding, bias=bias)
225 |
226 | def forward(self, x):
227 | out_red = self.conv_R(x)
228 | out_green = self.conv_G(x)
229 | out_blue = self.conv_B(x)
230 | return torch.cat([out_red, out_green, out_blue], dim=1)
231 |
232 |
233 | class GatedConvBlock(nn.Module):
234 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
235 | restricted=False):
236 | """Gated PixelCNN convolutional block. Note the number of input and
237 | output channels must be the same, unless restricted is True.
238 |
239 | Parameters
240 | ----------
241 | in_channels : int
242 |
243 | out_channels : int
244 |
245 | kernel_size : int
246 | Note this MUST be int and not tuple like in regular convs.
247 |
248 | stride : int
249 |
250 | padding : int
251 |
252 | restricted : bool
253 | If True, uses restricted masks, otherwise uses regular masks.
254 | """
255 | super(GatedConvBlock, self).__init__()
256 |
257 | assert type(kernel_size) is int
258 |
259 | self.restricted = restricted
260 |
261 | if restricted:
262 | vertical_mask = 'VR'
263 | horizontal_mask = 'HR'
264 | else:
265 | vertical_mask = 'V'
266 | horizontal_mask = 'H'
267 |
268 | self.vertical_conv = MaskedConv2d(vertical_mask, in_channels,
269 | 2 * out_channels, kernel_size,
270 | stride=stride, padding=padding)
271 |
272 | self.horizontal_conv = MaskedConv2d(horizontal_mask, in_channels,
273 | 2 * out_channels, (1, kernel_size),
274 | stride=stride, padding=(0, padding))
275 |
276 | self.vertical_to_horizontal = nn.Conv2d(2 * out_channels,
277 | 2 * out_channels, 1)
278 |
279 | self.horizontal_conv_2 = nn.Conv2d(out_channels, out_channels, 1)
280 |
281 | def forward(self, v_input, h_input):
282 | # Vertical stack
283 | v_conv = self.vertical_conv(v_input)
284 | v_out = gated_activation(v_conv)
285 | # Vertical to horizontal
286 | v_to_h = self.vertical_to_horizontal(v_conv)
287 | # Horizontal stack
288 | h_conv = self.horizontal_conv(h_input)
289 | h_conv_activation = gated_activation(h_conv + v_to_h)
290 | h_conv2 = self.horizontal_conv_2(h_conv_activation)
291 | if self.restricted:
292 | h_out = h_conv2
293 | else:
294 | h_out = h_conv2 + h_input
295 | return v_out, h_out
296 |
297 |
298 | class GatedConvBlockRGB(nn.Module):
299 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
300 | restricted=False):
301 | """Gated PixelCNN convolutional block for RGB images. Note the number of
302 | input and output channels must be the same, unless restricted is True.
303 |
304 | Parameters
305 | ----------
306 | in_channels : int
307 |
308 | out_channels : int
309 |
310 | kernel_size : int
311 | Note this MUST be int and not tuple like in regular convs.
312 |
313 | stride : int
314 |
315 | padding : int
316 |
317 | restricted : bool
318 | If True, uses restricted masks, otherwise uses regular masks.
319 | """
320 | super(GatedConvBlockRGB, self).__init__()
321 |
322 | assert type(kernel_size) is int
323 |
324 | self.restricted = restricted
325 | self.out_channels = out_channels
326 |
327 | if restricted:
328 | vertical_mask = 'VR'
329 | horizontal_mask = 'HR'
330 | else:
331 | vertical_mask = 'V'
332 | horizontal_mask = 'H'
333 |
334 | self.vertical_conv = MaskedConv2d(vertical_mask, in_channels,
335 | 2 * out_channels, kernel_size,
336 | stride=stride, padding=padding)
337 |
338 | self.horizontal_conv = MaskedConvRGB(horizontal_mask, in_channels,
339 | 2 * out_channels, (1, kernel_size),
340 | stride=stride,
341 | padding=(0, padding), bias=True)
342 |
343 | self.vertical_to_horizontal = nn.Conv2d(2 * out_channels,
344 | 2 * out_channels, 1)
345 |
346 | self.horizontal_conv_2 = MaskedConvRGB('B', out_channels, out_channels,
347 | (1, 1), stride=1, padding=0,
348 | bias=True)
349 |
350 | def forward(self, v_input, h_input):
351 | # Vertical stack
352 | v_conv = self.vertical_conv(v_input)
353 | v_out = gated_activation(v_conv)
354 | # Vertical to horizontal
355 | v_to_h = self.vertical_to_horizontal(v_conv)
356 | # Horizontal stack
357 | h_conv = self.horizontal_conv(h_input) + v_to_h
358 | # Gated activation must be applied for the R, G and B part of the
359 | # convolutional volume separately to avoid information from different
360 | # channels leaking
361 | channels_third = 2 * self.out_channels // 3
362 | h_conv_activation_R = gated_activation(h_conv[:, :channels_third])
363 | h_conv_activation_G = gated_activation(h_conv[:, channels_third:2 * channels_third])
364 | h_conv_activation_B = gated_activation(h_conv[:, 2 * channels_third:])
365 | h_conv_activation = torch.cat([h_conv_activation_R,
366 | h_conv_activation_G,
367 | h_conv_activation_B],
368 | dim=1)
369 | # 1 by 1 convolution on horizontal stack
370 | h_conv2 = self.horizontal_conv_2(h_conv_activation)
371 | if self.restricted:
372 | h_out = h_conv2
373 | else:
374 | h_out = h_conv2 + h_input
375 | return v_out, h_out
376 |
377 |
378 | def gated_activation(input_vol):
379 | """Applies a gated activation to the convolutional volume. Note that this
380 | activation divides the number of channels by 2.
381 |
382 | Parameters
383 | ----------
384 | input_vol : torch.Tensor
385 | Input convolutional volume. Shape (batch_size, channels, height, width)
386 | Note that number of channels must be even.
387 |
388 | Returns
389 | -------
390 | output_vol of shape (batch_size, channels // 2, height, width)
391 | """
392 | # Extract number of channels from input volume
393 | channels = input_vol.size(1)
394 | # Get activations for first and second half of volume
395 | tanh_activation = torch.tanh(input_vol[:, channels // 2:])
396 | sigmoid_activation = torch.sigmoid(input_vol[:, :channels // 2])
397 | return tanh_activation * sigmoid_activation
398 |
--------------------------------------------------------------------------------
/pixconcnn/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/pixconcnn/models/__init__.py
--------------------------------------------------------------------------------
/pixconcnn/models/cnn.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from pixconcnn.layers import ResidualBlock
3 |
4 |
5 | class ResNet(nn.Module):
6 | """ResNet (with regular unmasked convolutions) mapping an conditional pixel
7 | inputs to logits.
8 |
9 | Parameters
10 | ----------
11 | img_size : tuple of ints
12 | Shape of input image. Note that since mask is appended to the masked
13 | image, the img_size given should be (num_channels + 1, height, width).
14 | Note, the extra channel used to store the mask.
15 |
16 | num_colors : int
17 | Number of colors to quantize output into. Typically 256, but can be
18 | lower for e.g. binary images.
19 |
20 | num_filters : int
21 | Number of filters for each convolution layer in model.
22 |
23 | depth : int
24 | Number of layers of model. Must be at least 2 to have an input and
25 | output layer.
26 |
27 | filter_size : int
28 | Size of convolutional filters.
29 | """
30 | def __init__(self, img_size=(2, 32, 32), num_colors=256, num_filters=32,
31 | depth=17, filter_size=5):
32 | super(ResNet, self).__init__()
33 |
34 | self.depth = depth
35 | self.filter_size = filter_size
36 | self.padding = (filter_size - 1) // 2
37 | self.img_size = img_size
38 | self.num_channels = img_size[0] - 1 # Only output logits for color channels, not mask channel
39 | self.num_colors = num_colors
40 | self.num_filters = num_filters
41 |
42 | layers = [nn.Conv2d(self.num_channels + 1, self.num_filters,
43 | self.filter_size, stride=1, padding=self.padding)]
44 |
45 | for _ in range(self.depth - 2):
46 | layers.append(
47 | ResidualBlock(self.num_filters, self.num_filters,
48 | self.filter_size, stride=1, padding=self.padding)
49 | )
50 |
51 | # Final layer to output logits
52 | layers.append(
53 | nn.Conv2d(self.num_filters, self.num_colors * self.num_channels, 1)
54 | )
55 |
56 | self.img_to_pixel_logits = nn.Sequential(*layers)
57 |
58 | def forward(self, x):
59 | _, height, width = self.img_size
60 | # Shape (batch, output_channels, height, width)
61 | logits = self.img_to_pixel_logits(x)
62 | # Shape (batch, num_colors, channels, height, width)
63 | return logits.view(-1, self.num_colors, self.num_channels, height, width)
64 |
--------------------------------------------------------------------------------
/pixconcnn/models/gated_pixelcnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from abc import ABCMeta, abstractmethod
5 | from pixconcnn.layers import GatedConvBlock, GatedConvBlockRGB, MaskedConvRGB
6 |
7 |
8 | class PixelCNNBaseClass(nn.Module):
9 | """Abstract class defining PixelCNN sampling which is the same for both
10 | single channel and RGB models.
11 | """
12 | __metaclass__ = ABCMeta
13 | @abstractmethod
14 | def forward(self):
15 | """Forward method is implemented in child class (GatedPixelCNN or
16 | GatedPixelCNNRGB)."""
17 | pass
18 |
19 | def sample(self, device, num_samples=16, temp=1., return_likelihood=False):
20 | """Generates samples from a GatedPixelCNN or GatedPixelCNNRGB.
21 |
22 | Parameters
23 | ----------
24 | device : torch.device instance
25 |
26 | num_samples : int
27 | Number of samples to generate
28 |
29 | temp : float
30 | Temperature of softmax distribution. Temperatures larger than 1
31 | make the distribution more uniform while temperatures lower than
32 | 1 make the distribution more peaky.
33 |
34 | return_likelihood : bool
35 | If True returns the log likelihood of the samples according to the
36 | model.
37 | """
38 | # Set model to evaluation mode
39 | self.eval()
40 |
41 | samples = torch.zeros(num_samples, *self.img_size)
42 | samples = samples.to(device)
43 | channels, height, width = self.img_size
44 |
45 | # Sample pixel intensities from a batch of probability distributions
46 | # for each pixel in each channel
47 | with torch.no_grad():
48 | for i in range(height):
49 | for j in range(width):
50 | for k in range(channels):
51 | logits = self.forward(samples)
52 | probs = F.softmax(logits / temp, dim=1)
53 | # Note that probs has shape
54 | # (batch, num_colors, channels, height, width)
55 | pixel_val = torch.multinomial(probs[:, :, k, i, j], 1)
56 | # The pixel intensities will be given by 0, 1, 2, ..., so
57 | # normalize these to be in 0 - 1 range as this is what the
58 | # model expects. Note that pixel_val has shape (batch, 1)
59 | # so remove last dimension
60 | samples[:, k, i, j] = pixel_val[:, 0].float() / (self.num_colors - 1)
61 |
62 | # Reset model to train mode
63 | self.train()
64 |
65 | # Unnormalize pixels
66 | samples = (samples * (self.num_colors - 1)).long()
67 |
68 | if return_likelihood:
69 | return samples.cpu(), self.log_likelihood(device, samples).cpu()
70 | else:
71 | return samples.cpu()
72 |
73 | def log_likelihood(self, device, samples):
74 | """Calculates log likelihood of samples under model.
75 |
76 | Parameters
77 | ----------
78 | device : torch.device instance
79 |
80 | samples : torch.Tensor
81 | Batch of images. Shape (batch_size, num_channels, width, height).
82 | Values should be integers in [0, self.prior_net.num_colors - 1].
83 | """
84 | # Set model to evaluation mode
85 | self.eval()
86 |
87 | num_samples, num_channels, height, width = samples.size()
88 | log_probs = torch.zeros(num_samples)
89 | log_probs = log_probs.to(device)
90 |
91 | # Normalize samples before passing through model
92 | norm_samples = samples.float() / (self.num_colors - 1)
93 | # Calculate pixel probs according to the model
94 | logits = self.forward(norm_samples)
95 | # Note that probs has shape
96 | # (batch, num_colors, channels, height, width)
97 | probs = F.softmax(logits, dim=1)
98 |
99 | # Calculate probability of each pixel
100 | for i in range(height):
101 | for j in range(width):
102 | for k in range(num_channels):
103 | # Get the batch of true values at pixel (k, i, j)
104 | true_vals = samples[:, k, i, j]
105 | # Get probability assigned by model to true pixel
106 | probs_pixel = probs[:, true_vals, k, i, j][:, 0]
107 | # Add log probs (1e-9 to avoid log(0))
108 | log_probs += torch.log(probs_pixel + 1e-9)
109 |
110 | # Reset model to train mode
111 | self.train()
112 |
113 | return log_probs
114 |
115 |
116 |
117 | class GatedPixelCNN(PixelCNNBaseClass):
118 | """Gated PixelCNN model for single channel images.
119 |
120 | Parameters
121 | ----------
122 | img_size : tuple of ints
123 | Shape of input image. E.g. (1, 32, 32)
124 |
125 | num_colors : int
126 | Number of colors to quantize output into. Typically 256, but can be
127 | lower for e.g. binary images.
128 |
129 | num_filters : int
130 | Number of filters for each convolution layer in model.
131 |
132 | depth : int
133 | Number of layers of model. Must be at least 2 to have an input and
134 | output layer.
135 |
136 | filter_size : int
137 | Size of convolutional filters.
138 | """
139 | def __init__(self, img_size=(1, 32, 32), num_colors=256, num_filters=64,
140 | depth=17, filter_size=5):
141 | super(GatedPixelCNN, self).__init__()
142 |
143 | self.depth = depth
144 | self.filter_size = filter_size
145 | self.padding = (filter_size - 1) // 2
146 | self.img_size = img_size
147 | self.num_channels = img_size[0]
148 | self.num_colors = num_colors
149 | self.num_filters = num_filters
150 |
151 | # First layer is restricted (to avoid self-dependency on input pixel)
152 | self.input_to_stacks = GatedConvBlock(self.num_channels,
153 | self.num_filters,
154 | self.filter_size,
155 | stride=1,
156 | padding=self.padding,
157 | restricted=True)
158 |
159 | # Subsequent layers are regular gated blocks for vertical and horizontal
160 | # stack
161 | gated_stacks = []
162 | # -2 since we are not counting first and last layer
163 | for _ in range(self.depth - 2):
164 | gated_stacks.append(
165 | GatedConvBlock(self.num_filters, self.num_filters,
166 | self.filter_size, stride=1, padding=self.padding)
167 | )
168 | self.gated_stacks = nn.ModuleList(gated_stacks)
169 |
170 | # Final layer to output logits
171 | self.stacks_to_pixel_logits = nn.Conv2d(self.num_filters, self.num_colors * self.num_channels, 1)
172 |
173 | def forward(self, x):
174 | # Restricted gated layer
175 | vertical, horizontal = self.input_to_stacks(x, x)
176 | # Iterate over gated layers
177 | for gated_block in self.gated_stacks:
178 | vertical, horizontal = gated_block(vertical, horizontal)
179 | # Output logits from the horizontal stack (i.e. the stack which receives
180 | # information from both horizontal and vertical stack)
181 | logits = self.stacks_to_pixel_logits(horizontal)
182 |
183 | # Reshape logits
184 | _, height, width = self.img_size
185 | # Shape (batch, output_channels, height, width) ->
186 | # (batch, num_colors, channels, height, width)
187 | return logits.view(-1, self.num_colors, self.num_channels, height, width)
188 |
189 |
190 | class GatedPixelCNNRGB(PixelCNNBaseClass):
191 | """Gated PixelCNN model for RGB images.
192 |
193 | Parameters
194 | ----------
195 | img_size : tuple of ints
196 | Shape of input image. E.g. (3, 32, 32)
197 |
198 | num_colors : int
199 | Number of colors to quantize output into. Typically 256, but can be
200 | lower for e.g. binary images.
201 |
202 | num_filters : int
203 | Number of filters for each convolution layer in model.
204 |
205 | depth : int
206 | Number of layers of model. Must be at least 2 to have an input and
207 | output layer.
208 |
209 | filter_size : int
210 | Size of convolutional filters.
211 | """
212 | def __init__(self, img_size=(3, 32, 32), num_colors=256, num_filters=64,
213 | depth=17, filter_size=5):
214 | super(GatedPixelCNNRGB, self).__init__()
215 |
216 | self.depth = depth
217 | self.filter_size = filter_size
218 | self.padding = (filter_size - 1) / 2
219 | self.img_size = img_size
220 | self.num_channels = img_size[0]
221 | self.num_colors = num_colors
222 | self.num_filters = num_filters
223 |
224 | # First layer is restricted (to avoid self-dependency on input pixel)
225 | self.input_to_stacks = GatedConvBlockRGB(self.num_channels,
226 | self.num_filters,
227 | self.filter_size,
228 | stride=1,
229 | padding=self.padding,
230 | restricted=True)
231 |
232 | # Subsequent layers are regular gated blocks for vertical and horizontal
233 | # stack
234 | gated_stacks = []
235 | # -2 since we are not counting first and last layer
236 | for _ in range(self.depth - 2):
237 | gated_stacks.append(
238 | GatedConvBlockRGB(self.num_filters, self.num_filters,
239 | self.filter_size, stride=1,
240 | padding=self.padding)
241 | )
242 | self.gated_stacks = nn.ModuleList(gated_stacks)
243 |
244 | # Final layer to output logits
245 | self.stacks_to_pixel_logits = nn.Sequential(
246 | MaskedConvRGB('B', self.num_filters, 1023, (1, 1), stride=1, padding=0, bias=True),
247 | nn.ReLU(True),
248 | MaskedConvRGB('B', 1023, self.num_colors * self.num_channels, (1, 1), stride=1, padding=0, bias=True)
249 | )
250 |
251 | def forward(self, x):
252 | # Restricted gated layer
253 | vertical, horizontal = self.input_to_stacks(x, x)
254 | # Several gated layers
255 | for gated_block in self.gated_stacks:
256 | vertical, horizontal = gated_block(vertical, horizontal)
257 | # Output logits from the horizontal stack (i.e. the stack which receives
258 | # information from both horizontal and vertical stack)
259 | logits = self.stacks_to_pixel_logits(horizontal)
260 |
261 | # Reshape logits maintaining order between R, G and B channels
262 | _, height, width = self.img_size
263 | logits_red = logits[:, :self.num_colors]
264 | logits_green = logits[:, self.num_colors:2 * self.num_colors]
265 | logits_blue = logits[:, 2 * self.num_colors:]
266 | logits_red = logits_red.view(-1, self.num_colors, 1, height, width)
267 | logits_green = logits_green.view(-1, self.num_colors, 1, height, width)
268 | logits_blue = logits_blue.view(-1, self.num_colors, 1, height, width)
269 | # Shape (batch, num_colors, channels, height, width)
270 | return torch.cat([logits_red, logits_green, logits_blue], dim=2)
271 |
--------------------------------------------------------------------------------
/pixconcnn/models/pixel_constrained.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class PixelConstrained(nn.Module):
7 | """Pixel Constrained CNN model.
8 |
9 | Parameters
10 | ----------
11 | prior_net : pixconcnn.models.gated_pixelcnn.GatedPixelCNN(RGB) instance
12 | Model defining the prior network.
13 |
14 | cond_net : pixconcnn.models.cnn.ResNet instance
15 | Model defining the conditioning network.
16 | """
17 | def __init__(self, prior_net, cond_net):
18 |
19 | super(PixelConstrained, self).__init__()
20 |
21 | self.prior_net = prior_net
22 | self.cond_net = cond_net
23 |
24 | def forward(self, x, x_cond):
25 | """
26 | x : torch.Tensor
27 | Image to predict logits for.
28 |
29 | x_cond : torch.Tensor
30 | Image containing pixels to be conditioned on. The pixels which are
31 | being conditioned on should retain their usual value, while all
32 | other pixels should be set to 0. In addition the mask should be
33 | appended to known pixels such that the shape of x_cond is
34 | (batch_size, channels + 1, height, width). For more details, see
35 | utils.masks.get_conditional_pixels function.
36 | """
37 | prior_logits = self.prior_net(x)
38 | cond_logits = self.cond_net(x_cond)
39 | logits = prior_logits + cond_logits
40 | return logits, prior_logits, cond_logits
41 |
42 | def sample(self, x_cond, temp=1., return_likelihood=False):
43 | """Generate conditional samples from the model. The number of samples
44 | generated will be equal to the batch size of x_cond.
45 |
46 | Parameters
47 | ----------
48 | x_cond : torch.Tensor
49 | Tensor containing the conditioning pixels. Should have shape
50 | (num_samples, channels + 1, height, width).
51 |
52 | temp : float
53 | Temperature of softmax distribution. Temperatures larger than 1
54 | make the distribution more uniforman while temperatures lower than
55 | 1 make the distribution more peaky.
56 |
57 | return_likelihood : bool
58 | If True returns the log likelihood of the samples according to the
59 | model.
60 | """
61 | # Set model to evaluation mode
62 | self.eval()
63 |
64 | # "Channel dimension" of x_cond has size channels + 1, so decrease this
65 | num_samples, channels_plus_mask, height, width = x_cond.size()
66 | channels = channels_plus_mask - 1
67 |
68 | # Samples to be generated
69 | samples = torch.zeros((num_samples, channels, height, width))
70 | # Move samples to same device as conditional tensor
71 | samples = samples.to(x_cond.device)
72 | num_colors = self.prior_net.num_colors
73 |
74 | # Sample pixel intensities from a batch of probability distributions
75 | # for each pixel in each channel
76 | with torch.no_grad():
77 | for i in range(height):
78 | for j in range(width):
79 | # The unmasked pixels are the ones where the mask has
80 | # nonzero value
81 | unmasked = x_cond[:, -1, i, j] > 0
82 | # If (i, j)th pixel is known for all images in batch (i.e.
83 | # if all values in unmasked are True), do not perform
84 | # forward pass of the model
85 | sample_pixel = True
86 | if unmasked.long().sum().item() == num_samples:
87 | sample_pixel = False
88 | for k in range(channels):
89 | if sample_pixel:
90 | logits, _, _ = self.forward(samples, x_cond)
91 | probs = F.softmax(logits / temp, dim=1)
92 | # Note that probs has shape
93 | # (batch, num_colors, channels, height, width)
94 | pixel_val = torch.multinomial(probs[:, :, k, i, j], 1)
95 | # The pixel intensities will be given by 0, 1, 2, ..., so
96 | # normalize these to be in 0 - 1 range as this is what the
97 | # model expects
98 | samples[:, k, i, j] = pixel_val[:, 0].float() / (num_colors - 1)
99 | # Set all the unmasked pixels to the value they are
100 | # conditioned on
101 | samples[:, k, i, j][unmasked] = x_cond[:, k, i, j][unmasked]
102 |
103 | # Reset model to train mode
104 | self.train()
105 |
106 | # Unnormalize pixels
107 | samples = (samples * (num_colors - 1)).long()
108 |
109 | if return_likelihood:
110 | return samples.cpu(), self.log_likelihood(samples, x_cond).cpu()
111 | else:
112 | return samples.cpu()
113 |
114 | def sample_unconditional(self, device, num_samples=16):
115 | """Samples from prior model without conditioning."""
116 | return self.prior_net.sample(device, num_samples)
117 |
118 | def log_likelihood(self, samples, x_cond):
119 | """Calculates log likelihood of samples under model.
120 |
121 | Parameters
122 | ----------
123 | samples : torch.Tensor
124 | Batch of images. Shape (batch_size, num_channels, width, height).
125 | Values should be integers in [0, self.prior_net.num_colors - 1].
126 |
127 | x_cond : torch.Tensor
128 | Batch of conditional pixels.
129 | Shape (batch_size, num_channels + 1, width, height).
130 | """
131 | # Set model to evaluation mode
132 | self.eval()
133 |
134 | num_samples, num_channels, height, width = samples.size()
135 | log_probs = torch.zeros(num_samples)
136 | log_probs = log_probs.to(x_cond.device)
137 |
138 | # Normalize samples before passing through model
139 | norm_samples = samples.float() / (self.prior_net.num_colors - 1)
140 | # Calculate pixel probs according to the model
141 | logits, _, _ = self.forward(norm_samples, x_cond)
142 | # Note that probs has shape
143 | # (batch, num_colors, channels, height, width)
144 | probs = F.softmax(logits, dim=1)
145 |
146 | # Calculate probability of each pixel
147 | for i in range(height):
148 | for j in range(width):
149 | # The unmasked pixels are the ones where the mask is nonzero
150 | unmasked = x_cond[:, -1, i, j] > 0
151 | for k in range(num_channels):
152 | # Get the batch of true values at pixel (k, i, j)
153 | true_vals = samples[:, k, i, j]
154 | # Get probability assigned by model to true pixel
155 | probs_pixel = probs[:, true_vals, k, i, j][:, 0]
156 | # Conditional pixels are known, so set probs of these to 1
157 | probs_pixel[unmasked] = 1.
158 | # Add log probs (1e-9 to avoid log(0))
159 | log_probs += torch.log(probs_pixel + 1e-9)
160 |
161 | # Reset model to train mode
162 | self.train()
163 |
164 | return log_probs
165 |
--------------------------------------------------------------------------------
/pixconcnn/training.py:
--------------------------------------------------------------------------------
1 | import imageio
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from utils.masks import get_conditional_pixels
6 | from torchvision.utils import make_grid
7 |
8 |
9 | class Trainer():
10 | """Class used to train PixelCNN models without conditioning.
11 |
12 | Parameters
13 | ----------
14 | model : pixconcnn.models.gated_pixelcnn.GatedPixelCNN(RGB) instance
15 |
16 | optimizer : one of optimizers in torch.optim
17 |
18 | device : torch.device instance
19 |
20 | record_loss_every : int
21 | Frequency (in iterations) with which to record loss.
22 |
23 | save_model_every : int
24 | Frequency (in epochs) with which to save model.
25 | """
26 | def __init__(self, model, optimizer, device, record_loss_every=10,
27 | save_model_every=5):
28 | self.device = device
29 | self.losses = {'total': []}
30 | self.mean_epoch_losses = []
31 | self.model = model
32 | self.optimizer = optimizer
33 | self.record_loss_every = record_loss_every
34 | self.save_model_every = save_model_every
35 | self.steps = 0
36 |
37 | def train(self, data_loader, epochs, directory='.'):
38 | """Trains model on the data given in data_loader.
39 |
40 | Parameters
41 | ----------
42 | data_loader : torch.utils.data.DataLoader instance
43 |
44 | epochs : int
45 | Number of epochs to train model for.
46 |
47 | directory : string
48 | Directory in which to store training progress, including trained
49 | models and samples generated at every epoch.
50 |
51 | Returns
52 | -------
53 | List of numpy arrays of generated images after each epoch.
54 | """
55 | # List of generated images after each epoch to track progress of model
56 | progress_imgs = []
57 |
58 | for epoch in range(epochs):
59 | print("\nEpoch {}/{}".format(epoch + 1, epochs))
60 | epoch_loss = self._train_epoch(data_loader)
61 | mean_epoch_loss = epoch_loss / len(data_loader)
62 | print("Epoch loss: {}".format(mean_epoch_loss))
63 | self.mean_epoch_losses.append(mean_epoch_loss)
64 |
65 | # Create a grid of model samples (limit number of samples by scaling
66 | # by number of pixels in output image; this is needed because of
67 | # GPU memory limitations)
68 | if self.model.img_size[-1] > 32:
69 | scale_to_32 = self.model.img_size[-1] / 32
70 | num_images = 64 / (scale_to_32 * scale_to_32)
71 | else:
72 | num_images = 64
73 | # Generate samples from model
74 | samples = self.model.sample(self.device, num_images)
75 | img_grid = make_grid(samples).cpu()
76 | # Convert to numpy with channels in imageio order
77 | img_grid = img_grid.float().numpy().transpose(1, 2, 0) / (self.model.num_colors - 1.)
78 | progress_imgs.append(img_grid)
79 | # Save generated image
80 | imageio.imsave(directory + '/training{}.png'.format(epoch), progress_imgs[-1])
81 | # Save model
82 | if epoch % self.save_model_every == 0:
83 | torch.save(self.model.state_dict(),
84 | directory + '/model{}.pt'.format(epoch))
85 |
86 | return progress_imgs
87 |
88 | def _train_epoch(self, data_loader):
89 | epoch_loss = 0
90 | for i, (batch, _) in enumerate(data_loader):
91 | batch_loss = self._train_iteration(batch)
92 | epoch_loss += batch_loss
93 | if i % 50 == 0:
94 | print("Iteration {}/{}, Loss: {}".format(i + 1,
95 | len(data_loader),
96 | batch_loss))
97 | return epoch_loss
98 |
99 | def _train_iteration(self, batch):
100 | self.optimizer.zero_grad()
101 |
102 | batch = batch.to(self.device)
103 |
104 | # Normalize batch, i.e. put it in 0 - 1 range before passing it through
105 | # the model
106 | norm_batch = batch.float() / (self.model.num_colors - 1)
107 | logits = self.model(norm_batch)
108 |
109 | loss = self._loss(logits, batch)
110 | loss.backward()
111 | self.optimizer.step()
112 |
113 | self.steps += 1
114 |
115 | return loss.item()
116 |
117 | def _loss(self, logits, batch):
118 | loss = F.cross_entropy(logits, batch)
119 |
120 | if self.steps % self.record_loss_every == 0:
121 | self.losses['total'].append(loss.item())
122 |
123 | return loss
124 |
125 |
126 | class PixelConstrainedTrainer():
127 | """Class used to train Pixel Constrained CNN models.
128 |
129 | Parameters
130 | ----------
131 | model : pixconcnn.models.pixel_constrained.PixelConstrained instance
132 |
133 | optimizer : one of optimizers in torch.optim
134 |
135 | device : torch.device instance
136 |
137 | mask_generator : pixconcnn.utils.masks.MaskGenerator instance
138 | Defines the masks used during training.
139 |
140 | weight_cond_logits_loss : float
141 | Weight on conditional logits in the loss (called alpha in the paper)
142 |
143 | weight_cond_logits_loss : float
144 | Weight on prio logits in the loss.
145 |
146 | record_loss_every : int
147 | Frequency (in iterations) with which to record loss.
148 |
149 | save_model_every : int
150 | Frequency (in epochs) with which to save model.
151 | """
152 | def __init__(self, model, optimizer, device, mask_generator,
153 | weight_cond_logits_loss=0., weight_prior_logits_loss=0.,
154 | record_loss_every=10, save_model_every=5):
155 | self.device = device
156 | self.losses = {'cond_logits': [], 'prior_logits': [], 'logits': [], 'total': []} # Keep track of losses
157 | self.mask_generator = mask_generator
158 | self.mean_epoch_losses = []
159 | self.model = model
160 | self.optimizer = optimizer
161 | self.record_loss_every = record_loss_every
162 | self.save_model_every = save_model_every
163 | self.steps = 0
164 | self.weight_cond_logits_loss = weight_cond_logits_loss
165 | self.weight_prior_logits_loss = weight_prior_logits_loss
166 |
167 | def train(self, data_loader, epochs, directory='.'):
168 | """
169 | Parameters
170 | ----------
171 | data_loader : torch.utils.data.DataLoader instance
172 |
173 | epochs : int
174 | Number of epochs to train model for.
175 |
176 | directory : string
177 | Directory in which to store training progress, including trained
178 | models and samples generated at every epoch.
179 |
180 | Returns
181 | -------
182 | List of numpy arrays of generated images after each epoch.
183 | """
184 | # List of generated images after each epoch to track progress of model
185 | progress_imgs = []
186 |
187 | # Use a fixed batch of images to test conditioning throughout training
188 | for batch, _ in data_loader:
189 | break
190 | test_mask = self.mask_generator.get_masks(batch.size(0))
191 | cond_pixels = get_conditional_pixels(batch, test_mask,
192 | self.model.prior_net.num_colors)
193 | # Number of images generated in a batch is limited by GPU memory.
194 | # For 32 by 32 this should be 64, for 64 by 64 this should be 16 and
195 | # for 128 by 128 this should be 4 etc.
196 | if self.model.prior_net.img_size[-1] > 32:
197 | scale_to_32 = self.model.prior_net.img_size[-1] / 32
198 | num_images = 64 / (scale_to_32 * scale_to_32)
199 | else:
200 | num_images = 64
201 | cond_pixels = cond_pixels[:num_images]
202 |
203 | cond_pixels = cond_pixels.to(self.device)
204 |
205 | for epoch in range(epochs):
206 | print("\nEpoch {}/{}".format(epoch + 1, epochs))
207 | epoch_loss = self._train_epoch(data_loader)
208 | mean_epoch_loss = epoch_loss / len(data_loader)
209 | print("Epoch loss: {}".format(mean_epoch_loss))
210 | self.mean_epoch_losses.append(mean_epoch_loss)
211 |
212 | # Create a grid of model samples
213 | samples = self.model.sample(cond_pixels)
214 | img_grid = make_grid(samples, nrow=8).cpu()
215 | # Convert to numpy with channels in imageio order
216 | img_grid = img_grid.float().numpy().transpose(1, 2, 0) / (self.model.prior_net.num_colors - 1.)
217 | progress_imgs.append(img_grid)
218 |
219 | # Save generated image
220 | imageio.imsave(directory + '/training{}.png'.format(epoch), progress_imgs[-1])
221 |
222 | # Save model
223 | if epoch % self.save_model_every == 0:
224 | torch.save(self.model.state_dict(),
225 | directory + '/model{}.pt'.format(epoch))
226 |
227 | return progress_imgs
228 |
229 | def _train_epoch(self, data_loader):
230 | epoch_loss = 0
231 | for i, (batch, _) in enumerate(data_loader):
232 | mask = self.mask_generator.get_masks(batch.size(0))
233 | batch_loss = self._train_iteration(batch, mask)
234 | epoch_loss += batch_loss
235 | if i % 50 == 0:
236 | print("Iteration {}/{}, Loss: {}".format(i + 1, len(data_loader),
237 | batch_loss))
238 | return epoch_loss
239 |
240 | def _train_iteration(self, batch, mask):
241 | self.optimizer.zero_grad()
242 |
243 | # Note that towards the end of the dataset the batch size may be smaller
244 | # if batch_size doesn't divide the number of examples. In that case,
245 | # slice mask so it has same shape as batch.
246 | cond_pixels = get_conditional_pixels(batch, mask[:batch.size(0)], self.model.prior_net.num_colors)
247 |
248 | batch = batch.to(self.device)
249 | cond_pixels = cond_pixels.to(self.device)
250 |
251 | # Normalize batch, i.e. put it in 0 - 1 range before passing it through
252 | # the model
253 | norm_batch = batch.float() / (self.model.prior_net.num_colors - 1)
254 | logits, prior_logits, cond_logits = self.model(norm_batch, cond_pixels)
255 |
256 | loss = self._loss(logits, prior_logits, cond_logits, batch)
257 | loss.backward()
258 | self.optimizer.step()
259 |
260 | self.steps += 1
261 |
262 | return loss.item()
263 |
264 | def _loss(self, logits, prior_logits, cond_logits, batch):
265 | logits_loss = F.cross_entropy(logits, batch)
266 | prior_logits_loss = F.cross_entropy(prior_logits, batch)
267 | cond_logits_loss = F.cross_entropy(cond_logits, batch)
268 | total_loss = logits_loss + \
269 | self.weight_cond_logits_loss * cond_logits_loss + \
270 | self.weight_prior_logits_loss * prior_logits_loss
271 |
272 | # Record losses
273 | if self.steps % self.record_loss_every == 0:
274 | self.losses['total'].append(total_loss.item())
275 | self.losses['cond_logits'].append(cond_logits_loss.item())
276 | self.losses['prior_logits'].append(prior_logits_loss.item())
277 | self.losses['logits'].append(logits_loss.item())
278 |
279 | return total_loss
280 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | imageio==2.4.1
2 | matplotlib==2.2.3
3 | numpy==1.11.2
4 | Pillow==6.2.0
5 | torch==0.4.1
6 | torchvision==0.2.1
7 |
--------------------------------------------------------------------------------
/trained_models/celeba/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "0",
3 | "dataset": "celeba",
4 | "resize": 32,
5 | "crop": 89,
6 | "grayscale": false,
7 | "batch_size": 64,
8 | "constrained": true,
9 | "num_colors": 32,
10 | "filter_size": 5,
11 | "depth": 18,
12 | "num_filters_cond": 66,
13 | "num_filters_prior": 66,
14 | "lr": 4e-4,
15 | "epochs": 80,
16 | "mask_descriptor": ["random_blob", [4, [2, 7], 0.5]],
17 | "num_conds": 4,
18 | "num_samples": 64,
19 | "weight_cond_logits_loss": 1.0,
20 | "weight_prior_logits_loss": 0.0
21 | }
22 |
--------------------------------------------------------------------------------
/trained_models/celeba/model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/trained_models/celeba/model.pt
--------------------------------------------------------------------------------
/trained_models/mnist/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "0",
3 | "dataset": "mnist",
4 | "resize": 28,
5 | "crop": 28,
6 | "grayscale": true,
7 | "batch_size": 64,
8 | "constrained": true,
9 | "num_colors": 2,
10 | "filter_size": 5,
11 | "depth": 16,
12 | "num_filters_cond": 32,
13 | "num_filters_prior": 32,
14 | "lr": 4e-4,
15 | "epochs": 50,
16 | "mask_descriptor": ["random_blob", [4, [2, 7], 0.5]],
17 | "num_conds": 4,
18 | "num_samples": 64,
19 | "weight_cond_logits_loss": 1.0,
20 | "weight_prior_logits_loss": 0.0
21 | }
22 |
--------------------------------------------------------------------------------
/trained_models/mnist/model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/trained_models/mnist/model.pt
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Schlumberger/pixel-constrained-cnn-pytorch/2116e9c8d6d1c2231817a55a8f7c9776431ae522/utils/__init__.py
--------------------------------------------------------------------------------
/utils/dataloaders.py:
--------------------------------------------------------------------------------
1 | import glob
2 | from PIL import Image
3 | from torch.utils.data import DataLoader, Dataset
4 | from torchvision import datasets, transforms
5 |
6 |
7 | def mnist(batch_size=128, num_colors=256, size=28,
8 | path_to_data='../mnist_data'):
9 | """MNIST dataloader with (28, 28) images.
10 |
11 | Parameters
12 | ----------
13 | batch_size : int
14 |
15 | num_colors : int
16 | Number of colors to quantize images into. Typically 256, but can be
17 | lower for e.g. binary images.
18 |
19 | size : int
20 | Size (height and width) of each image. Default is 28 for no resizing.
21 |
22 | path_to_data : string
23 | Path to MNIST data files.
24 | """
25 | quantize = get_quantize_func(num_colors)
26 |
27 | all_transforms = transforms.Compose([
28 | transforms.Resize(size),
29 | transforms.ToTensor(),
30 | transforms.Lambda(lambda x: quantize(x))
31 | ])
32 |
33 | train_data = datasets.MNIST(path_to_data, train=True, download=True,
34 | transform=all_transforms)
35 | test_data = datasets.MNIST(path_to_data, train=False,
36 | transform=all_transforms)
37 |
38 | train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
39 | test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)
40 |
41 | return train_loader, test_loader
42 |
43 |
44 | def celeba(batch_size=128, num_colors=256, size=178, crop=178, grayscale=False,
45 | shuffle=True, path_to_data='../celeba_data'):
46 | """CelebA dataloader with square images. Note original CelebA images have
47 | shape (218, 178), this dataloader center crops these images to be (178, 178)
48 | by default.
49 |
50 | Parameters
51 | ----------
52 | batch_size : int
53 |
54 | num_colors : int
55 | Number of colors to quantize images into. Typically 256, but can be
56 | lower for e.g. binary images.
57 |
58 | size : int
59 | Size (height and width) of each image.
60 |
61 | crop : int
62 | Size of center crop. This crop happens *before* the resizing.
63 |
64 | grayscale : bool
65 | If True converts images to grayscale.
66 |
67 | shuffle : bool
68 | If True shuffles images.
69 |
70 | path_to_data : string
71 | Path to CelebA image files.
72 | """
73 | quantize = get_quantize_func(num_colors)
74 |
75 | if grayscale:
76 | transform = transforms.Compose([
77 | transforms.CenterCrop(crop),
78 | transforms.Resize(size),
79 | transforms.Grayscale(),
80 | transforms.ToTensor(),
81 | transforms.Lambda(lambda x: quantize(x))
82 | ])
83 | else:
84 | transform = transforms.Compose([
85 | transforms.CenterCrop(crop),
86 | transforms.Resize(size),
87 | transforms.ToTensor(),
88 | transforms.Lambda(lambda x: quantize(x))
89 | ])
90 | celeba_data = CelebADataset(path_to_data,
91 | transform=transform)
92 | celeba_loader = DataLoader(celeba_data, batch_size=batch_size,
93 | shuffle=shuffle)
94 | return celeba_loader
95 |
96 |
97 | class CelebADataset(Dataset):
98 | """CelebA dataset.
99 |
100 | Parameters
101 | ----------
102 | path_to_data : string
103 | Path to CelebA images.
104 |
105 | subsample : int
106 | Only load every |subsample| number of images.
107 |
108 | transform : None or one of torchvision.transforms instances
109 | """
110 | def __init__(self, path_to_data, subsample=1, transform=None):
111 | self.img_paths = glob.glob(path_to_data + '/*')[::subsample]
112 | self.transform = transform
113 |
114 | def __len__(self):
115 | return len(self.img_paths)
116 |
117 | def __getitem__(self, idx):
118 | sample_path = self.img_paths[idx]
119 | sample = Image.open(sample_path)
120 |
121 | if self.transform:
122 | sample = self.transform(sample)
123 | # Since there are no labels, return 0 for the "label"
124 | return sample, 0
125 |
126 |
127 | def get_quantize_func(num_colors):
128 | """Returns a quantization function which can be used to set the number of
129 | colors in an image.
130 |
131 | Parameters
132 | ----------
133 | num_colors : int
134 | Number of bins to quantize image into. Should be between 2 and 256.
135 | """
136 | def quantize_func(batch):
137 | """Takes as input a float tensor with values in the 0 - 1 range and
138 | outputs a long tensor with integer values corresponding to each
139 | quantization bin.
140 |
141 | Parameters
142 | ----------
143 | batch : torch.Tensor
144 | Values in 0 - 1 range.
145 | """
146 | if num_colors == 2:
147 | return (batch > 0.5).long()
148 | else:
149 | return (batch * (num_colors - 1)).long()
150 |
151 | return quantize_func
152 |
--------------------------------------------------------------------------------
/utils/init_models.py:
--------------------------------------------------------------------------------
1 | from pixconcnn.models.cnn import ResNet
2 | from pixconcnn.models.gated_pixelcnn import GatedPixelCNN, GatedPixelCNNRGB
3 | from pixconcnn.models.pixel_constrained import PixelConstrained
4 |
5 |
6 | def initialize_model(img_size, num_colors, depth, filter_size, constrained,
7 | num_filters_prior, num_filters_cond=1):
8 | """Helper function that initializes an appropriate model based on the
9 | input arguments.
10 |
11 | Parameters
12 | ----------
13 | img_size : tuple of ints
14 | Specifies size of image as (channels, height, width), e.g. (3, 32, 32).
15 | If img_size[0] == 1, returned model will be for grayscale images. If
16 | img_size[0] == 3, returned model will be for RGB images.
17 |
18 | num_colors : int
19 | Number of colors to quantize output into. Typically 256, but can be
20 | lower for e.g. binary images.
21 |
22 | depth : int
23 | Number of layers in model.
24 |
25 | filter_size : int
26 | Size of (square) convolutional filters of the model.
27 |
28 | constrained : bool
29 | If True returns a PixelConstrained model, otherwise returns a
30 | GatedPixelCNN or GatedPixelCNNRGB model.
31 |
32 | num_filters_prior : int
33 | Number of convolutional filters in each layer of prior network.
34 |
35 | num_filter_cond : int (optional)
36 | Required if using a PixelConstrained model. Number of of convolutional
37 | filters in each layer of conditioning network.
38 | """
39 |
40 | if img_size[0] == 1:
41 | prior_net = GatedPixelCNN(img_size=img_size,
42 | num_colors=num_colors,
43 | num_filters=num_filters_prior,
44 | depth=depth,
45 | filter_size=filter_size)
46 | else:
47 | prior_net = GatedPixelCNNRGB(img_size=img_size,
48 | num_colors=num_colors,
49 | num_filters=num_filters_prior,
50 | depth=depth,
51 | filter_size=filter_size)
52 |
53 | if constrained:
54 | # Add extra color channel for mask in conditioning network
55 | cond_net = ResNet(img_size=(img_size[0] + 1,) + img_size[1:],
56 | num_colors=num_colors,
57 | num_filters=num_filters_cond,
58 | depth=depth,
59 | filter_size=filter_size)
60 | # Define a pixel constrained model based on prior and cond net
61 | return PixelConstrained(prior_net, cond_net)
62 | else:
63 | return prior_net
64 |
--------------------------------------------------------------------------------
/utils/loading.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | from utils.dataloaders import mnist, celeba
4 | from utils.init_models import initialize_model
5 |
6 | def load_model(directory, model_version=None):
7 | """
8 | Returns model, data_loader and mask_descriptor of trained model.
9 |
10 | Parameters
11 | ----------
12 | directory : string
13 | Directory where experiment was saved. For example './experiment_1'.
14 |
15 | model_version : int or None
16 | If None loads final model, otherwise loads model version determined by
17 | int.
18 | """
19 | path_to_config = directory + '/config.json'
20 | if model_version is None:
21 | path_to_model = directory + '/model.pt'
22 | else:
23 | path_to_model = directory + '/model{}.pt'.format(model_version)
24 |
25 | # Open config file
26 | with open(path_to_config) as config_file:
27 | config = json.load(config_file)
28 |
29 | # Load dataset info
30 | dataset = config["dataset"]
31 | resize = config["resize"]
32 | crop = config["crop"]
33 | batch_size = config["batch_size"]
34 | num_colors = config["num_colors"]
35 | if "grayscale" in config:
36 | grayscale = config["grayscale"]
37 | else:
38 | grayscale = False
39 |
40 | # Get data
41 | if dataset == 'mnist':
42 | # Extract the test dataset (second argument)
43 | _, data_loader = mnist(batch_size, num_colors, resize)
44 | img_size = (1, resize, resize)
45 | elif dataset == 'celeba':
46 | data_loader = celeba(batch_size, num_colors, resize, crop, grayscale)
47 | if grayscale:
48 | img_size = (1, resize, resize)
49 | else:
50 | img_size = (3, resize, resize)
51 |
52 | # Load model info
53 | constrained = config["constrained"]
54 | depth = config["depth"]
55 | num_filters_cond = config["num_filters_cond"]
56 | num_filters_prior = config["num_filters_prior"]
57 | filter_size = config["filter_size"]
58 |
59 | model = initialize_model(img_size,
60 | num_colors,
61 | depth,
62 | filter_size,
63 | constrained,
64 | num_filters_prior,
65 | num_filters_cond)
66 |
67 | model.load_state_dict(torch.load(path_to_model, map_location=lambda storage, loc: storage))
68 |
69 | return model, data_loader, config["mask_descriptor"]
70 |
--------------------------------------------------------------------------------
/utils/masks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.utils.data import DataLoader, Dataset
4 | from torchvision import datasets, transforms
5 |
6 |
7 | class MaskGenerator():
8 | """Class used to generate masks. Can be used to create masks during
9 | training or to build various masks for generation.
10 |
11 | Parameters
12 | ----------
13 | img_size : tuple of ints
14 | E.g. (1, 28, 28) or (3, 64, 64)
15 |
16 | mask_descriptor : tuple of string and other
17 | Mask descriptors will be of the form (mask_type, mask_attribute).
18 | Allowed descriptors are:
19 | 1. ('random', None or int or tuple of ints): Generates random masks,
20 | where the position of visible pixels is selected uniformly
21 | at random over the image. If mask_attribute is None then the
22 | number of visible pixels is sampled uniformly between 1 and the
23 | total number of pixels in the image, otherwise it is fixed to
24 | the int given in mask_attribute. If mask_attribute is a tuple
25 | of ints, the number of visible pixels is sampled uniformly
26 | between the first int (lower bound) and the second int (upper
27 | bound).
28 | 2. ('bottom', int): Generates masks where only the bottom pixels are
29 | visible. The int determines the number of rows of the image to
30 | keep visible at the bottom.
31 | 3. ('top', int): Generates masks where only the top pixels are
32 | visible. The int determines the number of rows of the image to
33 | keep visible at the top.
34 | 4. ('center', int): Generates masks where only the central pixels
35 | are visible. The int determines the size in pixels of the sides
36 | of the square of visible pixels of the image.
37 | 5. ('edge', int): Generates masks where only the edge pixels of the
38 | image are visible. The int determines the thickness of the edges
39 | in pixels.
40 | 6. ('left', int): Generates masks where only the left pixels of the
41 | image are visible. The int determines the number of columns
42 | in pixels which are visible.
43 | 7. ('right', int): Generates masks where only the right pixels of
44 | the image are visible. The int determines the number of columns
45 | in pixels which are visible.
46 | 8. ('random_rect', (int, int)): Generates random rectangular masks
47 | where the maximum height and width of the rectangles are
48 | determined by the two ints.
49 | 9. ('random_blob', (int, (int, int), float)): Generates random
50 | blobs, where the number of blobs is determined by the first int,
51 | the range of iterations (see function definition) is determined
52 | by the tuple of ints and the threshold for making pixels visible
53 | is determined by the float.
54 | 10. ('random_blob_cache', (str, int)): Loads pregenerated random masks
55 | from a folder given by the string, using a batch_size given by
56 | the int.
57 | """
58 | def __init__(self, img_size, mask_descriptor):
59 | self.img_size = img_size
60 | self.num_pixels = img_size[1] * img_size[2]
61 | self.mask_type, self.mask_attribute = mask_descriptor
62 |
63 | if self.mask_type == 'random_blob_cache':
64 | dset = datasets.ImageFolder(self.mask_attribute[0],
65 | transform=transforms.Compose([transforms.Grayscale(),
66 | transforms.ToTensor()]))
67 | self.data_loader = DataLoader(dset, batch_size=self.mask_attribute[1], shuffle=True)
68 |
69 | def get_masks(self, batch_size):
70 | """Returns a tensor of shape (batch_size, 1, img_size[1], img_size[2])
71 | containing masks which were generated according to mask_type and
72 | mask_attribute.
73 |
74 | Parameters
75 | ----------
76 | batch_size : int
77 | """
78 | if self.mask_type == 'random':
79 | if self.mask_attribute is None:
80 | num_visibles = np.random.randint(1, self.num_pixels, size=batch_size)
81 | return batch_random_mask(self.img_size, num_visibles, batch_size)
82 | elif type(self.mask_attribute) == int:
83 | return batch_random_mask(self.img_size, self.mask_attribute, batch_size)
84 | else:
85 | lower_bound, upper_bound = self.mask_attribute
86 | num_visibles = np.random.randint(lower_bound, upper_bound, size=batch_size)
87 | return batch_random_mask(self.img_size, num_visibles, batch_size)
88 | elif self.mask_type == 'bottom':
89 | return batch_bottom_mask(self.img_size, self.mask_attribute, batch_size)
90 | elif self.mask_type == 'top':
91 | return batch_top_mask(self.img_size, self.mask_attribute, batch_size)
92 | elif self.mask_type == 'center':
93 | return batch_center_mask(self.img_size, self.mask_attribute, batch_size)
94 | elif self.mask_type == 'edge':
95 | return batch_edge_mask(self.img_size, self.mask_attribute, batch_size)
96 | elif self.mask_type == 'left':
97 | return batch_left_mask(self.img_size, self.mask_attribute, batch_size)
98 | elif self.mask_type == 'right':
99 | return batch_right_mask(self.img_size, self.mask_attribute, batch_size)
100 | elif self.mask_type == 'random_rect':
101 | return batch_random_rect_mask(self.img_size, self.mask_attribute[0],
102 | self.mask_attribute[1], batch_size)
103 | elif self.mask_type == 'random_blob':
104 | return batch_multi_random_blobs(self.img_size,
105 | self.mask_attribute[0],
106 | self.mask_attribute[1],
107 | self.mask_attribute[2], batch_size)
108 | elif self.mask_type == 'random_blob_cache':
109 | # Hacky way to get a single batch of data
110 | for mask_batch in self.data_loader:
111 | break
112 | # Zero index because Image folder returns (img, label) tuple
113 | return mask_batch[0]
114 |
115 |
116 | def single_random_mask(img_size, num_visible):
117 | """Returns random mask where 0 corresponds to a hidden value and 1 to a
118 | visible value. Shape of mask is same as img_size.
119 |
120 | Parameters
121 | ----------
122 | img_size : tuple of ints
123 | E.g. (1, 32, 32) for grayscale or (3, 64, 64) for RGB.
124 |
125 | num_visible : int
126 | Number of visible values.
127 | """
128 | _, height, width = img_size
129 | # Sample integers without replacement between 0 and the total number of
130 | # pixels. The measurements array will then contain a pixel indices
131 | # corresponding to locations where pixels will be visible.
132 | measurements = np.random.choice(range(height * width), size=num_visible, replace=False)
133 | # Create empty mask
134 | mask = torch.zeros(1, width, height)
135 | # Update mask with measurements
136 | for m in measurements:
137 | row = int(m / width)
138 | col = m % width
139 | mask[0, row, col] = 1
140 | return mask
141 |
142 |
143 | def batch_random_mask(img_size, num_visibles, batch_size, repeat=False):
144 | """Returns a batch of random masks.
145 |
146 | Parameters
147 | ----------
148 | img_size : see single_random_mask
149 |
150 | num_visibles : int or list of ints
151 | If int will keep the number of visible pixels in the masks fixed, if
152 | list will change the number of visible pixels depending on the values
153 | in the list. List should have length equal to batch_size.
154 |
155 | batch_size : int
156 | Number of masks to create.
157 |
158 | repeat : bool
159 | If True returns a batch of the same mask repeated batch_size times.
160 | """
161 | # Mask should have same shape as image, but only 1 channel
162 | mask_batch = torch.zeros(batch_size, 1, *img_size[1:])
163 | if repeat:
164 | if not type(num_visibles) == int:
165 | raise RuntimeError("num_visibles must be an int if used with repeat=True. {} was provided instead.".format(type(num_visibles)))
166 | single_mask = single_random_mask(img_size, num_visibles)
167 | for i in range(batch_size):
168 | mask_batch[i] = single_mask
169 | else:
170 | if type(num_visibles) == int:
171 | for i in range(batch_size):
172 | mask_batch[i] = single_random_mask(img_size, num_visibles)
173 | else:
174 | for i in range(batch_size):
175 | mask_batch[i] = single_random_mask(img_size, num_visibles[i])
176 | return mask_batch
177 |
178 |
179 | def batch_bottom_mask(img_size, num_rows, batch_size):
180 | """Masks all the output except the |num_rows| lowest rows (in the height
181 | dimension).
182 |
183 | Parameters
184 | ----------
185 | img_size : see single_random_mask
186 |
187 | num_rows : int
188 | Number of rows from bottom which will be visible.
189 |
190 | batch_size : int
191 | Number of masks to create.
192 | """
193 | mask = torch.zeros(batch_size, 1, *img_size[1:])
194 | mask[:, :, -num_rows:, :] = 1.
195 | return mask
196 |
197 |
198 | def batch_top_mask(img_size, num_rows, batch_size):
199 | """Masks all the output except the |num_rows| highest rows (in the height
200 | dimension).
201 |
202 | Parameters
203 | ----------
204 | img_size : see single_random_mask
205 |
206 | num_rows : int
207 | Number of rows from top which will be visible.
208 |
209 | batch_size : int
210 | Number of masks to create.
211 | """
212 | mask = torch.zeros(batch_size, 1, *img_size[1:])
213 | mask[:, :, :num_rows, :] = 1.
214 | return mask
215 |
216 |
217 | def batch_center_mask(img_size, num_pixels, batch_size):
218 | """Masks all the output except the num_pixels by num_pixels central square
219 | of the image.
220 |
221 | Parameters
222 | ----------
223 | img_size : see single_random_mask
224 |
225 | num_pixels : int
226 | Should be even. If not even, num_pixels will be replaced with
227 | num_pixels - 1.
228 |
229 | batch_size : int
230 | Number of masks to create.
231 | """
232 | mask = torch.zeros(batch_size, 1, *img_size[1:])
233 | _, height, width = img_size
234 | lower_height = int(height / 2 - num_pixels / 2)
235 | upper_height = int(height / 2 + num_pixels / 2)
236 | lower_width = int(width / 2 - num_pixels / 2)
237 | upper_width = int(width / 2 + num_pixels / 2)
238 | mask[:, :, lower_height:upper_height, lower_width:upper_width] = 1.
239 | return mask
240 |
241 |
242 | def batch_edge_mask(img_size, num_pixels, batch_size):
243 | """Masks all the output except the num_pixels thick edge of the image.
244 |
245 | Parameters
246 | ----------
247 | img_size : see single_random_mask
248 |
249 | num_pixels : int
250 | Should be smaller than min(height / 2, width / 2).
251 |
252 | batch_size : int
253 | Number of masks to create.
254 | """
255 | mask = torch.zeros(batch_size, 1, *img_size[1:])
256 | mask[:, :, :num_pixels, :] = 1.
257 | mask[:, :, -num_pixels:, :] = 1.
258 | mask[:, :, :, :num_pixels] = 1.
259 | mask[:, :, :, -num_pixels:] = 1.
260 | return mask
261 |
262 |
263 | def batch_left_mask(img_size, num_cols, batch_size):
264 | """Masks all the pixels except the left side of the image.
265 |
266 | Parameters
267 | ----------
268 | img_size : see single_random_mask
269 |
270 | num_cols : int
271 | Number of columns of the left side of the image to remain visible.
272 |
273 | batch_size : int
274 | Number of masks to create.
275 | """
276 | mask = torch.zeros(batch_size, 1, *img_size[1:])
277 | mask[:, :, :, :num_cols] = 1.
278 | return mask
279 |
280 |
281 | def batch_right_mask(img_size, num_cols, batch_size):
282 | """Masks all the pixels except the right side of the image.
283 |
284 | Parameters
285 | ----------
286 | img_size : see single_random_mask
287 |
288 | num_cols : int
289 | Number of columns of the right side of the image to remain visible.
290 |
291 | batch_size : int
292 | Number of masks to create.
293 | """
294 | mask = torch.zeros(batch_size, 1, *img_size[1:])
295 | mask[:, :, :, -num_cols:] = 1.
296 | return mask
297 |
298 |
299 | def random_rect_mask(img_size, max_height, max_width):
300 | """Returns a mask with a random rectangle of visible pixels.
301 |
302 | Parameters
303 | ----------
304 | img_size : see single_random_mask
305 |
306 | max_height : int
307 | Maximum height of randomly sampled rectangle.
308 |
309 | max_width : int
310 | Maximum width of randomly sampled rectangle.
311 | """
312 | mask = torch.zeros(1, *img_size[1:])
313 | _, img_width, img_height = img_size
314 | # Sample top left corner of unmasked rectangle
315 | top_left = np.random.randint(0, img_height - 1), np.random.randint(0, img_width - 1)
316 | # Sample height of rectangle
317 | # This is a number between 1 and the max_height parameter. If the top left corner
318 | # is too close to the bottom of the image, make sure the rectangle doesn't exceed
319 | # this
320 | rect_height = np.random.randint(1, min(max_height, img_height - top_left[0]))
321 | # Sample width of rectangle
322 | rect_width = np.random.randint(1, min(max_width, img_width - top_left[1]))
323 | # Set visible pixels
324 | bottom_right = top_left[0] + rect_height, top_left[1] + rect_width
325 | mask[0, top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]] = 1.
326 | return mask
327 |
328 |
329 | def batch_random_rect_mask(img_size, max_height, max_width, batch_size):
330 | """Returns a batch of masks with random rectangles of visible pixels.
331 |
332 | Parameters
333 | ----------
334 | img_size : see single_random_mask
335 |
336 | max_height : int
337 | Maximum height of randomly sampled rectangle.
338 |
339 | max_width : int
340 | Maximum width of randomly sampled rectangle.
341 |
342 | batch_size : int
343 | Number of masks to create.
344 | """
345 | mask = torch.zeros(batch_size, 1, *img_size[1:])
346 | for i in range(batch_size):
347 | mask[i] = random_rect_mask(img_size, max_height, max_width)
348 | return mask
349 |
350 |
351 | def random_blob(img_size, num_iter, threshold, fixed_init=None):
352 | """Generates masks with random connected blobs.
353 |
354 | Parameters
355 | ----------
356 | img_size : see single_random_mask
357 |
358 | num_iter : int
359 | Number of iterations to expand random blob for.
360 |
361 | threshold : float
362 | Number between 0 and 1. Probability of keeping a pixel hidden.
363 |
364 | fixed_init : tuple of ints or None
365 | If fixed_init is None, central position of blob will be sampled
366 | randomly, otherwise expansion will start from fixed_init. E.g.
367 | fixed_init = (6, 12) will start the expansion from pixel in row 6,
368 | column 12.
369 | """
370 | _, img_height, img_width = img_size
371 | # Defines the shifts around the central pixel which may be unmasked
372 | neighbors = [(-1, -1), (-1, 0), (-1, 1), (0, -1), (0, 1), (1, -1), (1, 0), (1, 1)]
373 | if fixed_init is None:
374 | # Sample random initial position
375 | init_pos = np.random.randint(0, img_height - 1), np.random.randint(0, img_width - 1)
376 | else:
377 | init_pos = (fixed_init[0], fixed_init[1])
378 | # Initialize mask and make init_pos visible
379 | mask = torch.zeros(1, 1, *img_size[1:])
380 | mask[0, 0, init_pos[0], init_pos[1]] = 1.
381 | # Initialize the list of seed positions
382 | seed_positions = [init_pos]
383 | # Randomly expand blob
384 | for i in range(num_iter):
385 | next_seed_positions = []
386 | for seed_pos in seed_positions:
387 | # Sample probability that neighboring pixel will be visible
388 | prob_visible = np.random.rand(len(neighbors))
389 | for j, neighbor in enumerate(neighbors):
390 | if prob_visible[j] > threshold:
391 | current_h, current_w = seed_pos
392 | shift_h, shift_w = neighbor
393 | # Ensure new height stays within image boundaries
394 | new_h = max(min(current_h + shift_h, img_height - 1), 0)
395 | # Ensure new width stays within image boundaries
396 | new_w = max(min(current_w + shift_w, img_width - 1), 0)
397 | # Update mask
398 | mask[0, 0, new_h, new_w] = 1.
399 | # Add new position to list of seeds
400 | next_seed_positions.append((new_h, new_w))
401 | seed_positions = next_seed_positions
402 | return mask
403 |
404 |
405 | def multi_random_blobs(img_size, max_num_blobs, iter_range, threshold):
406 | """Generates masks with multiple random connected blobs.
407 |
408 | Parameters
409 | ----------
410 | max_num_blobs : int
411 | Maximum number of blobs. Number of blobs will be sampled between 1 and
412 | max_num_blobs
413 |
414 | iter_range : (int, int)
415 | Lower and upper bound on number of iterations to be used for each blob.
416 | This will be sampled for each blob.
417 |
418 | threshold : float
419 | Number between 0 and 1. Probability of keeping a pixel hidden.
420 | """
421 | mask = torch.zeros(1, 1, *img_size[1:])
422 | # Sample number of blobs
423 | num_blobs = np.random.randint(1, max_num_blobs + 1)
424 | for _ in range(num_blobs):
425 | num_iter = np.random.randint(iter_range[0], iter_range[1])
426 | mask += random_blob(img_size, num_iter, threshold)
427 | mask[mask > 0] = 1.
428 | return mask
429 |
430 |
431 | def batch_multi_random_blobs(img_size, max_num_blobs, iter_range, threshold,
432 | batch_size):
433 | """Generates batch of masks with multiple random connected blobs."""
434 | mask = torch.zeros(batch_size, 1, *img_size[1:])
435 | for i in range(batch_size):
436 | mask[i] = multi_random_blobs(img_size, max_num_blobs, iter_range, threshold)
437 | return mask
438 |
439 |
440 | def get_conditional_pixels(batch, mask, num_colors):
441 | """Returns conditional pixels obtained from masking the data in batch with
442 | mask and appending the mask. E.g. if the input has size (N, C, H, W)
443 | then the output will have size (N, C + 1, H, W) i.e. the mask is appended
444 | as an extra color channel.
445 |
446 | Parameters
447 | ----------
448 | batch : torch.Tensor
449 | Batch of data as returned by a DataLoader, i.e. unnormalized.
450 | Shape (num_examples, num_channels, width, height)
451 |
452 | mask : torch.Tensor
453 | Mask as returned by MaskGenerator.get_masks.
454 | Shape (num_examples, 1, width, height)
455 |
456 | num_colors : int
457 | Number of colors image is quantized to.
458 | """
459 | batch_size, channels, width, height = batch.size()
460 | # Add extra channel to keep mask
461 | cond_pixels = torch.zeros((batch_size, channels + 1, height, width))
462 | # Mask batch to only show visible pixels
463 | cond_pixels[:, :channels, :, :] = mask * batch.float()
464 | # Add mask scaled by number of colors in last channel dimension
465 | cond_pixels[:, -1:, :, :] = mask * (num_colors - 1)
466 | # Normalize conditional pixels to be in 0 - 1 range
467 | return cond_pixels / (num_colors - 1)
468 |
469 |
470 | def get_repeated_conditional_pixels(batch, mask, num_colors, num_reps):
471 | """Returns repeated conditional pixels.
472 |
473 | Parameters
474 | ----------
475 | batch : torch.Tensor
476 | Shape (1, num_channels, width, height)
477 |
478 | mask : torch.Tensor
479 | Shape (1, num_channels, width, height)
480 |
481 | num_colors : int
482 | Number of colors image is quantized to.
483 |
484 | num_reps : int
485 | Number of times the conditional pixels will be repeated
486 | """
487 | assert batch.size(0) == 1
488 | assert mask.size(0) == 1
489 | cond_pixels = get_conditional_pixels(batch, mask, num_colors)
490 | return cond_pixels.expand(num_reps, *cond_pixels.size()[1:])
491 |
--------------------------------------------------------------------------------
/utils/plots.py:
--------------------------------------------------------------------------------
1 | from matplotlib.pyplot import get_cmap
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def probs_and_conditional_plot(img, probs, mask, cmap='plasma'):
7 | """Creates a plot of pixel probabilities with the conditional pixels from
8 | the original image overlayed. Note this function only works for binary
9 | images.
10 |
11 | Parameters
12 | ----------
13 | img : torch.Tensor
14 | Shape (1, H, W)
15 |
16 | probs : torch.Tensor
17 | Shape (H, W). Should be the probability of a pixel being 1.
18 |
19 | mask : torch.Tensor
20 | Shape (H, W)
21 |
22 | cmap : string
23 | Colormap to use for probs (as defined in matplotlib.plt,
24 | e.g. 'jet', 'viridis', ...)
25 | """
26 | # Define function to convert array to colormap rgb image
27 | # The colorscale has a min value of 0 and a max value of 1 by default
28 | # (i.e. it does not rescale depending on the value of the probs)
29 | convert_to_cmap = get_cmap(cmap)
30 | # Create output image from colormap of probs
31 | rgba_probs = convert_to_cmap(probs.numpy())
32 | output_img = np.delete(rgba_probs, 3, 2) # Convert to RGB
33 | # Overlay unmasked parts of original image over probs
34 | np_mask = mask.numpy().astype(bool) # Convert mask to boolean numpy array
35 | np_img = img.numpy()[0] # Convert img to grayscale numpy img
36 | output_img[:, :, 0][np_mask] = np_img[np_mask]
37 | output_img[:, :, 1][np_mask] = np_img[np_mask]
38 | output_img[:, :, 2][np_mask] = np_img[np_mask]
39 | # Convert numpy image to torch tensor
40 | return torch.Tensor(output_img.transpose(2, 0, 1))
41 |
42 |
43 | def uncertainty_plot(samples, log_probs, cmap='plasma'):
44 | """Sorts samples by their log likelihoods and creates an image representing
45 | the log likelihood of each sample as a box with color and size proportional
46 | to the log likelihood.
47 |
48 | Parameters
49 | ----------
50 | samples : torch.Tensor
51 | Shape (N, C, H, W)
52 |
53 | log_probs : torch.Tensor
54 | Shape (N,)
55 |
56 | cmap : string
57 | Colormap to use for likelihoods (as defined in matplotlib.plt,
58 | e.g. 'jet', 'viridis', ...)
59 | """
60 | # Sorted by negative log likelihood
61 | sorted_nll, sorted_indices = torch.sort(-log_probs)
62 | sorted_samples = samples[sorted_indices]
63 | # Normalize log likelihoods to be in 0 - 1 range
64 | min_ll, max_ll = (-sorted_nll).min(), (-sorted_nll).max()
65 | normalized_likelihoods = ((-sorted_nll) - min_ll) / (max_ll - min_ll)
66 |
67 | # For each sample draw an image with a box proportional in size and
68 | # color to the log likelihood value
69 | num_samples, _, height, width = samples.size()
70 | # Initialize white background images on which to draw boxes
71 | ll_images = torch.ones(num_samples, 3, height, width)
72 | # Specify box sizes
73 | lower_width = width // 2 - width // 5
74 | upper_width = width // 2 + width // 5
75 | max_box_height = height
76 | min_box_height = 1
77 | # Generate colors for the boxes
78 | convert_to_cmap = get_cmap(cmap)
79 | # Remove alpha channel from colormap
80 | colors = convert_to_cmap(normalized_likelihoods.numpy())[:, :-1]
81 |
82 | # Fill out images with boxes
83 | for i in range(num_samples):
84 | norm_ll = normalized_likelihoods[i].item()
85 | box_height = int(min_box_height + (max_box_height - min_box_height) * norm_ll)
86 | box_color = colors[i]
87 | for j in range(3):
88 | ll_images[i, j, height - box_height:height, lower_width:upper_width] = box_color[j]
89 |
90 | return sorted_samples, ll_images
91 |
--------------------------------------------------------------------------------