├── .gitignore ├── models ├── __init__.py ├── flowpp │ ├── dequantization_cifar10.py │ ├── dequantization_imagenet64.py │ ├── dequantization_imagenet32.py │ ├── logistic.py │ ├── modules_imagenet64.py │ ├── modules_imagenet32.py │ └── modules_cifar10.py ├── normalization.py ├── layerspp.py ├── ncsnpp.py ├── utils.py ├── up_or_down_sampling.py └── layers.py ├── configs ├── vp │ ├── cifar10_ddpmpp_continuous.py │ ├── imagenet32_ddpmpp_continuous.py │ ├── cifar10_ddpmpp_deep_continuous.py │ └── imagenet32_ddpmpp_deep_continuous.py ├── subvp │ ├── cifar10_ddpmpp_continuous.py │ ├── imagenet32_ddpmpp_continuous.py │ ├── cifar10_ddpmpp_deep_continuous.py │ └── imagenet32_ddpmpp_deep_continuous.py ├── default_imagenet32_configs.py └── default_cifar10_configs.py ├── main.py ├── utils.py ├── likelihood.py ├── evaluation.py ├── requirements.txt ├── README.md ├── datasets.py ├── bound_likelihood.py ├── sde_lib.py └── losses.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | .idea/ 8 | 9 | # Python egg metadata, regenerated from source files by setuptools. 10 | /*.egg-info 11 | .eggs/ 12 | 13 | # PyPI distribution artifacts. 14 | build/ 15 | dist/ 16 | 17 | # Tests 18 | .pytest_cache/ 19 | 20 | # Other 21 | *.DS_Store 22 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /configs/vp/cifar10_ddpmpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'ode' 33 | sampling.smallest_time = 1e-3 34 | 35 | # data 36 | data = config.data 37 | data.centered = True 38 | 39 | # model 40 | model = config.model 41 | model.name = 'ncsnpp' 42 | model.scale_by_sigma = False 43 | model.ema_rate = 0.9999 44 | model.normalization = 'GroupNorm' 45 | model.nonlinearity = 'swish' 46 | model.nf = 128 47 | model.ch_mult = (1, 2, 2, 2) 48 | model.num_res_blocks = 4 49 | model.attn_resolutions = (16,) 50 | model.resamp_with_conv = True 51 | model.conditional = True 52 | model.fir = False 53 | model.fir_kernel = [1, 3, 3, 1] 54 | model.skip_rescale = True 55 | model.resblock_type = 'biggan' 56 | model.progressive = 'none' 57 | model.progressive_input = 'none' 58 | model.progressive_combine = 'sum' 59 | model.attention_type = 'ddpm' 60 | model.init_scale = 0. 61 | model.embedding_type = 'positional' 62 | model.fourier_scale = 16 63 | model.conv_size = 3 64 | 65 | return config 66 | -------------------------------------------------------------------------------- /configs/subvp/cifar10_ddpmpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'subvpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | training.smallest_time = 1e-2 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'ode' 34 | sampling.smallest_time = 1e-2 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = False 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'none' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.embedding_type = 'positional' 63 | model.fourier_scale = 16 64 | model.conv_size = 3 65 | 66 | return config 67 | -------------------------------------------------------------------------------- /configs/subvp/imagenet32_ddpmpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_imagenet32_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'subvpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | training.smallest_time = 1e-2 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'ode' 34 | sampling.smallest_time = 1e-2 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = False 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'none' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.embedding_type = 'positional' 63 | model.fourier_scale = 16 64 | model.conv_size = 3 65 | 66 | return config 67 | -------------------------------------------------------------------------------- /configs/vp/imagenet32_ddpmpp_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_imagenet32_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | training.smallest_time = 5e-5 if training.likelihood_weighting and training.importance_weighting else 1e-5 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'ode' 34 | sampling.smallest_time = 1e-3 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = False 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'none' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.embedding_type = 'positional' 63 | model.fourier_scale = 16 64 | model.conv_size = 3 65 | 66 | return config 67 | -------------------------------------------------------------------------------- /configs/vp/cifar10_ddpmpp_deep_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | training.n_iters = 950001 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'ode' 34 | sampling.smallest_time = 1e-3 35 | 36 | # eval 37 | evaluate = config.eval 38 | evaluate.begin_ckpt = 19 39 | evaluate.end_ckpt = 19 40 | evaluate.ckpt_id = 19 41 | 42 | # data 43 | data = config.data 44 | data.centered = True 45 | 46 | # model 47 | model = config.model 48 | model.name = 'ncsnpp' 49 | model.scale_by_sigma = False 50 | model.ema_rate = 0.9999 51 | model.normalization = 'GroupNorm' 52 | model.nonlinearity = 'swish' 53 | model.nf = 128 54 | model.ch_mult = (1, 2, 2, 2) 55 | model.num_res_blocks = 8 56 | model.attn_resolutions = (16,) 57 | model.resamp_with_conv = True 58 | model.conditional = True 59 | model.fir = False 60 | model.fir_kernel = [1, 3, 3, 1] 61 | model.skip_rescale = True 62 | model.resblock_type = 'biggan' 63 | model.progressive = 'none' 64 | model.progressive_input = 'none' 65 | model.progressive_combine = 'sum' 66 | model.attention_type = 'ddpm' 67 | model.init_scale = 0. 68 | model.embedding_type = 'positional' 69 | model.fourier_scale = 16 70 | model.conv_size = 3 71 | 72 | return config 73 | -------------------------------------------------------------------------------- /configs/subvp/cifar10_ddpmpp_deep_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'subvpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | training.n_iters = 950001 30 | training.smallest_time = 1e-2 31 | 32 | # sampling 33 | sampling = config.sampling 34 | sampling.method = 'ode' 35 | sampling.smallest_time = 1e-2 36 | 37 | # eval 38 | evaluate = config.eval 39 | evaluate.begin_ckpt = 19 40 | evaluate.end_ckpt = 19 41 | evaluate.ckpt_id = 19 42 | 43 | # data 44 | data = config.data 45 | data.centered = True 46 | 47 | # model 48 | model = config.model 49 | model.name = 'ncsnpp' 50 | model.scale_by_sigma = False 51 | model.ema_rate = 0.9999 52 | model.normalization = 'GroupNorm' 53 | model.nonlinearity = 'swish' 54 | model.nf = 128 55 | model.ch_mult = (1, 2, 2, 2) 56 | model.num_res_blocks = 8 57 | model.attn_resolutions = (16,) 58 | model.resamp_with_conv = True 59 | model.conditional = True 60 | model.fir = False 61 | model.fir_kernel = [1, 3, 3, 1] 62 | model.skip_rescale = True 63 | model.resblock_type = 'biggan' 64 | model.progressive = 'none' 65 | model.progressive_input = 'none' 66 | model.progressive_combine = 'sum' 67 | model.attention_type = 'ddpm' 68 | model.init_scale = 0. 69 | model.embedding_type = 'positional' 70 | model.fourier_scale = 16 71 | model.conv_size = 3 72 | 73 | return config 74 | -------------------------------------------------------------------------------- /configs/subvp/imagenet32_ddpmpp_deep_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_imagenet32_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'subvpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | training.n_iters = 950001 30 | training.smallest_time = 1e-2 31 | 32 | # sampling 33 | sampling = config.sampling 34 | sampling.method = 'ode' 35 | sampling.smallest_time = 1e-2 36 | 37 | # eval 38 | evaluate = config.eval 39 | evaluate.begin_ckpt = 19 40 | evaluate.end_ckpt = 19 41 | evaluate.ckpt_id = 19 42 | 43 | # data 44 | data = config.data 45 | data.centered = True 46 | 47 | # model 48 | model = config.model 49 | model.name = 'ncsnpp' 50 | model.scale_by_sigma = False 51 | model.ema_rate = 0.9999 52 | model.normalization = 'GroupNorm' 53 | model.nonlinearity = 'swish' 54 | model.nf = 128 55 | model.ch_mult = (1, 2, 2, 2) 56 | model.num_res_blocks = 8 57 | model.attn_resolutions = (16,) 58 | model.resamp_with_conv = True 59 | model.conditional = True 60 | model.fir = False 61 | model.fir_kernel = [1, 3, 3, 1] 62 | model.skip_rescale = True 63 | model.resblock_type = 'biggan' 64 | model.progressive = 'none' 65 | model.progressive_input = 'none' 66 | model.progressive_combine = 'sum' 67 | model.attention_type = 'ddpm' 68 | model.init_scale = 0. 69 | model.embedding_type = 'positional' 70 | model.fourier_scale = 16 71 | model.conv_size = 3 72 | 73 | return config 74 | -------------------------------------------------------------------------------- /configs/vp/imagenet32_ddpmpp_deep_continuous.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_imagenet32_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | training.n_iters = 950001 30 | training.smallest_time = 5e-5 if training.likelihood_weighting and training.importance_weighting else 1e-5 31 | 32 | # sampling 33 | sampling = config.sampling 34 | sampling.method = 'ode' 35 | sampling.smallest_time = 1e-3 36 | 37 | # eval 38 | evaluate = config.eval 39 | evaluate.begin_ckpt = 19 40 | evaluate.end_ckpt = 19 41 | evaluate.ckpt_id = 19 42 | 43 | # data 44 | data = config.data 45 | data.centered = True 46 | 47 | # model 48 | model = config.model 49 | model.name = 'ncsnpp' 50 | model.scale_by_sigma = False 51 | model.ema_rate = 0.9999 52 | model.normalization = 'GroupNorm' 53 | model.nonlinearity = 'swish' 54 | model.nf = 128 55 | model.ch_mult = (1, 2, 2, 2) 56 | model.num_res_blocks = 8 57 | model.attn_resolutions = (16,) 58 | model.resamp_with_conv = True 59 | model.conditional = True 60 | model.fir = False 61 | model.fir_kernel = [1, 3, 3, 1] 62 | model.skip_rescale = True 63 | model.resblock_type = 'biggan' 64 | model.progressive = 'none' 65 | model.progressive_input = 'none' 66 | model.progressive_combine = 'sum' 67 | model.attention_type = 'ddpm' 68 | model.init_scale = 0. 69 | model.embedding_type = 'positional' 70 | model.fourier_scale = 16 71 | model.conv_size = 3 72 | 73 | return config 74 | -------------------------------------------------------------------------------- /models/flowpp/dequantization_cifar10.py: -------------------------------------------------------------------------------- 1 | from .modules_cifar10 import * 2 | 3 | 4 | class ShallowProcessor(nn.Module): 5 | dropout_p: float = 0.2 6 | 7 | @nn.compact 8 | def __call__(self, x, train=False): 9 | # x is assumed to take values in [0, 1] 10 | x = x - 0.5 11 | (this, that), _ = CheckerboardSplit()(x, inverse=False) 12 | x = conv2d(self, jnp.concatenate([this, that], axis=3), name='proj', num_units=32) 13 | for i in range(3): 14 | x = gated_conv(self, x, name=f'c{i}', dropout_p=self.dropout_p, use_nin=False, a=None, train=train) 15 | return x 16 | 17 | 18 | class Dequantization(nn.Module): 19 | filters: int = 96 20 | components: int = 32 21 | blocks: int = 2 22 | attn_heads: int = 4 23 | dropout_p: float = 0. 24 | 25 | @nn.compact 26 | def __call__(self, eps, x, inverse=False, train=False): 27 | # x is assumed to take values in [0, 1] 28 | logp_eps = jnp.sum(-eps ** 2 / 2. - 0.5 * np.log(2 * np.pi), axis=(1, 2, 3)) 29 | 30 | coupling_params = dict( 31 | filters=self.filters, 32 | blocks=self.blocks, 33 | components=self.components, 34 | heads=self.attn_heads 35 | ) 36 | modules = [ 37 | CheckerboardSplit(), 38 | Norm(), Pointwise(), MixLogisticAttnCoupling(**coupling_params), TupleFlip(), 39 | Norm(), Pointwise(), MixLogisticAttnCoupling(**coupling_params), TupleFlip(), 40 | Norm(), Pointwise(), MixLogisticAttnCoupling(**coupling_params), TupleFlip(), 41 | Norm(), Pointwise(), MixLogisticAttnCoupling(**coupling_params), TupleFlip(), 42 | CheckerboardSplit(inverse_module=True), 43 | Sigmoid() 44 | ] 45 | 46 | context = ShallowProcessor(dropout_p=self.dropout_p)(x, train=train) 47 | 48 | if not inverse: 49 | logp_sum = 0. 50 | h = eps 51 | for module in modules: 52 | if isinstance(module, MixLogisticAttnCoupling): 53 | h, logp = module(h, context=context, inverse=inverse, train=train) 54 | else: 55 | h, logp = module(h, inverse=inverse) 56 | logp_sum = logp_sum + logp if logp is not None else logp_sum 57 | return h, logp_sum - logp_eps 58 | 59 | else: 60 | logp_sum = 0. 61 | h = eps 62 | for module in modules[::-1]: 63 | if isinstance(module, MixLogisticAttnCoupling): 64 | h, logp = module(h, context=context, inverse=inverse, train=train) 65 | else: 66 | h, logp = module(h, inverse=inverse) 67 | logp_sum = logp_sum + logp if logp is not None else logp_sum 68 | return h, logp_sum - logp_eps 69 | -------------------------------------------------------------------------------- /configs/default_imagenet32_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_default_configs(): 5 | config = ml_collections.ConfigDict() 6 | # training 7 | config.training = training = ml_collections.ConfigDict() 8 | config.training.batch_size = 128 9 | training.n_iters = 1300001 10 | training.snapshot_freq = 50000 11 | training.log_freq = 100 12 | training.eval_freq = 100 13 | ## store additional checkpoints for preemption in cloud computing environments 14 | training.snapshot_freq_for_preemption = 10000 15 | ## produce samples at each snapshot. 16 | training.snapshot_sampling = True 17 | training.likelihood_weighting = False 18 | training.importance_weighting = False 19 | training.continuous = True 20 | training.n_jitted_steps = 5 21 | training.reduce_mean = False 22 | training.smallest_time = 1e-5 23 | 24 | # sampling 25 | config.sampling = sampling = ml_collections.ConfigDict() 26 | sampling.n_steps_each = 1 27 | sampling.noise_removal = True 28 | sampling.probability_flow = False 29 | sampling.snr = 0.16 30 | 31 | # evaluation 32 | config.eval = evaluate = ml_collections.ConfigDict() 33 | evaluate.begin_ckpt = 26 34 | evaluate.end_ckpt = 26 35 | evaluate.ckpt_id = 26 36 | evaluate.batch_size = 1024 37 | evaluate.enable_sampling = False 38 | evaluate.num_samples = 50000 39 | evaluate.enable_loss = False 40 | evaluate.enable_bpd = True 41 | evaluate.bpd_dataset = 'test' 42 | evaluate.num_repeats = 1 43 | evaluate.bound = False 44 | evaluate.dsm = True 45 | evaluate.dequantizer = False 46 | evaluate.offset = True 47 | 48 | # variational dequantization 49 | config.deq = deq = ml_collections.ConfigDict() 50 | deq.n_iters = 300001 51 | deq.ema_rate = 0.99 52 | deq.dropout = 0.0 53 | deq.offset = True 54 | 55 | # data 56 | config.data = data = ml_collections.ConfigDict() 57 | data.dataset = 'ImageNet' 58 | data.image_size = 32 59 | data.random_flip = False 60 | data.centered = False 61 | data.uniform_dequantization = False 62 | data.num_channels = 3 63 | 64 | # model 65 | config.model = model = ml_collections.ConfigDict() 66 | model.sigma_min = 0.01 67 | model.sigma_max = 50 68 | model.num_scales = 1000 69 | model.beta_min = 0.1 70 | model.beta_max = 20. 71 | model.dropout = 0. 72 | model.embedding_type = 'fourier' 73 | model.data_init = False 74 | model.trainable_embedding = False 75 | 76 | # optimization 77 | config.optim = optim = ml_collections.ConfigDict() 78 | optim.weight_decay = 0 79 | optim.optimizer = 'Adam' 80 | optim.lr = 2e-4 81 | optim.beta1 = 0.9 82 | optim.eps = 1e-8 83 | optim.warmup = 5000 84 | optim.grad_clip = 1. 85 | 86 | config.seed = 42 87 | 88 | return config -------------------------------------------------------------------------------- /configs/default_cifar10_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | 3 | 4 | def get_default_configs(): 5 | config = ml_collections.ConfigDict() 6 | # training 7 | config.training = training = ml_collections.ConfigDict() 8 | config.training.batch_size = 128 9 | training.n_iters = 1300001 10 | training.snapshot_freq = 50000 11 | training.log_freq = 100 12 | training.eval_freq = 100 13 | ## store additional checkpoints for preemption in cloud computing environments 14 | training.snapshot_freq_for_preemption = 10000 15 | ## produce samples at each snapshot. 16 | training.snapshot_sampling = True 17 | training.likelihood_weighting = False 18 | training.importance_weighting = False 19 | training.continuous = True 20 | training.n_jitted_steps = 5 21 | training.reduce_mean = False 22 | training.smallest_time = 1e-5 23 | 24 | # variational dequantization 25 | config.deq = deq = ml_collections.ConfigDict() 26 | deq.n_iters = 300001 27 | deq.ema_rate = 0.99 28 | deq.dropout = 0.0 29 | deq.offset = True 30 | 31 | # sampling 32 | config.sampling = sampling = ml_collections.ConfigDict() 33 | sampling.n_steps_each = 1 34 | sampling.noise_removal = True 35 | sampling.probability_flow = False 36 | sampling.snr = 0.16 37 | sampling.smallest_time = 1e-3 38 | 39 | # evaluation 40 | config.eval = evaluate = ml_collections.ConfigDict() 41 | evaluate.begin_ckpt = 26 42 | evaluate.end_ckpt = 26 43 | evaluate.ckpt_id = 26 44 | evaluate.batch_size = 1024 45 | evaluate.enable_sampling = False 46 | evaluate.num_samples = 50000 47 | evaluate.enable_loss = False 48 | evaluate.enable_bpd = True 49 | evaluate.bpd_dataset = 'test' 50 | evaluate.num_repeats = 5 51 | evaluate.bound = False 52 | evaluate.dsm = True 53 | evaluate.dequantizer = False 54 | evaluate.offset = True 55 | 56 | # data 57 | config.data = data = ml_collections.ConfigDict() 58 | data.dataset = 'CIFAR10' 59 | data.image_size = 32 60 | data.random_flip = True 61 | data.centered = False 62 | data.uniform_dequantization = False 63 | data.num_channels = 3 64 | 65 | # model 66 | config.model = model = ml_collections.ConfigDict() 67 | model.sigma_min = 0.01 68 | model.sigma_max = 50 69 | model.num_scales = 1000 70 | model.beta_min = 0.1 71 | model.beta_max = 20. 72 | model.dropout = 0.1 73 | model.embedding_type = 'fourier' 74 | model.data_init = False 75 | model.trainable_embedding = False 76 | 77 | # optimization 78 | config.optim = optim = ml_collections.ConfigDict() 79 | optim.weight_decay = 0 80 | optim.optimizer = 'Adam' 81 | optim.lr = 2e-4 82 | optim.beta1 = 0.9 83 | optim.eps = 1e-8 84 | optim.warmup = 5000 85 | optim.grad_clip = 1. 86 | 87 | config.seed = 42 88 | 89 | return config -------------------------------------------------------------------------------- /models/flowpp/dequantization_imagenet64.py: -------------------------------------------------------------------------------- 1 | from .modules_imagenet64 import * 2 | from .modules_cifar10 import CheckerboardSplit, Norm, TupleFlip 3 | import numpy as np 4 | 5 | 6 | class DeepProcessor(nn.Module): 7 | dropout_p: float = 0. 8 | 9 | @nn.compact 10 | def __call__(self, x, train=False): 11 | # x is assumed to take values in [0, 1] 12 | x = x - 0.5 13 | (this, that), _ = CheckerboardSplit()(x, inverse=False) 14 | processed_context = conv2d(self, jnp.concatenate([this, that], axis=3), name='proj', num_units=32) 15 | B, H, W, C = processed_context.shape 16 | pos_emb = self.param('pos_emb_dq', jax.nn.initializers.normal(0.01), (H, W, C)) 17 | 18 | for i in range(5): 19 | processed_context = gated_resnet(self, processed_context, name=f'c{i}', 20 | dropout_p=self.dropout_p, use_nin=False, a=None, train=train) 21 | processed_context = norm(self, processed_context, name=f'dqln{i}') 22 | 23 | return processed_context 24 | 25 | 26 | class Dequantization(nn.Module): 27 | filters: int = 96 28 | components: int = 4 29 | blocks: int = 5 30 | attn_heads: int = 4 31 | dropout_p: float = 0. 32 | use_nin: bool = True 33 | use_ln: bool = True 34 | 35 | @nn.compact 36 | def __call__(self, eps, x, inverse=False, train=False): 37 | # x is assumed to take values in [0, 1] 38 | logp_eps = jnp.sum(-eps ** 2 / 2. - 0.5 * np.log(2 * np.pi), axis=(1, 2, 3)) 39 | 40 | coupling_params = dict( 41 | filters=self.filters, 42 | blocks=self.blocks, 43 | components=self.components, 44 | heads=self.attn_heads, 45 | use_nin=self.use_nin, 46 | use_ln=self.use_ln 47 | ) 48 | modules = [ 49 | CheckerboardSplit(), 50 | Norm(), MixLogisticCoupling(**coupling_params), TupleFlip(), 51 | Norm(), MixLogisticCoupling(**coupling_params), TupleFlip(), 52 | Norm(), MixLogisticCoupling(**coupling_params), TupleFlip(), 53 | Norm(), MixLogisticCoupling(**coupling_params), TupleFlip(), 54 | CheckerboardSplit(inverse_module=True), 55 | Sigmoid() 56 | ] 57 | 58 | context = DeepProcessor(dropout_p=self.dropout_p)(x, train=train) 59 | 60 | if not inverse: 61 | logp_sum = 0. 62 | h = eps 63 | for module in modules: 64 | if isinstance(module, MixLogisticCoupling): 65 | h, logp = module(h, context=context, inverse=inverse, train=train) 66 | else: 67 | h, logp = module(h, inverse=inverse) 68 | logp_sum = logp_sum + logp if logp is not None else logp_sum 69 | return h, logp_sum - logp_eps 70 | 71 | else: 72 | logp_sum = 0. 73 | h = eps 74 | for module in modules[::-1]: 75 | if isinstance(module, MixLogisticCoupling): 76 | h, logp = module(h, context=context, inverse=inverse, train=train) 77 | else: 78 | h, logp = module(h, inverse=inverse) 79 | logp_sum = logp_sum + logp if logp is not None else logp_sum 80 | return h, logp_sum - logp_eps 81 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Training and evaluation""" 17 | 18 | import run_lib 19 | from absl import app 20 | from absl import flags 21 | from ml_collections.config_flags import config_flags 22 | import tensorflow as tf 23 | import logging 24 | import os 25 | # import chex 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | config_flags.DEFINE_config_file( 30 | "config", None, "Training configuration.", lock_config=True) 31 | flags.DEFINE_string("workdir", None, "Work directory.") 32 | flags.DEFINE_enum("mode", None, ["train", "eval", "train_deq"], "Running mode: train or eval") 33 | flags.DEFINE_string("eval_folder", "eval_test_bpd", 34 | "The folder name for storing evaluation results") 35 | flags.DEFINE_string("deq_folder", "flowpp_dequantizer", "The folder name for dequantizer training.") 36 | flags.mark_flags_as_required(["workdir", "config", "mode"]) 37 | 38 | 39 | def main(argv): 40 | tf.config.experimental.set_visible_devices([], "GPU") 41 | os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 42 | 43 | if FLAGS.mode == "train": 44 | # Create the working directory 45 | tf.io.gfile.makedirs(FLAGS.workdir) 46 | # Set logger so that it outputs to both console and file 47 | # Make logging work for both disk and Google Cloud Storage 48 | gfile_stream = tf.io.gfile.GFile(os.path.join(FLAGS.workdir, 'stdout.txt'), 'w') 49 | handler = logging.StreamHandler(gfile_stream) 50 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') 51 | handler.setFormatter(formatter) 52 | logger = logging.getLogger() 53 | logger.addHandler(handler) 54 | logger.setLevel('INFO') 55 | # Run the training pipeline 56 | if FLAGS.mode == "train": 57 | run_lib.train(FLAGS.config, FLAGS.workdir) 58 | else: 59 | run_lib.deq_score_joint_train(FLAGS.config, FLAGS.workdir) 60 | elif FLAGS.mode == "eval": 61 | # Run the evaluation pipeline 62 | run_lib.evaluate(FLAGS.config, FLAGS.workdir, FLAGS.eval_folder, FLAGS.deq_folder) 63 | elif FLAGS.mode == "train_deq": 64 | assert tf.io.gfile.exists(FLAGS.workdir) 65 | new_workdir = os.path.join(FLAGS.workdir, FLAGS.deq_folder) 66 | tf.io.gfile.makedirs(new_workdir) 67 | gfile_stream = tf.io.gfile.GFile(os.path.join(new_workdir, 'stdout.txt'), 'w') 68 | handler = logging.StreamHandler(gfile_stream) 69 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') 70 | handler.setFormatter(formatter) 71 | logger = logging.getLogger() 72 | logger.addHandler(handler) 73 | logger.setLevel('INFO') 74 | # with chex.fake_pmap(): 75 | run_lib.train_deq(FLAGS.config, FLAGS.workdir, new_workdir) 76 | else: 77 | raise ValueError(f"Mode {FLAGS.mode} not recognized.") 78 | 79 | 80 | if __name__ == "__main__": 81 | app.run(main) 82 | -------------------------------------------------------------------------------- /models/flowpp/dequantization_imagenet32.py: -------------------------------------------------------------------------------- 1 | from .modules_imagenet32 import * 2 | from .modules_cifar10 import CheckerboardSplit, Norm, TupleFlip 3 | import numpy as np 4 | 5 | 6 | class DeepProcessor(nn.Module): 7 | dropout_p: float = 0. 8 | 9 | @nn.compact 10 | def __call__(self, x, train=False): 11 | # x is assumed to take values in [0, 1] 12 | x = x - 0.5 13 | (this, that), _ = CheckerboardSplit()(x, inverse=False) 14 | processed_context = conv2d(self, jnp.concatenate([this, that], axis=3), name='proj', num_units=32) 15 | B, H, W, C = processed_context.shape 16 | pos_emb = self.param('pos_emb_dq', jax.nn.initializers.normal(0.01), (H, W, C)) 17 | 18 | for i in range(8): 19 | processed_context = gated_resnet(self, processed_context, name=f'c{i}', 20 | dropout_p=self.dropout_p, use_nin=False, a=None, train=train) 21 | processed_context = norm(self, processed_context, name=f'dqln{i}') 22 | processed_context = attn(self, processed_context, name=f'dqattn{i}', pos_emb=pos_emb, heads=4, 23 | dropout_p=self.dropout_p, train=train) 24 | processed_context = norm(self, processed_context, name=f'ln{i}') 25 | 26 | return processed_context 27 | 28 | 29 | class Dequantization(nn.Module): 30 | filters: int = 128 31 | components: int = 32 32 | blocks: int = 8 33 | attn_heads: int = 4 34 | dropout_p: float = 0. 35 | use_nin: bool = True 36 | use_ln: bool = True 37 | 38 | @nn.compact 39 | def __call__(self, eps, x, inverse=False, train=False): 40 | # x is assumed to take values in [0, 1] 41 | logp_eps = jnp.sum(-eps ** 2 / 2. - 0.5 * np.log(2 * np.pi), axis=(1, 2, 3)) 42 | 43 | coupling_params = dict( 44 | filters=self.filters, 45 | blocks=self.blocks, 46 | components=self.components, 47 | heads=self.attn_heads, 48 | use_nin=self.use_nin, 49 | use_ln=self.use_ln 50 | ) 51 | modules = [ 52 | CheckerboardSplit(), 53 | Norm(), MixLogisticAttnCoupling(**coupling_params), TupleFlip(), 54 | Norm(), MixLogisticAttnCoupling(**coupling_params), TupleFlip(), 55 | Norm(), MixLogisticAttnCoupling(**coupling_params), TupleFlip(), 56 | Norm(), MixLogisticAttnCoupling(**coupling_params), TupleFlip(), 57 | CheckerboardSplit(inverse_module=True), 58 | CheckerboardSplit(), 59 | Norm(), MixLogisticAttnCoupling(**coupling_params), TupleFlip(), 60 | Norm(), MixLogisticAttnCoupling(**coupling_params), TupleFlip(), 61 | Norm(), MixLogisticAttnCoupling(**coupling_params), TupleFlip(), 62 | Norm(), MixLogisticAttnCoupling(**coupling_params), TupleFlip(), 63 | CheckerboardSplit(inverse_module=True), 64 | Sigmoid() 65 | ] 66 | 67 | context = DeepProcessor(dropout_p=self.dropout_p)(x, train=train) 68 | 69 | if not inverse: 70 | logp_sum = 0. 71 | h = eps 72 | for module in modules: 73 | if isinstance(module, MixLogisticAttnCoupling): 74 | h, logp = module(h, context=context, inverse=inverse, train=train) 75 | else: 76 | h, logp = module(h, inverse=inverse) 77 | logp_sum = logp_sum + logp if logp is not None else logp_sum 78 | return h, logp_sum - logp_eps 79 | 80 | else: 81 | logp_sum = 0. 82 | h = eps 83 | for module in modules[::-1]: 84 | if isinstance(module, MixLogisticAttnCoupling): 85 | h, logp = module(h, context=context, inverse=inverse, train=train) 86 | else: 87 | h, logp = module(h, inverse=inverse) 88 | logp_sum = logp_sum + logp if logp is not None else logp_sum 89 | return h, logp_sum - logp_eps 90 | -------------------------------------------------------------------------------- /models/flowpp/logistic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ported from https://github.com/aravindsrinivas/flowpp/blob/737fadb2218c1e2810a91b523498f97def2c30de/flows/logistic.py 3 | """ 4 | 5 | import jax.numpy as jnp 6 | import jax 7 | 8 | 9 | def logistic_logpdf(*, x, mean, logscale): 10 | """ 11 | log density of logistic distribution 12 | this operates elementwise 13 | """ 14 | z = (x - mean) * jnp.exp(-logscale) 15 | return z - logscale - 2 * jax.nn.softplus(z) 16 | 17 | 18 | def logistic_logcdf(*, x, mean, logscale): 19 | """ 20 | log cdf of logistic distribution 21 | this operates elementwise 22 | """ 23 | z = (x - mean) * jnp.exp(-logscale) 24 | return jax.nn.log_sigmoid(z) 25 | 26 | 27 | def mixlogistic_logpdf(*, x, prior_logits, means, logscales): 28 | """logpdf of a mixture of logistics""" 29 | assert len(x.shape) + 1 == len(prior_logits.shape) == len(means.shape) == len(logscales.shape) 30 | return jax.nn.logsumexp( 31 | jax.nn.log_softmax(prior_logits, axis=-1) + logistic_logpdf( 32 | x=jnp.expand_dims(x, -1), mean=means, logscale=logscales), 33 | axis=-1 34 | ) 35 | 36 | 37 | def mixlogistic_logcdf(*, x, prior_logits, means, logscales): 38 | """log cumulative distribution function of a mixture of logistics""" 39 | assert (len(x.shape) + 1 == len(prior_logits.shape) == len(means.shape) == len(logscales.shape)) 40 | return jax.nn.logsumexp( 41 | jax.nn.log_softmax(prior_logits, axis=-1) + logistic_logcdf( 42 | x=jnp.expand_dims(x, -1), mean=means, logscale=logscales), 43 | axis=-1 44 | ) 45 | 46 | 47 | def mixlogistic_sample(rng, *, prior_logits, means, logscales): 48 | # Sample mixture component 49 | rng, step_rng = jax.random.split(rng) 50 | sampled_inds = jnp.argmax( 51 | prior_logits - jnp.log(-jnp.log(jax.random.uniform(step_rng, prior_logits.shape, 52 | minval=1e-5, maxval=1. - 1e-5))), 53 | axis=-1 54 | ) 55 | sampled_onehot = jax.nn.one_hot(sampled_inds, prior_logits.shape[-1]) 56 | # Pull out the sampled mixture component 57 | means = jnp.sum(means * sampled_onehot, axis=-1) 58 | logscales = jnp.sum(logscales * sampled_onehot, axis=-1) 59 | # Sample from the component 60 | rng, step_rng = jax.random.split(rng) 61 | u = jax.random.uniform(step_rng, means.shape, minval=1e-5, maxval=1. - 1e-5) 62 | x = means + jnp.exp(logscales) * (jnp.log(u) - jnp.log(1. - u)) 63 | return x 64 | 65 | 66 | def mixlogistic_invcdf(*, y, prior_logits, means, logscales, tol=1e-12, max_bisection_iters=200, 67 | init_bounds_scale=200.): 68 | """inverse cumulative distribution function of a mixture of logistics""" 69 | assert len(y.shape) + 1 == len(prior_logits.shape) == len(means.shape) == len(logscales.shape) 70 | 71 | def body(carry, _): 72 | x, lb, ub = carry 73 | cur_y = jnp.exp(mixlogistic_logcdf(x=x, prior_logits=prior_logits, means=means, logscales=logscales)) 74 | new_x = jnp.where(cur_y > y, (x + lb) / 2., (x + ub) / 2.) 75 | new_lb = jnp.where(cur_y > y, lb, x) 76 | new_ub = jnp.where(cur_y > y, x, ub) 77 | diff = jnp.max(jnp.abs(new_x - x)) 78 | return (new_x, new_lb, new_ub), diff 79 | 80 | init_x = jnp.zeros_like(y) 81 | maxscales = jnp.sum(jnp.exp(logscales), axis=-1, keepdims=True) # sum of scales across mixture components 82 | init_lb = jnp.min(means - init_bounds_scale * maxscales, axis=-1) 83 | init_ub = jnp.max(means + init_bounds_scale * maxscales, axis=-1) 84 | 85 | (out_x, _, _), _ = jax.lax.scan(body, (init_x, init_lb, init_ub), jnp.arange(max_bisection_iters), 86 | length=max_bisection_iters) 87 | assert out_x.shape == y.shape 88 | return out_x 89 | -------------------------------------------------------------------------------- /models/flowpp/modules_imagenet64.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax.linen as nn 4 | from .modules_cifar10 import concat_elu, nin, gate, layernorm as norm, MixLogisticCDF, Sigmoid, ElemwiseAffine, conv2d 5 | from jax.experimental import host_callback 6 | from typing import Any 7 | 8 | 9 | def gated_resnet(self, x, *, name, a, nonlinearity=concat_elu, conv=conv2d, use_nin, dropout_p, train=False): 10 | num_filters = int(x.shape[-1]) 11 | 12 | c1 = conv(self, nonlinearity(x), name=f'{name}_c1', num_units=num_filters) 13 | if a is not None: # add short-cut connection if auxiliary input 'a' is given 14 | c1 += nin(self, nonlinearity(a), name=f'{name}_a_proj', num_units=num_filters) 15 | c1 = nonlinearity(c1) 16 | if dropout_p > 0: 17 | c1 = nn.Dropout(rate=dropout_p, deterministic=not train)(c1) 18 | 19 | c2 = (nin if use_nin else conv)(self, c1, name='c2', num_units=num_filters * 2, init_scale=0.1) 20 | return x + gate(c2, axis=3) 21 | 22 | 23 | class MixLogisticCoupling(nn.Module): 24 | """ 25 | CDF of mixture of logistics, followed by affine 26 | """ 27 | filters: int 28 | blocks: int 29 | components: int 30 | heads: int = 4 31 | init_scale: float = 0.1 32 | dropout_p: float = 0. 33 | use_nin: bool = True 34 | use_ln: bool = True 35 | with_affine: bool = True 36 | use_final_nin: bool = False 37 | nonlinearity: Any = concat_elu 38 | verbose: bool = True 39 | 40 | @nn.compact 41 | def __call__(self, x, context=None, inverse=False, train=False): 42 | def f(x, *, context=None): 43 | if not self.has_variable('params', 'pos_emb') and self.verbose: 44 | # debug stuff 45 | def tap_func(x, transforms): 46 | xmean = jnp.mean(x, axis=list(range(len(x.shape)))) 47 | xvar = jnp.var(x, axis=list(range(len(x.shape)))) 48 | print(f'shape: {jnp.shape(x)}') 49 | print(f'mean: {xmean}') 50 | print(f'std: {jnp.sqrt(xvar)}') 51 | print(f'min: {jnp.min(x)}') 52 | print(f'max: {jnp.max(x)}') 53 | 54 | x = host_callback.id_tap(tap_func, x) 55 | 56 | B, H, W, C = x.shape 57 | pos_emb = self.param('pos_emb', jax.nn.initializers.normal(stddev=0.01), [H, W, self.filters]) 58 | x = conv2d(self, x, name='c1', num_units=self.filters) 59 | for i_block in range(self.blocks): 60 | name = f'block{i_block}' 61 | x = gated_resnet(self, x, name=f'{name}_conv', a=context, use_nin=self.use_nin, dropout_p=self.dropout_p, train=train) 62 | if self.use_ln: 63 | x = norm(self, x, name=f'{name}_ln1') 64 | 65 | x = self.nonlinearity(x) 66 | x = (nin if self.use_final_nin else conv2d)( 67 | self, x, name=f'{name}_c2', num_units=C * (2 + 3 * self.components), init_scale=self.init_scale) 68 | 69 | assert x.shape == (B, H, W, C * (2 + 3 * self.components)) 70 | x = jnp.reshape(x, [B, H, W, C, 2 + 3 * self.components]) 71 | 72 | s, t = jnp.tanh(x[:, :, :, :, 0]), x[:, :, :, :, 1] 73 | ml_logits, ml_means, ml_logscales = jnp.split(x[:, :, :, :, 2:], 3, axis=4) 74 | 75 | assert s.shape == t.shape == (B, H, W, C) 76 | assert ml_logits.shape == ml_means.shape == ml_logscales.shape == (B, H, W, C, self.components) 77 | return ml_logits, ml_means, ml_logscales, s, t 78 | 79 | assert isinstance(x, tuple) 80 | cf, ef = x 81 | ml_logits, ml_means, ml_logscales, s, t = f(cf, context=context) 82 | logp_sum = 0. 83 | 84 | mixlogistic_cdf = MixLogisticCDF() 85 | sigmoid = Sigmoid(inverse_module=True) 86 | if self.with_affine: 87 | elementwise_affine = ElemwiseAffine() 88 | 89 | if not inverse: 90 | h, logp = mixlogistic_cdf(ef, logits=ml_logits, means=ml_means, logscales=ml_logscales, inverse=False) 91 | if logp is not None: 92 | logp_sum = logp_sum + logp 93 | h, logp = sigmoid(h, inverse=False) 94 | if logp is not None: 95 | logp_sum = logp_sum + logp 96 | if self.with_affine: 97 | h, logp = elementwise_affine(h, scales=jnp.exp(s), biases=t, logscales=s, inverse=False) 98 | if logp is not None: 99 | logp_sum = logp_sum + logp 100 | return (cf, h), logp_sum 101 | 102 | else: 103 | if self.with_affine: 104 | h, logp = elementwise_affine(ef, scales=jnp.exp(s), biases=t, logscales=s, inverse=True) 105 | if logp is not None: 106 | logp_sum = logp_sum + logp 107 | h, logp = sigmoid(h, inverse=True) 108 | if logp is not None: 109 | logp_sum = logp_sum + logp 110 | h, logp = mixlogistic_cdf(h, logits=ml_logits, means=ml_means, logscales=ml_logscales, inverse=True) 111 | if logp is not None: 112 | logp_sum = logp_sum + logp 113 | return (cf, h), logp_sum -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Utility code for generating and saving image grids and checkpointing. 18 | 19 | The `save_image` code is copied from 20 | https://github.com/google/flax/blob/master/examples/vae/utils.py, 21 | which is a JAX equivalent to the same function in TorchVision 22 | (https://github.com/pytorch/vision/blob/master/torchvision/utils.py) 23 | """ 24 | 25 | import math 26 | from typing import Any, Dict, Optional, TypeVar 27 | 28 | import flax 29 | import jax 30 | import jax.numpy as jnp 31 | from PIL import Image 32 | import tensorflow as tf 33 | from jax import numpy as jnp 34 | 35 | T = TypeVar("T") 36 | 37 | 38 | def batch_add(a, b): 39 | return jax.vmap(lambda a, b: a + b)(a, b) 40 | 41 | 42 | def batch_mul(a, b): 43 | return jax.vmap(lambda a, b: a * b)(a, b) 44 | 45 | 46 | def load_training_state(filepath, state): 47 | with tf.io.gfile.GFile(filepath, "rb") as f: 48 | state = flax.serialization.from_bytes(state, f.read()) 49 | return state 50 | 51 | 52 | def save_image(ndarray, fp, nrow=8, padding=2, pad_value=0.0, format=None): 53 | """Make a grid of images and save it into an image file. 54 | 55 | Pixel values are assumed to be within [0, 1]. 56 | 57 | Args: 58 | ndarray (array_like): 4D mini-batch images of shape (B x H x W x C). 59 | fp: A filename(string) or file object. 60 | nrow (int, optional): Number of images displayed in each row of the grid. 61 | The final grid size is ``(B / nrow, nrow)``. Default: ``8``. 62 | padding (int, optional): amount of padding. Default: ``2``. 63 | pad_value (float, optional): Value for the padded pixels. Default: ``0``. 64 | format(Optional): If omitted, the format to use is determined from the 65 | filename extension. If a file object was used instead of a filename, this 66 | parameter should always be used. 67 | """ 68 | if not (isinstance(ndarray, jnp.ndarray) or 69 | (isinstance(ndarray, list) and 70 | all(isinstance(t, jnp.ndarray) for t in ndarray))): 71 | raise TypeError("array_like of tensors expected, got {}".format( 72 | type(ndarray))) 73 | 74 | ndarray = jnp.asarray(ndarray) 75 | 76 | if ndarray.ndim == 4 and ndarray.shape[-1] == 1: # single-channel images 77 | ndarray = jnp.concatenate((ndarray, ndarray, ndarray), -1) 78 | 79 | # make the mini-batch of images into a grid 80 | nmaps = ndarray.shape[0] 81 | xmaps = min(nrow, nmaps) 82 | ymaps = int(math.ceil(float(nmaps) / xmaps)) 83 | height, width = int(ndarray.shape[1] + padding), int(ndarray.shape[2] + 84 | padding) 85 | num_channels = ndarray.shape[3] 86 | grid = jnp.full( 87 | (height * ymaps + padding, width * xmaps + padding, num_channels), 88 | pad_value).astype(jnp.float32) 89 | k = 0 90 | for y in range(ymaps): 91 | for x in range(xmaps): 92 | if k >= nmaps: 93 | break 94 | grid = jax.ops.index_update( 95 | grid, jax.ops.index[y * height + padding:(y + 1) * height, 96 | x * width + padding:(x + 1) * width], ndarray[k]) 97 | k = k + 1 98 | 99 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer 100 | ndarr = jnp.clip(grid * 255.0 + 0.5, 0, 255).astype(jnp.uint8) 101 | im = Image.fromarray(ndarr.copy()) 102 | im.save(fp, format=format) 103 | 104 | 105 | def flatten_dict(config): 106 | """Flatten a hierarchical dict to a simple dict.""" 107 | new_dict = {} 108 | for key, value in config.items(): 109 | if isinstance(value, dict): 110 | sub_dict = flatten_dict(value) 111 | for subkey, subvalue in sub_dict.items(): 112 | new_dict[key + "/" + subkey] = subvalue 113 | elif isinstance(value, tuple): 114 | new_dict[key] = str(value) 115 | else: 116 | new_dict[key] = value 117 | return new_dict 118 | 119 | 120 | def get_div_fn(fn): 121 | """Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator.""" 122 | 123 | def div_fn(x, t, eps): 124 | grad_fn = lambda data: jnp.sum(fn(data, t) * eps) 125 | grad_fn_eps = jax.grad(grad_fn)(x) 126 | return jnp.sum(grad_fn_eps * eps, axis=tuple(range(1, len(x.shape)))) 127 | 128 | return div_fn 129 | 130 | 131 | def get_value_div_fn(fn): 132 | """Return both the function value and its estimated divergence via Hutchinson's trace estimator.""" 133 | 134 | def value_div_fn(x, t, eps): 135 | def value_grad_fn(data): 136 | f = fn(data, t) 137 | return jnp.sum(f * eps), f 138 | grad_fn_eps, value = jax.grad(value_grad_fn, has_aux=True)(x) 139 | return value, jnp.sum(grad_fn_eps * eps, axis=tuple(range(1, len(x.shape)))) 140 | 141 | return value_div_fn 142 | -------------------------------------------------------------------------------- /likelihood.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | # pytype: skip-file 18 | """Various sampling methods.""" 19 | 20 | import jax 21 | import flax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | from scipy import integrate 25 | from models import utils as mutils 26 | from utils import get_div_fn, get_value_div_fn, batch_mul 27 | 28 | 29 | def get_likelihood_fn(sde, model, inverse_scaler, hutchinson_type='Rademacher', rtol=1e-5, atol=1e-5, method='RK45', 30 | eps=1e-5): 31 | """Create a function to compute the unbiased log-likelihood estimate of a given data point. 32 | 33 | Args: 34 | sde: A `sde_lib.SDE` object that represents the forward SDE. 35 | model: A `flax.linen.Module` object that represents the architecture of the score-based model. 36 | inverse_scaler: The inverse data normalizer. 37 | hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator. 38 | rtol: A `float` number. The relative tolerance level of the black-box ODE solver. 39 | atol: A `float` number. The absolute tolerance level of the black-box ODE solver. 40 | method: A `str`. The algorithm for the black-box ODE solver. 41 | See documentation for `scipy.integrate.solve_ivp`. 42 | eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability. 43 | 44 | Returns: 45 | A function that takes random states, replicated training states, and a batch of data points 46 | and returns the log-likelihoods in bits/dim, the latent code, and the number of function 47 | evaluations cost by computation. 48 | """ 49 | 50 | def drift_fn(state, x, t): 51 | """The drift function of the reverse-time SDE.""" 52 | score_fn = mutils.get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=True) 53 | # Probability flow ODE is a special case of Reverse SDE 54 | rsde = sde.reverse(score_fn, probability_flow=True) 55 | return rsde.sde(x, t)[0] 56 | 57 | @jax.pmap 58 | def p_value_div_fn(state, x, t, eps): 59 | """Pmapped divergence of the drift function.""" 60 | value_div_fn = get_value_div_fn(lambda x, t: drift_fn(state, x, t)) 61 | return value_div_fn(x, t, eps) 62 | 63 | p_prior_logp_fn = jax.pmap(sde.prior_logp) # Pmapped log-PDF of the SDE's prior distribution 64 | 65 | p_marginal_prob = jax.pmap(sde.marginal_prob) 66 | 67 | def likelihood_fn(prng, pstate, data): 68 | """Compute an unbiased estimate to the log-likelihood in bits/dim. 69 | 70 | Args: 71 | prng: An array of random states. The list dimension equals the number of devices. 72 | pstate: Replicated training state for running on multiple devices. 73 | data: A JAX array of shape [#devices, batch size, ...]. 74 | 75 | Returns: 76 | bpd: A JAX array of shape [#devices, batch size]. The log-likelihoods on `data` in bits/dim. 77 | z: A JAX array of the same shape as `data`. The latent representation of `data` under the 78 | probability flow ODE. 79 | nfe: An integer. The number of function evaluations used for running the black-box ODE solver. 80 | """ 81 | rng, step_rng = jax.random.split(flax.jax_utils.unreplicate(prng)) 82 | shape = data.shape 83 | if hutchinson_type == 'Gaussian': 84 | epsilon = jax.random.normal(step_rng, shape) 85 | elif hutchinson_type == 'Rademacher': 86 | epsilon = jax.random.randint(step_rng, shape, 87 | minval=0, maxval=2).astype(jnp.float32) * 2 - 1 88 | else: 89 | raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.") 90 | 91 | def ode_func(t, x): 92 | sample = mutils.from_flattened_numpy(x[:-shape[0] * shape[1]], shape) 93 | vec_t = jnp.ones((sample.shape[0], sample.shape[1])) * t 94 | drift, logp_grad = p_value_div_fn(pstate, sample, vec_t, epsilon) 95 | drift = mutils.to_flattened_numpy(drift) 96 | logp_grad = mutils.to_flattened_numpy(logp_grad) 97 | return np.concatenate([drift, logp_grad], axis=0) 98 | 99 | init = jnp.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0] * shape[1],))], axis=0) 100 | solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method) 101 | nfe = solution.nfev 102 | t = solution.t 103 | zp = jnp.asarray(solution.y[:, -1]) 104 | z = mutils.from_flattened_numpy(zp[:-shape[0] * shape[1]], shape) 105 | delta_logp = zp[-shape[0] * shape[1]:].reshape((shape[0], shape[1])) 106 | prior_logp = p_prior_logp_fn(z) 107 | 108 | bpd = -(prior_logp + delta_logp) 109 | 110 | N = np.prod(shape[2:]) 111 | bpd = bpd / N / np.log(2.) 112 | 113 | # A hack to convert log-likelihoods to bits/dim 114 | # based on the gradient of the inverse data normalizer. 115 | offset = jnp.log2(jax.grad(inverse_scaler)(0.)) + 8. 116 | bpd += offset 117 | return bpd, z, t, nfe, solution 118 | 119 | return likelihood_fn -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Utility functions for computing FID/Inception scores.""" 17 | 18 | import jax 19 | import numpy as np 20 | import six 21 | import tensorflow as tf 22 | import tensorflow_gan as tfgan 23 | import tensorflow_hub as tfhub 24 | 25 | INCEPTION_TFHUB = 'https://tfhub.dev/tensorflow/tfgan/eval/inception/1' 26 | INCEPTION_OUTPUT = 'logits' 27 | INCEPTION_FINAL_POOL = 'pool_3' 28 | _DEFAULT_DTYPES = { 29 | INCEPTION_OUTPUT: tf.float32, 30 | INCEPTION_FINAL_POOL: tf.float32 31 | } 32 | INCEPTION_DEFAULT_IMAGE_SIZE = 299 33 | 34 | 35 | def get_inception_model(inceptionv3=False): 36 | if inceptionv3: 37 | return tfhub.load( 38 | 'https://tfhub.dev/google/imagenet/inception_v3/feature_vector/4') 39 | else: 40 | return tfhub.load(INCEPTION_TFHUB) 41 | 42 | 43 | def load_dataset_stats(config): 44 | """Load the pre-computed dataset statistics.""" 45 | if config.data.dataset == 'CIFAR10': 46 | filename = 'assets/stats/cifar10_stats.npz' 47 | elif config.data.dataset == 'CELEBA': 48 | filename = 'assets/stats/celeba_stats.npz' 49 | elif config.data.dataset == 'LSUN': 50 | filename = f'assets/stats/lsun_{config.data.category}_{config.data.image_size}_stats.npz' 51 | elif config.data.dataset == 'ImageNet': 52 | filename = f'assets/stats/imagenet{config.data.image_size}_stats.npz' 53 | else: 54 | raise ValueError(f'Dataset {config.data.dataset} stats not found.') 55 | 56 | with tf.io.gfile.GFile(filename, 'rb') as fin: 57 | stats = np.load(fin) 58 | return stats 59 | 60 | 61 | def classifier_fn_from_tfhub(output_fields, inception_model, 62 | return_tensor=False): 63 | """Returns a function that can be as a classifier function. 64 | 65 | Copied from tfgan but avoid loading the model each time calling _classifier_fn 66 | 67 | Args: 68 | output_fields: A string, list, or `None`. If present, assume the module 69 | outputs a dictionary, and select this field. 70 | inception_model: A model loaded from TFHub. 71 | return_tensor: If `True`, return a single tensor instead of a dictionary. 72 | 73 | Returns: 74 | A one-argument function that takes an image Tensor and returns outputs. 75 | """ 76 | if isinstance(output_fields, six.string_types): 77 | output_fields = [output_fields] 78 | 79 | def _classifier_fn(images): 80 | output = inception_model(images) 81 | if output_fields is not None: 82 | output = {x: output[x] for x in output_fields} 83 | if return_tensor: 84 | assert len(output) == 1 85 | output = list(output.values())[0] 86 | return tf.nest.map_structure(tf.compat.v1.layers.flatten, output) 87 | 88 | return _classifier_fn 89 | 90 | 91 | @tf.function 92 | def run_inception_jit(inputs, 93 | inception_model, 94 | num_batches=1, 95 | inceptionv3=False): 96 | """Running the inception network. Assuming input is within [0, 255].""" 97 | if not inceptionv3: 98 | inputs = (tf.cast(inputs, tf.float32) - 127.5) / 127.5 99 | else: 100 | inputs = tf.cast(inputs, tf.float32) / 255. 101 | 102 | return tfgan.eval.run_classifier_fn( 103 | inputs, 104 | num_batches=num_batches, 105 | classifier_fn=classifier_fn_from_tfhub(None, inception_model), 106 | dtypes=_DEFAULT_DTYPES) 107 | 108 | 109 | @tf.function 110 | def run_inception_distributed(input_tensor, 111 | inception_model, 112 | num_batches=1, 113 | inceptionv3=False): 114 | """Distribute the inception network computation to all available TPUs. 115 | 116 | Args: 117 | input_tensor: The input images. Assumed to be within [0, 255]. 118 | inception_model: The inception network model obtained from `tfhub`. 119 | num_batches: The number of batches used for dividing the input. 120 | inceptionv3: If `True`, use InceptionV3, otherwise use InceptionV1. 121 | 122 | Returns: 123 | A dictionary with key `pool_3` and `logits`, representing the pool_3 and 124 | logits of the inception network respectively. 125 | """ 126 | num_tpus = jax.local_device_count() 127 | input_tensors = tf.split(input_tensor, num_tpus, axis=0) 128 | pool3 = [] 129 | logits = [] if not inceptionv3 else None 130 | device_format = '/TPU:{}' if 'TPU' in str(jax.devices()[0]) else '/GPU:{}' 131 | for i, tensor in enumerate(input_tensors): 132 | with tf.device(device_format.format(i)): 133 | tensor_on_device = tf.identity(tensor) 134 | res = run_inception_jit( 135 | tensor_on_device, inception_model, num_batches=num_batches, 136 | inceptionv3=inceptionv3) 137 | 138 | if not inceptionv3: 139 | pool3.append(res['pool_3']) 140 | logits.append(res['logits']) # pytype: disable=attribute-error 141 | else: 142 | pool3.append(res) 143 | 144 | with tf.device('/CPU'): 145 | return { 146 | 'pool_3': tf.concat(pool3, axis=0), 147 | 'logits': tf.concat(logits, axis=0) if not inceptionv3 else None 148 | } 149 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.12.0 2 | aiohttp==3.7.4.post0 3 | appdirs==1.4.4 4 | argon2-cffi==20.1.0 5 | astunparse==1.6.3 6 | async-generator==1.10 7 | async-timeout==3.0.1 8 | atomicwrites==1.4.0 9 | attrs==19.3.0 10 | Automat==0.8.0 11 | ax-platform==0.2.0 12 | backcall==0.2.0 13 | bleach==3.3.0 14 | blinker==1.4 15 | botorch==0.5.0 16 | cachetools==4.2.2 17 | certifi==2020.12.5 18 | cffi==1.14.5 19 | chardet==4.0.0 20 | chex==0.0.7 21 | Click==7.0 22 | cloud-init==21.1 23 | cloud-tpu-client==0.10 24 | cloudpickle==1.6.0 25 | clu==0.0.3 26 | colorama==0.4.3 27 | command-not-found==0.3 28 | ConfigArgParse==1.5.2 29 | configobj==5.0.6 30 | configparser==5.0.2 31 | constantly==15.1.0 32 | contextlib2==0.6.0.post1 33 | coverage==5.5 34 | coveralls==3.2.0 35 | cryptography==2.8 36 | cycler==0.10.0 37 | Cython==0.29.23 38 | dbus-python==1.2.16 39 | decorator==4.4.2 40 | defusedxml==0.7.1 41 | dill==0.3.3 42 | distlib==0.3.1 43 | distro==1.4.0 44 | distro-info===0.23ubuntu1 45 | dm-tree==0.1.6 46 | docker-pycreds==0.4.0 47 | docopt==0.6.2 48 | einops==0.3.0 49 | entrypoints==0.3 50 | execnet==1.9.0 51 | fancycompleter==0.9.1 52 | filelock==3.0.12 53 | flatbuffers==1.12 54 | flax==0.3.3 55 | fsspec==2021.6.1 56 | future==0.18.2 57 | gast==0.4.0 58 | gdown==3.13.0 59 | gitdb==4.0.7 60 | GitPython==3.1.17 61 | google-api-core==1.26.3 62 | google-api-python-client==1.8.0 63 | google-auth==1.30.0 64 | google-auth-httplib2==0.1.0 65 | google-auth-oauthlib==0.4.4 66 | google-pasta==0.2.0 67 | googleapis-common-protos==1.53.0 68 | gpytorch==1.5.0 69 | grpcio==1.34.1 70 | h5py==3.1.0 71 | httplib2==0.19.1 72 | hyperlink==19.0.0 73 | idna==2.10 74 | imageio==2.9.0 75 | importlib-metadata==1.5.0 76 | importlib-resources==5.1.3 77 | incremental==16.10.1 78 | ipykernel==5.5.5 79 | ipython==7.23.1 80 | ipython-genutils==0.2.0 81 | ipywidgets==7.6.3 82 | jax==0.2.18 83 | jaxlib==0.1.69 84 | jedi==0.17.2 85 | Jinja2==2.10.1 86 | joblib==1.0.1 87 | jsonpatch==1.22 88 | jsonpointer==2.0 89 | jsonschema==3.2.0 90 | jupyter-client==6.1.12 91 | jupyter-core==4.7.1 92 | jupyter-http-over-ws==0.0.8 93 | jupyterlab-pygments==0.1.2 94 | jupyterlab-widgets==1.0.0 95 | Keras-Applications==1.0.8 96 | keras-nightly==2.5.0.dev2021032900 97 | Keras-Preprocessing==1.1.2 98 | keyring==18.0.1 99 | kiwisolver==1.3.1 100 | language-selector==0.1 101 | launchpadlib==1.10.13 102 | lazr.restfulclient==0.14.2 103 | lazr.uri==1.0.3 104 | libtpu-nightly==0.1.dev20210709 105 | llvmlite==0.36.0 106 | Markdown==3.3.4 107 | MarkupSafe==1.1.0 108 | matplotlib==3.4.2 109 | matplotlib-inline==0.1.2 110 | mistune==0.8.4 111 | ml-collections==0.1.0 112 | mock==4.0.3 113 | more-itertools==4.2.0 114 | msgpack==1.0.2 115 | multidict==5.1.0 116 | nbclient==0.5.3 117 | nbconvert==6.0.7 118 | nbformat==5.1.3 119 | nest-asyncio==1.5.1 120 | netifaces==0.10.4 121 | networkx==2.5.1 122 | notebook==6.3.0 123 | numba==0.53.1 124 | numpy==1.19.5 125 | oauth2client==4.1.3 126 | oauthlib==3.1.0 127 | odl==0.7.0 128 | opt-einsum==3.3.0 129 | optax==0.0.9 130 | packaging==20.9 131 | pandas==1.2.4 132 | pandocfilters==1.4.3 133 | parso==0.7.1 134 | pathtools==0.1.2 135 | pdbpp==0.10.2 136 | pep8==1.7.1 137 | pexpect==4.6.0 138 | pickleshare==0.7.5 139 | Pillow==8.2.0 140 | piq==0.5.5 141 | plotly==5.1.0 142 | pluggy==0.7.1 143 | prometheus-client==0.10.1 144 | promise==2.3 145 | prompt-toolkit==3.0.18 146 | protobuf==3.15.8 147 | psutil==5.8.0 148 | ptyprocess==0.7.0 149 | py==1.10.0 150 | pyasn1==0.4.8 151 | pyasn1-modules==0.2.8 152 | pycparser==2.20 153 | pyDeprecate==0.3.0 154 | Pygments==2.9.0 155 | PyGObject==3.36.0 156 | PyHamcrest==1.9.0 157 | PyJWT==1.7.1 158 | pymacaroons==0.13.0 159 | PyNaCl==1.3.0 160 | pynufft==2021.2.0 161 | pyOpenSSL==19.0.0 162 | pyparsing==2.4.7 163 | pyrepl==0.9.0 164 | pyrsistent==0.15.5 165 | pyserial==3.4 166 | PySocks==1.7.1 167 | pytest==3.6.4 168 | pytest-cache==1.0 169 | pytest-cov==2.9.0 170 | pytest-pep8==1.0.6 171 | python-apt==2.0.0+ubuntu0.20.4.4 172 | python-dateutil==2.8.1 173 | python-debian===0.1.36ubuntu1 174 | pytorch-lightning==1.3.8 175 | pytz==2021.1 176 | PyWavelets==1.1.1 177 | PyYAML==5.4.1 178 | pyzmq==22.0.3 179 | requests==2.25.1 180 | requests-oauthlib==1.3.0 181 | requests-unixsocket==0.2.0 182 | rsa==4.7.2 183 | runstats==2.0.0 184 | scikit-image==0.18.2 185 | scikit-learn==0.24.2 186 | scipy==1.6.3 187 | seaborn==0.11.1 188 | SecretStorage==2.3.1 189 | Send2Trash==1.5.0 190 | sentry-sdk==1.1.0 191 | service-identity==18.1.0 192 | shortuuid==1.0.1 193 | sigpy==0.1.23 194 | SimpleITK==2.0.2 195 | simplejson==3.16.0 196 | six==1.15.0 197 | smmap==4.0.0 198 | sos==4.1 199 | ssh-import-id==5.10 200 | subprocess32==3.5.4 201 | systemd-python==234 202 | tenacity==7.0.0 203 | tensorboard==2.4.1 204 | tensorboard-data-server==0.6.0 205 | tensorboard-plugin-wit==1.8.0 206 | tensorflow==2.5.0 207 | tensorflow-addons==0.13.0 208 | tensorflow-datasets==4.3.0 209 | tensorflow-estimator==2.5.0 210 | tensorflow-gan==2.0.0 211 | tensorflow-hub==0.12.0 212 | tensorflow-io==0.18.0 213 | tensorflow-io-gcs-filesystem==0.18.0 214 | tensorflow-metadata==0.30.0 215 | tensorflow-probability==0.12.2 216 | termcolor==1.1.0 217 | terminado==0.9.5 218 | testpath==0.4.4 219 | tf-estimator-nightly==2.5.0.dev2021032601 220 | tf-nightly==2.6.0 221 | threadpoolctl==2.1.0 222 | tifffile==2021.6.14 223 | toolz==0.11.1 224 | torch==1.8.1 225 | torch-xla==1.8.1 226 | torchmetrics==0.4.0 227 | torchvision==0.9.1 228 | tornado==6.1 229 | tqdm==4.60.0 230 | traitlets==5.0.5 231 | Twisted==18.9.0 232 | typeguard==2.12.0 233 | typing-extensions==3.7.4.3 234 | ubuntu-advantage-tools==20.3 235 | ufw==0.36 236 | unattended-upgrades==0.1 237 | uritemplate==3.0.1 238 | urllib3==1.26.4 239 | virtualenv==20.4.4 240 | wadllib==1.3.3 241 | wandb==0.10.30 242 | wcwidth==0.2.5 243 | webencodings==0.5.1 244 | Werkzeug==1.0.1 245 | widgetsnbextension==3.5.1 246 | wmctrl==0.4 247 | wrapt==1.12.1 248 | yarl==1.6.3 249 | zipp==1.0.0 250 | zope.interface==4.7.1 251 | -------------------------------------------------------------------------------- /models/flowpp/modules_imagenet32.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | import flax.linen as nn 4 | from .modules_cifar10 import concat_elu, nin, gate, layernorm as norm, gated_attn as attn, MixLogisticCDF, Sigmoid, ElemwiseAffine 5 | from jax.experimental import host_callback 6 | from typing import Any 7 | 8 | 9 | def conv2d(self, x, *, name, num_units, filter_size=(3, 3), stride=(1, 1), pad='SAME', init_scale=1.): 10 | # use weight normalization (Salimans & Kingma, 2016) 11 | assert len(x.shape) == 4 12 | 13 | _V = self.param(f'{name}_V', jax.nn.initializers.normal(0.05), [*filter_size, x.shape[-1], num_units]) 14 | vnorm = _V * jax.lax.rsqrt(jnp.maximum(jnp.sum(jnp.square(_V), axis=(0, 1, 2)), 1e-12)) 15 | 16 | def g_initializer(rng, x): 17 | g = jnp.ones((num_units,)) 18 | W = g[None, None, None, :] * vnorm 19 | x = jax.lax.conv_general_dilated(x, W, window_strides=stride, padding=pad, 20 | dimension_numbers=('NHWC', 'HWIO', 'NHWC')) 21 | m_init = jax.lax.pmean(jnp.mean(x, axis=(0, 1, 2)), axis_name='batch') 22 | m2_init = jax.lax.pmean(jnp.mean(x ** 2, axis=(0, 1, 2)), axis_name='batch') 23 | v_init = m2_init - m_init ** 2 24 | scale_init = init_scale * jax.lax.rsqrt(v_init + 1e-6) 25 | return g * scale_init 26 | 27 | g = self.param(f'{name}_g', g_initializer, x) 28 | 29 | def b_initializer(rng, x): 30 | g = jnp.ones((num_units,)) 31 | W = g[None, None, None, :] * vnorm 32 | x = jax.lax.conv_general_dilated(x, W, window_strides=stride, padding=pad, 33 | dimension_numbers=('NHWC', 'HWIO', 'NHWC')) 34 | m_init = jax.lax.pmean(jnp.mean(x, axis=(0, 1, 2)), axis_name='batch') 35 | m2_init = jax.lax.pmean(jnp.mean(x ** 2, axis=(0, 1, 2)), axis_name='batch') 36 | v_init = m2_init - m_init ** 2 37 | scale_init = init_scale * jax.lax.rsqrt(v_init + 1e-6) 38 | m_init = jax.lax.pmean(jnp.mean(x, axis=(0, 1, 2)), axis_name='batch') 39 | return -m_init * scale_init 40 | 41 | b = self.param(f'{name}_b', b_initializer, x) 42 | W = g[None, None, None, :] * vnorm 43 | 44 | # calculate convolutional layer output 45 | x = jax.lax.conv_general_dilated(x, W, window_strides=stride, padding=pad, 46 | dimension_numbers=('NHWC', 'HWIO', 'NHWC')) + b[None, None, None, :] 47 | 48 | return x 49 | 50 | 51 | def gated_resnet(self, x, *, name, a, nonlinearity=concat_elu, conv=conv2d, use_nin, dropout_p, train=False): 52 | num_filters = int(x.shape[-1]) 53 | 54 | c1 = conv(self, nonlinearity(x), name=f'{name}_c1', num_units=num_filters) 55 | if a is not None: # add short-cut connection if auxiliary input 'a' is given 56 | c1 += nin(self, nonlinearity(a), name=f'{name}_a_proj', num_units=num_filters) 57 | c1 = nonlinearity(c1) 58 | if dropout_p > 0: 59 | c1 = nn.Dropout(rate=dropout_p, deterministic=not train)(c1) 60 | 61 | c2 = (nin if use_nin else conv)(self, c1, name=f'{name}_c2', num_units=num_filters * 2, init_scale=0.1) 62 | return x + gate(c2, axis=3) 63 | 64 | 65 | class MixLogisticAttnCoupling(nn.Module): 66 | """ 67 | CDF of mixture of logistics, followed by affine 68 | """ 69 | filters: int 70 | blocks: int 71 | components: int 72 | heads: int = 4 73 | init_scale: float = 0.1 74 | dropout_p: float = 0. 75 | use_nin: bool = True 76 | use_ln: bool = True 77 | with_affine: bool = True 78 | use_final_nin: bool = False 79 | nonlinearity: Any = concat_elu 80 | verbose: bool = True 81 | 82 | @nn.compact 83 | def __call__(self, x, context=None, inverse=False, train=False): 84 | def f(x, *, context=None): 85 | if not self.has_variable('params', 'pos_emb') and self.verbose: 86 | # debug stuff 87 | def tap_func(x, transforms): 88 | xmean = jnp.mean(x, axis=list(range(len(x.shape)))) 89 | xvar = jnp.var(x, axis=list(range(len(x.shape)))) 90 | print(f'shape: {jnp.shape(x)}') 91 | print(f'mean: {xmean}') 92 | print(f'std: {jnp.sqrt(xvar)}') 93 | print(f'min: {jnp.min(x)}') 94 | print(f'max: {jnp.max(x)}') 95 | 96 | x = host_callback.id_tap(tap_func, x) 97 | 98 | B, H, W, C = x.shape 99 | pos_emb = self.param('pos_emb', jax.nn.initializers.normal(stddev=0.01), [H, W, self.filters]) 100 | x = conv2d(self, x, name='c1', num_units=self.filters) 101 | for i_block in range(self.blocks): 102 | name = f'block{i_block}' 103 | x = gated_resnet(self, x, name=f'{name}_conv', a=context, use_nin=self.use_nin, dropout_p=self.dropout_p, train=train) 104 | if self.use_ln: 105 | x = norm(self, x, name=f'{name}_ln1') 106 | x = attn(self, x, name=f'{name}_attn', pos_emb=pos_emb, heads=self.heads, dropout_p=self.dropout_p, 107 | train=train) 108 | if self.use_ln: 109 | x = norm(self, x, name=f'{name}_ln2') 110 | 111 | assert x.shape == (B, H, W, self.filters) 112 | x = self.nonlinearity(x) 113 | x = (nin if self.use_final_nin else conv2d)( 114 | self, x, name=f'{name}_c2', num_units=C * (2 + 3 * self.components), init_scale=self.init_scale) 115 | 116 | assert x.shape == (B, H, W, C * (2 + 3 * self.components)) 117 | x = jnp.reshape(x, [B, H, W, C, 2 + 3 * self.components]) 118 | 119 | s, t = jnp.tanh(x[:, :, :, :, 0]), x[:, :, :, :, 1] 120 | ml_logits, ml_means, ml_logscales = jnp.split(x[:, :, :, :, 2:], 3, axis=4) 121 | 122 | assert s.shape == t.shape == (B, H, W, C) 123 | assert ml_logits.shape == ml_means.shape == ml_logscales.shape == (B, H, W, C, self.components) 124 | return ml_logits, ml_means, ml_logscales, s, t 125 | 126 | assert isinstance(x, tuple) 127 | cf, ef = x 128 | ml_logits, ml_means, ml_logscales, s, t = f(cf, context=context) 129 | logp_sum = 0. 130 | 131 | mixlogistic_cdf = MixLogisticCDF() 132 | sigmoid = Sigmoid(inverse_module=True) 133 | if self.with_affine: 134 | elementwise_affine = ElemwiseAffine() 135 | 136 | if not inverse: 137 | h, logp = mixlogistic_cdf(ef, logits=ml_logits, means=ml_means, logscales=ml_logscales, inverse=False) 138 | if logp is not None: 139 | logp_sum = logp_sum + logp 140 | h, logp = sigmoid(h, inverse=False) 141 | if logp is not None: 142 | logp_sum = logp_sum + logp 143 | if self.with_affine: 144 | h, logp = elementwise_affine(h, scales=jnp.exp(s), biases=t, logscales=s, inverse=False) 145 | if logp is not None: 146 | logp_sum = logp_sum + logp 147 | return (cf, h), logp_sum 148 | 149 | else: 150 | if self.with_affine: 151 | h, logp = elementwise_affine(ef, scales=jnp.exp(s), biases=t, logscales=s, inverse=True) 152 | if logp is not None: 153 | logp_sum = logp_sum + logp 154 | h, logp = sigmoid(h, inverse=True) 155 | if logp is not None: 156 | logp_sum = logp_sum + logp 157 | h, logp = mixlogistic_cdf(h, logits=ml_logits, means=ml_means, logscales=ml_logscales, inverse=True) 158 | if logp is not None: 159 | logp_sum = logp_sum + logp 160 | return (cf, h), logp_sum -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Maximum Likelihood Training of Score-Based Diffusion Models 2 | 3 | This repo contains the official implementation for the paper [Maximum Likelihood Training of Score-Based Diffusion Models](https://arxiv.org/abs/2101.09258) 4 | 5 | by [Yang Song](https://yang-song.github.io)\*, [Conor Durkan](https://conormdurkan.github.io/)\*, [Iain Murray](https://homepages.inf.ed.ac.uk/imurray2/), and [Stefano Ermon](https://cs.stanford.edu/~ermon/). Published in NeurIPS 2021 (spotlight). 6 | 7 | -------------------- 8 | 9 | We prove the connection between the Kullback–Leibler divergence and the weighted combination of score matching losses used for training score-based generative models. Our results can be viewed as a generalization of both the de Bruijn identity in information theory and the evidence lower bound in variational inference. 10 | 11 | Our theoretical results enable *ScoreFlow*, a continuous normalizing flow model trained with a variational objective, which is much more efficient than neural ODEs. We report the state-of-the-art likelihood on CIFAR-10 and ImageNet 32x32 among all flow models, achieving comparable performance to cutting-edge autoregressive models. 12 | 13 | ## How to run the code 14 | 15 | ### Dependencies 16 | 17 | Run the following to install a subset of necessary python packages for our code 18 | ```sh 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ### Stats files for quantitative evaluation 23 | 24 | We provide stats files for computing FID and Inception scores for CIFAR-10 and ImageNet 32x32. You can find `cifar10_stats.npz` and `imagenet32_stats.npz` under the directory `assets/stats` in our [Google drive](https://drive.google.com/drive/folders/1gbDrVrFVSupFMRoK7HZo8aFgPvOtpmqB?usp=sharing). Download them and save to `assets/stats/` in the code repo. 25 | 26 | ### Usage 27 | 28 | Train and evaluate our models through `main.py`. Here are some common options: 29 | 30 | ```sh 31 | main.py: 32 | --config: Training configuration. 33 | (default: 'None') 34 | --eval_folder: The folder name for storing evaluation results 35 | (default: 'eval') 36 | --mode: : Running mode: train or eval or training the Flow++ variational dequantization model 37 | --workdir: Working directory 38 | ``` 39 | 40 | * `config` is the path to the config file. Our config files are provided in `configs/`. They are formatted according to [`ml_collections`](https://github.com/google/ml_collections) and should be quite self-explanatory. 41 | 42 | **Naming conventions of config files**: the name of a config file contains the following attributes: 43 | 44 | * dataset: Either `cifar10` or `imagenet32` 45 | * model: Either `ddpmpp_continuous` or `ddpmpp_deep_continuous` 46 | 47 | * `workdir` is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results. 48 | 49 | * `eval_folder` is the name of a subfolder in `workdir` that stores all artifacts of the evaluation process, like meta checkpoints for supporting pre-emption recovery, image samples, and numpy dumps of quantitative results. 50 | 51 | * `mode` is either "train" or "eval" or "train_deq". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in `workdir/checkpoints-meta` . When set to "eval", it can do the following: 52 | 53 | * Compute the log-likelihood on the training or test dataset. 54 | * Compute the lower bound of the log-likelihood on the training or test dataset. 55 | * Evaluate the loss function on the test / validation dataset. 56 | * Generate a fixed number of samples and compute its Inception score, FID, or KID. Prior to evaluation, stats files must have already been downloaded/computed and stored in `assets/stats`. 57 | 58 | When set to "train_deq", it trains a Flow++ variational dequantization model to bridge the gap of likelihoods on continuous and discrete images. Recommended if you want to compete with generative models trained on discrete images, such as VAEs and autoregressive models. `train_deq` mode also supports pre-emption recovery. 59 | 60 | 61 | These functionalities can be configured through config files, or more conveniently, through the command-line support of the `ml_collections` package. 62 | 63 | ### Configurations for training 64 | To turn on likelihood weighting, set `--config.training.likelihood_weighting`. To additionally turn on importance sampling for variance reduction, use `--config.training.likelihood_weighting`. To train a separate Flow++ variational dequantizer, you need to first finish training a score-based model, then use `--mode=train_deq`. 65 | 66 | ### Configurations for evaluation 67 | To generate samples and evaluate sample quality, use the `--config.eval.enable_sampling` flag; to compute log-likelihoods, use the `--config.eval.enable_bpd` flag, and specify `--config.eval.dataset=train/test` to indicate whether to compute the likelihoods on the training or test dataset. Turn on `--config.eval.bound` to evaluate the variational bound for the log-likelihood. Enable `--config.eval.dequantizer` to use variational dequantization for likelihood computation. `--config.eval.num_repeats` configures the number of repetitions across the dataset (more can reduce the variance of the likelihoods; default to 5). 68 | 69 | ## Pretrained checkpoints 70 | All checkpoints are provided in this [Google drive](https://drive.google.com/drive/folders/1gbDrVrFVSupFMRoK7HZo8aFgPvOtpmqB?usp=sharing). 71 | 72 | Folder structure: 73 | 74 | * `assets`: contains `cifar10_stats.npz` and `imagenet32_stats.npz`. Necessary for computing FID and Inception scores. 75 | * `_(deep)__(likelihood)_(iw)_(flip)`. Here the part enclosed in `()` is optional. `deep` in the name specifies whether the score model is a deeper architecture (`ddpmpp_deep_continuous`). `likelihood` specifies whether the model was trained with likelihood weighting. `iw` specifies whether the model was trained with importance sampling for variance reduction. `flip` shows whether the model was trained with horizontal flip for data augmentation. Each folder has the following two subfolders: 76 | * `checkpoints`: contains the last checkpoint for the score-based model. 77 | * `flowpp_dequantizer/checkpoints`: contains the last checkpoint for the Flow++ variational dequantization model. 78 | 79 | ## References 80 | 81 | If you find the code useful for your research, please consider citing 82 | ```bib 83 | @inproceedings{song2021maximum, 84 | title={Maximum Likelihood Training of Score-Based Diffusion Models}, 85 | author={Song, Yang and Durkan, Conor and Murray, Iain and Ermon, Stefano}, 86 | booktitle={Thirty-Fifth Conference on Neural Information Processing Systems}, 87 | year={2021} 88 | } 89 | ``` 90 | 91 | This work is built upon some previous papers which might also interest you: 92 | 93 | * Yang Song and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." *Proceedings of the 33rd Annual Conference on Neural Information Processing Systems*, 2019. 94 | * Yang Song and Stefano Ermon. "Improved techniques for training score-based generative models." *Proceedings of the 34th Annual Conference on Neural Information Processing Systems*, 2020. 95 | * Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. "Score-Based Generative Modeling through Stochastic Differential Equations". *Proceedings of the 9th International Conference on Learning Representations*, 2021. 96 | 97 | -------------------------------------------------------------------------------- /models/normalization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Normalization layers.""" 17 | import flax.linen as nn 18 | import functools 19 | import jax.nn.initializers as init 20 | import jax.numpy as jnp 21 | import jax 22 | 23 | 24 | def get_normalization(config, conditional=False): 25 | """Obtain normalization modules from the config file.""" 26 | norm = config.model.normalization 27 | if conditional: 28 | if norm == 'InstanceNorm++': 29 | return functools.partial(ConditionalInstanceNorm2dPlus, num_classes=config.model.num_classes) 30 | else: 31 | raise NotImplementedError(f'{norm} not implemented yet.') 32 | else: 33 | if norm == 'InstanceNorm': 34 | return InstanceNorm2d 35 | elif norm == 'InstanceNorm++': 36 | return InstanceNorm2dPlus 37 | elif norm == 'VarianceNorm': 38 | return VarianceNorm2d 39 | elif norm == 'GroupNorm': 40 | return nn.GroupNorm 41 | else: 42 | raise ValueError('Unknown normalization: %s' % norm) 43 | 44 | 45 | IGroupNorm = nn.GroupNorm 46 | 47 | # class IGroupNorm(nn.GroupNorm): 48 | # @nn.compact 49 | # def __call__(self, x): 50 | # return x 51 | 52 | # class IGroupNorm(nn.GroupNorm): 53 | # """Invertible Group normalization. Inspired by ActNorm. 54 | # Requires data-dependent initialization. 55 | # """ 56 | # logscale_factor: float = 10. 57 | # 58 | # @nn.compact 59 | # def __call__(self, x): 60 | # """Applies group normalization to the input (arxiv.org/abs/1803.08494). 61 | # 62 | # Args: 63 | # x: the input of shape N...C, where N is a batch dimension and C is a 64 | # channels dimensions. `...` represents an arbitrary number of extra 65 | # dimensions that are used to accumulate statistics over. 66 | # 67 | # Returns: 68 | # Normalized inputs (the same shape as inputs). 69 | # """ 70 | # x = jnp.asarray(x, jnp.float32) 71 | # if ((self.num_groups is None and self.group_size is None) or 72 | # (self.num_groups is not None and self.group_size is not None)): 73 | # raise ValueError('Either `num_groups` or `group_size` should be ' 74 | # 'specified, but not both of them.') 75 | # num_groups = self.num_groups 76 | # 77 | # if self.group_size is not None: 78 | # channels = x.shape[-1] 79 | # if channels % self.group_size != 0: 80 | # raise ValueError('Number of channels ({}) is not multiple of the ' 81 | # 'group size ({}).'.format(channels, self.group_size)) 82 | # num_groups = channels // self.group_size 83 | # 84 | # input_shape = x.shape 85 | # group_shape = x.shape[:-1] + (num_groups, x.shape[-1] // num_groups) 86 | # x = x.reshape(group_shape) 87 | # 88 | # reduction_axis = [d for d in range(x.ndim - 2)] + [x.ndim - 1] 89 | # mean = self.param('mean', lambda key: jnp.mean(x, axis=reduction_axis, keepdims=True)) 90 | # x = x - mean 91 | # 92 | # logs = self.param('logs', 93 | # lambda key: jnp.log(jax.lax.rsqrt(jnp.mean(jnp.square(x), axis=reduction_axis, keepdims=True) 94 | # + self.epsilon)) / self.logscale_factor) * self.logscale_factor 95 | # x = x * jnp.exp(logs) 96 | # x = x.reshape(input_shape) 97 | # 98 | # return x.astype(self.dtype) 99 | 100 | 101 | class VarianceNorm2d(nn.Module): 102 | """Variance normalization for images.""" 103 | bias: bool = False 104 | 105 | @staticmethod 106 | def scale_init(key, shape, dtype=jnp.float32): 107 | normal_init = init.normal(0.02) 108 | return normal_init(key, shape, dtype=dtype) + 1. 109 | 110 | @nn.compact 111 | def __call__(self, x): 112 | variance = jnp.var(x, axis=(1, 2), keepdims=True) 113 | h = x / jnp.sqrt(variance + 1e-5) 114 | 115 | h = h * self.param('scale', VarianceNorm2d.scale_init, (1, 1, 1, x.shape[-1])) 116 | if self.bias: 117 | h = h + self.param('bias', init.zeros, (1, 1, 1, x.shape[-1])) 118 | 119 | return h 120 | 121 | 122 | class InstanceNorm2d(nn.Module): 123 | """Instance normalization for images.""" 124 | bias: bool = True 125 | 126 | @nn.compact 127 | def __call__(self, x): 128 | mean = jnp.mean(x, axis=(1, 2), keepdims=True) 129 | variance = jnp.var(x, axis=(1, 2), keepdims=True) 130 | h = (x - mean) / jnp.sqrt(variance + 1e-5) 131 | h = h * self.param('scale', init.ones, (1, 1, 1, x.shape[-1])) 132 | if self.bias: 133 | h = h + self.param('bias', init.zeros, (1, 1, 1, x.shape[-1])) 134 | 135 | return h 136 | 137 | 138 | class InstanceNorm2dPlus(nn.Module): 139 | """InstanceNorm++ as proposed in the original NCSN paper.""" 140 | bias: bool = True 141 | 142 | @staticmethod 143 | def scale_init(key, shape, dtype=jnp.float32): 144 | normal_init = init.normal(0.02) 145 | return normal_init(key, shape, dtype=dtype) + 1. 146 | 147 | @nn.compact 148 | def __call__(self, x): 149 | means = jnp.mean(x, axis=(1, 2)) 150 | m = jnp.mean(means, axis=-1, keepdims=True) 151 | v = jnp.var(means, axis=-1, keepdims=True) 152 | means_plus = (means - m) / jnp.sqrt(v + 1e-5) 153 | 154 | h = (x - means[:, None, None, :]) / jnp.sqrt(jnp.var(x, axis=(1, 2), keepdims=True) + 1e-5) 155 | 156 | h = h + means_plus[:, None, None, :] * self.param('alpha', InstanceNorm2dPlus.scale_init, (1, 1, 1, x.shape[-1])) 157 | h = h * self.param('gamma', InstanceNorm2dPlus.scale_init, (1, 1, 1, x.shape[-1])) 158 | if self.bias: 159 | h = h + self.param('beta', init.zeros, (1, 1, 1, x.shape[-1])) 160 | 161 | return h 162 | 163 | 164 | class ConditionalInstanceNorm2dPlus(nn.Module): 165 | """Conditional InstanceNorm++ as in the original NCSN paper.""" 166 | num_classes: int = 10 167 | bias: bool = True 168 | 169 | @nn.compact 170 | def __call__(self, x, y): 171 | means = jnp.mean(x, axis=(1, 2)) 172 | m = jnp.mean(means, axis=-1, keepdims=True) 173 | v = jnp.var(means, axis=-1, keepdims=True) 174 | means_plus = (means - m) / jnp.sqrt(v + 1e-5) 175 | h = (x - means[:, None, None, :]) / jnp.sqrt(jnp.var(x, axis=(1, 2), keepdims=True) + 1e-5) 176 | normal_init = init.normal(0.02) 177 | zero_init = init.zeros 178 | if self.bias: 179 | def init_embed(key, shape, dtype=jnp.float32): 180 | feature_size = shape[1] // 3 181 | normal = normal_init( 182 | key, (shape[0], 2 * feature_size), dtype=dtype) + 1. 183 | zero = zero_init(key, (shape[0], feature_size), dtype=dtype) 184 | return jnp.concatenate([normal, zero], axis=-1) 185 | 186 | embed = nn.Embed(num_embeddings=self.num_classes, features=x.shape[-1] * 3, embedding_init=init_embed) 187 | else: 188 | def init_embed(key, shape, dtype=jnp.float32): 189 | return normal_init(key, shape, dtype=dtype) + 1. 190 | 191 | embed = nn.Embed(num_embeddings=self.num_classes, features=x.shape[-1] * 2, embedding_init=init_embed) 192 | 193 | if self.bias: 194 | gamma, alpha, beta = jnp.split(embed(y), 3, axis=-1) 195 | h = h + means_plus[:, None, None, :] * alpha[:, None, None, :] 196 | out = gamma[:, None, None, :] * h + beta[:, None, None, :] 197 | else: 198 | gamma, alpha = jnp.split(embed(y), 2, axis=-1) 199 | h = h + means_plus[:, None, None, :] * alpha[:, None, None, :] 200 | out = gamma[:, None, None, :] * h 201 | 202 | return out 203 | -------------------------------------------------------------------------------- /models/layerspp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Layers for defining NCSN++. 18 | """ 19 | from typing import Any, Optional, Tuple 20 | from . import layers 21 | from . import up_or_down_sampling 22 | import flax.linen as nn 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | conv1x1 = layers.ddpm_conv1x1 28 | conv3x3 = layers.ddpm_conv3x3 29 | NIN = layers.NIN 30 | default_init = layers.default_init 31 | 32 | 33 | class GaussianFourierProjection(nn.Module): 34 | """Gaussian Fourier embeddings for noise levels.""" 35 | embedding_size: int = 256 36 | scale: float = 1.0 37 | trainable: bool = False 38 | 39 | @nn.compact 40 | def __call__(self, x): 41 | W = self.param('W', jax.nn.initializers.normal(stddev=self.scale), (self.embedding_size,)) 42 | if not self.trainable: 43 | W = jax.lax.stop_gradient(W) 44 | x_proj = x[:, None] * W[None, :] * 2 * jnp.pi 45 | return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis=-1) 46 | 47 | 48 | class Combine(nn.Module): 49 | """Combine information from skip connections.""" 50 | method: str = 'cat' 51 | 52 | @nn.compact 53 | def __call__(self, x, y): 54 | h = conv1x1(x, y.shape[-1]) 55 | if self.method == 'cat': 56 | return jnp.concatenate([h, y], axis=-1) 57 | elif self.method == 'sum': 58 | return h + y 59 | else: 60 | raise ValueError(f'Method {self.method} not recognized.') 61 | 62 | 63 | class AttnBlockpp(nn.Module): 64 | """Channel-wise self-attention block. Modified from DDPM.""" 65 | skip_rescale: bool = False 66 | init_scale: float = 0. 67 | 68 | @nn.compact 69 | def __call__(self, x): 70 | B, H, W, C = x.shape 71 | h = nn.GroupNorm(num_groups=min(x.shape[-1] // 4, 32))(x) 72 | q = NIN(C)(h) 73 | k = NIN(C)(h) 74 | v = NIN(C)(h) 75 | 76 | w = jnp.einsum('bhwc,bHWc->bhwHW', q, k) * (int(C) ** (-0.5)) 77 | w = jnp.reshape(w, (B, H, W, H * W)) 78 | w = jax.nn.softmax(w, axis=-1) 79 | w = jnp.reshape(w, (B, H, W, H, W)) 80 | h = jnp.einsum('bhwHW,bHWc->bhwc', w, v) 81 | h = NIN(C, init_scale=self.init_scale)(h) 82 | if not self.skip_rescale: 83 | return x + h 84 | else: 85 | return (x + h) / np.sqrt(2.) 86 | 87 | 88 | class Upsample(nn.Module): 89 | out_ch: Optional[int] = None 90 | with_conv: bool = False 91 | fir: bool = False 92 | fir_kernel: Tuple[int] = (1, 3, 3, 1) 93 | 94 | @nn.compact 95 | def __call__(self, x): 96 | B, H, W, C = x.shape 97 | out_ch = self.out_ch if self.out_ch else C 98 | if not self.fir: 99 | h = jax.image.resize(x, (x.shape[0], H * 2, W * 2, C), 'nearest') 100 | if self.with_conv: 101 | h = conv3x3(h, out_ch) 102 | else: 103 | if not self.with_conv: 104 | h = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 105 | else: 106 | h = up_or_down_sampling.Conv2d(out_ch, 107 | kernel=3, 108 | up=True, 109 | resample_kernel=self.fir_kernel, 110 | use_bias=True, 111 | kernel_init=default_init())(x) 112 | 113 | assert h.shape == (B, 2 * H, 2 * W, out_ch) 114 | return h 115 | 116 | 117 | class Downsample(nn.Module): 118 | out_ch: Optional[int] = None 119 | with_conv: bool = False 120 | fir: bool = False 121 | fir_kernel: Tuple[int] = (1, 3, 3, 1) 122 | 123 | @nn.compact 124 | def __call__(self, x): 125 | B, H, W, C = x.shape 126 | out_ch = self.out_ch if self.out_ch else C 127 | if not self.fir: 128 | if self.with_conv: 129 | x = conv3x3(x, out_ch, stride=2) 130 | else: 131 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME') 132 | else: 133 | if not self.with_conv: 134 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 135 | else: 136 | x = up_or_down_sampling.Conv2d( 137 | out_ch, 138 | kernel=3, 139 | down=True, 140 | resample_kernel=self.fir_kernel, 141 | use_bias=True, 142 | kernel_init=default_init())(x) 143 | 144 | assert x.shape == (B, H // 2, W // 2, out_ch) 145 | return x 146 | 147 | 148 | class ResnetBlockDDPMpp(nn.Module): 149 | """ResBlock adapted from DDPM.""" 150 | act: Any 151 | out_ch: Optional[int] = None 152 | conv_shortcut: bool = False 153 | dropout: float = 0.1 154 | skip_rescale: bool = False 155 | init_scale: float = 0. 156 | 157 | @nn.compact 158 | def __call__(self, x, temb=None, train=True): 159 | B, H, W, C = x.shape 160 | out_ch = self.out_ch if self.out_ch else C 161 | h = self.act(nn.GroupNorm(num_groups=min(x.shape[-1] // 4, 32))(x)) 162 | h = conv3x3(h, out_ch) 163 | # Add bias to each feature map conditioned on the time embedding 164 | if temb is not None: 165 | h += nn.Dense(out_ch, kernel_init=default_init())(self.act(temb))[:, None, None, :] 166 | 167 | h = self.act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)) 168 | h = nn.Dropout(self.dropout)(h, deterministic=not train) 169 | h = conv3x3(h, out_ch, init_scale=self.init_scale) 170 | if C != out_ch: 171 | if self.conv_shortcut: 172 | x = conv3x3(x, out_ch) 173 | else: 174 | x = NIN(out_ch)(x) 175 | 176 | if not self.skip_rescale: 177 | return x + h 178 | else: 179 | return (x + h) / np.sqrt(2.) 180 | 181 | 182 | class ResnetBlockBigGANpp(nn.Module): 183 | """ResBlock adapted from BigGAN.""" 184 | act: Any 185 | up: bool = False 186 | down: bool = False 187 | out_ch: Optional[int] = None 188 | dropout: float = 0.1 189 | fir: bool = False 190 | fir_kernel: Tuple[int] = (1, 3, 3, 1) 191 | skip_rescale: bool = True 192 | init_scale: float = 0. 193 | 194 | @nn.compact 195 | def __call__(self, x, temb=None, train=True): 196 | B, H, W, C = x.shape 197 | out_ch = self.out_ch if self.out_ch else C 198 | h = self.act(nn.GroupNorm(num_groups=min(x.shape[-1] // 4, 32))(x)) 199 | 200 | if self.up: 201 | if self.fir: 202 | h = up_or_down_sampling.upsample_2d(h, self.fir_kernel, factor=2) 203 | x = up_or_down_sampling.upsample_2d(x, self.fir_kernel, factor=2) 204 | else: 205 | h = up_or_down_sampling.naive_upsample_2d(h, factor=2) 206 | x = up_or_down_sampling.naive_upsample_2d(x, factor=2) 207 | elif self.down: 208 | if self.fir: 209 | h = up_or_down_sampling.downsample_2d(h, self.fir_kernel, factor=2) 210 | x = up_or_down_sampling.downsample_2d(x, self.fir_kernel, factor=2) 211 | else: 212 | h = up_or_down_sampling.naive_downsample_2d(h, factor=2) 213 | x = up_or_down_sampling.naive_downsample_2d(x, factor=2) 214 | 215 | h = conv3x3(h, out_ch) 216 | # Add bias to each feature map conditioned on the time embedding 217 | if temb is not None: 218 | h += nn.Dense(out_ch, kernel_init=default_init())(self.act(temb))[:, None, None, :] 219 | 220 | h = self.act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)) 221 | h = nn.Dropout(self.dropout)(h, deterministic=not train) 222 | h = conv3x3(h, out_ch, init_scale=self.init_scale) 223 | if C != out_ch or self.up or self.down: 224 | x = conv1x1(x, out_ch) 225 | 226 | if not self.skip_rescale: 227 | return x + h 228 | else: 229 | return (x + h) / np.sqrt(2.) 230 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Return training and evaluation/test datasets from config files.""" 18 | import jax 19 | import tensorflow as tf 20 | import tensorflow_datasets as tfds 21 | 22 | 23 | def get_data_scaler(config): 24 | """Data normalizer. Assume data are always in [0, 1].""" 25 | if config.data.centered: 26 | # Rescale to [-1, 1] 27 | return lambda x: x * 2. - 1. 28 | else: 29 | return lambda x: x 30 | 31 | 32 | def get_data_inverse_scaler(config): 33 | """Inverse data normalizer.""" 34 | if config.data.centered: 35 | # Rescale [-1, 1] to [0, 1] 36 | return lambda x: (x + 1.) / 2. 37 | else: 38 | return lambda x: x 39 | 40 | 41 | def crop_resize(image, resolution): 42 | """Crop and resize an image to the given resolution.""" 43 | crop = tf.minimum(tf.shape(image)[0], tf.shape(image)[1]) 44 | h, w = tf.shape(image)[0], tf.shape(image)[1] 45 | image = image[(h - crop) // 2:(h + crop) // 2, 46 | (w - crop) // 2:(w + crop) // 2] 47 | image = tf.image.resize( 48 | image, 49 | size=(resolution, resolution), 50 | antialias=True, 51 | method=tf.image.ResizeMethod.BICUBIC) 52 | return tf.cast(image, tf.uint8) 53 | 54 | 55 | def resize_small(image, resolution): 56 | """Shrink an image to the given resolution.""" 57 | h, w = image.shape[0], image.shape[1] 58 | ratio = resolution / min(h, w) 59 | h = tf.round(h * ratio, tf.int32) 60 | w = tf.round(w * ratio, tf.int32) 61 | return tf.image.resize(image, [h, w], antialias=True) 62 | 63 | 64 | def central_crop(image, size): 65 | """Crop the center of an image to the given size.""" 66 | top = (image.shape[0] - size) // 2 67 | left = (image.shape[1] - size) // 2 68 | return tf.image.crop_to_bounding_box(image, top, left, size, size) 69 | 70 | 71 | def get_dataset(config, additional_dim=None, uniform_dequantization=False, evaluation=False): 72 | """Create data loaders for training and evaluation. 73 | 74 | Args: 75 | config: A ml_collection.ConfigDict parsed from config files. 76 | additional_dim: An integer or `None`. If present, add one additional dimension to the output data, 77 | which equals the number of steps jitted together. 78 | uniform_dequantization: If `True`, add uniform dequantization to images. 79 | evaluation: If `True`, fix number of epochs to 1. 80 | 81 | Returns: 82 | train_ds, eval_ds, dataset_builder. 83 | """ 84 | # Compute batch size for this worker. 85 | batch_size = config.training.batch_size if not evaluation else config.eval.batch_size 86 | if batch_size % jax.device_count() != 0: 87 | raise ValueError(f'Batch sizes ({batch_size} must be divided by' 88 | f'the number of devices ({jax.device_count()})') 89 | 90 | per_device_batch_size = batch_size // jax.device_count() 91 | # Reduce this when image resolution is too large and data pointer is stored 92 | shuffle_buffer_size = 10000 93 | prefetch_size = tf.data.experimental.AUTOTUNE 94 | num_epochs = None if not evaluation else 1 95 | # Create additional data dimension when jitting multiple steps together 96 | if additional_dim is None: 97 | batch_dims = [jax.local_device_count(), per_device_batch_size] 98 | else: 99 | batch_dims = [jax.local_device_count(), additional_dim, per_device_batch_size] 100 | 101 | # Create dataset builders for each dataset. 102 | if config.data.dataset == 'CIFAR10': 103 | dataset_builder = tfds.builder('cifar10') 104 | train_split_name = 'train' 105 | eval_split_name = 'test' 106 | 107 | def resize_op(img): 108 | img = tf.image.convert_image_dtype(img, tf.float32) 109 | return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True) 110 | 111 | elif config.data.dataset == 'MNIST': 112 | dataset_builder = tfds.builder('mnist') 113 | train_split_name = 'train' 114 | eval_split_name = 'test' 115 | 116 | def resize_op(img): 117 | img = tf.image.convert_image_dtype(img, tf.float32) 118 | return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True) 119 | 120 | elif config.data.dataset == 'SVHN': 121 | dataset_builder = tfds.builder('svhn_cropped') 122 | train_split_name = 'train' 123 | eval_split_name = 'test' 124 | 125 | def resize_op(img): 126 | img = tf.image.convert_image_dtype(img, tf.float32) 127 | return tf.image.resize(img, [config.data.image_size, config.data.image_size], antialias=True) 128 | 129 | 130 | elif config.data.dataset == 'CELEBA': 131 | dataset_builder = tfds.builder('celeb_a') 132 | train_split_name = 'train' 133 | eval_split_name = 'validation' 134 | 135 | def resize_op(img): 136 | img = tf.image.convert_image_dtype(img, tf.float32) 137 | img = central_crop(img, 140) 138 | img = resize_small(img, config.data.image_size) 139 | return img 140 | 141 | elif config.data.dataset == 'ImageNet': 142 | size = { 143 | 32: '32x32', 144 | 64: '64x64' 145 | }[config.data.image_size] 146 | dataset_builder = tfds.builder(f'downsampled_imagenet/{size}') 147 | train_split_name = 'train' 148 | eval_split_name = 'validation' 149 | 150 | def resize_op(img): 151 | img = tf.image.convert_image_dtype(img, tf.float32) 152 | return img 153 | 154 | 155 | elif config.data.dataset in ('FFHQ', 'CelebAHQ', 'LSUN'): 156 | dataset_builder = tf.data.TFRecordDataset(config.data.tfrecords_path) 157 | train_split_name = eval_split_name = 'train' 158 | 159 | else: 160 | raise NotImplementedError( 161 | f'Dataset {config.data.dataset} not yet supported.') 162 | 163 | # Customize preprocess functions for each dataset. 164 | if config.data.dataset in ('FFHQ', 'CelebAHQ', 'LSUN'): 165 | def preprocess_fn(d): 166 | sample = tf.io.parse_single_example(d, features={ 167 | 'shape': tf.io.FixedLenFeature([3], tf.int64), 168 | 'data': tf.io.FixedLenFeature([], tf.string)}) 169 | data = tf.io.decode_raw(sample['data'], tf.uint8) 170 | data = tf.reshape(data, sample['shape']) 171 | data = tf.transpose(data, (1, 2, 0)) 172 | img = tf.image.convert_image_dtype(data, tf.float32) 173 | if config.data.random_flip and not evaluation: 174 | img = tf.image.random_flip_left_right(img) 175 | if uniform_dequantization: 176 | img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256. 177 | return dict(image=img, label=None) 178 | 179 | else: 180 | def preprocess_fn(d): 181 | """Basic preprocessing function scales data to [0, 1) and randomly flips.""" 182 | img = resize_op(d['image']) 183 | if config.data.random_flip and not evaluation: 184 | img = tf.image.random_flip_left_right(img) 185 | if uniform_dequantization: 186 | # img = (tf.random.uniform(img.shape, dtype=tf.float32) + img * 255.) / 256. 187 | img = (tf.random.uniform((config.data.image_size, config.data.image_size, 188 | config.data.num_channels), dtype=tf.float32) + img * 255.) / 256. 189 | 190 | return dict(image=img, label=d.get('label', None)) 191 | 192 | def create_dataset(dataset_builder, split): 193 | dataset_options = tf.data.Options() 194 | dataset_options.experimental_optimization.map_parallelization = True 195 | dataset_options.experimental_threading.private_threadpool_size = 48 196 | dataset_options.experimental_threading.max_intra_op_parallelism = 1 197 | read_config = tfds.ReadConfig(options=dataset_options) 198 | if isinstance(dataset_builder, tfds.core.DatasetBuilder): 199 | dataset_builder.download_and_prepare() 200 | ds = dataset_builder.as_dataset( 201 | split=split, shuffle_files=True, read_config=read_config) 202 | else: 203 | ds = dataset_builder.with_options(dataset_options) 204 | ds = ds.repeat(count=num_epochs) 205 | ds = ds.shuffle(shuffle_buffer_size) 206 | ds = ds.map(preprocess_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) 207 | for batch_size in reversed(batch_dims): 208 | ds = ds.batch(batch_size, drop_remainder=True) 209 | return ds.prefetch(prefetch_size) 210 | 211 | train_ds = create_dataset(dataset_builder, train_split_name) 212 | eval_ds = create_dataset(dataset_builder, eval_split_name) 213 | return train_ds, eval_ds, dataset_builder 214 | -------------------------------------------------------------------------------- /bound_likelihood.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | # pytype: skip-file 18 | """Various sampling methods.""" 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | import numpy as np 23 | 24 | import utils 25 | from utils import batch_mul 26 | from models import utils as mutils 27 | from utils import get_div_fn, get_value_div_fn 28 | 29 | 30 | def get_likelihood_bound_fn(sde, model, inverse_scaler, hutchinson_type='Rademacher', 31 | dsm=True, eps=1e-5, N=1000, importance_weighting=True, 32 | eps_offset=True): 33 | """Create a function to compute the unbiased log-likelihood bound of a given data point. 34 | 35 | Args: 36 | sde: A `sde_lib.SDE` object that represents the forward SDE. 37 | model: A `flax.linen.Module` object that represents the architecture of the score-based model. 38 | inverse_scaler: The inverse data normalizer. 39 | hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator. 40 | dsm: bool. Use denoising score matching bound if enabled; otherwise use sliced score matching. 41 | eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability. 42 | N: The number of time values to be sampled. 43 | importance_weighting: True if enable importance weighting for potential variance reduction. 44 | eps_offset: True if use Jensen's inequality to offset the likelihood bound due to non-zero starting time. 45 | 46 | Returns: 47 | A function that takes random states, replicated training states, and a batch of data points 48 | and returns the log-likelihoods in bits/dim, the latent code, and the number of function 49 | evaluations cost by computation. 50 | """ 51 | 52 | def value_div_score_fn(state, x, t, eps): 53 | """Pmapped divergence of the drift function.""" 54 | score_fn = mutils.get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=True) 55 | value_div_fn = get_value_div_fn(lambda x, t: score_fn(x, t)) 56 | return value_div_fn(x, t, eps) 57 | 58 | def div_drift_fn(x, t, eps): 59 | div_fn = get_div_fn(lambda x, t: sde.sde(x, t)[0]) 60 | return div_fn(x, t, eps) 61 | 62 | def likelihood_bound_fn(prng, state, data): 63 | """Compute an unbiased estimate to the log-likelihood in bits/dim. 64 | 65 | Args: 66 | prng: An array of random states. The list dimension equals the number of devices. 67 | pstate: Replicated training state for running on multiple devices. 68 | data: A JAX array of shape [#devices, batch size, ...]. 69 | 70 | Returns: 71 | bpd: A JAX array of shape [#devices, batch size]. The log-likelihoods on `data` in bits/dim. 72 | N: same as input 73 | """ 74 | rng, step_rng = jax.random.split(prng) 75 | if importance_weighting: 76 | time_samples = sde.sample_importance_weighted_time_for_likelihood(step_rng, (N, data.shape[0]), eps=eps) 77 | Z = sde.likelihood_importance_cum_weight(sde.T, eps=eps) 78 | else: 79 | time_samples = jax.random.uniform(step_rng, (N, data.shape[0]), minval=eps, maxval=sde.T) 80 | Z = 1 81 | 82 | shape = data.shape 83 | if not dsm: 84 | def scan_fn(carry, vec_time): 85 | rng, value = carry 86 | rng, step_rng = jax.random.split(rng) 87 | if hutchinson_type == 'Gaussian': 88 | epsilon = jax.random.normal(step_rng, shape) 89 | elif hutchinson_type == 'Rademacher': 90 | epsilon = jax.random.rademacher(step_rng, shape, dtype=jnp.float32) 91 | else: 92 | raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.") 93 | 94 | rng, step_rng = jax.random.split(rng) 95 | noise = jax.random.normal(step_rng, shape) 96 | mean, std = sde.marginal_prob(data, vec_time) 97 | noisy_data = mean + utils.batch_mul(std, noise) 98 | score_val, score_div = value_div_score_fn(state, noisy_data, vec_time, epsilon) 99 | score_norm = jnp.square(score_val.reshape((score_val.shape[0], -1))).sum(axis=-1) 100 | drift_div = div_drift_fn(noisy_data, vec_time, epsilon) 101 | f, g = sde.sde(noisy_data, vec_time) 102 | integrand = utils.batch_mul(g ** 2, 2 * score_div + score_norm) - 2 * drift_div 103 | if importance_weighting: 104 | integrand = utils.batch_mul(std ** 2 / g ** 2 * Z, integrand) 105 | return (rng, value + integrand), integrand 106 | else: 107 | score_fn = mutils.get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=True) 108 | 109 | def scan_fn(carry, vec_time): 110 | rng, value = carry 111 | rng, step_rng = jax.random.split(rng) 112 | if hutchinson_type == 'Gaussian': 113 | epsilon = jax.random.normal(step_rng, shape) 114 | elif hutchinson_type == 'Rademacher': 115 | epsilon = jax.random.rademacher(step_rng, shape, dtype=jnp.float32) 116 | else: 117 | raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.") 118 | rng, step_rng = jax.random.split(rng) 119 | noise = jax.random.normal(step_rng, shape) 120 | mean, std = sde.marginal_prob(data, vec_time) 121 | noisy_data = mean + utils.batch_mul(std, noise) 122 | drift_div = div_drift_fn(noisy_data, vec_time, epsilon) 123 | score_val = score_fn(noisy_data, vec_time) 124 | grad = utils.batch_mul(-(noisy_data - mean), 1 / std ** 2) 125 | diff1 = score_val - grad 126 | diff1 = jnp.square(diff1.reshape((diff1.shape[0], -1))).sum(axis=-1) 127 | diff2 = jnp.square(grad.reshape((grad.shape[0], -1))).sum(axis=-1) 128 | f, g = sde.sde(noisy_data, vec_time) 129 | integrand = utils.batch_mul(g ** 2, diff1 - diff2) - 2 * drift_div 130 | if importance_weighting: 131 | integrand = utils.batch_mul(std ** 2 / g ** 2 * Z, integrand) 132 | return (rng, value + integrand), integrand 133 | 134 | (rng, integral), _ = jax.lax.scan(scan_fn, (rng, jnp.zeros((shape[0],))), time_samples) 135 | integral = integral / N 136 | mean, std = sde.marginal_prob(data, jnp.ones((data.shape[0],)) * sde.T) 137 | rng, step_rng = jax.random.split(rng) 138 | noise = jax.random.normal(step_rng, shape) 139 | neg_prior_logp = -sde.prior_logp(mean + utils.batch_mul(std, noise)) 140 | nlogp = neg_prior_logp + 0.5 * integral 141 | 142 | # whether to enable likelihood offset 143 | if eps_offset: 144 | score_fn = mutils.get_score_fn(sde, model, state.params_ema, state.model_state, train=False, continuous=True) 145 | offset_fn = get_likelihood_offset_fn(sde, score_fn, eps) 146 | rng, step_rng = jax.random.split(rng) 147 | nlogp = nlogp + offset_fn(step_rng, data) 148 | 149 | bpd = nlogp / np.log(2) 150 | dim = np.prod(shape[1:]) 151 | bpd = bpd / dim 152 | 153 | # A hack to convert log-likelihoods to bits/dim 154 | # based on the gradient of the inverse data normalizer. 155 | offset = jnp.log2(jax.grad(inverse_scaler)(0.)) + 8. 156 | bpd += offset 157 | 158 | return bpd, N 159 | 160 | return jax.pmap(likelihood_bound_fn, axis_name='batch') 161 | 162 | 163 | def get_likelihood_offset_fn(sde, score_fn, eps=1e-5): 164 | """Create a function to compute the unbiased log-likelihood bound of a given data point. 165 | """ 166 | 167 | def likelihood_offset_fn(prng, data): 168 | """Compute an unbiased estimate to the log-likelihood in bits/dim. 169 | 170 | Args: 171 | prng: An array of random states. The list dimension equals the number of devices. 172 | pstate: Replicated training state for running on multiple devices. 173 | data: A JAX array of shape [#devices, batch size, ...]. 174 | 175 | Returns: 176 | bpd: A JAX array of shape [#devices, batch size]. The log-likelihoods on `data` in bits/dim. 177 | N: same as input 178 | """ 179 | rng, step_rng = jax.random.split(prng) 180 | shape = data.shape 181 | 182 | eps_vec = jnp.full((shape[0],), eps) 183 | p_mean, p_std = sde.marginal_prob(data, eps_vec) 184 | rng, step_rng = jax.random.split(rng) 185 | noisy_data = p_mean + batch_mul(p_std, jax.random.normal(step_rng, shape)) 186 | score = score_fn(noisy_data, eps_vec) 187 | 188 | alpha, beta = sde.marginal_prob(jnp.ones_like(data), eps_vec) 189 | q_mean = noisy_data / alpha + batch_mul(beta ** 2, score / alpha) 190 | q_std = beta / jnp.mean(alpha, axis=(1, 2, 3)) 191 | 192 | n_dim = np.prod(data.shape[1:]) 193 | p_entropy = n_dim / 2. * (np.log(2 * np.pi) + 2 * jnp.log(p_std) + 1.) 194 | q_recon = n_dim / 2. * (np.log(2 * np.pi) + 2 * jnp.log(q_std)) + batch_mul(0.5 / (q_std ** 2), 195 | jnp.square(data - q_mean).sum( 196 | axis=(1, 2, 3))) 197 | offset = q_recon - p_entropy 198 | return offset 199 | 200 | return likelihood_offset_fn 201 | -------------------------------------------------------------------------------- /models/ncsnpp.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | 18 | from . import utils, layers, layerspp, normalization 19 | import flax.linen as nn 20 | import functools 21 | import jax.numpy as jnp 22 | import numpy as np 23 | import ml_collections 24 | 25 | ResnetBlockDDPM = layerspp.ResnetBlockDDPMpp 26 | ResnetBlockBigGAN = layerspp.ResnetBlockBigGANpp 27 | Combine = layerspp.Combine 28 | conv3x3 = layerspp.conv3x3 29 | conv1x1 = layerspp.conv1x1 30 | get_act = layers.get_act 31 | get_normalization = normalization.get_normalization 32 | default_initializer = layers.default_init 33 | 34 | 35 | @utils.register_model(name='ncsnpp') 36 | class NCSNpp(nn.Module): 37 | """NCSN++ model""" 38 | config: ml_collections.ConfigDict 39 | 40 | @nn.compact 41 | def __call__(self, x, time_cond, train=True): 42 | # config parsing 43 | config = self.config 44 | act = get_act(config) 45 | sigmas = utils.get_sigmas(config) 46 | 47 | nf = config.model.nf 48 | ch_mult = config.model.ch_mult 49 | num_res_blocks = config.model.num_res_blocks 50 | attn_resolutions = config.model.attn_resolutions 51 | dropout = config.model.dropout 52 | resamp_with_conv = config.model.resamp_with_conv 53 | num_resolutions = len(ch_mult) 54 | 55 | conditional = config.model.conditional # noise-conditional 56 | fir = config.model.fir 57 | fir_kernel = config.model.fir_kernel 58 | skip_rescale = config.model.skip_rescale 59 | resblock_type = config.model.resblock_type.lower() 60 | progressive = config.model.progressive.lower() 61 | progressive_input = config.model.progressive_input.lower() 62 | embedding_type = config.model.embedding_type.lower() 63 | init_scale = config.model.init_scale 64 | assert progressive in ['none', 'output_skip', 'residual'] 65 | assert progressive_input in ['none', 'input_skip', 'residual'] 66 | assert embedding_type in ['fourier', 'positional'] 67 | combine_method = config.model.progressive_combine.lower() 68 | combiner = functools.partial(Combine, method=combine_method) 69 | sde = config.training.sde 70 | 71 | # timestep/noise_level embedding; only for continuous training 72 | if embedding_type == 'fourier': 73 | # Gaussian Fourier features embeddings. 74 | 75 | assert config.training.continuous, "Fourier features are only used for continuous training." 76 | if sde.lower() == 'vesde': 77 | used_sigmas = time_cond 78 | temb = layerspp.GaussianFourierProjection( 79 | embedding_size=nf, 80 | scale=config.model.fourier_scale, 81 | trainable=config.model.trainable_embedding 82 | )(jnp.log(used_sigmas)) 83 | else: 84 | temb = layerspp.GaussianFourierProjection( 85 | embedding_size=nf, 86 | scale=config.model.fourier_scale, 87 | trainable=config.model.trainable_embedding 88 | )(time_cond) 89 | 90 | elif embedding_type == 'positional': 91 | # Sinusoidal positional embeddings. 92 | assert sde.lower() != 'vesde' 93 | timesteps = time_cond 94 | temb = layers.get_timestep_embedding(timesteps, nf) 95 | 96 | else: 97 | raise ValueError(f'embedding type {embedding_type} unknown.') 98 | 99 | if conditional: 100 | temb = nn.Dense(nf * 4, kernel_init=default_initializer())(temb) 101 | temb = nn.Dense(nf * 4, kernel_init=default_initializer())(act(temb)) 102 | else: 103 | temb = None 104 | 105 | AttnBlock = functools.partial(layerspp.AttnBlockpp, 106 | init_scale=init_scale, 107 | skip_rescale=skip_rescale) 108 | 109 | Upsample = functools.partial(layerspp.Upsample, 110 | with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) 111 | 112 | if progressive == 'output_skip': 113 | pyramid_upsample = functools.partial(layerspp.Upsample, 114 | fir=fir, fir_kernel=fir_kernel, with_conv=False) 115 | elif progressive == 'residual': 116 | pyramid_upsample = functools.partial(layerspp.Upsample, 117 | fir=fir, fir_kernel=fir_kernel, with_conv=True) 118 | 119 | Downsample = functools.partial(layerspp.Downsample, 120 | with_conv=resamp_with_conv, fir=fir, fir_kernel=fir_kernel) 121 | 122 | if progressive_input == 'input_skip': 123 | pyramid_downsample = functools.partial(layerspp.Downsample, 124 | fir=fir, fir_kernel=fir_kernel, with_conv=False) 125 | elif progressive_input == 'residual': 126 | pyramid_downsample = functools.partial(layerspp.Downsample, 127 | fir=fir, fir_kernel=fir_kernel, with_conv=True) 128 | 129 | if resblock_type == 'ddpm': 130 | ResnetBlock = functools.partial(ResnetBlockDDPM, 131 | act=act, 132 | dropout=dropout, 133 | init_scale=init_scale, 134 | skip_rescale=skip_rescale) 135 | 136 | elif resblock_type == 'biggan': 137 | ResnetBlock = functools.partial(ResnetBlockBigGAN, 138 | act=act, 139 | dropout=dropout, 140 | fir=fir, 141 | fir_kernel=fir_kernel, 142 | init_scale=init_scale, 143 | skip_rescale=skip_rescale) 144 | 145 | else: 146 | raise ValueError(f'resblock type {resblock_type} unrecognized.') 147 | 148 | if not config.data.centered: 149 | # If input data is in [0, 1] 150 | x = 2 * x - 1. 151 | 152 | # Downsampling block 153 | 154 | input_pyramid = None 155 | if progressive_input != 'none': 156 | input_pyramid = x 157 | 158 | hs = [conv3x3(x, nf)] 159 | for i_level in range(num_resolutions): 160 | # Residual blocks for this resolution 161 | for i_block in range(num_res_blocks): 162 | h = ResnetBlock(out_ch=nf * ch_mult[i_level])(hs[-1], temb, train) 163 | if h.shape[1] in attn_resolutions: 164 | h = AttnBlock()(h) 165 | hs.append(h) 166 | 167 | if i_level != num_resolutions - 1: 168 | if resblock_type == 'ddpm': 169 | h = Downsample()(hs[-1]) 170 | else: 171 | h = ResnetBlock(down=True)(hs[-1], temb, train) 172 | 173 | if progressive_input == 'input_skip': 174 | input_pyramid = pyramid_downsample()(input_pyramid) 175 | h = combiner()(input_pyramid, h) 176 | 177 | elif progressive_input == 'residual': 178 | input_pyramid = pyramid_downsample(out_ch=h.shape[-1])(input_pyramid) 179 | if skip_rescale: 180 | input_pyramid = (input_pyramid + h) / np.sqrt(2.) 181 | else: 182 | input_pyramid = input_pyramid + h 183 | h = input_pyramid 184 | 185 | hs.append(h) 186 | 187 | h = hs[-1] 188 | h = ResnetBlock()(h, temb, train) 189 | h = AttnBlock()(h) 190 | h = ResnetBlock()(h, temb, train) 191 | 192 | pyramid = None 193 | 194 | # Upsampling block 195 | for i_level in reversed(range(num_resolutions)): 196 | for i_block in range(num_res_blocks + 1): 197 | h = ResnetBlock(out_ch=nf * ch_mult[i_level])(jnp.concatenate([h, hs.pop()], axis=-1), 198 | temb, 199 | train) 200 | 201 | if h.shape[1] in attn_resolutions: 202 | h = AttnBlock()(h) 203 | 204 | if progressive != 'none': 205 | if i_level == num_resolutions - 1: 206 | if progressive == 'output_skip': 207 | pyramid = conv3x3( 208 | act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)), 209 | x.shape[-1], 210 | bias=True, 211 | init_scale=init_scale) 212 | elif progressive == 'residual': 213 | pyramid = conv3x3( 214 | act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)), 215 | h.shape[-1], 216 | bias=True) 217 | else: 218 | raise ValueError(f'{progressive} is not a valid name.') 219 | else: 220 | if progressive == 'output_skip': 221 | pyramid = pyramid_upsample()(pyramid) 222 | pyramid = pyramid + conv3x3( 223 | act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)), 224 | x.shape[-1], 225 | bias=True, 226 | init_scale=init_scale) 227 | elif progressive == 'residual': 228 | pyramid = pyramid_upsample(out_ch=h.shape[-1])(pyramid) 229 | if skip_rescale: 230 | pyramid = (pyramid + h) / np.sqrt(2.) 231 | else: 232 | pyramid = pyramid + h 233 | h = pyramid 234 | else: 235 | raise ValueError(f'{progressive} is not a valid name') 236 | 237 | if i_level != 0: 238 | if resblock_type == 'ddpm': 239 | h = Upsample()(h) 240 | else: 241 | h = ResnetBlock(up=True)(h, temb, train) 242 | 243 | assert not hs 244 | 245 | if progressive == 'output_skip': 246 | h = pyramid 247 | else: 248 | h = act(nn.GroupNorm(num_groups=min(h.shape[-1] // 4, 32))(h)) 249 | h = conv3x3(h, x.shape[-1], init_scale=init_scale) 250 | 251 | if config.model.scale_by_sigma: 252 | assert config.training.sde in ('vesde', 'linearvesde') 253 | used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:])))) 254 | h = h / used_sigmas 255 | 256 | return h 257 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """All functions and modules related to model definition. 17 | """ 18 | from typing import Any 19 | 20 | import flax 21 | import functools 22 | import jax.numpy as jnp 23 | 24 | import datasets 25 | import sde_lib 26 | import jax 27 | import numpy as np 28 | from flax.training import checkpoints 29 | from utils import batch_mul 30 | 31 | 32 | # The dataclass that stores all training states 33 | @flax.struct.dataclass 34 | class State: 35 | step: int 36 | optimizer: flax.optim.Optimizer 37 | lr: float 38 | model_state: Any 39 | ema_rate: float 40 | params_ema: Any 41 | rng: Any 42 | 43 | 44 | @flax.struct.dataclass 45 | class DeqState: 46 | step: int 47 | optimizer: flax.optim.Optimizer 48 | lr: float 49 | ema_rate: float 50 | params_ema: Any 51 | ema_train_bpd: float 52 | ema_eval_bpd: float 53 | rng: Any 54 | 55 | 56 | _MODELS = {} 57 | 58 | 59 | def register_model(cls=None, *, name=None): 60 | """A decorator for registering model classes.""" 61 | 62 | def _register(cls): 63 | if name is None: 64 | local_name = cls.__name__ 65 | else: 66 | local_name = name 67 | if local_name in _MODELS: 68 | raise ValueError(f'Already registered model with name: {local_name}') 69 | _MODELS[local_name] = cls 70 | return cls 71 | 72 | if cls is None: 73 | return _register 74 | else: 75 | return _register(cls) 76 | 77 | 78 | def get_model(name): 79 | return _MODELS[name] 80 | 81 | 82 | def get_sigmas(config): 83 | """Get sigmas --- the set of noise levels for SMLD from config files. 84 | Args: 85 | config: A ConfigDict object parsed from the config file 86 | Returns: 87 | sigmas: a jax numpy arrary of noise levels 88 | """ 89 | 90 | if config.training.sde.lower() == 'linearvesde': 91 | sigmas = jnp.sqrt(jnp.linspace(config.model.sigma_max ** 2, config.model.sigma_min ** 2, 92 | config.model.num_scales)) 93 | else: 94 | sigmas = jnp.exp( 95 | jnp.linspace( 96 | jnp.log(config.model.sigma_max), jnp.log(config.model.sigma_min), 97 | config.model.num_scales)) 98 | return sigmas 99 | 100 | 101 | def get_ddpm_params(config): 102 | """Get betas and alphas --- parameters used in the original DDPM paper.""" 103 | num_diffusion_timesteps = 1000 104 | # parameters need to be adapted if number of time steps differs from 1000 105 | beta_start = config.model.beta_min / config.model.num_scales 106 | beta_end = config.model.beta_max / config.model.num_scales 107 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 108 | 109 | alphas = 1. - betas 110 | alphas_cumprod = np.cumprod(alphas, axis=0) 111 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) 112 | sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod) 113 | 114 | return { 115 | 'betas': betas, 116 | 'alphas': alphas, 117 | 'alphas_cumprod': alphas_cumprod, 118 | 'sqrt_alphas_cumprod': sqrt_alphas_cumprod, 119 | 'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod, 120 | 'beta_min': beta_start * (num_diffusion_timesteps - 1), 121 | 'beta_max': beta_end * (num_diffusion_timesteps - 1), 122 | 'num_diffusion_timesteps': num_diffusion_timesteps 123 | } 124 | 125 | 126 | def init_model(rng, config, data=None, label=None): 127 | """ Initialize a `flax.linen.Module` model. """ 128 | model_name = config.model.name 129 | model_def = functools.partial(get_model(model_name), config=config) 130 | input_shape = (config.training.batch_size // jax.local_device_count(), 131 | config.data.image_size, config.data.image_size, config.data.num_channels) 132 | label_shape = input_shape[:1] 133 | if data is None: 134 | init_input = jnp.zeros(input_shape) 135 | else: 136 | init_input = data 137 | if label is None: 138 | init_label = jnp.zeros(label_shape, dtype=jnp.int32) 139 | else: 140 | init_label = label 141 | params_rng, dropout_rng = jax.random.split(rng) 142 | model = model_def() 143 | variables = model.init({'params': params_rng, 'dropout': dropout_rng}, init_input, init_label) 144 | # Variables is a `flax.FrozenDict`. It is immutable and respects functional programming 145 | init_model_state, initial_params = variables.pop('params') 146 | return model, init_model_state, initial_params 147 | 148 | 149 | def data_dependent_init_of_dequantizer(rng, config, init_data): 150 | if config.data.dataset == 'ImageNet': 151 | if config.data.image_size == 32: 152 | from .flowpp import dequantization_imagenet32 153 | model = dequantization_imagenet32.Dequantization() 154 | elif config.data.image_size == 64: 155 | from .flowpp import dequantization_imagenet64 156 | model = dequantization_imagenet64.Dequantization() 157 | elif config.data.dataset == 'CIFAR10': 158 | from .flowpp import dequantization_cifar10 159 | model = dequantization_cifar10.Dequantization() 160 | 161 | rng, step_rng = jax.random.split(rng) 162 | u = jax.random.normal(step_rng, init_data.shape) 163 | 164 | @functools.partial(jax.pmap, axis_name='batch') 165 | def init_func(params_rng, dropout_rng, eps, data): 166 | return model.init({'params': params_rng, 'dropout': dropout_rng}, eps, data, inverse=False, train=False) 167 | 168 | rng, *params_rng = jax.random.split(rng, jax.local_device_count() + 1) 169 | params_rng = jnp.asarray(params_rng) 170 | rng, *dropout_rng = jax.random.split(rng, jax.local_device_count() + 1) 171 | dropout_rng = jnp.asarray(dropout_rng) 172 | variables = flax.jax_utils.unreplicate(init_func(params_rng, dropout_rng, u, init_data)) 173 | return model, variables 174 | 175 | 176 | def get_dequantizer(model, variables, train=False): 177 | def dequantizer(u, x, rng=None): 178 | if not train: 179 | u_deq, sldj = model.apply(variables, u, x, train=train, inverse=False) 180 | else: 181 | u_deq, sldj = model.apply(variables, u, x, train=train, inverse=False, rngs={'dropout': rng}) 182 | 183 | return u_deq, sldj 184 | 185 | return dequantizer 186 | 187 | 188 | def get_model_fn(model, params, states, train=False): 189 | """Create a function to give the output of the score-based model. 190 | 191 | Args: 192 | model: A `flax.linen.Module` object the represent the architecture of score-based model. 193 | params: A dictionary that contains all trainable parameters. 194 | states: A dictionary that contains all mutable states. 195 | train: `True` for training and `False` for evaluation. 196 | 197 | Returns: 198 | A model function. 199 | """ 200 | 201 | def model_fn(x, labels, rng=None): 202 | """Compute the output of the score-based model. 203 | 204 | Args: 205 | x: A mini-batch of input data. 206 | labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently 207 | for different models. 208 | rng: If present, it is the random state for dropout 209 | 210 | Returns: 211 | A tuple of (model output, new mutable states) 212 | """ 213 | variables = {'params': params, **states} 214 | if not train: 215 | return model.apply(variables, x, labels, train=False, mutable=False), states 216 | else: 217 | rngs = {'dropout': rng} 218 | return model.apply(variables, x, labels, train=True, mutable=list(states.keys()), rngs=rngs) 219 | # if states: 220 | # return outputs 221 | # else: 222 | # return outputs, states 223 | 224 | return model_fn 225 | 226 | 227 | def get_score_fn(sde, model, params, states, train=False, continuous=False, return_state=False): 228 | """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function. 229 | 230 | Args: 231 | sde: An `sde_lib.SDE` object that represents the forward SDE. 232 | model: A `flax.linen.Module` object that represents the architecture of the score-based model. 233 | params: A dictionary that contains all trainable parameters. 234 | states: A dictionary that contains all other mutable parameters. 235 | train: `True` for training and `False` for evaluation. 236 | continuous: If `True`, the score-based model is expected to directly take continuous time steps. 237 | return_state: If `True`, return the new mutable states alongside the model output. 238 | 239 | Returns: 240 | A score function. 241 | """ 242 | model_fn = get_model_fn(model, params, states, train=train) 243 | 244 | if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE): 245 | def score_fn(x, t, rng=None): 246 | # Scale neural network output by standard deviation and flip sign 247 | if continuous or isinstance(sde, sde_lib.subVPSDE): 248 | # For VP-trained models, t=0 corresponds to the lowest noise level 249 | # The maximum value of time embedding is assumed to 999 for 250 | # continuously-trained models. 251 | labels = t * 999 252 | model, state = model_fn(x, labels, rng) 253 | std = sde.marginal_prob(jnp.zeros_like(x), t)[1] 254 | else: 255 | # For VP-trained models, t=0 corresponds to the lowest noise level 256 | labels = t * (sde.N - 1) 257 | model, state = model_fn(x, labels, rng) 258 | std = sde.sqrt_1m_alphas_cumprod[labels.astype(jnp.int32)] 259 | 260 | score = batch_mul(-model, 1. / std) 261 | if return_state: 262 | return score, state 263 | else: 264 | return score 265 | 266 | elif isinstance(sde, sde_lib.VESDE): 267 | def score_fn(x, t, rng=None): 268 | if sde.linear is False: 269 | if continuous: 270 | labels = sde.marginal_prob(jnp.zeros_like(x), t)[1] 271 | else: 272 | # For VE-trained models, t=0 corresponds to the highest noise level 273 | labels = sde.T - t 274 | labels *= sde.N - 1 275 | labels = jnp.round(labels).astype(jnp.int32) 276 | 277 | score, state = model_fn(x, labels, rng) 278 | else: 279 | assert continuous 280 | labels = t * 999 281 | model, state = model_fn(x, labels, rng) 282 | std = sde.marginal_prob(jnp.zeros_like(x), t)[1] 283 | score = batch_mul(-model, 1. / std) 284 | 285 | if return_state: 286 | return score, state 287 | else: 288 | return score 289 | 290 | else: 291 | raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.") 292 | 293 | return score_fn 294 | 295 | 296 | def to_flattened_numpy(x): 297 | """Flatten a JAX array `x` and convert it to numpy.""" 298 | return np.asarray(x.reshape((-1,)), dtype=np.float64) 299 | 300 | 301 | def from_flattened_numpy(x, shape): 302 | """Form a JAX array with the given `shape` from a flattened numpy array `x`.""" 303 | return jnp.asarray(x, dtype=jnp.float32).reshape(shape) -------------------------------------------------------------------------------- /sde_lib.py: -------------------------------------------------------------------------------- 1 | """Abstract SDE classes, Reverse SDE, and VE/VP SDEs.""" 2 | import abc 3 | import jax.numpy as jnp 4 | import jax 5 | import numpy as np 6 | from utils import batch_mul 7 | 8 | 9 | class SDE(abc.ABC): 10 | """SDE abstract class. Functions are designed for a mini-batch of inputs.""" 11 | 12 | def __init__(self, N): 13 | """Construct an SDE. 14 | 15 | Args: 16 | N: number of discretization time steps. 17 | """ 18 | super().__init__() 19 | self.N = N 20 | 21 | @property 22 | @abc.abstractmethod 23 | def T(self): 24 | """End time of the SDE.""" 25 | pass 26 | 27 | @abc.abstractmethod 28 | def sde(self, x, t): 29 | pass 30 | 31 | @abc.abstractmethod 32 | def marginal_prob(self, x, t): 33 | """Parameters to determine the marginal distribution of the SDE, $p_t(x)$.""" 34 | pass 35 | 36 | @abc.abstractmethod 37 | def prior_sampling(self, rng, shape): 38 | """Generate one sample from the prior distribution, $p_T(x)$.""" 39 | pass 40 | 41 | @abc.abstractmethod 42 | def prior_logp(self, z): 43 | """Compute log-density of the prior distribution. 44 | 45 | Useful for computing the log-likelihood via probability flow ODE. 46 | 47 | Args: 48 | z: latent code 49 | Returns: 50 | log probability density 51 | """ 52 | pass 53 | 54 | def discretize(self, x, t): 55 | """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i. 56 | 57 | Useful for reverse diffusion sampling and probabiliy flow sampling. 58 | Defaults to Euler-Maruyama discretization. 59 | 60 | Args: 61 | x: a JAX tensor. 62 | t: a JAX float representing the time step (from 0 to `self.T`) 63 | 64 | Returns: 65 | f, G 66 | """ 67 | dt = 1 / self.N 68 | drift, diffusion = self.sde(x, t) 69 | f = drift * dt 70 | G = diffusion * jnp.sqrt(dt) 71 | return f, G 72 | 73 | def reverse(self, score_fn, probability_flow=False): 74 | """Create the reverse-time SDE/ODE. 75 | 76 | Args: 77 | score_fn: A time-dependent score-based model that takes x and t and returns the score. 78 | probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling. 79 | """ 80 | N = self.N 81 | T = self.T 82 | sde_fn = self.sde 83 | discretize_fn = self.discretize 84 | 85 | # Build the class for reverse-time SDE. 86 | class RSDE(self.__class__): 87 | def __init__(self): 88 | self.N = N 89 | self.probability_flow = probability_flow 90 | 91 | @property 92 | def T(self): 93 | return T 94 | 95 | def sde(self, x, t): 96 | """Create the drift and diffusion functions for the reverse SDE/ODE.""" 97 | drift, diffusion = sde_fn(x, t) 98 | score = score_fn(x, t) 99 | drift = drift - batch_mul(diffusion ** 2, score * (0.5 if self.probability_flow else 1.)) 100 | # Set the diffusion function to zero for ODEs. 101 | diffusion = jnp.zeros_like(t) if self.probability_flow else diffusion 102 | return drift, diffusion 103 | 104 | def discretize(self, x, t): 105 | """Create discretized iteration rules for the reverse diffusion sampler.""" 106 | f, G = discretize_fn(x, t) 107 | rev_f = f - batch_mul(G ** 2, score_fn(x, t) * (0.5 if self.probability_flow else 1.)) 108 | rev_G = jnp.zeros_like(t) if self.probability_flow else G 109 | return rev_f, rev_G 110 | 111 | return RSDE() 112 | 113 | 114 | class VPSDE(SDE): 115 | def __init__(self, beta_min=0.1, beta_max=20, N=1000): 116 | """Construct a Variance Preserving SDE. 117 | 118 | Args: 119 | beta_min: value of beta(0) 120 | beta_max: value of beta(1) 121 | N: number of discretization steps 122 | """ 123 | super().__init__(N) 124 | self.beta_0 = beta_min 125 | self.beta_1 = beta_max 126 | self.N = N 127 | self.discrete_betas = jnp.linspace(beta_min / N, beta_max / N, N) 128 | self.alphas = 1. - self.discrete_betas 129 | self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0) 130 | self.sqrt_alphas_cumprod = jnp.sqrt(self.alphas_cumprod) 131 | self.sqrt_1m_alphas_cumprod = jnp.sqrt(1. - self.alphas_cumprod) 132 | 133 | @property 134 | def T(self): 135 | return 1 136 | 137 | def sde(self, x, t): 138 | beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) 139 | drift = -0.5 * batch_mul(beta_t, x) 140 | diffusion = jnp.sqrt(beta_t) 141 | return drift, diffusion 142 | 143 | def marginal_prob(self, x, t, high_precision=True): 144 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 145 | if high_precision: 146 | mean = batch_mul(jnp.where(jnp.abs(log_mean_coeff) <= 1e-3, 1 + log_mean_coeff, jnp.exp(log_mean_coeff)), x) 147 | std = jnp.where(jnp.abs(log_mean_coeff) <= 1e-3, jnp.sqrt(-2. * log_mean_coeff), 148 | jnp.sqrt(1 - jnp.exp(2. * log_mean_coeff))) 149 | else: 150 | mean = batch_mul(jnp.exp(log_mean_coeff), x) 151 | std = jnp.sqrt(1 - jnp.exp(2 * log_mean_coeff)) 152 | return mean, std 153 | 154 | def prior_sampling(self, rng, shape): 155 | return jax.random.normal(rng, shape) 156 | 157 | def prior_logp(self, z): 158 | shape = z.shape 159 | N = np.prod(shape[1:]) 160 | logp_fn = lambda z: -N / 2. * jnp.log(2 * np.pi) - jnp.sum(z ** 2) / 2. 161 | return jax.vmap(logp_fn)(z) 162 | 163 | def prior_entropy(self, z): 164 | shape = z.shape 165 | entropy = jnp.ones(shape) * (0.5 * jnp.log(2 * np.pi) + 0.5) 166 | entropy = entropy.reshape((z.shape[0], -1)) 167 | return jnp.sum(entropy, axis=-1) 168 | 169 | def discretize(self, x, t): 170 | """DDPM discretization.""" 171 | timestep = (t * (self.N - 1) / self.T).astype(jnp.int32) 172 | beta = self.discrete_betas[timestep] 173 | alpha = self.alphas[timestep] 174 | sqrt_beta = jnp.sqrt(beta) 175 | f = batch_mul(jnp.sqrt(alpha), x) - x 176 | G = sqrt_beta 177 | return f, G 178 | 179 | def likelihood_importance_cum_weight(self, t, eps=1e-5): 180 | exponent1 = 0.5 * eps * (eps - 2) * self.beta_0 - 0.5 * eps ** 2 * self.beta_1 181 | exponent2 = 0.5 * t * (t - 2) * self.beta_0 - 0.5 * t ** 2 * self.beta_1 182 | term1 = jnp.where(jnp.abs(exponent1) <= 1e-3, -exponent1, 1. - jnp.exp(exponent1)) 183 | term2 = jnp.where(jnp.abs(exponent2) <= 1e-3, -exponent2, 1. - jnp.exp(exponent2)) 184 | return 0.5 * (-2 * jnp.log(term1) + 2 * jnp.log(term2) 185 | + self.beta_0 * (-2 * eps + eps ** 2 - (t - 2) * t) 186 | + self.beta_1 * (-eps ** 2 + t ** 2)) 187 | 188 | def sample_importance_weighted_time_for_likelihood(self, rng, shape, quantile=None, eps=1e-5, steps=100): 189 | Z = self.likelihood_importance_cum_weight(self.T, eps=eps) 190 | if quantile is None: 191 | quantile = jax.random.uniform(rng, shape, minval=0, maxval=Z) 192 | lb = jnp.ones_like(quantile) * eps 193 | ub = jnp.ones_like(quantile) * self.T 194 | 195 | def bisection_func(carry, idx): 196 | lb, ub = carry 197 | mid = (lb + ub) / 2. 198 | value = self.likelihood_importance_cum_weight(mid, eps=eps) 199 | lb = jnp.where(value <= quantile, mid, lb) 200 | ub = jnp.where(value <= quantile, ub, mid) 201 | return (lb, ub), idx 202 | 203 | (lb, ub), _ = jax.lax.scan(bisection_func, (lb, ub), jnp.arange(0, steps)) 204 | return (lb + ub) / 2. 205 | 206 | 207 | class subVPSDE(SDE): 208 | def __init__(self, beta_min=0.1, beta_max=20, N=1000): 209 | """Construct the sub-VP SDE that excels at likelihoods. 210 | 211 | Args: 212 | beta_min: value of beta(0) 213 | beta_max: value of beta(1) 214 | N: number of discretization steps 215 | """ 216 | super().__init__(N) 217 | self.beta_0 = beta_min 218 | self.beta_1 = beta_max 219 | self.N = N 220 | 221 | @property 222 | def T(self): 223 | return 1 224 | 225 | def sde(self, x, t, high_precision=True): 226 | beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0) 227 | drift = -0.5 * batch_mul(beta_t, x) 228 | exponent = -2 * self.beta_0 * t - (self.beta_1 - self.beta_0) * t ** 2 229 | discount = 1. - jnp.exp(exponent) 230 | if high_precision: 231 | discount = jnp.where(jnp.abs(exponent) <= 1e-3, -exponent, discount) 232 | diffusion = jnp.sqrt(beta_t * discount) 233 | return drift, diffusion 234 | 235 | def marginal_prob(self, x, t, high_precision=True): 236 | log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 237 | if high_precision: 238 | mean = batch_mul(jnp.where(jnp.abs(log_mean_coeff) <= 1e-3, 1. + log_mean_coeff, jnp.exp(log_mean_coeff)), x) 239 | std = jnp.where(jnp.abs(log_mean_coeff) <= 1e-3, -2. * log_mean_coeff, 1 - jnp.exp(2. * log_mean_coeff)) 240 | else: 241 | mean = batch_mul(jnp.exp(log_mean_coeff), x) 242 | std = 1 - jnp.exp(2. * log_mean_coeff) 243 | return mean, std 244 | 245 | def prior_sampling(self, rng, shape): 246 | return jax.random.normal(rng, shape) 247 | 248 | def prior_logp(self, z): 249 | shape = z.shape 250 | N = np.prod(shape[1:]) 251 | logp_fn = lambda z: -N / 2. * jnp.log(2 * np.pi) - jnp.sum(z ** 2) / 2. 252 | return jax.vmap(logp_fn)(z) 253 | 254 | def prior_entropy(self, z): 255 | shape = z.shape 256 | entropy = jnp.ones(shape) * (0.5 * jnp.log(2 * np.pi) + 0.5) 257 | entropy = entropy.reshape((z.shape[0], -1)) 258 | return jnp.sum(entropy, axis=-1) 259 | 260 | def likelihood_importance_cum_weight(self, t, eps=1e-5): 261 | exponent1 = 0.5 * eps * (eps * self.beta_1 - (eps - 2) * self.beta_0) 262 | exponent2 = 0.5 * t * (self.beta_1 * t - (t - 2) * self.beta_0) 263 | term1 = jnp.where(exponent1 <= 1e-3, jnp.log(exponent1), jnp.log(jnp.exp(exponent1) - 1.)) 264 | term2 = jnp.where(exponent2 <= 1e-3, jnp.log(exponent2), jnp.log(jnp.exp(exponent2) - 1.)) 265 | return 0.5 * (-4 * term1 + 4 * term2 266 | + (2 * eps - eps ** 2 + t * (t - 2)) * self.beta_0 + (eps ** 2 - t ** 2) * self.beta_1) 267 | 268 | def sample_importance_weighted_time_for_likelihood(self, rng, shape, quantile=None, eps=1e-5, steps=100): 269 | Z = self.likelihood_importance_cum_weight(self.T, eps=eps) 270 | if quantile is None: 271 | quantile = jax.random.uniform(rng, shape, minval=0, maxval=Z) 272 | lb = jnp.ones_like(quantile) * eps 273 | ub = jnp.ones_like(quantile) * self.T 274 | 275 | def bisection_func(carry, idx): 276 | lb, ub = carry 277 | mid = (lb + ub) / 2. 278 | value = self.likelihood_importance_cum_weight(mid, eps=eps) 279 | lb = jnp.where(value <= quantile, mid, lb) 280 | ub = jnp.where(value <= quantile, ub, mid) 281 | return (lb, ub), idx 282 | 283 | (lb, ub), _ = jax.lax.scan(bisection_func, (lb, ub), jnp.arange(0, steps)) 284 | return (lb + ub) / 2. 285 | 286 | 287 | class VESDE(SDE): 288 | def __init__(self, sigma_min=0.01, sigma_max=50, N=1000, linear=False): 289 | """Construct a Variance Exploding SDE. 290 | 291 | Args: 292 | sigma_min: smallest sigma. 293 | sigma_max: largest sigma. 294 | N: number of discretization steps 295 | """ 296 | super().__init__(N) 297 | self.sigma_min = sigma_min 298 | self.sigma_max = sigma_max 299 | self.linear = linear 300 | if not linear: 301 | self.discrete_sigmas = jnp.exp(np.linspace(np.log(self.sigma_min), np.log(self.sigma_max), N)) 302 | else: 303 | self.discrete_sigmas = jnp.linspace(self.sigma_min, self.sigma_max, N) 304 | self.N = N 305 | 306 | @property 307 | def T(self): 308 | return 1 309 | 310 | def sde(self, x, t): 311 | drift = jnp.zeros_like(x) 312 | if not self.linear: 313 | sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 314 | diffusion = sigma * jnp.sqrt(2 * (jnp.log(self.sigma_max) - jnp.log(self.sigma_min))) 315 | else: 316 | diffusion = self.sigma_max * jnp.sqrt(2 * t) 317 | 318 | return drift, diffusion 319 | 320 | def marginal_prob(self, x, t): 321 | mean = x 322 | if not self.linear: 323 | std = self.sigma_min * (self.sigma_max / self.sigma_min) ** t 324 | else: 325 | std = t * self.sigma_max 326 | return mean, std 327 | 328 | def prior_sampling(self, rng, shape): 329 | return jax.random.normal(rng, shape) * self.sigma_max 330 | 331 | def prior_logp(self, z): 332 | shape = z.shape 333 | N = np.prod(shape[1:]) 334 | logp_fn = lambda z: -N / 2. * jnp.log(2 * np.pi * self.sigma_max ** 2) - jnp.sum(z ** 2) / (2 * self.sigma_max ** 2) 335 | return jax.vmap(logp_fn)(z) 336 | 337 | def prior_entropy(self, z): 338 | shape = z.shape 339 | entropy = jnp.ones(shape) * (0.5 * jnp.log(2 * np.pi * self.sigma_max ** 2) + 0.5) 340 | entropy = entropy.reshape((z.shape[0], -1)) 341 | return jnp.sum(entropy, axis=-1) 342 | 343 | def discretize(self, x, t): 344 | """SMLD(NCSN) discretization.""" 345 | if not self.linear: 346 | timestep = (t * (self.N - 1) / self.T).astype(jnp.int32) 347 | sigma = self.discrete_sigmas[timestep] 348 | adjacent_sigma = jnp.where(timestep == 0, jnp.zeros_like(timestep), self.discrete_sigmas[timestep - 1]) 349 | f = jnp.zeros_like(x) 350 | G = jnp.sqrt(sigma ** 2 - adjacent_sigma ** 2) 351 | return f, G 352 | else: 353 | return super().discretize(x, t) 354 | -------------------------------------------------------------------------------- /models/up_or_down_sampling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Layers used for up-sampling or down-sampling images. 18 | 19 | Many functions are ported from https://github.com/NVlabs/stylegan2. 20 | """ 21 | 22 | import flax.linen as nn 23 | from typing import Any, Tuple, Optional, Sequence 24 | import jax 25 | import jax.nn as jnn 26 | import jax.numpy as jnp 27 | import numpy as np 28 | 29 | 30 | # Function ported from StyleGAN2 31 | def get_weight(module, 32 | shape, 33 | weight_var='weight', 34 | kernel_init=None): 35 | """Get/create weight tensor for a convolution or fully-connected layer.""" 36 | 37 | return module.param(weight_var, kernel_init, shape) 38 | 39 | 40 | class Conv2d(nn.Module): 41 | """Conv2d layer with optimal upsampling and downsampling (StyleGAN2).""" 42 | fmaps: int 43 | kernel: int 44 | up: bool = False 45 | down: bool = False 46 | resample_kernel: Tuple[int] = (1, 3, 3, 1) 47 | use_bias: bool = True 48 | weight_var: str = 'weight' 49 | kernel_init: Optional[Any] = None 50 | 51 | @nn.compact 52 | def __call__(self, x): 53 | assert not (self.up and self.down) 54 | assert self.kernel >= 1 and self.kernel % 2 == 1 55 | w = get_weight(self, (self.kernel, self.kernel, x.shape[-1], self.fmaps), 56 | weight_var=self.weight_var, 57 | kernel_init=self.kernel_init) 58 | if self.up: 59 | x = upsample_conv_2d(x, w, data_format='NHWC', k=self.resample_kernel) 60 | elif self.down: 61 | x = conv_downsample_2d(x, w, data_format='NHWC', k=self.resample_kernel) 62 | else: 63 | x = jax.lax.conv_general_dilated( 64 | x, 65 | w, 66 | window_strides=(1, 1), 67 | padding='SAME', 68 | dimension_numbers=('NHWC', 'HWIO', 'NHWC')) 69 | 70 | if self.use_bias: 71 | b = self.param('bias', jnn.initializers.zeros, (x.shape[-1],)) 72 | x = x + b.reshape((1, 1, 1, -1)) 73 | return x 74 | 75 | 76 | def naive_upsample_2d(x, factor=2): 77 | _N, H, W, C = x.shape 78 | x = jnp.reshape(x, [-1, H, 1, W, 1, C]) 79 | x = jnp.tile(x, [1, 1, factor, 1, factor, 1]) 80 | return jnp.reshape(x, [-1, H * factor, W * factor, C]) 81 | 82 | 83 | def naive_downsample_2d(x, factor=2): 84 | _N, H, W, C = x.shape 85 | x = jnp.reshape(x, [-1, H // factor, factor, W // factor, factor, C]) 86 | return jnp.mean(x, axis=[2, 4]) 87 | 88 | 89 | def upsample_conv_2d(x, w, k=None, factor=2, gain=1, data_format='NHWC'): 90 | """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. 91 | 92 | Padding is performed only once at the beginning, not between the 93 | operations. 94 | The fused op is considerably more efficient than performing the same 95 | calculation 96 | using standard TensorFlow ops. It supports gradients of arbitrary order. 97 | Args: 98 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 99 | C]`. 100 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 101 | outChannels]`. Grouped convolution can be performed by `inChannels = 102 | x.shape[0] // numGroups`. 103 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 104 | (separable). The default is `[1] * factor`, which corresponds to 105 | nearest-neighbor upsampling. 106 | factor: Integer upsampling factor (default: 2). 107 | gain: Scaling factor for signal magnitude (default: 1.0). 108 | data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). 109 | 110 | Returns: 111 | Tensor of the shape `[N, C, H * factor, W * factor]` or 112 | `[N, H * factor, W * factor, C]`, and same datatype as `x`. 113 | """ 114 | 115 | assert isinstance(factor, int) and factor >= 1 116 | 117 | # Check weight shape. 118 | assert len(w.shape) == 4 119 | convH = w.shape[0] 120 | convW = w.shape[1] 121 | inC = w.shape[2] 122 | outC = w.shape[3] 123 | assert convW == convH 124 | 125 | # Setup filter kernel. 126 | if k is None: 127 | k = [1] * factor 128 | k = _setup_kernel(k) * (gain * (factor ** 2)) 129 | p = (k.shape[0] - factor) - (convW - 1) 130 | 131 | stride = [factor, factor] 132 | # Determine data dimensions. 133 | if data_format == 'NCHW': 134 | num_groups = _shape(x, 1) // inC 135 | else: 136 | num_groups = _shape(x, 3) // inC 137 | 138 | # Transpose weights. 139 | w = jnp.reshape(w, [convH, convW, inC, num_groups, -1]) 140 | w = jnp.transpose(w[::-1, ::-1], [0, 1, 4, 3, 2]) 141 | w = jnp.reshape(w, [convH, convW, -1, num_groups * inC]) 142 | 143 | ## Original TF code. 144 | # x = tf.nn.conv2d_transpose( 145 | # x, 146 | # w, 147 | # output_shape=output_shape, 148 | # strides=stride, 149 | # padding='VALID', 150 | # data_format=data_format) 151 | ## JAX equivalent 152 | x = jax.lax.conv_transpose( 153 | x, 154 | w, 155 | strides=stride, 156 | padding='VALID', 157 | transpose_kernel=True, 158 | dimension_numbers=(data_format, 'HWIO', data_format)) 159 | 160 | return _simple_upfirdn_2d( 161 | x, 162 | k, 163 | pad0=(p + 1) // 2 + factor - 1, 164 | pad1=p // 2 + 1, 165 | data_format=data_format) 166 | 167 | 168 | def conv_downsample_2d(x, w, k=None, factor=2, gain=1, data_format='NHWC'): 169 | """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. 170 | 171 | Padding is performed only once at the beginning, not between the operations. 172 | The fused op is considerably more efficient than performing the same 173 | calculation 174 | using standard TensorFlow ops. It supports gradients of arbitrary order. 175 | Args: 176 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 177 | C]`. 178 | w: Weight tensor of the shape `[filterH, filterW, inChannels, 179 | outChannels]`. Grouped convolution can be performed by `inChannels = 180 | x.shape[0] // numGroups`. 181 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 182 | (separable). The default is `[1] * factor`, which corresponds to 183 | average pooling. 184 | factor: Integer downsampling factor (default: 2). 185 | gain: Scaling factor for signal magnitude (default: 1.0). 186 | data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). 187 | 188 | Returns: 189 | Tensor of the shape `[N, C, H // factor, W // factor]` or 190 | `[N, H // factor, W // factor, C]`, and same datatype as `x`. 191 | """ 192 | 193 | assert isinstance(factor, int) and factor >= 1 194 | convH, convW, _inC, _outC = w.shape 195 | assert convW == convH 196 | if k is None: 197 | k = [1] * factor 198 | k = _setup_kernel(k) * gain 199 | p = (k.shape[0] - factor) + (convW - 1) 200 | s = [factor, factor] 201 | x = _simple_upfirdn_2d(x, k, pad0=(p + 1) // 2, 202 | pad1=p // 2, data_format=data_format) 203 | 204 | return jax.lax.conv_general_dilated( 205 | x, 206 | w, 207 | window_strides=s, 208 | padding='VALID', 209 | dimension_numbers=(data_format, 'HWIO', data_format)) 210 | 211 | 212 | def upfirdn_2d(x, k, upx, upy, downx, downy, padx0, padx1, pady0, pady1): 213 | """Pad, upsample, FIR filter, and downsample a batch of 2D images. 214 | 215 | Accepts a batch of 2D images of the shape `[majorDim, inH, inW, minorDim]` 216 | and performs the following operations for each image, batched across 217 | `majorDim` and `minorDim`: 218 | 1. Pad the image with zeros by the specified number of pixels on each side 219 | (`padx0`, `padx1`, `pady0`, `pady1`). Specifying a negative value 220 | corresponds to cropping the image. 221 | 2. Upsample the image by inserting the zeros after each pixel (`upx`, 222 | `upy`). 223 | 3. Convolve the image with the specified 2D FIR filter (`k`), shrinking the 224 | image so that the footprint of all output pixels lies within the input 225 | image. 226 | 4. Downsample the image by throwing away pixels (`downx`, `downy`). 227 | This sequence of operations bears close resemblance to 228 | scipy.signal.upfirdn(). 229 | The fused op is considerably more efficient than performing the same 230 | calculation 231 | using standard TensorFlow ops. It supports gradients of arbitrary order. 232 | Args: 233 | x: Input tensor of the shape `[majorDim, inH, inW, minorDim]`. 234 | k: 2D FIR filter of the shape `[firH, firW]`. 235 | upx: Integer upsampling factor along the X-axis (default: 1). 236 | upy: Integer upsampling factor along the Y-axis (default: 1). 237 | downx: Integer downsampling factor along the X-axis (default: 1). 238 | downy: Integer downsampling factor along the Y-axis (default: 1). 239 | padx0: Number of pixels to pad on the left side (default: 0). 240 | padx1: Number of pixels to pad on the right side (default: 0). 241 | pady0: Number of pixels to pad on the top side (default: 0). 242 | pady1: Number of pixels to pad on the bottom side (default: 0). 243 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` 244 | (default). 245 | 246 | Returns: 247 | Tensor of the shape `[majorDim, outH, outW, minorDim]`, and same 248 | datatype as `x`. 249 | """ 250 | k = jnp.asarray(k, dtype=np.float32) 251 | assert len(x.shape) == 4 252 | inH = x.shape[1] 253 | inW = x.shape[2] 254 | minorDim = x.shape[3] 255 | kernelH, kernelW = k.shape 256 | assert inW >= 1 and inH >= 1 257 | assert kernelW >= 1 and kernelH >= 1 258 | assert isinstance(upx, int) and isinstance(upy, int) 259 | assert isinstance(downx, int) and isinstance(downy, int) 260 | assert isinstance(padx0, int) and isinstance(padx1, int) 261 | assert isinstance(pady0, int) and isinstance(pady1, int) 262 | 263 | # Upsample (insert zeros). 264 | x = jnp.reshape(x, (-1, inH, 1, inW, 1, minorDim)) 265 | x = jnp.pad(x, [[0, 0], [0, 0], [0, upy - 1], [0, 0], [0, upx - 1], [0, 0]]) 266 | x = jnp.reshape(x, [-1, inH * upy, inW * upx, minorDim]) 267 | 268 | # Pad (crop if negative). 269 | x = jnp.pad(x, [[0, 0], [max(pady0, 0), max(pady1, 0)], 270 | [max(padx0, 0), max(padx1, 0)], [0, 0]]) 271 | x = x[:, 272 | max(-pady0, 0):x.shape[1] - max(-pady1, 0), 273 | max(-padx0, 0):x.shape[2] - max(-padx1, 0), :] 274 | 275 | # Convolve with filter. 276 | x = jnp.transpose(x, [0, 3, 1, 2]) 277 | x = jnp.reshape(x, 278 | [-1, 1, inH * upy + pady0 + pady1, inW * upx + padx0 + padx1]) 279 | w = jnp.array(k[::-1, ::-1, None, None], dtype=x.dtype) 280 | x = jax.lax.conv_general_dilated( 281 | x, 282 | w, 283 | window_strides=(1, 1), 284 | padding='VALID', 285 | dimension_numbers=('NCHW', 'HWIO', 'NCHW')) 286 | 287 | x = jnp.reshape(x, [ 288 | -1, minorDim, inH * upy + pady0 + pady1 - kernelH + 1, 289 | inW * upx + padx0 + padx1 - kernelW + 1 290 | ]) 291 | x = jnp.transpose(x, [0, 2, 3, 1]) 292 | 293 | # Downsample (throw away pixels). 294 | return x[:, ::downy, ::downx, :] 295 | 296 | 297 | def _simple_upfirdn_2d(x, k, up=1, down=1, pad0=0, pad1=0, data_format='NCHW'): 298 | assert data_format in ['NCHW', 'NHWC'] 299 | assert len(x.shape) == 4 300 | y = x 301 | if data_format == 'NCHW': 302 | y = jnp.reshape(y, [-1, y.shape[2], y.shape[3], 1]) 303 | y = upfirdn_2d( 304 | y, 305 | k, 306 | upx=up, 307 | upy=up, 308 | downx=down, 309 | downy=down, 310 | padx0=pad0, 311 | padx1=pad1, 312 | pady0=pad0, 313 | pady1=pad1) 314 | if data_format == 'NCHW': 315 | y = jnp.reshape(y, [-1, x.shape[1], y.shape[1], y.shape[2]]) 316 | return y 317 | 318 | 319 | def _setup_kernel(k): 320 | k = np.asarray(k, dtype=np.float32) 321 | if k.ndim == 1: 322 | k = np.outer(k, k) 323 | k /= np.sum(k) 324 | assert k.ndim == 2 325 | assert k.shape[0] == k.shape[1] 326 | return k 327 | 328 | 329 | def _shape(x, dim): 330 | return x.shape[dim] 331 | 332 | 333 | def upsample_2d(x, k=None, factor=2, gain=1, data_format='NHWC'): 334 | r"""Upsample a batch of 2D images with the given filter. 335 | 336 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 337 | and upsamples each image with the given filter. The filter is normalized so 338 | that 339 | if the input pixels are constant, they will be scaled by the specified 340 | `gain`. 341 | Pixels outside the image are assumed to be zero, and the filter is padded 342 | with 343 | zeros so that its shape is a multiple of the upsampling factor. 344 | Args: 345 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 346 | C]`. 347 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 348 | (separable). The default is `[1] * factor`, which corresponds to 349 | nearest-neighbor upsampling. 350 | factor: Integer upsampling factor (default: 2). 351 | gain: Scaling factor for signal magnitude (default: 1.0). 352 | data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). 353 | 354 | Returns: 355 | Tensor of the shape `[N, C, H * factor, W * factor]` or 356 | `[N, H * factor, W * factor, C]`, and same datatype as `x`. 357 | """ 358 | assert isinstance(factor, int) and factor >= 1 359 | if k is None: 360 | k = [1] * factor 361 | k = _setup_kernel(k) * (gain * (factor ** 2)) 362 | p = k.shape[0] - factor 363 | return _simple_upfirdn_2d( 364 | x, 365 | k, 366 | up=factor, 367 | pad0=(p + 1) // 2 + factor - 1, 368 | pad1=p // 2, 369 | data_format=data_format) 370 | 371 | 372 | def downsample_2d(x, k=None, factor=2, gain=1, data_format='NHWC'): 373 | r"""Downsample a batch of 2D images with the given filter. 374 | 375 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` 376 | and downsamples each image with the given filter. The filter is normalized 377 | so that 378 | if the input pixels are constant, they will be scaled by the specified 379 | `gain`. 380 | Pixels outside the image are assumed to be zero, and the filter is padded 381 | with 382 | zeros so that its shape is a multiple of the downsampling factor. 383 | Args: 384 | x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, 385 | C]`. 386 | k: FIR filter of the shape `[firH, firW]` or `[firN]` 387 | (separable). The default is `[1] * factor`, which corresponds to 388 | average pooling. 389 | factor: Integer downsampling factor (default: 2). 390 | gain: Scaling factor for signal magnitude (default: 1.0). 391 | data_format: `'NCHW'` or `'NHWC'` (default: `'NCHW'`). 392 | impl: Name of the implementation to use. Can be `"ref"` or 393 | `"cuda"` (default). 394 | 395 | Returns: 396 | Tensor of the shape `[N, C, H // factor, W // factor]` or 397 | `[N, H // factor, W // factor, C]`, and same datatype as `x`. 398 | """ 399 | 400 | assert isinstance(factor, int) and factor >= 1 401 | if k is None: 402 | k = [1] * factor 403 | k = _setup_kernel(k) * gain 404 | p = k.shape[0] - factor 405 | return _simple_upfirdn_2d( 406 | x, 407 | k, 408 | down=factor, 409 | pad0=(p + 1) // 2, 410 | pad1=p // 2, 411 | data_format=data_format) 412 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """All functions related to loss computation and optimization. 17 | """ 18 | 19 | import flax 20 | import jax 21 | import jax.numpy as jnp 22 | import jax.random as random 23 | from models import utils as mutils 24 | from sde_lib import VESDE, VPSDE 25 | from utils import batch_mul, get_div_fn, get_value_div_fn 26 | import functools 27 | import numpy as np 28 | from bound_likelihood import get_likelihood_offset_fn 29 | 30 | 31 | def get_optimizer(config, beta2=0.999): 32 | """Returns a flax optimizer object based on `config`.""" 33 | if config.optim.optimizer == 'Adam': 34 | optimizer = flax.optim.Adam(beta1=config.optim.beta1, beta2=beta2, eps=config.optim.eps, 35 | weight_decay=config.optim.weight_decay) 36 | else: 37 | raise NotImplementedError( 38 | f'Optimizer {config.optim.optimizer} not supported yet!') 39 | 40 | return optimizer 41 | 42 | 43 | def optimization_manager(config, deq_score_joint=False): 44 | """Returns an optimize_fn based on `config`.""" 45 | 46 | def optimize_fn(state, 47 | grad, 48 | warmup=config.optim.warmup, 49 | grad_clip=config.optim.grad_clip): 50 | """Optimizes with warmup and gradient clipping (disabled if negative).""" 51 | lr = state.lr 52 | if warmup > 0: 53 | lr = lr * jnp.minimum(state.step / warmup, 1.0) 54 | if grad_clip >= 0: 55 | # Compute global gradient norm 56 | grad_norm = jnp.sqrt( 57 | sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(grad)])) 58 | # Clip gradient 59 | clipped_grad = jax.tree_map( 60 | lambda x: x * grad_clip / jnp.maximum(grad_norm, grad_clip), grad) 61 | else: # disabling gradient clipping if grad_clip < 0 62 | clipped_grad = grad 63 | return state.optimizer.apply_gradient(clipped_grad, learning_rate=lr) 64 | 65 | def optimize_deq_score_fn(state, 66 | grad, 67 | warmup=config.optim.warmup, 68 | grad_clip=config.optim.grad_clip): 69 | lr = state.lr 70 | if warmup > 0: 71 | lr = lr * jnp.minimum(state.step / warmup, 1.0) 72 | if grad_clip >= 0: 73 | # Compute global gradient norm 74 | grad_norm = jnp.sqrt( 75 | sum([jnp.sum(jnp.square(x)) for x in jax.tree_leaves(grad)])) 76 | # Clip gradient 77 | clipped_grad = jax.tree_map( 78 | lambda x: x * grad_clip / jnp.maximum(grad_norm, grad_clip), grad) 79 | else: # disabling gradient clipping if grad_clip < 0 80 | clipped_grad = grad 81 | return (state.deq_optimizer.apply_gradient(clipped_grad['deq'], learning_rate=lr), 82 | state.score_optimizer.apply_gradient(clipped_grad['score'], learning_rate=lr)) 83 | 84 | if deq_score_joint: 85 | return optimize_deq_score_fn 86 | else: 87 | return optimize_fn 88 | 89 | 90 | def get_score_t(score_fn): 91 | def score_t(x, t, rng): 92 | tangent = jnp.ones_like(t) 93 | return jax.jvp(lambda time: score_fn(x, time, rng=rng), (t,), (tangent,)) 94 | 95 | return score_t 96 | 97 | 98 | def get_sde_loss_fn(sde, model, train, reduce_mean=True, continuous=True, likelihood_weighting=True, 99 | importance_weighting=True, eps=1e-5): 100 | """Create a loss function for training with arbirary SDEs. 101 | 102 | Args: 103 | sde: An `sde_lib.SDE` object that represents the forward SDE. 104 | model: A `flax.linen.Module` object that represents the architecture of the score-based model. 105 | train: `True` for training loss and `False` for evaluation loss. 106 | reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions. 107 | continuous: `True` indicates that the model is defined to take continuous time steps. Otherwise it requires 108 | ad-hoc interpolation to take continuous time steps. 109 | likelihood_weighting: If `True`, weight the mixture of score matching losses 110 | according to https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended in our paper. 111 | importance_weighting: If `True`, use importance weighting to reduce the variance of likelihood weighting. 112 | eps: A `float` number. The smallest time step to sample from. 113 | 114 | Returns: 115 | A loss function. 116 | """ 117 | reduce_op = jnp.mean if reduce_mean else lambda *args, **kwargs: 0.5 * jnp.sum(*args, **kwargs) 118 | 119 | def loss_fn(rng, params, states, batch): 120 | """Compute the loss function. 121 | 122 | Args: 123 | rng: A JAX random state. 124 | params: A dictionary that contains trainable parameters of the score-based model. 125 | states: A dictionary that contains mutable states of the score-based model. 126 | batch: A mini-batch of training data. 127 | 128 | Returns: 129 | loss: A scalar that represents the average loss value across the mini-batch. 130 | new_model_state: A dictionary that contains the mutated states of the score-based model. 131 | """ 132 | 133 | score_fn = mutils.get_score_fn(sde, model, params, states, train=train, continuous=continuous, return_state=True) 134 | data = batch['image'] 135 | 136 | rng, step_rng = random.split(rng) 137 | if likelihood_weighting and importance_weighting: 138 | t = sde.sample_importance_weighted_time_for_likelihood(step_rng, (data.shape[0],), eps=eps) 139 | else: 140 | t = random.uniform(step_rng, (data.shape[0],), minval=eps, maxval=sde.T) 141 | 142 | rng, step_rng = random.split(rng) 143 | z = random.normal(step_rng, data.shape) 144 | mean, std = sde.marginal_prob(data, t) 145 | perturbed_data = mean + batch_mul(std, z) 146 | rng, step_rng = random.split(rng) 147 | score, new_model_state = score_fn(perturbed_data, t, rng=step_rng) 148 | 149 | if likelihood_weighting: 150 | if importance_weighting: 151 | losses = jnp.square(batch_mul(score, std) + z) 152 | losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1) 153 | else: 154 | g2 = sde.sde(jnp.zeros_like(data), t)[1] ** 2 155 | losses = jnp.square(score + batch_mul(z, 1. / std)) 156 | losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1) * g2 157 | 158 | else: 159 | losses = jnp.square(batch_mul(score, std) + z) 160 | losses = reduce_op(losses.reshape((losses.shape[0], -1)), axis=-1) 161 | 162 | loss = jnp.mean(losses) 163 | return loss, new_model_state 164 | 165 | return loss_fn 166 | 167 | 168 | def get_step_fn(sde, model, train, optimize_fn=None, reduce_mean=False, continuous=True, likelihood_weighting=False, 169 | importance_weighting=False, smallest_time=1e-5): 170 | """Create a one-step training/evaluation function. 171 | 172 | Args: 173 | sde: An `sde_lib.SDE` object that represents the forward SDE. 174 | model: A `flax.linen.Module` object that represents the architecture of the score-based model. 175 | train: `True` for training and `False` for evaluation. 176 | optimize_fn: An optimization function. 177 | reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions. 178 | continuous: `True` indicates that the model is defined to take continuous time steps. 179 | likelihood_weighting: If `True`, weight the mixture of score matching losses according to 180 | https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper. 181 | 182 | Returns: 183 | A one-step function for training or evaluation. 184 | """ 185 | if continuous: 186 | loss_fn = get_sde_loss_fn(sde, model, train, reduce_mean=reduce_mean, 187 | continuous=True, likelihood_weighting=likelihood_weighting, 188 | importance_weighting=importance_weighting, 189 | eps=smallest_time) 190 | else: 191 | assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training." 192 | if isinstance(sde, VESDE): 193 | loss_fn = get_smld_loss_fn(sde, model, train, reduce_mean=reduce_mean) 194 | elif isinstance(sde, VPSDE): 195 | loss_fn = get_ddpm_loss_fn(sde, model, train, reduce_mean=reduce_mean) 196 | else: 197 | raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.") 198 | 199 | def step_fn(carry_state, batch): 200 | """Running one step of training or evaluation. 201 | 202 | This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together 203 | for faster execution. 204 | 205 | Args: 206 | carry_state: A tuple (JAX random state, `flax.struct.dataclass` containing the training state). 207 | batch: A mini-batch of training/evaluation data. 208 | 209 | Returns: 210 | new_carry_state: The updated tuple of `carry_state`. 211 | loss: The average loss value of this state. 212 | """ 213 | 214 | (rng, state) = carry_state 215 | rng, step_rng = jax.random.split(rng) 216 | grad_fn = jax.value_and_grad(loss_fn, argnums=1, has_aux=True) 217 | if train: 218 | params = state.optimizer.target 219 | states = state.model_state 220 | (loss, new_model_state), grad = grad_fn(step_rng, params, states, batch) 221 | 222 | grad = jax.lax.pmean(grad, axis_name='batch') 223 | new_optimizer = optimize_fn(state, grad) 224 | new_params_ema = jax.tree_multimap( 225 | lambda p_ema, p: p_ema * state.ema_rate + p * (1. - state.ema_rate), 226 | state.params_ema, new_optimizer.target 227 | ) 228 | step = state.step + 1 229 | new_state = state.replace( 230 | step=step, 231 | optimizer=new_optimizer, 232 | model_state=new_model_state, 233 | params_ema=new_params_ema 234 | ) 235 | else: 236 | loss, _ = loss_fn(step_rng, state.params_ema, state.model_state, batch) 237 | new_state = state 238 | 239 | loss = jax.lax.pmean(loss, axis_name='batch') 240 | new_carry_state = (rng, new_state) 241 | return new_carry_state, loss 242 | 243 | return step_fn 244 | 245 | 246 | def get_dequantization_loss_fn(sde, score_fn, deq_model, scaler, inverse_scaler, 247 | train=True, importance_weighting=True, eps=1e-5, 248 | eps_offset=True): 249 | def div_drift_fn(x, t, eps): 250 | div_fn = get_div_fn(lambda x, t: sde.sde(x, t)[0]) 251 | return div_fn(x, t, eps) 252 | 253 | def loss_fn(rng, params, batch): 254 | dequantizer = mutils.get_dequantizer(deq_model, params, train=train) 255 | 256 | data = batch['image'] 257 | shape = data.shape 258 | rng, step_rng = random.split(rng) 259 | u = random.normal(step_rng, shape) 260 | if train: 261 | rng, step_rng = random.split(rng) 262 | deq_noise, sldj = dequantizer(u, inverse_scaler(data), rng=step_rng) 263 | else: 264 | deq_noise, sldj = dequantizer(u, inverse_scaler(data)) 265 | 266 | data = scaler((inverse_scaler(data) * 255. + deq_noise) / 256.) 267 | 268 | mean, std = sde.marginal_prob(data, jnp.ones((shape[0],)) * sde.T) 269 | rng, step_rng = jax.random.split(rng) 270 | z = jax.random.normal(step_rng, shape) 271 | neg_prior_logp = -sde.prior_logp(mean + batch_mul(std, z)) 272 | 273 | rng, step_rng = random.split(rng) 274 | if importance_weighting: 275 | t = sde.sample_importance_weighted_time_for_likelihood(step_rng, (shape[0],), eps=eps) 276 | Z = sde.likelihood_importance_cum_weight(sde.T, eps=eps) 277 | else: 278 | t = random.uniform(step_rng, (shape[0],), minval=eps, maxval=sde.T) 279 | 280 | rng, step_rng = random.split(rng) 281 | z = random.normal(step_rng, shape) 282 | mean, std = sde.marginal_prob(data, t) 283 | perturbed_data = mean + batch_mul(std, z) 284 | 285 | score = score_fn(perturbed_data, t) 286 | if importance_weighting: 287 | losses = jnp.square(batch_mul(score, std) + z) 288 | losses = jnp.sum(losses.reshape((losses.shape[0], -1)), axis=-1) 289 | grad_norm = jnp.square(z).reshape((z.shape[0], -1)).sum(axis=-1) 290 | losses = (losses - grad_norm) * Z 291 | else: 292 | g2 = sde.sde(jnp.zeros_like(data), t)[1] ** 2 293 | losses = jnp.square(score + batch_mul(z, 1. / std)) 294 | losses = jnp.sum(losses.reshape((losses.shape[0], -1)), axis=-1) * g2 295 | grad_norm = jnp.square(z).reshape((z.shape[0], -1)).sum(axis=-1) 296 | grad_norm = grad_norm * g2 / (std ** 2) 297 | losses = losses - grad_norm 298 | 299 | rng, step_rng = random.split(rng) 300 | z = random.normal(step_rng, shape) 301 | rng, step_rng = random.split(rng) 302 | t = random.uniform(step_rng, (shape[0],), minval=eps, maxval=sde.T) 303 | mean, std = sde.marginal_prob(data, t) 304 | noisy_data = mean + batch_mul(std, z) 305 | rng, step_rng = random.split(rng) 306 | epsilon = random.rademacher(step_rng, shape, dtype=jnp.float32) 307 | drift_div = div_drift_fn(noisy_data, t, epsilon) 308 | 309 | losses = neg_prior_logp + 0.5 * (losses - 2 * drift_div) - sldj 310 | if eps_offset: 311 | offset_fn = get_likelihood_offset_fn(sde, score_fn, eps) 312 | rng, step_rng = random.split(rng) 313 | losses = losses + offset_fn(step_rng, data) 314 | 315 | dim = np.prod(shape[1:]) 316 | bpd = losses / np.log(2.) 317 | bpd = bpd / dim 318 | offset = jnp.log2(jax.grad(inverse_scaler)(0.)) + 8. 319 | bpd += offset 320 | bpd = bpd.mean() 321 | 322 | loss = jnp.mean(losses) 323 | 324 | return loss, bpd 325 | 326 | return loss_fn 327 | 328 | 329 | def get_dequantizer_step_fn(sde, score_fn, deq_model, scaler, inverse_scaler, 330 | train, optimize_fn=None, importance_weighting=False, smallest_time=1e-5, 331 | eps_offset=True): 332 | loss_fn = get_dequantization_loss_fn(sde, score_fn, deq_model, scaler, inverse_scaler, train, 333 | importance_weighting=importance_weighting, eps=smallest_time, 334 | eps_offset=eps_offset) 335 | 336 | def step_fn(carry_state, batch): 337 | """Running one step of training or evaluation. 338 | 339 | This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together 340 | for faster execution. 341 | 342 | Args: 343 | carry_state: A tuple (JAX random state, `flax.struct.dataclass` containing the training state). 344 | batch: A mini-batch of training/evaluation data. 345 | 346 | Returns: 347 | new_carry_state: The updated tuple of `carry_state`. 348 | loss: The average loss value of this state. 349 | """ 350 | 351 | (rng, state) = carry_state 352 | rng, step_rng = jax.random.split(rng) 353 | grad_fn = jax.value_and_grad(loss_fn, argnums=1, has_aux=True) 354 | if train: 355 | params = state.optimizer.target 356 | (loss, bpd), grad = grad_fn(step_rng, params, batch) 357 | 358 | grad = jax.lax.pmean(grad, axis_name='batch') 359 | new_optimizer = optimize_fn(state, grad) 360 | new_params_ema = jax.tree_multimap( 361 | lambda p_ema, p: p_ema * state.ema_rate + p * (1. - state.ema_rate), 362 | state.params_ema, new_optimizer.target 363 | ) 364 | step = state.step + 1 365 | new_state = state.replace( 366 | step=step, 367 | optimizer=new_optimizer, 368 | params_ema=new_params_ema 369 | ) 370 | else: 371 | loss, bpd = loss_fn(step_rng, state.params_ema, batch) 372 | new_state = state 373 | 374 | loss = jax.lax.pmean(loss, axis_name='batch') 375 | new_carry_state = (rng, new_state) 376 | return new_carry_state, (loss, bpd) 377 | 378 | return step_fn 379 | -------------------------------------------------------------------------------- /models/flowpp/modules_cifar10.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ported from https://github.com/aravindsrinivas/flowpp/blob/737fadb2218c1e2810a91b523498f97def2c30de/flows/flows.py 3 | """ 4 | 5 | import flax 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import flax.linen as nn 10 | from typing import Any 11 | from .logistic import * 12 | from jax.experimental import host_callback 13 | 14 | 15 | def safe_log(x): 16 | return jnp.log(jnp.clip(x, a_min=1e-7)) 17 | 18 | 19 | class CheckerboardSplit(nn.Module): 20 | inverse_module: bool = False 21 | 22 | @nn.compact 23 | def __call__(self, x, inverse=False): 24 | if self.inverse_module: 25 | inverse = not inverse 26 | 27 | if not inverse: 28 | B, H, W, C = x.shape 29 | x = jnp.reshape(x, [B, H, W // 2, 2, C]) 30 | a = x[:, :, :, 0, :] 31 | b = x[:, :, :, 1, :] 32 | assert a.shape == b.shape == (B, H, W // 2, C) 33 | return (a, b), None 34 | else: 35 | a, b = x 36 | assert a.shape == b.shape 37 | B, H, W_half, C = a.shape 38 | y = jnp.stack([a, b], axis=3) 39 | assert y.shape == (B, H, W_half, 2, C) 40 | return jnp.reshape(y, [B, H, W_half * 2, C]), None 41 | 42 | 43 | def init_normalization(self, x, *, name, init_scale=1.): 44 | def g_initializer(rng, x): 45 | # data based normalization 46 | # v_init = jnp.var(x, axis=0) 47 | m_init = jax.lax.pmean(jnp.mean(x, axis=0), axis_name='batch') 48 | m2_init = jax.lax.pmean(jnp.mean(x ** 2, axis=0), axis_name='batch') 49 | v_init = m2_init - m_init ** 2 50 | return init_scale * jax.lax.rsqrt(v_init + 1e-6) 51 | 52 | def b_initializer(rng, x): 53 | # data based normalization 54 | # m_init = jnp.mean(x, axis=0) 55 | # v_init = jnp.var(x, axis=0) 56 | 57 | m_init = jax.lax.pmean(jnp.mean(x, axis=0), axis_name='batch') 58 | m2_init = jax.lax.pmean(jnp.mean(x ** 2, axis=0), axis_name='batch') 59 | v_init = m2_init - m_init ** 2 60 | 61 | scale_init = init_scale * jax.lax.rsqrt(v_init + 1e-6) 62 | assert m_init.shape == v_init.shape == scale_init.shape 63 | 64 | return -m_init * scale_init 65 | 66 | g = self.param(f'{name}_g', g_initializer, x) 67 | b = self.param(f'{name}_b', b_initializer, x) 68 | 69 | return g, b 70 | 71 | 72 | class Norm(nn.Module): 73 | init_scale: float = 1. 74 | 75 | @nn.compact 76 | def __call__(self, inputs, inverse=False): 77 | assert not isinstance(inputs, list) 78 | if isinstance(inputs, tuple): 79 | is_tuple = True 80 | else: 81 | inputs = [inputs] 82 | is_tuple = False 83 | 84 | bs = int(inputs[0].shape[0]) 85 | g_and_b = [] 86 | for (i, x) in enumerate(inputs): 87 | g, b = init_normalization(self, x, name='norm{}'.format(i), init_scale=self.init_scale) 88 | g = jnp.maximum(g, 1e-10) 89 | assert x.shape[0] == bs and g.shape == b.shape == x.shape[1:] 90 | g_and_b.append((g, b)) 91 | 92 | logd = jnp.full([bs], sum([jnp.sum(safe_log(g)) for (g, _) in g_and_b])) 93 | if not inverse: 94 | out = [x * g[None] + b[None] for (x, (g, b)) in zip(inputs, g_and_b)] 95 | else: 96 | out = [(x - b[None]) / g[None] for (x, (g, b)) in zip(inputs, g_and_b)] 97 | logd = -logd 98 | 99 | if not is_tuple: 100 | assert len(out) == 1 101 | return out[0], logd 102 | return tuple(out), logd 103 | 104 | 105 | class Pointwise(nn.Module): 106 | noisy_identity_init: float = 0.001 107 | 108 | @nn.compact 109 | def __call__(self, inputs, noisy_identity_init=0.001, inverse=False): 110 | assert not isinstance(inputs, list) 111 | if isinstance(inputs, tuple): 112 | is_tuple = True 113 | else: 114 | inputs = [inputs] 115 | is_tuple = False 116 | 117 | out, logds = [], [] 118 | for i, x in enumerate(inputs): 119 | if self.noisy_identity_init: 120 | # identity + gaussian noise 121 | def initializer(key, x): 122 | _, img_h, img_w, img_c = x.shape 123 | return jnp.eye(img_c) + self.noisy_identity_init * jax.random.normal(key, (img_c, img_c)) 124 | else: 125 | # random orthogonal 126 | def initializer(key, x): 127 | _, img_h, img_w, img_c = x.shape 128 | return jnp.linalg.qr(jax.random.normal(key, (img_c, img_c)))[0] 129 | 130 | W = self.param('W{}'.format(i), initializer, x) 131 | out.append(self._nin(x, W if not inverse else jnp.linalg.inv(W))) 132 | _, img_h, img_w, img_c = x.shape 133 | logds.append((1 if not inverse else -1) * img_h * img_w * jnp.linalg.slogdet(W)[1]) 134 | logd = jnp.full([inputs[0].shape[0]], sum(logds)) 135 | 136 | if not is_tuple: 137 | assert len(out) == 1 138 | return out[0], logd 139 | return tuple(out), logd 140 | 141 | @staticmethod 142 | def _nin(x, w, b=None): 143 | _, out_dim = w.shape 144 | s = x.shape 145 | x = jnp.reshape(x, [np.prod(s[:-1]), s[-1]]) 146 | x = x @ w 147 | if b is not None: 148 | assert len(b.shape) == 1 149 | x = x + b[None, :] 150 | return jnp.reshape(x, s[:-1] + (out_dim,)) 151 | 152 | 153 | def conv2d(self, x, *, name, num_units, filter_size=(3, 3), stride=(1, 1), pad='SAME', init_scale=1.): 154 | assert len(x.shape) == 4 155 | 156 | def W_initializer(rng, x): 157 | W = jax.random.normal(rng, [*filter_size, int(x.shape[-1]), num_units]) * 0.05 158 | y = jax.lax.conv_general_dilated(x, W, window_strides=stride, padding=pad, 159 | dimension_numbers=('NHWC', 'HWIO', 'NWHC')) 160 | # v_init = jnp.var(y, axis=(0, 1, 2)) 161 | m_init = jax.lax.pmean(jnp.mean(y, axis=(0, 1, 2)), axis_name='batch') 162 | m2_init = jax.lax.pmean(jnp.mean(y ** 2, axis=(0, 1, 2)), axis_name='batch') 163 | v_init = m2_init - m_init ** 2 164 | scale_init = init_scale * jax.lax.rsqrt(v_init + 1e-6) 165 | 166 | return W * scale_init[None, None, None, :] 167 | 168 | def b_initializer(rng, x): 169 | W = jax.random.normal(rng, [*filter_size, int(x.shape[-1]), num_units]) * 0.05 170 | y = jax.lax.conv_general_dilated(x, W, window_strides=stride, padding=pad, 171 | dimension_numbers=('NHWC', 'HWIO', 'NWHC')) 172 | # m_init = jnp.mean(y, axis=(0, 1, 2)) 173 | # v_init = jnp.var(y, axis=(0, 1, 2)) 174 | m_init = jax.lax.pmean(jnp.mean(y, axis=(0, 1, 2)), axis_name='batch') 175 | m2_init = jax.lax.pmean(jnp.mean(y ** 2, axis=(0, 1, 2)), axis_name='batch') 176 | v_init = m2_init - m_init ** 2 177 | scale_init = init_scale * jax.lax.rsqrt(v_init + 1e-6) 178 | 179 | return -m_init * scale_init 180 | 181 | W = self.param(f'{name}_W', W_initializer, x) 182 | b = self.param(f'{name}_b', b_initializer, x) 183 | 184 | return jax.lax.conv_general_dilated(x, W, window_strides=stride, padding=pad, 185 | dimension_numbers=('NWHC', 'HWIO', 'NWHC')) + b[None, None, None, :] 186 | 187 | 188 | def concat_elu(x): 189 | axis = len(x.shape) - 1 190 | return jax.nn.elu(jnp.concatenate([x, -x], axis)) 191 | 192 | 193 | def dense(self, x, *, name, num_units, init_scale=1.): 194 | _, in_dim = x.shape 195 | 196 | def W_initializer(rng, x): 197 | W = jax.random.normal(rng, [in_dim, num_units]) * 0.05 198 | y = x @ W 199 | # v_init = jnp.var(y, axis=0) 200 | m_init = jax.lax.pmean(jnp.mean(y, axis=0), axis_name='batch') 201 | m2_init = jax.lax.pmean(jnp.mean(y ** 2, axis=0), axis_name='batch') 202 | v_init = m2_init - m_init ** 2 203 | scale_init = init_scale * jax.lax.rsqrt(v_init + 1e-6) 204 | return W * scale_init[None, :] 205 | 206 | def b_initializer(rng, x): 207 | W = jax.random.normal(rng, [in_dim, num_units]) * 0.05 208 | y = x @ W 209 | # m_init = jnp.mean(y, axis=0) 210 | # v_init = jnp.var(y, axis=0) 211 | m_init = jax.lax.pmean(jnp.mean(y, axis=0), axis_name='batch') 212 | m2_init = jax.lax.pmean(jnp.mean(y ** 2, axis=0), axis_name='batch') 213 | v_init = m2_init - m_init ** 2 214 | scale_init = init_scale * jax.lax.rsqrt(v_init + 1e-6) 215 | return -m_init * scale_init 216 | 217 | W = self.param(f'{name}_W', W_initializer, x) 218 | b = self.param(f'{name}_b', b_initializer, x) 219 | 220 | return x @ W + b[None, :] 221 | 222 | 223 | def nin(self, x, *, num_units, **kwargs): 224 | assert 'num_units' not in kwargs 225 | s = x.shape 226 | x = jnp.reshape(x, [np.prod(s[:-1]), s[-1]]) 227 | x = dense(self, x, num_units=num_units, **kwargs) 228 | return jnp.reshape(x, s[:-1] + (num_units,)) 229 | 230 | 231 | def gate(x, *, axis): 232 | a, b = jnp.split(x, 2, axis=axis) 233 | return a * jax.nn.sigmoid(b) 234 | 235 | 236 | def gated_conv(self, x, *, name, a, nonlinearity=concat_elu, conv=conv2d, use_nin, dropout_p, train=False): 237 | num_filters = int(x.shape[-1]) 238 | 239 | c1 = conv(self, nonlinearity(x), name=f'{name}_c1', num_units=num_filters) 240 | if a is not None: # add short-cut connection if auxiliary input 'a' is given 241 | c1 += nin(self, nonlinearity(a), name=f'{name}_a_proj', num_units=num_filters) 242 | c1 = nonlinearity(c1) 243 | if dropout_p > 0: 244 | c1 = nn.Dropout(rate=dropout_p, deterministic=not train)(c1) 245 | 246 | c2 = (nin if use_nin else conv)(self, c1, name=f'{name}_c2', num_units=num_filters * 2, init_scale=0.1) 247 | return x + gate(c2, axis=3) 248 | 249 | 250 | def layernorm(self, x, *, name, e=1e-5): 251 | """Layer norm over last axis""" 252 | shape = [1] * (len(x.shape) - 1) + [int(x.shape[-1])] 253 | g = self.param(f'{name}_g', jax.nn.initializers.ones, shape) 254 | b = self.param(f'{name}_b', jax.nn.initializers.zeros, shape) 255 | u = jnp.mean(x, axis=-1, keepdims=True) 256 | s = jnp.mean(jnp.square(x - u), axis=-1, keepdims=True) 257 | return (x - u) * jax.lax.rsqrt(s + e) * g + b 258 | 259 | 260 | def gated_attn(self, x, *, name, pos_emb, heads, dropout_p, train=False): 261 | bs, height, width, ch = x.shape 262 | assert pos_emb.shape == (height, width, ch) 263 | assert ch % heads == 0 264 | timesteps = height * width 265 | dim = ch // heads 266 | # Position embeddings 267 | c = x + pos_emb[None, :, :, :] 268 | # b, h, t, d == batch, num heads, num timesteps, per-head dim (C // heads) 269 | c = nin(self, c, name=f'{name}_proj1', num_units=3 * ch) 270 | assert c.shape == (bs, height, width, 3 * ch) 271 | # Split into heads / Q / K / V 272 | c = jnp.reshape(c, [bs, timesteps, 3, heads, dim]) # b, t, 3, h, d 273 | c = jnp.transpose(c, [2, 0, 3, 1, 4]) # 3, b, h, t, d 274 | q_bhtd, k_bhtd, v_bhtd = c[0, ...], c[1, ...], c[2, ...] 275 | assert q_bhtd.shape == k_bhtd.shape == v_bhtd.shape == (bs, heads, timesteps, dim) 276 | # Attention 277 | w_bhtt = jnp.einsum('bhTD,bhtD->bhTt', q_bhtd, k_bhtd) / np.sqrt(float(dim)) 278 | w_bhtt = jax.nn.softmax(w_bhtt) 279 | assert w_bhtt.shape == (bs, heads, timesteps, timesteps) 280 | a_bhtd = jnp.einsum('bhTt,bhtd->bhTd', w_bhtt, v_bhtd) 281 | # Merge heads 282 | a_bthd = jnp.transpose(a_bhtd, [0, 2, 1, 3]) 283 | assert a_bthd.shape == (bs, timesteps, heads, dim) 284 | a_btc = jnp.reshape(a_bthd, [bs, timesteps, ch]) 285 | # Project 286 | c1 = jnp.reshape(a_btc, [bs, height, width, ch]) 287 | if dropout_p > 0: 288 | c1 = nn.Dropout(rate=dropout_p, deterministic=not train)(c1) 289 | c2 = nin(self, c1, name=f'{name}_proj2', num_units=ch * 2, init_scale=0.1) 290 | return x + gate(c2, axis=3) 291 | 292 | 293 | def sumflat(x): 294 | return jnp.sum(jnp.reshape(x, [x.shape[0], -1]), axis=1) 295 | 296 | 297 | def inverse_sigmoid(x): 298 | return -safe_log(jax.lax.reciprocal(x) - 1.) 299 | 300 | 301 | class Sigmoid(nn.Module): 302 | inverse_module: bool = False 303 | 304 | @nn.compact 305 | def __call__(self, x, inverse=False): 306 | if self.inverse_module: 307 | inverse = not inverse 308 | if not inverse: 309 | y = jax.nn.sigmoid(x) 310 | logd = -jax.nn.softplus(x) - jax.nn.softplus(-x) 311 | return y, sumflat(logd) 312 | else: 313 | y = inverse_sigmoid(x) 314 | logd = -safe_log(x) - safe_log(1. - x) 315 | return y, sumflat(logd) 316 | 317 | 318 | class MixLogisticCDF(nn.Module): 319 | """ 320 | Elementwise transformation by the CDF of a mixture of logistics 321 | """ 322 | min_logscale: float = -7. 323 | 324 | @nn.compact 325 | def __call__(self, x, logits, means, logscales, inverse=False): 326 | logistic_kwargs = dict( 327 | prior_logits=logits, 328 | means=means, 329 | logscales=jnp.maximum(logscales, self.min_logscale) 330 | ) 331 | if not inverse: 332 | out = jnp.exp(mixlogistic_logcdf(x=x, **logistic_kwargs)) 333 | logd = mixlogistic_logpdf(x=x, **logistic_kwargs) 334 | return out, sumflat(logd) 335 | else: 336 | out = mixlogistic_invcdf(y=jnp.clip(x, 0., 1.), **logistic_kwargs) 337 | logd = -mixlogistic_logpdf(x=out, **logistic_kwargs) 338 | return out, sumflat(logd) 339 | 340 | 341 | class ElemwiseAffine(nn.Module): 342 | 343 | @nn.compact 344 | def __call__(self, x, scales, biases, logscales=None, inverse=False): 345 | logscales = safe_log(scales) if logscales is None else logscales 346 | if not inverse: 347 | assert logscales.shape == x.shape 348 | return (x * scales + biases), sumflat(logscales) 349 | else: 350 | y = x 351 | assert logscales.shape == y.shape 352 | return ((y - biases) / scales), sumflat(-logscales) 353 | 354 | 355 | class MixLogisticAttnCoupling(nn.Module): 356 | """ 357 | CDF of mixture of logistics, followed by affine 358 | """ 359 | filters: int 360 | blocks: int 361 | components: int 362 | heads: int = 4 363 | init_scale: float = 0.1 364 | dropout_p: float = 0. 365 | verbose: bool = True 366 | 367 | @nn.compact 368 | def __call__(self, x, context=None, inverse=False, train=False): 369 | def f(x, *, context=None): 370 | if not self.has_variable('params', 'pos_emb') and self.verbose: 371 | # debug stuff 372 | def tap_func(x, transforms): 373 | xmean = jnp.mean(x, axis=list(range(len(x.shape)))) 374 | xvar = jnp.var(x, axis=list(range(len(x.shape)))) 375 | print(f'shape: {jnp.shape(x)}') 376 | print(f'mean: {xmean}') 377 | print(f'std: {jnp.sqrt(xvar)}') 378 | print(f'min: {jnp.min(x)}') 379 | print(f'max: {jnp.max(x)}') 380 | 381 | x = host_callback.id_tap(tap_func, x) 382 | 383 | B, H, W, C = x.shape 384 | pos_emb = self.param('pos_emb', jax.nn.initializers.normal(stddev=0.01), [H, W, self.filters]) 385 | x = conv2d(self, x, name='proj_in', num_units=self.filters) 386 | for i_block in range(self.blocks): 387 | name = f'block{i_block}' 388 | x = gated_conv(self, x, name=f'{name}_conv', a=context, use_nin=True, dropout_p=self.dropout_p, train=train) 389 | x = layernorm(self, x, name=f'{name}_ln1') 390 | x = gated_attn(self, x, name=f'{name}_attn', pos_emb=pos_emb, heads=self.heads, dropout_p=self.dropout_p, 391 | train=train) 392 | x = layernorm(self, x, name=f'{name}_ln2') 393 | x = conv2d(self, x, name=f'{name}_proj_out', num_units=C * (2 + 3 * self.components), init_scale=self.init_scale) 394 | assert x.shape == (B, H, W, C * (2 + 3 * self.components)) 395 | x = jnp.reshape(x, [B, H, W, C, 2 + 3 * self.components]) 396 | 397 | s, t = jnp.tanh(x[:, :, :, :, 0]), x[:, :, :, :, 1] 398 | ml_logits, ml_means, ml_logscales = jnp.split(x[:, :, :, :, 2:], 3, axis=4) 399 | assert s.shape == t.shape == (B, H, W, C) 400 | assert ml_logits.shape == ml_means.shape == ml_logscales.shape == (B, H, W, C, self.components) 401 | return ml_logits, ml_means, ml_logscales, s, t 402 | 403 | assert isinstance(x, tuple) 404 | cf, ef = x 405 | ml_logits, ml_means, ml_logscales, s, t = f(cf, context=context) 406 | logp_sum = 0. 407 | 408 | mixlogistic_cdf = MixLogisticCDF() 409 | sigmoid = Sigmoid(inverse_module=True) 410 | elementwise_affine = ElemwiseAffine() 411 | 412 | if not inverse: 413 | h, logp = mixlogistic_cdf(ef, logits=ml_logits, means=ml_means, logscales=ml_logscales, inverse=False) 414 | if logp is not None: 415 | logp_sum = logp_sum + logp 416 | h, logp = sigmoid(h, inverse=False) 417 | if logp is not None: 418 | logp_sum = logp_sum + logp 419 | h, logp = elementwise_affine(h, scales=jnp.exp(s), biases=t, logscales=s, inverse=False) 420 | if logp is not None: 421 | logp_sum = logp_sum + logp 422 | return (cf, h), logp_sum 423 | 424 | else: 425 | h, logp = elementwise_affine(ef, scales=jnp.exp(s), biases=t, logscales=s, inverse=True) 426 | if logp is not None: 427 | logp_sum = logp_sum + logp 428 | h, logp = sigmoid(h, inverse=True) 429 | if logp is not None: 430 | logp_sum = logp_sum + logp 431 | h, logp = mixlogistic_cdf(h, logits=ml_logits, means=ml_means, logscales=ml_logscales, inverse=True) 432 | if logp is not None: 433 | logp_sum = logp_sum + logp 434 | return (cf, h), logp_sum 435 | 436 | 437 | class TupleFlip(nn.Module): 438 | @nn.compact 439 | def __call__(self, x, inverse=False): 440 | assert isinstance(x, tuple) 441 | a, b = x 442 | return (b, a), None 443 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # pylint: skip-file 17 | """Common layers for defining score networks. 18 | """ 19 | import functools 20 | import math 21 | import string 22 | from typing import Any, Sequence, Optional 23 | 24 | import flax.linen as nn 25 | import jax 26 | import jax.nn as jnn 27 | import jax.numpy as jnp 28 | 29 | 30 | def get_act(config): 31 | """Get activation functions from the config file.""" 32 | 33 | if config.model.nonlinearity.lower() == 'elu': 34 | return nn.elu 35 | elif config.model.nonlinearity.lower() == 'relu': 36 | return nn.relu 37 | elif config.model.nonlinearity.lower() == 'lrelu': 38 | return functools.partial(nn.leaky_relu, negative_slope=0.2) 39 | elif config.model.nonlinearity.lower() == 'swish': 40 | return nn.swish 41 | else: 42 | raise NotImplementedError('activation function does not exist!') 43 | 44 | 45 | def ncsn_conv1x1(x, out_planes, stride=1, bias=True, dilation=1, init_scale=1.): 46 | """1x1 convolution with PyTorch initialization. Same as NCSNv1/v2.""" 47 | init_scale = 1e-10 if init_scale == 0 else init_scale 48 | kernel_init = jnn.initializers.variance_scaling(1 / 3 * init_scale, 'fan_in', 49 | 'uniform') 50 | kernel_shape = (1, 1) + (x.shape[-1], out_planes) 51 | bias_init = lambda key, shape: kernel_init(key, kernel_shape)[0, 0, 0, :] 52 | output = nn.Conv(out_planes, kernel_size=(1, 1), 53 | strides=(stride, stride), padding='SAME', use_bias=bias, 54 | kernel_dilation=(dilation, dilation), 55 | kernel_init=kernel_init, 56 | bias_init=bias_init)(x) 57 | return output 58 | 59 | 60 | def default_init(scale=1.): 61 | """The same initialization used in DDPM.""" 62 | scale = 1e-10 if scale == 0 else scale 63 | return jnn.initializers.variance_scaling(scale, 'fan_avg', 'uniform') 64 | 65 | 66 | def ddpm_conv1x1(x, out_planes, stride=1, bias=True, dilation=1, init_scale=1.): 67 | """1x1 convolution with DDPM initialization.""" 68 | bias_init = jnn.initializers.zeros 69 | output = nn.Conv(out_planes, kernel_size=(1, 1), 70 | strides=(stride, stride), padding='SAME', use_bias=bias, 71 | kernel_dilation=(dilation, dilation), 72 | kernel_init=default_init(init_scale), 73 | bias_init=bias_init)(x) 74 | return output 75 | 76 | 77 | def ncsn_conv3x3(x, out_planes, stride=1, bias=True, dilation=1, init_scale=1.): 78 | """3x3 convolution with PyTorch initialization. Same as NCSNv1/NCSNv2.""" 79 | init_scale = 1e-10 if init_scale == 0 else init_scale 80 | kernel_init = jnn.initializers.variance_scaling(1 / 3 * init_scale, 'fan_in', 81 | 'uniform') 82 | kernel_shape = (3, 3) + (x.shape[-1], out_planes) 83 | bias_init = lambda key, shape: kernel_init(key, kernel_shape)[0, 0, 0, :] 84 | output = nn.Conv(out_planes, 85 | kernel_size=(3, 3), 86 | strides=(stride, stride), 87 | padding='SAME', 88 | use_bias=bias, 89 | kernel_dilation=(dilation, dilation), 90 | kernel_init=kernel_init, 91 | bias_init=bias_init)(x) 92 | return output 93 | 94 | 95 | def ddpm_conv3x3(x, out_planes, stride=1, bias=True, dilation=1, init_scale=1.): 96 | """3x3 convolution with DDPM initialization.""" 97 | bias_init = jnn.initializers.zeros 98 | output = nn.Conv( 99 | out_planes, 100 | kernel_size=(3, 3), 101 | strides=(stride, stride), 102 | padding='SAME', 103 | use_bias=bias, 104 | kernel_dilation=(dilation, dilation), 105 | kernel_init=default_init(init_scale), 106 | bias_init=bias_init)(x) 107 | return output 108 | 109 | 110 | ########################################################################### 111 | # Functions below are ported over from the NCSNv1/NCSNv2 codebase: 112 | # https://github.com/ermongroup/ncsn 113 | # https://github.com/ermongroup/ncsnv2 114 | ########################################################################### 115 | 116 | 117 | class CRPBlock(nn.Module): 118 | """CRPBlock for RefineNet. Used in NCSNv2.""" 119 | features: int 120 | n_stages: int 121 | act: Any = nn.relu 122 | 123 | @nn.compact 124 | def __call__(self, x): 125 | x = self.act(x) 126 | path = x 127 | for _ in range(self.n_stages): 128 | path = nn.max_pool( 129 | path, window_shape=(5, 5), strides=(1, 1), padding='SAME') 130 | path = ncsn_conv3x3(path, self.features, stride=1, bias=False) 131 | x = path + x 132 | return x 133 | 134 | 135 | class CondCRPBlock(nn.Module): 136 | """Noise-conditional CRPBlock for RefineNet. Used in NCSNv1.""" 137 | features: int 138 | n_stages: int 139 | normalizer: Any 140 | act: Any = nn.relu 141 | 142 | @nn.compact 143 | def __call__(self, x, y): 144 | x = self.act(x) 145 | path = x 146 | for _ in range(self.n_stages): 147 | path = self.normalizer()(path, y) 148 | path = nn.avg_pool(path, window_shape=(5, 5), strides=(1, 1), padding='SAME') 149 | path = ncsn_conv3x3(path, self.features, stride=1, bias=False) 150 | x = path + x 151 | return x 152 | 153 | 154 | class RCUBlock(nn.Module): 155 | """RCUBlock for RefineNet. Used in NCSNv2.""" 156 | features: int 157 | n_blocks: int 158 | n_stages: int 159 | act: Any = nn.relu 160 | 161 | @nn.compact 162 | def __call__(self, x): 163 | for _ in range(self.n_blocks): 164 | residual = x 165 | for _ in range(self.n_stages): 166 | x = self.act(x) 167 | x = ncsn_conv3x3(x, self.features, stride=1, bias=False) 168 | x = x + residual 169 | 170 | return x 171 | 172 | 173 | class CondRCUBlock(nn.Module): 174 | """Noise-conditional RCUBlock for RefineNet. Used in NCSNv1.""" 175 | features: int 176 | n_blocks: int 177 | n_stages: int 178 | normalizer: Any 179 | act: Any = nn.relu 180 | 181 | @nn.compact 182 | def __call__(self, x, y): 183 | for _ in range(self.n_blocks): 184 | residual = x 185 | for _ in range(self.n_stages): 186 | x = self.normalizer()(x, y) 187 | x = self.act(x) 188 | x = ncsn_conv3x3(x, self.features, stride=1, bias=False) 189 | x += residual 190 | return x 191 | 192 | 193 | class MSFBlock(nn.Module): 194 | """MSFBlock for RefineNet. Used in NCSNv2.""" 195 | shape: Sequence[int] 196 | features: int 197 | interpolation: str = 'bilinear' 198 | 199 | @nn.compact 200 | def __call__(self, xs): 201 | sums = jnp.zeros((xs[0].shape[0], *self.shape, self.features)) 202 | for i in range(len(xs)): 203 | h = ncsn_conv3x3(xs[i], self.features, stride=1, bias=True) 204 | if self.interpolation == 'bilinear': 205 | h = jax.image.resize(h, (h.shape[0], *self.shape, h.shape[-1]), 'bilinear') 206 | elif self.interpolation == 'nearest_neighbor': 207 | h = jax.image.resize(h, (h.shape[0], *self.shape, h.shape[-1]), 'nearest') 208 | else: 209 | raise ValueError(f'Interpolation {self.interpolation} does not exist!') 210 | sums = sums + h 211 | return sums 212 | 213 | 214 | class CondMSFBlock(nn.Module): 215 | """Noise-conditional MSFBlock for RefineNet. Used in NCSNv1.""" 216 | shape: Sequence[int] 217 | features: int 218 | normalizer: Any 219 | interpolation: str = 'bilinear' 220 | 221 | @nn.compact 222 | def __call__(self, xs, y): 223 | sums = jnp.zeros((xs[0].shape[0], *self.shape, self.features)) 224 | for i in range(len(xs)): 225 | h = self.normalizer()(xs[i], y) 226 | h = ncsn_conv3x3(h, self.features, stride=1, bias=True) 227 | if self.interpolation == 'bilinear': 228 | h = jax.image.resize(h, (h.shape[0], *self.shape, h.shape[-1]), 'bilinear') 229 | elif self.interpolation == 'nearest_neighbor': 230 | h = jax.image.resize(h, (h.shape[0], *self.shape, h.shape[-1]), 'nearest') 231 | else: 232 | raise ValueError(f'Interpolation {self.interpolation} does not exist') 233 | sums = sums + h 234 | return sums 235 | 236 | 237 | class RefineBlock(nn.Module): 238 | """RefineBlock for building NCSNv2 RefineNet.""" 239 | output_shape: Sequence[int] 240 | features: int 241 | act: Any = nn.relu 242 | interpolation: str = 'bilinear' 243 | start: bool = False 244 | end: bool = False 245 | 246 | @nn.compact 247 | def __call__(self, xs): 248 | rcu_block = functools.partial(RCUBlock, n_blocks=2, n_stages=2, act=self.act) 249 | rcu_block_output = functools.partial(RCUBlock, 250 | features=self.features, 251 | n_blocks=3 if self.end else 1, 252 | n_stages=2, 253 | act=self.act) 254 | hs = [] 255 | for i in range(len(xs)): 256 | h = rcu_block(features=xs[i].shape[-1])(xs[i]) 257 | hs.append(h) 258 | 259 | if not self.start: 260 | msf = functools.partial(MSFBlock, features=self.features, interpolation=self.interpolation) 261 | h = msf(shape=self.output_shape)(hs) 262 | else: 263 | h = hs[0] 264 | 265 | crp = functools.partial(CRPBlock, features=self.features, n_stages=2, act=self.act) 266 | h = crp()(h) 267 | h = rcu_block_output()(h) 268 | return h 269 | 270 | 271 | class CondRefineBlock(nn.Module): 272 | """Noise-conditional RefineBlock for building NCSNv1 RefineNet.""" 273 | output_shape: Sequence[int] 274 | features: int 275 | normalizer: Any 276 | act: Any = nn.relu 277 | interpolation: str = 'bilinear' 278 | start: bool = False 279 | end: bool = False 280 | 281 | @nn.compact 282 | def __call__(self, xs, y): 283 | rcu_block = functools.partial(CondRCUBlock, n_blocks=2, n_stages=2, act=self.act, normalizer=self.normalizer) 284 | rcu_block_output = functools.partial(CondRCUBlock, 285 | features=self.features, 286 | n_blocks=3 if self.end else 1, 287 | n_stages=2, act=self.act, 288 | normalizer=self.normalizer) 289 | hs = [] 290 | for i in range(len(xs)): 291 | h = rcu_block(features=xs[i].shape[-1])(xs[i], y) 292 | hs.append(h) 293 | 294 | if not self.start: 295 | msf = functools.partial(CondMSFBlock, 296 | features=self.features, 297 | interpolation=self.interpolation, 298 | normalizer=self.normalizer) 299 | h = msf(shape=self.output_shape)(hs, y) 300 | else: 301 | h = hs[0] 302 | 303 | crp = functools.partial(CondCRPBlock, 304 | features=self.features, 305 | n_stages=2, act=self.act, 306 | normalizer=self.normalizer) 307 | h = crp()(h, y) 308 | h = rcu_block_output()(h, y) 309 | return h 310 | 311 | 312 | class ConvMeanPool(nn.Module): 313 | """ConvMeanPool for building the ResNet backbone.""" 314 | output_dim: int 315 | kernel_size: int = 3 316 | biases: bool = True 317 | 318 | @nn.compact 319 | def __call__(self, inputs): 320 | output = nn.Conv(features=self.output_dim, 321 | kernel_size=(self.kernel_size, self.kernel_size), 322 | strides=(1, 1), 323 | padding='SAME', 324 | use_bias=self.biases)(inputs) 325 | output = sum([ 326 | output[:, ::2, ::2, :], output[:, 1::2, ::2, :], 327 | output[:, ::2, 1::2, :], output[:, 1::2, 1::2, :] 328 | ]) / 4. 329 | return output 330 | 331 | 332 | class MeanPoolConv(nn.Module): 333 | """MeanPoolConv for building the ResNet backbone.""" 334 | output_dim: int 335 | kernel_size: int = 3 336 | biases: bool = True 337 | 338 | @nn.compact 339 | def __call__(self, inputs): 340 | output = inputs 341 | output = sum([ 342 | output[:, ::2, ::2, :], output[:, 1::2, ::2, :], 343 | output[:, ::2, 1::2, :], output[:, 1::2, 1::2, :] 344 | ]) / 4. 345 | output = nn.Conv( 346 | features=self.output_dim, 347 | kernel_size=(self.kernel_size, self.kernel_size), 348 | strides=(1, 1), 349 | padding='SAME', 350 | use_bias=self.biases)(output) 351 | return output 352 | 353 | 354 | class ResidualBlock(nn.Module): 355 | """The residual block for defining the ResNet backbone. Used in NCSNv2.""" 356 | output_dim: int 357 | normalization: Any 358 | resample: Optional[str] = None 359 | act: Any = nn.elu 360 | dilation: int = 1 361 | 362 | @nn.compact 363 | def __call__(self, x): 364 | h = self.normalization()(x) 365 | h = self.act(h) 366 | if self.resample == 'down': 367 | h = ncsn_conv3x3(h, h.shape[-1], dilation=self.dilation) 368 | h = self.normalization()(h) 369 | h = self.act(h) 370 | if self.dilation > 1: 371 | h = ncsn_conv3x3(h, self.output_dim, dilation=self.dilation) 372 | shortcut = ncsn_conv3x3(x, self.output_dim, dilation=self.dilation) 373 | else: 374 | h = ConvMeanPool(output_dim=self.output_dim)(h) 375 | shortcut = ConvMeanPool(output_dim=self.output_dim, kernel_size=1)(x) 376 | elif self.resample is None: 377 | if self.dilation > 1: 378 | if self.output_dim == x.shape[-1]: 379 | shortcut = x 380 | else: 381 | shortcut = ncsn_conv3x3(x, self.output_dim, dilation=self.dilation) 382 | h = ncsn_conv3x3(h, self.output_dim, dilation=self.dilation) 383 | h = self.normalization()(h) 384 | h = self.act(h) 385 | h = ncsn_conv3x3(h, self.output_dim, dilation=self.dilation) 386 | else: 387 | if self.output_dim == x.shape[-1]: 388 | shortcut = x 389 | else: 390 | shortcut = ncsn_conv1x1(x, self.output_dim) 391 | h = ncsn_conv3x3(h, self.output_dim) 392 | h = self.normalization()(h) 393 | h = self.act(h) 394 | h = ncsn_conv3x3(h, self.output_dim) 395 | 396 | return h + shortcut 397 | 398 | 399 | class ConditionalResidualBlock(nn.Module): 400 | """The noise-conditional residual block for building NCSNv1.""" 401 | output_dim: int 402 | normalization: Any 403 | resample: Optional[str] = None 404 | act: Any = nn.elu 405 | dilation: int = 1 406 | 407 | @nn.compact 408 | def __call__(self, x, y): 409 | h = self.normalization()(x, y) 410 | h = self.act(h) 411 | if self.resample == 'down': 412 | h = ncsn_conv3x3(h, h.shape[-1], dilation=self.dilation) 413 | h = self.normalization(h, y) 414 | h = self.act(h) 415 | if self.dilation > 1: 416 | h = ncsn_conv3x3(h, self.output_dim, dilation=self.dilation) 417 | shortcut = ncsn_conv3x3(x, self.output_dim, dilation=self.dilation) 418 | else: 419 | h = ConvMeanPool(output_dim=self.output_dim)(h) 420 | shortcut = ConvMeanPool(output_dim=self.output_dim, kernel_size=1)(x) 421 | elif self.resample is None: 422 | if self.dilation > 1: 423 | if self.output_dim == x.shape[-1]: 424 | shortcut = x 425 | else: 426 | shortcut = ncsn_conv3x3(x, self.output_dim, dilation=self.dilation) 427 | h = ncsn_conv3x3(h, self.output_dim, dilation=self.dilation) 428 | h = self.normalization()(h, y) 429 | h = self.act(h) 430 | h = ncsn_conv3x3(h, self.output_dim, dilation=self.dilation) 431 | else: 432 | if self.output_dim == x.shape[-1]: 433 | shortcut = x 434 | else: 435 | shortcut = ncsn_conv1x1(x, self.output_dim) 436 | h = ncsn_conv3x3(h, self.output_dim) 437 | h = self.normalization()(h, y) 438 | h = self.act(h) 439 | h = ncsn_conv3x3(h, self.output_dim) 440 | 441 | return h + shortcut 442 | 443 | 444 | ########################################################################### 445 | # Functions below are ported over from the DDPM codebase: 446 | # https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py 447 | ########################################################################### 448 | 449 | 450 | def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000): 451 | assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32 452 | half_dim = embedding_dim // 2 453 | # magic number 10000 is from transformers 454 | emb = math.log(max_positions) / (half_dim - 1) 455 | # emb = math.log(2.) / (half_dim - 1) 456 | emb = jnp.exp(jnp.arange(half_dim, dtype=jnp.float32) * -emb) 457 | # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :] 458 | # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :] 459 | emb = timesteps[:, None] * emb[None, :] 460 | emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=1) 461 | if embedding_dim % 2 == 1: # zero pad 462 | emb = jnp.pad(emb, [[0, 0], [0, 1]]) 463 | assert emb.shape == (timesteps.shape[0], embedding_dim) 464 | return emb 465 | 466 | 467 | class NIN(nn.Module): 468 | num_units: int 469 | init_scale: float = 0.1 470 | 471 | @nn.compact 472 | def __call__(self, x): 473 | in_dim = int(x.shape[-1]) 474 | W = self.param('W', default_init(scale=self.init_scale), (in_dim, self.num_units)) 475 | b = self.param('b', jnn.initializers.zeros, (self.num_units,)) 476 | y = contract_inner(x, W) + b 477 | assert y.shape == x.shape[:-1] + (self.num_units,) 478 | return y 479 | 480 | 481 | def _einsum(a, b, c, x, y): 482 | einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c)) 483 | return jnp.einsum(einsum_str, x, y) 484 | 485 | 486 | def contract_inner(x, y): 487 | """tensordot(x, y, 1).""" 488 | x_chars = list(string.ascii_lowercase[:len(x.shape)]) 489 | y_chars = list(string.ascii_uppercase[:len(y.shape)]) 490 | assert len(x_chars) == len(x.shape) and len(y_chars) == len(y.shape) 491 | y_chars[0] = x_chars[-1] # first axis of y and last of x get summed 492 | out_chars = x_chars[:-1] + y_chars[1:] 493 | return _einsum(x_chars, y_chars, out_chars, x, y) 494 | 495 | 496 | class AttnBlock(nn.Module): 497 | """Channel-wise self-attention block.""" 498 | normalize: Any 499 | 500 | @nn.compact 501 | def __call__(self, x): 502 | B, H, W, C = x.shape 503 | h = self.normalize()(x) 504 | q = NIN(C)(h) 505 | k = NIN(C)(h) 506 | v = NIN(C)(h) 507 | 508 | w = jnp.einsum('bhwc,bHWc->bhwHW', q, k) * (int(C) ** (-0.5)) 509 | w = jnp.reshape(w, (B, H, W, H * W)) 510 | w = jax.nn.softmax(w, axis=-1) 511 | w = jnp.reshape(w, (B, H, W, H, W)) 512 | h = jnp.einsum('bhwHW,bHWc->bhwc', w, v) 513 | h = NIN(C, init_scale=0.)(h) 514 | return x + h 515 | 516 | 517 | class Upsample(nn.Module): 518 | with_conv: bool = False 519 | 520 | @nn.compact 521 | def __call__(self, x): 522 | B, H, W, C = x.shape 523 | h = jax.image.resize(x, (x.shape[0], H * 2, W * 2, C), 'nearest') 524 | if self.with_conv: 525 | h = ddpm_conv3x3(h, C) 526 | return h 527 | 528 | 529 | class Downsample(nn.Module): 530 | with_conv: bool = False 531 | 532 | @nn.compact 533 | def __call__(self, x): 534 | B, H, W, C = x.shape 535 | if self.with_conv: 536 | x = ddpm_conv3x3(x, C, stride=2) 537 | else: 538 | x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME') 539 | assert x.shape == (B, H // 2, W // 2, C) 540 | return x 541 | 542 | 543 | class ResnetBlockDDPM(nn.Module): 544 | """The ResNet Blocks used in DDPM.""" 545 | act: Any 546 | normalize: Any 547 | out_ch: Optional[int] = None 548 | conv_shortcut: bool = False 549 | dropout: float = 0.5 550 | 551 | @nn.compact 552 | def __call__(self, x, temb=None, train=True): 553 | B, H, W, C = x.shape 554 | out_ch = self.out_ch if self.out_ch else C 555 | h = self.act(self.normalize()(x)) 556 | h = ddpm_conv3x3(h, out_ch) 557 | # Add bias to each feature map conditioned on the time embedding 558 | if temb is not None: 559 | h += nn.Dense(out_ch, kernel_init=default_init())(self.act(temb))[:, None, None, :] 560 | h = self.act(self.normalize()(h)) 561 | h = nn.Dropout(self.dropout)(h, deterministic=not train) 562 | h = ddpm_conv3x3(h, out_ch, init_scale=0.) 563 | if C != out_ch: 564 | if self.conv_shortcut: 565 | x = ddpm_conv3x3(x, out_ch) 566 | else: 567 | x = NIN(out_ch)(x) 568 | return x + h 569 | --------------------------------------------------------------------------------