├── .gitignore ├── LICENSE ├── config ├── ddim_bedroom.yml ├── ddim_celeba.yml ├── ddim_church.yml ├── ddim_cifar10.yml ├── iddpm_cifar10.yml ├── pf_cifar10.yml └── pf_deep_cifar10.yml ├── dataset ├── __init__.py ├── celeba.py ├── ffhq.py ├── lsun.py ├── utils.py └── vision.py ├── main.py ├── model ├── ddim.py ├── ema.py ├── iDDPM │ ├── nn.py │ └── unet.py └── scoresde │ ├── ddpm.py │ ├── layers.py │ ├── layerspp.py │ ├── ncsnpp.py │ ├── normalization.py │ ├── up_or_down_sampling.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── readme.md ├── requirements.txt ├── runner ├── method.py ├── runner.py └── schedule.py └── tool ├── dataset.sh └── fid.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | __pycache__/ 6 | .cache/ 7 | .idea/ 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /config/ddim_bedroom.yml: -------------------------------------------------------------------------------- 1 | Schedule: 2 | type: 'linear' 3 | beta_start: 0.0001 4 | beta_end: 0.02 5 | diffusion_step: 1000 6 | 7 | Dataset: 8 | dataset: 'LSUN' 9 | category: 'bedroom' 10 | image_size: 256 11 | channels: 3 12 | logit_transform: false 13 | uniform_dequantization: false 14 | gaussian_dequantization: false 15 | random_flip: true 16 | rescaled: true 17 | num_workers: 32 18 | batch_size: 64 19 | 20 | Model: 21 | struc: 'DDIM' 22 | type: "simple" 23 | in_channels: 3 24 | out_ch: 3 25 | ch: 128 26 | ch_mult: [ 1, 1, 2, 2, 4, 4 ] 27 | num_res_blocks: 2 28 | attn_resolutions: [ 16, ] 29 | dropout: 0.0 30 | var_type: fixedsmall 31 | resamp_with_conv: True 32 | image_size: 256 33 | 34 | Train: 35 | epoch: 10000 36 | loss_type: 'linear' 37 | ema_rate: 0.999 38 | ema: True 39 | 40 | Optim: 41 | weight_decay: 0.000 42 | optimizer: 'adam' 43 | lr: 0.00002 44 | beta1: 0.9 45 | amsgrad: false 46 | eps: 0.00000001 47 | 48 | Sample: 49 | mpi4py: false 50 | batch_size: 32 51 | last_only: True 52 | total_num: 12500 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /config/ddim_celeba.yml: -------------------------------------------------------------------------------- 1 | Schedule: 2 | type: 'linear' 3 | beta_start: 0.0001 4 | beta_end: 0.02 5 | diffusion_step: 1000 6 | 7 | Dataset: 8 | dataset: 'CELEBA' 9 | image_size: 64 10 | channels: 3 11 | logit_transform: false 12 | uniform_dequantization: false 13 | gaussian_dequantization: false 14 | random_flip: true 15 | rescaled: true 16 | num_workers: 6 17 | batch_size: 64 18 | 19 | Model: 20 | struc: 'DDIM' 21 | type: "simple" 22 | in_channels: 3 23 | out_ch: 3 24 | ch: 128 25 | ch_mult: [ 1, 2, 2, 2, 4 ] 26 | num_res_blocks: 2 27 | attn_resolutions: [ 16, ] 28 | dropout: 0.1 29 | var_type: fixedlarge 30 | resamp_with_conv: True 31 | image_size: 64 32 | 33 | Train: 34 | epoch: 10000 35 | loss_type: 'linear' 36 | ema_rate: 0.9999 37 | ema: True 38 | 39 | Optim: 40 | weight_decay: 0.000 41 | optimizer: 'adam' 42 | lr: 0.0002 43 | beta1: 0.9 44 | amsgrad: false 45 | eps: 0.00000001 46 | grad_clip: 1.0 47 | 48 | Sample: 49 | mpi4py: true 50 | batch_size: 128 51 | last_only: True 52 | total_num: 12500 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /config/ddim_church.yml: -------------------------------------------------------------------------------- 1 | Schedule: 2 | type: 'linear' 3 | beta_start: 0.0001 4 | beta_end: 0.02 5 | diffusion_step: 1000 6 | 7 | Dataset: 8 | dataset: 'LSUN' 9 | category: 'church_outdoor' 10 | image_size: 256 11 | channels: 3 12 | logit_transform: false 13 | uniform_dequantization: false 14 | gaussian_dequantization: false 15 | random_flip: true 16 | rescaled: true 17 | num_workers: 32 18 | batch_size: 64 19 | 20 | Model: 21 | struc: 'DDIM' 22 | type: "simple" 23 | in_channels: 3 24 | out_ch: 3 25 | ch: 128 26 | ch_mult: [ 1, 1, 2, 2, 4, 4 ] 27 | num_res_blocks: 2 28 | attn_resolutions: [ 16, ] 29 | dropout: 0.0 30 | var_type: fixedsmall 31 | resamp_with_conv: True 32 | image_size: 256 33 | 34 | Train: 35 | epoch: 10000 36 | loss_type: 'linear' 37 | ema_rate: 0.999 38 | ema: True 39 | 40 | Optim: 41 | weight_decay: 0.000 42 | optimizer: 'adam' 43 | lr: 0.00002 44 | beta1: 0.9 45 | amsgrad: false 46 | eps: 0.00000001 47 | 48 | Sample: 49 | mpi4py: false 50 | batch_size: 32 51 | last_only: True 52 | total_num: 12500 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /config/ddim_cifar10.yml: -------------------------------------------------------------------------------- 1 | Schedule: 2 | type: 'linear' 3 | beta_start: 0.0001 4 | beta_end: 0.02 5 | diffusion_step: 1000 6 | 7 | Dataset: 8 | dataset: 'CIFAR10' 9 | image_size: 32 10 | channels: 3 11 | batch_size: 64 12 | logit_transform: false 13 | uniform_dequantization: false 14 | gaussian_dequantization: false 15 | random_flip: true 16 | rescaled: true 17 | num_workers: 6 18 | 19 | Model: 20 | struc: 'DDIM' 21 | type: "simple" 22 | image_size: 32 23 | in_channels: 3 24 | out_ch: 3 25 | ch: 128 26 | ch_mult: [ 1, 2, 2, 2 ] 27 | num_res_blocks: 2 28 | attn_resolutions: [ 16, ] 29 | dropout: 0.1 30 | var_type: fixedlarge 31 | resamp_with_conv: True 32 | 33 | Train: 34 | epoch: 1000 35 | loss_type: 'linear' 36 | ema_rate: 0.9999 37 | ema: True 38 | 39 | Optim: 40 | weight_decay: 0.000 41 | optimizer: 'adam' 42 | lr: 0.0002 43 | beta1: 0.9 44 | amsgrad: false 45 | eps: 0.00000001 46 | grad_clip: 1.0 47 | 48 | Sample: 49 | mpi4py: true 50 | batch_size: 512 51 | last_only: True 52 | total_num: 12500 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/iddpm_cifar10.yml: -------------------------------------------------------------------------------- 1 | Schedule: 2 | type: 'cosine' 3 | diffusion_step: 1000 4 | learn_sigma: true 5 | sigma_small: false 6 | noise_schedule: cosine 7 | use_kl: false 8 | predict_xstart: false 9 | rescale_timesteps: true 10 | rescale_learned_sigmas: true 11 | timestep_respacing: ddim 12 | 13 | Dataset: 14 | dataset: 'CIFAR10' 15 | image_size: 32 16 | channels: 3 17 | batch_size: 256 18 | logit_transform: false 19 | uniform_dequantization: false 20 | gaussian_dequantization: false 21 | random_flip: true 22 | rescaled: true 23 | num_workers: 6 24 | 25 | Model: 26 | struc: 'iDDPM' 27 | in_channels: 3 28 | model_channels: 128 29 | out_channels: 6 30 | num_res_blocks: 3 31 | attention_resolutions: [ 2, 4 ] 32 | dropout: 0.3 33 | channel_mult: [ 1, 2, 2, 2 ] 34 | dims: 2 35 | conv_resample: true 36 | use_scale_shift_norm: true 37 | use_checkpoint: false 38 | num_heads: 4 39 | num_heads_upsample: 4 40 | 41 | Train: 42 | epoch: 1000 43 | loss_type: 'linear' 44 | ema_rate: 0.9999 45 | ema: True 46 | 47 | Optim: 48 | weight_decay: 0.000 49 | optimizer: 'adam' 50 | lr: 0.0001 51 | beta1: 0.9 52 | amsgrad: false 53 | eps: 0.00000001 54 | grad_clip: 1.0 55 | 56 | Sample: 57 | mpi4py: true 58 | batch_size: 512 59 | last_only: True 60 | total_num: 12500 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /config/pf_cifar10.yml: -------------------------------------------------------------------------------- 1 | Schedule: 2 | type: 'linear' 3 | beta_start: 0.0001 4 | beta_end: 0.02 5 | diffusion_step: 1000 6 | 7 | Dataset: 8 | dataset: 'CIFAR10' 9 | image_size: 32 10 | channels: 3 11 | batch_size: 64 12 | logit_transform: false 13 | uniform_dequantization: false 14 | gaussian_dequantization: false 15 | random_flip: true 16 | rescaled: true 17 | num_workers: 6 18 | 19 | Model: 20 | struc: 'PF' 21 | image_size: 32 22 | num_channels: 3 23 | nf: 128 24 | ch_mult: [ 1, 2, 2, 2 ] 25 | num_res_blocks: 2 26 | attn_resolutions: [ 16, ] 27 | dropout: 0.1 28 | nonlinearity: 'swish' 29 | resamp_with_conv: true 30 | conditional: true 31 | centered: true 32 | 33 | Train: 34 | epoch: 1000 35 | loss_type: 'linear' 36 | ema_rate: 0.9999 37 | ema: True 38 | 39 | Optim: 40 | weight_decay: 0.000 41 | optimizer: 'adam' 42 | lr: 0.0002 43 | beta1: 0.9 44 | amsgrad: false 45 | eps: 0.00000001 46 | grad_clip: 1.0 47 | 48 | Sample: 49 | mpi4py: true 50 | batch_size: 512 51 | last_only: True 52 | total_num: 12500 53 | 54 | 55 | -------------------------------------------------------------------------------- /config/pf_deep_cifar10.yml: -------------------------------------------------------------------------------- 1 | Schedule: 2 | type: 'linear' 3 | beta_start: 0.0001 4 | beta_end: 0.02 5 | diffusion_step: 1000 6 | 7 | Dataset: 8 | dataset: 'CIFAR10' 9 | image_size: 32 10 | channels: 3 11 | batch_size: 64 12 | logit_transform: false 13 | uniform_dequantization: false 14 | gaussian_dequantization: false 15 | random_flip: true 16 | rescaled: true 17 | num_workers: 6 18 | 19 | Model: 20 | struc: 'PF_deep' 21 | image_size: 32 22 | num_channels: 3 23 | nf: 128 24 | ch_mult: [ 1, 2, 2, 2 ] 25 | num_res_blocks: 8 26 | attn_resolutions: [ 16, ] 27 | dropout: 0.1 28 | nonlinearity: 'swish' 29 | resamp_with_conv: true 30 | conditional: true 31 | skip_rescale: true 32 | resblock_type: 'biggan' 33 | progressive: none 34 | progressive_input: none 35 | embedding_type: 'positional' 36 | init_scale: 0.0 37 | combine_method: 'sum' 38 | fir: false 39 | fir_kernel: [ 1, 3, 3, 1 ] 40 | continuous: true 41 | centered: true 42 | fourier_scale: 16 43 | scale_by_sigma: false 44 | 45 | Train: 46 | epoch: 1000 47 | loss_type: 'linear' 48 | ema_rate: 0.9999 49 | ema: True 50 | 51 | Optim: 52 | weight_decay: 0.000 53 | optimizer: 'adam' 54 | lr: 0.0002 55 | beta1: 0.9 56 | amsgrad: false 57 | eps: 0.00000001 58 | grad_clip: 1.0 59 | 60 | Sample: 61 | mpi4py: true 62 | batch_size: 512 63 | last_only: True 64 | total_num: 12500 65 | 66 | 67 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Jiaming Song 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import os 24 | import torch 25 | import numbers 26 | import torchvision.transforms as transforms 27 | import torchvision.transforms.functional as F 28 | from torchvision.datasets import CIFAR10 29 | from dataset.celeba import CelebA 30 | # from dataset.ffhq import FFHQ 31 | from dataset.lsun import LSUN 32 | from torch.utils.data import Subset 33 | import numpy as np 34 | 35 | 36 | class Crop(object): 37 | def __init__(self, x1, x2, y1, y2): 38 | self.x1 = x1 39 | self.x2 = x2 40 | self.y1 = y1 41 | self.y2 = y2 42 | 43 | def __call__(self, img): 44 | return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1) 45 | 46 | def __repr__(self): 47 | return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( 48 | self.x1, self.x2, self.y1, self.y2 49 | ) 50 | 51 | 52 | def get_dataset(args, config): 53 | if config['random_flip'] is False: 54 | tran_transform = test_transform = transforms.Compose( 55 | [transforms.Resize(config['image_size']), transforms.ToTensor()] 56 | ) 57 | else: 58 | tran_transform = transforms.Compose( 59 | [ 60 | transforms.Resize(config['image_size']), 61 | transforms.RandomHorizontalFlip(p=0.5), 62 | transforms.ToTensor(), 63 | ] 64 | ) 65 | test_transform = transforms.Compose( 66 | [transforms.Resize(config['image_size']), transforms.ToTensor()] 67 | ) 68 | 69 | if config['dataset'] == "CIFAR10": 70 | dataset = CIFAR10( 71 | os.path.join(os.getcwd(), "temp", "cifar10"), 72 | train=True, 73 | download=True, 74 | transform=tran_transform, 75 | ) 76 | test_dataset = CIFAR10( 77 | os.path.join(os.getcwd(), "temp", "cifar10"), 78 | train=False, 79 | download=True, 80 | transform=test_transform, 81 | ) 82 | 83 | elif config['dataset'] == "CELEBA": 84 | cx = 89 85 | cy = 121 86 | x1 = cy - 64 87 | x2 = cy + 64 88 | y1 = cx - 64 89 | y2 = cx + 64 90 | if config['random_flip']: 91 | dataset = CelebA( 92 | root=os.path.join(os.getcwd(), "temp", "celeba"), 93 | split="train", 94 | transform=transforms.Compose( 95 | [ 96 | Crop(x1, x2, y1, y2), 97 | transforms.Resize(config['image_size']), 98 | transforms.RandomHorizontalFlip(), 99 | transforms.ToTensor(), 100 | ] 101 | ), 102 | download=True, 103 | ) 104 | else: 105 | dataset = CelebA( 106 | root=os.path.join(os.getcwd(), "temp", "celeba"), 107 | split="train", 108 | transform=transforms.Compose( 109 | [ 110 | Crop(x1, x2, y1, y2), 111 | transforms.Resize(config['image_size']), 112 | transforms.ToTensor(), 113 | ] 114 | ), 115 | download=True, 116 | ) 117 | 118 | test_dataset = CelebA( 119 | root=os.path.join(os.getcwd(), "temp", "celeba"), 120 | split="test", 121 | transform=transforms.Compose( 122 | [ 123 | Crop(x1, x2, y1, y2), 124 | transforms.Resize(config['image_size']), 125 | transforms.ToTensor(), 126 | ] 127 | ), 128 | download=True, 129 | ) 130 | 131 | elif config['dataset'] == "LSUN": 132 | train_folder = "{}_train".format(config['category']) 133 | val_folder = "{}_val".format(config['category']) 134 | if config['random_flip']: 135 | dataset = LSUN( 136 | root=os.path.join(os.getcwd(), "temp", "lsun"), 137 | classes=[train_folder], 138 | transform=transforms.Compose( 139 | [ 140 | transforms.Resize(config['image_size']), 141 | transforms.CenterCrop(config['image_size']), 142 | transforms.RandomHorizontalFlip(p=0.5), 143 | transforms.ToTensor(), 144 | ] 145 | ), 146 | ) 147 | else: 148 | dataset = LSUN( 149 | root=os.path.join(os.getcwd(), "temp", "lsun"), 150 | classes=[train_folder], 151 | transform=transforms.Compose( 152 | [ 153 | transforms.Resize(config['image_size']), 154 | transforms.CenterCrop(config['image_size']), 155 | transforms.ToTensor(), 156 | ] 157 | ), 158 | ) 159 | 160 | test_dataset = LSUN( 161 | root=os.path.join(os.getcwd(), "temp", "lsun"), 162 | classes=[val_folder], 163 | transform=transforms.Compose( 164 | [ 165 | transforms.Resize(config['image_size']), 166 | transforms.CenterCrop(config['image_size']), 167 | transforms.ToTensor(), 168 | ] 169 | ), 170 | ) 171 | # 172 | # elif config.data.dataset == "FFHQ": 173 | # if config.data.random_flip: 174 | # dataset = FFHQ( 175 | # path=os.path.join(args.exp, "datasets", "FFHQ"), 176 | # transform=transforms.Compose( 177 | # [transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor()] 178 | # ), 179 | # resolution=config.data.image_size, 180 | # ) 181 | # else: 182 | # dataset = FFHQ( 183 | # path=os.path.join(args.exp, "datasets", "FFHQ"), 184 | # transform=transforms.ToTensor(), 185 | # resolution=config.data.image_size, 186 | # ) 187 | # 188 | # num_items = len(dataset) 189 | # indices = list(range(num_items)) 190 | # random_state = np.random.get_state() 191 | # np.random.seed(2019) 192 | # np.random.shuffle(indices) 193 | # np.random.set_state(random_state) 194 | # train_indices, test_indices = ( 195 | # indices[: int(num_items * 0.9)], 196 | # indices[int(num_items * 0.9) :], 197 | # ) 198 | # test_dataset = Subset(dataset, test_indices) 199 | # dataset = Subset(dataset, train_indices) 200 | else: 201 | dataset, test_dataset = None, None 202 | 203 | return dataset, test_dataset 204 | 205 | 206 | def logit_transform(image, lam=1e-6): 207 | image = lam + (1 - 2 * lam) * image 208 | return torch.log(image) - torch.log1p(-image) 209 | 210 | 211 | def data_transform(config, X): 212 | if config['uniform_dequantization']: 213 | X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0 214 | if config['gaussian_dequantization']: 215 | X = X + torch.randn_like(X) * 0.01 216 | 217 | if config['rescaled']: 218 | X = 2 * X - 1.0 219 | elif config['logit_transform']: 220 | X = logit_transform(X) 221 | 222 | # if hasattr(config, "image_mean"): 223 | # return X - config.image_mean.to(X.device)[None, ...] 224 | 225 | return X 226 | 227 | 228 | def inverse_data_transform(config, X): 229 | # if hasattr(config, "image_mean"): 230 | # X = X + config.image_mean.to(X.device)[None, ...] 231 | 232 | if config['logit_transform']: 233 | X = torch.sigmoid(X) 234 | elif config['rescaled']: 235 | X = (X + 1.0) / 2.0 236 | 237 | return torch.clamp(X, 0.0, 1.0) 238 | -------------------------------------------------------------------------------- /dataset/celeba.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Jiaming Song 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch 24 | import os 25 | import PIL 26 | from .vision import VisionDataset 27 | from .utils import download_file_from_google_drive, check_integrity 28 | 29 | 30 | class CelebA(VisionDataset): 31 | """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. 32 | 33 | Args: 34 | root (string): Root directory where images are downloaded to. 35 | split (string): One of {'train', 'valid', 'test'}. 36 | Accordingly dataset is selected. 37 | target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, 38 | or ``landmarks``. Can also be a list to output a tuple with all specified target types. 39 | The targets represent: 40 | ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes 41 | ``identity`` (int): label for each person (data points with the same identity are the same person) 42 | ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) 43 | ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, 44 | righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) 45 | Defaults to ``attr``. 46 | transform (callable, optional): A function/transform that takes in an PIL image 47 | and returns a transformed version. E.g, ``transforms.ToTensor`` 48 | target_transform (callable, optional): A function/transform that takes in the 49 | target and transforms it. 50 | download (bool, optional): If true, downloads the dataset from the internet and 51 | puts it in root directory. If dataset is already downloaded, it is not 52 | downloaded again. 53 | """ 54 | 55 | base_folder = "celeba" 56 | # There currently does not appear to be a easy way to extract 7z in python (without introducing additional 57 | # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available 58 | # right now. 59 | file_list = [ 60 | # File ID MD5 Hash Filename 61 | ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), 62 | # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), 63 | # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), 64 | ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), 65 | ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), 66 | ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), 67 | ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), 68 | # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), 69 | ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), 70 | ] 71 | 72 | def __init__(self, root, 73 | split="train", 74 | target_type="attr", 75 | transform=None, target_transform=None, 76 | download=False): 77 | import pandas 78 | super(CelebA, self).__init__(root) 79 | self.split = split 80 | if isinstance(target_type, list): 81 | self.target_type = target_type 82 | else: 83 | self.target_type = [target_type] 84 | self.transform = transform 85 | self.target_transform = target_transform 86 | 87 | if download: 88 | self.download() 89 | 90 | if not self._check_integrity(): 91 | raise RuntimeError('Dataset not found or corrupted.' + 92 | ' You can use download=True to download it') 93 | 94 | self.transform = transform 95 | self.target_transform = target_transform 96 | 97 | if split.lower() == "train": 98 | split = 0 99 | elif split.lower() == "valid": 100 | split = 1 101 | elif split.lower() == "test": 102 | split = 2 103 | else: 104 | raise ValueError('Wrong split entered! Please use split="train" ' 105 | 'or split="valid" or split="test"') 106 | 107 | with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f: 108 | splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 109 | 110 | with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f: 111 | self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) 112 | 113 | with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f: 114 | self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0) 115 | 116 | with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f: 117 | self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1) 118 | 119 | with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f: 120 | self.attr = pandas.read_csv(f, delim_whitespace=True, header=1) 121 | 122 | mask = (splits[1] == split) 123 | self.filename = splits[mask].index.values 124 | self.identity = torch.as_tensor(self.identity[mask].values) 125 | self.bbox = torch.as_tensor(self.bbox[mask].values) 126 | self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values) 127 | self.attr = torch.as_tensor(self.attr[mask].values) 128 | self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} 129 | 130 | def _check_integrity(self): 131 | for (_, md5, filename) in self.file_list: 132 | fpath = os.path.join(self.root, self.base_folder, filename) 133 | _, ext = os.path.splitext(filename) 134 | # Allow original archive to be deleted (zip and 7z) 135 | # Only need the extracted images 136 | if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): 137 | return False 138 | 139 | # Should check a hash of the images 140 | return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) 141 | 142 | def download(self): 143 | import zipfile 144 | 145 | if self._check_integrity(): 146 | print('Files already downloaded and verified') 147 | return 148 | 149 | for (file_id, md5, filename) in self.file_list: 150 | download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) 151 | 152 | with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: 153 | f.extractall(os.path.join(self.root, self.base_folder)) 154 | 155 | def __getitem__(self, index): 156 | X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) 157 | 158 | target = [] 159 | for t in self.target_type: 160 | if t == "attr": 161 | target.append(self.attr[index, :]) 162 | elif t == "identity": 163 | target.append(self.identity[index, 0]) 164 | elif t == "bbox": 165 | target.append(self.bbox[index, :]) 166 | elif t == "landmarks": 167 | target.append(self.landmarks_align[index, :]) 168 | else: 169 | raise ValueError("Target type \"{}\" is not recognized.".format(t)) 170 | target = tuple(target) if len(target) > 1 else target[0] 171 | 172 | if self.transform is not None: 173 | X = self.transform(X) 174 | 175 | if self.target_transform is not None: 176 | target = self.target_transform(target) 177 | 178 | return X, target 179 | 180 | def __len__(self): 181 | return len(self.attr) 182 | 183 | def extra_repr(self): 184 | lines = ["Target type: {target_type}", "Split: {split}"] 185 | return '\n'.join(lines).format(**self.__dict__) 186 | -------------------------------------------------------------------------------- /dataset/ffhq.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Jiaming Song 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from io import BytesIO 24 | 25 | import lmdb 26 | from PIL import Image 27 | from torch.utils.data import Dataset 28 | 29 | 30 | class FFHQ(Dataset): 31 | def __init__(self, path, transform, resolution=8): 32 | self.env = lmdb.open( 33 | path, 34 | max_readers=32, 35 | readonly=True, 36 | lock=False, 37 | readahead=False, 38 | meminit=False, 39 | ) 40 | 41 | if not self.env: 42 | raise IOError('Cannot open lmdb dataset', path) 43 | 44 | with self.env.begin(write=False) as txn: 45 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 46 | 47 | self.resolution = resolution 48 | self.transform = transform 49 | 50 | def __len__(self): 51 | return self.length 52 | 53 | def __getitem__(self, index): 54 | with self.env.begin(write=False) as txn: 55 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 56 | img_bytes = txn.get(key) 57 | 58 | buffer = BytesIO(img_bytes) 59 | img = Image.open(buffer) 60 | img = self.transform(img) 61 | target = 0 62 | 63 | return img, target -------------------------------------------------------------------------------- /dataset/lsun.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Jiaming Song 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | from .vision import VisionDataset 24 | from PIL import Image 25 | import os 26 | import os.path 27 | import io 28 | from collections.abc import Iterable 29 | import pickle 30 | from torchvision.datasets.utils import verify_str_arg, iterable_to_str 31 | 32 | 33 | class LSUNClass(VisionDataset): 34 | def __init__(self, root, transform=None, target_transform=None): 35 | import lmdb 36 | 37 | super(LSUNClass, self).__init__( 38 | root, transform=transform, target_transform=target_transform 39 | ) 40 | 41 | self.env = lmdb.open( 42 | root, 43 | max_readers=1, 44 | readonly=True, 45 | lock=False, 46 | readahead=False, 47 | meminit=False, 48 | ) 49 | with self.env.begin(write=False) as txn: 50 | self.length = txn.stat()["entries"] 51 | root_split = root.split("/") 52 | cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}") 53 | if os.path.isfile(cache_file): 54 | self.keys = pickle.load(open(cache_file, "rb")) 55 | else: 56 | with self.env.begin(write=False) as txn: 57 | self.keys = [key for key, _ in txn.cursor()] 58 | pickle.dump(self.keys, open(cache_file, "wb")) 59 | 60 | def __getitem__(self, index): 61 | img, target = None, None 62 | env = self.env 63 | with env.begin(write=False) as txn: 64 | imgbuf = txn.get(self.keys[index]) 65 | 66 | buf = io.BytesIO() 67 | buf.write(imgbuf) 68 | buf.seek(0) 69 | img = Image.open(buf).convert("RGB") 70 | 71 | if self.transform is not None: 72 | img = self.transform(img) 73 | 74 | if self.target_transform is not None: 75 | target = self.target_transform(target) 76 | 77 | return img, target 78 | 79 | def __len__(self): 80 | return self.length 81 | 82 | 83 | class LSUN(VisionDataset): 84 | """ 85 | `LSUN `_ dataset. 86 | 87 | Args: 88 | root (string): Root directory for the database files. 89 | classes (string or list): One of {'train', 'val', 'test'} or a list of 90 | categories to load. e,g. ['bedroom_train', 'church_outdoor_train']. 91 | transform (callable, optional): A function/transform that takes in an PIL image 92 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 93 | target_transform (callable, optional): A function/transform that takes in the 94 | target and transforms it. 95 | """ 96 | 97 | def __init__(self, root, classes="train", transform=None, target_transform=None): 98 | super(LSUN, self).__init__( 99 | root, transform=transform, target_transform=target_transform 100 | ) 101 | self.classes = self._verify_classes(classes) 102 | 103 | # for each class, create an LSUNClassDataset 104 | self.dbs = [] 105 | for c in self.classes: 106 | self.dbs.append( 107 | LSUNClass(root=root + "/" + c + "_lmdb", transform=transform) 108 | ) 109 | 110 | self.indices = [] 111 | count = 0 112 | for db in self.dbs: 113 | count += len(db) 114 | self.indices.append(count) 115 | 116 | self.length = count 117 | 118 | def _verify_classes(self, classes): 119 | categories = [ 120 | "bedroom", 121 | "bridge", 122 | "church_outdoor", 123 | "classroom", 124 | "conference_room", 125 | "dining_room", 126 | "kitchen", 127 | "living_room", 128 | "restaurant", 129 | "tower", 130 | ] 131 | dset_opts = ["train", "val", "test"] 132 | 133 | try: 134 | verify_str_arg(classes, "classes", dset_opts) 135 | if classes == "test": 136 | classes = [classes] 137 | else: 138 | classes = [c + "_" + classes for c in categories] 139 | except ValueError: 140 | if not isinstance(classes, Iterable): 141 | msg = ( 142 | "Expected type str or Iterable for argument classes, " 143 | "but got type {}." 144 | ) 145 | raise ValueError(msg.format(type(classes))) 146 | 147 | classes = list(classes) 148 | msg_fmtstr = ( 149 | "Expected type str for elements in argument classes, " 150 | "but got type {}." 151 | ) 152 | for c in classes: 153 | verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c))) 154 | c_short = c.split("_") 155 | category, dset_opt = "_".join(c_short[:-1]), c_short[-1] 156 | 157 | msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." 158 | msg = msg_fmtstr.format( 159 | category, "LSUN class", iterable_to_str(categories) 160 | ) 161 | verify_str_arg(category, valid_values=categories, custom_msg=msg) 162 | 163 | msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) 164 | verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg) 165 | 166 | return classes 167 | 168 | def __getitem__(self, index): 169 | """ 170 | Args: 171 | index (int): Index 172 | 173 | Returns: 174 | tuple: Tuple (image, target) where target is the index of the target category. 175 | """ 176 | target = 0 177 | sub = 0 178 | for ind in self.indices: 179 | if index < ind: 180 | break 181 | target += 1 182 | sub = ind 183 | 184 | db = self.dbs[target] 185 | index = index - sub 186 | 187 | if self.target_transform is not None: 188 | target = self.target_transform(target) 189 | 190 | img, _ = db[index] 191 | return img, target 192 | 193 | def __len__(self): 194 | return self.length 195 | 196 | def extra_repr(self): 197 | return "Classes: {classes}".format(**self.__dict__) 198 | -------------------------------------------------------------------------------- /dataset/utils.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Jiaming Song 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import os 24 | import os.path 25 | import hashlib 26 | import errno 27 | from torch.utils.model_zoo import tqdm 28 | 29 | 30 | def gen_bar_updater(): 31 | pbar = tqdm(total=None) 32 | 33 | def bar_update(count, block_size, total_size): 34 | if pbar.total is None and total_size: 35 | pbar.total = total_size 36 | progress_bytes = count * block_size 37 | pbar.update(progress_bytes - pbar.n) 38 | 39 | return bar_update 40 | 41 | 42 | def check_integrity(fpath, md5=None): 43 | if md5 is None: 44 | return True 45 | if not os.path.isfile(fpath): 46 | return False 47 | md5o = hashlib.md5() 48 | with open(fpath, 'rb') as f: 49 | # read in 1MB chunks 50 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 51 | md5o.update(chunk) 52 | md5c = md5o.hexdigest() 53 | if md5c != md5: 54 | return False 55 | return True 56 | 57 | 58 | def makedir_exist_ok(dirpath): 59 | """ 60 | Python2 support for os.makedirs(.., exist_ok=True) 61 | """ 62 | try: 63 | os.makedirs(dirpath) 64 | except OSError as e: 65 | if e.errno == errno.EEXIST: 66 | pass 67 | else: 68 | raise 69 | 70 | 71 | def download_url(url, root, filename=None, md5=None): 72 | """Download a file from a url and place it in root. 73 | 74 | Args: 75 | url (str): URL to download file from 76 | root (str): Directory to place downloaded file in 77 | filename (str, optional): Name to save the file under. If None, use the basename of the URL 78 | md5 (str, optional): MD5 checksum of the download. If None, do not check 79 | """ 80 | from six.moves import urllib 81 | 82 | root = os.path.expanduser(root) 83 | if not filename: 84 | filename = os.path.basename(url) 85 | fpath = os.path.join(root, filename) 86 | 87 | makedir_exist_ok(root) 88 | 89 | # downloads file 90 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 91 | print('Using downloaded and verified file: ' + fpath) 92 | else: 93 | try: 94 | print('Downloading ' + url + ' to ' + fpath) 95 | urllib.request.urlretrieve( 96 | url, fpath, 97 | reporthook=gen_bar_updater() 98 | ) 99 | except OSError: 100 | if url[:5] == 'https': 101 | url = url.replace('https:', 'http:') 102 | print('Failed download. Trying https -> http instead.' 103 | ' Downloading ' + url + ' to ' + fpath) 104 | urllib.request.urlretrieve( 105 | url, fpath, 106 | reporthook=gen_bar_updater() 107 | ) 108 | 109 | 110 | def list_dir(root, prefix=False): 111 | """List all directories at a given root 112 | 113 | Args: 114 | root (str): Path to directory whose folders need to be listed 115 | prefix (bool, optional): If true, prepends the path to each result, otherwise 116 | only returns the name of the directories found 117 | """ 118 | root = os.path.expanduser(root) 119 | directories = list( 120 | filter( 121 | lambda p: os.path.isdir(os.path.join(root, p)), 122 | os.listdir(root) 123 | ) 124 | ) 125 | 126 | if prefix is True: 127 | directories = [os.path.join(root, d) for d in directories] 128 | 129 | return directories 130 | 131 | 132 | def list_files(root, suffix, prefix=False): 133 | """List all files ending with a suffix at a given root 134 | 135 | Args: 136 | root (str): Path to directory whose folders need to be listed 137 | suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). 138 | It uses the Python "str.endswith" method and is passed directly 139 | prefix (bool, optional): If true, prepends the path to each result, otherwise 140 | only returns the name of the files found 141 | """ 142 | root = os.path.expanduser(root) 143 | files = list( 144 | filter( 145 | lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), 146 | os.listdir(root) 147 | ) 148 | ) 149 | 150 | if prefix is True: 151 | files = [os.path.join(root, d) for d in files] 152 | 153 | return files 154 | 155 | 156 | def download_file_from_google_drive(file_id, root, filename=None, md5=None): 157 | """Download a Google Drive file from and place it in root. 158 | 159 | Args: 160 | file_id (str): id of file to be downloaded 161 | root (str): Directory to place downloaded file in 162 | filename (str, optional): Name to save the file under. If None, use the id of the file. 163 | md5 (str, optional): MD5 checksum of the download. If None, do not check 164 | """ 165 | # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url 166 | import requests 167 | url = "https://docs.google.com/uc?export=download" 168 | 169 | root = os.path.expanduser(root) 170 | if not filename: 171 | filename = file_id 172 | fpath = os.path.join(root, filename) 173 | 174 | makedir_exist_ok(root) 175 | 176 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 177 | print('Using downloaded and verified file: ' + fpath) 178 | else: 179 | session = requests.Session() 180 | 181 | response = session.get(url, params={'id': file_id}, stream=True) 182 | token = _get_confirm_token(response) 183 | 184 | if token: 185 | params = {'id': file_id, 'confirm': token} 186 | response = session.get(url, params=params, stream=True) 187 | 188 | _save_response_content(response, fpath) 189 | 190 | 191 | def _get_confirm_token(response): 192 | for key, value in response.cookies.items(): 193 | if key.startswith('download_warning'): 194 | return value 195 | 196 | return None 197 | 198 | 199 | def _save_response_content(response, destination, chunk_size=32768): 200 | with open(destination, "wb") as f: 201 | pbar = tqdm(total=None) 202 | progress = 0 203 | for chunk in response.iter_content(chunk_size): 204 | if chunk: # filter out keep-alive new chunks 205 | f.write(chunk) 206 | progress += len(chunk) 207 | pbar.update(progress - pbar.n) 208 | pbar.close() 209 | -------------------------------------------------------------------------------- /dataset/vision.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Jiaming Song 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import os 24 | import torch 25 | import torch.utils.data as data 26 | 27 | 28 | class VisionDataset(data.Dataset): 29 | _repr_indent = 4 30 | 31 | def __init__(self, root, transforms=None, transform=None, target_transform=None): 32 | if isinstance(root, torch._six.string_classes): 33 | root = os.path.expanduser(root) 34 | self.root = root 35 | 36 | has_transforms = transforms is not None 37 | has_separate_transform = transform is not None or target_transform is not None 38 | if has_transforms and has_separate_transform: 39 | raise ValueError("Only transforms or transform/target_transform can " 40 | "be passed as argument") 41 | 42 | # for backwards-compatibility 43 | self.transform = transform 44 | self.target_transform = target_transform 45 | 46 | if has_separate_transform: 47 | transforms = StandardTransform(transform, target_transform) 48 | self.transforms = transforms 49 | 50 | def __getitem__(self, index): 51 | raise NotImplementedError 52 | 53 | def __len__(self): 54 | raise NotImplementedError 55 | 56 | def __repr__(self): 57 | head = "Dataset " + self.__class__.__name__ 58 | body = ["Number of datapoints: {}".format(self.__len__())] 59 | if self.root is not None: 60 | body.append("Root location: {}".format(self.root)) 61 | body += self.extra_repr().splitlines() 62 | if hasattr(self, 'transform') and self.transform is not None: 63 | body += self._format_transform_repr(self.transform, 64 | "Transforms: ") 65 | if hasattr(self, 'target_transform') and self.target_transform is not None: 66 | body += self._format_transform_repr(self.target_transform, 67 | "Target transforms: ") 68 | lines = [head] + [" " * self._repr_indent + line for line in body] 69 | return '\n'.join(lines) 70 | 71 | def _format_transform_repr(self, transform, head): 72 | lines = transform.__repr__().splitlines() 73 | return (["{}{}".format(head, lines[0])] + 74 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 75 | 76 | def extra_repr(self): 77 | return "" 78 | 79 | 80 | class StandardTransform(object): 81 | def __init__(self, transform=None, target_transform=None): 82 | self.transform = transform 83 | self.target_transform = target_transform 84 | 85 | def __call__(self, input, target): 86 | if self.transform is not None: 87 | input = self.transform(input) 88 | if self.target_transform is not None: 89 | target = self.target_transform(target) 90 | return input, target 91 | 92 | def _format_transform_repr(self, transform, head): 93 | lines = transform.__repr__().splitlines() 94 | return (["{}{}".format(head, lines[0])] + 95 | ["{}{}".format(" " * len(head), line) for line in lines[1:]]) 96 | 97 | def __repr__(self): 98 | body = [self.__class__.__name__] 99 | if self.transform is not None: 100 | body += self._format_transform_repr(self.transform, 101 | "Transform: ") 102 | if self.target_transform is not None: 103 | body += self._format_transform_repr(self.target_transform, 104 | "Target transform: ") 105 | 106 | return '\n'.join(body) 107 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Luping Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import yaml 17 | import sys 18 | import os 19 | import numpy as np 20 | import torch as th 21 | 22 | from runner.schedule import Schedule 23 | from runner.runner import Runner 24 | 25 | 26 | def args_and_config(): 27 | parser = argparse.ArgumentParser() 28 | 29 | parser.add_argument("--runner", type=str, default='sample', 30 | help="Choose the mode of runner") 31 | parser.add_argument("--config", type=str, default='ddim_cifar10.yml', 32 | help="Choose the config file") 33 | parser.add_argument("--model", type=str, default='DDIM', 34 | help="Choose the model's structure (DDIM, iDDPM, PF)") 35 | parser.add_argument("--method", type=str, default='F-PNDM', 36 | help="Choose the numerical methods (DDIM, FON, S-PNDM, F-PNDM, PF)") 37 | parser.add_argument("--sample_speed", type=int, default=50, 38 | help="Control the total generation step") 39 | parser.add_argument("--device", type=str, default='cuda', 40 | help="Choose the device to use") 41 | parser.add_argument("--image_path", type=str, default='temp/sample', 42 | help="Choose the path to save images") 43 | parser.add_argument("--model_path", type=str, default='temp/models/ddim/ema_cifar10.ckpt', 44 | help="Choose the path of model") 45 | parser.add_argument("--restart", action="store_true", 46 | help="Restart a previous training process") 47 | parser.add_argument("--train_path", type=str, default='temp/train', 48 | help="Choose the path to save training status") 49 | 50 | 51 | args = parser.parse_args() 52 | 53 | work_dir = os.getcwd() 54 | with open(f'{work_dir}/config/{args.config}', 'r') as f: 55 | config = yaml.safe_load(f) 56 | 57 | return args, config 58 | 59 | 60 | def check_config(): 61 | # image_size, total_step 62 | pass 63 | 64 | 65 | if __name__ == "__main__": 66 | args, config = args_and_config() 67 | 68 | if args.runner == 'sample' and config['Sample']['mpi4py']: 69 | from mpi4py import MPI 70 | 71 | comm = MPI.COMM_WORLD 72 | mpi_rank = comm.Get_rank() 73 | os.environ['CUDA_VISIBLE_DEVICES'] = str(mpi_rank) 74 | 75 | device = th.device(args.device) 76 | schedule = Schedule(args, config['Schedule']) 77 | if config['Model']['struc'] == 'DDIM': 78 | from model.ddim import Model 79 | model = Model(args, config['Model']).to(device) 80 | elif config['Model']['struc'] == 'iDDPM': 81 | from model.iDDPM.unet import UNetModel 82 | model = UNetModel(args, config['Model']).to(device) 83 | elif config['Model']['struc'] == 'PF': 84 | from model.scoresde.ddpm import DDPM 85 | model = DDPM(args, config['Model']).to(device) 86 | elif config['Model']['struc'] == 'PF_deep': 87 | from model.scoresde.ncsnpp import NCSNpp 88 | model = NCSNpp(args, config['Model']).to(device) 89 | else: 90 | model = None 91 | 92 | runner = Runner(args, config, schedule, model) 93 | if args.runner == 'train': 94 | runner.train() 95 | elif args.runner == 'sample': 96 | runner.sample_fid() 97 | 98 | -------------------------------------------------------------------------------- /model/ddim.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Jiaming Song 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import math 24 | import torch 25 | import torch.nn as nn 26 | 27 | 28 | def get_timestep_embedding(timesteps, embedding_dim): 29 | """ 30 | This matches the implementation in Denoising Diffusion Probabilistic Models: 31 | From Fairseq. 32 | Build sinusoidal embeddings. 33 | This matches the implementation in tensor2tensor, but differs slightly 34 | from the description in Section 3.5 of "Attention Is All You Need". 35 | """ 36 | assert len(timesteps.shape) == 1 37 | 38 | half_dim = embedding_dim // 2 39 | emb = math.log(10000) / (half_dim - 1) 40 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 41 | emb = emb.to(device=timesteps.device) 42 | emb = timesteps.float()[:, None] * emb[None, :] 43 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 44 | if embedding_dim % 2 == 1: # zero pad 45 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 46 | return emb 47 | 48 | 49 | def nonlinearity(x): 50 | # swish 51 | return x*torch.sigmoid(x) 52 | 53 | 54 | def Normalize(in_channels): 55 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 56 | 57 | 58 | class Upsample(nn.Module): 59 | def __init__(self, in_channels, with_conv): 60 | super().__init__() 61 | self.with_conv = with_conv 62 | if self.with_conv: 63 | self.conv = torch.nn.Conv2d(in_channels, 64 | in_channels, 65 | kernel_size=3, 66 | stride=1, 67 | padding=1) 68 | 69 | def forward(self, x): 70 | x = torch.nn.functional.interpolate( 71 | x, scale_factor=2.0, mode="nearest") 72 | if self.with_conv: 73 | x = self.conv(x) 74 | return x 75 | 76 | 77 | class Downsample(nn.Module): 78 | def __init__(self, in_channels, with_conv): 79 | super().__init__() 80 | self.with_conv = with_conv 81 | if self.with_conv: 82 | # no asymmetric padding in torch conv, must do it ourselves 83 | self.conv = torch.nn.Conv2d(in_channels, 84 | in_channels, 85 | kernel_size=3, 86 | stride=2, 87 | padding=0) 88 | 89 | def forward(self, x): 90 | if self.with_conv: 91 | pad = (0, 1, 0, 1) 92 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 93 | x = self.conv(x) 94 | else: 95 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 96 | return x 97 | 98 | 99 | class ResnetBlock(nn.Module): 100 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 101 | dropout, temb_channels=512): 102 | super().__init__() 103 | self.in_channels = in_channels 104 | out_channels = in_channels if out_channels is None else out_channels 105 | self.out_channels = out_channels 106 | self.use_conv_shortcut = conv_shortcut 107 | 108 | self.norm1 = Normalize(in_channels) 109 | self.conv1 = torch.nn.Conv2d(in_channels, 110 | out_channels, 111 | kernel_size=3, 112 | stride=1, 113 | padding=1) 114 | self.temb_proj = torch.nn.Linear(temb_channels, 115 | out_channels) 116 | self.norm2 = Normalize(out_channels) 117 | self.dropout = torch.nn.Dropout(dropout) 118 | self.conv2 = torch.nn.Conv2d(out_channels, 119 | out_channels, 120 | kernel_size=3, 121 | stride=1, 122 | padding=1) 123 | if self.in_channels != self.out_channels: 124 | if self.use_conv_shortcut: 125 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 126 | out_channels, 127 | kernel_size=3, 128 | stride=1, 129 | padding=1) 130 | else: 131 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 132 | out_channels, 133 | kernel_size=1, 134 | stride=1, 135 | padding=0) 136 | 137 | def forward(self, x, temb): 138 | h = x 139 | h = self.norm1(h) 140 | h = nonlinearity(h) 141 | h = self.conv1(h) 142 | 143 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 144 | 145 | h = self.norm2(h) 146 | h = nonlinearity(h) 147 | h = self.dropout(h) 148 | h = self.conv2(h) 149 | 150 | if self.in_channels != self.out_channels: 151 | if self.use_conv_shortcut: 152 | x = self.conv_shortcut(x) 153 | else: 154 | x = self.nin_shortcut(x) 155 | 156 | return x+h 157 | 158 | 159 | class AttnBlock(nn.Module): 160 | def __init__(self, in_channels): 161 | super().__init__() 162 | self.in_channels = in_channels 163 | 164 | self.norm = Normalize(in_channels) 165 | self.q = torch.nn.Conv2d(in_channels, 166 | in_channels, 167 | kernel_size=1, 168 | stride=1, 169 | padding=0) 170 | self.k = torch.nn.Conv2d(in_channels, 171 | in_channels, 172 | kernel_size=1, 173 | stride=1, 174 | padding=0) 175 | self.v = torch.nn.Conv2d(in_channels, 176 | in_channels, 177 | kernel_size=1, 178 | stride=1, 179 | padding=0) 180 | self.proj_out = torch.nn.Conv2d(in_channels, 181 | in_channels, 182 | kernel_size=1, 183 | stride=1, 184 | padding=0) 185 | 186 | def forward(self, x): 187 | h_ = x 188 | h_ = self.norm(h_) 189 | q = self.q(h_) 190 | k = self.k(h_) 191 | v = self.v(h_) 192 | 193 | # compute attention 194 | b, c, h, w = q.shape 195 | q = q.reshape(b, c, h*w) 196 | q = q.permute(0, 2, 1) # b,hw,c 197 | k = k.reshape(b, c, h*w) # b,c,hw 198 | w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 199 | w_ = w_ * (int(c)**(-0.5)) 200 | w_ = torch.nn.functional.softmax(w_, dim=2) 201 | 202 | # attend to values 203 | v = v.reshape(b, c, h*w) 204 | w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) 205 | # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 206 | h_ = torch.bmm(v, w_) 207 | h_ = h_.reshape(b, c, h, w) 208 | 209 | h_ = self.proj_out(h_) 210 | 211 | return x+h_ 212 | 213 | 214 | class Model(nn.Module): 215 | def __init__(self, args, config): 216 | super().__init__() 217 | self.config = config 218 | ch, out_ch, ch_mult = config['ch'], config['out_ch'], tuple(config['ch_mult']) 219 | num_res_blocks = config['num_res_blocks'] 220 | attn_resolutions = config['attn_resolutions'] 221 | dropout = config['dropout'] 222 | in_channels = config['in_channels'] 223 | resolution = config['image_size'] 224 | resamp_with_conv = config['resamp_with_conv'] 225 | num_timesteps = 1000 226 | 227 | if config['type'] == 'bayesian': 228 | self.logvar = nn.Parameter(torch.zeros(num_timesteps)) 229 | 230 | self.ch = ch 231 | self.temb_ch = self.ch*4 232 | self.num_resolutions = len(ch_mult) 233 | self.num_res_blocks = num_res_blocks 234 | self.resolution = resolution 235 | self.in_channels = in_channels 236 | 237 | # timestep embedding 238 | self.temb = nn.Module() 239 | self.temb.dense = nn.ModuleList([ 240 | torch.nn.Linear(self.ch, 241 | self.temb_ch), 242 | torch.nn.Linear(self.temb_ch, 243 | self.temb_ch), 244 | ]) 245 | 246 | # downsampling 247 | self.conv_in = torch.nn.Conv2d(in_channels, 248 | self.ch, 249 | kernel_size=3, 250 | stride=1, 251 | padding=1) 252 | 253 | curr_res = resolution 254 | in_ch_mult = (1,)+ch_mult 255 | self.down = nn.ModuleList() 256 | block_in = None 257 | for i_level in range(self.num_resolutions): 258 | block = nn.ModuleList() 259 | attn = nn.ModuleList() 260 | block_in = ch*in_ch_mult[i_level] 261 | block_out = ch*ch_mult[i_level] 262 | for i_block in range(self.num_res_blocks): 263 | block.append(ResnetBlock(in_channels=block_in, 264 | out_channels=block_out, 265 | temb_channels=self.temb_ch, 266 | dropout=dropout)) 267 | block_in = block_out 268 | if curr_res in attn_resolutions: 269 | attn.append(AttnBlock(block_in)) 270 | down = nn.Module() 271 | down.block = block 272 | down.attn = attn 273 | if i_level != self.num_resolutions-1: 274 | down.downsample = Downsample(block_in, resamp_with_conv) 275 | curr_res = curr_res // 2 276 | self.down.append(down) 277 | 278 | # middle 279 | self.mid = nn.Module() 280 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 281 | out_channels=block_in, 282 | temb_channels=self.temb_ch, 283 | dropout=dropout) 284 | self.mid.attn_1 = AttnBlock(block_in) 285 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 286 | out_channels=block_in, 287 | temb_channels=self.temb_ch, 288 | dropout=dropout) 289 | 290 | # upsampling 291 | self.up = nn.ModuleList() 292 | for i_level in reversed(range(self.num_resolutions)): 293 | block = nn.ModuleList() 294 | attn = nn.ModuleList() 295 | block_out = ch*ch_mult[i_level] 296 | skip_in = ch*ch_mult[i_level] 297 | for i_block in range(self.num_res_blocks+1): 298 | if i_block == self.num_res_blocks: 299 | skip_in = ch*in_ch_mult[i_level] 300 | block.append(ResnetBlock(in_channels=block_in+skip_in, 301 | out_channels=block_out, 302 | temb_channels=self.temb_ch, 303 | dropout=dropout)) 304 | block_in = block_out 305 | if curr_res in attn_resolutions: 306 | attn.append(AttnBlock(block_in)) 307 | up = nn.Module() 308 | up.block = block 309 | up.attn = attn 310 | if i_level != 0: 311 | up.upsample = Upsample(block_in, resamp_with_conv) 312 | curr_res = curr_res * 2 313 | self.up.insert(0, up) # prepend to get consistent order 314 | 315 | # end 316 | self.norm_out = Normalize(block_in) 317 | self.conv_out = torch.nn.Conv2d(block_in, 318 | out_ch, 319 | kernel_size=3, 320 | stride=1, 321 | padding=1) 322 | 323 | def forward(self, x, t): 324 | assert x.shape[2] == x.shape[3] == self.resolution 325 | 326 | # timestep embedding 327 | temb = get_timestep_embedding(t, self.ch) 328 | temb = self.temb.dense[0](temb) 329 | temb = nonlinearity(temb) 330 | temb = self.temb.dense[1](temb) 331 | 332 | # downsampling 333 | hs = [self.conv_in(x)] 334 | for i_level in range(self.num_resolutions): 335 | for i_block in range(self.num_res_blocks): 336 | h = self.down[i_level].block[i_block](hs[-1], temb) 337 | if len(self.down[i_level].attn) > 0: 338 | h = self.down[i_level].attn[i_block](h) 339 | hs.append(h) 340 | if i_level != self.num_resolutions-1: 341 | hs.append(self.down[i_level].downsample(hs[-1])) 342 | 343 | # middle 344 | h = hs[-1] 345 | h = self.mid.block_1(h, temb) 346 | h = self.mid.attn_1(h) 347 | h = self.mid.block_2(h, temb) 348 | 349 | # upsampling 350 | for i_level in reversed(range(self.num_resolutions)): 351 | for i_block in range(self.num_res_blocks+1): 352 | h = self.up[i_level].block[i_block]( 353 | torch.cat([h, hs.pop()], dim=1), temb) 354 | if len(self.up[i_level].attn) > 0: 355 | h = self.up[i_level].attn[i_block](h) 356 | if i_level != 0: 357 | h = self.up[i_level].upsample(h) 358 | 359 | # end 360 | h = self.norm_out(h) 361 | h = nonlinearity(h) 362 | h = self.conv_out(h) 363 | return h 364 | -------------------------------------------------------------------------------- /model/ema.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2020 Jiaming Song 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import torch.nn as nn 24 | 25 | 26 | class EMAHelper(object): 27 | def __init__(self, mu=0.999): 28 | self.mu = mu 29 | self.shadow = {} 30 | 31 | def register(self, module): 32 | if isinstance(module, nn.DataParallel): 33 | module = module.module 34 | for name, param in module.named_parameters(): 35 | if param.requires_grad: 36 | self.shadow[name] = param.data.clone() 37 | 38 | def update(self, module): 39 | if isinstance(module, nn.DataParallel): 40 | module = module.module 41 | for name, param in module.named_parameters(): 42 | if param.requires_grad: 43 | self.shadow[name].data = ( 44 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data 45 | 46 | def ema(self, module): 47 | if isinstance(module, nn.DataParallel): 48 | module = module.module 49 | for name, param in module.named_parameters(): 50 | if param.requires_grad: 51 | param.data.copy_(self.shadow[name].data) 52 | 53 | def ema_copy(self, module): 54 | if isinstance(module, nn.DataParallel): 55 | inner_module = module.module 56 | module_copy = type(inner_module)( 57 | inner_module.config).to(inner_module.config.device) 58 | module_copy.load_state_dict(inner_module.state_dict()) 59 | module_copy = nn.DataParallel(module_copy) 60 | else: 61 | module_copy = type(module)(module.config).to(module.config.device) 62 | module_copy.load_state_dict(module.state_dict()) 63 | # module_copy = copy.deepcopy(module) 64 | self.ema(module_copy) 65 | return module_copy 66 | 67 | def state_dict(self): 68 | return self.shadow 69 | 70 | def load_state_dict(self, state_dict): 71 | self.shadow = state_dict 72 | -------------------------------------------------------------------------------- /model/iDDPM/nn.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 OpenAI 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | """ 24 | Various utilities for neural networks. 25 | """ 26 | 27 | import math 28 | 29 | import torch as th 30 | import torch.nn as nn 31 | 32 | 33 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 34 | class SiLU(nn.Module): 35 | def forward(self, x): 36 | return x * th.sigmoid(x) 37 | 38 | 39 | class GroupNorm32(nn.GroupNorm): 40 | def forward(self, x): 41 | return super().forward(x.float()).type(x.dtype) 42 | 43 | 44 | def conv_nd(dims, *args, **kwargs): 45 | """ 46 | Create a 1D, 2D, or 3D convolution module. 47 | """ 48 | if dims == 1: 49 | return nn.Conv1d(*args, **kwargs) 50 | elif dims == 2: 51 | return nn.Conv2d(*args, **kwargs) 52 | elif dims == 3: 53 | return nn.Conv3d(*args, **kwargs) 54 | raise ValueError(f"unsupported dimensions: {dims}") 55 | 56 | 57 | def linear(*args, **kwargs): 58 | """ 59 | Create a linear module. 60 | """ 61 | return nn.Linear(*args, **kwargs) 62 | 63 | 64 | def avg_pool_nd(dims, *args, **kwargs): 65 | """ 66 | Create a 1D, 2D, or 3D average pooling module. 67 | """ 68 | if dims == 1: 69 | return nn.AvgPool1d(*args, **kwargs) 70 | elif dims == 2: 71 | return nn.AvgPool2d(*args, **kwargs) 72 | elif dims == 3: 73 | return nn.AvgPool3d(*args, **kwargs) 74 | raise ValueError(f"unsupported dimensions: {dims}") 75 | 76 | 77 | def update_ema(target_params, source_params, rate=0.99): 78 | """ 79 | Update target parameters to be closer to those of source parameters using 80 | an exponential moving average. 81 | 82 | :param target_params: the target parameter sequence. 83 | :param source_params: the source parameter sequence. 84 | :param rate: the EMA rate (closer to 1 means slower). 85 | """ 86 | for targ, src in zip(target_params, source_params): 87 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 88 | 89 | 90 | def zero_module(module): 91 | """ 92 | Zero out the parameters of a module and return it. 93 | """ 94 | for p in module.parameters(): 95 | p.detach().zero_() 96 | return module 97 | 98 | 99 | def scale_module(module, scale): 100 | """ 101 | Scale the parameters of a module and return it. 102 | """ 103 | for p in module.parameters(): 104 | p.detach().mul_(scale) 105 | return module 106 | 107 | 108 | def mean_flat(tensor): 109 | """ 110 | Take the mean over all non-batch dimensions. 111 | """ 112 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 113 | 114 | 115 | def normalization(channels): 116 | """ 117 | Make a standard normalization layer. 118 | 119 | :param channels: number of input channels. 120 | :return: an nn.Module for normalization. 121 | """ 122 | return GroupNorm32(32, channels) 123 | 124 | 125 | def timestep_embedding(timesteps, dim, max_period=10000): 126 | """ 127 | Create sinusoidal timestep embeddings. 128 | 129 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 130 | These may be fractional. 131 | :param dim: the dimension of the output. 132 | :param max_period: controls the minimum frequency of the embeddings. 133 | :return: an [N x dim] Tensor of positional embeddings. 134 | """ 135 | half = dim // 2 136 | freqs = th.exp( 137 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 138 | ).to(device=timesteps.device) 139 | args = timesteps[:, None].float() * freqs[None] 140 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 141 | if dim % 2: 142 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 143 | return embedding 144 | 145 | 146 | def checkpoint(func, inputs, params, flag): 147 | """ 148 | Evaluate a function without caching intermediate activations, allowing for 149 | reduced memory at the expense of extra compute in the backward pass. 150 | 151 | :param func: the function to evaluate. 152 | :param inputs: the argument sequence to pass to `func`. 153 | :param params: a sequence of parameters `func` depends on but does not 154 | explicitly take as arguments. 155 | :param flag: if False, disable gradient checkpointing. 156 | """ 157 | if flag: 158 | args = tuple(inputs) + tuple(params) 159 | return CheckpointFunction.apply(func, len(inputs), *args) 160 | else: 161 | return func(*inputs) 162 | 163 | 164 | class CheckpointFunction(th.autograd.Function): 165 | @staticmethod 166 | def forward(ctx, run_function, length, *args): 167 | ctx.run_function = run_function 168 | ctx.input_tensors = list(args[:length]) 169 | ctx.input_params = list(args[length:]) 170 | with th.no_grad(): 171 | output_tensors = ctx.run_function(*ctx.input_tensors) 172 | return output_tensors 173 | 174 | @staticmethod 175 | def backward(ctx, *output_grads): 176 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 177 | with th.enable_grad(): 178 | # Fixes a bug where the first op in run_function modifies the 179 | # Tensor storage in place, which is not allowed for detach()'d 180 | # Tensors. 181 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 182 | output_tensors = ctx.run_function(*shallow_copies) 183 | input_grads = th.autograd.grad( 184 | output_tensors, 185 | ctx.input_tensors + ctx.input_params, 186 | output_grads, 187 | allow_unused=True, 188 | ) 189 | del ctx.input_tensors 190 | del ctx.input_params 191 | del output_tensors 192 | return (None, None) + input_grads 193 | -------------------------------------------------------------------------------- /model/iDDPM/unet.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2021 OpenAI 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | 23 | import sys 24 | from abc import abstractmethod 25 | 26 | import math 27 | 28 | import numpy as np 29 | import torch as th 30 | import torch.nn as nn 31 | import torch.nn.functional as F 32 | 33 | from .nn import ( 34 | SiLU, 35 | conv_nd, 36 | linear, 37 | avg_pool_nd, 38 | zero_module, 39 | normalization, 40 | timestep_embedding, 41 | checkpoint, 42 | ) 43 | 44 | 45 | def convert_module_to_f16(l): 46 | """ 47 | Convert primitive modules to float16. 48 | """ 49 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 50 | l.weight.data = l.weight.data.half() 51 | l.bias.data = l.bias.data.half() 52 | 53 | 54 | def convert_module_to_f32(l): 55 | """ 56 | Convert primitive modules to float32, undoing convert_module_to_f16(). 57 | """ 58 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 59 | l.weight.data = l.weight.data.float() 60 | l.bias.data = l.bias.data.float() 61 | 62 | 63 | class TimestepBlock(nn.Module): 64 | """ 65 | Any module where forward() takes timestep embeddings as a second argument. 66 | """ 67 | 68 | @abstractmethod 69 | def forward(self, x, emb): 70 | """ 71 | Apply the module to `x` given `emb` timestep embeddings. 72 | """ 73 | 74 | 75 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 76 | """ 77 | A sequential module that passes timestep embeddings to the children that 78 | support it as an extra input. 79 | """ 80 | 81 | def forward(self, x, emb): 82 | for layer in self: 83 | if isinstance(layer, TimestepBlock): 84 | x = layer(x, emb) 85 | else: 86 | x = layer(x) 87 | return x 88 | 89 | 90 | class Upsample(nn.Module): 91 | """ 92 | An upsampling layer with an optional convolution. 93 | 94 | :param channels: channels in the inputs and outputs. 95 | :param use_conv: a bool determining if a convolution is applied. 96 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 97 | upsampling occurs in the inner-two dimensions. 98 | """ 99 | 100 | def __init__(self, channels, use_conv, dims=2): 101 | super().__init__() 102 | self.channels = channels 103 | self.use_conv = use_conv 104 | self.dims = dims 105 | if use_conv: 106 | self.conv = conv_nd(dims, channels, channels, 3, padding=1) 107 | 108 | def forward(self, x): 109 | assert x.shape[1] == self.channels 110 | if self.dims == 3: 111 | x = F.interpolate( 112 | x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" 113 | ) 114 | else: 115 | x = F.interpolate(x, scale_factor=2, mode="nearest") 116 | if self.use_conv: 117 | x = self.conv(x) 118 | return x 119 | 120 | 121 | class Downsample(nn.Module): 122 | """ 123 | A downsampling layer with an optional convolution. 124 | 125 | :param channels: channels in the inputs and outputs. 126 | :param use_conv: a bool determining if a convolution is applied. 127 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 128 | downsampling occurs in the inner-two dimensions. 129 | """ 130 | 131 | def __init__(self, channels, use_conv, dims=2): 132 | super().__init__() 133 | self.channels = channels 134 | self.use_conv = use_conv 135 | self.dims = dims 136 | stride = 2 if dims != 3 else (1, 2, 2) 137 | if use_conv: 138 | self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1) 139 | else: 140 | self.op = avg_pool_nd(stride) 141 | 142 | def forward(self, x): 143 | assert x.shape[1] == self.channels 144 | return self.op(x) 145 | 146 | 147 | class ResBlock(TimestepBlock): 148 | """ 149 | A residual block that can optionally change the number of channels. 150 | 151 | :param channels: the number of input channels. 152 | :param emb_channels: the number of timestep embedding channels. 153 | :param dropout: the rate of dropout. 154 | :param out_channels: if specified, the number of out channels. 155 | :param use_conv: if True and out_channels is specified, use a spatial 156 | convolution instead of a smaller 1x1 convolution to change the 157 | channels in the skip connection. 158 | :param dims: determines if the signal is 1D, 2D, or 3D. 159 | :param use_checkpoint: if True, use gradient checkpointing on this module. 160 | """ 161 | 162 | def __init__( 163 | self, 164 | channels, 165 | emb_channels, 166 | dropout, 167 | out_channels=None, 168 | use_conv=False, 169 | use_scale_shift_norm=False, 170 | dims=2, 171 | use_checkpoint=False, 172 | ): 173 | super().__init__() 174 | self.channels = channels 175 | self.emb_channels = emb_channels 176 | self.dropout = dropout 177 | self.out_channels = out_channels or channels 178 | self.use_conv = use_conv 179 | self.use_checkpoint = use_checkpoint 180 | self.use_scale_shift_norm = use_scale_shift_norm 181 | 182 | self.in_layers = nn.Sequential( 183 | normalization(channels), 184 | SiLU(), 185 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 186 | ) 187 | self.emb_layers = nn.Sequential( 188 | SiLU(), 189 | linear( 190 | emb_channels, 191 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 192 | ), 193 | ) 194 | self.out_layers = nn.Sequential( 195 | normalization(self.out_channels), 196 | SiLU(), 197 | nn.Dropout(p=dropout), 198 | zero_module( 199 | conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) 200 | ), 201 | ) 202 | 203 | if self.out_channels == channels: 204 | self.skip_connection = nn.Identity() 205 | elif use_conv: 206 | self.skip_connection = conv_nd( 207 | dims, channels, self.out_channels, 3, padding=1 208 | ) 209 | else: 210 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 211 | 212 | def forward(self, x, emb): 213 | """ 214 | Apply the block to a Tensor, conditioned on a timestep embedding. 215 | 216 | :param x: an [N x C x ...] Tensor of features. 217 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 218 | :return: an [N x C x ...] Tensor of outputs. 219 | """ 220 | return checkpoint( 221 | self._forward, (x, emb), self.parameters(), self.use_checkpoint 222 | ) 223 | 224 | def _forward(self, x, emb): 225 | h = self.in_layers(x) 226 | emb_out = self.emb_layers(emb).type(h.dtype) 227 | while len(emb_out.shape) < len(h.shape): 228 | emb_out = emb_out[..., None] 229 | if self.use_scale_shift_norm: 230 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 231 | scale, shift = th.chunk(emb_out, 2, dim=1) 232 | h = out_norm(h) * (1 + scale) + shift 233 | h = out_rest(h) 234 | else: 235 | h = h + emb_out 236 | h = self.out_layers(h) 237 | return self.skip_connection(x) + h 238 | 239 | 240 | class AttentionBlock(nn.Module): 241 | """ 242 | An attention block that allows spatial positions to attend to each other. 243 | 244 | Originally ported from here, but adapted to the N-d case. 245 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 246 | """ 247 | 248 | def __init__(self, channels, num_heads=1, use_checkpoint=False): 249 | super().__init__() 250 | self.channels = channels 251 | self.num_heads = num_heads 252 | self.use_checkpoint = use_checkpoint 253 | 254 | self.norm = normalization(channels) 255 | self.qkv = conv_nd(1, channels, channels * 3, 1) 256 | self.attention = QKVAttention() 257 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 258 | 259 | def forward(self, x): 260 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 261 | 262 | def _forward(self, x): 263 | b, c, *spatial = x.shape 264 | x = x.reshape(b, c, -1) 265 | qkv = self.qkv(self.norm(x)) 266 | qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2]) 267 | h = self.attention(qkv) 268 | h = h.reshape(b, -1, h.shape[-1]) 269 | h = self.proj_out(h) 270 | return (x + h).reshape(b, c, *spatial) 271 | 272 | 273 | class QKVAttention(nn.Module): 274 | """ 275 | A module which performs QKV attention. 276 | """ 277 | 278 | def forward(self, qkv): 279 | """ 280 | Apply QKV attention. 281 | 282 | :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs. 283 | :return: an [N x C x T] tensor after attention. 284 | """ 285 | ch = qkv.shape[1] // 3 286 | q, k, v = th.split(qkv, ch, dim=1) 287 | scale = 1 / math.sqrt(math.sqrt(ch)) 288 | weight = th.einsum( 289 | "bct,bcs->bts", q * scale, k * scale 290 | ) # More stable with f16 than dividing afterwards 291 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 292 | return th.einsum("bts,bcs->bct", weight, v) 293 | 294 | @staticmethod 295 | def count_flops(model, _x, y): 296 | """ 297 | A counter for the `thop` package to count the operations in an 298 | attention operation. 299 | 300 | Meant to be used like: 301 | 302 | macs, params = thop.profile( 303 | model, 304 | inputs=(inputs, timestamps), 305 | custom_ops={QKVAttention: QKVAttention.count_flops}, 306 | ) 307 | 308 | """ 309 | b, c, *spatial = y[0].shape 310 | num_spatial = int(np.prod(spatial)) 311 | # We perform two matmuls with the same number of ops. 312 | # The first computes the weight matrix, the second computes 313 | # the combination of the value vectors. 314 | matmul_ops = 2 * b * (num_spatial ** 2) * c 315 | model.total_ops += th.DoubleTensor([matmul_ops]) 316 | 317 | 318 | class UNetModel(nn.Module): 319 | """ 320 | The full UNet model with attention and timestep embedding. 321 | 322 | :param in_channels: channels in the input Tensor. 323 | :param model_channels: base channel count for the model. 324 | :param out_channels: channels in the output Tensor. 325 | :param num_res_blocks: number of residual blocks per downsample. 326 | :param attention_resolutions: a collection of downsample rates at which 327 | attention will take place. May be a set, list, or tuple. 328 | For example, if this contains 4, then at 4x downsampling, attention 329 | will be used. 330 | :param dropout: the dropout probability. 331 | :param channel_mult: channel multiplier for each level of the UNet. 332 | :param conv_resample: if True, use learned convolutions for upsampling and 333 | downsampling. 334 | :param dims: determines if the signal is 1D, 2D, or 3D. 335 | :param num_classes: if specified (as an int), then this model will be 336 | class-conditional with `num_classes` classes. 337 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 338 | :param num_heads: the number of attention heads in each attention layer. 339 | """ 340 | 341 | def __init__(self, args, config): 342 | super().__init__() 343 | 344 | if config['num_heads_upsample'] == -1: 345 | num_heads_upsample = config['num_heads'] 346 | 347 | self.in_channels = in_channels = config['in_channels'] 348 | self.model_channels = model_channels = config['model_channels'] 349 | self.out_channels = out_channels = config['out_channels'] 350 | self.num_res_blocks = num_res_blocks = config['num_res_blocks'] 351 | self.attention_resolutions = attention_resolutions = config['attention_resolutions'] 352 | self.dropout = dropout = config['dropout'] 353 | self.channel_mult = channel_mult = config['channel_mult'] 354 | self.conv_resample = conv_resample = config['conv_resample'] 355 | self.num_classes = num_classes = None 356 | self.use_checkpoint = use_checkpoint = config['use_checkpoint'] 357 | self.num_heads = num_heads = config['num_heads'] 358 | self.num_heads_upsample = num_heads_upsample = config['num_heads_upsample'] 359 | 360 | dims = config['dims'] 361 | use_scale_shift_norm = config['use_scale_shift_norm'] 362 | 363 | time_embed_dim = model_channels * 4 364 | self.time_embed = nn.Sequential( 365 | linear(model_channels, time_embed_dim), 366 | SiLU(), 367 | linear(time_embed_dim, time_embed_dim), 368 | ) 369 | 370 | if self.num_classes is not None: 371 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 372 | 373 | self.input_blocks = nn.ModuleList( 374 | [ 375 | TimestepEmbedSequential( 376 | conv_nd(dims, in_channels, model_channels, 3, padding=1) 377 | ) 378 | ] 379 | ) 380 | input_block_chans = [model_channels] 381 | ch = model_channels 382 | ds = 1 383 | for level, mult in enumerate(channel_mult): 384 | for _ in range(num_res_blocks): 385 | layers = [ 386 | ResBlock( 387 | ch, 388 | time_embed_dim, 389 | dropout, 390 | out_channels=mult * model_channels, 391 | dims=dims, 392 | use_checkpoint=use_checkpoint, 393 | use_scale_shift_norm=use_scale_shift_norm, 394 | ) 395 | ] 396 | ch = mult * model_channels 397 | if ds in attention_resolutions: 398 | layers.append( 399 | AttentionBlock( 400 | ch, use_checkpoint=use_checkpoint, num_heads=num_heads 401 | ) 402 | ) 403 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 404 | input_block_chans.append(ch) 405 | if level != len(channel_mult) - 1: 406 | self.input_blocks.append( 407 | TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims)) 408 | ) 409 | input_block_chans.append(ch) 410 | ds *= 2 411 | 412 | self.middle_block = TimestepEmbedSequential( 413 | ResBlock( 414 | ch, 415 | time_embed_dim, 416 | dropout, 417 | dims=dims, 418 | use_checkpoint=use_checkpoint, 419 | use_scale_shift_norm=use_scale_shift_norm, 420 | ), 421 | AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads), 422 | ResBlock( 423 | ch, 424 | time_embed_dim, 425 | dropout, 426 | dims=dims, 427 | use_checkpoint=use_checkpoint, 428 | use_scale_shift_norm=use_scale_shift_norm, 429 | ), 430 | ) 431 | 432 | self.output_blocks = nn.ModuleList([]) 433 | for level, mult in list(enumerate(channel_mult))[::-1]: 434 | for i in range(num_res_blocks + 1): 435 | layers = [ 436 | ResBlock( 437 | ch + input_block_chans.pop(), 438 | time_embed_dim, 439 | dropout, 440 | out_channels=model_channels * mult, 441 | dims=dims, 442 | use_checkpoint=use_checkpoint, 443 | use_scale_shift_norm=use_scale_shift_norm, 444 | ) 445 | ] 446 | ch = model_channels * mult 447 | if ds in attention_resolutions: 448 | layers.append( 449 | AttentionBlock( 450 | ch, 451 | use_checkpoint=use_checkpoint, 452 | num_heads=num_heads_upsample, 453 | ) 454 | ) 455 | if level and i == num_res_blocks: 456 | layers.append(Upsample(ch, conv_resample, dims=dims)) 457 | ds //= 2 458 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 459 | 460 | self.out = nn.Sequential( 461 | normalization(ch), 462 | SiLU(), 463 | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), 464 | ) 465 | 466 | def convert_to_fp16(self): 467 | """ 468 | Convert the torso of the model to float16. 469 | """ 470 | self.input_blocks.apply(convert_module_to_f16) 471 | self.middle_block.apply(convert_module_to_f16) 472 | self.output_blocks.apply(convert_module_to_f16) 473 | 474 | def convert_to_fp32(self): 475 | """ 476 | Convert the torso of the model to float32. 477 | """ 478 | self.input_blocks.apply(convert_module_to_f32) 479 | self.middle_block.apply(convert_module_to_f32) 480 | self.output_blocks.apply(convert_module_to_f32) 481 | 482 | @property 483 | def inner_dtype(self): 484 | """ 485 | Get the dtype used by the torso of the model. 486 | """ 487 | # print(next(self.input_blocks.parameters()).dtype) 488 | # return next(self.input_blocks.parameters()).dtype 489 | return th.float32 490 | 491 | def forward(self, x, timesteps, y=None): 492 | """ 493 | Apply the model to an input batch. 494 | 495 | :param x: an [N x C x ...] Tensor of inputs. 496 | :param timesteps: a 1-D batch of timesteps. 497 | :param y: an [N] Tensor of labels, if class-conditional. 498 | :return: an [N x C x ...] Tensor of outputs. 499 | """ 500 | assert (y is not None) == ( 501 | self.num_classes is not None 502 | ), "must specify y if and only if the model is class-conditional" 503 | 504 | hs = [] 505 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 506 | B, C = x.shape[:2] 507 | 508 | if self.num_classes is not None: 509 | assert y.shape == (x.shape[0],) 510 | emb = emb + self.label_emb(y) 511 | 512 | h = x.type(self.inner_dtype) 513 | for module in self.input_blocks: 514 | h = module(h, emb) 515 | hs.append(h) 516 | h = self.middle_block(h, emb) 517 | for module in self.output_blocks: 518 | cat_in = th.cat([h, hs.pop()], dim=1) 519 | h = module(cat_in, emb) 520 | h = h.type(x.dtype) 521 | output = self.out(h) 522 | model_output, model_var_values = th.split(output, C, dim=1) 523 | return model_output 524 | 525 | def get_feature_vectors(self, x, timesteps, y=None): 526 | """ 527 | Apply the model and return all of the intermediate tensors. 528 | 529 | :param x: an [N x C x ...] Tensor of inputs. 530 | :param timesteps: a 1-D batch of timesteps. 531 | :param y: an [N] Tensor of labels, if class-conditional. 532 | :return: a dict with the following keys: 533 | - 'down': a list of hidden state tensors from downsampling. 534 | - 'middle': the tensor of the output of the lowest-resolution 535 | block in the model. 536 | - 'up': a list of hidden state tensors from upsampling. 537 | """ 538 | hs = [] 539 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 540 | if self.num_classes is not None: 541 | assert y.shape == (x.shape[0],) 542 | emb = emb + self.label_emb(y) 543 | result = dict(down=[], up=[]) 544 | h = x.type(self.inner_dtype) 545 | for module in self.input_blocks: 546 | h = module(h, emb) 547 | hs.append(h) 548 | result["down"].append(h.type(x.dtype)) 549 | h = self.middle_block(h, emb) 550 | result["middle"] = h.type(x.dtype) 551 | for module in self.output_blocks: 552 | cat_in = th.cat([h, hs.pop()], dim=1) 553 | h = module(cat_in, emb) 554 | result["up"].append(h.type(x.dtype)) 555 | return result 556 | 557 | 558 | class SuperResModel(UNetModel): 559 | """ 560 | A UNetModel that performs super-resolution. 561 | 562 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 563 | """ 564 | 565 | def __init__(self, in_channels, *args, **kwargs): 566 | super().__init__(in_channels * 2, *args, **kwargs) 567 | 568 | def forward(self, x, timesteps, low_res=None, **kwargs): 569 | _, _, new_height, new_width = x.shape 570 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 571 | x = th.cat([x, upsampled], dim=1) 572 | return super().forward(x, timesteps, **kwargs) 573 | 574 | def get_feature_vectors(self, x, timesteps, low_res=None, **kwargs): 575 | _, new_height, new_width, _ = x.shape 576 | upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear") 577 | x = th.cat([x, upsampled], dim=1) 578 | return super().get_feature_vectors(x, timesteps, **kwargs) 579 | 580 | -------------------------------------------------------------------------------- /model/scoresde/ddpm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """DDPM model. 18 | 19 | This code is the pytorch equivalent of: 20 | https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py 21 | """ 22 | import torch 23 | import torch.nn as nn 24 | import numpy as np 25 | import functools 26 | 27 | from . import layers, normalization 28 | 29 | RefineBlock = layers.RefineBlock 30 | ResidualBlock = layers.ResidualBlock 31 | ResnetBlockDDPM = layers.ResnetBlockDDPM 32 | Upsample = layers.Upsample 33 | Downsample = layers.Downsample 34 | conv3x3 = layers.ddpm_conv3x3 35 | get_act = layers.get_act 36 | get_normalization = normalization.get_normalization 37 | default_initializer = layers.default_init 38 | 39 | 40 | class DDPM(nn.Module): 41 | def __init__(self, args, config): 42 | super().__init__() 43 | self.act = act = get_act(config) 44 | 45 | self.nf = nf = config['nf'] 46 | ch_mult = config['ch_mult'] 47 | self.num_res_blocks = num_res_blocks = config['num_res_blocks'] 48 | self.attn_resolutions = attn_resolutions = config['attn_resolutions'] 49 | dropout = config['dropout'] 50 | resamp_with_conv = config['resamp_with_conv'] 51 | self.num_resolutions = num_resolutions = len(ch_mult) 52 | self.all_resolutions = all_resolutions = [config['image_size'] // (2 ** i) for i in range(num_resolutions)] 53 | 54 | AttnBlock = functools.partial(layers.AttnBlock) 55 | self.conditional = conditional = config['conditional'] 56 | ResnetBlock = functools.partial(ResnetBlockDDPM, act=act, temb_dim=4 * nf, dropout=dropout) 57 | if conditional: 58 | # Condition on noise levels. 59 | modules = [nn.Linear(nf, nf * 4)] 60 | modules[0].weight.data = default_initializer()(modules[0].weight.data.shape) 61 | nn.init.zeros_(modules[0].bias) 62 | modules.append(nn.Linear(nf * 4, nf * 4)) 63 | modules[1].weight.data = default_initializer()(modules[1].weight.data.shape) 64 | nn.init.zeros_(modules[1].bias) 65 | 66 | self.centered = config['centered'] 67 | channels = config['num_channels'] 68 | 69 | # Downsampling block 70 | modules.append(conv3x3(channels, nf)) 71 | hs_c = [nf] 72 | in_ch = nf 73 | for i_level in range(num_resolutions): 74 | # Residual blocks for this resolution 75 | for i_block in range(num_res_blocks): 76 | out_ch = nf * ch_mult[i_level] 77 | modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) 78 | in_ch = out_ch 79 | if all_resolutions[i_level] in attn_resolutions: 80 | modules.append(AttnBlock(channels=in_ch)) 81 | hs_c.append(in_ch) 82 | if i_level != num_resolutions - 1: 83 | modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv)) 84 | hs_c.append(in_ch) 85 | 86 | in_ch = hs_c[-1] 87 | modules.append(ResnetBlock(in_ch=in_ch)) 88 | modules.append(AttnBlock(channels=in_ch)) 89 | modules.append(ResnetBlock(in_ch=in_ch)) 90 | 91 | # Upsampling block 92 | for i_level in reversed(range(num_resolutions)): 93 | for i_block in range(num_res_blocks + 1): 94 | out_ch = nf * ch_mult[i_level] 95 | modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), out_ch=out_ch)) 96 | in_ch = out_ch 97 | if all_resolutions[i_level] in attn_resolutions: 98 | modules.append(AttnBlock(channels=in_ch)) 99 | if i_level != 0: 100 | modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv)) 101 | 102 | assert not hs_c 103 | modules.append(nn.GroupNorm(num_channels=in_ch, num_groups=32, eps=1e-6)) 104 | modules.append(conv3x3(in_ch, channels, init_scale=0.)) 105 | self.all_modules = nn.ModuleList(modules) 106 | 107 | self.scale_by_sigma = False 108 | 109 | def forward(self, x, labels): 110 | modules = self.all_modules 111 | m_idx = 0 112 | if self.conditional: 113 | # timestep/scale embedding 114 | timesteps = labels 115 | temb = layers.get_timestep_embedding(timesteps, self.nf) 116 | temb = modules[m_idx](temb) 117 | m_idx += 1 118 | temb = modules[m_idx](self.act(temb)) 119 | m_idx += 1 120 | else: 121 | temb = None 122 | 123 | if self.centered: 124 | # Input is in [-1, 1] 125 | h = x 126 | else: 127 | # Input is in [0, 1] 128 | h = 2 * x - 1. 129 | 130 | # Downsampling block 131 | hs = [modules[m_idx](h)] 132 | m_idx += 1 133 | for i_level in range(self.num_resolutions): 134 | # Residual blocks for this resolution 135 | for i_block in range(self.num_res_blocks): 136 | h = modules[m_idx](hs[-1], temb) 137 | m_idx += 1 138 | if h.shape[-1] in self.attn_resolutions: 139 | h = modules[m_idx](h) 140 | m_idx += 1 141 | hs.append(h) 142 | if i_level != self.num_resolutions - 1: 143 | hs.append(modules[m_idx](hs[-1])) 144 | m_idx += 1 145 | 146 | h = hs[-1] 147 | h = modules[m_idx](h, temb) 148 | m_idx += 1 149 | h = modules[m_idx](h) 150 | m_idx += 1 151 | h = modules[m_idx](h, temb) 152 | m_idx += 1 153 | 154 | # Upsampling block 155 | for i_level in reversed(range(self.num_resolutions)): 156 | for i_block in range(self.num_res_blocks + 1): 157 | h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) 158 | m_idx += 1 159 | if h.shape[-1] in self.attn_resolutions: 160 | h = modules[m_idx](h) 161 | m_idx += 1 162 | if i_level != 0: 163 | h = modules[m_idx](h) 164 | m_idx += 1 165 | 166 | assert not hs 167 | h = self.act(modules[m_idx](h)) 168 | m_idx += 1 169 | h = modules[m_idx](h) 170 | m_idx += 1 171 | assert m_idx == len(modules) 172 | 173 | return h 174 | -------------------------------------------------------------------------------- /model/scoresde/layerspp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Layers for defining NCSN++. 18 | """ 19 | from . import layers 20 | from . import up_or_down_sampling 21 | import torch.nn as nn 22 | import torch 23 | import torch.nn.functional as F 24 | import numpy as np 25 | 26 | conv1x1 = layers.ddpm_conv1x1 27 | conv3x3 = layers.ddpm_conv3x3 28 | NIN = layers.NIN 29 | default_init = layers.default_init 30 | 31 | 32 | class GaussianFourierProjection(nn.Module): 33 | """Gaussian Fourier embeddings for noise levels.""" 34 | 35 | def __init__(self, embedding_size=256, scale=1.0): 36 | super().__init__() 37 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 38 | 39 | def forward(self, x): 40 | x_proj = x[:, None] * self.W[None, :] * 2 * np.pi 41 | return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 42 | 43 | 44 | class Combine(nn.Module): 45 | """Combine information from skip connections.""" 46 | 47 | def __init__(self, dim1, dim2, method='cat'): 48 | super().__init__() 49 | self.Conv_0 = conv1x1(dim1, dim2) 50 | self.method = method 51 | 52 | def forward(self, x, y): 53 | h = self.Conv_0(x) 54 | if self.method == 'cat': 55 | return torch.cat([h, y], dim=1) 56 | elif self.method == 'sum': 57 | return h + y 58 | else: 59 | raise ValueError(f'Method {self.method} not recognized.') 60 | 61 | 62 | class AttnBlockpp(nn.Module): 63 | """Channel-wise self-attention block. Modified from DDPM.""" 64 | 65 | def __init__(self, channels, skip_rescale=False, init_scale=0.): 66 | super().__init__() 67 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 32), num_channels=channels, 68 | eps=1e-6) 69 | self.NIN_0 = NIN(channels, channels) 70 | self.NIN_1 = NIN(channels, channels) 71 | self.NIN_2 = NIN(channels, channels) 72 | self.NIN_3 = NIN(channels, channels, init_scale=init_scale) 73 | self.skip_rescale = skip_rescale 74 | 75 | def forward(self, x): 76 | B, C, H, W = x.shape 77 | h = self.GroupNorm_0(x) 78 | q = self.NIN_0(h) 79 | k = self.NIN_1(h) 80 | v = self.NIN_2(h) 81 | 82 | w = torch.einsum('bchw,bcij->bhwij', q, k) * (int(C) ** (-0.5)) 83 | w = torch.reshape(w, (B, H, W, H * W)) 84 | w = F.softmax(w, dim=-1) 85 | w = torch.reshape(w, (B, H, W, H, W)) 86 | h = torch.einsum('bhwij,bcij->bchw', w, v) 87 | h = self.NIN_3(h) 88 | if not self.skip_rescale: 89 | return x + h 90 | else: 91 | return (x + h) / np.sqrt(2.) 92 | 93 | 94 | class Upsample(nn.Module): 95 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 96 | fir_kernel=(1, 3, 3, 1)): 97 | super().__init__() 98 | out_ch = out_ch if out_ch else in_ch 99 | if not fir: 100 | if with_conv: 101 | self.Conv_0 = conv3x3(in_ch, out_ch) 102 | else: 103 | if with_conv: 104 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 105 | kernel=3, up=True, 106 | resample_kernel=fir_kernel, 107 | use_bias=True, 108 | kernel_init=default_init()) 109 | self.fir = fir 110 | self.with_conv = with_conv 111 | self.fir_kernel = fir_kernel 112 | self.out_ch = out_ch 113 | 114 | def forward(self, x): 115 | B, C, H, W = x.shape 116 | if not self.fir: 117 | h = F.interpolate(x, (H * 2, W * 2), 'nearest') 118 | if self.with_conv: 119 | h = self.Conv_0(h) 120 | else: 121 | if not self.with_conv: 122 | h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 123 | else: 124 | h = self.Conv2d_0(x) 125 | 126 | return h 127 | 128 | 129 | class Downsample(nn.Module): 130 | def __init__(self, in_ch=None, out_ch=None, with_conv=False, fir=False, 131 | fir_kernel=(1, 3, 3, 1)): 132 | super().__init__() 133 | out_ch = out_ch if out_ch else in_ch 134 | if not fir: 135 | if with_conv: 136 | self.Conv_0 = conv3x3(in_ch, out_ch, stride=2, padding=0) 137 | else: 138 | if with_conv: 139 | self.Conv2d_0 = up_or_down_sampling.Conv2d(in_ch, out_ch, 140 | kernel=3, down=True, 141 | resample_kernel=fir_kernel, 142 | use_bias=True, 143 | kernel_init=default_init()) 144 | self.fir = fir 145 | self.fir_kernel = fir_kernel 146 | self.with_conv = with_conv 147 | self.out_ch = out_ch 148 | 149 | def forward(self, x): 150 | B, C, H, W = x.shape 151 | if not self.fir: 152 | if self.with_conv: 153 | x = F.pad(x, (0, 1, 0, 1)) 154 | x = self.Conv_0(x) 155 | else: 156 | x = F.avg_pool2d(x, 2, stride=2) 157 | else: 158 | if not self.with_conv: 159 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 160 | else: 161 | x = self.Conv2d_0(x) 162 | 163 | return x 164 | 165 | 166 | class ResnetBlockDDPMpp(nn.Module): 167 | """ResBlock adapted from DDPM.""" 168 | 169 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, 170 | dropout=0.1, skip_rescale=False, init_scale=0.): 171 | super().__init__() 172 | out_ch = out_ch if out_ch else in_ch 173 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 174 | self.Conv_0 = conv3x3(in_ch, out_ch) 175 | if temb_dim is not None: 176 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 177 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape) 178 | nn.init.zeros_(self.Dense_0.bias) 179 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 180 | self.Dropout_0 = nn.Dropout(dropout) 181 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 182 | if in_ch != out_ch: 183 | if conv_shortcut: 184 | self.Conv_2 = conv3x3(in_ch, out_ch) 185 | else: 186 | self.NIN_0 = NIN(in_ch, out_ch) 187 | 188 | self.skip_rescale = skip_rescale 189 | self.act = act 190 | self.out_ch = out_ch 191 | self.conv_shortcut = conv_shortcut 192 | 193 | def forward(self, x, temb=None): 194 | h = self.act(self.GroupNorm_0(x)) 195 | h = self.Conv_0(h) 196 | if temb is not None: 197 | h += self.Dense_0(self.act(temb))[:, :, None, None] 198 | h = self.act(self.GroupNorm_1(h)) 199 | h = self.Dropout_0(h) 200 | h = self.Conv_1(h) 201 | if x.shape[1] != self.out_ch: 202 | if self.conv_shortcut: 203 | x = self.Conv_2(x) 204 | else: 205 | x = self.NIN_0(x) 206 | if not self.skip_rescale: 207 | return x + h 208 | else: 209 | return (x + h) / np.sqrt(2.) 210 | 211 | 212 | class ResnetBlockBigGANpp(nn.Module): 213 | def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False, 214 | dropout=0.1, fir=False, fir_kernel=(1, 3, 3, 1), 215 | skip_rescale=True, init_scale=0.): 216 | super().__init__() 217 | 218 | out_ch = out_ch if out_ch else in_ch 219 | self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 32), num_channels=in_ch, eps=1e-6) 220 | self.up = up 221 | self.down = down 222 | self.fir = fir 223 | self.fir_kernel = fir_kernel 224 | 225 | self.Conv_0 = conv3x3(in_ch, out_ch) 226 | if temb_dim is not None: 227 | self.Dense_0 = nn.Linear(temb_dim, out_ch) 228 | self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape) 229 | nn.init.zeros_(self.Dense_0.bias) 230 | 231 | self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 32), num_channels=out_ch, eps=1e-6) 232 | self.Dropout_0 = nn.Dropout(dropout) 233 | self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=init_scale) 234 | if in_ch != out_ch or up or down: 235 | self.Conv_2 = conv1x1(in_ch, out_ch) 236 | 237 | self.skip_rescale = skip_rescale 238 | self.act = act 239 | self.in_ch = in_ch 240 | self.out_ch = out_ch 241 | 242 | def forward(self, x, temb=None): 243 | h = self.act(self.GroupNorm_0(x)) 244 | 245 | if self.up: 246 | if self.fir: 247 | h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2) 248 | x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 249 | else: 250 | h = up_or_down_sampling.naive_upsample_2d(h, factor=2) 251 | x = up_or_down_sampling.naive_upsample_2d(x, factor=2) 252 | elif self.down: 253 | if self.fir: 254 | h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2) 255 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 256 | else: 257 | h = up_or_down_sampling.naive_downsample_2d(h, factor=2) 258 | x = up_or_down_sampling.naive_downsample_2d(x, factor=2) 259 | 260 | h = self.Conv_0(h) 261 | # Add bias to each feature map conditioned on the time embedding 262 | if temb is not None: 263 | h += self.Dense_0(self.act(temb))[:, :, None, None] 264 | h = self.act(self.GroupNorm_1(h)) 265 | h = self.Dropout_0(h) 266 | h = self.Conv_1(h) 267 | 268 | if self.in_ch != self.out_ch or self.up or self.down: 269 | x = self.Conv_2(x) 270 | 271 | if not self.skip_rescale: 272 | return x + h 273 | else: 274 | return (x + h) / np.sqrt(2.) 275 | -------------------------------------------------------------------------------- /model/scoresde/ncsnpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | 18 | from . import layers, layerspp, normalization 19 | import torch.nn as nn 20 | import functools 21 | import torch 22 | import numpy as np 23 | 24 | ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp 25 | ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp 26 | Combine = layerspp.Combine 27 | conv3x3 = layerspp.conv3x3 28 | conv1x1 = layerspp.conv1x1 29 | get_act = layers.get_act 30 | get_normalization = normalization.get_normalization 31 | default_initializer = layers.default_init 32 | 33 | 34 | class NCSNpp(nn.Module): 35 | """NCSN++ model""" 36 | 37 | def __init__(self, args, config): 38 | super().__init__() 39 | self.config = config 40 | self.act = act = get_act(config) 41 | 42 | self.nf = nf = config['nf'] 43 | ch_mult = config['ch_mult'] 44 | self.num_res_blocks = num_res_blocks = config['num_res_blocks'] 45 | self.attn_resolutions = attn_resolutions = config['attn_resolutions'] 46 | dropout = config['dropout'] 47 | resamp_with_conv = config['resamp_with_conv'] 48 | self.num_resolutions = num_resolutions = len(ch_mult) 49 | self.all_resolutions = all_resolutions = [config['image_size'] // (2 ** i) for i in range(num_resolutions)] 50 | 51 | self.conditional = conditional = config['conditional'] # noise-conditional 52 | fir = config['fir'] 53 | fir_kernel = config['fir_kernel'] 54 | self.skip_rescale = skip_rescale = config['skip_rescale'] 55 | self.resblock_type = resblock_type = config['resblock_type'] 56 | self.progressive = progressive = config['progressive'] 57 | self.progressive_input = progressive_input = config['progressive_input'] 58 | self.embedding_type = embedding_type = config['embedding_type'] 59 | init_scale = config['init_scale'] 60 | assert progressive in ['none', 'output_skip', 'residual'] 61 | assert progressive_input in ['none', 'input_skip', 'residual'] 62 | assert embedding_type in ['fourier', 'positional'] 63 | combine_method = config['combine_method'] 64 | combiner = functools.partial(Combine, method=combine_method) 65 | 66 | modules = [] 67 | # timestep/noise_level embedding; only for continuous training 68 | if embedding_type == 'fourier': 69 | # Gaussian Fourier features embeddings. 70 | assert config['continuous'], "Fourier features are only used for continuous training." 71 | 72 | modules.append(layerspp.GaussianFourierProjection( 73 | embedding_size=nf, scale=config['fourier_scale'] 74 | )) 75 | embed_dim = 2 * nf 76 | 77 | elif embedding_type == 'positional': 78 | embed_dim = nf 79 | 80 | else: 81 | raise ValueError(f'embedding type {embedding_type} unknown.') 82 | 83 | if conditional: 84 | modules.append(nn.Linear(embed_dim, nf * 4)) 85 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 86 | nn.init.zeros_(modules[-1].bias) 87 | modules.append(nn.Linear(nf * 4, nf * 4)) 88 | modules[-1].weight.data = default_initializer()(modules[-1].weight.shape) 89 | nn.init.zeros_(modules[-1].bias) 90 | 91 | AttnBlock = functools.partial(layerspp.AttnBlockpp, 92 | init_scale=init_scale, 93 | skip_rescale=skip_rescale) 94 | 95 | Upsample = functools.partial(layerspp.Upsample, 96 | with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) 97 | 98 | if progressive == 'output_skip': 99 | self.pyramid_upsample = layerspp.Upsample(fir=fir, fir_kernel=fir_kernel, with_conv=False) 100 | elif progressive == 'residual': 101 | pyramid_upsample = functools.partial(layerspp.Upsample, 102 | fir=fir, fir_kernel=fir_kernel, with_conv=True) 103 | 104 | Downsample = functools.partial(layerspp.Downsample, 105 | with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) 106 | 107 | if progressive_input == 'input_skip': 108 | self.pyramid_downsample = layerspp.Downsample(fir=fir, fir_kernel=fir_kernel, with_conv=False) 109 | elif progressive_input == 'residual': 110 | pyramid_downsample = functools.partial(layerspp.Downsample, 111 | fir=fir, fir_kernel=fir_kernel, with_conv=True) 112 | 113 | if resblock_type == 'ddpm': 114 | ResnetBlock = functools.partial(ResnetBlockDDPM, 115 | act=act, 116 | dropout=dropout, 117 | init_scale=init_scale, 118 | skip_rescale=skip_rescale, 119 | temb_dim=nf * 4) 120 | 121 | elif resblock_type == 'biggan': 122 | ResnetBlock = functools.partial(ResnetBlockBigGAN, 123 | act=act, 124 | dropout=dropout, 125 | fir=fir, 126 | fir_kernel=fir_kernel, 127 | init_scale=init_scale, 128 | skip_rescale=skip_rescale, 129 | temb_dim=nf * 4) 130 | 131 | else: 132 | raise ValueError(f'resblock type {resblock_type} unrecognized.') 133 | 134 | # Downsampling block 135 | 136 | channels = config['num_channels'] 137 | if progressive_input != 'none': 138 | input_pyramid_ch = channels 139 | 140 | modules.append(conv3x3(channels, nf)) 141 | hs_c = [nf] 142 | 143 | in_ch = nf 144 | for i_level in range(num_resolutions): 145 | # Residual blocks for this resolution 146 | for i_block in range(num_res_blocks): 147 | out_ch = nf * ch_mult[i_level] 148 | modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch)) 149 | in_ch = out_ch 150 | 151 | if all_resolutions[i_level] in attn_resolutions: 152 | modules.append(AttnBlock(channels=in_ch)) 153 | hs_c.append(in_ch) 154 | 155 | if i_level != num_resolutions - 1: 156 | if resblock_type == 'ddpm': 157 | modules.append(Downsample(in_ch=in_ch)) 158 | else: 159 | modules.append(ResnetBlock(down=True, in_ch=in_ch)) 160 | 161 | if progressive_input == 'input_skip': 162 | modules.append(combiner(dim1=input_pyramid_ch, dim2=in_ch)) 163 | if combine_method == 'cat': 164 | in_ch *= 2 165 | 166 | elif progressive_input == 'residual': 167 | modules.append(pyramid_downsample(in_ch=input_pyramid_ch, out_ch=in_ch)) 168 | input_pyramid_ch = in_ch 169 | 170 | hs_c.append(in_ch) 171 | 172 | in_ch = hs_c[-1] 173 | modules.append(ResnetBlock(in_ch=in_ch)) 174 | modules.append(AttnBlock(channels=in_ch)) 175 | modules.append(ResnetBlock(in_ch=in_ch)) 176 | 177 | pyramid_ch = 0 178 | # Upsampling block 179 | for i_level in reversed(range(num_resolutions)): 180 | for i_block in range(num_res_blocks + 1): 181 | out_ch = nf * ch_mult[i_level] 182 | modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(), 183 | out_ch=out_ch)) 184 | in_ch = out_ch 185 | 186 | if all_resolutions[i_level] in attn_resolutions: 187 | modules.append(AttnBlock(channels=in_ch)) 188 | 189 | if progressive != 'none': 190 | if i_level == num_resolutions - 1: 191 | if progressive == 'output_skip': 192 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), 193 | num_channels=in_ch, eps=1e-6)) 194 | modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) 195 | pyramid_ch = channels 196 | elif progressive == 'residual': 197 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), 198 | num_channels=in_ch, eps=1e-6)) 199 | modules.append(conv3x3(in_ch, in_ch, bias=True)) 200 | pyramid_ch = in_ch 201 | else: 202 | raise ValueError(f'{progressive} is not a valid name.') 203 | else: 204 | if progressive == 'output_skip': 205 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), 206 | num_channels=in_ch, eps=1e-6)) 207 | modules.append(conv3x3(in_ch, channels, bias=True, init_scale=init_scale)) 208 | pyramid_ch = channels 209 | elif progressive == 'residual': 210 | modules.append(pyramid_upsample(in_ch=pyramid_ch, out_ch=in_ch)) 211 | pyramid_ch = in_ch 212 | else: 213 | raise ValueError(f'{progressive} is not a valid name') 214 | 215 | if i_level != 0: 216 | if resblock_type == 'ddpm': 217 | modules.append(Upsample(in_ch=in_ch)) 218 | else: 219 | modules.append(ResnetBlock(in_ch=in_ch, up=True)) 220 | 221 | assert not hs_c 222 | 223 | if progressive != 'output_skip': 224 | modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 32), 225 | num_channels=in_ch, eps=1e-6)) 226 | modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) 227 | 228 | self.all_modules = nn.ModuleList(modules) 229 | 230 | def forward(self, x, time_cond): 231 | # timestep/noise_level embedding; only for continuous training 232 | modules = self.all_modules 233 | m_idx = 0 234 | if self.embedding_type == 'fourier': 235 | # Gaussian Fourier features embeddings. 236 | used_sigmas = time_cond 237 | temb = modules[m_idx](torch.log(used_sigmas)) 238 | m_idx += 1 239 | 240 | elif self.embedding_type == 'positional': 241 | # Sinusoidal positional embeddings. 242 | timesteps = time_cond 243 | temb = layers.get_timestep_embedding(timesteps, self.nf) 244 | 245 | else: 246 | raise ValueError(f'embedding type {self.embedding_type} unknown.') 247 | 248 | if self.conditional: 249 | temb = modules[m_idx](temb) 250 | m_idx += 1 251 | temb = modules[m_idx](self.act(temb)) 252 | m_idx += 1 253 | else: 254 | temb = None 255 | 256 | if not self.config['centered']: 257 | # If input data is in [0, 1] 258 | x = 2 * x - 1. 259 | 260 | # Downsampling block 261 | input_pyramid = None 262 | if self.progressive_input != 'none': 263 | input_pyramid = x 264 | 265 | hs = [modules[m_idx](x)] 266 | m_idx += 1 267 | for i_level in range(self.num_resolutions): 268 | # Residual blocks for this resolution 269 | for i_block in range(self.num_res_blocks): 270 | h = modules[m_idx](hs[-1], temb) 271 | m_idx += 1 272 | if h.shape[-1] in self.attn_resolutions: 273 | h = modules[m_idx](h) 274 | m_idx += 1 275 | 276 | hs.append(h) 277 | 278 | if i_level != self.num_resolutions - 1: 279 | if self.resblock_type == 'ddpm': 280 | h = modules[m_idx](hs[-1]) 281 | m_idx += 1 282 | else: 283 | h = modules[m_idx](hs[-1], temb) 284 | m_idx += 1 285 | 286 | if self.progressive_input == 'input_skip': 287 | input_pyramid = self.pyramid_downsample(input_pyramid) 288 | h = modules[m_idx](input_pyramid, h) 289 | m_idx += 1 290 | 291 | elif self.progressive_input == 'residual': 292 | input_pyramid = modules[m_idx](input_pyramid) 293 | m_idx += 1 294 | if self.skip_rescale: 295 | input_pyramid = (input_pyramid + h) / np.sqrt(2.) 296 | else: 297 | input_pyramid = input_pyramid + h 298 | h = input_pyramid 299 | 300 | hs.append(h) 301 | 302 | h = hs[-1] 303 | h = modules[m_idx](h, temb) 304 | m_idx += 1 305 | h = modules[m_idx](h) 306 | m_idx += 1 307 | h = modules[m_idx](h, temb) 308 | m_idx += 1 309 | 310 | pyramid = None 311 | 312 | # Upsampling block 313 | for i_level in reversed(range(self.num_resolutions)): 314 | for i_block in range(self.num_res_blocks + 1): 315 | h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb) 316 | m_idx += 1 317 | 318 | if h.shape[-1] in self.attn_resolutions: 319 | h = modules[m_idx](h) 320 | m_idx += 1 321 | 322 | if self.progressive != 'none': 323 | if i_level == self.num_resolutions - 1: 324 | if self.progressive == 'output_skip': 325 | pyramid = self.act(modules[m_idx](h)) 326 | m_idx += 1 327 | pyramid = modules[m_idx](pyramid) 328 | m_idx += 1 329 | elif self.progressive == 'residual': 330 | pyramid = self.act(modules[m_idx](h)) 331 | m_idx += 1 332 | pyramid = modules[m_idx](pyramid) 333 | m_idx += 1 334 | else: 335 | raise ValueError(f'{self.progressive} is not a valid name.') 336 | else: 337 | if self.progressive == 'output_skip': 338 | pyramid = self.pyramid_upsample(pyramid) 339 | pyramid_h = self.act(modules[m_idx](h)) 340 | m_idx += 1 341 | pyramid_h = modules[m_idx](pyramid_h) 342 | m_idx += 1 343 | pyramid = pyramid + pyramid_h 344 | elif self.progressive == 'residual': 345 | pyramid = modules[m_idx](pyramid) 346 | m_idx += 1 347 | if self.skip_rescale: 348 | pyramid = (pyramid + h) / np.sqrt(2.) 349 | else: 350 | pyramid = pyramid + h 351 | h = pyramid 352 | else: 353 | raise ValueError(f'{self.progressive} is not a valid name') 354 | 355 | if i_level != 0: 356 | if self.resblock_type == 'ddpm': 357 | h = modules[m_idx](h) 358 | m_idx += 1 359 | else: 360 | h = modules[m_idx](h, temb) 361 | m_idx += 1 362 | 363 | assert not hs 364 | 365 | if self.progressive == 'output_skip': 366 | h = pyramid 367 | else: 368 | h = self.act(modules[m_idx](h)) 369 | m_idx += 1 370 | h = modules[m_idx](h) 371 | m_idx += 1 372 | 373 | assert m_idx == len(modules) 374 | 375 | return h 376 | -------------------------------------------------------------------------------- /model/scoresde/normalization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Normalization layers.""" 17 | import torch.nn as nn 18 | import torch 19 | import functools 20 | 21 | 22 | def get_normalization(config, conditional=False): 23 | """Obtain normalization modules from the config file.""" 24 | norm = config.model.normalization 25 | if conditional: 26 | if norm == 'InstanceNorm++': 27 | return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes) 28 | else: 29 | raise NotImplementedError(f'{norm} not implemented yet.') 30 | else: 31 | if norm == 'InstanceNorm': 32 | return nn.InstanceNorm2d 33 | elif norm == 'InstanceNorm++': 34 | return InstanceNorm2dPlus 35 | elif norm == 'VarianceNorm': 36 | return VarianceNorm2d 37 | elif norm == 'GroupNorm': 38 | return nn.GroupNorm 39 | else: 40 | raise ValueError('Unknown normalization: %s' % norm) 41 | 42 | 43 | class ConditionalBatchNorm2d(nn.Module): 44 | def __init__(self, num_features, num_classes, bias=True): 45 | super().__init__() 46 | self.num_features = num_features 47 | self.bias = bias 48 | self.bn = nn.BatchNorm2d(num_features, affine=False) 49 | if self.bias: 50 | self.embed = nn.Embedding(num_classes, num_features * 2) 51 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 52 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 53 | else: 54 | self.embed = nn.Embedding(num_classes, num_features) 55 | self.embed.weight.data.uniform_() 56 | 57 | def forward(self, x, y): 58 | out = self.bn(x) 59 | if self.bias: 60 | gamma, beta = self.embed(y).chunk(2, dim=1) 61 | out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(-1, self.num_features, 1, 1) 62 | else: 63 | gamma = self.embed(y) 64 | out = gamma.view(-1, self.num_features, 1, 1) * out 65 | return out 66 | 67 | 68 | class ConditionalInstanceNorm2d(nn.Module): 69 | def __init__(self, num_features, num_classes, bias=True): 70 | super().__init__() 71 | self.num_features = num_features 72 | self.bias = bias 73 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 74 | if bias: 75 | self.embed = nn.Embedding(num_classes, num_features * 2) 76 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 77 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 78 | else: 79 | self.embed = nn.Embedding(num_classes, num_features) 80 | self.embed.weight.data.uniform_() 81 | 82 | def forward(self, x, y): 83 | h = self.instance_norm(x) 84 | if self.bias: 85 | gamma, beta = self.embed(y).chunk(2, dim=-1) 86 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 87 | else: 88 | gamma = self.embed(y) 89 | out = gamma.view(-1, self.num_features, 1, 1) * h 90 | return out 91 | 92 | 93 | class ConditionalVarianceNorm2d(nn.Module): 94 | def __init__(self, num_features, num_classes, bias=False): 95 | super().__init__() 96 | self.num_features = num_features 97 | self.bias = bias 98 | self.embed = nn.Embedding(num_classes, num_features) 99 | self.embed.weight.data.normal_(1, 0.02) 100 | 101 | def forward(self, x, y): 102 | vars = torch.var(x, dim=(2, 3), keepdim=True) 103 | h = x / torch.sqrt(vars + 1e-5) 104 | 105 | gamma = self.embed(y) 106 | out = gamma.view(-1, self.num_features, 1, 1) * h 107 | return out 108 | 109 | 110 | class VarianceNorm2d(nn.Module): 111 | def __init__(self, num_features, bias=False): 112 | super().__init__() 113 | self.num_features = num_features 114 | self.bias = bias 115 | self.alpha = nn.Parameter(torch.zeros(num_features)) 116 | self.alpha.data.normal_(1, 0.02) 117 | 118 | def forward(self, x): 119 | vars = torch.var(x, dim=(2, 3), keepdim=True) 120 | h = x / torch.sqrt(vars + 1e-5) 121 | 122 | out = self.alpha.view(-1, self.num_features, 1, 1) * h 123 | return out 124 | 125 | 126 | class ConditionalNoneNorm2d(nn.Module): 127 | def __init__(self, num_features, num_classes, bias=True): 128 | super().__init__() 129 | self.num_features = num_features 130 | self.bias = bias 131 | if bias: 132 | self.embed = nn.Embedding(num_classes, num_features * 2) 133 | self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02) 134 | self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0 135 | else: 136 | self.embed = nn.Embedding(num_classes, num_features) 137 | self.embed.weight.data.uniform_() 138 | 139 | def forward(self, x, y): 140 | if self.bias: 141 | gamma, beta = self.embed(y).chunk(2, dim=-1) 142 | out = gamma.view(-1, self.num_features, 1, 1) * x + beta.view(-1, self.num_features, 1, 1) 143 | else: 144 | gamma = self.embed(y) 145 | out = gamma.view(-1, self.num_features, 1, 1) * x 146 | return out 147 | 148 | 149 | class NoneNorm2d(nn.Module): 150 | def __init__(self, num_features, bias=True): 151 | super().__init__() 152 | 153 | def forward(self, x): 154 | return x 155 | 156 | 157 | class InstanceNorm2dPlus(nn.Module): 158 | def __init__(self, num_features, bias=True): 159 | super().__init__() 160 | self.num_features = num_features 161 | self.bias = bias 162 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 163 | self.alpha = nn.Parameter(torch.zeros(num_features)) 164 | self.gamma = nn.Parameter(torch.zeros(num_features)) 165 | self.alpha.data.normal_(1, 0.02) 166 | self.gamma.data.normal_(1, 0.02) 167 | if bias: 168 | self.beta = nn.Parameter(torch.zeros(num_features)) 169 | 170 | def forward(self, x): 171 | means = torch.mean(x, dim=(2, 3)) 172 | m = torch.mean(means, dim=-1, keepdim=True) 173 | v = torch.var(means, dim=-1, keepdim=True) 174 | means = (means - m) / (torch.sqrt(v + 1e-5)) 175 | h = self.instance_norm(x) 176 | 177 | if self.bias: 178 | h = h + means[..., None, None] * self.alpha[..., None, None] 179 | out = self.gamma.view(-1, self.num_features, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1) 180 | else: 181 | h = h + means[..., None, None] * self.alpha[..., None, None] 182 | out = self.gamma.view(-1, self.num_features, 1, 1) * h 183 | return out 184 | 185 | 186 | class ConditionalInstanceNorm2dPlus(nn.Module): 187 | def __init__(self, num_features, num_classes, bias=True): 188 | super().__init__() 189 | self.num_features = num_features 190 | self.bias = bias 191 | self.instance_norm = nn.InstanceNorm2d(num_features, affine=False, track_running_stats=False) 192 | if bias: 193 | self.embed = nn.Embedding(num_classes, num_features * 3) 194 | self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02) 195 | self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0 196 | else: 197 | self.embed = nn.Embedding(num_classes, 2 * num_features) 198 | self.embed.weight.data.normal_(1, 0.02) 199 | 200 | def forward(self, x, y): 201 | means = torch.mean(x, dim=(2, 3)) 202 | m = torch.mean(means, dim=-1, keepdim=True) 203 | v = torch.var(means, dim=-1, keepdim=True) 204 | means = (means - m) / (torch.sqrt(v + 1e-5)) 205 | h = self.instance_norm(x) 206 | 207 | if self.bias: 208 | gamma, alpha, beta = self.embed(y).chunk(3, dim=-1) 209 | h = h + means[..., None, None] * alpha[..., None, None] 210 | out = gamma.view(-1, self.num_features, 1, 1) * h + beta.view(-1, self.num_features, 1, 1) 211 | else: 212 | gamma, alpha = self.embed(y).chunk(2, dim=-1) 213 | h = h + means[..., None, None] * alpha[..., None, None] 214 | out = gamma.view(-1, self.num_features, 1, 1) * h 215 | return out 216 | -------------------------------------------------------------------------------- /model/scoresde/up_or_down_sampling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Layers used for up-sampling or down-sampling images. 17 | 18 | Many functions are ported from https://github.com/NVlabs/stylegan2. 19 | """ 20 | 21 | import torch.nn as nn 22 | import torch 23 | import torch.nn.functional as F 24 | import numpy as np 25 | from . import upfirdn2d 26 | 27 | 28 | # Function ported from StyleGAN2 29 | def get_weight(module, 30 | shape, 31 | weight_var='weight', 32 | kernel_init=None): 33 | """Get/create weight tensor for a convolution or fully-connected layer.""" 34 | 35 | return module.param(weight_var, kernel_init, shape) 36 | 37 | 38 | class Conv2d(nn.Module): 39 | """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" 40 | 41 | def __init__(self, in_ch, out_ch, kernel, up=False, down=False, 42 | resample_kernel=(1, 3, 3, 1), 43 | use_bias=True, 44 | kernel_init=None): 45 | super().__init__() 46 | assert not (up and down) 47 | assert kernel >= 1 and kernel % 2 == 1 48 | self.weight = nn.Parameter(torch.zeros(out_ch, in_ch, kernel, kernel)) 49 | if kernel_init is not None: 50 | self.weight.data = kernel_init(self.weight.data.shape) 51 | if use_bias: 52 | self.bias = nn.Parameter(torch.zeros(out_ch)) 53 | 54 | self.up = up 55 | self.down = down 56 | self.resample_kernel = resample_kernel 57 | self.kernel = kernel 58 | self.use_bias = use_bias 59 | 60 | def forward(self, x): 61 | if self.up: 62 | x = upsample_conv_2d(x, self.weight, k=self.resample_kernel) 63 | elif self.down: 64 | x = conv_downsample_2d(x, self.weight, k=self.resample_kernel) 65 | else: 66 | x = F.conv2d(x, self.weight, stride=1, padding=self.kernel // 2) 67 | 68 | if self.use_bias: 69 | x = x + self.bias.reshape(1, -1, 1, 1) 70 | 71 | return x 72 | 73 | 74 | def naive_upsample_2d(x, factor=2): 75 | _N, C, H, W = x.shape 76 | x = torch.reshape(x, (-1, C, H, 1, W, 1)) 77 | x = x.repeat(1, 1, 1, factor, 1, factor) 78 | return torch.reshape(x, (-1, C, H * factor, W * factor)) 79 | 80 | 81 | def naive_downsample_2d(x, factor=2): 82 | _N, C, H, W = x.shape 83 | x = torch.reshape(x, (-1, C, H // factor, factor, W // factor, factor)) 84 | return torch.mean(x, dim=(3, 5)) 85 | 86 | 87 | def upsample_conv_2d(x, w, k=None, factor=2, gain=1): 88 | """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. 89 | 90 | Padding is performed only once at the beginning, not between the 91 | operations. 92 | The fused op is considerably more efficient than performing the same 93 | calculation 94 | using standard TensorFlow ops. It supports gradients of arbitrary order. 95 | Args: 96 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 97 | C]`. 98 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 99 | outChannels]`. Grouped convolution can be performed by `inChannels = 100 | x.shape[0] // numGroups`. 101 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 102 | (separable). The default is `[1] * factor`, which corresponds to 103 | nearest-neighbor upsampling. 104 | factor: Integer upsampling factor (default: 2). 105 | gain: Scaling factor for signal magnitude (default: 1.0). 106 | 107 | Returns: 108 | Tensor of the shape `[N, C, H * factor, W * factor]` or 109 | `[N, H * factor, W * factor, C]`, and same datatype as `x`. 110 | """ 111 | 112 | assert isinstance(factor, int) and factor >= 1 113 | 114 | # Check weight shape. 115 | assert len(w.shape) == 4 116 | convH = w.shape[2] 117 | convW = w.shape[3] 118 | inC = w.shape[1] 119 | outC = w.shape[0] 120 | 121 | assert convW == convH 122 | 123 | # Setup filter kernel. 124 | if k is None: 125 | k = [1] * factor 126 | k = _setup_kernel(k) * (gain * (factor ** 2)) 127 | p = (k.shape[0] - factor) - (convW - 1) 128 | 129 | stride = (factor, factor) 130 | 131 | # Determine data dimensions. 132 | stride = [1, 1, factor, factor] 133 | output_shape = ((_shape(x, 2) - 1) * factor + convH, (_shape(x, 3) - 1) * factor + convW) 134 | output_padding = (output_shape[0] - (_shape(x, 2) - 1) * stride[0] - convH, 135 | output_shape[1] - (_shape(x, 3) - 1) * stride[1] - convW) 136 | assert output_padding[0] >= 0 and output_padding[1] >= 0 137 | num_groups = _shape(x, 1) // inC 138 | 139 | # Transpose weights. 140 | w = torch.reshape(w, (num_groups, -1, inC, convH, convW)) 141 | w = w[..., ::-1, ::-1].permute(0, 2, 1, 3, 4) 142 | w = torch.reshape(w, (num_groups * inC, -1, convH, convW)) 143 | 144 | x = F.conv_transpose2d(x, w, stride=stride, output_padding=output_padding, padding=0) 145 | ## Original TF code. 146 | # x = tf.nn.conv2d_transpose( 147 | # x, 148 | # w, 149 | # output_shape=output_shape, 150 | # strides=stride, 151 | # padding='VALID', 152 | # data_format=data_format) 153 | ## JAX equivalent 154 | 155 | return upfirdn2d(x, torch.tensor(k, device=x.device), 156 | pad=((p + 1) // 2 + factor - 1, p // 2 + 1)) 157 | 158 | 159 | def conv_downsample_2d(x, w, k=None, factor=2, gain=1): 160 | """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. 161 | 162 | Padding is performed only once at the beginning, not between the operations. 163 | The fused op is considerably more efficient than performing the same 164 | calculation 165 | using standard TensorFlow ops. It supports gradients of arbitrary order. 166 | Args: 167 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 168 | C]`. 169 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 170 | outChannels]`. Grouped convolution can be performed by `inChannels = 171 | x.shape[0] // numGroups`. 172 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 173 | (separable). The default is `[1] * factor`, which corresponds to 174 | average pooling. 175 | factor: Integer downsampling factor (default: 2). 176 | gain: Scaling factor for signal magnitude (default: 1.0). 177 | 178 | Returns: 179 | Tensor of the shape `[N, C, H // factor, W // factor]` or 180 | `[N, H // factor, W // factor, C]`, and same datatype as `x`. 181 | """ 182 | 183 | assert isinstance(factor, int) and factor >= 1 184 | _outC, _inC, convH, convW = w.shape 185 | assert convW == convH 186 | if k is None: 187 | k = [1] * factor 188 | k = _setup_kernel(k) * gain 189 | p = (k.shape[0] - factor) + (convW - 1) 190 | s = [factor, factor] 191 | x = upfirdn2d(x, torch.tensor(k, device=x.device), 192 | pad=((p + 1) // 2, p // 2)) 193 | return F.conv2d(x, w, stride=s, padding=0) 194 | 195 | 196 | def _setup_kernel(k): 197 | k = np.asarray(k, dtype=np.float32) 198 | if k.ndim == 1: 199 | k = np.outer(k, k) 200 | k /= np.sum(k) 201 | assert k.ndim == 2 202 | assert k.shape[0] == k.shape[1] 203 | return k 204 | 205 | 206 | def _shape(x, dim): 207 | return x.shape[dim] 208 | 209 | 210 | def upsample_2d(x, k=None, factor=2, gain=1): 211 | r"""Upsample a batch of 2D images with the given filter. 212 | 213 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 214 | and upsamples each image with the given filter. The filter is normalized so 215 | that 216 | if the input pixels are constant, they will be scaled by the specified 217 | `gain`. 218 | Pixels outside the image are assumed to be zero, and the filter is padded 219 | with 220 | zeros so that its shape is a multiple of the upsampling factor. 221 | Args: 222 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 223 | C]`. 224 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 225 | (separable). The default is `[1] * factor`, which corresponds to 226 | nearest-neighbor upsampling. 227 | factor: Integer upsampling factor (default: 2). 228 | gain: Scaling factor for signal magnitude (default: 1.0). 229 | 230 | Returns: 231 | Tensor of the shape `[N, C, H * factor, W * factor]` 232 | """ 233 | assert isinstance(factor, int) and factor >= 1 234 | if k is None: 235 | k = [1] * factor 236 | k = _setup_kernel(k) * (gain * (factor ** 2)) 237 | p = k.shape[0] - factor 238 | return upfirdn2d(x, torch.tensor(k, device=x.device), 239 | up=factor, pad=((p + 1) // 2 + factor - 1, p // 2)) 240 | 241 | 242 | def downsample_2d(x, k=None, factor=2, gain=1): 243 | r"""Downsample a batch of 2D images with the given filter. 244 | 245 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 246 | and downsamples each image with the given filter. The filter is normalized 247 | so that 248 | if the input pixels are constant, they will be scaled by the specified 249 | `gain`. 250 | Pixels outside the image are assumed to be zero, and the filter is padded 251 | with 252 | zeros so that its shape is a multiple of the downsampling factor. 253 | Args: 254 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 255 | C]`. 256 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 257 | (separable). The default is `[1] * factor`, which corresponds to 258 | average pooling. 259 | factor: Integer downsampling factor (default: 2). 260 | gain: Scaling factor for signal magnitude (default: 1.0). 261 | 262 | Returns: 263 | Tensor of the shape `[N, C, H // factor, W // factor]` 264 | """ 265 | 266 | assert isinstance(factor, int) and factor >= 1 267 | if k is None: 268 | k = [1] * factor 269 | k = _setup_kernel(k) * gain 270 | p = k.shape[0] - factor 271 | return upfirdn2d(x, torch.tensor(k, device=x.device), 272 | down=factor, pad=((p + 1) // 2, p // 2)) 273 | -------------------------------------------------------------------------------- /model/scoresde/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // coding=utf-8 2 | // Copyright 2020 The Google Research Authors. 3 | // 4 | // Licensed under the Apache License, Version 2.0 (the "License"); 5 | // you may not use this file except in compliance with the License. 6 | // You may obtain a copy of the License at 7 | // 8 | // http://www.apache.org/licenses/LICENSE-2.0 9 | // 10 | // Unless required by applicable law or agreed to in writing, software 11 | // distributed under the License is distributed on an "AS IS" BASIS, 12 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | // See the License for the specific language governing permissions and 14 | // limitations under the License. 15 | 16 | #include 17 | 18 | 19 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 20 | int up_x, int up_y, int down_x, int down_y, 21 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 22 | 23 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 24 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 25 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 26 | 27 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 28 | int up_x, int up_y, int down_x, int down_y, 29 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 30 | CHECK_CUDA(input); 31 | CHECK_CUDA(kernel); 32 | 33 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 34 | } 35 | 36 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 37 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 38 | } -------------------------------------------------------------------------------- /model/scoresde/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | 18 | import torch 19 | from torch.nn import functional as F 20 | from torch.autograd import Function 21 | from torch.utils.cpp_extension import load 22 | 23 | 24 | module_path = os.path.dirname(__file__) 25 | upfirdn2d_op = load( 26 | "upfirdn2d", 27 | sources=[ 28 | os.path.join(module_path, "upfirdn2d.cpp"), 29 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 30 | ], 31 | ) 32 | 33 | 34 | class UpFirDn2dBackward(Function): 35 | @staticmethod 36 | def forward( 37 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 38 | ): 39 | 40 | up_x, up_y = up 41 | down_x, down_y = down 42 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 43 | 44 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 45 | 46 | grad_input = upfirdn2d_op.upfirdn2d( 47 | grad_output, 48 | grad_kernel, 49 | down_x, 50 | down_y, 51 | up_x, 52 | up_y, 53 | g_pad_x0, 54 | g_pad_x1, 55 | g_pad_y0, 56 | g_pad_y1, 57 | ) 58 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 59 | 60 | ctx.save_for_backward(kernel) 61 | 62 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 63 | 64 | ctx.up_x = up_x 65 | ctx.up_y = up_y 66 | ctx.down_x = down_x 67 | ctx.down_y = down_y 68 | ctx.pad_x0 = pad_x0 69 | ctx.pad_x1 = pad_x1 70 | ctx.pad_y0 = pad_y0 71 | ctx.pad_y1 = pad_y1 72 | ctx.in_size = in_size 73 | ctx.out_size = out_size 74 | 75 | return grad_input 76 | 77 | @staticmethod 78 | def backward(ctx, gradgrad_input): 79 | kernel, = ctx.saved_tensors 80 | 81 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 82 | 83 | gradgrad_out = upfirdn2d_op.upfirdn2d( 84 | gradgrad_input, 85 | kernel, 86 | ctx.up_x, 87 | ctx.up_y, 88 | ctx.down_x, 89 | ctx.down_y, 90 | ctx.pad_x0, 91 | ctx.pad_x1, 92 | ctx.pad_y0, 93 | ctx.pad_y1, 94 | ) 95 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 96 | gradgrad_out = gradgrad_out.view( 97 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 98 | ) 99 | 100 | return gradgrad_out, None, None, None, None, None, None, None, None 101 | 102 | 103 | class UpFirDn2d(Function): 104 | @staticmethod 105 | def forward(ctx, input, kernel, up, down, pad): 106 | up_x, up_y = up 107 | down_x, down_y = down 108 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 109 | 110 | kernel_h, kernel_w = kernel.shape 111 | batch, channel, in_h, in_w = input.shape 112 | ctx.in_size = input.shape 113 | 114 | input = input.reshape(-1, in_h, in_w, 1) 115 | 116 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 117 | 118 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 119 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 120 | ctx.out_size = (out_h, out_w) 121 | 122 | ctx.up = (up_x, up_y) 123 | ctx.down = (down_x, down_y) 124 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 125 | 126 | g_pad_x0 = kernel_w - pad_x0 - 1 127 | g_pad_y0 = kernel_h - pad_y0 - 1 128 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 129 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 130 | 131 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 132 | 133 | out = upfirdn2d_op.upfirdn2d( 134 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 135 | ) 136 | # out = out.view(major, out_h, out_w, minor) 137 | out = out.view(-1, channel, out_h, out_w) 138 | 139 | return out 140 | 141 | @staticmethod 142 | def backward(ctx, grad_output): 143 | kernel, grad_kernel = ctx.saved_tensors 144 | 145 | grad_input = UpFirDn2dBackward.apply( 146 | grad_output, 147 | kernel, 148 | grad_kernel, 149 | ctx.up, 150 | ctx.down, 151 | ctx.pad, 152 | ctx.g_pad, 153 | ctx.in_size, 154 | ctx.out_size, 155 | ) 156 | 157 | return grad_input, None, None, None, None 158 | 159 | 160 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 161 | if input.device.type == "cpu": 162 | out = upfirdn2d_native( 163 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 164 | ) 165 | 166 | else: 167 | out = UpFirDn2d.apply( 168 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 169 | ) 170 | 171 | return out 172 | 173 | 174 | def upfirdn2d_native( 175 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 176 | ): 177 | _, channel, in_h, in_w = input.shape 178 | input = input.reshape(-1, in_h, in_w, 1) 179 | 180 | _, in_h, in_w, minor = input.shape 181 | kernel_h, kernel_w = kernel.shape 182 | 183 | out = input.view(-1, in_h, 1, in_w, 1, minor) 184 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 185 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 186 | 187 | out = F.pad( 188 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 189 | ) 190 | out = out[ 191 | :, 192 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 193 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 194 | :, 195 | ] 196 | 197 | out = out.permute(0, 3, 1, 2) 198 | out = out.reshape( 199 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 200 | ) 201 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 202 | out = F.conv2d(out, w) 203 | out = out.reshape( 204 | -1, 205 | minor, 206 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 207 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 208 | ) 209 | out = out.permute(0, 2, 3, 1) 210 | out = out[:, ::down_y, ::down_x, :] 211 | 212 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 213 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 214 | 215 | return out.view(-1, channel, out_h, out_w) 216 | -------------------------------------------------------------------------------- /model/scoresde/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright 2020 The Google Research Authors. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 16 | // 17 | // This work is made available under the Nvidia Source Code License-NC. 18 | // To view a copy of this license, visit 19 | // https://nvlabs.github.io/stylegan2/license.html 20 | 21 | #include 22 | 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | #include 29 | #include 30 | 31 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 32 | int c = a / b; 33 | 34 | if (c * b > a) { 35 | c--; 36 | } 37 | 38 | return c; 39 | } 40 | 41 | struct UpFirDn2DKernelParams { 42 | int up_x; 43 | int up_y; 44 | int down_x; 45 | int down_y; 46 | int pad_x0; 47 | int pad_x1; 48 | int pad_y0; 49 | int pad_y1; 50 | 51 | int major_dim; 52 | int in_h; 53 | int in_w; 54 | int minor_dim; 55 | int kernel_h; 56 | int kernel_w; 57 | int out_h; 58 | int out_w; 59 | int loop_major; 60 | int loop_x; 61 | }; 62 | 63 | template 64 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 65 | const scalar_t *kernel, 66 | const UpFirDn2DKernelParams p) { 67 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 68 | int out_y = minor_idx / p.minor_dim; 69 | minor_idx -= out_y * p.minor_dim; 70 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 71 | int major_idx_base = blockIdx.z * p.loop_major; 72 | 73 | if (out_x_base >= p.out_w || out_y >= p.out_h || 74 | major_idx_base >= p.major_dim) { 75 | return; 76 | } 77 | 78 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 79 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 80 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 81 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 82 | 83 | for (int loop_major = 0, major_idx = major_idx_base; 84 | loop_major < p.loop_major && major_idx < p.major_dim; 85 | loop_major++, major_idx++) { 86 | for (int loop_x = 0, out_x = out_x_base; 87 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 88 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 89 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 90 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 91 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 92 | 93 | const scalar_t *x_p = 94 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 95 | minor_idx]; 96 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 97 | int x_px = p.minor_dim; 98 | int k_px = -p.up_x; 99 | int x_py = p.in_w * p.minor_dim; 100 | int k_py = -p.up_y * p.kernel_w; 101 | 102 | scalar_t v = 0.0f; 103 | 104 | for (int y = 0; y < h; y++) { 105 | for (int x = 0; x < w; x++) { 106 | v += static_cast(*x_p) * static_cast(*k_p); 107 | x_p += x_px; 108 | k_p += k_px; 109 | } 110 | 111 | x_p += x_py - w * x_px; 112 | k_p += k_py - w * k_px; 113 | } 114 | 115 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 116 | minor_idx] = v; 117 | } 118 | } 119 | } 120 | 121 | template 123 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 124 | const scalar_t *kernel, 125 | const UpFirDn2DKernelParams p) { 126 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 127 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 128 | 129 | __shared__ volatile float sk[kernel_h][kernel_w]; 130 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 131 | 132 | int minor_idx = blockIdx.x; 133 | int tile_out_y = minor_idx / p.minor_dim; 134 | minor_idx -= tile_out_y * p.minor_dim; 135 | tile_out_y *= tile_out_h; 136 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 137 | int major_idx_base = blockIdx.z * p.loop_major; 138 | 139 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 140 | major_idx_base >= p.major_dim) { 141 | return; 142 | } 143 | 144 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 145 | tap_idx += blockDim.x) { 146 | int ky = tap_idx / kernel_w; 147 | int kx = tap_idx - ky * kernel_w; 148 | scalar_t v = 0.0; 149 | 150 | if (kx < p.kernel_w & ky < p.kernel_h) { 151 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 152 | } 153 | 154 | sk[ky][kx] = v; 155 | } 156 | 157 | for (int loop_major = 0, major_idx = major_idx_base; 158 | loop_major < p.loop_major & major_idx < p.major_dim; 159 | loop_major++, major_idx++) { 160 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 161 | loop_x < p.loop_x & tile_out_x < p.out_w; 162 | loop_x++, tile_out_x += tile_out_w) { 163 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 164 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 165 | int tile_in_x = floor_div(tile_mid_x, up_x); 166 | int tile_in_y = floor_div(tile_mid_y, up_y); 167 | 168 | __syncthreads(); 169 | 170 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 171 | in_idx += blockDim.x) { 172 | int rel_in_y = in_idx / tile_in_w; 173 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 174 | int in_x = rel_in_x + tile_in_x; 175 | int in_y = rel_in_y + tile_in_y; 176 | 177 | scalar_t v = 0.0; 178 | 179 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 180 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 181 | p.minor_dim + 182 | minor_idx]; 183 | } 184 | 185 | sx[rel_in_y][rel_in_x] = v; 186 | } 187 | 188 | __syncthreads(); 189 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 190 | out_idx += blockDim.x) { 191 | int rel_out_y = out_idx / tile_out_w; 192 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 193 | int out_x = rel_out_x + tile_out_x; 194 | int out_y = rel_out_y + tile_out_y; 195 | 196 | int mid_x = tile_mid_x + rel_out_x * down_x; 197 | int mid_y = tile_mid_y + rel_out_y * down_y; 198 | int in_x = floor_div(mid_x, up_x); 199 | int in_y = floor_div(mid_y, up_y); 200 | int rel_in_x = in_x - tile_in_x; 201 | int rel_in_y = in_y - tile_in_y; 202 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 203 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 204 | 205 | scalar_t v = 0.0; 206 | 207 | #pragma unroll 208 | for (int y = 0; y < kernel_h / up_y; y++) 209 | #pragma unroll 210 | for (int x = 0; x < kernel_w / up_x; x++) 211 | v += sx[rel_in_y + y][rel_in_x + x] * 212 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 213 | 214 | if (out_x < p.out_w & out_y < p.out_h) { 215 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 216 | minor_idx] = v; 217 | } 218 | } 219 | } 220 | } 221 | } 222 | 223 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 224 | const torch::Tensor &kernel, int up_x, int up_y, 225 | int down_x, int down_y, int pad_x0, int pad_x1, 226 | int pad_y0, int pad_y1) { 227 | int curDevice = -1; 228 | cudaGetDevice(&curDevice); 229 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 230 | 231 | UpFirDn2DKernelParams p; 232 | 233 | auto x = input.contiguous(); 234 | auto k = kernel.contiguous(); 235 | 236 | p.major_dim = x.size(0); 237 | p.in_h = x.size(1); 238 | p.in_w = x.size(2); 239 | p.minor_dim = x.size(3); 240 | p.kernel_h = k.size(0); 241 | p.kernel_w = k.size(1); 242 | p.up_x = up_x; 243 | p.up_y = up_y; 244 | p.down_x = down_x; 245 | p.down_y = down_y; 246 | p.pad_x0 = pad_x0; 247 | p.pad_x1 = pad_x1; 248 | p.pad_y0 = pad_y0; 249 | p.pad_y1 = pad_y1; 250 | 251 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 252 | p.down_y; 253 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 254 | p.down_x; 255 | 256 | auto out = 257 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 258 | 259 | int mode = -1; 260 | 261 | int tile_out_h = -1; 262 | int tile_out_w = -1; 263 | 264 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 1; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 3 && p.kernel_w <= 3) { 273 | mode = 2; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 3; 281 | tile_out_h = 16; 282 | tile_out_w = 64; 283 | } 284 | 285 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 4; 288 | tile_out_h = 16; 289 | tile_out_w = 64; 290 | } 291 | 292 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 293 | p.kernel_h <= 4 && p.kernel_w <= 4) { 294 | mode = 5; 295 | tile_out_h = 8; 296 | tile_out_w = 32; 297 | } 298 | 299 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 300 | p.kernel_h <= 2 && p.kernel_w <= 2) { 301 | mode = 6; 302 | tile_out_h = 8; 303 | tile_out_w = 32; 304 | } 305 | 306 | dim3 block_size; 307 | dim3 grid_size; 308 | 309 | if (tile_out_h > 0 && tile_out_w > 0) { 310 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 311 | p.loop_x = 1; 312 | block_size = dim3(32 * 8, 1, 1); 313 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 314 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 315 | (p.major_dim - 1) / p.loop_major + 1); 316 | } else { 317 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 318 | p.loop_x = 4; 319 | block_size = dim3(4, 32, 1); 320 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 321 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 322 | (p.major_dim - 1) / p.loop_major + 1); 323 | } 324 | 325 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 326 | switch (mode) { 327 | case 1: 328 | upfirdn2d_kernel 329 | <<>>(out.data_ptr(), 330 | x.data_ptr(), 331 | k.data_ptr(), p); 332 | 333 | break; 334 | 335 | case 2: 336 | upfirdn2d_kernel 337 | <<>>(out.data_ptr(), 338 | x.data_ptr(), 339 | k.data_ptr(), p); 340 | 341 | break; 342 | 343 | case 3: 344 | upfirdn2d_kernel 345 | <<>>(out.data_ptr(), 346 | x.data_ptr(), 347 | k.data_ptr(), p); 348 | 349 | break; 350 | 351 | case 4: 352 | upfirdn2d_kernel 353 | <<>>(out.data_ptr(), 354 | x.data_ptr(), 355 | k.data_ptr(), p); 356 | 357 | break; 358 | 359 | case 5: 360 | upfirdn2d_kernel 361 | <<>>(out.data_ptr(), 362 | x.data_ptr(), 363 | k.data_ptr(), p); 364 | 365 | break; 366 | 367 | case 6: 368 | upfirdn2d_kernel 369 | <<>>(out.data_ptr(), 370 | x.data_ptr(), 371 | k.data_ptr(), p); 372 | 373 | break; 374 | 375 | default: 376 | upfirdn2d_kernel_large<<>>( 377 | out.data_ptr(), x.data_ptr(), 378 | k.data_ptr(), p); 379 | } 380 | }); 381 | 382 | return out; 383 | } -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Pseudo Numerical Methods for Diffusion Models on Manifolds (PNDM, PLMS | ICLR2022) 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/pseudo-numerical-methods-for-diffusion-models-1/image-generation-on-celeba-64x64)](https://paperswithcode.com/sota/image-generation-on-celeba-64x64?p=pseudo-numerical-methods-for-diffusion-models-1) 3 | 4 | This repo is the official PyTorch implementation for the paper [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://openreview.net/forum?id=PlKWVd2yBkY) (PNDM, PLMS | ICLR2022) 5 | 6 | by [Luping Liu](https://luping-liu.github.io/), [Yi Ren](https://rayeren.github.io/), Zhijie Lin, Zhou Zhao (Zhejiang University). 7 | 8 | 9 | ## What does this code do? 10 | This code is not only the official implementation for PNDM, but also a generic framework for DDIM-like models including: 11 | - [x] [Pseudo Numerical Methods for Diffusion Models on Manifolds (PNDM)](https://openreview.net/forum?id=PlKWVd2yBkY) 12 | - [x] [Denoising Diffusion Implicit Models (DDIM)](https://arxiv.org/abs/2010.02502) 13 | - [x] [Score-Based Generative Modeling through Stochastic Differential Equations (PF)](https://arxiv.org/abs/2011.13456) 14 | - [x] [Improved Denoising Diffusion Probabilistic Models (iDDPM)](https://arxiv.org/abs/2102.09672) 15 | 16 | ### Structure 17 | This code contains three main objects including method, schedule and model. The following table shows the options 18 | supported by this code and the role of each object. 19 | 20 | | Object | Option | Role | 21 | |----------|-------------------------------|-----------------------------------------------| 22 | | method | DDIM, S-PNDM, F-PNDM, FON, PF | the numerical method used to generate samples | 23 | | schedule | linear, quad, cosine | the schedule of adding noise to images | 24 | | model | DDIM, iDDPM, PF, PF_deep | the neural network used to fit noise | 25 | 26 | All of them can be combined at will, so this code provide at least 5x3x4=60 choices to generate samples. 27 | 28 | 29 | ## Integration with 🤗 Diffusers library 30 | 31 | PNDM is now also available in 🧨 Diffusers and accesible via the [PNDMPipeline](https://huggingface.co/docs/diffusers/api/pipelines/pndm). 32 | Diffusers allows you to test PNDM in PyTorch in just a couple lines of code. 33 | 34 | You can install diffusers as follows: 35 | 36 | ``` 37 | pip install diffusers torch accelerate 38 | ``` 39 | 40 | And then try out the sampler/scheduler with just a couple lines of code: 41 | 42 | ```python 43 | from diffusers import PNDMPipeline 44 | 45 | model_id = "google/ddpm-cifar10-32" 46 | 47 | # load model and scheduler 48 | pndm = PNDMPipeline.from_pretrained(model_id) 49 | 50 | # run pipeline in inference (sample random noise and denoise) 51 | image = pndm(num_inference_steps=50).images[0] 52 | 53 | # save image 54 | image.save("pndm_generated_image.png") 55 | ``` 56 | 57 | The PNDM scheduler can also be used with more powerful diffusion models such as [Stable Diffusion](https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#stable-diffusion-pipelines) 58 | 59 | You simply need to [accept the license on the Hub](https://huggingface.co/runwayml/stable-diffusion-v1-5), login with `huggingface-cli login` and install transformers: 60 | 61 | ``` 62 | pip install transformers 63 | ``` 64 | 65 | Then you can run: 66 | 67 | ```python 68 | from diffusers import StableDiffusionPipeline, PNDMScheduler 69 | 70 | pndm = PNDMScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler") 71 | pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=pndm) 72 | 73 | image = pipeline("An astronaut riding a horse.").images[0] 74 | image.save("astronaut_riding_a_horse.png") 75 | ``` 76 | 77 | 78 | ## How to run the code 79 | 80 | ### Dependencies 81 | Run the following to install a subset of necessary python packages for our code. 82 | ```bash 83 | pip install -r requirements.txt 84 | ``` 85 | Tip: mpi4py can make the generation process faster using multi-gpus. It is not necessary and can be removed freely. 86 | 87 | ### Usage 88 | Evaluate our models through main.py. 89 | ```bash 90 | python main.py --runner sample --method F-PNDM --sample_speed 50 --device cuda --config ddim_cifar10.yml --image_path temp/results --model_path temp/models/ddim/ema_cifar10.ckpt 91 | ``` 92 | - runner (train|sample): choose the mode of runner 93 | - method (DDIM|FON|S-PNDM|F-PNDM|PF): choose the numerical methods 94 | - sample_speed: control the total generation step 95 | - device (cpu|cuda:0): choose the device to use 96 | - config: choose the config file 97 | - image_path: choose the path to save images 98 | - model_path: choose the path of model 99 | 100 | Train our models through main.py. 101 | ```bash 102 | python main.py --runner train --device cuda --config ddim_cifar10.yml --train_path temp/train 103 | ``` 104 | - train_path: choose the path to save training status 105 | 106 | ### Checkpoints & statistics 107 | All checkpoints of models and precalculated statistics for FID are provided in this [Google Drive](https://drive.google.com/drive/folders/1leEpziaPdYlshzB4QALY5pL7Snw8XM28?usp=share_link). 108 | 109 | 110 | ## References 111 | If you find the code useful for your research, please consider citing: 112 | ```bib 113 | @inproceedings{liu2022pseudo, 114 | title={Pseudo Numerical Methods for Diffusion Models on Manifolds}, 115 | author={Luping Liu and Yi Ren and Zhijie Lin and Zhou Zhao}, 116 | booktitle={International Conference on Learning Representations}, 117 | year={2022}, 118 | url={https://openreview.net/forum?id=PlKWVd2yBkY} 119 | } 120 | ``` 121 | This work is built upon some previous papers which might also interest you: 122 | - Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. Advances in Neural Information Processing Systems 33 (2020): 6840-6851. 123 | - Jiaming Song, Chenlin Meng, and Stefano Ermon. Denoising Diffusion Implicit Models. International Conference on Learning Representations. 2021. 124 | - Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. Score-Based Generative Modeling through Stochastic Differential Equations. International Conference on Learning Representations. 2021. 125 | 126 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.0 2 | tensorboard 3 | pytorch-fid 4 | PyYAML 5 | mpi4py 6 | scipy 7 | numpy 8 | lmdb 9 | tqdm 10 | -------------------------------------------------------------------------------- /runner/method.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Luping Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import copy 17 | import torch as th 18 | 19 | 20 | def choose_method(name): 21 | if name == 'DDIM': 22 | return gen_order_1 23 | elif name == 'S-PNDM': 24 | return gen_order_2 25 | elif name == 'F-PNDM': 26 | return gen_order_4 27 | elif name == 'FON': 28 | return gen_fon 29 | elif name == 'PF': 30 | return gen_pflow 31 | else: 32 | return None 33 | 34 | 35 | def gen_pflow(img, t, t_next, model, betas, total_step): 36 | n = img.shape[0] 37 | beta_0, beta_1 = betas[0], betas[-1] 38 | 39 | t_start = th.ones(n, device=img.device) * t 40 | beta_t = (beta_0 + t_start * (beta_1 - beta_0)) * total_step 41 | 42 | log_mean_coeff = (-0.25 * t_start ** 2 * (beta_1 - beta_0) - 0.5 * t_start * beta_0) * total_step 43 | std = th.sqrt(1. - th.exp(2. * log_mean_coeff)) 44 | 45 | # drift, diffusion -> f(x,t), g(t) 46 | drift, diffusion = -0.5 * beta_t.view(-1, 1, 1, 1) * img, th.sqrt(beta_t) 47 | score = - model(img, t_start * (total_step - 1)) / std.view(-1, 1, 1, 1) # score -> noise 48 | drift = drift - diffusion.view(-1, 1, 1, 1) ** 2 * score * 0.5 # drift -> dx/dt 49 | 50 | return drift 51 | 52 | 53 | def gen_fon(img, t, t_next, model, alphas_cump, ets): 54 | t_list = [t, (t + t_next) / 2.0, t_next] 55 | if len(ets) > 2: 56 | noise = model(img, t) 57 | img_next = transfer(img, t, t-1, noise, alphas_cump) 58 | delta1 = img_next - img 59 | ets.append(delta1) 60 | delta = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) 61 | else: 62 | noise = model(img, t_list[0]) 63 | img_ = transfer(img, t, t - 1, noise, alphas_cump) 64 | delta_1 = img_ - img 65 | ets.append(delta_1) 66 | 67 | img_2 = img + delta_1 * (t - t_next).view(-1, 1, 1, 1) / 2.0 68 | noise = model(img_2, t_list[1]) 69 | img_ = transfer(img, t, t - 1, noise, alphas_cump) 70 | delta_2 = img_ - img 71 | 72 | img_3 = img + delta_2 * (t - t_next).view(-1, 1, 1, 1) / 2.0 73 | noise = model(img_3, t_list[1]) 74 | img_ = transfer(img, t, t - 1, noise, alphas_cump) 75 | delta_3 = img_ - img 76 | 77 | img_4 = img + delta_3 * (t - t_next).view(-1, 1, 1, 1) 78 | noise = model(img_4, t_list[2]) 79 | img_ = transfer(img, t, t - 1, noise, alphas_cump) 80 | delta_4 = img_ - img 81 | delta = (1 / 6.0) * (delta_1 + 2*delta_2 + 2*delta_3 + delta_4) 82 | 83 | img_next = img + delta * (t - t_next).view(-1, 1, 1, 1) 84 | return img_next 85 | 86 | 87 | def gen_order_4(img, t, t_next, model, alphas_cump, ets): 88 | t_list = [t, (t+t_next)/2, t_next] 89 | if len(ets) > 2: 90 | noise_ = model(img, t) 91 | ets.append(noise_) 92 | noise = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4]) 93 | else: 94 | noise = runge_kutta(img, t_list, model, alphas_cump, ets) 95 | 96 | img_next = transfer(img, t, t_next, noise, alphas_cump) 97 | return img_next 98 | 99 | 100 | def runge_kutta(x, t_list, model, alphas_cump, ets): 101 | e_1 = model(x, t_list[0]) 102 | ets.append(e_1) 103 | x_2 = transfer(x, t_list[0], t_list[1], e_1, alphas_cump) 104 | 105 | e_2 = model(x_2, t_list[1]) 106 | x_3 = transfer(x, t_list[0], t_list[1], e_2, alphas_cump) 107 | 108 | e_3 = model(x_3, t_list[1]) 109 | x_4 = transfer(x, t_list[0], t_list[2], e_3, alphas_cump) 110 | 111 | e_4 = model(x_4, t_list[2]) 112 | et = (1 / 6) * (e_1 + 2 * e_2 + 2 * e_3 + e_4) 113 | 114 | return et 115 | 116 | 117 | def gen_order_2(img, t, t_next, model, alphas_cump, ets): 118 | if len(ets) > 0: 119 | noise_ = model(img, t) 120 | ets.append(noise_) 121 | noise = 0.5 * (3 * ets[-1] - ets[-2]) 122 | else: 123 | noise = improved_eular(img, t, t_next, model, alphas_cump, ets) 124 | 125 | img_next = transfer(img, t, t_next, noise, alphas_cump) 126 | return img_next 127 | 128 | 129 | def improved_eular(x, t, t_next, model, alphas_cump, ets): 130 | e_1 = model(x, t) 131 | ets.append(e_1) 132 | x_2 = transfer(x, t, t_next, e_1, alphas_cump) 133 | 134 | e_2 = model(x_2, t_next) 135 | et = (e_1 + e_2) / 2 136 | # x_next = transfer(x, t, t_next, et, alphas_cump) 137 | 138 | return et 139 | 140 | 141 | def gen_order_1(img, t, t_next, model, alphas_cump, ets): 142 | noise = model(img, t) 143 | ets.append(noise) 144 | img_next = transfer(img, t, t_next, noise, alphas_cump) 145 | return img_next 146 | 147 | 148 | def transfer(x, t, t_next, et, alphas_cump): 149 | at = alphas_cump[t.long() + 1].view(-1, 1, 1, 1) 150 | at_next = alphas_cump[t_next.long() + 1].view(-1, 1, 1, 1) 151 | 152 | x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - \ 153 | 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et) 154 | 155 | x_next = x + x_delta 156 | return x_next 157 | 158 | 159 | def transfer_dev(x, t, t_next, et, alphas_cump): 160 | at = alphas_cump[t.long()+1].view(-1, 1, 1, 1) 161 | at_next = alphas_cump[t_next.long()+1].view(-1, 1, 1, 1) 162 | 163 | x_start = th.sqrt(1.0 / at) * x - th.sqrt(1.0 / at - 1) * et 164 | x_start = x_start.clamp(-1.0, 1.0) 165 | 166 | x_next = x_start * th.sqrt(at_next) + th.sqrt(1 - at_next) * et 167 | 168 | return x_next 169 | -------------------------------------------------------------------------------- /runner/runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Luping Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | import time 18 | import torch as th 19 | import numpy as np 20 | import torch.optim as optimi 21 | import torch.utils.data as data 22 | import torchvision.utils as tvu 23 | import torch.utils.tensorboard as tb 24 | from scipy import integrate 25 | # from torchdiffeq import odeint 26 | from tqdm.auto import tqdm 27 | 28 | from dataset import get_dataset, inverse_data_transform 29 | from model.ema import EMAHelper 30 | 31 | 32 | def get_optim(params, config): 33 | if config['optimizer'] == 'adam': 34 | optim = optimi.Adam(params, lr=config['lr'], weight_decay=config['weight_decay'], 35 | betas=(config['beta1'], 0.999), amsgrad=config['amsgrad'], 36 | eps=config['eps']) 37 | elif config['optimizer'] == 'sgd': 38 | optim = optimi.SGD(params, lr=config['lr'], momentum=0.9) 39 | else: 40 | optim = None 41 | 42 | return optim 43 | 44 | 45 | class Runner(object): 46 | def __init__(self, args, config, schedule, model): 47 | self.args = args 48 | self.config = config 49 | self.diffusion_step = config['Schedule']['diffusion_step'] 50 | self.sample_speed = args.sample_speed 51 | self.device = th.device(args.device) 52 | 53 | self.schedule = schedule 54 | self.model = model 55 | 56 | def train(self): 57 | schedule = self.schedule 58 | model = self.model 59 | model = th.nn.DataParallel(model) 60 | 61 | optim = get_optim(model.parameters(), self.config['Optim']) 62 | 63 | config = self.config['Dataset'] 64 | dataset, test_dataset = get_dataset(self.args, config) 65 | train_loader = data.DataLoader(dataset, batch_size=config['batch_size'], shuffle=True, 66 | num_workers=config['num_workers']) 67 | 68 | config = self.config['Train'] 69 | if config['ema']: 70 | ema = EMAHelper(mu=config['ema_rate']) 71 | ema.register(model) 72 | else: 73 | ema = None 74 | 75 | tb_logger = tb.SummaryWriter(f'temp/tensorboard/{time.strftime("%m%d-%H%M")}') 76 | epoch, step = 0, 0 77 | 78 | if self.args.restart: 79 | train_state = th.load(os.path.join(self.args.train_path, 'train.ckpt'), map_location=self.device) 80 | model.load_state_dict(train_state[0]) 81 | optim.load_state_dict(train_state[1]) 82 | epoch, step = train_state[2:4] 83 | if ema is not None: 84 | ema_state = th.load(os.path.join(self.args.train_path, 'ema.ckpt'), map_location=self.device) 85 | ema.load_state_dict(ema_state) 86 | 87 | for epoch in range(epoch, config['epoch']): 88 | for i, (img, y) in enumerate(train_loader): 89 | n = img.shape[0] 90 | model.train() 91 | step += 1 92 | t = th.randint(low=0, high=self.diffusion_step, size=(n // 2 + 1,)) 93 | t = th.cat([t, self.diffusion_step - t - 1], dim=0)[:n].to(self.device) 94 | img = img.to(self.device) * 2.0 - 1.0 95 | 96 | img_n, noise = schedule.diffusion(img, t) 97 | noise_p = model(img_n, t) 98 | 99 | if config['loss_type'] == 'linear': 100 | loss = (noise_p - noise).abs().sum(dim=(1, 2, 3)).mean(dim=0) 101 | elif config['loss_type'] == 'square': 102 | loss = (noise_p - noise).square().sum(dim=(1, 2, 3)).mean(dim=0) 103 | else: 104 | loss = None 105 | 106 | optim.zero_grad() 107 | loss.backward() 108 | try: 109 | th.nn.utils.clip_grad_norm_(model.parameters(), self.config['Optim']['grad_clip']) 110 | except Exception: 111 | pass 112 | optim.step() 113 | 114 | if ema is not None: 115 | ema.update(model) 116 | 117 | if step % 10 == 0: 118 | tb_logger.add_scalar('loss', loss, global_step=step) 119 | if step % 50 == 0: 120 | print(step, loss.item()) 121 | if step % 500 == 0: 122 | config = self.config['Dataset'] 123 | model.eval() 124 | skip = self.diffusion_step // self.sample_speed 125 | seq = range(0, self.diffusion_step, skip) 126 | noise = th.randn(16, config['channels'], config['image_size'], 127 | config['image_size'], device=self.device) 128 | img = self.sample_image(noise, seq, model) 129 | img = th.clamp(img * 0.5 + 0.5, 0.0, 1.0) 130 | tb_logger.add_images('sample', img, global_step=step) 131 | config = self.config['Train'] 132 | model.train() 133 | 134 | if step % 5000 == 0: 135 | train_state = [model.state_dict(), optim.state_dict(), epoch, step] 136 | th.save(train_state, os.path.join(self.args.train_path, 'train.ckpt')) 137 | if ema is not None: 138 | th.save(ema.state_dict(), os.path.join(self.args.train_path, 'ema.ckpt')) 139 | 140 | def sample_fid(self): 141 | config = self.config['Sample'] 142 | mpi_rank = 0 143 | if config['mpi4py']: 144 | from mpi4py import MPI 145 | comm = MPI.COMM_WORLD 146 | mpi_rank = comm.Get_rank() 147 | 148 | model = self.model 149 | device = self.device 150 | pflow = True if self.args.method == 'PF' else False 151 | 152 | model.load_state_dict(th.load(self.args.model_path, map_location=device), strict=True) 153 | model.eval() 154 | 155 | n = config['batch_size'] 156 | total_num = config['total_num'] 157 | 158 | skip = self.diffusion_step // self.sample_speed 159 | seq = range(0, self.diffusion_step, skip) 160 | seq_next = [-1] + list(seq[:-1]) 161 | image_num = 0 162 | 163 | config = self.config['Dataset'] 164 | if mpi_rank == 0: 165 | my_iter = tqdm(range(total_num // n + 1), ncols=120) 166 | else: 167 | my_iter = range(total_num // n + 1) 168 | 169 | for _ in my_iter: 170 | noise = th.randn(n, config['channels'], config['image_size'], 171 | config['image_size'], device=self.device) 172 | 173 | img = self.sample_image(noise, seq, model, pflow) 174 | 175 | img = inverse_data_transform(config, img) 176 | for i in range(img.shape[0]): 177 | if image_num+i > total_num: 178 | break 179 | tvu.save_image(img[i], os.path.join(self.args.image_path, f"{mpi_rank}-{image_num+i}.png")) 180 | 181 | image_num += n 182 | 183 | def sample_image(self, noise, seq, model, pflow=False): 184 | with th.no_grad(): 185 | if pflow: 186 | shape = noise.shape 187 | device = self.device 188 | tol = 1e-5 if self.sample_speed > 1 else self.sample_speed 189 | 190 | def drift_func(t, x): 191 | x = th.from_numpy(x.reshape(shape)).to(device).type(th.float32) 192 | drift = self.schedule.denoising(x, None, t, model, pflow=pflow) 193 | drift = drift.cpu().numpy().reshape((-1,)) 194 | return drift 195 | 196 | solution = integrate.solve_ivp(drift_func, (1, 1e-3), noise.cpu().numpy().reshape((-1,)), 197 | rtol=tol, atol=tol, method='RK45') 198 | img = th.tensor(solution.y[:, -1]).reshape(shape).type(th.float32) 199 | 200 | else: 201 | imgs = [noise] 202 | seq_next = [-1] + list(seq[:-1]) 203 | 204 | start = True 205 | n = noise.shape[0] 206 | 207 | for i, j in zip(reversed(seq), reversed(seq_next)): 208 | t = (th.ones(n) * i).to(self.device) 209 | t_next = (th.ones(n) * j).to(self.device) 210 | 211 | img_t = imgs[-1].to(self.device) 212 | img_next = self.schedule.denoising(img_t, t_next, t, model, start, pflow) 213 | start = False 214 | 215 | imgs.append(img_next.to('cpu')) 216 | 217 | img = imgs[-1] 218 | 219 | return img 220 | -------------------------------------------------------------------------------- /runner/schedule.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Luping Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import math 17 | import torch as th 18 | import torch.nn as nn 19 | import numpy as np 20 | 21 | import runner.method as mtd 22 | 23 | 24 | def get_schedule(args, config): 25 | if config['type'] == "quad": 26 | betas = (np.linspace(config['beta_start'] ** 0.5, config['beta_end'] ** 0.5, config['diffusion_step'], dtype=np.float64) ** 2) 27 | elif config['type'] == "linear": 28 | betas = np.linspace(config['beta_start'], config['beta_end'], config['diffusion_step'], dtype=np.float64) 29 | elif config['type'] == 'cosine': 30 | betas = betas_for_alpha_bar(config['diffusion_step'], lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2) 31 | else: 32 | betas = None 33 | 34 | betas = th.from_numpy(betas).float() 35 | alphas = 1.0 - betas 36 | alphas_cump = alphas.cumprod(dim=0) 37 | 38 | return betas, alphas_cump 39 | 40 | 41 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 42 | betas = [] 43 | for i in range(num_diffusion_timesteps): 44 | t1 = i / num_diffusion_timesteps 45 | t2 = (i + 1) / num_diffusion_timesteps 46 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 47 | return np.array(betas) 48 | 49 | 50 | class Schedule(object): 51 | def __init__(self, args, config): 52 | device = th.device(args.device) 53 | betas, alphas_cump = get_schedule(args, config) 54 | 55 | self.betas, self.alphas_cump = betas.to(device), alphas_cump.to(device) 56 | self.alphas_cump_pre = th.cat([th.ones(1).to(device), self.alphas_cump[:-1]], dim=0) 57 | self.total_step = config['diffusion_step'] 58 | 59 | self.method = mtd.choose_method(args.method) # add pflow 60 | self.ets = None 61 | 62 | def diffusion(self, img, t_end, t_start=0, noise=None): 63 | if noise is None: 64 | noise = th.randn_like(img) 65 | alpha = self.alphas_cump.index_select(0, t_end).view(-1, 1, 1, 1) 66 | img_n = img * alpha.sqrt() + noise * (1 - alpha).sqrt() 67 | 68 | return img_n, noise 69 | 70 | def denoising(self, img_n, t_end, t_start, model, first_step=False, pflow=False): 71 | if pflow: 72 | drift = self.method(img_n, t_start, t_end, model, self.betas, self.total_step) 73 | 74 | return drift 75 | else: 76 | if first_step: 77 | self.ets = [] 78 | img_next = self.method(img_n, t_start, t_end, model, self.alphas_cump, self.ets) 79 | 80 | return img_next 81 | 82 | -------------------------------------------------------------------------------- /tool/dataset.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luping-liu/PNDM/e771949c344875ebecda46967de68fc86e970569/tool/dataset.sh -------------------------------------------------------------------------------- /tool/fid.sh: -------------------------------------------------------------------------------- 1 | for method in DDIM S-PNDM F-PNDM FON PF; 2 | do 3 | echo $method 4 | mkdir -p ./temp/sample 5 | mpiexec -np 4 python main.py --runner sample --method $method --config pf_deep_cifar10.yml --model_path temp/models/pf_deep_cifar10.ckpt 6 | pytorch-fid ./temp/sample ~/llp/Datasets/fid_cifar10_train.npz --device cuda:3 7 | mv ./temp/sample ./temp/pf_deep/$method 8 | done --------------------------------------------------------------------------------