├── LICENSE ├── README.md ├── data.py ├── docs ├── _config.yml ├── images │ ├── 1-7.jpeg │ ├── 3-8.jpeg │ ├── 4-9.jpeg │ ├── 5-6.jpeg │ ├── best_generated.jpeg │ ├── blindspot.jpeg │ ├── blindspot.png │ ├── generated_plot.gif │ └── mask.png └── index.md ├── experiment.py ├── model.py ├── train.py └── vis.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional PixelCNN 2 | 3 | A PyTorch implementation of conditional PixelCNNs with code for reproducing visualizations found at the project page: https://jrbtaylor.github.io/conditional-pixelcnn/ 4 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Written by Jason Taylor 2018-2019 3 | """ 4 | 5 | import os 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision import datasets, transforms 10 | from torchvision.datasets.mnist import read_label_file, read_image_file 11 | 12 | 13 | def onehot(n_classes): 14 | def onehot_fcn(x): 15 | y = np.zeros((n_classes), dtype='float32') 16 | y[x] = 1 17 | return y 18 | return onehot_fcn 19 | 20 | 21 | def augment(rotate=5): 22 | return transforms.Compose([transforms.RandomRotation(rotate), 23 | transforms.ToTensor()]) 24 | 25 | 26 | def loader(dataset, batch_size, n_workers=8): 27 | assert dataset.lower() in ['mnist','emnist','fashionmnist'] 28 | 29 | loader_args = {'batch_size':batch_size, 30 | 'num_workers':n_workers, 31 | 'pin_memory':True} 32 | datapath = os.path.join(os.getenv('HOME'), 'data', dataset.lower()) 33 | dataset_args = {'root':datapath, 34 | 'download':True, 35 | 'transform':transforms.ToTensor()} 36 | 37 | if dataset.lower()=='mnist': 38 | dataset_init = datasets.MNIST 39 | n_classes = 10 40 | elif dataset.lower()=='emnist': 41 | dataset_init = EMNIST 42 | n_classes = 37 43 | dataset_args.update({'split':'letters'}) 44 | else: 45 | dataset_init = datasets.FashionMNIST 46 | n_classes = 10 47 | onehot_fcn = onehot(n_classes) 48 | dataset_args.update({'target_transform':onehot_fcn}) 49 | 50 | val_loader = torch.utils.data.DataLoader( 51 | dataset_init(train=False, **dataset_args), shuffle=False, **loader_args) 52 | 53 | dataset_args['transform'] = augment() 54 | train_loader = torch.utils.data.DataLoader( 55 | dataset_init(train=True, **dataset_args), shuffle=True, **loader_args) 56 | 57 | 58 | return train_loader, val_loader, onehot_fcn, n_classes 59 | 60 | 61 | # Note: Can't build master ver of pytorch/torchvision, so copied this here 62 | class EMNIST(datasets.MNIST): 63 | """`EMNIST `_ Dataset. 64 | Args: 65 | root (string): Root directory of dataset where ``processed/training.pt`` 66 | and ``processed/test.pt`` exist. 67 | split (string): The dataset has 6 different splits: ``byclass``, ``bymerge``, 68 | ``balanced``, ``letters``, ``digits`` and ``mnist``. This argument specifies 69 | which one to use. 70 | train (bool, optional): If True, creates dataset from ``training.pt``, 71 | otherwise from ``test.pt``. 72 | download (bool, optional): If true, downloads the dataset from the internet and 73 | puts it in root directory. If dataset is already downloaded, it is not 74 | downloaded again. 75 | transform (callable, optional): A function/transform that takes in an PIL image 76 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 77 | target_transform (callable, optional): A function/transform that takes in the 78 | target and transforms it. 79 | """ 80 | url = 'http://biometrics.nist.gov/cs_links/EMNIST/gzip.zip' 81 | splits = ('byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist') 82 | 83 | def __init__(self, root, split, **kwargs): 84 | if split not in self.splits: 85 | raise ValueError('Split "{}" not found. Valid splits are: {}'.format( 86 | split, ', '.join(self.splits), 87 | )) 88 | self.split = split 89 | self.training_file = self._training_file(split) 90 | self.test_file = self._test_file(split) 91 | super(EMNIST, self).__init__(root, **kwargs) 92 | 93 | def _training_file(self, split): 94 | return 'training_{}.pt'.format(split) 95 | 96 | def _test_file(self, split): 97 | return 'test_{}.pt'.format(split) 98 | 99 | def download(self): 100 | """Download the EMNIST data if it doesn't exist in processed_folder already.""" 101 | import errno 102 | from six.moves import urllib 103 | import gzip 104 | import shutil 105 | import zipfile 106 | 107 | if self._check_exists(): 108 | return 109 | 110 | # download files 111 | try: 112 | os.makedirs(os.path.join(self.root, self.raw_folder)) 113 | os.makedirs(os.path.join(self.root, self.processed_folder)) 114 | except OSError as e: 115 | if e.errno == errno.EEXIST: 116 | pass 117 | else: 118 | raise 119 | 120 | print('Downloading ' + self.url) 121 | data = urllib.request.urlopen(self.url) 122 | filename = self.url.rpartition('/')[2] 123 | raw_folder = os.path.join(self.root, self.raw_folder) 124 | file_path = os.path.join(raw_folder, filename) 125 | with open(file_path, 'wb') as f: 126 | f.write(data.read()) 127 | 128 | print('Extracting zip archive') 129 | with zipfile.ZipFile(file_path) as zip_f: 130 | zip_f.extractall(raw_folder) 131 | os.unlink(file_path) 132 | gzip_folder = os.path.join(raw_folder, 'gzip') 133 | for gzip_file in os.listdir(gzip_folder): 134 | if gzip_file.endswith('.gz'): 135 | print('Extracting ' + gzip_file) 136 | with open(os.path.join(raw_folder, gzip_file.replace('.gz', '')), 'wb') as out_f, \ 137 | gzip.GzipFile(os.path.join(gzip_folder, gzip_file)) as zip_f: 138 | out_f.write(zip_f.read()) 139 | shutil.rmtree(gzip_folder) 140 | 141 | # process and save as torch files 142 | for split in self.splits: 143 | print('Processing ' + split) 144 | training_set = ( 145 | read_image_file(os.path.join(raw_folder, 'emnist-{}-train-images-idx3-ubyte'.format(split))), 146 | read_label_file(os.path.join(raw_folder, 'emnist-{}-train-labels-idx1-ubyte'.format(split))) 147 | ) 148 | test_set = ( 149 | read_image_file(os.path.join(raw_folder, 'emnist-{}-test-images-idx3-ubyte'.format(split))), 150 | read_label_file(os.path.join(raw_folder, 'emnist-{}-test-labels-idx1-ubyte'.format(split))) 151 | ) 152 | with open(os.path.join(self.root, self.processed_folder, self._training_file(split)), 'wb') as f: 153 | torch.save(training_set, f) 154 | with open(os.path.join(self.root, self.processed_folder, self._test_file(split)), 'wb') as f: 155 | torch.save(test_set, f) 156 | 157 | print('Done!') -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-slate -------------------------------------------------------------------------------- /docs/images/1-7.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jrbtaylor/conditional-pixelcnn/2141e69ac98e567d201fce29189b1caf5ef17d62/docs/images/1-7.jpeg -------------------------------------------------------------------------------- /docs/images/3-8.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jrbtaylor/conditional-pixelcnn/2141e69ac98e567d201fce29189b1caf5ef17d62/docs/images/3-8.jpeg -------------------------------------------------------------------------------- /docs/images/4-9.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jrbtaylor/conditional-pixelcnn/2141e69ac98e567d201fce29189b1caf5ef17d62/docs/images/4-9.jpeg -------------------------------------------------------------------------------- /docs/images/5-6.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jrbtaylor/conditional-pixelcnn/2141e69ac98e567d201fce29189b1caf5ef17d62/docs/images/5-6.jpeg -------------------------------------------------------------------------------- /docs/images/best_generated.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jrbtaylor/conditional-pixelcnn/2141e69ac98e567d201fce29189b1caf5ef17d62/docs/images/best_generated.jpeg -------------------------------------------------------------------------------- /docs/images/blindspot.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jrbtaylor/conditional-pixelcnn/2141e69ac98e567d201fce29189b1caf5ef17d62/docs/images/blindspot.jpeg -------------------------------------------------------------------------------- /docs/images/blindspot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jrbtaylor/conditional-pixelcnn/2141e69ac98e567d201fce29189b1caf5ef17d62/docs/images/blindspot.png -------------------------------------------------------------------------------- /docs/images/generated_plot.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jrbtaylor/conditional-pixelcnn/2141e69ac98e567d201fce29189b1caf5ef17d62/docs/images/generated_plot.gif -------------------------------------------------------------------------------- /docs/images/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jrbtaylor/conditional-pixelcnn/2141e69ac98e567d201fce29189b1caf5ef17d62/docs/images/mask.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: Conditional PixelCNNs 3 | tagline: 4 | description: A PyTorch implementation of Conditional PixelCNNs to generate between-class examples 5 | --- 6 | 7 | ## Motivation 8 | This is the first of what I expect will be a few posts. 9 | I started a [machine learning blog on WordPress](http://netsprawl.wordpress.com) in 2017 but abandoned it 2 posts in after finding that showing code without messing up the formatting was not possible — the narrow column format would wrap the code and render it unreadable. 10 | 11 | For 2018, my new year's resolution is to write 5 posts (as Github project pages). I wanted to play with PixelCNNs and finally try [PyTorch](http://pytorch.org) (I use Tensorflow for my work at [Envision.AI](http://envision.ai) and previously used Theano at McGill) so this post will include my thoughts on both. 12 | In particular, I was curious if PixelCNNs conditioned on class labels could generate believable between-class examples. 13 | 14 | 15 | 16 | 17 | ## Conditional PixelCNNs 18 | 19 | PixelCNNs are the convolutional version of PixelRNNs, which treat the pixels in an image as a sequence and predict each pixel after seeing the preceding pixels (defined as above and to the left, though this is arbitrary). 20 | PixelRNNs are an autoregressive model of the joint prior distribution for images: 21 | 22 |

p(x) = p(x0) ∏ p(xi| xi<)

23 | 24 | PixelRNNs are slow to train since the recurrence can't be parallelized — even small images have hundreds or thousands of pixels, which is a relatively long sequence for RNNs. 25 | Replacing the recurrence with masked convolutions, such that the convolution filter only sees pixels above and to the left, allows for faster training (figure from [conditional PixelCNN paper](https://arxiv.org/abs/1606.05328)). 26 | 27 |
28 | 29 | 30 | However, it's worth noting that the [original PixelCNN implementation](https://arxiv.org/abs/1601.06759) produced worse results than the PixelRNN. 31 | One possible reason for the degraded results, conjectured in the follow-up paper ([Conditional Image Generation with PixelCNN Decoders](https://arxiv.org/abs/1606.05328)), is the relative simplicity of the ReLU activations in the PixelCNN compared to the gated connections in the LSTM. 32 | The Conditional PixelCNN paper subsequently replaced the ReLUs with gated activations: 33 |

y = tanh(Wf∗ x) • σ(Wg∗ x)

34 | Another possible reason offered in the follow-up paper is that stacking masked convolutional filters results in blind spots, failing to capture all the pixels above the one being predicted (figure from [paper](https://arxiv.org/abs/1606.05328)): 35 | 36 |
37 | 38 | 39 | #### PixelCNNs vs GANs 40 | 41 | PixelCNNs and GANs are currently the two flavors of deep learning models for generating images. 42 | GANs are receiving a lot of attention recently, but in many ways I find their popularity unwarranted. 43 | 44 | It's unclear what objective GANs are actually trying to optimize as the minimum of the training objective (i.e. fooling the discriminator) would result in the generator recreating all the training images and/or generating adversarial examples that don't necessarily resemble natural images. 45 | This is reflected in the notorious difficulty of training GANs and the myriad hacks to regularize them. 46 | The idea of pitting two nets against each other to produce training signals is interesting and has produced many good papers (notably cycleGAN) 47 | but I remain unconvinced that they're useful for much beyond making flashy posts on social media. 48 | 49 | On the other hand, PixelCNNs have a nice probabilistic underpinning. 50 | This allows them to not only generate images by sampling the distribution (left-to-right, top-to-bottom, following their autoregressive definition), 51 | but also means they can be used for other tasks. For example: as a pre-screening network to detect out-of-domain or adversarial examples; for detecting outliers in a training set; or estimating uncertainty at test. 52 | I'll cover some of these extensions more in my next post. 53 | 54 | I'd be interested in hearing if anyone has tried combining PixelCNNs and GANs. Perhaps the PixelCNN can be used as a prior or as a final stage of the decoder (conditioned on some higher-level learned representation) to avoid some of the training difficulties with GANs. 55 | 56 | 57 | 58 | 59 | ## Implementation 60 | 61 | My implementation uses the gated blocks but for rapid implementation, I decided to forego the two-stream solution to the blind spot problem (separating the filters into horizontal and vertical components). 62 | There's code available for solving the blind-spot problem in Tensorflow and it'd be fairly trivial to re-write it in PyTorch. 63 | This way the masking is simple: everything below and to the right of the current pixel is zeroed-out in the filter and in the first layer the current pixel is also set to zero in the filter. 64 | ```python 65 | class MaskedConv(nn.Conv2d): 66 | def __init__(self,mask_type,in_channels,out_channels,kernel_size,stride=1): 67 | """ 68 | mask_type: 'A' for first layer of network, 'B' for all others 69 | """ 70 | super(MaskedConv,self).__init__(in_channels,out_channels,kernel_size, 71 | stride,padding=kernel_size//2) 72 | assert mask_type in ('A','B') 73 | mask = torch.ones(1,1,kernel_size,kernel_size) 74 | mask[:,:,kernel_size//2,kernel_size//2+(mask_type=='B'):] = 0 75 | mask[:,:,kernel_size//2+1:] = 0 76 | self.register_buffer('mask',mask) 77 | 78 | def forward(self,x): 79 | self.weight.data *= self.mask 80 | return super(MaskedConv,self).forward(x) 81 | ``` 82 | The implementation for the gated ResNet blocks is slightly more complicated: 83 | the PixelCNN has shortcut connections between the two halfs of the network, like a U-Net; 84 | PyTorch allows the forward method of a Module to take multiple inputs *only* if they're Variables; 85 | since the feature maps from the first half of the network are not Variables, they must be concatenated with the other input (the features from the preceding layer). 86 | This is avoided with the conditioning vector, since it is a Variable (in this case, the class label). 87 | ```python 88 | class GatedRes(nn.Module): 89 | def __init__(self,in_channels,out_channels,n_classes,kernel_size=3,stride=1, 90 | aux_channels=0): 91 | super(GatedRes,self).__init__() 92 | self.conv = MaskedConv('B',in_channels,2*out_channels,kernel_size, 93 | stride) 94 | self.y_embed = nn.Linear(n_classes,2*out_channels) 95 | self.out_channels = out_channels 96 | if aux_channels!=2*out_channels and aux_channels!=0: 97 | self.aux_shortcut = nn.Sequential( 98 | nn.Conv2d(aux_channels,2*out_channels,1), 99 | nn.BatchNorm2d(2*out_channels,momentum=0.1)) 100 | if in_channels!=out_channels: 101 | self.shortcut = nn.Sequential( 102 | nn.Conv2d(in_channels,out_channels,1), 103 | nn.BatchNorm2d(out_channels,momentum=0.1)) 104 | self.batchnorm = nn.BatchNorm2d(out_channels,momentum=0.1) 105 | 106 | def forward(self,x,y): 107 | # check for aux input from first half of net stacked into x 108 | if x.dim()==5: 109 | x,aux = torch.split(x,1,dim=0) 110 | x = torch.squeeze(x,0) 111 | aux = torch.squeeze(x,0) 112 | else: 113 | aux = None 114 | x1 = self.conv(x) 115 | y = torch.unsqueeze(torch.unsqueeze(self.y_embed(y),-1),-1) 116 | if aux is not None: 117 | if hasattr(self,'aux_shortcut'): 118 | aux = self.aux_shortcut(aux) 119 | x1 = (x1+aux)/2 120 | # split for gate (note: pytorch dims are [n,c,h,w]) 121 | xf,xg = torch.split(x1,self.out_channels,dim=1) 122 | yf,yg = torch.split(y,self.out_channels,dim=1) 123 | f = torch.tanh(xf+yf) 124 | g = torch.sigmoid(xg+yg) 125 | if hasattr(self,'shortcut'): 126 | x = self.shortcut(x) 127 | return x+self.batchnorm(g*f) 128 | ``` 129 | I wasn't sure where to put batch normalization from reading the original papers, so I placed it where I thought it made sense: prior to adding the residual connection. 130 | 131 | With those two classes implemented, the full network was relatively easy. 132 | The PyTorch scheme of defining everything as subclasses of `nn.Module`, initializing all the layers/operations/etc. in the constructor and then connecting them together in the `forward` method can be messy. 133 | This is especially true if you have lots of shortcut connections and want to code your model with loops for arbitrary depth. 134 | 135 | *Note:* to be able to save/restore the model, you have to store layers in a `ModuleList` instead of a regular list. 136 | Appending and indexing this list is otherwise the same though. 137 | 138 | ```python 139 | class PixelCNN(nn.Module): 140 | def __init__(self,in_channels,n_classes,n_features,n_layers,n_bins, 141 | dropout=0.5): 142 | super(PixelCNN,self).__init__() 143 | 144 | self.layers = nn.ModuleList() 145 | self.n_layers = n_layers 146 | 147 | # Up pass 148 | self.input_batchnorm = nn.BatchNorm2d(in_channels,momentum=0.1) 149 | for l in range(n_layers): 150 | if l==0: # start with normal conv 151 | block = nn.Sequential( 152 | MaskedConv('A',in_channels+1,n_features,kernel_size=7), 153 | nn.BatchNorm2d(n_features,momentum=0.1), 154 | nn.ReLU()) 155 | else: 156 | block = GatedRes(n_features, n_features, n_classes) 157 | self.layers.append(block) 158 | 159 | # Down pass 160 | for _ in range(n_layers): 161 | block = GatedRes(n_features, n_features,n_classes, 162 | aux_channels=n_features) 163 | self.layers.append(block) 164 | 165 | # Last layer: project to n_bins (output is [-1, n_bins, h, w]) 166 | self.layers.append( 167 | nn.Sequential(nn.Dropout2d(dropout), 168 | nn.Conv2d(n_features,n_bins,1), 169 | nn.LogSoftmax(dim=1))) 170 | 171 | def forward(self,x,y): 172 | # Add channel of ones so network can tell where padding is 173 | x = nn.functional.pad(x,(0,0,0,0,0,1,0,0),mode='constant',value=1) 174 | 175 | # Up pass 176 | features = [] 177 | i = -1 178 | for _ in range(self.n_layers): 179 | i += 1 180 | if i>0: 181 | x = self.layers[i](x,y) 182 | else: 183 | x = self.layers[i](x) 184 | features.append(x) 185 | 186 | # Down pass 187 | for _ in range(self.n_layers): 188 | i += 1 189 | x = self.layers[i](torch.stack((x,features.pop())),y) 190 | 191 | # Last layer 192 | i += 1 193 | x = self.layers[i](x) 194 | assert i==len(self.layers)-1 195 | assert len(features)==0 196 | return x 197 | ``` 198 | MNIST is practically black and white, so I discretized the label to only 4 grayscale levels for the purposes of calculating cross-entropy loss. 199 | On natural images, the number of output levels would obviously need to be higher. 200 | All layers in the network have 200 features. 201 | For data augmentation I used random rotations of +/-5 degrees with nearest neighbour sampling. 202 | For training, I used Adam with a learning rate of 10-4 and dropout rate of 0.9. 203 | 204 | The higher number of features (than is necessary for MNIST) and higher dropout is a trade-off of training time vs regularization. 205 | This is a trick that is rarely mentioned in papers but is helpful to avoid overfitting — I've only seen it mentioned in a paper for training on action recognition in video, where overfitting is a problem due to the high dimensionality vs the dataset sizes currently available. 206 | 207 | I have a single GTX1070 GPU at home, so I didn't run any kind of hyperparameter optimization: 208 | the ability to guess reasonable hyperparameters and have your model work says a lot about the robustness of Adam + batch normalization + dropout. 209 | The learning rate definitely could've been higher but this makes for a more interesting GIF. 210 | 211 | 212 | 213 | 214 | ## Results 215 | 216 |
217 | 218 | The gif above shows a batch of 50 images (5 examples per class) generated after each epoch throughout training, from seemingly random scribbles to something resembling actual digits. 219 | Here's the results at the best epoch: 220 | 221 |
222 | 223 | The motivation for this work was to see if a Conditional PixelCNN could also generate reasonable examples between classes. 224 | This is done by conditioning on soft labels instead of one-hot encoded labels. 225 | 226 | Let's try what I'd expect are easily confused pairs of digits: (1,7), (3,8), (4,9), (5,6) 227 | 228 | 229 |
230 | 231 | 232 |
233 | 234 | 235 |
236 | 237 | 238 |
239 | 240 | The generated between-class examples do not appear as realistic as the normal examples. 241 | It's possible the model needs some additional training signal (e.g. teacher forcing from a classifier network) to interpolate along the image manifold like that. 242 | This is somewhat disappointing because I had hoped that generating between-class examples might allow for a learned form of mixup to be used (rather than averaging images). 243 | Obviously, testing this idea further would require many more GPUs to generate batches of inputs so it's out of my range for now anyway. 244 | 245 | 246 | ## Miscellaneous Thoughts on PyTorch 247 | 1. Debugging something without a compiled computational graph (i.e. like in Tensorflow or Theano) is faster and more intuitive. 248 | 2. Writing models where you have to initialize every operation/layer/etc. in the constructor and then call them in the forward method seems unnecessarily complicated. This is especially error-prone for models with shortcut connections if you write them with loops for arbitrary depth. 249 | 3. Point 2 cancels out point 1 for me. I prefer Tensorflow. 250 | -------------------------------------------------------------------------------- /experiment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Written by Jason Taylor 2018-2019 3 | """ 4 | 5 | import json 6 | import os 7 | 8 | import torch 9 | 10 | import data 11 | import model 12 | import train 13 | from vis import generate_between_classes 14 | 15 | def run(dataset='mnist', batch_size=64, n_features=200, n_layers=6, n_bins=4, 16 | optimizer='adam', learnrate=1e-4, dropout=0.9, exp_name='pixelCNN', 17 | exp_dir='~/experiments/conditional-pixelcnn/', cuda=True, 18 | resume=False): 19 | 20 | exp_name += '_%s_%ifeat_%ilayers_%ibins'%( 21 | dataset, n_features, n_layers, n_bins) 22 | exp_dir = os.path.join(os.path.expanduser(exp_dir), exp_name) 23 | if not os.path.isdir(exp_dir): 24 | os.makedirs(exp_dir) 25 | 26 | # Data loaders 27 | train_loader, val_loader, onehot_fcn, n_classes = data.loader(dataset, 28 | batch_size) 29 | 30 | if not resume: 31 | # Store experiment params in params.json 32 | params = {'batch_size':batch_size, 'n_features':n_features, 33 | 'n_layers':n_layers, 'n_bins':n_bins, 'optimizer': optimizer, 34 | 'learnrate':learnrate, 'dropout':dropout, 'cuda':cuda} 35 | with open(os.path.join(exp_dir,'params.json'),'w') as f: 36 | json.dump(params,f) 37 | 38 | # Model 39 | net = model.PixelCNN(1, n_classes, n_features, n_layers, n_bins, 40 | dropout) 41 | else: 42 | # if resuming, need to have params, stats and checkpoint files 43 | if not (os.path.isfile(os.path.join(exp_dir,'params.json')) 44 | and os.path.isfile(os.path.join(exp_dir,'stats.json')) 45 | and os.path.isfile(os.path.join(exp_dir,'last_checkpoint'))): 46 | raise Exception('Missing param, stats or checkpoint file on resume') 47 | net = torch.load(os.path.join(exp_dir, 'last_checkpoint')) 48 | 49 | # Define loss fcn, incl. label formatting from input 50 | def input2label(x): 51 | return torch.squeeze(torch.round((n_bins-1)*x).type(torch.LongTensor),1) 52 | loss_fcn = torch.nn.NLLLoss2d() 53 | 54 | # Train 55 | train.fit(train_loader, val_loader, net, exp_dir, input2label, loss_fcn, 56 | onehot_fcn, n_classes, optimizer, learnrate=learnrate, cuda=cuda, 57 | resume=resume) 58 | 59 | # Generate some between-class examples 60 | generate_between_classes(net, [28, 28], [1, 7], 61 | os.path.join(exp_dir,'1-7.jpeg'), n_classes, cuda) 62 | generate_between_classes(net, [28, 28], [3, 8], 63 | os.path.join(exp_dir,'3-8.jpeg'), n_classes, cuda) 64 | generate_between_classes(net, [28, 28], [4, 9], 65 | os.path.join(exp_dir,'4-9.jpeg'), n_classes, cuda) 66 | generate_between_classes(net, [28, 28], [5, 6], 67 | os.path.join(exp_dir,'5-6.jpeg'), n_classes, cuda) 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Written by Jason Taylor 2018-2019 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class MaskedConv(nn.Conv2d): 10 | def __init__(self,mask_type,in_channels,out_channels,kernel_size,stride=1): 11 | """ 12 | mask_type: 'A' for first layer of network, 'B' for all others 13 | """ 14 | super(MaskedConv,self).__init__(in_channels,out_channels,kernel_size, 15 | stride,padding=kernel_size//2) 16 | assert mask_type in ('A','B') 17 | mask = torch.ones(1,1,kernel_size,kernel_size) 18 | mask[:,:,kernel_size//2,kernel_size//2+(mask_type=='B'):] = 0 19 | mask[:,:,kernel_size//2+1:] = 0 20 | self.register_buffer('mask',mask) 21 | 22 | def forward(self,x): 23 | self.weight.data *= self.mask 24 | return super(MaskedConv,self).forward(x) 25 | 26 | 27 | class GatedRes(nn.Module): 28 | def __init__(self,in_channels,out_channels,n_classes,kernel_size=3,stride=1, 29 | aux_channels=0): 30 | super(GatedRes,self).__init__() 31 | self.conv = MaskedConv('B',in_channels,2*out_channels,kernel_size, 32 | stride) 33 | self.y_embed = nn.Linear(n_classes,2*out_channels) 34 | self.out_channels = out_channels 35 | if aux_channels!=2*out_channels and aux_channels!=0: 36 | self.aux_shortcut = nn.Sequential( 37 | nn.Conv2d(aux_channels,2*out_channels,1), 38 | nn.BatchNorm2d(2*out_channels,momentum=0.1)) 39 | if in_channels!=out_channels: 40 | self.shortcut = nn.Sequential( 41 | nn.Conv2d(in_channels,out_channels,1), 42 | nn.BatchNorm2d(out_channels,momentum=0.1)) 43 | self.batchnorm = nn.BatchNorm2d(out_channels,momentum=0.1) 44 | 45 | def forward(self,x,y): 46 | # check for aux input from first half of net stacked into x 47 | if x.dim()==5: 48 | x,aux = torch.split(x,1,dim=0) 49 | x = torch.squeeze(x,0) 50 | aux = torch.squeeze(x,0) 51 | else: 52 | aux = None 53 | x1 = self.conv(x) 54 | y = torch.unsqueeze(torch.unsqueeze(self.y_embed(y),-1),-1) 55 | if aux is not None: 56 | if hasattr(self,'aux_shortcut'): 57 | aux = self.aux_shortcut(aux) 58 | x1 = (x1+aux)/2 59 | # split for gate (note: pytorch dims are [n,c,h,w]) 60 | xf,xg = torch.split(x1,self.out_channels,dim=1) 61 | yf,yg = torch.split(y,self.out_channels,dim=1) 62 | f = torch.tanh(xf+yf) 63 | g = torch.sigmoid(xg+yg) 64 | if hasattr(self,'shortcut'): 65 | x = self.shortcut(x) 66 | return x+self.batchnorm(g*f) 67 | 68 | 69 | class PixelCNN(nn.Module): 70 | def __init__(self,in_channels,n_classes,n_features,n_layers,n_bins, 71 | dropout=0.5): 72 | super(PixelCNN,self).__init__() 73 | 74 | self.layers = nn.ModuleList() 75 | self.n_layers = n_layers 76 | 77 | # Up pass 78 | self.input_batchnorm = nn.BatchNorm2d(in_channels,momentum=0.1) 79 | for l in range(n_layers): 80 | if l==0: # start with normal conv 81 | block = nn.Sequential( 82 | MaskedConv('A',in_channels+1,n_features,kernel_size=7), 83 | nn.BatchNorm2d(n_features,momentum=0.1), 84 | nn.ReLU()) 85 | else: 86 | block = GatedRes(n_features, n_features, n_classes) 87 | self.layers.append(block) 88 | 89 | # Down pass 90 | for _ in range(n_layers): 91 | block = GatedRes(n_features, n_features,n_classes, 92 | aux_channels=n_features) 93 | self.layers.append(block) 94 | 95 | # Last layer: project to n_bins (output is [-1, n_bins, h, w]) 96 | self.dropout = nn.Dropout2d(dropout) 97 | self.layers.append(GatedRes(n_features,n_bins,n_classes)) 98 | self.layers.append(nn.LogSoftmax(dim=1)) 99 | 100 | def forward(self,x,y): 101 | # Add channel of ones so network can tell where padding is 102 | x = nn.functional.pad(x,(0,0,0,0,0,1,0,0),mode='constant',value=1) 103 | 104 | # Up pass 105 | features = [] 106 | i = -1 107 | for _ in range(self.n_layers): 108 | i += 1 109 | if i>0: 110 | x = self.layers[i](x,y) 111 | else: 112 | x = self.layers[i](x) 113 | features.append(x) 114 | 115 | # Down pass 116 | for _ in range(self.n_layers): 117 | i += 1 118 | x = self.layers[i](torch.stack((x,features.pop())),y) 119 | 120 | # Last layer 121 | x = self.dropout(x) 122 | i += 1 123 | x = self.layers[i](x,y) 124 | i += 1 125 | x = self.layers[i](x) 126 | assert i==len(self.layers)-1 127 | assert len(features)==0 128 | return x 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Written by Jason Taylor 2018-2019 3 | """ 4 | 5 | import imageio 6 | import json 7 | import os 8 | from PIL import Image 9 | import time 10 | 11 | import matplotlib; matplotlib.use('Agg') 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | from progressbar import ProgressBar 15 | from skimage.transform import resize 16 | import torch 17 | from torch.autograd import Variable 18 | 19 | from vis import plot_stats, clearline, generate, tile_images 20 | 21 | 22 | def generate_images(model,img_size,n_classes,onehot_fcn,cuda=True): 23 | y = np.array(list(range(min(n_classes,10)))*5) # gpu mem limit 24 | y = np.concatenate([onehot_fcn(x)[np.newaxis,:] for x in y]) 25 | return generate(model, img_size, y, cuda) 26 | 27 | 28 | def plot_loss(train_loss, val_loss): 29 | fig = plt.figure(num=1, figsize=(4, 4), dpi=70, facecolor='w', 30 | edgecolor='k') 31 | plt.plot(range(1,len(train_loss)+1), train_loss, 'r', label='training') 32 | plt.plot(range(1,len(val_loss)+1), val_loss, 'b', label='validation') 33 | plt.title('After %i epochs'%len(train_loss)) 34 | plt.xlabel('Epoch') 35 | plt.ylabel('Cross-entropy loss') 36 | plt.rcParams.update({'font.size':10}) 37 | fig.tight_layout(pad=1) 38 | fig.canvas.draw() 39 | 40 | # now convert the plot to a numpy array 41 | plot = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 42 | plot = plot.reshape(fig.canvas.get_width_height()[::-1]+(3,)) 43 | plt.close(fig) 44 | return plot 45 | 46 | 47 | def fit(train_loader, val_loader, model, exp_path, label_preprocess, loss_fcn, 48 | onehot_fcn, n_classes=10, optimizer='adam', learnrate=1e-4, cuda=True, 49 | patience=10, max_epochs=200, resume=False): 50 | 51 | if cuda: 52 | model = model.cuda() 53 | 54 | if not os.path.isdir(exp_path): 55 | os.makedirs(exp_path) 56 | statsfile = os.path.join(exp_path,'stats.json') 57 | 58 | optimizer = {'adam':torch.optim.Adam(model.parameters(),lr=learnrate), 59 | 'sgd':torch.optim.SGD( 60 | model.parameters(),lr=learnrate,momentum=0.9), 61 | 'adamax':torch.optim.Adamax(model.parameters(),lr=learnrate) 62 | }[optimizer.lower()] 63 | 64 | # load a single example from the iterator to get the image size 65 | x = train_loader.sampler.data_source.__getitem__(0)[0] 66 | img_size = list(x.numpy().shape[1:]) 67 | 68 | if not resume: 69 | stats = {'loss':{'train':[],'val':[]}, 70 | 'mean_output':{'train':[],'val':[]}} 71 | best_val = np.inf 72 | stall = 0 73 | start_epoch = 0 74 | generated = [] 75 | plots = [] 76 | else: 77 | with open(statsfile,'r') as js: 78 | stats = json.load(js) 79 | best_val = np.min(stats['loss']['val']) 80 | stall = len(stats['loss']['val'])-np.argmin(stats['loss']['val'])-1 81 | start_epoch = len(stats['loss']['val'])-1 82 | generated = list(np.load(os.path.join(exp_path,'generated.npy'))) 83 | plots = list(np.load(os.path.join(exp_path,'generated_plots.npy'))) 84 | print('Resuming from epoch %i'%start_epoch) 85 | 86 | def save_img(x,filename): 87 | Image.fromarray((255*x).astype('uint8')).save(filename) 88 | 89 | def epoch(dataloader,training): 90 | bar = ProgressBar() 91 | losses = [] 92 | mean_outs = [] 93 | for x,y in bar(dataloader): 94 | label = label_preprocess(x) 95 | if cuda: 96 | x,y = x.cuda(),y.cuda() 97 | label = label.cuda() 98 | x,y = Variable(x),Variable(y) 99 | label = Variable(label) 100 | if training: 101 | optimizer.zero_grad() 102 | model.train() 103 | else: 104 | model.eval() 105 | output = model(x,y) 106 | loss = loss_fcn(output,label) 107 | # track mean output 108 | output = output.data.cpu().numpy() 109 | mean_outs.append(np.mean(np.argmax(output,axis=1))/output.shape[1]) 110 | if training: 111 | loss.backward() 112 | optimizer.step() 113 | losses.append(loss.data.cpu().numpy()) 114 | clearline() 115 | return float(np.mean(losses)), np.mean(mean_outs) 116 | 117 | for e in range(start_epoch,max_epochs): 118 | # Training 119 | t0 = time.time() 120 | loss,mean_out = epoch(train_loader,training=True) 121 | time_per_example = (time.time()-t0)/len(train_loader.dataset) 122 | stats['loss']['train'].append(loss) 123 | stats['mean_output']['train'].append(mean_out) 124 | print(('Epoch %3i: Training loss = %6.4f mean output = %1.2f ' 125 | '%4.2f msec/example')%(e,loss,mean_out,time_per_example*1000)) 126 | 127 | # Validation 128 | t0 = time.time() 129 | loss,mean_out = epoch(val_loader,training=False) 130 | time_per_example = (time.time()-t0)/len(val_loader.dataset) 131 | stats['loss']['val'].append(loss) 132 | stats['mean_output']['val'].append(mean_out) 133 | print((' Validation loss = %6.4f mean output = %1.2f ' 134 | '%4.2f msec/example')%(loss,mean_out,time_per_example*1000)) 135 | 136 | # Generate images and update gif 137 | new_frame = tile_images(generate_images(model, img_size, n_classes, 138 | onehot_fcn, cuda)) 139 | generated.append(new_frame) 140 | 141 | # Update gif with loss plot 142 | plot_frame = plot_loss(stats['loss']['train'],stats['loss']['val']) 143 | if new_frame.ndim==2: 144 | new_frame = np.repeat(new_frame[:,:,np.newaxis],3,axis=2) 145 | nw = int(new_frame.shape[1]*plot_frame.shape[0]/new_frame.shape[0]) 146 | new_frame = resize(new_frame,[plot_frame.shape[0],nw], 147 | order=0, preserve_range=True, mode='constant') 148 | plots.append(np.concatenate((plot_frame.astype('uint8'), 149 | new_frame.astype('uint8')), 150 | axis=1)) 151 | 152 | # Save gif arrays so it can resume training if interrupted 153 | np.save(os.path.join(exp_path,'generated.npy'),generated) 154 | np.save(os.path.join(exp_path,'generated_plots.npy'),plots) 155 | 156 | # Save stats and update training curves 157 | with open(statsfile,'w') as sf: 158 | json.dump(stats,sf) 159 | plot_stats(stats,exp_path) 160 | 161 | # Early stopping 162 | torch.save(model,os.path.join(exp_path,'last_checkpoint')) 163 | if loss=patience: 178 | break 179 | 180 | 181 | 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Written by Jason Taylor 2018-2019 3 | """ 4 | 5 | import json 6 | import os 7 | 8 | import imageio 9 | import matplotlib 10 | # Disable Xwindows backend before importing matplotlib.pyplot 11 | matplotlib.use('Agg') 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | from progressbar import ProgressBar 15 | import torch 16 | from torch.autograd import Variable 17 | 18 | 19 | def plot_stats(stats,savepath): 20 | """ 21 | Make all the plots in stats. Stats can be a dict or a path to json (str) 22 | """ 23 | if type(stats) is str: 24 | assert os.path.isfile(stats) 25 | with open(stats,'r') as sf: 26 | stats = json.load(sf) 27 | assert type(stats) is dict 28 | 29 | assert type(savepath) is str 30 | if not os.path.isdir(savepath): 31 | os.makedirs(savepath) 32 | 33 | def _plot(y,title): 34 | plt.Figure() 35 | if type(y) is list: 36 | plt.plot(range(1,len(y)+1),y) 37 | elif type(y) is dict: 38 | for key,z in y.items(): 39 | plt.plot(range(1,len(z)+1),z,label=key) 40 | plt.legend() 41 | else: 42 | raise ValueError 43 | plt.xlabel('Epoch') 44 | plt.ylabel(title) 45 | plt.title(title) 46 | plt.savefig(os.path.join(savepath,title.replace(' ','_')+'.png')) 47 | plt.close() 48 | 49 | # Loop over stats dict and plot. Dicts within stats get plotted together. 50 | for key,value in stats.items(): 51 | _plot(value,key) 52 | 53 | 54 | def clearline(): 55 | CURSOR_UP_ONE = '\x1b[1A' 56 | ERASE_LINE = '\x1b[2K' 57 | print(CURSOR_UP_ONE+ERASE_LINE+CURSOR_UP_ONE) 58 | 59 | 60 | def generate(model, img_size, y, temp=0.8, cuda=True): 61 | model.eval() 62 | gen = torch.from_numpy(np.zeros([y.shape[0], 1]+img_size, dtype='float32')) 63 | y = torch.from_numpy(y) 64 | if cuda: 65 | y, gen = y.cuda(), gen.cuda() 66 | y, gen = Variable(y), Variable(gen) 67 | bar = ProgressBar() 68 | print('Generating images...') 69 | for r in bar(range(img_size[0])): 70 | for c in range(img_size[1]): 71 | out = model(gen, y) 72 | p = torch.exp(out)[:, :, r, c] 73 | p = torch.pow(p, 1/temp) 74 | p = p/torch.sum(p, -1, keepdim=True) 75 | sample = p.multinomial(1) 76 | gen[:, :, r, c] = sample.float()/(out.shape[1]-1) 77 | clearline() 78 | clearline() 79 | return (255*gen.data.cpu().numpy()).astype('uint8') 80 | 81 | 82 | def generate_between_classes(model, img_size, classes, saveto, 83 | n_classes, cuda=True): 84 | y = np.zeros((1,n_classes), dtype='float32') 85 | y[:,classes] = 1/len(classes) 86 | y = np.repeat(y,10,axis=0) 87 | gen = tile_images(generate(model, img_size, y, cuda),r=1) 88 | imageio.imsave(saveto,gen.astype('uint8')) 89 | 90 | 91 | def tile_images(imgs,r=0): 92 | n = len(imgs) 93 | h = imgs[0].shape[1] 94 | w = imgs[0].shape[2] 95 | if r==0: 96 | r = int(np.floor(np.sqrt(n))) 97 | while n%r!=0: 98 | r -= 1 99 | c = int(n/r) 100 | imgs = np.squeeze(np.array(imgs),axis=1) 101 | imgs = np.transpose(imgs,(1,2,0)) 102 | imgs = np.reshape(imgs,[h,w,r,c]) 103 | imgs = np.transpose(imgs,(2,3,0,1)) 104 | imgs = np.concatenate(imgs,1) 105 | imgs = np.concatenate(imgs,1) 106 | return imgs --------------------------------------------------------------------------------