├── LICENSE.md ├── README.md ├── data ├── __init__.py ├── cifar10_data.py ├── imagenet_data.py └── pixelcnn_samples.png ├── pixel_cnn_pp ├── __init__.py ├── model.py └── nn.py ├── train.py └── utils ├── __init__.py └── plotting.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | This project is licensed under both the MIT License and the Apache License 2.0: 2 | 3 | ## MIT License 4 | 5 | Copyright (c) 2019 OpenAI (http://openai.com) 6 | 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included in all 15 | copies or substantial portions of the Software. 16 | 17 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | SOFTWARE. 24 | 25 | ## Apache License 2.0 26 | 27 | Apache License 28 | Version 2.0, January 2004 29 | http://www.apache.org/licenses/ 30 | 31 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 32 | 33 | 1. Definitions. 34 | 35 | "License" shall mean the terms and conditions for use, reproduction, 36 | and distribution as defined by Sections 1 through 9 of this document. 37 | 38 | "Licensor" shall mean the copyright owner or entity authorized by 39 | the copyright owner that is granting the License. 40 | 41 | "Legal Entity" shall mean the union of the acting entity and all 42 | other entities that control, are controlled by, or are under common 43 | control with that entity. For the purposes of this definition, 44 | "control" means (i) the power, direct or indirect, to cause the 45 | direction or management of such entity, whether by contract or 46 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 47 | outstanding shares, or (iii) beneficial ownership of such entity. 48 | 49 | "You" (or "Your") shall mean an individual or Legal Entity 50 | exercising permissions granted by this License. 51 | 52 | "Source" form shall mean the preferred form for making modifications, 53 | including but not limited to software source code, documentation 54 | source, and configuration files. 55 | 56 | "Object" form shall mean any form resulting from mechanical 57 | transformation or translation of a Source form, including but 58 | not limited to compiled object code, generated documentation, 59 | and conversions to other media types. 60 | 61 | "Work" shall mean the work of authorship, whether in Source or 62 | Object form, made available under the License, as indicated by a 63 | copyright notice that is included in or attached to the work 64 | (an example is provided in the Appendix below). 65 | 66 | "Derivative Works" shall mean any work, whether in Source or Object 67 | form, that is based on (or derived from) the Work and for which the 68 | editorial revisions, annotations, elaborations, or other modifications 69 | represent, as a whole, an original work of authorship. For the purposes 70 | of this License, Derivative Works shall not include works that remain 71 | separable from, or merely link (or bind by name) to the interfaces of, 72 | the Work and Derivative Works thereof. 73 | 74 | "Contribution" shall mean any work of authorship, including 75 | the original version of the Work and any modifications or additions 76 | to that Work or Derivative Works thereof, that is intentionally 77 | submitted to Licensor for inclusion in the Work by the copyright owner 78 | or by an individual or Legal Entity authorized to submit on behalf of 79 | the copyright owner. For the purposes of this definition, "submitted" 80 | means any form of electronic, verbal, or written communication sent 81 | to the Licensor or its representatives, including but not limited to 82 | communication on electronic mailing lists, source code control systems, 83 | and issue tracking systems that are managed by, or on behalf of, the 84 | Licensor for the purpose of discussing and improving the Work, but 85 | excluding communication that is conspicuously marked or otherwise 86 | designated in writing by the copyright owner as "Not a Contribution." 87 | 88 | "Contributor" shall mean Licensor and any individual or Legal Entity 89 | on behalf of whom a Contribution has been received by Licensor and 90 | subsequently incorporated within the Work. 91 | 92 | 2. Grant of Copyright License. Subject to the terms and conditions of 93 | this License, each Contributor hereby grants to You a perpetual, 94 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 95 | copyright license to reproduce, prepare Derivative Works of, 96 | publicly display, publicly perform, sublicense, and distribute the 97 | Work and such Derivative Works in Source or Object form. 98 | 99 | 3. Grant of Patent License. Subject to the terms and conditions of 100 | this License, each Contributor hereby grants to You a perpetual, 101 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 102 | (except as stated in this section) patent license to make, have made, 103 | use, offer to sell, sell, import, and otherwise transfer the Work, 104 | where such license applies only to those patent claims licensable 105 | by such Contributor that are necessarily infringed by their 106 | Contribution(s) alone or by combination of their Contribution(s) 107 | with the Work to which such Contribution(s) was submitted. If You 108 | institute patent litigation against any entity (including a 109 | cross-claim or counterclaim in a lawsuit) alleging that the Work 110 | or a Contribution incorporated within the Work constitutes direct 111 | or contributory patent infringement, then any patent licenses 112 | granted to You under this License for that Work shall terminate 113 | as of the date such litigation is filed. 114 | 115 | 4. Redistribution. You may reproduce and distribute copies of the 116 | Work or Derivative Works thereof in any medium, with or without 117 | modifications, and in Source or Object form, provided that You 118 | meet the following conditions: 119 | 120 | (a) You must give any other recipients of the Work or 121 | Derivative Works a copy of this License; and 122 | 123 | (b) You must cause any modified files to carry prominent notices 124 | stating that You changed the files; and 125 | 126 | (c) You must retain, in the Source form of any Derivative Works 127 | that You distribute, all copyright, patent, trademark, and 128 | attribution notices from the Source form of the Work, 129 | excluding those notices that do not pertain to any part of 130 | the Derivative Works; and 131 | 132 | (d) If the Work includes a "NOTICE" text file as part of its 133 | distribution, then any Derivative Works that You distribute must 134 | include a readable copy of the attribution notices contained 135 | within such NOTICE file, excluding those notices that do not 136 | pertain to any part of the Derivative Works, in at least one 137 | of the following places: within a NOTICE text file distributed 138 | as part of the Derivative Works; within the Source form or 139 | documentation, if provided along with the Derivative Works; or, 140 | within a display generated by the Derivative Works, if and 141 | wherever such third-party notices normally appear. The contents 142 | of the NOTICE file are for informational purposes only and 143 | do not modify the License. You may add Your own attribution 144 | notices within Derivative Works that You distribute, alongside 145 | or as an addendum to the NOTICE text from the Work, provided 146 | that such additional attribution notices cannot be construed 147 | as modifying the License. 148 | 149 | You may add Your own copyright statement to Your modifications and 150 | may provide additional or different license terms and conditions 151 | for use, reproduction, or distribution of Your modifications, or 152 | for any such Derivative Works as a whole, provided Your use, 153 | reproduction, and distribution of the Work otherwise complies with 154 | the conditions stated in this License. 155 | 156 | 5. Submission of Contributions. Unless You explicitly state otherwise, 157 | any Contribution intentionally submitted for inclusion in the Work 158 | by You to the Licensor shall be under the terms and conditions of 159 | this License, without any additional terms or conditions. 160 | Notwithstanding the above, nothing herein shall supersede or modify 161 | the terms of any separate license agreement you may have executed 162 | with Licensor regarding such Contributions. 163 | 164 | 6. Trademarks. This License does not grant permission to use the trade 165 | names, trademarks, service marks, or product names of the Licensor, 166 | except as required for reasonable and customary use in describing the 167 | origin of the Work and reproducing the content of the NOTICE file. 168 | 169 | 7. Disclaimer of Warranty. Unless required by applicable law or 170 | agreed to in writing, Licensor provides the Work (and each 171 | Contributor provides its Contributions) on an "AS IS" BASIS, 172 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 173 | implied, including, without limitation, any warranties or conditions 174 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 175 | PARTICULAR PURPOSE. You are solely responsible for determining the 176 | appropriateness of using or redistributing the Work and assume any 177 | risks associated with Your exercise of permissions under this License. 178 | 179 | 8. Limitation of Liability. In no event and under no legal theory, 180 | whether in tort (including negligence), contract, or otherwise, 181 | unless required by applicable law (such as deliberate and grossly 182 | negligent acts) or agreed to in writing, shall any Contributor be 183 | liable to You for damages, including any direct, indirect, special, 184 | incidental, or consequential damages of any character arising as a 185 | result of this License or out of the use or inability to use the 186 | Work (including but not limited to damages for loss of goodwill, 187 | work stoppage, computer failure or malfunction, or any and all 188 | other commercial damages or losses), even if such Contributor 189 | has been advised of the possibility of such damages. 190 | 191 | 9. Accepting Warranty or Additional Liability. While redistributing 192 | the Work or Derivative Works thereof, You may choose to offer, 193 | and charge a fee for, acceptance of support, warranty, indemnity, 194 | or other liability obligations and/or rights consistent with this 195 | License. However, in accepting such obligations, You may act only 196 | on Your own behalf and on Your sole responsibility, not on behalf 197 | of any other Contributor, and only if You agree to indemnify, 198 | defend, and hold each Contributor harmless for any liability 199 | incurred by, or claims asserted against, such Contributor by reason 200 | of your accepting any such warranty or additional liability. 201 | 202 | END OF TERMS AND CONDITIONS 203 | 204 | APPENDIX: How to apply the Apache License to your work. 205 | 206 | To apply the Apache License to your work, attach the following 207 | boilerplate notice, with the fields enclosed by brackets "[]" 208 | replaced with your own identifying information. (Don't include 209 | the brackets!) The text should be enclosed in the appropriate 210 | comment syntax for the file format. We also recommend that a 211 | file or class name and description of purpose be included on the 212 | same "printed page" as the copyright notice for easier 213 | identification within third-party archives. 214 | 215 | Copyright 2019 OpenAI (http://openai.com) 216 | 217 | Licensed under the Apache License, Version 2.0 (the "License"); 218 | you may not use this file except in compliance with the License. 219 | You may obtain a copy of the License at 220 | 221 | http://www.apache.org/licenses/LICENSE-2.0 222 | 223 | Unless required by applicable law or agreed to in writing, software 224 | distributed under the License is distributed on an "AS IS" BASIS, 225 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 226 | See the License for the specific language governing permissions and 227 | limitations under the License. 228 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Status:** Archive (code is provided as-is, no updates expected) 2 | 3 | 4 | # pixel-cnn++ 5 | 6 | This is a Python3 / [Tensorflow](https://www.tensorflow.org/) implementation 7 | of [PixelCNN++](https://openreview.net/pdf?id=BJrFC6ceg), as described in the following 8 | paper: 9 | 10 | **PixelCNN++: A PixelCNN Implementation with Discretized Logistic Mixture Likelihood and Other Modifications**, by 11 | Tim Salimans, Andrej Karpathy, Xi Chen, Diederik P. Kingma, and Yaroslav Bulatov. 12 | 13 | Our work builds on PixelCNNs that were originally proposed in [van der Oord et al.](https://arxiv.org/abs/1606.05328) 14 | in June 2016. PixelCNNs are a class of powerful generative models with tractable 15 | likelihood that are also easy to sample from. The core convolutional neural network 16 | computes a probability distribution over a value of one pixel conditioned on the values 17 | of pixels to the left and above it. Below are example samples from a model 18 | trained on CIFAR-10 that achieves **2.92 bits per dimension** (compared to 3.03 of 19 | the PixelCNN in van der Oord et al.): 20 | 21 | Samples from the model (**left**) and samples from a model that is conditioned 22 | on the CIFAR-10 class labels (**right**): 23 | 24 | ![Improved PixelCNN papers](data/pixelcnn_samples.png) 25 | 26 | This code supports multi-GPU training of our improved PixelCNN on [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) 27 | and [Small ImageNet](http://image-net.org/small/download.php), but is easy to adapt 28 | for additional datasets. Training on a machine with 8 Maxwell TITAN X GPUs achieves 29 | 3.0 bits per dimension in about 10 hours and it takes approximately 5 days to converge to 2.92. 30 | 31 | ## Setup 32 | 33 | To run this code you need the following: 34 | 35 | - a machine with multiple GPUs 36 | - Python3 37 | - Numpy, TensorFlow and imageio packages: 38 | ``` 39 | pip install numpy tensorflow-gpu imageio 40 | ``` 41 | 42 | ## Training the model 43 | 44 | Use the `train.py` script to train the model. To train the default model on 45 | CIFAR-10 simply use: 46 | 47 | ``` 48 | python3 train.py 49 | ``` 50 | 51 | You might want to at least change the `--data_dir` and `--save_dir` which 52 | point to paths on your system to download the data to (if not available), and 53 | where to save the checkpoints. 54 | 55 | **I want to train on fewer GPUs**. To train on fewer GPUs we recommend using `CUDA_VISIBLE_DEVICES` 56 | to narrow the visibility of GPUs to only a few and then run the script. Don't forget to modulate 57 | the flag `--nr_gpu` accordingly. 58 | 59 | **I want to train on my own dataset**. Have a look at the `DataLoader` classes 60 | in the `data/` folder. You have to write an analogous data iterator object for 61 | your own dataset and the code should work well from there. 62 | 63 | ## Pretrained model checkpoint 64 | 65 | You can download our pretrained (TensorFlow) model that achieves 2.92 bpd on CIFAR-10 [here](http://alpha.openai.com/pxpp.zip) (656MB). 66 | 67 | ## Citation 68 | 69 | If you find this code useful please cite us in your work: 70 | 71 | ``` 72 | @inproceedings{Salimans2017PixeCNN, 73 | title={PixelCNN++: A PixelCNN Implementation with Discretized Logistic Mixture Likelihood and Other Modifications}, 74 | author={Tim Salimans and Andrej Karpathy and Xi Chen and Diederik P. Kingma}, 75 | booktitle={ICLR}, 76 | year={2017} 77 | } 78 | ``` 79 | # pixel-cnn-rotations 80 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/pixel-cnn/bbc15688dd37934a12c2759cf2b34975e15901d9/data/__init__.py -------------------------------------------------------------------------------- /data/cifar10_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for downloading and unpacking the CIFAR-10 dataset, originally published 3 | by Krizhevsky et al. and hosted here: https://www.cs.toronto.edu/~kriz/cifar.html 4 | """ 5 | 6 | import os 7 | import sys 8 | import tarfile 9 | from six.moves import urllib 10 | import numpy as np 11 | 12 | def maybe_download_and_extract(data_dir, url='http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'): 13 | if not os.path.exists(os.path.join(data_dir, 'cifar-10-batches-py')): 14 | if not os.path.exists(data_dir): 15 | os.makedirs(data_dir) 16 | filename = url.split('/')[-1] 17 | filepath = os.path.join(data_dir, filename) 18 | if not os.path.exists(filepath): 19 | def _progress(count, block_size, total_size): 20 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, 21 | float(count * block_size) / float(total_size) * 100.0)) 22 | sys.stdout.flush() 23 | filepath, _ = urllib.request.urlretrieve(url, filepath, _progress) 24 | print() 25 | statinfo = os.stat(filepath) 26 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 27 | tarfile.open(filepath, 'r:gz').extractall(data_dir) 28 | 29 | def unpickle(file): 30 | fo = open(file, 'rb') 31 | if (sys.version_info >= (3, 0)): 32 | import pickle 33 | d = pickle.load(fo, encoding='latin1') 34 | else: 35 | import cPickle 36 | d = cPickle.load(fo) 37 | fo.close() 38 | return {'x': d['data'].reshape((10000,3,32,32)), 'y': np.array(d['labels']).astype(np.uint8)} 39 | 40 | def load(data_dir, subset='train'): 41 | maybe_download_and_extract(data_dir) 42 | if subset=='train': 43 | train_data = [unpickle(os.path.join(data_dir,'cifar-10-batches-py','data_batch_' + str(i))) for i in range(1,6)] 44 | trainx = np.concatenate([d['x'] for d in train_data],axis=0) 45 | trainy = np.concatenate([d['y'] for d in train_data],axis=0) 46 | return trainx, trainy 47 | elif subset=='test': 48 | test_data = unpickle(os.path.join(data_dir,'cifar-10-batches-py','test_batch')) 49 | testx = test_data['x'] 50 | testy = test_data['y'] 51 | return testx, testy 52 | else: 53 | raise NotImplementedError('subset should be either train or test') 54 | 55 | class DataLoader(object): 56 | """ an object that generates batches of CIFAR-10 data for training """ 57 | 58 | def __init__(self, data_dir, subset, batch_size, rng=None, shuffle=False, return_labels=False): 59 | """ 60 | - data_dir is location where to store files 61 | - subset is train|test 62 | - batch_size is int, of #examples to load at once 63 | - rng is np.random.RandomState object for reproducibility 64 | """ 65 | 66 | self.data_dir = data_dir 67 | self.batch_size = batch_size 68 | self.shuffle = shuffle 69 | self.return_labels = return_labels 70 | 71 | # create temporary storage for the data, if not yet created 72 | if not os.path.exists(data_dir): 73 | print('creating folder', data_dir) 74 | os.makedirs(data_dir) 75 | 76 | # load CIFAR-10 training data to RAM 77 | self.data, self.labels = load(os.path.join(data_dir,'cifar-10-python'), subset=subset) 78 | self.data = np.transpose(self.data, (0,2,3,1)) # (N,3,32,32) -> (N,32,32,3) 79 | 80 | self.p = 0 # pointer to where we are in iteration 81 | self.rng = np.random.RandomState(1) if rng is None else rng 82 | 83 | def get_observation_size(self): 84 | return self.data.shape[1:] 85 | 86 | def get_num_labels(self): 87 | return np.amax(self.labels) + 1 88 | 89 | def reset(self): 90 | self.p = 0 91 | 92 | def __iter__(self): 93 | return self 94 | 95 | def __next__(self, n=None): 96 | """ n is the number of examples to fetch """ 97 | if n is None: n = self.batch_size 98 | 99 | # on first iteration lazily permute all data 100 | if self.p == 0 and self.shuffle: 101 | inds = self.rng.permutation(self.data.shape[0]) 102 | self.data = self.data[inds] 103 | self.labels = self.labels[inds] 104 | 105 | # on last iteration reset the counter and raise StopIteration 106 | if self.p + n > self.data.shape[0]: 107 | self.reset() # reset for next time we get called 108 | raise StopIteration 109 | 110 | # on intermediate iterations fetch the next batch 111 | x = self.data[self.p : self.p + n] 112 | y = self.labels[self.p : self.p + n] 113 | self.p += self.batch_size 114 | 115 | if self.return_labels: 116 | return x,y 117 | else: 118 | return x 119 | 120 | next = __next__ # Python 2 compatibility (https://stackoverflow.com/questions/29578469/how-to-make-an-object-both-a-python2-and-python3-iterator) 121 | 122 | 123 | -------------------------------------------------------------------------------- /data/imagenet_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for loading the small ImageNet dataset used in Oord et al. 3 | use scripts/png_to_npz.py to create the npz files 4 | 5 | The code here currently assumes that the preprocessing was done manually. 6 | TODO: make automatic and painless 7 | """ 8 | 9 | import os 10 | import sys 11 | import tarfile 12 | from six.moves import urllib 13 | 14 | import numpy as np 15 | from imageio import imread 16 | 17 | def fetch(url, filepath): 18 | filename = url.split('/')[-1] 19 | def _progress(count, block_size, total_size): 20 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, 21 | float(count * block_size) / float(total_size) * 100.0)) 22 | sys.stdout.flush() 23 | print(url) 24 | filepath, headers = urllib.request.urlretrieve(url, filepath, _progress) 25 | print() 26 | statinfo = os.stat(filepath) 27 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 28 | 29 | def maybe_download_and_extract(data_dir): 30 | # more info on the dataset at http://image-net.org/small/download.php 31 | # downloads and extracts the two tar files for train/val 32 | 33 | train_dir = os.path.join(data_dir, 'train_32x32') 34 | if not os.path.exists(train_dir): 35 | train_url = 'http://image-net.org/small/train_32x32.tar' # 4GB 36 | filepath = os.path.join(data_dir, 'train_32x32.tar') 37 | fetch(train_url, filepath) 38 | print('unpacking the tar file', filepath) 39 | tarfile.open(filepath, 'r').extractall(data_dir) # creates the train_32x32 folder 40 | 41 | test_dir = os.path.join(data_dir, 'valid_32x32') 42 | if not os.path.exists(test_dir): 43 | test_url = 'http://image-net.org/small/valid_32x32.tar' # 154MB 44 | filepath = os.path.join(data_dir, 'valid_32x32.tar') 45 | fetch(test_url, filepath) 46 | print('unpacking the tar file', filepath) 47 | tarfile.open(filepath, 'r').extractall(data_dir) # creates the valid_32x32 folder 48 | 49 | def maybe_preprocess(data_dir): 50 | 51 | npz_file = os.path.join(data_dir, 'imgnet_32x32.npz') 52 | if os.path.exists(npz_file): 53 | return # all good 54 | 55 | trainx = [] 56 | train_dir = os.path.join(data_dir, 'train_32x32') 57 | for f in os.listdir(train_dir): 58 | if f.endswith('.png'): 59 | print('reading', f) 60 | filepath = os.path.join(train_dir, f) 61 | trainx.append(imread(filepath).reshape((1,32,32,3))) 62 | trainx = np.concatenate(trainx, axis=0) 63 | 64 | testx = [] 65 | test_dir = os.path.join(data_dir, 'valid_32x32') 66 | for f in os.listdir(test_dir): 67 | if f.endswith('.png'): 68 | print('reading', f) 69 | filepath = os.path.join(test_dir, f) 70 | testx.append(imread(filepath).reshape((1,32,32,3))) 71 | testx = np.concatenate(testx, axis=0) 72 | 73 | np.savez(npz_file, trainx=trainx, testx=testx) 74 | 75 | 76 | def load(data_dir, subset='train'): 77 | if not os.path.exists(data_dir): 78 | print('creating folder', data_dir) 79 | os.makedirs(data_dir) 80 | maybe_download_and_extract(data_dir) 81 | maybe_preprocess(data_dir) 82 | imagenet_data = np.load(os.path.join(data_dir,'imgnet_32x32.npz')) 83 | return imagenet_data['trainx'] if subset == 'train' else imagenet_data['testx'] 84 | 85 | 86 | 87 | class DataLoader(object): 88 | """ an object that generates batches of CIFAR-10 data for training """ 89 | 90 | def __init__(self, data_dir, subset, batch_size, rng=None, shuffle=False, **kwargs): 91 | """ 92 | - data_dir is location where the files are stored 93 | - subset is train|test 94 | - batch_size is int, of #examples to load at once 95 | - rng is np.random.RandomState object for reproducibility 96 | """ 97 | 98 | self.data_dir = data_dir 99 | self.batch_size = batch_size 100 | self.shuffle = shuffle 101 | 102 | self.data = load(os.path.join(data_dir,'small_imagenet'), subset=subset) 103 | 104 | self.p = 0 # pointer to where we are in iteration 105 | self.rng = np.random.RandomState(1) if rng is None else rng 106 | 107 | def get_observation_size(self): 108 | return self.data.shape[1:] 109 | 110 | def reset(self): 111 | self.p = 0 112 | 113 | def __iter__(self): 114 | return self 115 | 116 | def __next__(self, n=None): 117 | """ n is the number of examples to fetch """ 118 | if n is None: n = self.batch_size 119 | 120 | # on first iteration lazily permute all data 121 | if self.p == 0 and self.shuffle: 122 | inds = self.rng.permutation(self.data.shape[0]) 123 | self.data = self.data[inds] 124 | 125 | # on last iteration reset the counter and raise StopIteration 126 | if self.p + n > self.data.shape[0]: 127 | self.reset() # reset for next time we get called 128 | raise StopIteration 129 | 130 | # on intermediate iterations fetch the next batch 131 | x = self.data[self.p : self.p + n] 132 | self.p += self.batch_size 133 | 134 | return x 135 | 136 | next = __next__ # Python 2 compatibility (https://stackoverflow.com/questions/29578469/how-to-make-an-object-both-a-python2-and-python3-iterator) 137 | 138 | -------------------------------------------------------------------------------- /data/pixelcnn_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/pixel-cnn/bbc15688dd37934a12c2759cf2b34975e15901d9/data/pixelcnn_samples.png -------------------------------------------------------------------------------- /pixel_cnn_pp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/pixel-cnn/bbc15688dd37934a12c2759cf2b34975e15901d9/pixel_cnn_pp/__init__.py -------------------------------------------------------------------------------- /pixel_cnn_pp/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | The core Pixel-CNN model 3 | """ 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.contrib.framework.python.ops import arg_scope 8 | import pixel_cnn_pp.nn as nn 9 | 10 | def model_spec(x, h=None, init=False, ema=None, dropout_p=0.5, nr_resnet=5, nr_filters=160, nr_logistic_mix=10, resnet_nonlinearity='concat_elu', energy_distance=False): 11 | """ 12 | We receive a Tensor x of shape (N,H,W,D1) (e.g. (12,32,32,3)) and produce 13 | a Tensor x_out of shape (N,H,W,D2) (e.g. (12,32,32,100)), where each fiber 14 | of the x_out tensor describes the predictive distribution for the RGB at 15 | that position. 16 | 'h' is an optional N x K matrix of values to condition our generative model on 17 | """ 18 | 19 | counters = {} 20 | with arg_scope([nn.conv2d, nn.deconv2d, nn.gated_resnet, nn.dense], counters=counters, init=init, ema=ema, dropout_p=dropout_p): 21 | 22 | # parse resnet nonlinearity argument 23 | if resnet_nonlinearity == 'concat_elu': 24 | resnet_nonlinearity = nn.concat_elu 25 | elif resnet_nonlinearity == 'elu': 26 | resnet_nonlinearity = tf.nn.elu 27 | elif resnet_nonlinearity == 'relu': 28 | resnet_nonlinearity = tf.nn.relu 29 | else: 30 | raise('resnet nonlinearity ' + resnet_nonlinearity + ' is not supported') 31 | 32 | with arg_scope([nn.gated_resnet], nonlinearity=resnet_nonlinearity, h=h): 33 | 34 | # ////////// up pass through pixelCNN //////// 35 | xs = nn.int_shape(x) 36 | x_pad = tf.concat([x,tf.ones(xs[:-1]+[1])],3) # add channel of ones to distinguish image from padding later on 37 | u_list = [nn.down_shift(nn.down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2, 3]))] # stream for pixels above 38 | ul_list = [nn.down_shift(nn.down_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[1,3])) + \ 39 | nn.right_shift(nn.down_right_shifted_conv2d(x_pad, num_filters=nr_filters, filter_size=[2,1]))] # stream for up and to the left 40 | 41 | for rep in range(nr_resnet): 42 | u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d)) 43 | ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d)) 44 | 45 | u_list.append(nn.down_shifted_conv2d(u_list[-1], num_filters=nr_filters, stride=[2, 2])) 46 | ul_list.append(nn.down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, stride=[2, 2])) 47 | 48 | for rep in range(nr_resnet): 49 | u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d)) 50 | ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d)) 51 | 52 | u_list.append(nn.down_shifted_conv2d(u_list[-1], num_filters=nr_filters, stride=[2, 2])) 53 | ul_list.append(nn.down_right_shifted_conv2d(ul_list[-1], num_filters=nr_filters, stride=[2, 2])) 54 | 55 | for rep in range(nr_resnet): 56 | u_list.append(nn.gated_resnet(u_list[-1], conv=nn.down_shifted_conv2d)) 57 | ul_list.append(nn.gated_resnet(ul_list[-1], u_list[-1], conv=nn.down_right_shifted_conv2d)) 58 | 59 | # remember nodes 60 | for t in u_list+ul_list: 61 | tf.add_to_collection('checkpoints', t) 62 | 63 | # /////// down pass //////// 64 | u = u_list.pop() 65 | ul = ul_list.pop() 66 | for rep in range(nr_resnet): 67 | u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d) 68 | ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d) 69 | tf.add_to_collection('checkpoints', u) 70 | tf.add_to_collection('checkpoints', ul) 71 | 72 | u = nn.down_shifted_deconv2d(u, num_filters=nr_filters, stride=[2, 2]) 73 | ul = nn.down_right_shifted_deconv2d(ul, num_filters=nr_filters, stride=[2, 2]) 74 | 75 | for rep in range(nr_resnet+1): 76 | u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d) 77 | ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d) 78 | tf.add_to_collection('checkpoints', u) 79 | tf.add_to_collection('checkpoints', ul) 80 | 81 | u = nn.down_shifted_deconv2d(u, num_filters=nr_filters, stride=[2, 2]) 82 | ul = nn.down_right_shifted_deconv2d(ul, num_filters=nr_filters, stride=[2, 2]) 83 | 84 | for rep in range(nr_resnet+1): 85 | u = nn.gated_resnet(u, u_list.pop(), conv=nn.down_shifted_conv2d) 86 | ul = nn.gated_resnet(ul, tf.concat([u, ul_list.pop()],3), conv=nn.down_right_shifted_conv2d) 87 | tf.add_to_collection('checkpoints', u) 88 | tf.add_to_collection('checkpoints', ul) 89 | 90 | if energy_distance: 91 | f = nn.nin(tf.nn.elu(ul), 64) 92 | 93 | # generate 10 samples 94 | fs = [] 95 | for rep in range(10): 96 | fs.append(f) 97 | f = tf.concat(fs, 0) 98 | fs = nn.int_shape(f) 99 | f += nn.nin(tf.random_uniform(shape=fs[:-1] + [4], minval=-1., maxval=1.), 64) 100 | f = nn.nin(nn.concat_elu(f), 64) 101 | x_sample = tf.tanh(nn.nin(nn.concat_elu(f), 3, init_scale=0.1)) 102 | 103 | x_sample = tf.split(x_sample, 10, 0) 104 | 105 | assert len(u_list) == 0 106 | assert len(ul_list) == 0 107 | 108 | return x_sample 109 | 110 | else: 111 | x_out = nn.nin(tf.nn.elu(ul),10*nr_logistic_mix) 112 | 113 | assert len(u_list) == 0 114 | assert len(ul_list) == 0 115 | 116 | return x_out 117 | 118 | -------------------------------------------------------------------------------- /pixel_cnn_pp/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various tensorflow utilities 3 | """ 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.contrib.framework.python.ops import add_arg_scope 8 | 9 | def int_shape(x): 10 | return list(map(int, x.get_shape())) 11 | 12 | def concat_elu(x): 13 | """ like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU """ 14 | axis = len(x.get_shape())-1 15 | return tf.nn.elu(tf.concat([x, -x], axis)) 16 | 17 | def log_sum_exp(x): 18 | """ numerically stable log_sum_exp implementation that prevents overflow """ 19 | axis = len(x.get_shape())-1 20 | m = tf.reduce_max(x, axis) 21 | m2 = tf.reduce_max(x, axis, keepdims=True) 22 | return m + tf.log(tf.reduce_sum(tf.exp(x-m2), axis)) 23 | 24 | def log_prob_from_logits(x): 25 | """ numerically stable log_softmax implementation that prevents overflow """ 26 | axis = len(x.get_shape())-1 27 | m = tf.reduce_max(x, axis, keepdims=True) 28 | return x - m - tf.log(tf.reduce_sum(tf.exp(x-m), axis, keepdims=True)) 29 | 30 | def energy_distance(x, x_sample): 31 | l1 = 0. 32 | for xs in x_sample: 33 | l1 += tf.reduce_sum(tf.pow(1e-10 + tf.reduce_sum(tf.square(127.5*(x-xs)),3), 0.75)) 34 | l1 /= len(x_sample) 35 | 36 | l2 = 0. 37 | n = 0 38 | for i in range(len(x_sample)): 39 | for j in range(i+1,len(x_sample)): 40 | l2 += tf.reduce_sum(tf.pow(1e-10 + tf.reduce_sum(tf.square(127.5*(x_sample[i] - x_sample[j])), 3), 0.75)) 41 | n += 1 42 | l2 /= n 43 | 44 | return 2.*l1 - l2 45 | 46 | def discretized_mix_logistic_loss(x,l,sum_all=True): 47 | """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ 48 | xs = int_shape(x) # true image (i.e. labels) to regress to, e.g. (B,32,32,3) 49 | ls = int_shape(l) # predicted distribution, e.g. (B,32,32,100) 50 | nr_mix = int(ls[-1] / 10) # here and below: unpacking the params of the mixture of logistics 51 | logit_probs = l[:,:,:,:nr_mix] 52 | l = tf.reshape(l[:,:,:,nr_mix:], xs + [nr_mix*3]) 53 | means = l[:,:,:,:,:nr_mix] 54 | log_scales = tf.maximum(l[:,:,:,:,nr_mix:2*nr_mix], -7.) 55 | coeffs = tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix]) 56 | x = tf.reshape(x, xs + [1]) + tf.zeros(xs + [nr_mix]) # here and below: getting the means and adjusting them based on preceding sub-pixels 57 | m2 = tf.reshape(means[:,:,:,1,:] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0],xs[1],xs[2],1,nr_mix]) 58 | m3 = tf.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0],xs[1],xs[2],1,nr_mix]) 59 | means = tf.concat([tf.reshape(means[:,:,:,0,:], [xs[0],xs[1],xs[2],1,nr_mix]), m2, m3],3) 60 | centered_x = x - means 61 | inv_stdv = tf.exp(-log_scales) 62 | plus_in = inv_stdv * (centered_x + 1./255.) 63 | cdf_plus = tf.nn.sigmoid(plus_in) 64 | min_in = inv_stdv * (centered_x - 1./255.) 65 | cdf_min = tf.nn.sigmoid(min_in) 66 | log_cdf_plus = plus_in - tf.nn.softplus(plus_in) # log probability for edge case of 0 (before scaling) 67 | log_one_minus_cdf_min = -tf.nn.softplus(min_in) # log probability for edge case of 255 (before scaling) 68 | cdf_delta = cdf_plus - cdf_min # probability for all other cases 69 | mid_in = inv_stdv * centered_x 70 | log_pdf_mid = mid_in - log_scales - 2.*tf.nn.softplus(mid_in) # log probability in the center of the bin, to be used in extreme cases (not actually used in our code) 71 | 72 | # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us) 73 | 74 | # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select() 75 | # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta))) 76 | 77 | # robust version, that still works if probabilities are below 1e-5 (which never happens in our code) 78 | # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs 79 | # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue 80 | # if the probability on a sub-pixel is below 1e-5, we use an approximation based on the assumption that the log-density is constant in the bin of the observed sub-pixel value 81 | log_probs = tf.where(x < -0.999, log_cdf_plus, tf.where(x > 0.999, log_one_minus_cdf_min, tf.where(cdf_delta > 1e-5, tf.log(tf.maximum(cdf_delta, 1e-12)), log_pdf_mid - np.log(127.5)))) 82 | 83 | log_probs = tf.reduce_sum(log_probs,3) + log_prob_from_logits(logit_probs) 84 | if sum_all: 85 | return -tf.reduce_sum(log_sum_exp(log_probs)) 86 | else: 87 | return -tf.reduce_sum(log_sum_exp(log_probs),[1,2]) 88 | 89 | def sample_from_discretized_mix_logistic(l,nr_mix): 90 | ls = int_shape(l) 91 | xs = ls[:-1] + [3] 92 | # unpack parameters 93 | logit_probs = l[:, :, :, :nr_mix] 94 | l = tf.reshape(l[:, :, :, nr_mix:], xs + [nr_mix*3]) 95 | # sample mixture indicator from softmax 96 | sel = tf.one_hot(tf.argmax(logit_probs - tf.log(-tf.log(tf.random_uniform(logit_probs.get_shape(), minval=1e-5, maxval=1. - 1e-5))), 3), depth=nr_mix, dtype=tf.float32) 97 | sel = tf.reshape(sel, xs[:-1] + [1,nr_mix]) 98 | # select logistic parameters 99 | means = tf.reduce_sum(l[:,:,:,:,:nr_mix]*sel,4) 100 | log_scales = tf.maximum(tf.reduce_sum(l[:,:,:,:,nr_mix:2*nr_mix]*sel,4), -7.) 101 | coeffs = tf.reduce_sum(tf.nn.tanh(l[:,:,:,:,2*nr_mix:3*nr_mix])*sel,4) 102 | # sample from logistic & clip to interval 103 | # we don't actually round to the nearest 8bit value when sampling 104 | u = tf.random_uniform(means.get_shape(), minval=1e-5, maxval=1. - 1e-5) 105 | x = means + tf.exp(log_scales)*(tf.log(u) - tf.log(1. - u)) 106 | x0 = tf.minimum(tf.maximum(x[:,:,:,0], -1.), 1.) 107 | x1 = tf.minimum(tf.maximum(x[:,:,:,1] + coeffs[:,:,:,0]*x0, -1.), 1.) 108 | x2 = tf.minimum(tf.maximum(x[:,:,:,2] + coeffs[:,:,:,1]*x0 + coeffs[:,:,:,2]*x1, -1.), 1.) 109 | return tf.concat([tf.reshape(x0,xs[:-1]+[1]), tf.reshape(x1,xs[:-1]+[1]), tf.reshape(x2,xs[:-1]+[1])],3) 110 | 111 | def get_var_maybe_avg(var_name, ema, **kwargs): 112 | ''' utility for retrieving polyak averaged params ''' 113 | v = tf.get_variable(var_name, **kwargs) 114 | if ema is not None: 115 | v = ema.average(v) 116 | return v 117 | 118 | def get_vars_maybe_avg(var_names, ema, **kwargs): 119 | ''' utility for retrieving polyak averaged params ''' 120 | vars = [] 121 | for vn in var_names: 122 | vars.append(get_var_maybe_avg(vn, ema, **kwargs)) 123 | return vars 124 | 125 | def adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999): 126 | ''' Adam optimizer ''' 127 | updates = [] 128 | if type(cost_or_grads) is not list: 129 | grads = tf.gradients(cost_or_grads, params) 130 | else: 131 | grads = cost_or_grads 132 | t = tf.Variable(1., 'adam_t') 133 | for p, g in zip(params, grads): 134 | mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg') 135 | if mom1>0: 136 | v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v') 137 | v_t = mom1*v + (1. - mom1)*g 138 | v_hat = v_t / (1. - tf.pow(mom1,t)) 139 | updates.append(v.assign(v_t)) 140 | else: 141 | v_hat = g 142 | mg_t = mom2*mg + (1. - mom2)*tf.square(g) 143 | mg_hat = mg_t / (1. - tf.pow(mom2,t)) 144 | g_t = v_hat / tf.sqrt(mg_hat + 1e-8) 145 | p_t = p - lr * g_t 146 | updates.append(mg.assign(mg_t)) 147 | updates.append(p.assign(p_t)) 148 | updates.append(t.assign_add(1)) 149 | return tf.group(*updates) 150 | 151 | def get_name(layer_name, counters): 152 | ''' utlity for keeping track of layer names ''' 153 | if not layer_name in counters: 154 | counters[layer_name] = 0 155 | name = layer_name + '_' + str(counters[layer_name]) 156 | counters[layer_name] += 1 157 | return name 158 | 159 | @add_arg_scope 160 | def dense(x_, num_units, nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): 161 | ''' fully connected layer ''' 162 | name = get_name('dense', counters) 163 | with tf.variable_scope(name): 164 | V = get_var_maybe_avg('V', ema, shape=[int(x_.get_shape()[1]),num_units], dtype=tf.float32, 165 | initializer=tf.random_normal_initializer(0, 0.05), trainable=True) 166 | g = get_var_maybe_avg('g', ema, shape=[num_units], dtype=tf.float32, 167 | initializer=tf.constant_initializer(1.), trainable=True) 168 | b = get_var_maybe_avg('b', ema, shape=[num_units], dtype=tf.float32, 169 | initializer=tf.constant_initializer(0.), trainable=True) 170 | 171 | # use weight normalization (Salimans & Kingma, 2016) 172 | x = tf.matmul(x_, V) 173 | scaler = g / tf.sqrt(tf.reduce_sum(tf.square(V), [0])) 174 | x = tf.reshape(scaler, [1, num_units]) * x + tf.reshape(b, [1, num_units]) 175 | 176 | if init: # normalize x 177 | m_init, v_init = tf.nn.moments(x, [0]) 178 | scale_init = init_scale/tf.sqrt(v_init + 1e-10) 179 | with tf.control_dependencies([g.assign(g*scale_init), b.assign_add(-m_init*scale_init)]): 180 | # x = tf.identity(x) 181 | x = tf.matmul(x_, V) 182 | scaler = g / tf.sqrt(tf.reduce_sum(tf.square(V), [0])) 183 | x = tf.reshape(scaler, [1, num_units]) * x + tf.reshape(b, [1, num_units]) 184 | 185 | # apply nonlinearity 186 | if nonlinearity is not None: 187 | x = nonlinearity(x) 188 | 189 | return x 190 | 191 | @add_arg_scope 192 | def conv2d(x_, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): 193 | ''' convolutional layer ''' 194 | name = get_name('conv2d', counters) 195 | with tf.variable_scope(name): 196 | V = get_var_maybe_avg('V', ema, shape=filter_size+[int(x_.get_shape()[-1]),num_filters], dtype=tf.float32, 197 | initializer=tf.random_normal_initializer(0, 0.05), trainable=True) 198 | g = get_var_maybe_avg('g', ema, shape=[num_filters], dtype=tf.float32, 199 | initializer=tf.constant_initializer(1.), trainable=True) 200 | b = get_var_maybe_avg('b', ema, shape=[num_filters], dtype=tf.float32, 201 | initializer=tf.constant_initializer(0.), trainable=True) 202 | 203 | # use weight normalization (Salimans & Kingma, 2016) 204 | W = tf.reshape(g, [1, 1, 1, num_filters]) * tf.nn.l2_normalize(V, [0, 1, 2]) 205 | 206 | # calculate convolutional layer output 207 | x = tf.nn.bias_add(tf.nn.conv2d(x_, W, [1] + stride + [1], pad), b) 208 | 209 | if init: # normalize x 210 | m_init, v_init = tf.nn.moments(x, [0,1,2]) 211 | scale_init = init_scale / tf.sqrt(v_init + 1e-10) 212 | with tf.control_dependencies([g.assign(g * scale_init), b.assign_add(-m_init * scale_init)]): 213 | # x = tf.identity(x) 214 | W = tf.reshape(g, [1, 1, 1, num_filters]) * tf.nn.l2_normalize(V, [0, 1, 2]) 215 | x = tf.nn.bias_add(tf.nn.conv2d(x_, W, [1] + stride + [1], pad), b) 216 | 217 | # apply nonlinearity 218 | if nonlinearity is not None: 219 | x = nonlinearity(x) 220 | 221 | return x 222 | 223 | @add_arg_scope 224 | def deconv2d(x_, num_filters, filter_size=[3,3], stride=[1,1], pad='SAME', nonlinearity=None, init_scale=1., counters={}, init=False, ema=None, **kwargs): 225 | ''' transposed convolutional layer ''' 226 | name = get_name('deconv2d', counters) 227 | xs = int_shape(x_) 228 | if pad=='SAME': 229 | target_shape = [xs[0], xs[1]*stride[0], xs[2]*stride[1], num_filters] 230 | else: 231 | target_shape = [xs[0], xs[1]*stride[0] + filter_size[0]-1, xs[2]*stride[1] + filter_size[1]-1, num_filters] 232 | with tf.variable_scope(name): 233 | V = get_var_maybe_avg('V', ema, shape=filter_size+[num_filters,int(x_.get_shape()[-1])], dtype=tf.float32, 234 | initializer=tf.random_normal_initializer(0, 0.05), trainable=True) 235 | g = get_var_maybe_avg('g', ema, shape=[num_filters], dtype=tf.float32, 236 | initializer=tf.constant_initializer(1.), trainable=True) 237 | b = get_var_maybe_avg('b', ema, shape=[num_filters], dtype=tf.float32, 238 | initializer=tf.constant_initializer(0.), trainable=True) 239 | 240 | # use weight normalization (Salimans & Kingma, 2016) 241 | W = tf.reshape(g, [1, 1, num_filters, 1]) * tf.nn.l2_normalize(V, [0, 1, 3]) 242 | 243 | # calculate convolutional layer output 244 | x = tf.nn.conv2d_transpose(x_, W, target_shape, [1] + stride + [1], padding=pad) 245 | x = tf.nn.bias_add(x, b) 246 | 247 | if init: # normalize x 248 | m_init, v_init = tf.nn.moments(x, [0,1,2]) 249 | scale_init = init_scale / tf.sqrt(v_init + 1e-10) 250 | with tf.control_dependencies([g.assign(g * scale_init), b.assign_add(-m_init * scale_init)]): 251 | # x = tf.identity(x) 252 | W = tf.reshape(g, [1, 1, num_filters, 1]) * tf.nn.l2_normalize(V, [0, 1, 3]) 253 | x = tf.nn.conv2d_transpose(x_, W, target_shape, [1] + stride + [1], padding=pad) 254 | x = tf.nn.bias_add(x, b) 255 | 256 | # apply nonlinearity 257 | if nonlinearity is not None: 258 | x = nonlinearity(x) 259 | 260 | return x 261 | 262 | @add_arg_scope 263 | def nin(x, num_units, **kwargs): 264 | """ a network in network layer (1x1 CONV) """ 265 | s = int_shape(x) 266 | x = tf.reshape(x, [np.prod(s[:-1]),s[-1]]) 267 | x = dense(x, num_units, **kwargs) 268 | return tf.reshape(x, s[:-1]+[num_units]) 269 | 270 | ''' meta-layer consisting of multiple base layers ''' 271 | 272 | @add_arg_scope 273 | def gated_resnet(x, a=None, h=None, nonlinearity=concat_elu, conv=conv2d, init=False, counters={}, ema=None, dropout_p=0., **kwargs): 274 | xs = int_shape(x) 275 | num_filters = xs[-1] 276 | 277 | c1 = conv(nonlinearity(x), num_filters) 278 | if a is not None: # add short-cut connection if auxiliary input 'a' is given 279 | c1 += nin(nonlinearity(a), num_filters) 280 | c1 = nonlinearity(c1) 281 | if dropout_p > 0: 282 | c1 = tf.nn.dropout(c1, keep_prob=1. - dropout_p) 283 | c2 = conv(c1, num_filters * 2, init_scale=0.1) 284 | 285 | # add projection of h vector if included: conditional generation 286 | if h is not None: 287 | with tf.variable_scope(get_name('conditional_weights', counters)): 288 | hw = get_var_maybe_avg('hw', ema, shape=[int_shape(h)[-1], 2 * num_filters], dtype=tf.float32, 289 | initializer=tf.random_normal_initializer(0, 0.05), trainable=True) 290 | if init: 291 | hw = hw.initialized_value() 292 | c2 += tf.reshape(tf.matmul(h, hw), [xs[0], 1, 1, 2 * num_filters]) 293 | 294 | a, b = tf.split(c2, 2, 3) 295 | c3 = a * tf.nn.sigmoid(b) 296 | return x + c3 297 | 298 | ''' utilities for shifting the image around, efficient alternative to masking convolutions ''' 299 | 300 | def down_shift(x): 301 | xs = int_shape(x) 302 | return tf.concat([tf.zeros([xs[0],1,xs[2],xs[3]]), x[:,:xs[1]-1,:,:]],1) 303 | 304 | def right_shift(x): 305 | xs = int_shape(x) 306 | return tf.concat([tf.zeros([xs[0],xs[1],1,xs[3]]), x[:,:,:xs[2]-1,:]],2) 307 | 308 | @add_arg_scope 309 | def down_shifted_conv2d(x, num_filters, filter_size=[2,3], stride=[1,1], **kwargs): 310 | x = tf.pad(x, [[0,0],[filter_size[0]-1,0], [int((filter_size[1]-1)/2),int((filter_size[1]-1)/2)],[0,0]]) 311 | return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) 312 | 313 | @add_arg_scope 314 | def down_shifted_deconv2d(x, num_filters, filter_size=[2,3], stride=[1,1], **kwargs): 315 | x = deconv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) 316 | xs = int_shape(x) 317 | return x[:,:(xs[1]-filter_size[0]+1),int((filter_size[1]-1)/2):(xs[2]-int((filter_size[1]-1)/2)),:] 318 | 319 | @add_arg_scope 320 | def down_right_shifted_conv2d(x, num_filters, filter_size=[2,2], stride=[1,1], **kwargs): 321 | x = tf.pad(x, [[0,0],[filter_size[0]-1, 0], [filter_size[1]-1, 0],[0,0]]) 322 | return conv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) 323 | 324 | @add_arg_scope 325 | def down_right_shifted_deconv2d(x, num_filters, filter_size=[2,2], stride=[1,1], **kwargs): 326 | x = deconv2d(x, num_filters, filter_size=filter_size, pad='VALID', stride=stride, **kwargs) 327 | xs = int_shape(x) 328 | return x[:,:(xs[1]-filter_size[0]+1):,:(xs[2]-filter_size[1]+1),:] 329 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trains a Pixel-CNN++ generative model on CIFAR-10 or Tiny ImageNet data. 3 | Uses multiple GPUs, indicated by the flag --nr_gpu 4 | 5 | Example usage: 6 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train_double_cnn.py --nr_gpu 4 7 | """ 8 | 9 | import os 10 | import sys 11 | import json 12 | import argparse 13 | import time 14 | 15 | import numpy as np 16 | import tensorflow as tf 17 | 18 | from pixel_cnn_pp import nn 19 | from pixel_cnn_pp.model import model_spec 20 | from utils import plotting 21 | 22 | # ----------------------------------------------------------------------------- 23 | parser = argparse.ArgumentParser() 24 | # data I/O 25 | parser.add_argument('-i', '--data_dir', type=str, default='/local_home/tim/pxpp/data', help='Location for the dataset') 26 | parser.add_argument('-o', '--save_dir', type=str, default='/local_home/tim/pxpp/save', help='Location for parameter checkpoints and samples') 27 | parser.add_argument('-d', '--data_set', type=str, default='cifar', help='Can be either cifar|imagenet') 28 | parser.add_argument('-t', '--save_interval', type=int, default=20, help='Every how many epochs to write checkpoint/samples?') 29 | parser.add_argument('-r', '--load_params', dest='load_params', action='store_true', help='Restore training from previous model checkpoint?') 30 | # model 31 | parser.add_argument('-q', '--nr_resnet', type=int, default=5, help='Number of residual blocks per stage of the model') 32 | parser.add_argument('-n', '--nr_filters', type=int, default=160, help='Number of filters to use across the model. Higher = larger model.') 33 | parser.add_argument('-m', '--nr_logistic_mix', type=int, default=10, help='Number of logistic components in the mixture. Higher = more flexible model') 34 | parser.add_argument('-z', '--resnet_nonlinearity', type=str, default='concat_elu', help='Which nonlinearity to use in the ResNet layers. One of "concat_elu", "elu", "relu" ') 35 | parser.add_argument('-c', '--class_conditional', dest='class_conditional', action='store_true', help='Condition generative model on labels?') 36 | parser.add_argument('-ed', '--energy_distance', dest='energy_distance', action='store_true', help='use energy distance in place of likelihood') 37 | # optimization 38 | parser.add_argument('-l', '--learning_rate', type=float, default=0.001, help='Base learning rate') 39 | parser.add_argument('-e', '--lr_decay', type=float, default=0.999995, help='Learning rate decay, applied every step of the optimization') 40 | parser.add_argument('-b', '--batch_size', type=int, default=16, help='Batch size during training per GPU') 41 | parser.add_argument('-u', '--init_batch_size', type=int, default=16, help='How much data to use for data-dependent initialization.') 42 | parser.add_argument('-p', '--dropout_p', type=float, default=0.5, help='Dropout strength (i.e. 1 - keep_prob). 0 = No dropout, higher = more dropout.') 43 | parser.add_argument('-x', '--max_epochs', type=int, default=5000, help='How many epochs to run in total?') 44 | parser.add_argument('-g', '--nr_gpu', type=int, default=8, help='How many GPUs to distribute the training across?') 45 | # evaluation 46 | parser.add_argument('--polyak_decay', type=float, default=0.9995, help='Exponential decay rate of the sum of previous model iterates during Polyak averaging') 47 | parser.add_argument('-ns', '--num_samples', type=int, default=1, help='How many batches of samples to output.') 48 | # reproducibility 49 | parser.add_argument('-s', '--seed', type=int, default=1, help='Random seed to use') 50 | args = parser.parse_args() 51 | print('input args:\n', json.dumps(vars(args), indent=4, separators=(',',':'))) # pretty print args 52 | 53 | # ----------------------------------------------------------------------------- 54 | # fix random seed for reproducibility 55 | rng = np.random.RandomState(args.seed) 56 | tf.set_random_seed(args.seed) 57 | 58 | # energy distance or maximum likelihood? 59 | if args.energy_distance: 60 | loss_fun = nn.energy_distance 61 | else: 62 | loss_fun = nn.discretized_mix_logistic_loss 63 | 64 | # initialize data loaders for train/test splits 65 | if args.data_set == 'imagenet' and args.class_conditional: 66 | raise("We currently don't have labels for the small imagenet data set") 67 | if args.data_set == 'cifar': 68 | import data.cifar10_data as cifar10_data 69 | DataLoader = cifar10_data.DataLoader 70 | elif args.data_set == 'imagenet': 71 | import data.imagenet_data as imagenet_data 72 | DataLoader = imagenet_data.DataLoader 73 | else: 74 | raise("unsupported dataset") 75 | train_data = DataLoader(args.data_dir, 'train', args.batch_size * args.nr_gpu, rng=rng, shuffle=True, return_labels=args.class_conditional) 76 | test_data = DataLoader(args.data_dir, 'test', args.batch_size * args.nr_gpu, shuffle=False, return_labels=args.class_conditional) 77 | obs_shape = train_data.get_observation_size() # e.g. a tuple (32,32,3) 78 | assert len(obs_shape) == 3, 'assumed right now' 79 | 80 | # data place holders 81 | x_init = tf.placeholder(tf.float32, shape=(args.init_batch_size,) + obs_shape) 82 | xs = [tf.placeholder(tf.float32, shape=(args.batch_size, ) + obs_shape) for i in range(args.nr_gpu)] 83 | 84 | # if the model is class-conditional we'll set up label placeholders + one-hot encodings 'h' to condition on 85 | if args.class_conditional: 86 | num_labels = train_data.get_num_labels() 87 | y_init = tf.placeholder(tf.int32, shape=(args.init_batch_size,)) 88 | h_init = tf.one_hot(y_init, num_labels) 89 | y_sample = np.split(np.mod(np.arange(args.batch_size*args.nr_gpu), num_labels), args.nr_gpu) 90 | h_sample = [tf.one_hot(tf.Variable(y_sample[i], trainable=False), num_labels) for i in range(args.nr_gpu)] 91 | ys = [tf.placeholder(tf.int32, shape=(args.batch_size,)) for i in range(args.nr_gpu)] 92 | hs = [tf.one_hot(ys[i], num_labels) for i in range(args.nr_gpu)] 93 | else: 94 | h_init = None 95 | h_sample = [None] * args.nr_gpu 96 | hs = h_sample 97 | 98 | # create the model 99 | model_opt = { 'nr_resnet': args.nr_resnet, 'nr_filters': args.nr_filters, 'nr_logistic_mix': args.nr_logistic_mix, 'resnet_nonlinearity': args.resnet_nonlinearity, 'energy_distance': args.energy_distance } 100 | model = tf.make_template('model', model_spec) 101 | 102 | # run once for data dependent initialization of parameters 103 | init_pass = model(x_init, h_init, init=True, dropout_p=args.dropout_p, **model_opt) 104 | 105 | # keep track of moving average 106 | all_params = tf.trainable_variables() 107 | ema = tf.train.ExponentialMovingAverage(decay=args.polyak_decay) 108 | maintain_averages_op = tf.group(ema.apply(all_params)) 109 | ema_params = [ema.average(p) for p in all_params] 110 | 111 | # get loss gradients over multiple GPUs + sampling 112 | grads = [] 113 | loss_gen = [] 114 | loss_gen_test = [] 115 | new_x_gen = [] 116 | for i in range(args.nr_gpu): 117 | with tf.device('/gpu:%d' % i): 118 | # train 119 | out = model(xs[i], hs[i], ema=None, dropout_p=args.dropout_p, **model_opt) 120 | loss_gen.append(loss_fun(tf.stop_gradient(xs[i]), out)) 121 | 122 | # gradients 123 | grads.append(tf.gradients(loss_gen[i], all_params, colocate_gradients_with_ops=True)) 124 | 125 | # test 126 | out = model(xs[i], hs[i], ema=ema, dropout_p=0., **model_opt) 127 | loss_gen_test.append(loss_fun(xs[i], out)) 128 | 129 | # sample 130 | out = model(xs[i], h_sample[i], ema=ema, dropout_p=0, **model_opt) 131 | if args.energy_distance: 132 | new_x_gen.append(out[0]) 133 | else: 134 | new_x_gen.append(nn.sample_from_discretized_mix_logistic(out, args.nr_logistic_mix)) 135 | 136 | # add losses and gradients together and get training updates 137 | tf_lr = tf.placeholder(tf.float32, shape=[]) 138 | with tf.device('/gpu:0'): 139 | for i in range(1,args.nr_gpu): 140 | loss_gen[0] += loss_gen[i] 141 | loss_gen_test[0] += loss_gen_test[i] 142 | for j in range(len(grads[0])): 143 | grads[0][j] += grads[i][j] 144 | # training op 145 | optimizer = tf.group(nn.adam_updates(all_params, grads[0], lr=tf_lr, mom1=0.95, mom2=0.9995), maintain_averages_op) 146 | 147 | # convert loss to bits/dim 148 | bits_per_dim = loss_gen[0]/(args.nr_gpu*np.log(2.)*np.prod(obs_shape)*args.batch_size) 149 | bits_per_dim_test = loss_gen_test[0]/(args.nr_gpu*np.log(2.)*np.prod(obs_shape)*args.batch_size) 150 | 151 | # sample from the model 152 | def sample_from_model(sess): 153 | x_gen = [np.zeros((args.batch_size,) + obs_shape, dtype=np.float32) for i in range(args.nr_gpu)] 154 | for yi in range(obs_shape[0]): 155 | for xi in range(obs_shape[1]): 156 | new_x_gen_np = sess.run(new_x_gen, {xs[i]: x_gen[i] for i in range(args.nr_gpu)}) 157 | for i in range(args.nr_gpu): 158 | x_gen[i][:,yi,xi,:] = new_x_gen_np[i][:,yi,xi,:] 159 | return np.concatenate(x_gen, axis=0) 160 | 161 | # init & save 162 | initializer = tf.global_variables_initializer() 163 | saver = tf.train.Saver() 164 | 165 | # turn numpy inputs into feed_dict for use with tensorflow 166 | def make_feed_dict(data, init=False): 167 | if type(data) is tuple: 168 | x,y = data 169 | else: 170 | x = data 171 | y = None 172 | x = np.cast[np.float32]((x - 127.5) / 127.5) # input to pixelCNN is scaled from uint8 [0,255] to float in range [-1,1] 173 | if init: 174 | feed_dict = {x_init: x} 175 | if y is not None: 176 | feed_dict.update({y_init: y}) 177 | else: 178 | x = np.split(x, args.nr_gpu) 179 | feed_dict = {xs[i]: x[i] for i in range(args.nr_gpu)} 180 | if y is not None: 181 | y = np.split(y, args.nr_gpu) 182 | feed_dict.update({ys[i]: y[i] for i in range(args.nr_gpu)}) 183 | return feed_dict 184 | 185 | # //////////// perform training ////////////// 186 | if not os.path.exists(args.save_dir): 187 | os.makedirs(args.save_dir) 188 | test_bpd = [] 189 | lr = args.learning_rate 190 | with tf.Session() as sess: 191 | for epoch in range(args.max_epochs): 192 | begin = time.time() 193 | 194 | # init 195 | if epoch == 0: 196 | train_data.reset() # rewind the iterator back to 0 to do one full epoch 197 | if args.load_params: 198 | ckpt_file = args.save_dir + '/params_' + args.data_set + '.ckpt' 199 | print('restoring parameters from', ckpt_file) 200 | saver.restore(sess, ckpt_file) 201 | else: 202 | print('initializing the model...') 203 | sess.run(initializer) 204 | feed_dict = make_feed_dict(train_data.next(args.init_batch_size), init=True) # manually retrieve exactly init_batch_size examples 205 | sess.run(init_pass, feed_dict) 206 | print('starting training') 207 | 208 | # train for one epoch 209 | train_losses = [] 210 | for d in train_data: 211 | feed_dict = make_feed_dict(d) 212 | # forward/backward/update model on each gpu 213 | lr *= args.lr_decay 214 | feed_dict.update({ tf_lr: lr }) 215 | l,_ = sess.run([bits_per_dim, optimizer], feed_dict) 216 | train_losses.append(l) 217 | train_loss_gen = np.mean(train_losses) 218 | 219 | # compute likelihood over test data 220 | test_losses = [] 221 | for d in test_data: 222 | feed_dict = make_feed_dict(d) 223 | l = sess.run(bits_per_dim_test, feed_dict) 224 | test_losses.append(l) 225 | test_loss_gen = np.mean(test_losses) 226 | test_bpd.append(test_loss_gen) 227 | 228 | # log progress to console 229 | print("Iteration %d, time = %ds, train bits_per_dim = %.4f, test bits_per_dim = %.4f" % (epoch, time.time()-begin, train_loss_gen, test_loss_gen)) 230 | sys.stdout.flush() 231 | 232 | if epoch % args.save_interval == 0: 233 | 234 | # generate samples from the model 235 | sample_x = [] 236 | for i in range(args.num_samples): 237 | sample_x.append(sample_from_model(sess)) 238 | sample_x = np.concatenate(sample_x,axis=0) 239 | img_tile = plotting.img_tile(sample_x[:100], aspect_ratio=1.0, border_color=1.0, stretch=True) 240 | img = plotting.plot_img(img_tile, title=args.data_set + ' samples') 241 | plotting.plt.savefig(os.path.join(args.save_dir,'%s_sample%d.png' % (args.data_set, epoch))) 242 | plotting.plt.close('all') 243 | np.savez(os.path.join(args.save_dir,'%s_sample%d.npz' % (args.data_set, epoch)), sample_x) 244 | 245 | # save params 246 | saver.save(sess, args.save_dir + '/params_' + args.data_set + '.ckpt') 247 | np.savez(args.save_dir + '/test_bpd_' + args.data_set + '.npz', test_bpd=np.array(test_bpd)) 248 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/pixel-cnn/bbc15688dd37934a12c2759cf2b34975e15901d9/utils/__init__.py -------------------------------------------------------------------------------- /utils/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use('Agg') 4 | from matplotlib import pyplot as plt 5 | 6 | # Plot image examples. 7 | def plot_img(img, title=None): 8 | plt.figure() 9 | plt.imshow(img, interpolation='nearest') 10 | if title is not None: 11 | plt.title(title) 12 | plt.axis('off') 13 | plt.tight_layout() 14 | 15 | def img_stretch(img): 16 | img = img.astype(float) 17 | img -= np.min(img) 18 | img /= np.max(img)+1e-12 19 | return img 20 | 21 | def img_tile(imgs, aspect_ratio=1.0, tile_shape=None, border=1, 22 | border_color=0, stretch=False): 23 | ''' Tile images in a grid. 24 | If tile_shape is provided only as many images as specified in tile_shape 25 | will be included in the output. 26 | ''' 27 | 28 | # Prepare images 29 | if stretch: 30 | imgs = img_stretch(imgs) 31 | imgs = np.array(imgs) 32 | if imgs.ndim != 3 and imgs.ndim != 4: 33 | raise ValueError('imgs has wrong number of dimensions.') 34 | n_imgs = imgs.shape[0] 35 | 36 | # Grid shape 37 | img_shape = np.array(imgs.shape[1:3]) 38 | if tile_shape is None: 39 | img_aspect_ratio = img_shape[1] / float(img_shape[0]) 40 | aspect_ratio *= img_aspect_ratio 41 | tile_height = int(np.ceil(np.sqrt(n_imgs * aspect_ratio))) 42 | tile_width = int(np.ceil(np.sqrt(n_imgs / aspect_ratio))) 43 | grid_shape = np.array((tile_height, tile_width)) 44 | else: 45 | assert len(tile_shape) == 2 46 | grid_shape = np.array(tile_shape) 47 | 48 | # Tile image shape 49 | tile_img_shape = np.array(imgs.shape[1:]) 50 | tile_img_shape[:2] = (img_shape[:2] + border) * grid_shape[:2] - border 51 | 52 | # Assemble tile image 53 | tile_img = np.empty(tile_img_shape) 54 | tile_img[:] = border_color 55 | for i in range(grid_shape[0]): 56 | for j in range(grid_shape[1]): 57 | img_idx = j + i*grid_shape[1] 58 | if img_idx >= n_imgs: 59 | # No more images - stop filling out the grid. 60 | break 61 | img = imgs[img_idx] 62 | yoff = (img_shape[0] + border) * i 63 | xoff = (img_shape[1] + border) * j 64 | tile_img[yoff:yoff+img_shape[0], xoff:xoff+img_shape[1], ...] = img 65 | 66 | return tile_img 67 | 68 | def conv_filter_tile(filters): 69 | n_filters, n_channels, height, width = filters.shape 70 | tile_shape = None 71 | if n_channels == 3: 72 | # Interpret 3 color channels as RGB 73 | filters = np.transpose(filters, (0, 2, 3, 1)) 74 | else: 75 | # Organize tile such that each row corresponds to a filter and the 76 | # columns are the filter channels 77 | tile_shape = (n_channels, n_filters) 78 | filters = np.transpose(filters, (1, 0, 2, 3)) 79 | filters = np.resize(filters, (n_filters*n_channels, height, width)) 80 | filters = img_stretch(filters) 81 | return img_tile(filters, tile_shape=tile_shape) 82 | 83 | def scale_to_unit_interval(ndar, eps=1e-8): 84 | """ Scales all values in the ndarray ndar to be between 0 and 1 """ 85 | ndar = ndar.copy() 86 | ndar -= ndar.min() 87 | ndar *= 1.0 / (ndar.max() + eps) 88 | return ndar 89 | 90 | 91 | def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0, 0), 92 | scale_rows_to_unit_interval=True, 93 | output_pixel_vals=True): 94 | """ 95 | Transform an array with one flattened image per row, into an array in 96 | which images are reshaped and layed out like tiles on a floor. 97 | 98 | This function is useful for visualizing datasets whose rows are images, 99 | and also columns of matrices for transforming those rows 100 | (such as the first layer of a neural net). 101 | 102 | :type X: a 2-D ndarray or a tuple of 4 channels, elements of which can 103 | be 2-D ndarrays or None; 104 | :param X: a 2-D array in which every row is a flattened image. 105 | 106 | :type img_shape: tuple; (height, width) 107 | :param img_shape: the original shape of each image 108 | 109 | :type tile_shape: tuple; (rows, cols) 110 | :param tile_shape: the number of images to tile (rows, cols) 111 | 112 | :param output_pixel_vals: if output should be pixel values (i.e. int8 113 | values) or floats 114 | 115 | :param scale_rows_to_unit_interval: if the values need to be scaled before 116 | being plotted to [0,1] or not 117 | 118 | 119 | :returns: array suitable for viewing as an image. 120 | (See:`PIL.Image.fromarray`.) 121 | :rtype: a 2-d array with same dtype as X. 122 | 123 | """ 124 | 125 | assert len(img_shape) == 2 126 | assert len(tile_shape) == 2 127 | assert len(tile_spacing) == 2 128 | 129 | # The expression below can be re-written in a more C style as 130 | # follows : 131 | # 132 | # out_shape = [0,0] 133 | # out_shape[0] = (img_shape[0] + tile_spacing[0]) * tile_shape[0] - 134 | # tile_spacing[0] 135 | # out_shape[1] = (img_shape[1] + tile_spacing[1]) * tile_shape[1] - 136 | # tile_spacing[1] 137 | out_shape = [(ishp + tsp) * tshp - tsp for ishp, tshp, tsp 138 | in zip(img_shape, tile_shape, tile_spacing)] 139 | 140 | if isinstance(X, tuple): 141 | assert len(X) == 4 142 | # Create an output numpy ndarray to store the image 143 | if output_pixel_vals: 144 | out_array = np.zeros((out_shape[0], out_shape[1], 4), dtype='uint8') 145 | else: 146 | out_array = np.zeros((out_shape[0], out_shape[1], 4), dtype=X.dtype) 147 | 148 | #colors default to 0, alpha defaults to 1 (opaque) 149 | if output_pixel_vals: 150 | channel_defaults = [0, 0, 0, 255] 151 | else: 152 | channel_defaults = [0., 0., 0., 1.] 153 | 154 | for i in range(4): 155 | if X[i] is None: 156 | # if channel is None, fill it with zeros of the correct 157 | # dtype 158 | out_array[:, :, i] = np.zeros(out_shape, 159 | dtype='uint8' if output_pixel_vals else out_array.dtype 160 | ) + channel_defaults[i] 161 | else: 162 | # use a recurrent call to compute the channel and store it 163 | # in the output 164 | out_array[:, :, i] = tile_raster_images(X[i], img_shape, tile_shape, tile_spacing, scale_rows_to_unit_interval, output_pixel_vals) 165 | return out_array 166 | 167 | else: 168 | # if we are dealing with only one channel 169 | H, W = img_shape 170 | Hs, Ws = tile_spacing 171 | 172 | # generate a matrix to store the output 173 | out_array = np.zeros(out_shape, dtype='uint8' if output_pixel_vals else X.dtype) 174 | 175 | 176 | for tile_row in range(tile_shape[0]): 177 | for tile_col in range(tile_shape[1]): 178 | if tile_row * tile_shape[1] + tile_col < X.shape[0]: 179 | if scale_rows_to_unit_interval: 180 | # if we should scale values to be between 0 and 1 181 | # do this by calling the `scale_to_unit_interval` 182 | # function 183 | this_img = scale_to_unit_interval(X[tile_row * tile_shape[1] + tile_col].reshape(img_shape)) 184 | else: 185 | this_img = X[tile_row * tile_shape[1] + tile_col].reshape(img_shape) 186 | # add the slice to the corresponding position in the 187 | # output array 188 | out_array[ 189 | tile_row * (H+Hs): tile_row * (H + Hs) + H, 190 | tile_col * (W+Ws): tile_col * (W + Ws) + W 191 | ] \ 192 | = this_img * (255 if output_pixel_vals else 1) 193 | return out_array 194 | 195 | --------------------------------------------------------------------------------