├── cleanfid ├── __init__.py ├── leaderboard.py ├── utils.py ├── downloads_helper.py ├── features.py └── resize.py ├── flow_models ├── __init__.py ├── resflow │ ├── __init__.py │ ├── layers │ │ ├── base │ │ │ ├── __init__.py │ │ │ ├── utils.py │ │ │ └── activations.py │ │ ├── __init__.py │ │ ├── container.py │ │ ├── squeeze.py │ │ ├── mask_utils.py │ │ ├── glow.py │ │ ├── elemwise.py │ │ ├── nonlinear_activation.py │ │ ├── normalization.py │ │ └── act_norm.py │ ├── datasets.py │ ├── lr_scheduler.py │ ├── toy_data.py │ └── visualize_flow.py └── wolf │ ├── flows │ ├── resflow │ │ ├── __init__.py │ │ ├── layers │ │ │ ├── base │ │ │ │ ├── __init__.py │ │ │ │ ├── utils.py │ │ │ │ └── activations.py │ │ │ ├── __init__.py │ │ │ ├── container.py │ │ │ ├── squeeze.py │ │ │ ├── mask_utils.py │ │ │ ├── glow.py │ │ │ ├── elemwise.py │ │ │ ├── nonlinear_activation.py │ │ │ ├── normalization.py │ │ │ └── act_norm.py │ │ ├── datasets.py │ │ ├── lr_scheduler.py │ │ ├── toy_data.py │ │ └── visualize_flow.py │ ├── couplings │ │ └── __init__.py │ ├── __init__.py │ └── flow.py │ ├── __init__.py │ ├── nnet │ ├── resnets │ │ ├── __init__.py │ │ └── resnet.py │ ├── __init__.py │ ├── layer_norm.py │ ├── adaptive_instance_norm.py │ ├── positional_encoding.py │ ├── shift_conv.py │ └── weight_norm.py │ ├── modules │ ├── generators │ │ ├── __init__.py │ │ └── generator.py │ ├── dequantization │ │ └── __init__.py │ ├── discriminators │ │ ├── priors │ │ │ ├── __init__.py │ │ │ └── prior.py │ │ ├── __init__.py │ │ ├── discriminator.py │ │ ├── categorical.py │ │ └── gaussian.py │ ├── __init__.py │ └── encoders │ │ ├── __init__.py │ │ ├── encoder.py │ │ ├── global_encoder.py │ │ └── local_encoder.py │ ├── optim │ └── __init__.py │ ├── data │ └── __init__.py │ ├── wolf_configs │ ├── cifar10 │ │ ├── glow │ │ │ ├── glow-base-uni.json │ │ │ ├── resflow-gaussian-uni.json │ │ │ ├── glow-cat-uni.json │ │ │ ├── resflow-gaussian-uni-squeeze.json │ │ │ ├── glow-gaussian-uni.json │ │ │ ├── glow-base-var.json │ │ │ └── glow-gaussian-var.json │ │ └── macow │ │ │ ├── macow-base-uni.json │ │ │ ├── macow-cat-uni.json │ │ │ ├── macow-gaussian-uni.json │ │ │ └── macow-base-var.json │ ├── imagenet │ │ └── 64x64 │ │ │ └── glow │ │ │ ├── glow-base-uni.json │ │ │ ├── resflow-gaussian-uni.json │ │ │ └── glow-gaussian-uni.json │ ├── lsun │ │ └── 128x128 │ │ │ └── glow │ │ │ ├── glow-base-uni.json │ │ │ └── glow-gaussian-uni.json │ └── celebA-HQ │ │ ├── glow │ │ ├── glow-base-uni.json │ │ ├── glow-gaussian-uni.json │ │ └── glow-base-var.json │ │ └── macow │ │ ├── macow-base-uni.json │ │ ├── macow-gaussian-uni.json │ │ └── macow-base-var.json │ └── utils.py ├── figures └── overview.png ├── op ├── __init__.py ├── fused_bias_act.cpp ├── upfirdn2d.cpp ├── fused_act.py └── fused_bias_act_kernel.cu ├── .gitignore ├── requirements.txt ├── models ├── __init__.py └── ema.py ├── configs ├── ve │ ├── CIFAR10 │ │ └── indm.py │ └── CELEBA │ │ └── indm.py ├── vp │ ├── CIFAR10 │ │ ├── indm_nll.py │ │ └── indm_fid.py │ └── CELEBA │ │ ├── indm_nll.py │ │ └── indm_fid.py ├── default_celeba_configs.py └── default_cifar10_configs.py └── main.py /cleanfid/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /flow_models/__init__.py: -------------------------------------------------------------------------------- 1 | from flow_models.resflow import ResidualFlow -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/byeonghu-na/INDM/HEAD/figures/overview.png -------------------------------------------------------------------------------- /flow_models/resflow/__init__.py: -------------------------------------------------------------------------------- 1 | from flow_models.resflow.resflow_ import ResidualFlow 2 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/__init__.py: -------------------------------------------------------------------------------- 1 | from flow_models.wolf.flows.resflow.resflow_ import ResidualFlow 2 | -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /flow_models/wolf/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from flow_models.wolf.wolf import WolfModel, WolfCore 4 | -------------------------------------------------------------------------------- /flow_models/wolf/nnet/resnets/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from flow_models.wolf.nnet.resnets.resnet import * 4 | -------------------------------------------------------------------------------- /flow_models/resflow/layers/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .lipschitz import * 3 | from .mixed_lipschitz import * 4 | -------------------------------------------------------------------------------- /flow_models/wolf/modules/generators/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from flow_models.wolf.modules.generators.generator import Generator 4 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/layers/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .activations import * 2 | from .lipschitz import * 3 | from .mixed_lipschitz import * 4 | -------------------------------------------------------------------------------- /flow_models/wolf/optim/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from flow_models.wolf.optim.lr_scheduler import InverseSquareRootScheduler, ExponentialScheduler -------------------------------------------------------------------------------- /flow_models/wolf/flows/couplings/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from flow_models.wolf.flows.couplings.coupling import NICE1d, NICE2d, MaskedConvFlow 4 | -------------------------------------------------------------------------------- /flow_models/wolf/modules/dequantization/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from flow_models.wolf.modules.dequantization.dequantizer import DeQuantizer, UniformDeQuantizer, FlowDeQuantizer 4 | -------------------------------------------------------------------------------- /flow_models/wolf/data/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from flow_models.wolf.data.image import load_datasets, iterate_minibatches, get_batch, binarize_data, binarize_image 4 | from flow_models.wolf.data.image import preprocess, postprocess 5 | -------------------------------------------------------------------------------- /flow_models/wolf/modules/discriminators/priors/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from flow_models.wolf.modules.discriminators.priors.prior import Prior, NormalPrior 4 | from flow_models.wolf.modules.discriminators.priors.flow import FlowPrior 5 | -------------------------------------------------------------------------------- /flow_models/wolf/modules/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from flow_models.wolf.modules.dequantization import * 4 | from flow_models.wolf.modules.encoders import * 5 | from flow_models.wolf.modules.discriminators import * 6 | from flow_models.wolf.modules.generators import * 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled python modules. 2 | *.pyc 3 | 4 | # Byte-compiled 5 | _pycache__/ 6 | .cache/ 7 | .idea/ 8 | 9 | # Python egg metadata, regenerated from source files by setuptools. 10 | /*.egg-info 11 | .eggs/ 12 | 13 | # PyPI distribution artifacts. 14 | build/ 15 | dist/ 16 | 17 | # Tests 18 | .pytest_cache/ 19 | 20 | # Other 21 | *.DS_Store 22 | /assets/ 23 | -------------------------------------------------------------------------------- /flow_models/wolf/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from flow_models.wolf.modules.encoders.encoder import Encoder 4 | from flow_models.wolf.modules.encoders.global_encoder import GlobalResNetEncoderBatchNorm, GlobalResNetEncoderGroupNorm 5 | from flow_models.wolf.modules.encoders.local_encoder import LocalResNetEncoderBatchNorm, LocalResNetEncoderGroupNorm 6 | -------------------------------------------------------------------------------- /flow_models/wolf/modules/discriminators/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from flow_models.wolf.modules.discriminators.discriminator import Discriminator 4 | from flow_models.wolf.modules.discriminators.gaussian import GaussianDiscriminator 5 | from flow_models.wolf.modules.discriminators.categorical import CategoricalDiscriminator 6 | from flow_models.wolf.modules.discriminators.priors import * 7 | -------------------------------------------------------------------------------- /flow_models/resflow/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from flow_models.resflow.layers.act_norm import * 2 | from flow_models.resflow.layers.container import * 3 | from flow_models.resflow.layers.coupling import * 4 | from flow_models.resflow.layers.elemwise import * 5 | from flow_models.resflow.layers.iresblock import * 6 | from flow_models.resflow.layers.normalization import * 7 | from flow_models.resflow.layers.squeeze import * 8 | from flow_models.resflow.layers.glow import * 9 | -------------------------------------------------------------------------------- /flow_models/resflow/layers/base/utils.py: -------------------------------------------------------------------------------- 1 | #from torch._six import container_abcs 2 | import collections.abc as container_abcs 3 | from itertools import repeat 4 | 5 | 6 | def _ntuple(n): 7 | 8 | def parse(x): 9 | if isinstance(x, container_abcs.Iterable): 10 | return x 11 | return tuple(repeat(x, n)) 12 | 13 | return parse 14 | 15 | 16 | _single = _ntuple(1) 17 | _pair = _ntuple(2) 18 | _triple = _ntuple(3) 19 | _quadruple = _ntuple(4) 20 | -------------------------------------------------------------------------------- /flow_models/wolf/nnet/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from flow_models.wolf.nnet.weight_norm import LinearWeightNorm, Conv2dWeightNorm, ConvTranspose2dWeightNorm 4 | from flow_models.wolf.nnet.shift_conv import ShiftedConv2d 5 | from flow_models.wolf.nnet.resnets import * 6 | from flow_models.wolf.nnet.attention import MultiHeadAttention, MultiHeadAttention2d 7 | from flow_models.wolf.nnet.layer_norm import LayerNorm 8 | from flow_models.wolf.nnet.adaptive_instance_norm import AdaIN2d 9 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/layers/base/utils.py: -------------------------------------------------------------------------------- 1 | #from torch._six import container_abcs 2 | import collections.abc as container_abcs 3 | from itertools import repeat 4 | 5 | 6 | def _ntuple(n): 7 | 8 | def parse(x): 9 | if isinstance(x, container_abcs.Iterable): 10 | return x 11 | return tuple(repeat(x, n)) 12 | 13 | return parse 14 | 15 | 16 | _single = _ntuple(1) 17 | _pair = _ntuple(2) 18 | _triple = _ntuple(3) 19 | _quadruple = _ntuple(4) 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | overrides==3.1.0 2 | ml-collections==0.1.0 3 | tensorflow-gan==2.0.0 4 | tensorflow_io 5 | tensorflow_datasets==3.1.0 6 | tensorflow==2.4.0 7 | tensorflow-addons==0.12.0 8 | tensorflow_probability==0.12.0 9 | tensorboard==2.4.0 10 | absl-py==0.10.0 11 | ninja 12 | scipy 13 | natsort 14 | imageio 15 | -f https://storage.googleapis.com/jax-releases/jax_releases.html 16 | jax[cuda111] 17 | -f https://download.pytorch.org/whl/torch_stable.html 18 | torch==1.7.1+cu110 19 | torchvision==0.8.2+cu110 20 | torchaudio==0.7.2 21 | -------------------------------------------------------------------------------- /flow_models/wolf/nnet/layer_norm.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): 8 | # if not export and torch.cuda.is_available(): 9 | # try: 10 | # from apex.normalization import FusedLayerNorm 11 | # return FusedLayerNorm(normalized_shape, eps, elementwise_affine) 12 | # except ImportError: 13 | # pass 14 | return nn.LayerNorm(normalized_shape, eps, elementwise_affine) 15 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from flow_models.wolf.flows.resflow.layers.act_norm import * 2 | from flow_models.wolf.flows.resflow.layers.container import * 3 | from flow_models.wolf.flows.resflow.layers.coupling import * 4 | from flow_models.wolf.flows.resflow.layers.elemwise import * 5 | from flow_models.wolf.flows.resflow.layers.iresblock import * 6 | from flow_models.wolf.flows.resflow.layers.normalization import * 7 | from flow_models.wolf.flows.resflow.layers.squeeze import * 8 | from flow_models.wolf.flows.resflow.layers.glow import * 9 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from flow_models.wolf.flows.flow import Flow 4 | from flow_models.wolf.flows.normalization import ActNorm1dFlow, ActNorm2dFlow 5 | from flow_models.wolf.flows.activation import LeakyReLUFlow, ELUFlow, PowshrinkFlow, IdentityFlow, SigmoidFlow 6 | from flow_models.wolf.flows.permutation import Conv1x1Flow, InvertibleLinearFlow, InvertibleMultiHeadFlow 7 | from flow_models.wolf.flows.multiscale_architecture import MultiScaleExternal, MultiScaleInternal 8 | from flow_models.wolf.flows.couplings import * 9 | from flow_models.wolf.flows.glow import Glow 10 | from flow_models.wolf.flows.macow import MaCow 11 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/cifar10/glow/glow-base-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "glow", 5 | "levels": 4, 6 | "num_steps": [2, [6, 6], [6, 6], 4], 7 | "factors": [4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 512, 512, 512], 10 | "activation": "elu", 11 | "inverse": true, 12 | "transform": "affine", 13 | "prior_transform": "affine", 14 | "alpha": 1.0, 15 | "coupling_type": "conv", 16 | "num_groups": [2, 4, 4, 4] 17 | } 18 | }, 19 | "discriminator" : { 20 | "type": "base" 21 | }, 22 | "dequantizer": { 23 | "type": "uniform" 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/cifar10/macow/macow-base-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "macow", 5 | "levels": 4, 6 | "num_steps": [2, [6, 6], [4, 4], 4], 7 | "factors": [4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 512, 512, 512], 10 | "kernel_size": [2, 3], 11 | "activation": "elu", 12 | "inverse": true, 13 | "transform": "affine", 14 | "prior_transform": "affine", 15 | "coupling_type": "conv", 16 | "num_groups": [2, 4, 4, 4] 17 | } 18 | }, 19 | "discriminator" : { 20 | "type": "base" 21 | }, 22 | "dequantizer": { 23 | "type": "uniform" 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/imagenet/64x64/glow/glow-base-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "glow", 5 | "levels": 5, 6 | "num_steps": [2, [8, 8], [8, 8], [6, 6], 4], 7 | "factors": [4, 4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 512, 512, 512, 512], 10 | "activation": "elu", 11 | "inverse": true, 12 | "transform": "affine", 13 | "prior_transform": "affine", 14 | "alpha": 1.0, 15 | "coupling_type": "conv", 16 | "num_groups": [2, 4, 4, 4, 4] 17 | } 18 | }, 19 | "discriminator" : { 20 | "type": "base" 21 | }, 22 | "dequantizer": { 23 | "type": "uniform" 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/lsun/128x128/glow/glow-base-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "glow", 5 | "levels": 6, 6 | "num_steps": [2, [16, 16], [16, 16], [8, 8], [4, 4], 2], 7 | "factors": [4, 4, 4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 256, 256, 256, 256, 512], 10 | "activation": "elu", 11 | "inverse": true, 12 | "transform": "affine", 13 | "prior_transform": "affine", 14 | "alpha": 1.0, 15 | "coupling_type": "conv", 16 | "num_groups": [2, 4, 4, 4, 4, 4] 17 | } 18 | }, 19 | "discriminator" : { 20 | "type": "base" 21 | }, 22 | "dequantizer": { 23 | "type": "uniform" 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /flow_models/wolf/modules/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from typing import Dict 4 | import torch.nn as nn 5 | 6 | 7 | class Encoder(nn.Module): 8 | """ 9 | Encoder base class 10 | """ 11 | _registry = dict() 12 | 13 | def __init__(self): 14 | super(Encoder, self).__init__() 15 | 16 | def init(self, x, init_scale=1.0): 17 | raise NotImplementedError 18 | 19 | @classmethod 20 | def register(cls, name: str): 21 | Encoder._registry[name] = cls 22 | 23 | @classmethod 24 | def by_name(cls, name: str): 25 | return Encoder._registry[name] 26 | 27 | @classmethod 28 | def from_params(cls, params: Dict): 29 | raise NotImplementedError 30 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/celebA-HQ/glow/glow-base-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "glow", 5 | "levels": 7, 6 | "num_steps": [2, [12, 12], [10, 10], [8, 8], [4, 4], [2, 2], 1], 7 | "factors": [4, 4, 4, 4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 256, 256, 256, 256, 256, 512], 10 | "activation": "elu", 11 | "inverse": true, 12 | "transform": "affine", 13 | "prior_transform": "affine", 14 | "alpha": 1.0, 15 | "coupling_type": "conv", 16 | "num_groups": [2, 4, 4, 4, 4, 4, 4] 17 | } 18 | }, 19 | "discriminator" : { 20 | "type": "base" 21 | }, 22 | "dequantizer": { 23 | "type": "uniform" 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/celebA-HQ/macow/macow-base-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "macow", 5 | "levels": 7, 6 | "num_steps": [2, [10, 10], [8, 8], [8, 8], [4, 4], [1, 1], 1], 7 | "factors": [4, 4, 4, 4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 256, 256, 256, 256, 256, 512], 10 | "kernel_size": [2, 3], 11 | "activation": "elu", 12 | "inverse": true, 13 | "transform": "affine", 14 | "prior_transform": "affine", 15 | "coupling_type": "conv", 16 | "num_groups": [2, 4, 4, 4, 4, 4, 4] 17 | } 18 | }, 19 | "discriminator" : { 20 | "type": "base" 21 | }, 22 | "dequantizer": { 23 | "type": "uniform" 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/cifar10/glow/resflow-gaussian-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "resflow" 5 | } 6 | }, 7 | "discriminator" : { 8 | "type": "gaussian", 9 | "encoder": { 10 | "type": "global_resnet_bn", 11 | "levels": 3, 12 | "in_planes": 3, 13 | "hidden_planes": [48, 96, 96], 14 | "out_planes": 8, 15 | "activation": "elu" 16 | }, 17 | "in_dim": 128, 18 | "dim": 64, 19 | "prior": { 20 | "type": "flow", 21 | "num_steps": 2, 22 | "in_features": 64, 23 | "hidden_features": 256, 24 | "activation": "elu", 25 | "transform": "affine", 26 | "alpha": 1.0, 27 | "coupling_type": "mlp" 28 | } 29 | }, 30 | "dequantizer": { 31 | "type": "uniform" 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/cifar10/glow/glow-cat-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "glow", 5 | "levels": 4, 6 | "num_steps": [2, [6, 6], [6, 6], 4], 7 | "factors": [4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 512, 512, 512], 10 | "h_channels": 64, 11 | "h_type": "global_linear", 12 | "activation": "elu", 13 | "inverse": true, 14 | "transform": "affine", 15 | "prior_transform": "affine", 16 | "alpha": 1.0, 17 | "coupling_type": "conv", 18 | "num_groups": [2, 4, 4, 4] 19 | } 20 | }, 21 | "discriminator" : { 22 | "type": "categorical", 23 | "num_events": 10, 24 | "dim": 64, 25 | "activation": "elu" 26 | }, 27 | "dequantizer": { 28 | "type": "uniform" 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/imagenet/64x64/glow/resflow-gaussian-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "resflow" 5 | } 6 | }, 7 | "discriminator" : { 8 | "type": "gaussian", 9 | "encoder": { 10 | "type": "global_resnet_bn", 11 | "levels": 3, 12 | "in_planes": 12, 13 | "hidden_planes": [48, 96, 96], 14 | "out_planes": 8, 15 | "activation": "elu" 16 | }, 17 | "in_dim": 128, 18 | "dim": 64, 19 | "prior": { 20 | "type": "flow", 21 | "num_steps": 2, 22 | "in_features": 64, 23 | "hidden_features": 256, 24 | "activation": "elu", 25 | "transform": "affine", 26 | "alpha": 1.0, 27 | "coupling_type": "mlp" 28 | } 29 | }, 30 | "dequantizer": { 31 | "type": "uniform" 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/cifar10/glow/resflow-gaussian-uni-squeeze.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "resflow" 5 | } 6 | }, 7 | "discriminator" : { 8 | "type": "gaussian", 9 | "encoder": { 10 | "type": "global_resnet_bn", 11 | "levels": 3, 12 | "in_planes": 12, 13 | "hidden_planes": [48, 96, 96], 14 | "out_planes": 32, 15 | "activation": "elu" 16 | }, 17 | "in_dim": 128, 18 | "dim": 64, 19 | "prior": { 20 | "type": "flow", 21 | "num_steps": 2, 22 | "in_features": 64, 23 | "hidden_features": 256, 24 | "activation": "elu", 25 | "transform": "affine", 26 | "alpha": 1.0, 27 | "coupling_type": "mlp" 28 | } 29 | }, 30 | "dequantizer": { 31 | "type": "uniform" 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/cifar10/macow/macow-cat-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "macow", 5 | "levels": 4, 6 | "num_steps": [2, [6, 6], [4, 4], 4], 7 | "factors": [4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 512, 512, 512], 10 | "kernel_size": [2, 3], 11 | "h_channels": 64, 12 | "h_type": "global_linear", 13 | "activation": "elu", 14 | "inverse": true, 15 | "transform": "affine", 16 | "prior_transform": "affine", 17 | "coupling_type": "conv", 18 | "num_groups": [2, 4, 4, 4] 19 | } 20 | }, 21 | "discriminator" : { 22 | "type": "categorical", 23 | "num_events": 10, 24 | "dim": 64, 25 | "activation": "elu" 26 | }, 27 | "dequantizer": { 28 | "type": "uniform" 29 | } 30 | } 31 | -------------------------------------------------------------------------------- /flow_models/wolf/nnet/adaptive_instance_norm.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class AdaIN2d(nn.Module): 8 | def __init__(self, in_channels, in_features): 9 | super(AdaIN2d, self).__init__() 10 | self.norm = nn.InstanceNorm2d(in_channels, affine=False, track_running_stats=False) 11 | self.net = nn.Linear(in_features, 2 * in_channels) 12 | self.reset_parameters() 13 | 14 | def forward(self, x, h): 15 | # [batch, num_features * 2] 16 | h = self.net(h) 17 | bs, fs = h.size() 18 | h.view(bs, fs, 1, 1) 19 | # [batch, num_features, 1, 1] 20 | b, s = h.chunk(2, 1) 21 | x = self.norm(x) 22 | return x * (s + 1) + b 23 | 24 | def reset_parameters(self): 25 | nn.init.constant_(self.net.weight, 0.0) 26 | nn.init.constant_(self.net.bias, 0.0) 27 | -------------------------------------------------------------------------------- /op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/celebA-HQ/macow/macow-gaussian-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "macow", 5 | "levels": 7, 6 | "num_steps": [2, [10, 10], [8, 8], [8, 8], [4, 4], [1, 1], 1], 7 | "factors": [4, 4, 4, 4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 256, 256, 256, 256, 256, 512], 10 | "kernel_size": [2, 3], 11 | "h_channels": 256, 12 | "h_type": "global_linear", 13 | "activation": "elu", 14 | "inverse": true, 15 | "transform": "affine", 16 | "prior_transform": "affine", 17 | "coupling_type": "conv", 18 | "num_groups": [2, 4, 4, 4, 4, 4, 4] 19 | } 20 | }, 21 | "discriminator" : { 22 | "type": "gaussian", 23 | "encoder": { 24 | "type": "global_resnet_bn", 25 | "levels": 6, 26 | "in_planes": 3, 27 | "hidden_planes": [48, 96, 96, 192, 192, 256], 28 | "out_planes": 32, 29 | "activation": "elu" 30 | }, 31 | "in_dim": 512, 32 | "dim": 256, 33 | "prior": { 34 | "type": "normal" 35 | } 36 | }, 37 | "dequantizer": { 38 | "type": "uniform" 39 | } 40 | } 41 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/cifar10/macow/macow-gaussian-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "macow", 5 | "levels": 4, 6 | "num_steps": [2, [6, 6], [4, 4], 4], 7 | "factors": [4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 512, 512, 512], 10 | "kernel_size": [2, 3], 11 | "h_channels": 64, 12 | "h_type": "global_linear", 13 | "activation": "elu", 14 | "inverse": true, 15 | "transform": "affine", 16 | "prior_transform": "affine", 17 | "coupling_type": "conv", 18 | "num_groups": [2, 4, 4, 4] 19 | } 20 | }, 21 | "discriminator" : { 22 | "type": "gaussian", 23 | "encoder": { 24 | "type": "global_resnet_bn", 25 | "levels": 3, 26 | "in_planes": 3, 27 | "hidden_planes": [48, 96, 96], 28 | "out_planes": 8, 29 | "activation": "elu" 30 | }, 31 | "in_dim": 128, 32 | "dim": 64, 33 | "prior": { 34 | "type": "flow", 35 | "num_steps": 2, 36 | "in_features": 64, 37 | "hidden_features": 256, 38 | "activation": "elu", 39 | "transform": "affine", 40 | "coupling_type": "mlp" 41 | } 42 | }, 43 | "dequantizer": { 44 | "type": "uniform" 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/cifar10/glow/glow-gaussian-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "glow", 5 | "levels": 4, 6 | "num_steps": [2, [6, 6], [6, 6], 4], 7 | "factors": [4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 512, 512, 512], 10 | "h_channels": 64, 11 | "h_type": "global_linear", 12 | "activation": "elu", 13 | "inverse": true, 14 | "transform": "affine", 15 | "prior_transform": "affine", 16 | "alpha": 1.0, 17 | "coupling_type": "conv", 18 | "num_groups": [2, 4, 4, 4] 19 | } 20 | }, 21 | "discriminator" : { 22 | "type": "gaussian", 23 | "encoder": { 24 | "type": "global_resnet_bn", 25 | "levels": 3, 26 | "in_planes": 3, 27 | "hidden_planes": [48, 96, 96], 28 | "out_planes": 8, 29 | "activation": "elu" 30 | }, 31 | "in_dim": 128, 32 | "dim": 64, 33 | "prior": { 34 | "type": "flow", 35 | "num_steps": 2, 36 | "in_features": 64, 37 | "hidden_features": 256, 38 | "activation": "elu", 39 | "transform": "affine", 40 | "alpha": 1.0, 41 | "coupling_type": "mlp" 42 | } 43 | }, 44 | "dequantizer": { 45 | "type": "uniform" 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/imagenet/64x64/glow/glow-gaussian-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "glow", 5 | "levels": 5, 6 | "num_steps": [2, [8, 8], [8, 8], [6, 6], 4], 7 | "factors": [4, 4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 512, 512, 512, 512], 10 | "h_channels": 128, 11 | "h_type": "global_linear", 12 | "activation": "elu", 13 | "inverse": true, 14 | "transform": "affine", 15 | "prior_transform": "affine", 16 | "alpha": 1.0, 17 | "coupling_type": "conv", 18 | "num_groups": [2, 4, 4, 4, 4] 19 | } 20 | }, 21 | "discriminator" : { 22 | "type": "gaussian", 23 | "encoder": { 24 | "type": "global_resnet_bn", 25 | "levels": 4, 26 | "in_planes": 3, 27 | "hidden_planes": [48, 96, 96, 192], 28 | "out_planes": 32, 29 | "activation": "elu" 30 | }, 31 | "in_dim": 512, 32 | "dim": 128, 33 | "prior": { 34 | "type": "flow", 35 | "num_steps": 2, 36 | "in_features": 128, 37 | "hidden_features": 256, 38 | "activation": "elu", 39 | "transform": "affine", 40 | "alpha": 1.0, 41 | "coupling_type": "mlp" 42 | } 43 | }, 44 | "dequantizer": { 45 | "type": "uniform" 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/lsun/128x128/glow/glow-gaussian-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "glow", 5 | "levels": 6, 6 | "num_steps": [2, [16, 16], [16, 16], [8, 8], [4, 4], 2], 7 | "factors": [4, 4, 4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 256, 256, 256, 256, 512], 10 | "h_channels": 256, 11 | "h_type": "global_linear", 12 | "activation": "elu", 13 | "inverse": true, 14 | "transform": "affine", 15 | "prior_transform": "affine", 16 | "alpha": 1.0, 17 | "coupling_type": "conv", 18 | "num_groups": [2, 4, 4, 4, 4, 4] 19 | } 20 | }, 21 | "discriminator" : { 22 | "type": "gaussian", 23 | "encoder": { 24 | "type": "global_resnet_bn", 25 | "levels": 5, 26 | "in_planes": 3, 27 | "hidden_planes": [48, 96, 96, 192, 192], 28 | "out_planes": 32, 29 | "activation": "elu" 30 | }, 31 | "in_dim": 512, 32 | "dim": 256, 33 | "prior": { 34 | "type": "flow", 35 | "num_steps": 2, 36 | "in_features": 256, 37 | "hidden_features": 512, 38 | "activation": "elu", 39 | "transform": "affine", 40 | "alpha": 1.0, 41 | "coupling_type": "mlp" 42 | } 43 | }, 44 | "dequantizer": { 45 | "type": "uniform" 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/cifar10/glow/glow-base-var.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "glow", 5 | "levels": 4, 6 | "num_steps": [2, [6, 6], [6, 6], 4], 7 | "factors": [4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 512, 512, 512], 10 | "activation": "elu", 11 | "inverse": true, 12 | "transform": "affine", 13 | "prior_transform": "affine", 14 | "alpha": 1.0, 15 | "coupling_type": "conv", 16 | "num_groups": [2, 4, 4, 4] 17 | } 18 | }, 19 | "discriminator" : { 20 | "type": "base" 21 | }, 22 | "dequantizer": { 23 | "type": "flow", 24 | "encoder": { 25 | "type": "local_resnet_bn", 26 | "levels": 2, 27 | "in_planes": 3, 28 | "hidden_planes": [48, 96], 29 | "out_planes": 4, 30 | "activation": "elu" 31 | }, 32 | "flow": { 33 | "type": "glow", 34 | "levels": 2, 35 | "num_steps": [2, 4], 36 | "factors": [], 37 | "in_channels": 3, 38 | "hidden_channels": [24, 256], 39 | "h_channels": 4, 40 | "h_type": "local_linear", 41 | "activation": "elu", 42 | "inverse": false, 43 | "transform": "affine", 44 | "prior_transform": "affine", 45 | "alpha": 1.0, 46 | "coupling_type": "conv" 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/celebA-HQ/glow/glow-gaussian-uni.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "glow", 5 | "levels": 7, 6 | "num_steps": [2, [12, 12], [10, 10], [8, 8], [4, 4], [2, 2], 1], 7 | "factors": [4, 4, 4, 4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 256, 256, 256, 256, 256, 512], 10 | "h_channels": 256, 11 | "h_type": "global_linear", 12 | "activation": "elu", 13 | "inverse": true, 14 | "transform": "affine", 15 | "prior_transform": "affine", 16 | "alpha": 1.0, 17 | "coupling_type": "conv", 18 | "num_groups": [2, 4, 4, 4, 4, 4, 4] 19 | } 20 | }, 21 | "discriminator" : { 22 | "type": "gaussian", 23 | "encoder": { 24 | "type": "global_resnet_bn", 25 | "levels": 6, 26 | "in_planes": 3, 27 | "hidden_planes": [48, 96, 96, 192, 192, 256], 28 | "out_planes": 32, 29 | "activation": "elu" 30 | }, 31 | "in_dim": 512, 32 | "dim": 256, 33 | "prior": { 34 | "type": "flow", 35 | "num_steps": 2, 36 | "in_features": 256, 37 | "hidden_features": 512, 38 | "activation": "elu", 39 | "transform": "affine", 40 | "alpha": 1.0, 41 | "coupling_type": "mlp" 42 | } 43 | }, 44 | "dequantizer": { 45 | "type": "uniform" 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/cifar10/macow/macow-base-var.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "macow", 5 | "levels": 4, 6 | "num_steps": [2, [6, 6], [4, 4], 4], 7 | "factors": [4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 512, 512, 512], 10 | "kernel_size": [2, 3], 11 | "activation": "elu", 12 | "inverse": true, 13 | "transform": "affine", 14 | "prior_transform": "affine", 15 | "coupling_type": "conv", 16 | "num_groups": [2, 4, 4, 4] 17 | } 18 | }, 19 | "discriminator" : { 20 | "type": "base" 21 | }, 22 | "dequantizer": { 23 | "type": "flow", 24 | "encoder": { 25 | "type": "local_resnet_bn", 26 | "levels": 2, 27 | "in_planes": 3, 28 | "hidden_planes": [48, 96], 29 | "out_planes": 4, 30 | "activation": "elu" 31 | }, 32 | "flow": { 33 | "type": "macow", 34 | "levels": 2, 35 | "num_steps": [2, 4], 36 | "factors": [], 37 | "in_channels": 3, 38 | "hidden_channels": [24, 256], 39 | "kernel_size": [2, 3], 40 | "h_channels": 4, 41 | "h_type": "local_linear", 42 | "activation": "elu", 43 | "inverse": false, 44 | "transform": "affine", 45 | "prior_transform": "affine", 46 | "coupling_type": "conv" 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /flow_models/resflow/layers/container.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SequentialFlow(nn.Module): 5 | """A generalized nn.Sequential container for normalizing flows. 6 | """ 7 | 8 | def __init__(self, layersList): 9 | super(SequentialFlow, self).__init__() 10 | self.chain = nn.ModuleList(layersList) 11 | 12 | def forward(self, x, logpx=None): 13 | if logpx is None: 14 | for i in range(len(self.chain)): 15 | x = self.chain[i](x) 16 | return x 17 | else: 18 | for i in range(len(self.chain)): 19 | x, logpx = self.chain[i](x, logpx) 20 | return x, logpx 21 | 22 | def inverse(self, y, logpy=None): 23 | if logpy is None: 24 | for i in range(len(self.chain) - 1, -1, -1): 25 | y = self.chain[i].inverse(y) 26 | return y 27 | else: 28 | for i in range(len(self.chain) - 1, -1, -1): 29 | y, logpy = self.chain[i].inverse(y, logpy) 30 | return y, logpy 31 | 32 | 33 | class Inverse(nn.Module): 34 | 35 | def __init__(self, flow): 36 | super(Inverse, self).__init__() 37 | self.flow = flow 38 | 39 | def forward(self, x, logpx=None): 40 | return self.flow.inverse(x, logpx) 41 | 42 | def inverse(self, y, logpy=None): 43 | return self.flow.forward(y, logpy) 44 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/celebA-HQ/glow/glow-base-var.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "glow", 5 | "levels": 7, 6 | "num_steps": [2, [12, 12], [10, 10], [8, 8], [4, 4], [2, 2], 1], 7 | "factors": [4, 4, 4, 4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 256, 256, 256, 256, 256, 512], 10 | "activation": "elu", 11 | "inverse": true, 12 | "transform": "affine", 13 | "prior_transform": "affine", 14 | "alpha": 1.0, 15 | "coupling_type": "conv", 16 | "num_groups": [2, 4, 4, 4, 4, 4, 4] 17 | } 18 | }, 19 | "discriminator" : { 20 | "type": "base" 21 | }, 22 | "dequantizer": { 23 | "type": "flow", 24 | "encoder": { 25 | "type": "local_resnet_bn", 26 | "levels": 3, 27 | "in_planes": 3, 28 | "hidden_planes": [48, 96, 96], 29 | "out_planes": 4, 30 | "activation": "elu" 31 | }, 32 | "flow": { 33 | "type": "glow", 34 | "levels": 3, 35 | "num_steps": [2, [4], 2], 36 | "factors": [2], 37 | "in_channels": 3, 38 | "hidden_channels": [24, 256, 256], 39 | "h_channels": 4, 40 | "h_type": "local_linear", 41 | "activation": "elu", 42 | "inverse": false, 43 | "transform": "affine", 44 | "prior_transform": "affine", 45 | "alpha": 1.0, 46 | "coupling_type": "conv" 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/celebA-HQ/macow/macow-base-var.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "macow", 5 | "levels": 7, 6 | "num_steps": [2, [10, 10], [8, 8], [8, 8], [4, 4], [1, 1], 1], 7 | "factors": [4, 4, 4, 4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 256, 256, 256, 256, 256, 512], 10 | "kernel_size": [2, 3], 11 | "activation": "elu", 12 | "inverse": true, 13 | "transform": "affine", 14 | "prior_transform": "affine", 15 | "coupling_type": "conv", 16 | "num_groups": [2, 4, 4, 4, 4, 4, 4] 17 | } 18 | }, 19 | "discriminator" : { 20 | "type": "base" 21 | }, 22 | "dequantizer": { 23 | "type": "flow", 24 | "encoder": { 25 | "type": "local_resnet_bn", 26 | "levels": 3, 27 | "in_planes": 3, 28 | "hidden_planes": [48, 96, 96], 29 | "out_planes": 4, 30 | "activation": "elu" 31 | }, 32 | "flow": { 33 | "type": "macow", 34 | "levels": 3, 35 | "num_steps": [2, [4], 2], 36 | "factors": [2], 37 | "in_channels": 3, 38 | "hidden_channels": [24, 256, 256], 39 | "kernel_size": [2, 3], 40 | "h_channels": 4, 41 | "h_type": "local_linear", 42 | "activation": "elu", 43 | "inverse": false, 44 | "transform": "affine", 45 | "prior_transform": "affine", 46 | "coupling_type": "conv" 47 | } 48 | } 49 | } 50 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/layers/container.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SequentialFlow(nn.Module): 5 | """A generalized nn.Sequential container for normalizing flows. 6 | """ 7 | 8 | def __init__(self, layersList): 9 | super(SequentialFlow, self).__init__() 10 | self.chain = nn.ModuleList(layersList) 11 | 12 | def forward(self, x, logpx=None, h=None): 13 | if logpx is None: 14 | for i in range(len(self.chain)): 15 | x = self.chain[i](x, h=h) 16 | return x 17 | else: 18 | for i in range(len(self.chain)): 19 | x, logpx = self.chain[i](x, logpx, h=h) 20 | return x, logpx 21 | 22 | def inverse(self, y, logpy=None, h=None): 23 | if logpy is None: 24 | for i in range(len(self.chain) - 1, -1, -1): 25 | y = self.chain[i].inverse(y, h=h) 26 | return y 27 | else: 28 | for i in range(len(self.chain) - 1, -1, -1): 29 | y, logpy = self.chain[i].inverse(y, logpy, h=h) 30 | return y, logpy 31 | 32 | 33 | class Inverse(nn.Module): 34 | 35 | def __init__(self, flow): 36 | super(Inverse, self).__init__() 37 | self.flow = flow 38 | 39 | def forward(self, x, logpx=None, h=None): 40 | return self.flow.inverse(x, logpx, h=h) 41 | 42 | def inverse(self, y, logpy=None, h=None): 43 | return self.flow.forward(y, logpy, h=h) 44 | -------------------------------------------------------------------------------- /flow_models/resflow/layers/squeeze.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ['SqueezeLayer'] 5 | 6 | 7 | class SqueezeLayer(nn.Module): 8 | 9 | def __init__(self, downscale_factor): 10 | super(SqueezeLayer, self).__init__() 11 | self.downscale_factor = downscale_factor 12 | 13 | def forward(self, x, logpx=None): 14 | squeeze_x = squeeze(x, self.downscale_factor) 15 | if logpx is None: 16 | return squeeze_x 17 | else: 18 | return squeeze_x, logpx 19 | 20 | def inverse(self, y, logpy=None): 21 | unsqueeze_y = unsqueeze(y, self.downscale_factor) 22 | if logpy is None: 23 | return unsqueeze_y 24 | else: 25 | return unsqueeze_y, logpy 26 | 27 | 28 | def unsqueeze(input, upscale_factor=2): 29 | return torch.pixel_shuffle(input, upscale_factor) 30 | 31 | 32 | def squeeze(input, downscale_factor=2): 33 | ''' 34 | [:, C, H*r, W*r] -> [:, C*r^2, H, W] 35 | ''' 36 | batch_size, in_channels, in_height, in_width = input.shape 37 | out_channels = in_channels * (downscale_factor**2) 38 | 39 | out_height = in_height // downscale_factor 40 | out_width = in_width // downscale_factor 41 | 42 | input_view = input.reshape(batch_size, in_channels, out_height, downscale_factor, out_width, downscale_factor) 43 | 44 | output = input_view.permute(0, 1, 3, 5, 2, 4) 45 | return output.reshape(batch_size, out_channels, out_height, out_width) 46 | -------------------------------------------------------------------------------- /flow_models/resflow/layers/mask_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _get_checkerboard_mask(x, swap=False): 5 | n, c, h, w = x.size() 6 | 7 | H = ((h - 1) // 2 + 1) * 2 # H = h + 1 if h is odd and h if h is even 8 | W = ((w - 1) // 2 + 1) * 2 9 | 10 | # construct checkerboard mask 11 | if not swap: 12 | mask = torch.Tensor([[1, 0], [0, 1]]).repeat(H // 2, W // 2) 13 | else: 14 | mask = torch.Tensor([[0, 1], [1, 0]]).repeat(H // 2, W // 2) 15 | mask = mask[:h, :w] 16 | mask = mask.contiguous().view(1, 1, h, w).expand(n, c, h, w).type_as(x.data) 17 | 18 | return mask 19 | 20 | 21 | def _get_channel_mask(x, swap=False): 22 | n, c, h, w = x.size() 23 | assert (c % 2 == 0) 24 | 25 | # construct channel-wise mask 26 | mask = torch.zeros(x.size()) 27 | if not swap: 28 | mask[:, :c // 2] = 1 29 | else: 30 | mask[:, c // 2:] = 1 31 | return mask 32 | 33 | 34 | def get_mask(x, mask_type=None): 35 | if mask_type is None: 36 | return torch.zeros(x.size()).to(x) 37 | elif mask_type == 'channel0': 38 | return _get_channel_mask(x, swap=False) 39 | elif mask_type == 'channel1': 40 | return _get_channel_mask(x, swap=True) 41 | elif mask_type == 'checkerboard0': 42 | return _get_checkerboard_mask(x, swap=False) 43 | elif mask_type == 'checkerboard1': 44 | return _get_checkerboard_mask(x, swap=True) 45 | else: 46 | raise ValueError('Unknown mask type {}'.format(mask_type)) 47 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/layers/squeeze.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | __all__ = ['SqueezeLayer'] 5 | 6 | 7 | class SqueezeLayer(nn.Module): 8 | 9 | def __init__(self, downscale_factor): 10 | super(SqueezeLayer, self).__init__() 11 | self.downscale_factor = downscale_factor 12 | 13 | def forward(self, x, logpx=None, h=None): 14 | squeeze_x = squeeze(x, self.downscale_factor) 15 | if logpx is None: 16 | return squeeze_x 17 | else: 18 | return squeeze_x, logpx 19 | 20 | def inverse(self, y, logpy=None, h=None): 21 | unsqueeze_y = unsqueeze(y, self.downscale_factor) 22 | if logpy is None: 23 | return unsqueeze_y 24 | else: 25 | return unsqueeze_y, logpy 26 | 27 | 28 | def unsqueeze(input, upscale_factor=2): 29 | return torch.pixel_shuffle(input, upscale_factor) 30 | 31 | 32 | def squeeze(input, downscale_factor=2): 33 | ''' 34 | [:, C, H*r, W*r] -> [:, C*r^2, H, W] 35 | ''' 36 | batch_size, in_channels, in_height, in_width = input.shape 37 | out_channels = in_channels * (downscale_factor**2) 38 | 39 | out_height = in_height // downscale_factor 40 | out_width = in_width // downscale_factor 41 | 42 | input_view = input.reshape(batch_size, in_channels, out_height, downscale_factor, out_width, downscale_factor) 43 | 44 | output = input_view.permute(0, 1, 3, 5, 2, 4) 45 | return output.reshape(batch_size, out_channels, out_height, out_width) 46 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/layers/mask_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def _get_checkerboard_mask(x, swap=False): 5 | n, c, h, w = x.size() 6 | 7 | H = ((h - 1) // 2 + 1) * 2 # H = h + 1 if h is odd and h if h is even 8 | W = ((w - 1) // 2 + 1) * 2 9 | 10 | # construct checkerboard mask 11 | if not swap: 12 | mask = torch.Tensor([[1, 0], [0, 1]]).repeat(H // 2, W // 2) 13 | else: 14 | mask = torch.Tensor([[0, 1], [1, 0]]).repeat(H // 2, W // 2) 15 | mask = mask[:h, :w] 16 | mask = mask.contiguous().view(1, 1, h, w).expand(n, c, h, w).type_as(x.data) 17 | 18 | return mask 19 | 20 | 21 | def _get_channel_mask(x, swap=False): 22 | n, c, h, w = x.size() 23 | assert (c % 2 == 0) 24 | 25 | # construct channel-wise mask 26 | mask = torch.zeros(x.size()) 27 | if not swap: 28 | mask[:, :c // 2] = 1 29 | else: 30 | mask[:, c // 2:] = 1 31 | return mask 32 | 33 | 34 | def get_mask(x, mask_type=None): 35 | if mask_type is None: 36 | return torch.zeros(x.size()).to(x) 37 | elif mask_type == 'channel0': 38 | return _get_channel_mask(x, swap=False) 39 | elif mask_type == 'channel1': 40 | return _get_channel_mask(x, swap=True) 41 | elif mask_type == 'checkerboard0': 42 | return _get_checkerboard_mask(x, swap=False) 43 | elif mask_type == 'checkerboard1': 44 | return _get_checkerboard_mask(x, swap=True) 45 | else: 46 | raise ValueError('Unknown mask type {}'.format(mask_type)) 47 | -------------------------------------------------------------------------------- /cleanfid/leaderboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import shutil 4 | import urllib.request 5 | 6 | def get_score(model_name=None, dataset_name=None, 7 | dataset_res=None, dataset_split=None, task_name=None): 8 | # download the csv file from server 9 | url = "https://www.cs.cmu.edu/~clean-fid/files/leaderboard.csv" 10 | local_path = "/tmp/leaderboard.csv" 11 | with urllib.request.urlopen(url) as response, open(local_path, 'wb') as f: 12 | shutil.copyfileobj(response, f) 13 | 14 | d_field2idx = {} 15 | l_matches = [] 16 | with open(local_path, 'r') as f: 17 | csvreader = csv.reader(f) 18 | l_fields = next(csvreader) 19 | for idx, val in enumerate(l_fields): 20 | d_field2idx[val.strip()]=idx 21 | # iterate through all rows 22 | for row in csvreader: 23 | # skip empty rows 24 | if len(row)==0: continue 25 | # skip if the filter doesn't match 26 | if model_name is not None and (row[d_field2idx["model_name"]].strip() != model_name): 27 | continue 28 | if dataset_name is not None and (row[d_field2idx["dataset_name"]].strip() != dataset_name): 29 | continue 30 | if dataset_res is not None and (row[d_field2idx["dataset_res"]].strip() != dataset_res): 31 | continue 32 | if dataset_split is not None and (row[d_field2idx["dataset_split"]].strip() != dataset_split): 33 | continue 34 | if task_name is not None and (row[d_field2idx["task_name"]].strip() != task_name): 35 | continue 36 | curr = {} 37 | for f in l_fields: 38 | curr[f.strip()] = row[d_field2idx[f.strip()]].strip() 39 | l_matches.append(curr) 40 | os.remove(local_path) 41 | return l_matches -------------------------------------------------------------------------------- /flow_models/wolf/wolf_configs/cifar10/glow/glow-gaussian-var.json: -------------------------------------------------------------------------------- 1 | { 2 | "generator": { 3 | "flow": { 4 | "type": "glow", 5 | "levels": 4, 6 | "num_steps": [2, [6, 6], [6, 6], 4], 7 | "factors": [4, 4], 8 | "in_channels": 3, 9 | "hidden_channels": [24, 512, 512, 512], 10 | "h_channels": 64, 11 | "h_type": "global_linear", 12 | "activation": "elu", 13 | "inverse": true, 14 | "transform": "affine", 15 | "prior_transform": "affine", 16 | "alpha": 1.0, 17 | "coupling_type": "conv", 18 | "num_groups": [2, 4, 4, 4] 19 | } 20 | }, 21 | "discriminator" : { 22 | "type": "gaussian", 23 | "encoder": { 24 | "type": "global_resnet_bn", 25 | "levels": 3, 26 | "in_planes": 3, 27 | "hidden_planes": [48, 96, 96], 28 | "out_planes": 8, 29 | "activation": "elu" 30 | }, 31 | "in_dim": 128, 32 | "dim": 64, 33 | "prior": { 34 | "type": "flow", 35 | "num_steps": 2, 36 | "in_features": 64, 37 | "hidden_features": 256, 38 | "activation": "elu", 39 | "transform": "affine", 40 | "alpha": 1.0, 41 | "coupling_type": "mlp" 42 | } 43 | }, 44 | "dequantizer": { 45 | "type": "flow", 46 | "encoder": { 47 | "type": "local_resnet_bn", 48 | "levels": 2, 49 | "in_planes": 3, 50 | "hidden_planes": [48, 96], 51 | "out_planes": 4, 52 | "activation": "elu" 53 | }, 54 | "flow": { 55 | "type": "glow", 56 | "levels": 2, 57 | "num_steps": [2, 4], 58 | "factors": [], 59 | "in_channels": 3, 60 | "hidden_channels": [24, 256], 61 | "h_channels": 4, 62 | "h_type": "local_linear", 63 | "activation": "elu", 64 | "inverse": false, 65 | "transform": "affine", 66 | "prior_transform": "affine", 67 | "alpha": 1.0, 68 | "coupling_type": "conv" 69 | } 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /flow_models/resflow/layers/glow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class InvertibleLinear(nn.Module): 7 | 8 | def __init__(self, dim): 9 | super(InvertibleLinear, self).__init__() 10 | self.dim = dim 11 | self.weight = nn.Parameter(torch.eye(dim)[torch.randperm(dim)]) 12 | 13 | def forward(self, x, logpx=None): 14 | y = F.linear(x, self.weight) 15 | if logpx is None: 16 | return y 17 | else: 18 | return y, logpx - self._logdetgrad 19 | 20 | def inverse(self, y, logpy=None): 21 | x = F.linear(y, self.weight.inverse()) 22 | if logpy is None: 23 | return x 24 | else: 25 | return x, logpy + self._logdetgrad 26 | 27 | @property 28 | def _logdetgrad(self): 29 | return torch.log(torch.abs(torch.det(self.weight))) 30 | 31 | def extra_repr(self): 32 | return 'dim={}'.format(self.dim) 33 | 34 | 35 | class InvertibleConv2d(nn.Module): 36 | 37 | def __init__(self, dim): 38 | super(InvertibleConv2d, self).__init__() 39 | self.dim = dim 40 | self.weight = nn.Parameter(torch.eye(dim)[torch.randperm(dim)]) 41 | 42 | def forward(self, x, logpx=None): 43 | y = F.conv2d(x, self.weight.view(self.dim, self.dim, 1, 1)) 44 | if logpx is None: 45 | return y 46 | else: 47 | return y, logpx - self._logdetgrad.expand_as(logpx) * x.shape[2] * x.shape[3] 48 | 49 | def inverse(self, y, logpy=None): 50 | x = F.conv2d(y, self.weight.inverse().view(self.dim, self.dim, 1, 1)) 51 | if logpy is None: 52 | return x 53 | else: 54 | return x, logpy + self._logdetgrad.expand_as(logpy) * x.shape[2] * x.shape[3] 55 | 56 | @property 57 | def _logdetgrad(self): 58 | return torch.log(torch.abs(torch.det(self.weight))) 59 | 60 | def extra_repr(self): 61 | return 'dim={}'.format(self.dim) 62 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/layers/glow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class InvertibleLinear(nn.Module): 7 | 8 | def __init__(self, dim): 9 | super(InvertibleLinear, self).__init__() 10 | self.dim = dim 11 | self.weight = nn.Parameter(torch.eye(dim)[torch.randperm(dim)]) 12 | 13 | def forward(self, x, logpx=None): 14 | y = F.linear(x, self.weight) 15 | if logpx is None: 16 | return y 17 | else: 18 | return y, logpx - self._logdetgrad 19 | 20 | def inverse(self, y, logpy=None): 21 | x = F.linear(y, self.weight.inverse()) 22 | if logpy is None: 23 | return x 24 | else: 25 | return x, logpy + self._logdetgrad 26 | 27 | @property 28 | def _logdetgrad(self): 29 | return torch.log(torch.abs(torch.det(self.weight))) 30 | 31 | def extra_repr(self): 32 | return 'dim={}'.format(self.dim) 33 | 34 | 35 | class InvertibleConv2d(nn.Module): 36 | 37 | def __init__(self, dim): 38 | super(InvertibleConv2d, self).__init__() 39 | self.dim = dim 40 | self.weight = nn.Parameter(torch.eye(dim)[torch.randperm(dim)]) 41 | 42 | def forward(self, x, logpx=None): 43 | y = F.conv2d(x, self.weight.view(self.dim, self.dim, 1, 1)) 44 | if logpx is None: 45 | return y 46 | else: 47 | return y, logpx - self._logdetgrad.expand_as(logpx) * x.shape[2] * x.shape[3] 48 | 49 | def inverse(self, y, logpy=None): 50 | x = F.conv2d(y, self.weight.inverse().view(self.dim, self.dim, 1, 1)) 51 | if logpy is None: 52 | return x 53 | else: 54 | return x, logpy + self._logdetgrad.expand_as(logpy) * x.shape[2] * x.shape[3] 55 | 56 | @property 57 | def _logdetgrad(self): 58 | return torch.log(torch.abs(torch.det(self.weight))) 59 | 60 | def extra_repr(self): 61 | return 'dim={}'.format(self.dim) 62 | -------------------------------------------------------------------------------- /cleanfid/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torchvision 5 | from PIL import Image 6 | import urllib.request 7 | import requests 8 | import shutil 9 | import torch.nn.functional as F 10 | #import logging 11 | 12 | class ResizeDataset(torch.utils.data.Dataset): 13 | """ 14 | A placeholder Dataset that enables parallelizing the resize operation 15 | using multiple CPU cores 16 | 17 | files: list of all files in the folder 18 | fn_resize: function that takes an np_array as input [0,255] 19 | """ 20 | 21 | def __init__(self, files, size=(299, 299), fn_resize=None): 22 | self.files = files 23 | self.transforms = torchvision.transforms.ToTensor() 24 | self.size = size 25 | self.fn_resize = fn_resize 26 | 27 | def __len__(self): 28 | return len(self.files) 29 | 30 | def __getitem__(self, i): 31 | path = str(self.files[i]) 32 | print("path : ", path) 33 | if ".npz" in path: 34 | img_np = np.load(path)['samples'] 35 | print(img_np.shape) 36 | else: 37 | img_pil = Image.open(path).convert('RGB') 38 | img_np = np.array(img_pil) 39 | 40 | # fn_resize expects a np array and returns a np array 41 | itr = 0 42 | for img in img_np: 43 | if itr == 0: 44 | img_resized = self.fn_resize(img) 45 | img_resized = img_resized.reshape((1,)+img_resized.shape) 46 | else: 47 | img_resized = np.concatenate((img_resized, self.fn_resize(img).reshape((1,) + img_resized.shape[1:]))) 48 | itr += 1 49 | 50 | 51 | # ToTensor() converts to [0,1] only if input in uint8 52 | if img_resized.dtype == "uint8": 53 | img_t = self.transforms(np.array(img_resized))*255 54 | elif img_resized.dtype == "float32": 55 | img_t = self.transforms(img_resized) 56 | 57 | return img_t 58 | 59 | 60 | #EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 61 | # 'tif', 'tiff', 'webp', 'npy'} 62 | 63 | EXTENSIONS = {'npz'} 64 | -------------------------------------------------------------------------------- /flow_models/wolf/modules/discriminators/priors/prior.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from typing import Dict 4 | from overrides import overrides 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class Prior(nn.Module): 11 | """ 12 | Prior base class 13 | """ 14 | _registry = dict() 15 | 16 | def __init__(self): 17 | super(Prior, self).__init__() 18 | 19 | def log_probability(self, z): 20 | raise NotImplementedError 21 | 22 | def sample(self, nsample, dim, device=torch.device('cpu')): 23 | raise NotImplementedError 24 | 25 | def calcKL(self, z, eps, mu, logvar): 26 | raise NotImplementedError 27 | 28 | def init(self, z, eps, mu, logvar, init_scale=1.0): 29 | raise NotImplementedError 30 | 31 | def sync(self): 32 | pass 33 | 34 | @classmethod 35 | def register(cls, name: str): 36 | Prior._registry[name] = cls 37 | 38 | @classmethod 39 | def by_name(cls, name: str): 40 | return Prior._registry[name] 41 | 42 | @classmethod 43 | def from_params(cls, params: Dict) -> "Prior": 44 | raise NotImplementedError 45 | 46 | 47 | class NormalPrior(Prior): 48 | """ 49 | Prior base class 50 | """ 51 | 52 | def __init__(self): 53 | super(NormalPrior, self).__init__() 54 | 55 | @overrides 56 | def log_probability(self, z): 57 | # [batch, nsamples, dim] 58 | dim = z.size(2) 59 | # [batch, nsamples] 60 | log_probs = z.pow(2).sum(dim=2) + math.log(math.pi * 2.) * dim 61 | return log_probs * -0.5 62 | 63 | @overrides 64 | def sample(self, nsamples, dim, device=torch.device('cpu')): 65 | epsilon = torch.randn(nsamples, dim, device=device) 66 | return epsilon 67 | 68 | @overrides 69 | def calcKL(self, z, eps, mu, logvar): 70 | return 0.5 * (mu.pow(2) + logvar.exp() - logvar - 1).sum(dim=1) 71 | 72 | @overrides 73 | def init(self, z, eps, mu, logvar, init_scale=1.0): 74 | return self.calcKL(z, eps, mu, logvar) 75 | 76 | @classmethod 77 | def from_params(cls, params: Dict) -> "NormalPrior": 78 | return NormalPrior() 79 | 80 | 81 | NormalPrior.register('normal') 82 | -------------------------------------------------------------------------------- /flow_models/resflow/layers/base/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class Sin(nn.Module): 8 | def __init__(self): 9 | super(Sin, self).__init__() 10 | 11 | def forward(self, x): 12 | return torch.sin(2. * math.pi * x) / math.pi * 0.5 13 | 14 | class Identity(nn.Module): 15 | 16 | def forward(self, x): 17 | return x 18 | 19 | 20 | class FullSort(nn.Module): 21 | 22 | def forward(self, x): 23 | return torch.sort(x, 1)[0] 24 | 25 | 26 | class MaxMin(nn.Module): 27 | 28 | def forward(self, x): 29 | b, d = x.shape 30 | max_vals = torch.max(x.view(b, d // 2, 2), 2)[0] 31 | min_vals = torch.min(x.view(b, d // 2, 2), 2)[0] 32 | return torch.cat([max_vals, min_vals], 1) 33 | 34 | 35 | class LipschitzCube(nn.Module): 36 | 37 | def forward(self, x): 38 | return (x >= 1).to(x) * (x - 2 / 3) + (x <= -1).to(x) * (x + 2 / 3) + ((x > -1) * (x < 1)).to(x) * x**3 / 3 39 | 40 | 41 | class SwishFn(torch.autograd.Function): 42 | 43 | @staticmethod 44 | def forward(ctx, x, beta): 45 | beta_sigm = torch.sigmoid(beta * x) 46 | output = x * beta_sigm 47 | ctx.save_for_backward(x, output, beta) 48 | return output / 1.1 49 | 50 | @staticmethod 51 | def backward(ctx, grad_output): 52 | x, output, beta = ctx.saved_tensors 53 | beta_sigm = output / x 54 | grad_x = grad_output * (beta * output + beta_sigm * (1 - beta * output)) 55 | grad_beta = torch.sum(grad_output * (x * output - output * output)).expand_as(beta) 56 | return grad_x / 1.1, grad_beta / 1.1 57 | 58 | 59 | class Swish(nn.Module): 60 | 61 | def __init__(self): 62 | super(Swish, self).__init__() 63 | self.beta = nn.Parameter(torch.tensor([0.5])) 64 | 65 | def forward(self, x): 66 | return (x * torch.sigmoid_(x * F.softplus(self.beta))).div_(1.1) 67 | 68 | 69 | if __name__ == '__main__': 70 | 71 | m = Swish() 72 | xx = torch.linspace(-5, 5, 1000).requires_grad_(True) 73 | yy = m(xx) 74 | dd, dbeta = torch.autograd.grad(yy.sum() * 2, [xx, m.beta]) 75 | 76 | import matplotlib.pyplot as plt 77 | 78 | plt.plot(xx.detach().numpy(), yy.detach().numpy(), label='Func') 79 | plt.plot(xx.detach().numpy(), dd.detach().numpy(), label='Deriv') 80 | plt.plot(xx.detach().numpy(), torch.max(dd.detach().abs() - 1, torch.zeros_like(dd)).numpy(), label='|Deriv| > 1') 81 | plt.legend() 82 | plt.tight_layout() 83 | plt.show() 84 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/layers/base/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | 7 | class Sin(nn.Module): 8 | def __init__(self): 9 | super(Sin, self).__init__() 10 | 11 | def forward(self, x): 12 | return torch.sin(2. * math.pi * x) / math.pi * 0.5 13 | 14 | class Identity(nn.Module): 15 | 16 | def forward(self, x): 17 | return x 18 | 19 | 20 | class FullSort(nn.Module): 21 | 22 | def forward(self, x): 23 | return torch.sort(x, 1)[0] 24 | 25 | 26 | class MaxMin(nn.Module): 27 | 28 | def forward(self, x): 29 | b, d = x.shape 30 | max_vals = torch.max(x.view(b, d // 2, 2), 2)[0] 31 | min_vals = torch.min(x.view(b, d // 2, 2), 2)[0] 32 | return torch.cat([max_vals, min_vals], 1) 33 | 34 | 35 | class LipschitzCube(nn.Module): 36 | 37 | def forward(self, x): 38 | return (x >= 1).to(x) * (x - 2 / 3) + (x <= -1).to(x) * (x + 2 / 3) + ((x > -1) * (x < 1)).to(x) * x**3 / 3 39 | 40 | 41 | class SwishFn(torch.autograd.Function): 42 | 43 | @staticmethod 44 | def forward(ctx, x, beta): 45 | beta_sigm = torch.sigmoid(beta * x) 46 | output = x * beta_sigm 47 | ctx.save_for_backward(x, output, beta) 48 | return output / 1.1 49 | 50 | @staticmethod 51 | def backward(ctx, grad_output): 52 | x, output, beta = ctx.saved_tensors 53 | beta_sigm = output / x 54 | grad_x = grad_output * (beta * output + beta_sigm * (1 - beta * output)) 55 | grad_beta = torch.sum(grad_output * (x * output - output * output)).expand_as(beta) 56 | return grad_x / 1.1, grad_beta / 1.1 57 | 58 | 59 | class Swish(nn.Module): 60 | 61 | def __init__(self): 62 | super(Swish, self).__init__() 63 | self.beta = nn.Parameter(torch.tensor([0.5])) 64 | 65 | def forward(self, x): 66 | return (x * torch.sigmoid_(x * F.softplus(self.beta))).div_(1.1) 67 | 68 | 69 | if __name__ == '__main__': 70 | 71 | m = Swish() 72 | xx = torch.linspace(-5, 5, 1000).requires_grad_(True) 73 | yy = m(xx) 74 | dd, dbeta = torch.autograd.grad(yy.sum() * 2, [xx, m.beta]) 75 | 76 | import matplotlib.pyplot as plt 77 | 78 | plt.plot(xx.detach().numpy(), yy.detach().numpy(), label='Func') 79 | plt.plot(xx.detach().numpy(), dd.detach().numpy(), label='Deriv') 80 | plt.plot(xx.detach().numpy(), torch.max(dd.detach().abs() - 1, torch.zeros_like(dd)).numpy(), label='|Deriv| > 1') 81 | plt.legend() 82 | plt.tight_layout() 83 | plt.show() 84 | -------------------------------------------------------------------------------- /cleanfid/downloads_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from argparse import ArgumentParser 4 | import urllib.request 5 | import requests 6 | import shutil 7 | 8 | 9 | inception_url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt" 10 | 11 | 12 | def check_download_inception(fpath="./"): 13 | inception_path = os.path.join(fpath, "inception-2015-12-05.pt") 14 | if not os.path.exists(inception_path): 15 | # download the file 16 | with urllib.request.urlopen(inception_url) as response, open(inception_path, 'wb') as f: 17 | shutil.copyfileobj(response, f) 18 | return inception_path 19 | 20 | 21 | def check_download_url(local_folder, url): 22 | name = os.path.basename(url) 23 | local_path = os.path.join(local_folder, name) 24 | if not os.path.exists(local_path): 25 | os.makedirs(local_folder, exist_ok=True) 26 | print(f"downloading statistics to {local_path}") 27 | with urllib.request.urlopen(url) as response, open(local_path, 'wb') as f: 28 | shutil.copyfileobj(response, f) 29 | return local_path 30 | 31 | 32 | def get_confirm_token(response): 33 | for key, value in response.cookies.items(): 34 | if key.startswith('download_warning'): 35 | return value 36 | return None 37 | 38 | # download stylegan-ffhq-1024 weights 39 | def download_google_drive(file_id="1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT", out_path="stylegan2-ffhq-config-f.pt"): 40 | URL = "https://drive.google.com/uc?export=download" 41 | session = requests.Session() 42 | response = session.get(URL, params={'id': file_id}, stream=True) 43 | token = get_confirm_token(response) 44 | 45 | if token: 46 | params = {'id': file_id, 'confirm': token} 47 | response = session.get(URL, params=params, stream=True) 48 | 49 | CHUNK_SIZE = 32768 50 | with open(out_path, "wb") as f: 51 | for chunk in response.iter_content(CHUNK_SIZE): 52 | if chunk: 53 | f.write(chunk) 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = ArgumentParser() 58 | parser.add_argument('--download_stylegan_weights', default=False, action="store_true") 59 | parser.add_argument('--save_path', type=str, required=True) 60 | args = parser.parse_args() 61 | 62 | if args.download_stylegan_weights: 63 | gdrive_id = "1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT" 64 | print("downloading pretrained stylegan2 from google drive...") 65 | download_google_drive(file_id=gdrive_id, out_path=args.save_path) 66 | print("download complete") 67 | -------------------------------------------------------------------------------- /flow_models/wolf/nnet/positional_encoding.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | 7 | from flow_models.wolf.utils import make_positions 8 | 9 | 10 | class PositionalEncoding(nn.Module): 11 | """This module produces sinusoidal positional embeddings of any length. 12 | Padding symbols are ignored. 13 | """ 14 | 15 | def __init__(self, encoding_dim, padding_idx, init_size=1024): 16 | super().__init__() 17 | self.encoding_dim = encoding_dim 18 | self.padding_idx = padding_idx 19 | self.weights = PositionalEncoding.get_embedding( 20 | init_size, 21 | encoding_dim, 22 | padding_idx, 23 | ) 24 | self.register_buffer('_float_tensor', torch.FloatTensor(1)) 25 | 26 | @staticmethod 27 | def get_embedding(num_encodings, encoding_dim, padding_idx=None): 28 | """Build sinusoidal embeddings. 29 | This matches the implementation in tensor2tensor, but differs slightly 30 | from the description in Section 3.5 of "Attention Is All You Need". 31 | """ 32 | half_dim = encoding_dim // 2 33 | emb = math.log(10000) / (half_dim - 1) 34 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) 35 | emb = torch.arange(num_encodings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) 36 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_encodings, -1) 37 | if encoding_dim % 2 == 1: 38 | # zero pad 39 | emb = torch.cat([emb, torch.zeros(num_encodings, 1)], dim=1) 40 | emb[0, :] = 0 41 | return emb 42 | 43 | def forward(self, x): 44 | """Input is expected to be of size [bsz x seqlen].""" 45 | bsz, seq_len = x.size()[:2] 46 | max_pos = seq_len + 1 47 | if self.weights is None or max_pos > self.weights.size(0): 48 | # recompute/expand embeddings if needed 49 | self.weights = PositionalEncoding.get_embedding( 50 | max_pos, 51 | self.embedding_dim, 52 | self.padding_idx, 53 | ) 54 | self.weights = self.weights.type_as(self._float_tensor) 55 | 56 | if self.padding_idx is None: 57 | return self.weights[1:seq_len + 1].detach() 58 | else: 59 | positions = make_positions(x, self.padding_idx) 60 | return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() 61 | 62 | def max_positions(self): 63 | """Maximum number of supported positions.""" 64 | return int(1e5) # an arbitrary large number 65 | -------------------------------------------------------------------------------- /flow_models/wolf/modules/discriminators/discriminator.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from typing import Dict 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class Discriminator(nn.Module): 9 | """ 10 | Discriminator base class 11 | """ 12 | _registry = dict() 13 | 14 | def __init__(self): 15 | super(Discriminator, self).__init__() 16 | 17 | def sample_from_prior(self, nsamples=1, device=torch.device('cpu')): 18 | """ 19 | 20 | Args: 21 | nsamples: int 22 | Number of samples 23 | device: torch.device 24 | device to store the samples 25 | 26 | Returns: Tensor[nsamples, dim] 27 | the tensor of samples 28 | 29 | """ 30 | return None 31 | 32 | def sample_from_posterior(self, x, y=None, nsamples=1, random=True): 33 | """ 34 | 35 | Args: 36 | x: Tensor 37 | The input data 38 | y: Tensor or None 39 | The label id of the data (for conditional generation). 40 | nsamples: int 41 | Number of samples for each instance. 42 | random: bool 43 | if True, perform random sampling. 44 | 45 | Returns: Tensor1, Tensor2 46 | Tensor1: samples from the posterior [batch, nsamples, dim] 47 | Tensor2: log probabilities [batch, nsamples] 48 | 49 | """ 50 | return None, None 51 | 52 | def sampling_and_KL(self, x, y=None, nsamples=1): 53 | """ 54 | 55 | Args: 56 | x: Tensor 57 | The input data 58 | y: Tensor or None 59 | The label id of the data (for conditional generation). 60 | nsamples: int 61 | Number of samples for each instance. 62 | 63 | Returns: Tensor1, Tensor2, Tensor3, Tensor4 64 | Tensor1: samples from the posterior [batch, nsamples, dim] 65 | Tensor2: tensor for KL [batch,] 66 | # Tensor3: log probabilities of posterior [batch, nsamples] 67 | # Tensor4: log probabilities of prior [batch, nsamples] 68 | 69 | """ 70 | return None, None 71 | 72 | def init(self, x, y=None, init_scale=1.0): 73 | with torch.no_grad(): 74 | return self.sampling_and_KL(x, y=y) 75 | 76 | def to_device(self, device): 77 | pass 78 | 79 | def sync(self): 80 | pass 81 | 82 | @classmethod 83 | def register(cls, name: str): 84 | Discriminator._registry[name] = cls 85 | 86 | @classmethod 87 | def by_name(cls, name: str): 88 | return Discriminator._registry[name] 89 | 90 | @classmethod 91 | def from_params(cls, params: Dict) -> "Discriminator": 92 | return Discriminator() 93 | 94 | 95 | Discriminator.register('base') 96 | -------------------------------------------------------------------------------- /flow_models/resflow/layers/elemwise.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | _DEFAULT_ALPHA = 1e-6 6 | 7 | 8 | class ZeroMeanTransform(nn.Module): 9 | 10 | def __init__(self): 11 | nn.Module.__init__(self) 12 | 13 | def forward(self, x, logpx=None): 14 | x = x - .5 15 | if logpx is None: 16 | return x 17 | return x, logpx 18 | 19 | def inverse(self, y, logpy=None): 20 | y = y + .5 21 | if logpy is None: 22 | return y 23 | return y, logpy 24 | 25 | 26 | class Normalize(nn.Module): 27 | 28 | def __init__(self, mean, std): 29 | nn.Module.__init__(self) 30 | self.register_buffer('mean', torch.as_tensor(mean, dtype=torch.float32)) 31 | self.register_buffer('std', torch.as_tensor(std, dtype=torch.float32)) 32 | 33 | def forward(self, x, logpx=None): 34 | y = x.clone() 35 | c = len(self.mean) 36 | y[:, :c].sub_(self.mean[None, :, None, None]).div_(self.std[None, :, None, None]) 37 | if logpx is None: 38 | return y 39 | else: 40 | return y, logpx - self._logdetgrad(x) 41 | 42 | def inverse(self, y, logpy=None): 43 | x = y.clone() 44 | c = len(self.mean) 45 | x[:, :c].mul_(self.std[None, :, None, None]).add_(self.mean[None, :, None, None]) 46 | if logpy is None: 47 | return x 48 | else: 49 | return x, logpy + self._logdetgrad(x) 50 | 51 | def _logdetgrad(self, x): 52 | logdetgrad = ( 53 | self.std.abs().log().mul_(-1).view(1, -1, 1, 1).expand(x.shape[0], len(self.std), x.shape[2], x.shape[3]) 54 | ) 55 | return logdetgrad.reshape(x.shape[0], -1).sum(-1, keepdim=True) 56 | 57 | 58 | class LogitTransform(nn.Module): 59 | """ 60 | The proprocessing step used in Real NVP: 61 | y = sigmoid(x) - a / (1 - 2a) 62 | x = logit(a + (1 - 2a)*y) 63 | """ 64 | 65 | def __init__(self, alpha=_DEFAULT_ALPHA): 66 | nn.Module.__init__(self) 67 | self.alpha = alpha 68 | 69 | def forward(self, x, logpx=None): 70 | s = self.alpha + (1 - 2 * self.alpha) * x 71 | y = torch.log(s) - torch.log(1 - s) 72 | if logpx is None: 73 | return y 74 | return y, logpx - self._logdetgrad(x).view(x.size(0), -1).sum(1, keepdim=True) 75 | 76 | def inverse(self, y, logpy=None): 77 | x = (torch.sigmoid(y) - self.alpha) / (1 - 2 * self.alpha) 78 | if logpy is None: 79 | return x 80 | return x, logpy + self._logdetgrad(x).view(x.size(0), -1).sum(1, keepdim=True) 81 | 82 | def _logdetgrad(self, x): 83 | s = self.alpha + (1 - 2 * self.alpha) * x 84 | logdetgrad = -torch.log(s - s * s) + math.log(1 - 2 * self.alpha) 85 | return logdetgrad 86 | 87 | def __repr__(self): 88 | return ('{name}({alpha})'.format(name=self.__class__.__name__, **self.__dict__)) 89 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/layers/elemwise.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | _DEFAULT_ALPHA = 1e-6 6 | 7 | 8 | class ZeroMeanTransform(nn.Module): 9 | 10 | def __init__(self): 11 | nn.Module.__init__(self) 12 | 13 | def forward(self, x, logpx=None): 14 | x = x - .5 15 | if logpx is None: 16 | return x 17 | return x, logpx 18 | 19 | def inverse(self, y, logpy=None): 20 | y = y + .5 21 | if logpy is None: 22 | return y 23 | return y, logpy 24 | 25 | 26 | class Normalize(nn.Module): 27 | 28 | def __init__(self, mean, std): 29 | nn.Module.__init__(self) 30 | self.register_buffer('mean', torch.as_tensor(mean, dtype=torch.float32)) 31 | self.register_buffer('std', torch.as_tensor(std, dtype=torch.float32)) 32 | 33 | def forward(self, x, logpx=None): 34 | y = x.clone() 35 | c = len(self.mean) 36 | y[:, :c].sub_(self.mean[None, :, None, None]).div_(self.std[None, :, None, None]) 37 | if logpx is None: 38 | return y 39 | else: 40 | return y, logpx - self._logdetgrad(x) 41 | 42 | def inverse(self, y, logpy=None): 43 | x = y.clone() 44 | c = len(self.mean) 45 | x[:, :c].mul_(self.std[None, :, None, None]).add_(self.mean[None, :, None, None]) 46 | if logpy is None: 47 | return x 48 | else: 49 | return x, logpy + self._logdetgrad(x) 50 | 51 | def _logdetgrad(self, x): 52 | logdetgrad = ( 53 | self.std.abs().log().mul_(-1).view(1, -1, 1, 1).expand(x.shape[0], len(self.std), x.shape[2], x.shape[3]) 54 | ) 55 | return logdetgrad.reshape(x.shape[0], -1).sum(-1, keepdim=True) 56 | 57 | 58 | class LogitTransform(nn.Module): 59 | """ 60 | The proprocessing step used in Real NVP: 61 | y = sigmoid(x) - a / (1 - 2a) 62 | x = logit(a + (1 - 2a)*y) 63 | """ 64 | 65 | def __init__(self, alpha=_DEFAULT_ALPHA): 66 | nn.Module.__init__(self) 67 | self.alpha = alpha 68 | 69 | def forward(self, x, logpx=None): 70 | s = self.alpha + (1 - 2 * self.alpha) * x 71 | y = torch.log(s) - torch.log(1 - s) 72 | if logpx is None: 73 | return y 74 | return y, logpx - self._logdetgrad(x).view(x.size(0), -1).sum(1, keepdim=True) 75 | 76 | def inverse(self, y, logpy=None): 77 | x = (torch.sigmoid(y) - self.alpha) / (1 - 2 * self.alpha) 78 | if logpy is None: 79 | return x 80 | return x, logpy + self._logdetgrad(x).view(x.size(0), -1).sum(1, keepdim=True) 81 | 82 | def _logdetgrad(self, x): 83 | s = self.alpha + (1 - 2 * self.alpha) * x 84 | logdetgrad = -torch.log(s - s * s) + math.log(1 - 2 * self.alpha) 85 | return logdetgrad 86 | 87 | def __repr__(self): 88 | return ('{name}({alpha})'.format(name=self.__class__.__name__, **self.__dict__)) 89 | -------------------------------------------------------------------------------- /flow_models/wolf/nnet/shift_conv.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from overrides import overrides 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class ShiftedConv2d(nn.Conv2d): 9 | """ 10 | Conv2d with shift operation. 11 | A -> top 12 | B -> bottom 13 | C -> left 14 | D -> right 15 | """ 16 | 17 | def __init__(self, in_channels, out_channels, kernel_size, stride=(1, 1), dilation=1, groups=1, bias=True, order='A'): 18 | assert len(stride) == 2 19 | assert len(kernel_size) == 2 20 | assert order in {'A', 'B', 'C', 'D'}, 'unknown order: {}'.format(order) 21 | if order in {'A', 'B'}: 22 | assert kernel_size[1] % 2 == 1, 'kernel width cannot be even number: {}'.format(kernel_size) 23 | else: 24 | assert kernel_size[0] % 2 == 1, 'kernel height cannot be even number: {}'.format(kernel_size) 25 | 26 | self.order = order 27 | if order == 'A': 28 | # left, right, top, bottom 29 | self.shift_padding = ((kernel_size[1] - 1) // 2, (kernel_size[1] - 1) // 2, kernel_size[0], 0) 30 | # top, bottom, left, right 31 | self.cut = (0, -1, 0, 0) 32 | elif order == 'B': 33 | # left, right, top, bottom 34 | self.shift_padding = ((kernel_size[1] - 1) // 2, (kernel_size[1] - 1) // 2, 0, kernel_size[0]) 35 | # top, bottom, left, right 36 | self.cut = (1, 0, 0, 0) 37 | elif order == 'C': 38 | # left, right, top, bottom 39 | self.shift_padding = (kernel_size[1], 0, (kernel_size[0] - 1) // 2, (kernel_size[0] - 1) // 2) 40 | # top, bottom, left, right 41 | self.cut = (0, 0, 0, -1) 42 | elif order == 'D': 43 | # left, right, top, bottom 44 | self.shift_padding = (0, kernel_size[1], (kernel_size[0] - 1) // 2, (kernel_size[0] - 1) // 2) 45 | # top, bottom, left, right 46 | self.cut = (0, 0, 1, 0) 47 | else: 48 | self.shift_padding = None 49 | raise ValueError('unknown order: {}'.format(order)) 50 | 51 | super(ShiftedConv2d, self).__init__(in_channels, out_channels, kernel_size, padding=0, 52 | stride=stride, dilation=dilation, groups=groups, bias=bias) 53 | 54 | def forward(self, input, shifted=True): 55 | if shifted: 56 | input = F.pad(input, self.shift_padding) 57 | bs, channels, height, width = input.size() 58 | t, b, l, r = self.cut 59 | input = input[:, :, t:height + b, l:width + r] 60 | return F.conv2d(input, self.weight, self.bias, self.stride, 61 | self.padding, self.dilation, self.groups) 62 | 63 | @overrides 64 | def extra_repr(self): 65 | s = super(ShiftedConv2d, self).extra_repr() 66 | s += ', order={order}' 67 | s += ', shift_padding={shift_padding}' 68 | s += ', cut={cut}' 69 | return s.format(**self.__dict__) 70 | -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | grad_bias = grad_input.sum(dim).detach() 39 | 40 | return grad_input, grad_bias 41 | 42 | @staticmethod 43 | def backward(ctx, gradgrad_input, gradgrad_bias): 44 | out, = ctx.saved_tensors 45 | gradgrad_out = fused.fused_bias_act( 46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 47 | ) 48 | 49 | return gradgrad_out, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 57 | ctx.save_for_backward(out) 58 | ctx.negative_slope = negative_slope 59 | ctx.scale = scale 60 | 61 | return out 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | out, = ctx.saved_tensors 66 | 67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 68 | grad_output, out, ctx.negative_slope, ctx.scale 69 | ) 70 | 71 | return grad_input, grad_bias, None, None 72 | 73 | 74 | class FusedLeakyReLU(nn.Module): 75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 76 | super().__init__() 77 | 78 | self.bias = nn.Parameter(torch.zeros(channel)) 79 | self.negative_slope = negative_slope 80 | self.scale = scale 81 | 82 | def forward(self, input): 83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 84 | 85 | 86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 87 | if input.device.type == "cpu": 88 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 89 | return ( 90 | F.leaky_relu( 91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 92 | ) 93 | * scale 94 | ) 95 | 96 | else: 97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 98 | -------------------------------------------------------------------------------- /flow_models/resflow/layers/nonlinear_activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | import math 5 | 6 | #__all__ = ['ActNorm1d', 'ActNorm2d'] 7 | 8 | class Sigmoid(nn.Module): 9 | 10 | def __init__(self, eps=1e-5): 11 | super(Sigmoid, self).__init__() 12 | 13 | def forward(self, x, logpx=None): 14 | if logpx is None: 15 | return torch.sigmoid(x) 16 | else: 17 | return torch.sigmoid(x), logpx + self._logdetgrad(x) 18 | 19 | def inverse(self, y, logpy=None): 20 | if logpy is None: 21 | return torch.log(1. - y) - torch.log(y) 22 | else: 23 | return torch.log(1. - y) - torch.log(y), logpy - self._logdetgrad(y) 24 | 25 | def _logdetgrad(self, x): 26 | return torch.log(torch.exp(-x) / (1. + torch.exp(-x)) ** 2).view(x.shape[0],-1).sum(1).reshape(x.shape[0],1) 27 | 28 | def __repr__(self): 29 | return ('{name}({num_features})'.format(name=self.__class__.__name__, **self.__dict__)) 30 | 31 | class Tanh(nn.Module): 32 | 33 | def __init__(self, eps=1e-5): 34 | super(Tanh, self).__init__() 35 | 36 | def forward(self, x, logpx=None): 37 | if logpx is None: 38 | return torch.tanh(x) 39 | else: 40 | return torch.tanh(x), logpx + self._logdetgrad(x) 41 | 42 | def inverse(self, y, logpy=None): 43 | if logpy is None: 44 | return 0.5 * (torch.log(1. + y) - torch.log(1. - y)) 45 | else: 46 | return 0.5 * (torch.log(1. + y) - torch.log(1. - y)), logpy - self._logdetgrad(y) 47 | 48 | def _logdetgrad(self, x): 49 | return torch.log(4. * torch.exp(- 2. * x) / (1. + torch.exp(- 2. * x)) ** 2).view(x.shape[0], -1).sum(1).reshape(x.shape[0],1) 50 | 51 | def __repr__(self): 52 | return ('{name}({num_features})'.format(name=self.__class__.__name__, **self.__dict__)) 53 | 54 | # noinspection PyUnusedLocal 55 | class LogitTransform_(nn.Module): 56 | """ 57 | The proprocessing step used in Real NVP: 58 | y = sigmoid(x) - a / (1 - 2a) 59 | x = logit(a + (1 - 2a)*y) 60 | """ 61 | 62 | def __init__(self, alpha): 63 | nn.Module.__init__(self) 64 | self.alpha = alpha 65 | 66 | def forward_transform(self, x, logpx=None): 67 | s = self.alpha + (1 - 2 * self.alpha) * x 68 | y = torch.log(s) - torch.log(1 - s) 69 | if logpx is None: 70 | return y 71 | return y, logpx + self._logdetgrad(x).reshape(x.size(0), -1).sum(1).reshape(x.size(0),1) 72 | 73 | def reverse(self, y, logpy=None, **kwargs): 74 | x = (torch.sigmoid(y) - self.alpha) / (1 - 2 * self.alpha) 75 | if logpy is None: 76 | return x 77 | return x, logpy - self._logdetgrad(x).reshape(x.size(0), -1).sum(1).reshape(x.size(0),1) 78 | 79 | def _logdetgrad(self, x): 80 | s = self.alpha + (1 - 2 * self.alpha) * x 81 | logdetgrad = -torch.log(s - s * s) + math.log(1 - 2 * self.alpha) 82 | return logdetgrad 83 | 84 | def __repr__(self): 85 | return '{name}({alpha})'.format(name=self.__class__.__name__, **self.__dict__) -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/layers/nonlinear_activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | import math 5 | 6 | #__all__ = ['ActNorm1d', 'ActNorm2d'] 7 | 8 | class Sigmoid(nn.Module): 9 | 10 | def __init__(self, eps=1e-5): 11 | super(Sigmoid, self).__init__() 12 | 13 | def forward(self, x, logpx=None): 14 | if logpx is None: 15 | return torch.sigmoid(x) 16 | else: 17 | return torch.sigmoid(x), logpx + self._logdetgrad(x) 18 | 19 | def inverse(self, y, logpy=None): 20 | if logpy is None: 21 | return torch.log(1. - y) - torch.log(y) 22 | else: 23 | return torch.log(1. - y) - torch.log(y), logpy - self._logdetgrad(y) 24 | 25 | def _logdetgrad(self, x): 26 | return torch.log(torch.exp(-x) / (1. + torch.exp(-x)) ** 2).view(x.shape[0],-1).sum(1).reshape(x.shape[0],1) 27 | 28 | def __repr__(self): 29 | return ('{name}({num_features})'.format(name=self.__class__.__name__, **self.__dict__)) 30 | 31 | class Tanh(nn.Module): 32 | 33 | def __init__(self, eps=1e-5): 34 | super(Tanh, self).__init__() 35 | 36 | def forward(self, x, logpx=None): 37 | if logpx is None: 38 | return torch.tanh(x) 39 | else: 40 | return torch.tanh(x), logpx + self._logdetgrad(x) 41 | 42 | def inverse(self, y, logpy=None): 43 | if logpy is None: 44 | return 0.5 * (torch.log(1. + y) - torch.log(1. - y)) 45 | else: 46 | return 0.5 * (torch.log(1. + y) - torch.log(1. - y)), logpy - self._logdetgrad(y) 47 | 48 | def _logdetgrad(self, x): 49 | return torch.log(4. * torch.exp(- 2. * x) / (1. + torch.exp(- 2. * x)) ** 2).view(x.shape[0], -1).sum(1).reshape(x.shape[0],1) 50 | 51 | def __repr__(self): 52 | return ('{name}({num_features})'.format(name=self.__class__.__name__, **self.__dict__)) 53 | 54 | # noinspection PyUnusedLocal 55 | class LogitTransform_(nn.Module): 56 | """ 57 | The proprocessing step used in Real NVP: 58 | y = sigmoid(x) - a / (1 - 2a) 59 | x = logit(a + (1 - 2a)*y) 60 | """ 61 | 62 | def __init__(self, alpha): 63 | nn.Module.__init__(self) 64 | self.alpha = alpha 65 | 66 | def forward_transform(self, x, logpx=None): 67 | s = self.alpha + (1 - 2 * self.alpha) * x 68 | y = torch.log(s) - torch.log(1 - s) 69 | if logpx is None: 70 | return y 71 | return y, logpx + self._logdetgrad(x).reshape(x.size(0), -1).sum(1).reshape(x.size(0),1) 72 | 73 | def reverse(self, y, logpy=None, **kwargs): 74 | x = (torch.sigmoid(y) - self.alpha) / (1 - 2 * self.alpha) 75 | if logpy is None: 76 | return x 77 | return x, logpy - self._logdetgrad(x).reshape(x.size(0), -1).sum(1).reshape(x.size(0),1) 78 | 79 | def _logdetgrad(self, x): 80 | s = self.alpha + (1 - 2 * self.alpha) * x 81 | logdetgrad = -torch.log(s - s * s) + math.log(1 - 2 * self.alpha) 82 | return logdetgrad 83 | 84 | def __repr__(self): 85 | return '{name}({alpha})'.format(name=self.__class__.__name__, **self.__dict__) -------------------------------------------------------------------------------- /op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /configs/ve/CIFAR10/indm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with VE SDE.""" 18 | from configs.default_cifar10_configs import get_default_configs 19 | 20 | 21 | def get_config(): 22 | config = get_default_configs() 23 | # training 24 | training = config.training 25 | training.sde = 'vesde' 26 | training.continuous = True 27 | 28 | training.likelihood_weighting = True 29 | training.importance_sampling = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'reverse_diffusion' 35 | sampling.corrector = 'langevin' 36 | 37 | # model 38 | model = config.model 39 | model.name = 'ncsnpp' 40 | model.scale_by_sigma = True 41 | model.ema_rate = 0.999 42 | model.normalization = 'GroupNorm' 43 | model.nonlinearity = 'swish' 44 | model.nf = 128 45 | model.ch_mult = (1, 2, 2, 2) 46 | model.num_res_blocks = 4 47 | model.attn_resolutions = (16,) 48 | model.resamp_with_conv = True 49 | model.conditional = True 50 | model.fir = True 51 | model.fir_kernel = [1, 3, 3, 1] 52 | model.skip_rescale = True 53 | model.resblock_type = 'biggan' 54 | model.progressive = 'none' 55 | model.progressive_input = 'residual' 56 | model.progressive_combine = 'sum' 57 | model.attention_type = 'ddpm' 58 | model.init_scale = 0. 59 | model.fourier_scale = 16 60 | model.conv_size = 3 61 | 62 | # flow 63 | flow = config.flow 64 | flow.model = 'wolf' 65 | flow.lr = 1e-3 66 | flow.ema_rate = 0.999 67 | flow.optim_reset = False 68 | flow.nblocks = '16-16' 69 | flow.intermediate_dim = 512 70 | flow.resblock_type = 'resflow' 71 | 72 | flow.model_config = 'flow_models/wolf/wolf_configs/cifar10/glow/resflow-gaussian-uni.json' 73 | flow.rank = 1 74 | flow.local_rank = 0 75 | flow.batch_size = 512 76 | flow.eval_batch_size = 4 77 | flow.batch_steps = 1 78 | flow.init_batch_size = 1024 79 | flow.epochs = 500 80 | flow.valid_epochs = 1 81 | flow.seed = 65537 82 | flow.train_k = 1 83 | flow.log_interval = 10 84 | # flow.lr = 0.001 85 | flow.warmup_steps = 500 86 | flow.lr_decay = 0.999997 87 | flow.beta1 = 0.9 88 | flow.beta2 = 0.999 89 | flow.eps = 1e-8 90 | flow.weight_decay = 0 91 | flow.amsgrad = True 92 | flow.grad_clip = 0 93 | flow.dataset = 'cifar10' 94 | flow.category = None 95 | flow.image_size = 32 96 | flow.workers = 4 97 | flow.n_bits = 8 98 | flow.recover = -1 99 | 100 | return config 101 | -------------------------------------------------------------------------------- /configs/ve/CELEBA/indm.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSN++ on CIFAR-10 with VE SDE.""" 18 | from configs.default_celeba_configs import get_default_configs 19 | 20 | 21 | def get_config(): 22 | config = get_default_configs() 23 | # training 24 | training = config.training 25 | training.sde = 'vesde' 26 | training.continuous = True 27 | 28 | training.likelihood_weighting = True 29 | training.importance_sampling = True 30 | 31 | # sampling 32 | sampling = config.sampling 33 | sampling.method = 'pc' 34 | sampling.predictor = 'reverse_diffusion' 35 | sampling.corrector = 'langevin' 36 | 37 | # model 38 | model = config.model 39 | model.name = 'ncsnpp' 40 | model.scale_by_sigma = True 41 | model.ema_rate = 0.999 42 | model.normalization = 'GroupNorm' 43 | model.nonlinearity = 'swish' 44 | model.nf = 128 45 | model.ch_mult = (1, 2, 2, 2) 46 | model.num_res_blocks = 4 47 | model.attn_resolutions = (16,) 48 | model.resamp_with_conv = True 49 | model.conditional = True 50 | model.fir = True 51 | model.fir_kernel = [1, 3, 3, 1] 52 | model.skip_rescale = True 53 | model.resblock_type = 'biggan' 54 | model.progressive = 'none' 55 | model.progressive_input = 'residual' 56 | model.progressive_combine = 'sum' 57 | model.attention_type = 'ddpm' 58 | model.init_scale = 0. 59 | model.fourier_scale = 16 60 | model.conv_size = 3 61 | 62 | # flow 63 | flow = config.flow 64 | flow.model = 'wolf' 65 | flow.lr = 1e-3 66 | flow.ema_rate = 0.999 67 | flow.optim_reset = False 68 | flow.nblocks = '16-16' 69 | flow.intermediate_dim = 512 70 | flow.resblock_type = 'resflow' 71 | 72 | flow.model_config = 'flow_models/wolf/wolf_configs/imagenet/64x64/glow/resflow-gaussian-uni.json' 73 | flow.rank = 1 74 | flow.local_rank = 0 75 | flow.batch_size = 512 76 | flow.eval_batch_size = 4 77 | flow.batch_steps = 1 78 | flow.init_batch_size = 1024 79 | flow.epochs = 500 80 | flow.valid_epochs = 1 81 | flow.seed = 65537 82 | flow.train_k = 1 83 | flow.log_interval = 10 84 | # flow.lr = 0.001 85 | flow.warmup_steps = 500 86 | flow.lr_decay = 0.999997 87 | flow.beta1 = 0.9 88 | flow.beta2 = 0.999 89 | flow.eps = 1e-8 90 | flow.weight_decay = 0 91 | flow.amsgrad = True 92 | flow.grad_clip = 0 93 | flow.dataset = 'celeba' 94 | flow.category = None 95 | flow.image_size = 64 96 | flow.workers = 4 97 | flow.n_bits = 8 98 | flow.recover = -1 99 | 100 | return config 101 | -------------------------------------------------------------------------------- /flow_models/resflow/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | 5 | __all__ = ['MovingBatchNorm1d', 'MovingBatchNorm2d'] 6 | 7 | 8 | class MovingBatchNormNd(nn.Module): 9 | 10 | def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True): 11 | super(MovingBatchNormNd, self).__init__() 12 | self.num_features = num_features 13 | self.affine = affine 14 | self.eps = eps 15 | self.decay = decay 16 | self.bn_lag = bn_lag 17 | self.register_buffer('step', torch.zeros(1)) 18 | if self.affine: 19 | self.bias = Parameter(torch.Tensor(num_features)) 20 | else: 21 | self.register_parameter('bias', None) 22 | self.register_buffer('running_mean', torch.zeros(num_features)) 23 | self.reset_parameters() 24 | 25 | @property 26 | def shape(self): 27 | raise NotImplementedError 28 | 29 | def reset_parameters(self): 30 | self.running_mean.zero_() 31 | if self.affine: 32 | self.bias.data.zero_() 33 | 34 | def forward(self, x, logpx=None): 35 | c = x.size(1) 36 | used_mean = self.running_mean.clone().detach() 37 | 38 | if self.training: 39 | # compute batch statistics 40 | x_t = x.transpose(0, 1).contiguous().view(c, -1) 41 | batch_mean = torch.mean(x_t, dim=1) 42 | 43 | # moving average 44 | if self.bn_lag > 0: 45 | used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach()) 46 | used_mean /= (1. - self.bn_lag**(self.step[0] + 1)) 47 | 48 | # update running estimates 49 | self.running_mean -= self.decay * (self.running_mean - batch_mean.data) 50 | self.step += 1 51 | 52 | # perform normalization 53 | used_mean = used_mean.view(*self.shape).expand_as(x) 54 | 55 | y = x - used_mean 56 | 57 | if self.affine: 58 | bias = self.bias.view(*self.shape).expand_as(x) 59 | y = y + bias 60 | 61 | if logpx is None: 62 | return y 63 | else: 64 | return y, logpx 65 | 66 | def inverse(self, y, logpy=None): 67 | used_mean = self.running_mean 68 | 69 | if self.affine: 70 | bias = self.bias.view(*self.shape).expand_as(y) 71 | y = y - bias 72 | 73 | used_mean = used_mean.view(*self.shape).expand_as(y) 74 | x = y + used_mean 75 | 76 | if logpy is None: 77 | return x 78 | else: 79 | return x, logpy 80 | 81 | def __repr__(self): 82 | return ( 83 | '{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},' 84 | ' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__) 85 | ) 86 | 87 | 88 | class MovingBatchNorm1d(MovingBatchNormNd): 89 | 90 | @property 91 | def shape(self): 92 | return [1, -1] 93 | 94 | 95 | class MovingBatchNorm2d(MovingBatchNormNd): 96 | 97 | @property 98 | def shape(self): 99 | return [1, -1, 1, 1] 100 | -------------------------------------------------------------------------------- /configs/vp/CIFAR10/indm_nll.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'ode' 33 | sampling.predictor = 'euler_maruyama' 34 | sampling.corrector = 'none' 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = False 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'none' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.embedding_type = 'positional' 63 | model.fourier_scale = 16 64 | model.conv_size = 3 65 | 66 | # flow 67 | flow = config.flow 68 | flow.model = 'wolf' 69 | flow.lr = 1e-3 70 | flow.ema_rate = 0.999 71 | flow.optim_reset = False 72 | flow.nblocks = '16-16' 73 | flow.intermediate_dim = 512 74 | flow.resblock_type = 'resflow' 75 | 76 | flow.model_config = 'flow_models/wolf/wolf_configs/cifar10/glow/resflow-gaussian-uni.json' 77 | flow.rank = 1 78 | flow.local_rank = 0 79 | flow.batch_size = 512 80 | flow.eval_batch_size = 4 81 | flow.batch_steps = 1 82 | flow.init_batch_size = 1024 83 | flow.epochs = 500 84 | flow.valid_epochs = 1 85 | flow.seed = 65537 86 | flow.train_k = 1 87 | flow.log_interval = 10 88 | # flow.lr = 0.001 89 | flow.warmup_steps = 500 90 | flow.lr_decay = 0.999997 91 | flow.beta1 = 0.9 92 | flow.beta2 = 0.999 93 | flow.eps = 1e-8 94 | flow.weight_decay = 0 95 | flow.amsgrad = True 96 | flow.grad_clip = 0 97 | flow.dataset = 'cifar10' 98 | flow.category = None 99 | flow.image_size = 32 100 | flow.workers = 4 101 | flow.n_bits = 8 102 | flow.recover = -1 103 | 104 | return config 105 | -------------------------------------------------------------------------------- /configs/vp/CELEBA/indm_nll.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | 30 | # sampling 31 | sampling = config.sampling 32 | sampling.method = 'ode' 33 | sampling.predictor = 'euler_maruyama' 34 | sampling.corrector = 'none' 35 | 36 | # data 37 | data = config.data 38 | data.centered = True 39 | 40 | # model 41 | model = config.model 42 | model.name = 'ncsnpp' 43 | model.scale_by_sigma = False 44 | model.ema_rate = 0.9999 45 | model.normalization = 'GroupNorm' 46 | model.nonlinearity = 'swish' 47 | model.nf = 128 48 | model.ch_mult = (1, 2, 2, 2) 49 | model.num_res_blocks = 4 50 | model.attn_resolutions = (16,) 51 | model.resamp_with_conv = True 52 | model.conditional = True 53 | model.fir = False 54 | model.fir_kernel = [1, 3, 3, 1] 55 | model.skip_rescale = True 56 | model.resblock_type = 'biggan' 57 | model.progressive = 'none' 58 | model.progressive_input = 'none' 59 | model.progressive_combine = 'sum' 60 | model.attention_type = 'ddpm' 61 | model.init_scale = 0. 62 | model.embedding_type = 'positional' 63 | model.fourier_scale = 16 64 | model.conv_size = 3 65 | 66 | # flow 67 | flow = config.flow 68 | flow.model = 'wolf' 69 | flow.lr = 1e-3 70 | flow.ema_rate = 0.999 71 | flow.optim_reset = False 72 | flow.nblocks = '16-16' 73 | flow.intermediate_dim = 512 74 | flow.resblock_type = 'resflow' 75 | 76 | flow.model_config = 'flow_models/wolf/wolf_configs/imagenet/64x64/glow/resflow-gaussian-uni.json' 77 | flow.rank = 1 78 | flow.local_rank = 0 79 | flow.batch_size = 512 80 | flow.eval_batch_size = 4 81 | flow.batch_steps = 1 82 | flow.init_batch_size = 1024 83 | flow.epochs = 500 84 | flow.valid_epochs = 1 85 | flow.seed = 65537 86 | flow.train_k = 1 87 | flow.log_interval = 10 88 | # flow.lr = 0.001 89 | flow.warmup_steps = 500 90 | flow.lr_decay = 0.999997 91 | flow.beta1 = 0.9 92 | flow.beta2 = 0.999 93 | flow.eps = 1e-8 94 | flow.weight_decay = 0 95 | flow.amsgrad = True 96 | flow.grad_clip = 0 97 | flow.dataset = 'celeba' 98 | flow.category = None 99 | flow.image_size = 64 100 | flow.workers = 4 101 | flow.n_bits = 8 102 | flow.recover = -1 103 | 104 | return config 105 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | 5 | __all__ = ['MovingBatchNorm1d', 'MovingBatchNorm2d'] 6 | 7 | 8 | class MovingBatchNormNd(nn.Module): 9 | 10 | def __init__(self, num_features, eps=1e-4, decay=0.1, bn_lag=0., affine=True): 11 | super(MovingBatchNormNd, self).__init__() 12 | self.num_features = num_features 13 | self.affine = affine 14 | self.eps = eps 15 | self.decay = decay 16 | self.bn_lag = bn_lag 17 | self.register_buffer('step', torch.zeros(1)) 18 | if self.affine: 19 | self.bias = Parameter(torch.Tensor(num_features)) 20 | else: 21 | self.register_parameter('bias', None) 22 | self.register_buffer('running_mean', torch.zeros(num_features)) 23 | self.reset_parameters() 24 | 25 | @property 26 | def shape(self): 27 | raise NotImplementedError 28 | 29 | def reset_parameters(self): 30 | self.running_mean.zero_() 31 | if self.affine: 32 | self.bias.data.zero_() 33 | 34 | def forward(self, x, logpx=None, h=None): 35 | c = x.size(1) 36 | used_mean = self.running_mean.clone().detach() 37 | 38 | if self.training: 39 | # compute batch statistics 40 | x_t = x.transpose(0, 1).contiguous().view(c, -1) 41 | batch_mean = torch.mean(x_t, dim=1) 42 | 43 | # moving average 44 | if self.bn_lag > 0: 45 | used_mean = batch_mean - (1 - self.bn_lag) * (batch_mean - used_mean.detach()) 46 | used_mean /= (1. - self.bn_lag**(self.step[0] + 1)) 47 | 48 | # update running estimates 49 | self.running_mean -= self.decay * (self.running_mean - batch_mean.data) 50 | self.step += 1 51 | 52 | # perform normalization 53 | used_mean = used_mean.view(*self.shape).expand_as(x) 54 | 55 | y = x - used_mean 56 | 57 | if self.affine: 58 | bias = self.bias.view(*self.shape).expand_as(x) 59 | y = y + bias 60 | 61 | if logpx is None: 62 | return y 63 | else: 64 | return y, logpx 65 | 66 | def inverse(self, y, logpy=None): 67 | used_mean = self.running_mean 68 | 69 | if self.affine: 70 | bias = self.bias.view(*self.shape).expand_as(y) 71 | y = y - bias 72 | 73 | used_mean = used_mean.view(*self.shape).expand_as(y) 74 | x = y + used_mean 75 | 76 | if logpy is None: 77 | return x 78 | else: 79 | return x, logpy 80 | 81 | def __repr__(self): 82 | return ( 83 | '{name}({num_features}, eps={eps}, decay={decay}, bn_lag={bn_lag},' 84 | ' affine={affine})'.format(name=self.__class__.__name__, **self.__dict__) 85 | ) 86 | 87 | 88 | class MovingBatchNorm1d(MovingBatchNormNd): 89 | 90 | @property 91 | def shape(self): 92 | return [1, -1] 93 | 94 | 95 | class MovingBatchNorm2d(MovingBatchNormNd): 96 | 97 | @property 98 | def shape(self): 99 | return [1, -1, 1, 1] 100 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/layers/act_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | 5 | __all__ = ['ActNorm1d', 'ActNorm2d'] 6 | 7 | 8 | class ActNormNd(nn.Module): 9 | 10 | def __init__(self, num_features, eps=1e-12): 11 | super(ActNormNd, self).__init__() 12 | self.num_features = num_features 13 | self.eps = eps 14 | self.weight = Parameter(torch.Tensor(num_features)) 15 | self.bias = Parameter(torch.Tensor(num_features)) 16 | # if actnorm_initialize: 17 | # self.register_buffer('initialized', torch.tensor(1)) 18 | # else: 19 | # self.register_buffer('initialized', torch.tensor(0)) 20 | self.register_buffer('initialized', torch.tensor(1)) 21 | nn.init.uniform_(self.weight, -1e-5, 1e-5) 22 | nn.init.uniform_(self.bias, -1e-5, 1e-5) 23 | 24 | @property 25 | def shape(self): 26 | raise NotImplementedError 27 | 28 | 29 | def forward(self, x, logpx=None, h=None): 30 | c = x.size(1) 31 | 32 | if not self.initialized: 33 | with torch.no_grad(): 34 | # compute batch statistics 35 | x_t = x.transpose(0, 1).contiguous().view(c, -1) 36 | batch_mean = torch.mean(x_t, dim=1) 37 | batch_var = torch.var(x_t, dim=1) 38 | 39 | # for numerical issues 40 | batch_var = torch.max(batch_var, torch.tensor(0.2).to(batch_var)) 41 | 42 | self.bias.data.copy_(-batch_mean) 43 | self.weight.data.copy_(-0.5 * torch.log(batch_var)) 44 | self.initialized.fill_(1) 45 | #else: 46 | # self.bias.data.copy_(torch.zeros(self.bias.shape[0], device=x.device)) 47 | # self.weight.data.copy_(torch.zeros(self.weight.shape[0], device=x.device)) 48 | 49 | bias = self.bias.view(*self.shape).expand_as(x) 50 | weight = self.weight.view(*self.shape).expand_as(x) 51 | 52 | y = (x + bias) * torch.exp(weight) 53 | 54 | if logpx is None: 55 | return y 56 | else: 57 | return y, logpx - self._logdetgrad(x) 58 | 59 | def inverse(self, y, logpy=None, h=None): 60 | # if self.initialized: 61 | # self.bias.data.copy_(torch.zeros(self.bias.shape[0], device=y.device)) 62 | # self.weight.data.copy_(torch.zeros(self.weight.shape[0], device=y.device)) 63 | assert self.initialized 64 | bias = self.bias.view(*self.shape).expand_as(y) 65 | weight = self.weight.view(*self.shape).expand_as(y) 66 | 67 | x = y * torch.exp(-weight) - bias 68 | if logpy is None: 69 | return x 70 | else: 71 | return x, logpy + self._logdetgrad(x) 72 | 73 | def _logdetgrad(self, x): 74 | return self.weight.view(*self.shape).expand(*x.size()).contiguous().view(x.size(0), -1).sum(1, keepdim=True) 75 | 76 | def __repr__(self): 77 | return ('{name}({num_features})'.format(name=self.__class__.__name__, **self.__dict__)) 78 | 79 | 80 | class ActNorm1d(ActNormNd): 81 | 82 | @property 83 | def shape(self): 84 | return [1, -1] 85 | 86 | 87 | class ActNorm2d(ActNormNd): 88 | 89 | @property 90 | def shape(self): 91 | return [1, -1, 1, 1] 92 | -------------------------------------------------------------------------------- /flow_models/resflow/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.datasets as vdsets 3 | 4 | 5 | class Dataset(object): 6 | 7 | def __init__(self, loc, transform=None, in_mem=True): 8 | self.in_mem = in_mem 9 | self.dataset = torch.load(loc) 10 | if in_mem: self.dataset = self.dataset.float().div(255) 11 | self.transform = transform 12 | 13 | def __len__(self): 14 | return self.dataset.size(0) 15 | 16 | @property 17 | def ndim(self): 18 | return self.dataset.size(1) 19 | 20 | def __getitem__(self, index): 21 | x = self.dataset[index] 22 | if not self.in_mem: x = x.float().div(255) 23 | x = self.transform(x) if self.transform is not None else x 24 | return x, 0 25 | 26 | 27 | class MNIST(object): 28 | 29 | def __init__(self, dataroot, train=True, transform=None): 30 | self.mnist = vdsets.MNIST(dataroot, train=train, download=True, transform=transform) 31 | 32 | def __len__(self): 33 | return len(self.mnist) 34 | 35 | @property 36 | def ndim(self): 37 | return 1 38 | 39 | def __getitem__(self, index): 40 | return self.mnist[index] 41 | 42 | 43 | class CIFAR10(object): 44 | 45 | def __init__(self, dataroot, train=True, transform=None): 46 | self.cifar10 = vdsets.CIFAR10(dataroot, train=train, download=True, transform=transform) 47 | 48 | def __len__(self): 49 | return len(self.cifar10) 50 | 51 | @property 52 | def ndim(self): 53 | return 3 54 | 55 | def __getitem__(self, index): 56 | return self.cifar10[index] 57 | 58 | 59 | class CelebA5bit(object): 60 | 61 | LOC = 'data/celebahq64_5bit/celeba_full_64x64_5bit.pth' 62 | 63 | def __init__(self, train=True, transform=None): 64 | self.dataset = torch.load(self.LOC).float().div(31) 65 | if not train: 66 | self.dataset = self.dataset[:5000] 67 | self.transform = transform 68 | 69 | def __len__(self): 70 | return self.dataset.size(0) 71 | 72 | @property 73 | def ndim(self): 74 | return self.dataset.size(1) 75 | 76 | def __getitem__(self, index): 77 | x = self.dataset[index] 78 | x = self.transform(x) if self.transform is not None else x 79 | return x, 0 80 | 81 | 82 | class CelebAHQ(Dataset): 83 | TRAIN_LOC = 'data/celebahq/celeba256_train.pth' 84 | TEST_LOC = 'data/celebahq/celeba256_validation.pth' 85 | 86 | def __init__(self, train=True, transform=None): 87 | return super(CelebAHQ, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform) 88 | 89 | 90 | class Imagenet32(Dataset): 91 | TRAIN_LOC = 'data/imagenet32/train_32x32.pth' 92 | TEST_LOC = 'data/imagenet32/valid_32x32.pth' 93 | 94 | def __init__(self, train=True, transform=None): 95 | return super(Imagenet32, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform) 96 | 97 | 98 | class Imagenet64(Dataset): 99 | TRAIN_LOC = 'data/imagenet64/train_64x64.pth' 100 | TEST_LOC = 'data/imagenet64/valid_64x64.pth' 101 | 102 | def __init__(self, train=True, transform=None): 103 | return super(Imagenet64, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform, in_mem=False) 104 | -------------------------------------------------------------------------------- /configs/vp/CELEBA/indm_fid.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_celeba_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | training.likelihood_weighting = False 30 | training.importance_sampling = False 31 | 32 | # sampling 33 | sampling = config.sampling 34 | sampling.method = 'ode' 35 | sampling.predictor = 'euler_maruyama' 36 | sampling.corrector = 'none' 37 | 38 | # data 39 | data = config.data 40 | data.centered = True 41 | 42 | # model 43 | model = config.model 44 | model.name = 'ncsnpp' 45 | model.scale_by_sigma = False 46 | model.ema_rate = 0.9999 47 | model.normalization = 'GroupNorm' 48 | model.nonlinearity = 'swish' 49 | model.nf = 128 50 | model.ch_mult = (1, 2, 2, 2) 51 | model.num_res_blocks = 4 52 | model.attn_resolutions = (16,) 53 | model.resamp_with_conv = True 54 | model.conditional = True 55 | model.fir = False 56 | model.fir_kernel = [1, 3, 3, 1] 57 | model.skip_rescale = True 58 | model.resblock_type = 'biggan' 59 | model.progressive = 'none' 60 | model.progressive_input = 'none' 61 | model.progressive_combine = 'sum' 62 | model.attention_type = 'ddpm' 63 | model.init_scale = 0. 64 | model.embedding_type = 'positional' 65 | model.fourier_scale = 16 66 | model.conv_size = 3 67 | 68 | # flow 69 | flow = config.flow 70 | flow.model = 'wolf' 71 | flow.lr = 1e-3 72 | flow.ema_rate = 0.999 73 | flow.optim_reset = False 74 | flow.nblocks = '16-16' 75 | flow.intermediate_dim = 512 76 | flow.resblock_type = 'resflow' 77 | 78 | flow.model_config = 'flow_models/wolf/wolf_configs/imagenet/64x64/glow/resflow-gaussian-uni.json' 79 | flow.rank = 1 80 | flow.local_rank = 0 81 | flow.batch_size = 512 82 | flow.eval_batch_size = 4 83 | flow.batch_steps = 1 84 | flow.init_batch_size = 1024 85 | flow.epochs = 500 86 | flow.valid_epochs = 1 87 | flow.seed = 65537 88 | flow.train_k = 1 89 | flow.log_interval = 10 90 | # flow.lr = 0.001 91 | flow.warmup_steps = 500 92 | flow.lr_decay = 0.999997 93 | flow.beta1 = 0.9 94 | flow.beta2 = 0.999 95 | flow.eps = 1e-8 96 | flow.weight_decay = 0 97 | flow.amsgrad = True 98 | flow.grad_clip = 0 99 | flow.dataset = 'celeba' 100 | flow.category = None 101 | flow.image_size = 64 102 | flow.workers = 4 103 | flow.n_bits = 8 104 | flow.recover = -1 105 | 106 | return config 107 | -------------------------------------------------------------------------------- /configs/vp/CIFAR10/indm_fid.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Training NCSNv3 on CIFAR-10 with continuous sigmas.""" 18 | 19 | from configs.default_cifar10_configs import get_default_configs 20 | 21 | 22 | def get_config(): 23 | config = get_default_configs() 24 | # training 25 | training = config.training 26 | training.sde = 'vpsde' 27 | training.continuous = True 28 | training.reduce_mean = True 29 | training.likelihood_weighting = False 30 | training.importance_sampling = False 31 | 32 | # sampling 33 | sampling = config.sampling 34 | sampling.method = 'ode' 35 | sampling.predictor = 'euler_maruyama' 36 | sampling.corrector = 'none' 37 | 38 | # data 39 | data = config.data 40 | data.centered = True 41 | 42 | # model 43 | model = config.model 44 | model.name = 'ncsnpp' 45 | model.scale_by_sigma = False 46 | model.ema_rate = 0.9999 47 | model.normalization = 'GroupNorm' 48 | model.nonlinearity = 'swish' 49 | model.nf = 128 50 | model.ch_mult = (1, 2, 2, 2) 51 | model.num_res_blocks = 4 52 | model.attn_resolutions = (16,) 53 | model.resamp_with_conv = True 54 | model.conditional = True 55 | model.fir = False 56 | model.fir_kernel = [1, 3, 3, 1] 57 | model.skip_rescale = True 58 | model.resblock_type = 'biggan' 59 | model.progressive = 'none' 60 | model.progressive_input = 'none' 61 | model.progressive_combine = 'sum' 62 | model.attention_type = 'ddpm' 63 | model.init_scale = 0. 64 | model.embedding_type = 'positional' 65 | model.fourier_scale = 16 66 | model.conv_size = 3 67 | 68 | # flow 69 | flow = config.flow 70 | flow.model = 'wolf' 71 | flow.lr = 1e-3 72 | flow.ema_rate = 0.999 73 | flow.optim_reset = False 74 | flow.nblocks = '16-16' 75 | flow.intermediate_dim = 512 76 | flow.resblock_type = 'resflow' 77 | 78 | flow.model_config = 'flow_models/wolf/wolf_configs/cifar10/glow/resflow-gaussian-uni.json' 79 | flow.rank = 1 80 | flow.local_rank = 0 81 | flow.batch_size = 512 82 | flow.eval_batch_size = 4 83 | flow.batch_steps = 1 84 | flow.init_batch_size = 1024 85 | flow.epochs = 500 86 | flow.valid_epochs = 1 87 | flow.seed = 65537 88 | flow.train_k = 1 89 | flow.log_interval = 10 90 | # flow.lr = 0.001 91 | flow.warmup_steps = 500 92 | flow.lr_decay = 0.999997 93 | flow.beta1 = 0.9 94 | flow.beta2 = 0.999 95 | flow.eps = 1e-8 96 | flow.weight_decay = 0 97 | flow.amsgrad = True 98 | flow.grad_clip = 0 99 | flow.dataset = 'cifar10' 100 | flow.category = None 101 | flow.image_size = 32 102 | flow.workers = 4 103 | flow.n_bits = 8 104 | flow.recover = -1 105 | 106 | return config 107 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.datasets as vdsets 3 | 4 | 5 | class Dataset(object): 6 | 7 | def __init__(self, loc, transform=None, in_mem=True): 8 | self.in_mem = in_mem 9 | self.dataset = torch.load(loc) 10 | if in_mem: self.dataset = self.dataset.float().div(255) 11 | self.transform = transform 12 | 13 | def __len__(self): 14 | return self.dataset.size(0) 15 | 16 | @property 17 | def ndim(self): 18 | return self.dataset.size(1) 19 | 20 | def __getitem__(self, index): 21 | x = self.dataset[index] 22 | if not self.in_mem: x = x.float().div(255) 23 | x = self.transform(x) if self.transform is not None else x 24 | return x, 0 25 | 26 | 27 | class MNIST(object): 28 | 29 | def __init__(self, dataroot, train=True, transform=None): 30 | self.mnist = vdsets.MNIST(dataroot, train=train, download=True, transform=transform) 31 | 32 | def __len__(self): 33 | return len(self.mnist) 34 | 35 | @property 36 | def ndim(self): 37 | return 1 38 | 39 | def __getitem__(self, index): 40 | return self.mnist[index] 41 | 42 | 43 | class CIFAR10(object): 44 | 45 | def __init__(self, dataroot, train=True, transform=None): 46 | self.cifar10 = vdsets.CIFAR10(dataroot, train=train, download=True, transform=transform) 47 | 48 | def __len__(self): 49 | return len(self.cifar10) 50 | 51 | @property 52 | def ndim(self): 53 | return 3 54 | 55 | def __getitem__(self, index): 56 | return self.cifar10[index] 57 | 58 | 59 | class CelebA5bit(object): 60 | 61 | LOC = 'data/celebahq64_5bit/celeba_full_64x64_5bit.pth' 62 | 63 | def __init__(self, train=True, transform=None): 64 | self.dataset = torch.load(self.LOC).float().div(31) 65 | if not train: 66 | self.dataset = self.dataset[:5000] 67 | self.transform = transform 68 | 69 | def __len__(self): 70 | return self.dataset.size(0) 71 | 72 | @property 73 | def ndim(self): 74 | return self.dataset.size(1) 75 | 76 | def __getitem__(self, index): 77 | x = self.dataset[index] 78 | x = self.transform(x) if self.transform is not None else x 79 | return x, 0 80 | 81 | 82 | class CelebAHQ(Dataset): 83 | TRAIN_LOC = 'data/celebahq/celeba256_train.pth' 84 | TEST_LOC = 'data/celebahq/celeba256_validation.pth' 85 | 86 | def __init__(self, train=True, transform=None): 87 | return super(CelebAHQ, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform) 88 | 89 | 90 | class Imagenet32(Dataset): 91 | TRAIN_LOC = 'data/imagenet32/train_32x32.pth' 92 | TEST_LOC = 'data/imagenet32/valid_32x32.pth' 93 | 94 | def __init__(self, train=True, transform=None): 95 | return super(Imagenet32, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform) 96 | 97 | 98 | class Imagenet64(Dataset): 99 | TRAIN_LOC = 'data/imagenet64/train_64x64.pth' 100 | TEST_LOC = 'data/imagenet64/valid_64x64.pth' 101 | 102 | def __init__(self, train=True, transform=None): 103 | return super(Imagenet64, self).__init__(self.TRAIN_LOC if train else self.TEST_LOC, transform, in_mem=False) 104 | -------------------------------------------------------------------------------- /flow_models/wolf/modules/encoders/global_encoder.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from collections import OrderedDict 4 | from typing import Dict 5 | import torch 6 | import torch.nn as nn 7 | 8 | from flow_models.wolf.modules.encoders.encoder import Encoder 9 | from flow_models.wolf.nnet.resnets import ResNetBatchNorm, ResNetGroupNorm 10 | 11 | 12 | class GlobalResNetEncoderBatchNorm(Encoder): 13 | """ 14 | Global ResNet Encoder with batch normalization 15 | """ 16 | def __init__(self, levels, in_planes, out_planes, hidden_planes, activation): 17 | super(GlobalResNetEncoderBatchNorm, self).__init__() 18 | layers = list() 19 | assert len(hidden_planes) == levels 20 | for level in range(levels): 21 | hidden_channels = hidden_planes[level] 22 | layers.append(('resnet{}'.format(level), 23 | ResNetBatchNorm(in_planes, 24 | [hidden_channels, hidden_channels], 25 | [1, 2], activation))) 26 | in_planes = hidden_channels 27 | 28 | layers.append(('top', nn.Conv2d(in_planes, out_planes, 1, bias=True))) 29 | layers.append(('activate', nn.ELU(inplace=True))) 30 | self.net = nn.Sequential(OrderedDict(layers)) 31 | 32 | def forward(self, x): 33 | # [batch, out_planes, h, w] 34 | out = self.net(x) 35 | # [batch, out_planes * h * w] 36 | return out.contiguous().view(out.size(0), -1) 37 | 38 | def init(self, x, init_scale=1.0): 39 | with torch.no_grad(): 40 | return self(x) 41 | 42 | @classmethod 43 | def from_params(cls, params: Dict) -> "GlobalResNetEncoderBatchNorm": 44 | return GlobalResNetEncoderBatchNorm(**params) 45 | 46 | 47 | class GlobalResNetEncoderGroupNorm(Encoder): 48 | """ 49 | Global ResNet Encoder with batch normalization 50 | """ 51 | def __init__(self, levels, in_planes, out_planes, hidden_planes, activation, num_groups): 52 | super(GlobalResNetEncoderGroupNorm, self).__init__() 53 | layers = list() 54 | assert len(hidden_planes) == levels 55 | assert len(num_groups) == levels 56 | for level in range(levels): 57 | hidden_channels = hidden_planes[level] 58 | n_groups = num_groups[level] 59 | layers.append(('resnet{}'.format(level), 60 | ResNetGroupNorm(in_planes, 61 | [hidden_channels, hidden_channels], 62 | [1, 2], activation, num_groups=n_groups))) 63 | in_planes = hidden_channels 64 | 65 | layers.append(('top', nn.Conv2d(in_planes, out_planes, 1, bias=True))) 66 | layers.append(('activate', nn.ELU(inplace=True))) 67 | self.net = nn.Sequential(OrderedDict(layers)) 68 | 69 | def forward(self, x): 70 | # [batch, out_planes, h, w] 71 | out = self.net(x) 72 | # [batch, out_planes * h * w] 73 | return out.contiguous().view(out.size(0), -1) 74 | 75 | def init(self, x, init_scale=1.0): 76 | with torch.no_grad(): 77 | return self(x) 78 | 79 | @classmethod 80 | def from_params(cls, params: Dict) -> "GlobalResNetEncoderGroupNorm": 81 | return GlobalResNetEncoderGroupNorm(**params) 82 | 83 | 84 | GlobalResNetEncoderBatchNorm.register('global_resnet_bn') 85 | GlobalResNetEncoderGroupNorm.register('global_resnet_gn') 86 | -------------------------------------------------------------------------------- /flow_models/resflow/layers/act_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import Parameter 4 | 5 | __all__ = ['ActNorm1d', 'ActNorm2d'] 6 | 7 | 8 | class ActNormNd(nn.Module): 9 | 10 | def __init__(self, num_features, eps=1e-12): 11 | super(ActNormNd, self).__init__() 12 | self.num_features = num_features 13 | self.eps = eps 14 | self.weight = Parameter(torch.Tensor(num_features)) 15 | self.bias = Parameter(torch.Tensor(num_features)) 16 | # if actnorm_initialize: 17 | # self.register_buffer('initialized', torch.tensor(1)) 18 | # else: 19 | # self.register_buffer('initialized', torch.tensor(0)) 20 | self.register_buffer('initialized', torch.tensor(1)) 21 | nn.init.uniform_(self.weight, -1e-5, 1e-5) 22 | nn.init.uniform_(self.bias, -1e-5, 1e-5) 23 | 24 | @property 25 | def shape(self): 26 | raise NotImplementedError 27 | 28 | 29 | def forward(self, x, logpx=None): 30 | c = x.size(1) 31 | 32 | if not self.initialized: 33 | with torch.no_grad(): 34 | # compute batch statistics 35 | x_t = x.transpose(0, 1).contiguous().view(c, -1) 36 | batch_mean = torch.mean(x_t, dim=1) 37 | batch_var = torch.var(x_t, dim=1) 38 | 39 | # for numerical issues 40 | batch_var = torch.max(batch_var, torch.tensor(0.2).to(batch_var)) 41 | 42 | self.bias.data.copy_(-batch_mean) 43 | # self.weight.data.copy_(-0.5 * torch.log(batch_var)) 44 | self.weight.data.copy_(-0.5 * torch.log(batch_var)) 45 | self.initialized.fill_(1) 46 | print('init: ', self.weight.min().item(), self.weight.max().item(), self.weight.mean().item()) 47 | #else: 48 | # self.bias.data.copy_(torch.zeros(self.bias.shape[0], device=x.device)) 49 | # self.weight.data.copy_(torch.zeros(self.weight.shape[0], device=x.device)) 50 | 51 | bias = self.bias.view(*self.shape).expand_as(x) 52 | weight = self.weight.view(*self.shape).expand_as(x) 53 | 54 | # print('forward: ', self.weight.min().item(), self.weight.max().item(), self.weight.mean().item()) 55 | y = (x + bias) * torch.exp(weight) 56 | 57 | if logpx is None: 58 | return y 59 | else: 60 | return y, logpx - self._logdetgrad(x) 61 | 62 | def inverse(self, y, logpy=None): 63 | # if self.initialized: 64 | # self.bias.data.copy_(torch.zeros(self.bias.shape[0], device=y.device)) 65 | # self.weight.data.copy_(torch.zeros(self.weight.shape[0], device=y.device)) 66 | assert self.initialized 67 | bias = self.bias.view(*self.shape).expand_as(y) 68 | weight = self.weight.view(*self.shape).expand_as(y) 69 | 70 | x = y * torch.exp(-weight) - bias 71 | if logpy is None: 72 | return x 73 | else: 74 | return x, logpy + self._logdetgrad(x) 75 | 76 | def _logdetgrad(self, x): 77 | return self.weight.view(*self.shape).expand(*x.size()).contiguous().view(x.size(0), -1).sum(1, keepdim=True) 78 | 79 | def __repr__(self): 80 | return ('{name}({num_features})'.format(name=self.__class__.__name__, **self.__dict__)) 81 | 82 | 83 | class ActNorm1d(ActNormNd): 84 | 85 | @property 86 | def shape(self): 87 | return [1, -1] 88 | 89 | 90 | class ActNorm2d(ActNormNd): 91 | 92 | @property 93 | def shape(self): 94 | return [1, -1, 1, 1] 95 | -------------------------------------------------------------------------------- /flow_models/wolf/modules/discriminators/categorical.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from typing import Dict 4 | from overrides import overrides 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | from torch.distributions import Categorical 9 | 10 | from flow_models.wolf.modules.discriminators.discriminator import Discriminator 11 | 12 | 13 | class CategoricalDiscriminator(Discriminator): 14 | """ 15 | Prior with categorical distribution (using for class label conditioned generation) 16 | """ 17 | def __init__(self, num_events, dim, activation='relu', probs=None, logits=None): 18 | super(CategoricalDiscriminator, self).__init__() 19 | if probs is not None and logits is not None: 20 | raise ValueError("Either `probs` or `logits` can be specified, but not both.") 21 | 22 | if probs is not None: 23 | assert len(probs) == num_events, 'number of probs must match number of events.' 24 | probs = torch.tensor(probs).float() 25 | self.cat_dist = Categorical(probs=probs) 26 | elif logits is not None: 27 | assert len(logits) == num_events, 'number of logits must match number of events.' 28 | logits = torch.tensor(logits).float() 29 | self.cat_dist = Categorical(logits=logits) 30 | else: 31 | probs = torch.full((num_events, ), 1.0 / num_events).float() 32 | self.cat_dist = Categorical(probs=probs) 33 | 34 | if activation == 'relu': 35 | Actv = nn.ReLU(inplace=True) 36 | elif activation == 'elu': 37 | Actv = nn.ELU(inplace=True) 38 | else: 39 | Actv = nn.LeakyReLU(inplace=True, negative_slope=1e-1) 40 | self.embed = nn.Embedding(num_events, dim) 41 | self.net = nn.Sequential( 42 | nn.Linear(dim, 4 * dim), 43 | Actv, 44 | nn.Linear(4 * dim, 4 * dim), 45 | Actv, 46 | nn.Linear(4 * dim, dim) 47 | ) 48 | self.reset_parameters() 49 | 50 | def reset_parameters(self): 51 | nn.init.uniform_(self.embed.weight, -0.1, 0.1) 52 | 53 | @overrides 54 | def to_device(self, device): 55 | logits = self.cat_dist.logits.to(device) 56 | self.cat_dist = Categorical(logits=logits) 57 | 58 | @overrides 59 | def init(self, x, y=None, init_scale=1.0): 60 | with torch.no_grad(): 61 | z, KL = self.sampling_and_KL(x, y=y) 62 | return z.squeeze(1), KL 63 | 64 | @overrides 65 | def sample_from_prior(self, nsamples=1, device=torch.device('cpu')): 66 | # [nsamples] 67 | cids = self.cat_dist.sample((nsamples, )).to(device) 68 | cids = torch.sort(cids)[0] 69 | # [nsamples, dim] 70 | return self.net(self.embed(cids)) 71 | 72 | @overrides 73 | def sample_from_posterior(self, x, y=None, nsamples=1, random=True): 74 | assert y is not None 75 | log_probs = x.new_zeros(x.size(0), nsamples) 76 | # [batch, nsamples, dim] 77 | z = self.net(self.embed(y)).unsqueeze(1) + log_probs.unsqueeze(2) 78 | return z, log_probs 79 | 80 | @overrides 81 | def sampling_and_KL(self, x, y=None, nsamples=1): 82 | # [batch, nsamples, dim] 83 | z, _ = self.sample_from_posterior(x, y=y, nsamples=nsamples, random=True) 84 | # [batch,] 85 | log_probs_prior = self.cat_dist.log_prob(y) 86 | KL = -log_probs_prior 87 | return z, KL 88 | 89 | @classmethod 90 | def from_params(cls, params: Dict) -> "CategoricalDiscriminator": 91 | return CategoricalDiscriminator(**params) 92 | 93 | 94 | CategoricalDiscriminator.register('categorical') 95 | -------------------------------------------------------------------------------- /flow_models/resflow/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | 5 | class CosineAnnealingWarmRestarts(_LRScheduler): 6 | r"""Set the learning rate of each parameter group using a cosine annealing 7 | schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` 8 | is the number of epochs since the last restart and :math:`T_{i}` is the number 9 | of epochs between two warm restarts in SGDR: 10 | .. math:: 11 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 12 | \cos(\frac{T_{cur}}{T_{i}}\pi)) 13 | When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. 14 | When :math:`T_{cur}=0`(after restart), set :math:`\eta_t=\eta_{max}`. 15 | It has been proposed in 16 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. 17 | Args: 18 | optimizer (Optimizer): Wrapped optimizer. 19 | T_0 (int): Number of iterations for the first restart. 20 | T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. 21 | eta_min (float, optional): Minimum learning rate. Default: 0. 22 | last_epoch (int, optional): The index of last epoch. Default: -1. 23 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 24 | https://arxiv.org/abs/1608.03983 25 | """ 26 | 27 | def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1): 28 | if T_0 <= 0 or not isinstance(T_0, int): 29 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) 30 | if T_mult < 1 or not isinstance(T_mult, int): 31 | raise ValueError("Expected integer T_mul >= 1, but got {}".format(T_mult)) 32 | self.T_0 = T_0 33 | self.T_i = T_0 34 | self.T_mult = T_mult 35 | self.eta_min = eta_min 36 | super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch) 37 | self.T_cur = last_epoch 38 | 39 | def get_lr(self): 40 | return [ 41 | self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 42 | for base_lr in self.base_lrs 43 | ] 44 | 45 | def step(self, epoch=None): 46 | """Step could be called after every update, i.e. if one epoch has 10 iterations 47 | (number_of_train_examples / batch_size), we should call SGDR.step(0.1), SGDR.step(0.2), etc. 48 | This function can be called in an interleaved way. 49 | Example: 50 | >>> scheduler = SGDR(optimizer, T_0, T_mult) 51 | >>> for epoch in range(20): 52 | >>> scheduler.step() 53 | >>> scheduler.step(26) 54 | >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) 55 | """ 56 | if epoch is None: 57 | epoch = self.last_epoch + 1 58 | self.T_cur = self.T_cur + 1 59 | if self.T_cur >= self.T_i: 60 | self.T_cur = self.T_cur - self.T_i 61 | self.T_i = self.T_i * self.T_mult 62 | else: 63 | if epoch >= self.T_0: 64 | if self.T_mult == 1: 65 | self.T_cur = epoch % self.T_0 66 | else: 67 | n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) 68 | self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (self.T_mult - 1) 69 | self.T_i = self.T_0 * self.T_mult**(n) 70 | else: 71 | self.T_i = self.T_0 72 | self.T_cur = epoch 73 | self.last_epoch = math.floor(epoch) 74 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 75 | param_group['lr'] = lr 76 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | 4 | 5 | class CosineAnnealingWarmRestarts(_LRScheduler): 6 | r"""Set the learning rate of each parameter group using a cosine annealing 7 | schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}` 8 | is the number of epochs since the last restart and :math:`T_{i}` is the number 9 | of epochs between two warm restarts in SGDR: 10 | .. math:: 11 | \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 + 12 | \cos(\frac{T_{cur}}{T_{i}}\pi)) 13 | When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`. 14 | When :math:`T_{cur}=0`(after restart), set :math:`\eta_t=\eta_{max}`. 15 | It has been proposed in 16 | `SGDR: Stochastic Gradient Descent with Warm Restarts`_. 17 | Args: 18 | optimizer (Optimizer): Wrapped optimizer. 19 | T_0 (int): Number of iterations for the first restart. 20 | T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1. 21 | eta_min (float, optional): Minimum learning rate. Default: 0. 22 | last_epoch (int, optional): The index of last epoch. Default: -1. 23 | .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: 24 | https://arxiv.org/abs/1608.03983 25 | """ 26 | 27 | def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1): 28 | if T_0 <= 0 or not isinstance(T_0, int): 29 | raise ValueError("Expected positive integer T_0, but got {}".format(T_0)) 30 | if T_mult < 1 or not isinstance(T_mult, int): 31 | raise ValueError("Expected integer T_mul >= 1, but got {}".format(T_mult)) 32 | self.T_0 = T_0 33 | self.T_i = T_0 34 | self.T_mult = T_mult 35 | self.eta_min = eta_min 36 | super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch) 37 | self.T_cur = last_epoch 38 | 39 | def get_lr(self): 40 | return [ 41 | self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * self.T_cur / self.T_i)) / 2 42 | for base_lr in self.base_lrs 43 | ] 44 | 45 | def step(self, epoch=None): 46 | """Step could be called after every update, i.e. if one epoch has 10 iterations 47 | (number_of_train_examples / batch_size), we should call SGDR.step(0.1), SGDR.step(0.2), etc. 48 | This function can be called in an interleaved way. 49 | Example: 50 | >>> scheduler = SGDR(optimizer, T_0, T_mult) 51 | >>> for epoch in range(20): 52 | >>> scheduler.step() 53 | >>> scheduler.step(26) 54 | >>> scheduler.step() # scheduler.step(27), instead of scheduler(20) 55 | """ 56 | if epoch is None: 57 | epoch = self.last_epoch + 1 58 | self.T_cur = self.T_cur + 1 59 | if self.T_cur >= self.T_i: 60 | self.T_cur = self.T_cur - self.T_i 61 | self.T_i = self.T_i * self.T_mult 62 | else: 63 | if epoch >= self.T_0: 64 | if self.T_mult == 1: 65 | self.T_cur = epoch % self.T_0 66 | else: 67 | n = int(math.log((epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult)) 68 | self.T_cur = epoch - self.T_0 * (self.T_mult**n - 1) / (self.T_mult - 1) 69 | self.T_i = self.T_0 * self.T_mult**(n) 70 | else: 71 | self.T_i = self.T_0 72 | self.T_cur = epoch 73 | self.last_epoch = math.floor(epoch) 74 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 75 | param_group['lr'] = lr 76 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | # Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py 2 | 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | 8 | 9 | # Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py 10 | class ExponentialMovingAverage: 11 | """ 12 | Maintains (exponential) moving average of a set of parameters. 13 | """ 14 | 15 | def __init__(self, parameters, decay, use_num_updates=True): 16 | """ 17 | Args: 18 | parameters: Iterable of `torch.nn.Parameter`; usually the result of 19 | `model.parameters()`. 20 | decay: The exponential decay. 21 | use_num_updates: Whether to use number of updates when computing 22 | averages. 23 | """ 24 | if decay < 0.0 or decay > 1.0: 25 | raise ValueError('Decay must be between 0 and 1') 26 | self.decay = decay 27 | self.num_updates = 0 if use_num_updates else None 28 | self.shadow_params = [p.clone().detach() 29 | for p in parameters if p.requires_grad] 30 | self.collected_params = [] 31 | 32 | def update(self, parameters): 33 | """ 34 | Update currently maintained parameters. 35 | 36 | Call this every time the parameters are updated, such as the result of 37 | the `optimizer.step()` call. 38 | 39 | Args: 40 | parameters: Iterable of `torch.nn.Parameter`; usually the same set of 41 | parameters used to initialize this object. 42 | """ 43 | decay = self.decay 44 | if self.num_updates is not None: 45 | self.num_updates += 1 46 | decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) 47 | one_minus_decay = 1.0 - decay 48 | with torch.no_grad(): 49 | parameters = [p for p in parameters if p.requires_grad] 50 | for s_param, param in zip(self.shadow_params, parameters): 51 | s_param.sub_(one_minus_decay * (s_param - param)) 52 | 53 | def copy_to(self, parameters): 54 | """ 55 | Copy current parameters into given collection of parameters. 56 | 57 | Args: 58 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 59 | updated with the stored moving averages. 60 | """ 61 | parameters = [p for p in parameters if p.requires_grad] 62 | for s_param, param in zip(self.shadow_params, parameters): 63 | if param.requires_grad: 64 | param.data.copy_(s_param.data) 65 | 66 | def store(self, parameters): 67 | """ 68 | Save the current parameters for restoring later. 69 | 70 | Args: 71 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 72 | temporarily stored. 73 | """ 74 | self.collected_params = [param.clone() for param in parameters] 75 | 76 | def restore(self, parameters): 77 | """ 78 | Restore the parameters stored with the `store` method. 79 | Useful to validate the model with EMA parameters without affecting the 80 | original optimization process. Store the parameters before the 81 | `copy_to` method. After validation (or model saving), use this to 82 | restore the former parameters. 83 | 84 | Args: 85 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 86 | updated with the stored parameters. 87 | """ 88 | for c_param, param in zip(self.collected_params, parameters): 89 | param.data.copy_(c_param.data) 90 | 91 | def state_dict(self): 92 | return dict(decay=self.decay, num_updates=self.num_updates, 93 | shadow_params=self.shadow_params) 94 | 95 | def load_state_dict(self, state_dict): 96 | self.decay = state_dict['decay'] 97 | self.num_updates = state_dict['num_updates'] 98 | self.shadow_params = state_dict['shadow_params'] -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Training and evaluation""" 17 | 18 | import torch 19 | import run_lib 20 | from absl import app 21 | from absl import flags 22 | from ml_collections.config_flags import config_flags 23 | import logging 24 | import os 25 | import tensorflow as tf 26 | 27 | FLAGS = flags.FLAGS 28 | 29 | config_flags.DEFINE_config_file( 30 | "config", None, "Training configuration.", lock_config=True) 31 | flags.DEFINE_string("workdir", None, "Work directory.") 32 | flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval") 33 | flags.DEFINE_string("assetdir", os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + 34 | "/assets/stats/", "The folder name for storing evaluation results") 35 | flags.DEFINE_string("eval_folder", "eval", 36 | "The folder name for storing evaluation results") 37 | flags.mark_flags_as_required(["workdir", "config", "mode"]) 38 | 39 | 40 | def main(argv): 41 | tf.io.gfile.makedirs(FLAGS.workdir) 42 | with open(os.path.join(FLAGS.workdir, 'config.txt'), 'w') as f: 43 | # f.write(str(FLAGS.config.to_dict())) 44 | for k, v in FLAGS.config.to_dict().items(): 45 | f.write(str(k) + '\n') 46 | print(type(v)) 47 | if type(v) == dict: 48 | for k2, v2 in v.items(): 49 | f.write('> ' + str(k2) + ': ' + str(v2) + '\n') 50 | f.write('\n\n') 51 | if FLAGS.mode == "train": 52 | # Create the working directory 53 | tf.io.gfile.makedirs(FLAGS.workdir) 54 | # Set logger so that it outputs to both console and file 55 | # Make logging work for both disk and Google Cloud Storage 56 | if os.path.exists(os.path.join(FLAGS.workdir, 'stdout.txt')): 57 | gfile_stream = open(os.path.join(FLAGS.workdir, 'stdout.txt'), 'a') 58 | else: 59 | gfile_stream = open(os.path.join(FLAGS.workdir, 'stdout.txt'), 'w') 60 | handler = logging.StreamHandler(gfile_stream) 61 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') 62 | handler.setFormatter(formatter) 63 | logger = logging.getLogger() 64 | logger.addHandler(handler) 65 | logger.setLevel('INFO') 66 | # Run the training pipeline 67 | run_lib.train(FLAGS.config, FLAGS.workdir, FLAGS.assetdir) 68 | elif FLAGS.mode == "eval": 69 | eval_dir = os.path.join(FLAGS.workdir, FLAGS.eval_folder) 70 | tf.io.gfile.makedirs(eval_dir) 71 | stdout_name = 'evaluation_history' 72 | if os.path.exists(os.path.join(FLAGS.workdir, f'{stdout_name}.txt')): 73 | gfile_stream = open(os.path.join(FLAGS.workdir, f'{stdout_name}.txt'), 'a') 74 | else: 75 | gfile_stream = open(os.path.join(FLAGS.workdir, f'{stdout_name}.txt'), 'w') 76 | handler = logging.StreamHandler(gfile_stream) 77 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') 78 | handler.setFormatter(formatter) 79 | logger = logging.getLogger() 80 | logger.addHandler(handler) 81 | logger.setLevel('INFO') 82 | # Run the evaluation pipeline 83 | run_lib.evaluate(FLAGS.config, FLAGS.workdir, FLAGS.assetdir, FLAGS.eval_folder) 84 | else: 85 | raise ValueError(f"Mode {FLAGS.mode} not recognized.") 86 | 87 | 88 | if __name__ == "__main__": 89 | app.run(main) 90 | -------------------------------------------------------------------------------- /flow_models/wolf/nnet/resnets/resnet.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | import torch.nn as nn 4 | 5 | from flow_models.wolf.nnet.resnets.resnet_batchnorm import ResNetBlockBatchNorm, DeResNetBlockBatchNorm 6 | from flow_models.wolf.nnet.resnets.resnet_weightnorm import ResNetBlockWeightNorm, DeResNetBlockWeightNorm 7 | from flow_models.wolf.nnet.resnets.resnet_groupnorm import ResNetBlockGroupNorm, DeResNetBlockGroupNorm 8 | 9 | __all__ = ['ResNetBatchNorm', 'ResNetGroupNorm', 'ResNetWeightNorm', 10 | 'DeResNetBatchNorm', 'DeResNetGroupNorm', 'DeResNetWeightNorm'] 11 | 12 | 13 | class _ResNet(nn.Module): 14 | def __init__(self, resnet_block, inplanes, planes, strides, activation, **kwargs): 15 | super(_ResNet, self).__init__() 16 | assert len(planes) == len(strides) 17 | 18 | blocks = [] 19 | for i in range(len(planes)): 20 | plane = planes[i] 21 | stride = strides[i] 22 | block = resnet_block(inplanes, plane, stride=stride, activation=activation, **kwargs) 23 | blocks.append(block) 24 | inplanes = plane 25 | 26 | self.main = nn.Sequential(*blocks) 27 | 28 | def init(self, x, init_scale=1.0): 29 | for block in self.main: 30 | x = block.init(x, init_scale=init_scale) 31 | return x 32 | 33 | def forward(self, x): 34 | return self.main(x) 35 | 36 | 37 | class _DeResNet(nn.Module): 38 | def __init__(self, deresnet_block, inplanes, planes, strides, output_paddings, activation, **kwargs): 39 | super(_DeResNet, self).__init__() 40 | assert len(planes) == len(strides) 41 | assert len(planes) == len(output_paddings) 42 | 43 | blocks = [] 44 | for i in range(len(planes)): 45 | plane = planes[i] 46 | stride = strides[i] 47 | output_padding = output_paddings[i] 48 | block = deresnet_block(inplanes, plane, stride=stride, output_padding=output_padding, 49 | activation=activation, **kwargs) 50 | blocks.append(block) 51 | inplanes = plane 52 | 53 | self.main = nn.Sequential(*blocks) 54 | 55 | def init(self, x, init_scale=1.0): 56 | for block in self.main: 57 | x = block.init(x, init_scale=init_scale) 58 | return x 59 | 60 | def forward(self, x): 61 | return self.main(x) 62 | 63 | 64 | class ResNetBatchNorm(_ResNet): 65 | def __init__(self, inplanes, planes, strides, activation): 66 | super(ResNetBatchNorm, self).__init__(ResNetBlockBatchNorm, inplanes, planes, strides, activation) 67 | 68 | 69 | class ResNetWeightNorm(_ResNet): 70 | def __init__(self, inplanes, planes, strides, activation): 71 | super(ResNetWeightNorm, self).__init__(ResNetBlockWeightNorm, inplanes, planes, strides, activation) 72 | 73 | 74 | class ResNetGroupNorm(_ResNet): 75 | def __init__(self, inplanes, planes, strides, activation, num_groups): 76 | super(ResNetGroupNorm, self).__init__(ResNetBlockGroupNorm, inplanes, planes, strides, activation, 77 | num_groups=num_groups) 78 | 79 | 80 | class DeResNetBatchNorm(_DeResNet): 81 | def __init__(self, inplanes, planes, strides, output_paddings, activation): 82 | super(DeResNetBatchNorm, self).__init__(DeResNetBlockBatchNorm, inplanes, planes, strides, 83 | output_paddings, activation) 84 | 85 | 86 | class DeResNetWeightNorm(_DeResNet): 87 | def __init__(self, inplanes, planes, strides, output_paddings, activation): 88 | super(DeResNetWeightNorm, self).__init__(DeResNetBlockWeightNorm, inplanes, planes, strides, 89 | output_paddings, activation) 90 | 91 | 92 | class DeResNetGroupNorm(_DeResNet): 93 | def __init__(self, inplanes, planes, strides, output_paddings, activation, num_groups): 94 | super(DeResNetGroupNorm, self).__init__(DeResNetBlockGroupNorm, inplanes, planes, strides, 95 | output_paddings, activation, num_groups=num_groups) 96 | -------------------------------------------------------------------------------- /cleanfid/features.py: -------------------------------------------------------------------------------- 1 | """ 2 | helpers for extractign features from image 3 | """ 4 | import os 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import cleanfid 9 | from cleanfid.downloads_helper import * 10 | from cleanfid.inception_pytorch import InceptionV3 11 | 12 | 13 | class InceptionV3W(nn.Module): 14 | """ 15 | Wrapper around Inception V3 torchscript model provided here 16 | https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt 17 | 18 | path: locally saved inception weights 19 | """ 20 | 21 | def __init__(self, path, download=True, resize_inside=False): 22 | super(InceptionV3W, self).__init__() 23 | # download the network if it is not present at the given directory 24 | # use the current directory by default 25 | if download: 26 | check_download_inception(fpath=path) 27 | path = os.path.join(path, "inception-2015-12-05.pt") 28 | self.base = torch.jit.load(path).eval() 29 | self.layers = self.base.layers 30 | self.resize_inside=resize_inside 31 | 32 | """ 33 | Get the inception features without resizing 34 | x: Image with values in range [0,255] 35 | """ 36 | 37 | def forward(self, x): 38 | bs = x.shape[0] 39 | if self.resize_inside: 40 | features = self.base(x, return_features=True).view((bs, 2048)) 41 | else: 42 | # make sure it is resized already 43 | assert x.shape[2] == 299 44 | # apply normalization 45 | x1 = x - 128 46 | x2 = x1 / 128 47 | features = self.layers.forward(x2, ).view((bs, 2048)) 48 | return features 49 | 50 | 51 | """ 52 | returns a functions that takes an image in range [0,255] 53 | and outputs a feature embedding vector 54 | """ 55 | def feature_extractor(name="torchscript_inception", device=torch.device("cuda"), resize_inside=False): 56 | if name == "torchscript_inception": 57 | model = torch.nn.DataParallel(InceptionV3W("/tmp", download=True, resize_inside=resize_inside).to(device)) 58 | #model = InceptionV3W("/tmp", download=True, resize_inside=resize_inside).to(device) 59 | model.eval() 60 | def model_fn(x): return model(x) 61 | elif name == "pytorch_inception": 62 | model = InceptionV3(output_blocks=[3], resize_input=False).to(device) 63 | model.eval() 64 | def model_fn(x): return model(x/255)[0].squeeze(-1).squeeze(-1) 65 | else: 66 | raise ValueError(f"{name} feature extractor not implemented") 67 | return model_fn 68 | 69 | 70 | def build_feature_extractor(mode, device=torch.device("cuda")): 71 | if mode=="legacy_pytorch": 72 | feat_model = feature_extractor(name="pytorch_inception", resize_inside=False, device=device) 73 | elif mode=="legacy_tensorflow": 74 | feat_model = feature_extractor(name="torchscript_inception", resize_inside=True, device=device) 75 | elif mode=="clean": 76 | feat_model = feature_extractor(name="torchscript_inception", resize_inside=False, device=device) 77 | return feat_model 78 | 79 | def get_reference_statistics(name, res, mode="clean", seed=0, split="test", metric="FID"): 80 | base_url = "https://www.cs.cmu.edu/~clean-fid/stats/" 81 | if split=="custom": res = "na" 82 | if metric=="FID": 83 | rel_path = (f"{name}_{mode}_{split}_{res}.npz").lower() 84 | url = f"{base_url}/{rel_path}" 85 | mod_path = os.path.dirname(cleanfid.__file__) 86 | stats_folder = os.path.join(mod_path, "stats") 87 | fpath = check_download_url(local_folder=stats_folder, url=url) 88 | stats = np.load(fpath) 89 | mu, sigma = stats["mu"], stats["sigma"] 90 | return mu, sigma 91 | elif metric=="KID": 92 | rel_path = (f"{name}_{mode}_{split}_{res}_kid.npz").lower() 93 | url = f"{base_url}/{rel_path}" 94 | mod_path = os.path.dirname(cleanfid.__file__) 95 | stats_folder = os.path.join(mod_path, "stats") 96 | fpath = check_download_url(local_folder=stats_folder, url=url) 97 | stats = np.load(fpath) 98 | return stats["feats"] 99 | -------------------------------------------------------------------------------- /configs/default_celeba_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | 4 | 5 | def get_default_configs(): 6 | config = ml_collections.ConfigDict() 7 | # training 8 | config.training = training = ml_collections.ConfigDict() 9 | config.training.batch_size = 128 10 | training.n_iters = 13000001 11 | training.snapshot_freq = 10000 12 | training.log_freq = 100 13 | training.eval_freq = 100 14 | ## store additional checkpoints for preemption in cloud computing environments 15 | training.snapshot_freq_for_preemption = 10000 16 | ## produce samples at each snapshot. 17 | training.snapshot_sampling = True 18 | training.likelihood_weighting = True 19 | training.continuous = True 20 | training.reduce_mean = False 21 | training.importance_sampling = True 22 | training.unbounded_parametrization = False 23 | training.ddpm_score = True 24 | training.st = False 25 | training.k = 1.2 26 | training.truncation_time = 1e-5 27 | training.num_train_data = 50000 28 | training.reconstruction_loss = False 29 | 30 | # sampling 31 | config.sampling = sampling = ml_collections.ConfigDict() 32 | sampling.n_steps_each = 1 33 | sampling.noise_removal = True 34 | sampling.probability_flow = False 35 | sampling.snr = 0.15 36 | sampling.batch_size = 1024 37 | sampling.truncation_time = 1e-5 38 | 39 | sampling.temperature = 1. 40 | sampling.need_sample = True 41 | sampling.idx_rand = True 42 | sampling.pc_denoise = False 43 | sampling.pc_denoise_time = 0. 44 | sampling.more_step = False 45 | sampling.num_scales = 1000 46 | sampling.pc_ratio = 1. 47 | 48 | sampling.begin_snr = 0.16 49 | sampling.end_snr = 0.16 50 | sampling.snr_scheduling = 'none' 51 | 52 | # evaluation 53 | config.eval = evaluate = ml_collections.ConfigDict() 54 | evaluate.begin_ckpt = 1 55 | evaluate.end_ckpt = 26 56 | evaluate.batch_size = 200 57 | evaluate.enable_sampling = True 58 | evaluate.num_samples = 50000 59 | evaluate.enable_loss = True 60 | evaluate.enable_bpd = True 61 | evaluate.bpd_dataset = 'test' 62 | evaluate.num_test_data = 19962 63 | evaluate.residual = False 64 | evaluate.score_ema = True 65 | evaluate.flow_ema = False 66 | evaluate.num_nelbo = 3 67 | evaluate.rtol = 1e-5 68 | evaluate.atol = 1e-5 69 | 70 | evaluate.gap_diff = False 71 | evaluate.target_ckpt = -1 72 | evaluate.truncation_time = -1. 73 | 74 | evaluate.data_mean = False 75 | evaluate.skip_nll_wrong = False 76 | 77 | # data 78 | config.data = data = ml_collections.ConfigDict() 79 | data.dataset = 'CELEBA' 80 | data.image_size = 64 81 | data.random_flip = True 82 | data.centered = False 83 | data.num_channels = 3 84 | 85 | # model 86 | config.model = model = ml_collections.ConfigDict() 87 | model.sigma_max = 90. 88 | model.sigma_min = 0.01 89 | model.num_scales = 1000 90 | model.beta_min = 0.1 91 | model.beta_max = 20. 92 | model.dropout = 0.1 93 | model.embedding_type = 'fourier' 94 | model.auxiliary_resblock = True 95 | model.attention = True 96 | model.fourier_feature = False 97 | 98 | # optimization 99 | config.optim = optim = ml_collections.ConfigDict() 100 | optim.optimizer = 'AdamW' 101 | optim.weight_decay = 0.01 102 | optim.lr = 2e-4 103 | optim.beta1 = 0.9 104 | optim.eps = 1e-8 105 | optim.warmup = 0 106 | optim.grad_clip = 1. 107 | optim.num_micro_batch = 1 108 | optim.reset = True 109 | optim.amsgrad = False 110 | 111 | # flow 112 | config.flow = flow = ml_collections.ConfigDict() 113 | flow.model = 'identity' 114 | flow.lr = 1e-3 115 | flow.ema_rate = 0.999 116 | flow.optim_reset = False 117 | flow.nblocks = '16-16' 118 | flow.intermediate_dim = 512 119 | flow.resblock_type = 'resflow' 120 | flow.squeeze = True 121 | flow.actnorm = False 122 | flow.grad_in_forward = False 123 | flow.act_fn = 'sin' 124 | 125 | config.seed = 42 126 | config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 127 | 128 | config.datadir = '.' 129 | config.checkpoint_meta_dir = '.' 130 | config.resume = False 131 | 132 | return config -------------------------------------------------------------------------------- /configs/default_cifar10_configs.py: -------------------------------------------------------------------------------- 1 | import ml_collections 2 | import torch 3 | 4 | 5 | def get_default_configs(): 6 | config = ml_collections.ConfigDict() 7 | # training 8 | config.training = training = ml_collections.ConfigDict() 9 | config.training.batch_size = 128 10 | training.n_iters = 13000001 11 | training.snapshot_freq = 10000 12 | training.log_freq = 100 13 | training.eval_freq = 100 14 | ## store additional checkpoints for preemption in cloud computing environments 15 | training.snapshot_freq_for_preemption = 10000 16 | ## produce samples at each snapshot. 17 | training.snapshot_sampling = True 18 | training.likelihood_weighting = True 19 | training.continuous = True 20 | training.reduce_mean = False 21 | training.importance_sampling = True 22 | training.unbounded_parametrization = False 23 | training.ddpm_score = True 24 | training.st = False 25 | training.k = 1.2 26 | training.truncation_time = 1e-5 27 | training.num_train_data = 50000 28 | training.reconstruction_loss = False 29 | 30 | # sampling 31 | config.sampling = sampling = ml_collections.ConfigDict() 32 | sampling.n_steps_each = 1 33 | sampling.noise_removal = True 34 | sampling.probability_flow = False 35 | sampling.snr = 0.16 36 | sampling.batch_size = 1024 37 | sampling.truncation_time = 1e-5 38 | 39 | sampling.temperature = 1. 40 | sampling.need_sample = True 41 | sampling.idx_rand = True 42 | sampling.pc_denoise = False 43 | sampling.pc_denoise_time = 0. 44 | sampling.more_step = False 45 | sampling.num_scales = 1000 46 | sampling.pc_ratio = 1. 47 | 48 | sampling.begin_snr = 0.16 49 | sampling.end_snr = 0.16 50 | sampling.snr_scheduling = 'none' 51 | 52 | # evaluation 53 | config.eval = evaluate = ml_collections.ConfigDict() 54 | evaluate.begin_ckpt = 9 55 | evaluate.end_ckpt = 26 56 | evaluate.batch_size = 200 57 | evaluate.enable_sampling = True 58 | evaluate.num_samples = 50000 59 | evaluate.enable_loss = True 60 | evaluate.enable_bpd = True 61 | evaluate.bpd_dataset = 'test' 62 | evaluate.num_test_data = 10000 63 | evaluate.residual = False 64 | evaluate.score_ema = True 65 | evaluate.flow_ema = False 66 | evaluate.num_nelbo = 3 67 | evaluate.rtol = 1e-5 68 | evaluate.atol = 1e-5 69 | 70 | evaluate.gap_diff = False 71 | evaluate.target_ckpt = -1 72 | evaluate.truncation_time = -1. 73 | 74 | evaluate.data_mean = False 75 | evaluate.skip_nll_wrong = False 76 | 77 | # data 78 | config.data = data = ml_collections.ConfigDict() 79 | data.dataset = 'CIFAR10' 80 | data.image_size = 32 81 | data.random_flip = True 82 | data.centered = False 83 | data.num_channels = 3 84 | 85 | 86 | # model 87 | config.model = model = ml_collections.ConfigDict() 88 | model.sigma_min = 0.01 89 | model.sigma_max = 50 90 | model.num_scales = 1000 91 | model.beta_min = 0.1 92 | model.beta_max = 20. 93 | model.dropout = 0.1 94 | model.embedding_type = 'fourier' 95 | model.auxiliary_resblock = True 96 | model.attention = True 97 | model.fourier_feature = False 98 | 99 | # optimization 100 | config.optim = optim = ml_collections.ConfigDict() 101 | optim.optimizer = 'AdamW' 102 | optim.weight_decay = 0.01 103 | optim.lr = 2e-4 104 | optim.beta1 = 0.9 105 | optim.eps = 1e-8 106 | optim.warmup = 0 107 | optim.grad_clip = 1. 108 | optim.num_micro_batch = 1 109 | optim.reset = True 110 | optim.amsgrad = False 111 | 112 | # flow 113 | config.flow = flow = ml_collections.ConfigDict() 114 | flow.model = 'identity' 115 | flow.lr = 1e-3 116 | flow.ema_rate = 0.999 117 | flow.optim_reset = False 118 | flow.nblocks = '16-16' 119 | flow.intermediate_dim = 512 120 | flow.resblock_type = 'resflow' 121 | flow.squeeze = False 122 | flow.actnorm = False 123 | flow.grad_in_forward = False 124 | flow.act_fn = 'sin' 125 | 126 | config.seed = 42 127 | config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 128 | 129 | config.datadir = '.' 130 | config.checkpoint_meta_dir = '.' 131 | config.resume = False 132 | 133 | return config -------------------------------------------------------------------------------- /flow_models/wolf/flows/flow.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from typing import Dict, Tuple 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class Flow(nn.Module): 9 | """ 10 | Normalizing Flow base class 11 | """ 12 | _registry = dict() 13 | 14 | def __init__(self, inverse): 15 | super(Flow, self).__init__() 16 | self.inverse = inverse 17 | 18 | def forward(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 19 | """ 20 | 21 | Args: 22 | *input: input [batch, *input_size] 23 | 24 | Returns: out: Tensor [batch, *input_size], logdet: Tensor [batch] 25 | out, the output of the flow 26 | logdet, the log determinant of :math:`\partial output / \partial input` 27 | """ 28 | raise NotImplementedError 29 | 30 | def backward(self, *inputs, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 31 | """ 32 | 33 | Args: 34 | *input: input [batch, *input_size] 35 | 36 | Returns: out: Tensor [batch, *input_size], logdet: Tensor [batch] 37 | out, the output of the flow 38 | logdet, the log determinant of :math:`\partial output / \partial input` 39 | """ 40 | raise NotImplementedError 41 | 42 | def init(self, *input, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 43 | raise NotImplementedError 44 | 45 | def fwdpass(self, x: torch.Tensor, *h, init=False, init_scale=1.0, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 46 | """ 47 | 48 | Args: 49 | x: Tensor 50 | The random variable before flow 51 | h: list of object 52 | other conditional inputs 53 | init: bool 54 | perform initialization or not (default: False) 55 | init_scale: float 56 | initial scale (default: 1.0) 57 | 58 | Returns: y: Tensor, logdet: Tensor 59 | y, the random variable after flow 60 | logdet, the log determinant of :math:`\partial y / \partial x` 61 | Then the density :math:`\log(p(y)) = \log(p(x)) - logdet` 62 | 63 | """ 64 | if self.inverse: 65 | if init: 66 | raise RuntimeError('inverse flow shold be initialized with backward pass') 67 | else: 68 | return self.backward(x, *h, **kwargs) 69 | else: 70 | if init: 71 | return self.init(x, *h, init_scale=init_scale, **kwargs) 72 | else: 73 | return self.forward(x, *h, **kwargs) 74 | 75 | def bwdpass(self, y: torch.Tensor, *h, init=False, init_scale=1.0, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: 76 | """ 77 | 78 | Args: 79 | y: Tensor 80 | The random variable after flow 81 | h: list of object 82 | other conditional inputs 83 | init: bool 84 | perform initialization or not (default: False) 85 | init_scale: float 86 | initial scale (default: 1.0) 87 | 88 | Returns: x: Tensor, logdet: Tensor 89 | x, the random variable before flow 90 | logdet, the log determinant of :math:`\partial x / \partial y` 91 | Then the density :math:`\log(p(y)) = \log(p(x)) + logdet` 92 | 93 | """ 94 | if self.inverse: 95 | if init: 96 | return self.init(y, *h, init_scale=init_scale, **kwargs) 97 | else: 98 | return self.forward(y, *h, **kwargs) 99 | else: 100 | if init: 101 | raise RuntimeError('forward flow should be initialzed with forward pass') 102 | else: 103 | return self.backward(y, *h, **kwargs) 104 | 105 | @classmethod 106 | def register(cls, name: str): 107 | Flow._registry[name] = cls 108 | 109 | @classmethod 110 | def by_name(cls, name: str): 111 | return Flow._registry[name] 112 | 113 | @classmethod 114 | def from_params(cls, params: Dict): 115 | raise NotImplementedError 116 | -------------------------------------------------------------------------------- /flow_models/wolf/modules/encoders/local_encoder.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from collections import OrderedDict 4 | from typing import Dict 5 | import torch 6 | import torch.nn as nn 7 | 8 | from flow_models.wolf.modules.encoders.encoder import Encoder 9 | from flow_models.wolf.nnet.resnets import ResNetBatchNorm, DeResNetBatchNorm, ResNetGroupNorm, DeResNetGroupNorm 10 | 11 | 12 | class LocalResNetEncoderBatchNorm(Encoder): 13 | """ 14 | Local ResNet Encoder with batch normalization 15 | """ 16 | 17 | def __init__(self, levels, in_planes, out_planes, hidden_planes, activation): 18 | super(LocalResNetEncoderBatchNorm, self).__init__() 19 | layers = list() 20 | assert len(hidden_planes) == levels 21 | for level in range(levels): 22 | hidden_channels = hidden_planes[level] 23 | layers.append(('resnet{}'.format(level), 24 | ResNetBatchNorm(in_planes, 25 | [hidden_channels, hidden_channels], 26 | [1, 2], activation))) 27 | in_planes = hidden_channels 28 | 29 | in_planes = hidden_planes[-1] 30 | hidden_planes = [out_planes, ] + hidden_planes 31 | for level in reversed(range(levels)): 32 | hidden_channels = hidden_planes[level] 33 | layers.append(('deresnet{}'.format(level), 34 | DeResNetBatchNorm(in_planes, 35 | [in_planes, hidden_channels], 36 | [1, 2], [0, 1], activation))) 37 | in_planes = hidden_channels 38 | 39 | self.net = nn.Sequential(OrderedDict(layers)) 40 | 41 | def forward(self, x): 42 | # [batch, out_planes, h, w] 43 | return self.net(x) 44 | 45 | def init(self, x, init_scale=1.0): 46 | with torch.no_grad(): 47 | return self(x) 48 | 49 | @classmethod 50 | def from_params(cls, params: Dict) -> "LocalResNetEncoderBatchNorm": 51 | return LocalResNetEncoderBatchNorm(**params) 52 | 53 | 54 | class LocalResNetEncoderGroupNorm(Encoder): 55 | """ 56 | Local ResNet Encoder with batch normalization 57 | """ 58 | 59 | def __init__(self, levels, in_planes, out_planes, hidden_planes, activation, num_groups): 60 | super(LocalResNetEncoderGroupNorm, self).__init__() 61 | layers = list() 62 | assert len(hidden_planes) == levels 63 | assert len(num_groups) == levels 64 | for level in range(levels): 65 | hidden_channels = hidden_planes[level] 66 | n_groups = num_groups[level] 67 | layers.append(('resnet{}'.format(level), 68 | ResNetGroupNorm(in_planes, 69 | [hidden_channels, hidden_channels], 70 | [1, 2], activation, num_groups=n_groups))) 71 | in_planes = hidden_channels 72 | 73 | in_planes = hidden_planes[-1] 74 | hidden_planes = [out_planes, ] + hidden_planes 75 | for level in reversed(range(levels)): 76 | hidden_channels = hidden_planes[level] 77 | n_groups = num_groups[level] 78 | layers.append(('deresnet{}'.format(level), 79 | DeResNetGroupNorm(in_planes, 80 | [in_planes, hidden_channels], 81 | [1, 2], [0, 0], activation, num_groups=n_groups))) 82 | in_planes = hidden_channels 83 | 84 | self.net = nn.Sequential(OrderedDict(layers)) 85 | 86 | def forward(self, x): 87 | # [batch, out_planes, h, w] 88 | return self.net(x) 89 | 90 | def init(self, x, init_scale=1.0): 91 | with torch.no_grad(): 92 | return self(x) 93 | 94 | @classmethod 95 | def from_params(cls, params: Dict) -> "LocalResNetEncoderGroupNorm": 96 | return LocalResNetEncoderGroupNorm(**params) 97 | 98 | 99 | LocalResNetEncoderBatchNorm.register('local_resnet_bn') 100 | LocalResNetEncoderGroupNorm.register('local_resnet_gn') 101 | -------------------------------------------------------------------------------- /flow_models/wolf/modules/discriminators/gaussian.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from typing import Dict 4 | from overrides import overrides 5 | import math 6 | import torch 7 | 8 | from flow_models.wolf.nnet.weight_norm import LinearWeightNorm 9 | from flow_models.wolf.modules.discriminators.discriminator import Discriminator 10 | from flow_models.wolf.modules.encoders.encoder import Encoder 11 | from flow_models.wolf.modules.discriminators.priors.prior import Prior 12 | 13 | 14 | class GaussianDiscriminator(Discriminator): 15 | def __init__(self, encoder: Encoder, in_dim, dim, prior: Prior): 16 | super(GaussianDiscriminator, self).__init__() 17 | self.dim = dim 18 | self.encoder = encoder 19 | self.fc = LinearWeightNorm(in_dim, 2 * dim, bias=True) 20 | self.prior = prior 21 | 22 | def forward(self, x): 23 | c = self.encoder(x) 24 | c = self.fc(c) 25 | mu, logvar = c.chunk(2, dim=1) 26 | return mu, logvar 27 | 28 | @staticmethod 29 | def reparameterize(mu, logvar, nsamples=1, random=True): 30 | # [batch, dim] 31 | size = mu.size() 32 | std = logvar.mul(0.5).exp() 33 | # [batch, nsamples, dim] 34 | if random: 35 | eps = torch.randn(size[0], nsamples, size[1], device=mu.device) 36 | else: 37 | eps = mu.new_zeros(size[0], nsamples, size[1]) 38 | return eps.mul(std.unsqueeze(1)).add(mu.unsqueeze(1)), eps 39 | 40 | @staticmethod 41 | def log_probability_posterior(eps, logvar): 42 | size = eps.size() 43 | dim = size[2] 44 | # [batch, nsamples, dim] 45 | log_probs = logvar.unsqueeze(1) + eps.pow(2) 46 | # [batch, 1] 47 | cc = math.log(math.pi * 2.) * dim 48 | # [batch, nsamples, dim] --> [batch, nsamples] 49 | log_probs = log_probs.sum(dim=2) + cc 50 | return log_probs * -0.5 51 | 52 | @overrides 53 | def sample_from_prior(self, nsamples=1, device=torch.device('cpu')): 54 | return self.prior.sample(nsamples, self.dim, device) 55 | 56 | @overrides 57 | def sample_from_posterior(self, x, y=None, nsamples=1, random=True): 58 | # [batch, dim] 59 | mu, logvar = self(x) 60 | # [batch, nsamples, dim] 61 | z, eps = GaussianDiscriminator.reparameterize(mu, logvar, nsamples=nsamples, random=random) 62 | # [batch, nsamples] 63 | log_probs = GaussianDiscriminator.log_probability_posterior(eps, logvar) 64 | return z, log_probs 65 | 66 | @overrides 67 | def sampling_and_KL(self, x, y=None, nsamples=1): 68 | mu, logvar = self(x) 69 | # [batch, nsamples, dim] 70 | z, eps = GaussianDiscriminator.reparameterize(mu, logvar, nsamples=nsamples, random=True) 71 | # [batch,] 72 | KL = self.prior.calcKL(z, eps, mu, logvar) 73 | # [batch, nsamples] 74 | # log_probs_posterior = GaussianDiscriminator.log_probability_posterior(eps, logvar) 75 | # log_probs_prior = GaussianDiscriminator.log_probability_prior(z) 76 | return z, KL 77 | 78 | @overrides 79 | def init(self, x, y=None, init_scale=1.0): 80 | with torch.no_grad(): 81 | c = self.encoder.init(x, init_scale=init_scale) 82 | c = self.fc.init(c, init_scale=0.01 * init_scale) 83 | mu, logvar = c.chunk(2, dim=1) 84 | # [batch, 1, dim] 85 | z, eps = GaussianDiscriminator.reparameterize(mu, logvar, nsamples=1, random=True) 86 | # [batch,] 87 | KL = self.prior.init(z, eps, mu, logvar, init_scale=init_scale) 88 | return z.squeeze(1), KL 89 | 90 | @overrides 91 | def sync(self): 92 | self.prior.sync() 93 | 94 | @classmethod 95 | def from_params(cls, params: Dict) -> "GaussianDiscriminator": 96 | encoder_params = params.pop('encoder') 97 | encoder = Encoder.by_name(encoder_params.pop('type')).from_params(encoder_params) 98 | prior_params = params.pop('prior') 99 | prior = Prior.by_name(prior_params.pop('type')).from_params(prior_params) 100 | return GaussianDiscriminator(encoder=encoder, prior=prior, **params) 101 | 102 | 103 | GaussianDiscriminator.register('gaussian') 104 | -------------------------------------------------------------------------------- /flow_models/wolf/modules/generators/generator.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | import math 4 | from typing import Dict, Tuple, Union 5 | import torch 6 | import torch.nn as nn 7 | 8 | from flow_models.wolf.flows.flow import Flow 9 | 10 | 11 | class Generator(nn.Module): 12 | """ 13 | class for Generator with a Flow. 14 | """ 15 | 16 | def __init__(self, flow: Flow): 17 | super(Generator, self).__init__() 18 | self.flow = flow 19 | 20 | def add_config(self, config): 21 | self.config = config 22 | 23 | def sync(self): 24 | self.flow.sync() 25 | 26 | def generate(self, epsilon: torch.Tensor, 27 | h: Union[None, torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: 28 | """ 29 | 30 | Args: 31 | epsilon: Tensor [batch, channels, height, width] 32 | epslion for generation 33 | h: Tensor or None [batch, dim] 34 | conditional input 35 | 36 | Returns: Tensor1, Tensor2 37 | Tensor1: generated tensor [batch, channels, height, width] 38 | Tensor2: log probabilities [batch] 39 | 40 | """ 41 | # [batch, channel, height, width] 42 | z, logdet = self.flow.fwdpass(epsilon, h) 43 | return z, logdet 44 | 45 | def encode(self, x: torch.Tensor, h: Union[None, torch.Tensor] = None) -> torch.Tensor: 46 | """ 47 | 48 | Args: 49 | x: Tensor [batch, channels, height, width] 50 | The input data. 51 | h: Tensor or None [batch, dim] 52 | conditional input 53 | 54 | Returns: Tensor [batch, channels, height, width] 55 | The tensor for encoded epsilon. 56 | 57 | """ 58 | return self.flow.bwdpass(x, h)[0] 59 | 60 | def log_probability(self, x: torch.Tensor, h: Union[None, torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: 61 | """ 62 | 63 | Args: 64 | x: Tensor [batch, channel, height, width] 65 | The input data. 66 | h: Tensor or None [batch, dim] 67 | conditional input 68 | 69 | Returns: Tensor1, Tensor2 [batch,] 70 | Tensor1: generated tensor [batch, channels, height, width] 71 | Tensor2: The tensor of the log probabilities of x [batch] 72 | 73 | """ 74 | # [batch, channel, height, width] 75 | epsilon_org, logdet = self.flow.bwdpass(x, h) 76 | if self.config.model.name == 'None': 77 | # [batch, numels] 78 | epsilon = epsilon_org.view(epsilon_org.size(0), -1) 79 | # [batch] 80 | log_probs = epsilon.mul(epsilon).sum(dim=1) + math.log(math.pi * 2.) * epsilon.size(1) 81 | return epsilon_org, log_probs.mul(-0.5) + logdet 82 | else: 83 | return epsilon_org, logdet 84 | 85 | def init(self, data: torch.Tensor, h=None, init_scale=1.0): 86 | return self.flow.bwdpass(data, h, init=True, init_scale=init_scale) 87 | 88 | @classmethod 89 | def from_params(cls, params: Dict, config=None) -> "Generator": 90 | flow_params = params.pop('flow') 91 | flow_type = flow_params.pop('type') 92 | if flow_type == 'resflow': 93 | from flow_models.wolf.flows.resflow import ResidualFlow 94 | if config.flow.squeeze: 95 | input_shape = (config.training.batch_size, config.data.num_channels * 4, config.data.image_size // 2, 96 | config.data.image_size // 2) 97 | else: 98 | input_shape = ( 99 | config.training.batch_size, config.data.num_channels, config.data.image_size, config.data.image_size) 100 | flow = ResidualFlow(config, input_shape, 101 | n_blocks=list(map(int, config.flow.nblocks.split('-'))), 102 | intermediate_dim=config.flow.intermediate_dim, 103 | vnorms='ffff', 104 | actnorm=config.flow.actnorm, 105 | grad_in_forward=config.flow.grad_in_forward, 106 | activation_fn=config.flow.act_fn).to(config.device) 107 | else: 108 | flow = Flow.by_name(flow_type).from_params(flow_params) 109 | return Generator(flow) 110 | -------------------------------------------------------------------------------- /cleanfid/resize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for resizing with multiple CPU cores 3 | """ 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | import torch.nn.functional as F 8 | from cleanfid.utils import * 9 | 10 | dict_name_to_filter = { 11 | "PIL": { 12 | "bicubic": Image.BICUBIC, 13 | "bilinear": Image.BILINEAR, 14 | "nearest" : Image.NEAREST, 15 | "lanczos" : Image.LANCZOS, 16 | "box" : Image.BOX 17 | }, 18 | } 19 | 20 | def build_resizer(mode): 21 | if mode=="clean": 22 | return make_resizer("PIL", False, "bicubic", (299,299)) 23 | # if using legacy tensorflow, do not manually resize outside the network 24 | elif mode == "legacy_tensorflow": 25 | return lambda x: x 26 | elif mode == "legacy_pytorch": 27 | return make_resizer("PyTorch", False, "bilinear", (299, 299)) 28 | else: 29 | raise ValueError(f"Invalid mode {mode} specified") 30 | 31 | """ 32 | Construct a function that resizes a numpy image based on the 33 | flags passed in. 34 | """ 35 | def make_resizer(library, quantize_after, filter, output_size): 36 | if library == "PIL" and quantize_after: 37 | def func(x): 38 | x = Image.fromarray(x) 39 | x = x.resize(output_size, resample=dict_name_to_filter[library][filter]) 40 | x = np.asarray(x).astype(np.uint8) 41 | return x 42 | 43 | elif library == "PIL" and not quantize_after: 44 | s1, s2 = output_size 45 | 46 | def resize_single_channel(x_np): 47 | img = Image.fromarray(x_np.astype(np.float32), mode='F') 48 | img = img.resize(output_size, resample=dict_name_to_filter[library][filter]) 49 | return np.asarray(img).reshape(s1, s2, 1) 50 | def func(x): 51 | x = [resize_single_channel(x[:, :, idx]) for idx in range(3)] 52 | x = np.concatenate(x, axis=2).astype(np.float32) 53 | return x 54 | 55 | elif library == "PyTorch": 56 | import warnings 57 | # ignore the numpy warnings 58 | warnings.filterwarnings("ignore") 59 | def func(x): 60 | x = torch.Tensor(x.transpose((2, 0, 1)))[None, ...] 61 | x = F.interpolate(x, size=output_size, mode=filter, align_corners=False) 62 | x = x[0, ...].cpu().data.numpy().transpose((1, 2, 0)).clip(0, 255) 63 | if quantize_after: 64 | x = x.astype(np.uint8) 65 | return x 66 | 67 | elif library == "TensorFlow": 68 | import warnings 69 | # ignore the numpy warnings 70 | warnings.filterwarnings("ignore") 71 | import tensorflow as tf 72 | def func(x): 73 | x = tf.constant(x)[tf.newaxis, ...] 74 | x = tf.image.resize(x, output_size, method=filter) 75 | x = x[0, ...].numpy().clip(0, 255) 76 | if quantize_after: 77 | x = x.astype(np.uint8) 78 | return x 79 | 80 | elif library=="OpenCV": 81 | import cv2 82 | name_to_filter = { 83 | "bilinear": cv2.INTER_LINEAR, 84 | "bicubic" : cv2.INTER_CUBIC, 85 | "lanczos" : cv2.INTER_LANCZOS4, 86 | "nearest" : cv2.INTER_NEAREST, 87 | "area" : cv2.INTER_AREA 88 | } 89 | def func(x): 90 | x = cv2.resize(x, output_size, interpolation=name_to_filter[filter]) 91 | if quantize_after: x = x.astype(np.uint8) 92 | return x 93 | else: 94 | raise NotImplementedError('library [%s] is not include' % library) 95 | return func 96 | 97 | 98 | class FolderResizer(torch.utils.data.Dataset): 99 | def __init__(self, files, outpath, fn_resize, output_ext=".png"): 100 | self.files = files 101 | self.outpath = outpath 102 | self.output_ext = output_ext 103 | self.fn_resize = fn_resize 104 | 105 | def __len__(self): 106 | return len(self.files) 107 | 108 | def __getitem__(self, i): 109 | path = str(self.files[i]) 110 | img_np = np.asarray(Image.open(path)) 111 | img_resize_np = self.fn_resize(img_np) 112 | # swap the output extension 113 | basename = os.path.basename(path).split(".")[0] + self.output_ext 114 | outname = os.path.join(self.outpath, basename) 115 | if self.output_ext == ".npy": 116 | np.save(outname, img_resize_np) 117 | elif self.output_ext == ".png": 118 | img_resized_pil = Image.fromarray(img_resize_np) 119 | img_resized_pil.save(outname) 120 | else: 121 | raise ValueError("invalid output extension") 122 | return 0 123 | -------------------------------------------------------------------------------- /flow_models/resflow/toy_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn 3 | import sklearn.datasets 4 | from sklearn.utils import shuffle as util_shuffle 5 | 6 | 7 | # Dataset iterator 8 | def inf_train_gen(data, batch_size=200): 9 | 10 | if data == "swissroll": 11 | data = sklearn.datasets.make_swiss_roll(n_samples=batch_size, noise=1.0)[0] 12 | data = data.astype("float32")[:, [0, 2]] 13 | data /= 5 14 | return data 15 | 16 | elif data == "circles": 17 | data = sklearn.datasets.make_circles(n_samples=batch_size, factor=.5, noise=0.08)[0] 18 | data = data.astype("float32") 19 | data *= 3 20 | return data 21 | 22 | elif data == "rings": 23 | n_samples4 = n_samples3 = n_samples2 = batch_size // 4 24 | n_samples1 = batch_size - n_samples4 - n_samples3 - n_samples2 25 | 26 | # so as not to have the first point = last point, we set endpoint=False 27 | linspace4 = np.linspace(0, 2 * np.pi, n_samples4, endpoint=False) 28 | linspace3 = np.linspace(0, 2 * np.pi, n_samples3, endpoint=False) 29 | linspace2 = np.linspace(0, 2 * np.pi, n_samples2, endpoint=False) 30 | linspace1 = np.linspace(0, 2 * np.pi, n_samples1, endpoint=False) 31 | 32 | circ4_x = np.cos(linspace4) 33 | circ4_y = np.sin(linspace4) 34 | circ3_x = np.cos(linspace4) * 0.75 35 | circ3_y = np.sin(linspace3) * 0.75 36 | circ2_x = np.cos(linspace2) * 0.5 37 | circ2_y = np.sin(linspace2) * 0.5 38 | circ1_x = np.cos(linspace1) * 0.25 39 | circ1_y = np.sin(linspace1) * 0.25 40 | 41 | X = np.vstack([ 42 | np.hstack([circ4_x, circ3_x, circ2_x, circ1_x]), 43 | np.hstack([circ4_y, circ3_y, circ2_y, circ1_y]) 44 | ]).T * 3.0 45 | X = util_shuffle(X) 46 | 47 | # Add noise 48 | X = X + np.random.normal(scale=0.08, size=X.shape) 49 | 50 | return X.astype("float32") 51 | 52 | elif data == "moons": 53 | data = sklearn.datasets.make_moons(n_samples=batch_size, noise=0.1)[0] 54 | data = data.astype("float32") 55 | data = data * 2 + np.array([-1, -0.2]) 56 | return data 57 | 58 | elif data == "8gaussians": 59 | scale = 4. 60 | centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)), 61 | (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2), 62 | 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))] 63 | centers = [(scale * x, scale * y) for x, y in centers] 64 | 65 | dataset = [] 66 | for i in range(batch_size): 67 | point = np.random.randn(2) * 0.5 68 | idx = np.random.randint(8) 69 | center = centers[idx] 70 | point[0] += center[0] 71 | point[1] += center[1] 72 | dataset.append(point) 73 | dataset = np.array(dataset, dtype="float32") 74 | dataset /= 1.414 75 | return dataset 76 | 77 | elif data == "pinwheel": 78 | radial_std = 0.3 79 | tangential_std = 0.1 80 | num_classes = 5 81 | num_per_class = batch_size // 5 82 | rate = 0.25 83 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False) 84 | 85 | features = np.random.randn(num_classes*num_per_class, 2) \ 86 | * np.array([radial_std, tangential_std]) 87 | features[:, 0] += 1. 88 | labels = np.repeat(np.arange(num_classes), num_per_class) 89 | 90 | angles = rads[labels] + rate * np.exp(features[:, 0]) 91 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)]) 92 | rotations = np.reshape(rotations.T, (-1, 2, 2)) 93 | 94 | return 2 * np.random.permutation(np.einsum("ti,tij->tj", features, rotations)) 95 | 96 | elif data == "2spirals": 97 | n = np.sqrt(np.random.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360 98 | d1x = -np.cos(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 99 | d1y = np.sin(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 100 | x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3 101 | x += np.random.randn(*x.shape) * 0.1 102 | return x 103 | 104 | elif data == "checkerboard": 105 | x1 = np.random.rand(batch_size) * 4 - 2 106 | x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2 107 | x2 = x2_ + (np.floor(x1) % 2) 108 | return np.concatenate([x1[:, None], x2[:, None]], 1) * 2 109 | 110 | elif data == "line": 111 | x = np.random.rand(batch_size) * 5 - 2.5 112 | y = x 113 | return np.stack((x, y), 1) 114 | elif data == "cos": 115 | x = np.random.rand(batch_size) * 5 - 2.5 116 | y = np.sin(x) * 2.5 117 | return np.stack((x, y), 1) 118 | else: 119 | return inf_train_gen("8gaussians", batch_size) 120 | -------------------------------------------------------------------------------- /flow_models/resflow/visualize_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use("Agg") 4 | import matplotlib.pyplot as plt 5 | import torch 6 | 7 | LOW = -4 8 | HIGH = 4 9 | 10 | 11 | def plt_potential_func(potential, ax, npts=100, title="$p(x)$"): 12 | """ 13 | Args: 14 | potential: computes U(z_k) given z_k 15 | """ 16 | xside = np.linspace(LOW, HIGH, npts) 17 | yside = np.linspace(LOW, HIGH, npts) 18 | xx, yy = np.meshgrid(xside, yside) 19 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 20 | 21 | z = torch.Tensor(z) 22 | u = potential(z).cpu().numpy() 23 | p = np.exp(-u).reshape(npts, npts) 24 | 25 | plt.pcolormesh(xx, yy, p) 26 | ax.invert_yaxis() 27 | ax.get_xaxis().set_ticks([]) 28 | ax.get_yaxis().set_ticks([]) 29 | ax.set_title(title) 30 | 31 | 32 | def plt_flow(prior_logdensity, transform, ax, npts=100, title="$q(x)$", device="cpu"): 33 | """ 34 | Args: 35 | transform: computes z_k and log(q_k) given z_0 36 | """ 37 | side = np.linspace(LOW, HIGH, npts) 38 | xx, yy = np.meshgrid(side, side) 39 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 40 | 41 | z = torch.tensor(z, requires_grad=True).type(torch.float32).to(device) 42 | logqz = prior_logdensity(z) 43 | logqz = torch.sum(logqz, dim=1)[:, None] 44 | z, logqz = transform(z, logqz) 45 | logqz = torch.sum(logqz, dim=1)[:, None] 46 | 47 | xx = z[:, 0].cpu().numpy().reshape(npts, npts) 48 | yy = z[:, 1].cpu().numpy().reshape(npts, npts) 49 | qz = np.exp(logqz.cpu().numpy()).reshape(npts, npts) 50 | 51 | plt.pcolormesh(xx, yy, qz) 52 | ax.set_xlim(LOW, HIGH) 53 | ax.set_ylim(LOW, HIGH) 54 | cmap = matplotlib.cm.get_cmap(None) 55 | ax.set_facecolor(cmap(0.)) 56 | ax.invert_yaxis() 57 | ax.get_xaxis().set_ticks([]) 58 | ax.get_yaxis().set_ticks([]) 59 | ax.set_title(title) 60 | 61 | 62 | def plt_flow_density(prior_logdensity, inverse_transform, ax, npts=100, memory=100, title="$q(x)$", device="cpu"): 63 | side = np.linspace(LOW, HIGH, npts) 64 | xx, yy = np.meshgrid(side, side) 65 | x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 66 | 67 | x = torch.from_numpy(x).type(torch.float32).to(device) 68 | zeros = torch.zeros(x.shape[0], 1).to(x) 69 | 70 | z, delta_logp = [], [] 71 | inds = torch.arange(0, x.shape[0]).to(torch.int64) 72 | for ii in torch.split(inds, int(memory**2)): 73 | z_, delta_logp_ = inverse_transform(x[ii], zeros[ii]) 74 | z.append(z_) 75 | delta_logp.append(delta_logp_) 76 | z = torch.cat(z, 0) 77 | delta_logp = torch.cat(delta_logp, 0) 78 | 79 | logpz = prior_logdensity(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) 80 | logpx = logpz - delta_logp 81 | 82 | px = np.exp(logpx.cpu().numpy()).reshape(npts, npts) 83 | 84 | ax.imshow(px, cmap='inferno') 85 | ax.get_xaxis().set_ticks([]) 86 | ax.get_yaxis().set_ticks([]) 87 | ax.set_title(title) 88 | 89 | 90 | def plt_flow_samples(prior_sample, transform, ax, npts=100, memory=100, title="$x ~ q(x)$", device="cpu"): 91 | z = prior_sample(npts * npts, 2).type(torch.float32).to(device) 92 | zk = [] 93 | inds = torch.arange(0, z.shape[0]).to(torch.int64) 94 | for ii in torch.split(inds, int(memory**2)): 95 | zk.append(transform(z[ii])) 96 | zk = torch.cat(zk, 0).cpu().numpy() 97 | ax.hist2d(zk[:, 0], zk[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts, cmap='inferno') 98 | ax.invert_yaxis() 99 | ax.get_xaxis().set_ticks([]) 100 | ax.get_yaxis().set_ticks([]) 101 | ax.set_title(title) 102 | 103 | 104 | def plt_samples(samples, ax, npts=100, title="$x ~ p(x)$"): 105 | ax.hist2d(samples[:, 0], samples[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts, cmap='inferno') 106 | ax.invert_yaxis() 107 | ax.get_xaxis().set_ticks([]) 108 | ax.get_yaxis().set_ticks([]) 109 | ax.set_title(title) 110 | 111 | 112 | def visualize_transform( 113 | potential_or_samples, prior_sample, prior_density, transform=None, inverse_transform=None, samples=True, npts=100, 114 | memory=100, device="cpu" 115 | ): 116 | """Produces visualization for the model density and samples from the model.""" 117 | plt.clf() 118 | ax = plt.subplot(1, 3, 1, aspect="equal") 119 | if samples: 120 | plt_samples(potential_or_samples, ax, npts=npts) 121 | else: 122 | plt_potential_func(potential_or_samples, ax, npts=npts) 123 | 124 | ax = plt.subplot(1, 3, 2, aspect="equal") 125 | if inverse_transform is None: 126 | plt_flow(prior_density, transform, ax, npts=npts, device=device) 127 | else: 128 | plt_flow_density(prior_density, inverse_transform, ax, npts=npts, memory=memory, device=device) 129 | 130 | ax = plt.subplot(1, 3, 3, aspect="equal") 131 | if transform is not None: 132 | plt_flow_samples(prior_sample, transform, ax, npts=npts, memory=memory, device=device) 133 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/toy_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn 3 | import sklearn.datasets 4 | from sklearn.utils import shuffle as util_shuffle 5 | 6 | 7 | # Dataset iterator 8 | def inf_train_gen(data, batch_size=200): 9 | 10 | if data == "swissroll": 11 | data = sklearn.datasets.make_swiss_roll(n_samples=batch_size, noise=1.0)[0] 12 | data = data.astype("float32")[:, [0, 2]] 13 | data /= 5 14 | return data 15 | 16 | elif data == "circles": 17 | data = sklearn.datasets.make_circles(n_samples=batch_size, factor=.5, noise=0.08)[0] 18 | data = data.astype("float32") 19 | data *= 3 20 | return data 21 | 22 | elif data == "rings": 23 | n_samples4 = n_samples3 = n_samples2 = batch_size // 4 24 | n_samples1 = batch_size - n_samples4 - n_samples3 - n_samples2 25 | 26 | # so as not to have the first point = last point, we set endpoint=False 27 | linspace4 = np.linspace(0, 2 * np.pi, n_samples4, endpoint=False) 28 | linspace3 = np.linspace(0, 2 * np.pi, n_samples3, endpoint=False) 29 | linspace2 = np.linspace(0, 2 * np.pi, n_samples2, endpoint=False) 30 | linspace1 = np.linspace(0, 2 * np.pi, n_samples1, endpoint=False) 31 | 32 | circ4_x = np.cos(linspace4) 33 | circ4_y = np.sin(linspace4) 34 | circ3_x = np.cos(linspace4) * 0.75 35 | circ3_y = np.sin(linspace3) * 0.75 36 | circ2_x = np.cos(linspace2) * 0.5 37 | circ2_y = np.sin(linspace2) * 0.5 38 | circ1_x = np.cos(linspace1) * 0.25 39 | circ1_y = np.sin(linspace1) * 0.25 40 | 41 | X = np.vstack([ 42 | np.hstack([circ4_x, circ3_x, circ2_x, circ1_x]), 43 | np.hstack([circ4_y, circ3_y, circ2_y, circ1_y]) 44 | ]).T * 3.0 45 | X = util_shuffle(X) 46 | 47 | # Add noise 48 | X = X + np.random.normal(scale=0.08, size=X.shape) 49 | 50 | return X.astype("float32") 51 | 52 | elif data == "moons": 53 | data = sklearn.datasets.make_moons(n_samples=batch_size, noise=0.1)[0] 54 | data = data.astype("float32") 55 | data = data * 2 + np.array([-1, -0.2]) 56 | return data 57 | 58 | elif data == "8gaussians": 59 | scale = 4. 60 | centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)), 61 | (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2), 62 | 1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))] 63 | centers = [(scale * x, scale * y) for x, y in centers] 64 | 65 | dataset = [] 66 | for i in range(batch_size): 67 | point = np.random.randn(2) * 0.5 68 | idx = np.random.randint(8) 69 | center = centers[idx] 70 | point[0] += center[0] 71 | point[1] += center[1] 72 | dataset.append(point) 73 | dataset = np.array(dataset, dtype="float32") 74 | dataset /= 1.414 75 | return dataset 76 | 77 | elif data == "pinwheel": 78 | radial_std = 0.3 79 | tangential_std = 0.1 80 | num_classes = 5 81 | num_per_class = batch_size // 5 82 | rate = 0.25 83 | rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False) 84 | 85 | features = np.random.randn(num_classes*num_per_class, 2) \ 86 | * np.array([radial_std, tangential_std]) 87 | features[:, 0] += 1. 88 | labels = np.repeat(np.arange(num_classes), num_per_class) 89 | 90 | angles = rads[labels] + rate * np.exp(features[:, 0]) 91 | rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)]) 92 | rotations = np.reshape(rotations.T, (-1, 2, 2)) 93 | 94 | return 2 * np.random.permutation(np.einsum("ti,tij->tj", features, rotations)) 95 | 96 | elif data == "2spirals": 97 | n = np.sqrt(np.random.rand(batch_size // 2, 1)) * 540 * (2 * np.pi) / 360 98 | d1x = -np.cos(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 99 | d1y = np.sin(n) * n + np.random.rand(batch_size // 2, 1) * 0.5 100 | x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3 101 | x += np.random.randn(*x.shape) * 0.1 102 | return x 103 | 104 | elif data == "checkerboard": 105 | x1 = np.random.rand(batch_size) * 4 - 2 106 | x2_ = np.random.rand(batch_size) - np.random.randint(0, 2, batch_size) * 2 107 | x2 = x2_ + (np.floor(x1) % 2) 108 | return np.concatenate([x1[:, None], x2[:, None]], 1) * 2 109 | 110 | elif data == "line": 111 | x = np.random.rand(batch_size) * 5 - 2.5 112 | y = x 113 | return np.stack((x, y), 1) 114 | elif data == "cos": 115 | x = np.random.rand(batch_size) * 5 - 2.5 116 | y = np.sin(x) * 2.5 117 | return np.stack((x, y), 1) 118 | else: 119 | return inf_train_gen("8gaussians", batch_size) 120 | -------------------------------------------------------------------------------- /flow_models/wolf/flows/resflow/visualize_flow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | matplotlib.use("Agg") 4 | import matplotlib.pyplot as plt 5 | import torch 6 | 7 | LOW = -4 8 | HIGH = 4 9 | 10 | 11 | def plt_potential_func(potential, ax, npts=100, title="$p(x)$"): 12 | """ 13 | Args: 14 | potential: computes U(z_k) given z_k 15 | """ 16 | xside = np.linspace(LOW, HIGH, npts) 17 | yside = np.linspace(LOW, HIGH, npts) 18 | xx, yy = np.meshgrid(xside, yside) 19 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 20 | 21 | z = torch.Tensor(z) 22 | u = potential(z).cpu().numpy() 23 | p = np.exp(-u).reshape(npts, npts) 24 | 25 | plt.pcolormesh(xx, yy, p) 26 | ax.invert_yaxis() 27 | ax.get_xaxis().set_ticks([]) 28 | ax.get_yaxis().set_ticks([]) 29 | ax.set_title(title) 30 | 31 | 32 | def plt_flow(prior_logdensity, transform, ax, npts=100, title="$q(x)$", device="cpu"): 33 | """ 34 | Args: 35 | transform: computes z_k and log(q_k) given z_0 36 | """ 37 | side = np.linspace(LOW, HIGH, npts) 38 | xx, yy = np.meshgrid(side, side) 39 | z = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 40 | 41 | z = torch.tensor(z, requires_grad=True).type(torch.float32).to(device) 42 | logqz = prior_logdensity(z) 43 | logqz = torch.sum(logqz, dim=1)[:, None] 44 | z, logqz = transform(z, logqz) 45 | logqz = torch.sum(logqz, dim=1)[:, None] 46 | 47 | xx = z[:, 0].cpu().numpy().reshape(npts, npts) 48 | yy = z[:, 1].cpu().numpy().reshape(npts, npts) 49 | qz = np.exp(logqz.cpu().numpy()).reshape(npts, npts) 50 | 51 | plt.pcolormesh(xx, yy, qz) 52 | ax.set_xlim(LOW, HIGH) 53 | ax.set_ylim(LOW, HIGH) 54 | cmap = matplotlib.cm.get_cmap(None) 55 | ax.set_facecolor(cmap(0.)) 56 | ax.invert_yaxis() 57 | ax.get_xaxis().set_ticks([]) 58 | ax.get_yaxis().set_ticks([]) 59 | ax.set_title(title) 60 | 61 | 62 | def plt_flow_density(prior_logdensity, inverse_transform, ax, npts=100, memory=100, title="$q(x)$", device="cpu"): 63 | side = np.linspace(LOW, HIGH, npts) 64 | xx, yy = np.meshgrid(side, side) 65 | x = np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]) 66 | 67 | x = torch.from_numpy(x).type(torch.float32).to(device) 68 | zeros = torch.zeros(x.shape[0], 1).to(x) 69 | 70 | z, delta_logp = [], [] 71 | inds = torch.arange(0, x.shape[0]).to(torch.int64) 72 | for ii in torch.split(inds, int(memory**2)): 73 | z_, delta_logp_ = inverse_transform(x[ii], zeros[ii]) 74 | z.append(z_) 75 | delta_logp.append(delta_logp_) 76 | z = torch.cat(z, 0) 77 | delta_logp = torch.cat(delta_logp, 0) 78 | 79 | logpz = prior_logdensity(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) 80 | logpx = logpz - delta_logp 81 | 82 | px = np.exp(logpx.cpu().numpy()).reshape(npts, npts) 83 | 84 | ax.imshow(px, cmap='inferno') 85 | ax.get_xaxis().set_ticks([]) 86 | ax.get_yaxis().set_ticks([]) 87 | ax.set_title(title) 88 | 89 | 90 | def plt_flow_samples(prior_sample, transform, ax, npts=100, memory=100, title="$x ~ q(x)$", device="cpu"): 91 | z = prior_sample(npts * npts, 2).type(torch.float32).to(device) 92 | zk = [] 93 | inds = torch.arange(0, z.shape[0]).to(torch.int64) 94 | for ii in torch.split(inds, int(memory**2)): 95 | zk.append(transform(z[ii])) 96 | zk = torch.cat(zk, 0).cpu().numpy() 97 | ax.hist2d(zk[:, 0], zk[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts, cmap='inferno') 98 | ax.invert_yaxis() 99 | ax.get_xaxis().set_ticks([]) 100 | ax.get_yaxis().set_ticks([]) 101 | ax.set_title(title) 102 | 103 | 104 | def plt_samples(samples, ax, npts=100, title="$x ~ p(x)$"): 105 | ax.hist2d(samples[:, 0], samples[:, 1], range=[[LOW, HIGH], [LOW, HIGH]], bins=npts, cmap='inferno') 106 | ax.invert_yaxis() 107 | ax.get_xaxis().set_ticks([]) 108 | ax.get_yaxis().set_ticks([]) 109 | ax.set_title(title) 110 | 111 | 112 | def visualize_transform( 113 | potential_or_samples, prior_sample, prior_density, transform=None, inverse_transform=None, samples=True, npts=100, 114 | memory=100, device="cpu" 115 | ): 116 | """Produces visualization for the model density and samples from the model.""" 117 | plt.clf() 118 | ax = plt.subplot(1, 3, 1, aspect="equal") 119 | if samples: 120 | plt_samples(potential_or_samples, ax, npts=npts) 121 | else: 122 | plt_potential_func(potential_or_samples, ax, npts=npts) 123 | 124 | ax = plt.subplot(1, 3, 2, aspect="equal") 125 | if inverse_transform is None: 126 | plt_flow(prior_density, transform, ax, npts=npts, device=device) 127 | else: 128 | plt_flow_density(prior_density, inverse_transform, ax, npts=npts, memory=memory, device=device) 129 | 130 | ax = plt.subplot(1, 3, 3, aspect="equal") 131 | if transform is not None: 132 | plt_flow_samples(prior_sample, transform, ax, npts=npts, memory=memory, device=device) 133 | -------------------------------------------------------------------------------- /flow_models/wolf/nnet/weight_norm.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from overrides import overrides 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class LinearWeightNorm(nn.Module): 9 | def __init__(self, in_features, out_features, bias=True): 10 | super(LinearWeightNorm, self).__init__() 11 | self.linear = nn.Linear(in_features, out_features, bias=bias) 12 | self.reset_parameters() 13 | 14 | def reset_parameters(self): 15 | nn.init.normal_(self.linear.weight, mean=0.0, std=0.05) 16 | if self.linear.bias is not None: 17 | nn.init.constant_(self.linear.bias, 0) 18 | self.linear = nn.utils.weight_norm(self.linear) 19 | 20 | def extra_repr(self): 21 | return 'in_features={}, out_features={}, bias={}'.format( 22 | self.in_features, self.out_features, self.bias is not None 23 | ) 24 | 25 | def init(self, x, init_scale=1.0): 26 | with torch.no_grad(): 27 | # [batch, out_features] 28 | out = self(x).view(-1, self.linear.out_features) 29 | # [out_features] 30 | mean = out.mean(dim=0) 31 | std = out.std(dim=0) 32 | inv_stdv = init_scale / (std + 1e-6) 33 | 34 | self.linear.weight_g.mul_(inv_stdv.unsqueeze(1)) 35 | if self.linear.bias is not None: 36 | self.linear.bias.add_(-mean).mul_(inv_stdv) 37 | return self(x) 38 | 39 | def forward(self, input): 40 | return self.linear(input) 41 | 42 | 43 | class Conv2dWeightNorm(nn.Module): 44 | """ 45 | Conv2d with weight normalization 46 | """ 47 | 48 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 49 | padding=0, dilation=1, groups=1, bias=True): 50 | super(Conv2dWeightNorm, self).__init__() 51 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, 52 | padding=padding, dilation=dilation, groups=groups, bias=bias) 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self): 56 | nn.init.normal_(self.conv.weight, mean=0.0, std=0.05) 57 | if self.conv.bias is not None: 58 | nn.init.constant_(self.conv.bias, 0) 59 | self.conv = nn.utils.weight_norm(self.conv) 60 | 61 | def init(self, x, init_scale=1.0): 62 | with torch.no_grad(): 63 | # [batch, n_channels, H, W] 64 | out = self(x) 65 | n_channels = out.size(1) 66 | out = out.transpose(0, 1).contiguous().view(n_channels, -1) 67 | # [n_channels] 68 | mean = out.mean(dim=1) 69 | std = out.std(dim=1) 70 | inv_stdv = init_scale / (std + 1e-6) 71 | 72 | self.conv.weight_g.mul_(inv_stdv.view(n_channels, 1, 1, 1)) 73 | if self.conv.bias is not None: 74 | self.conv.bias.add_(-mean).mul_(inv_stdv) 75 | return self(x) 76 | 77 | def forward(self, input): 78 | return self.conv(input) 79 | 80 | @overrides 81 | def extra_repr(self): 82 | return self.conv.extra_repr() 83 | 84 | 85 | class ConvTranspose2dWeightNorm(nn.Module): 86 | """ 87 | Convolution transpose 2d with weight normalization 88 | """ 89 | 90 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 91 | padding=0, output_padding=0, groups=1, bias=True, dilation=1): 92 | super(ConvTranspose2dWeightNorm, self).__init__() 93 | self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, 94 | padding=padding, output_padding=output_padding, groups=groups, 95 | bias=bias, dilation=dilation) 96 | self.reset_parameters() 97 | 98 | def reset_parameters(self): 99 | nn.init.normal_(self.deconv.weight, mean=0.0, std=0.05) 100 | if self.deconv.bias is not None: 101 | nn.init.constant_(self.deconv.bias, 0) 102 | self.deconv = nn.utils.weight_norm(self.deconv, dim=1) 103 | 104 | def _output_padding(self, input, output_size): 105 | return self.deconv._output_padding(input, output_size) 106 | 107 | def init(self, x, init_scale=1.0): 108 | with torch.no_grad(): 109 | # [batch, n_channels, H, W] 110 | out = self(x) 111 | n_channels = out.size(1) 112 | out = out.transpose(0, 1).contiguous().view(n_channels, -1) 113 | # [n_channels] 114 | mean = out.mean(dim=1) 115 | std = out.std(dim=1) 116 | inv_stdv = init_scale / (std + 1e-6) 117 | 118 | self.deconv.weight_g.mul_(inv_stdv.view(1, n_channels, 1, 1)) 119 | if self.deconv.bias is not None: 120 | self.deconv.bias.add_(-mean).mul_(inv_stdv) 121 | return self(x) 122 | 123 | def forward(self, input): 124 | return self.deconv(input) 125 | 126 | @overrides 127 | def extra_repr(self): 128 | return self.deconv.extra_repr() 129 | -------------------------------------------------------------------------------- /flow_models/wolf/utils.py: -------------------------------------------------------------------------------- 1 | __author__ = 'max' 2 | 3 | from typing import Tuple, List 4 | import torch 5 | from torch._six import inf 6 | 7 | 8 | def norm(p: torch.Tensor, dim: int): 9 | """Computes the norm over all dimensions except dim""" 10 | if dim is None: 11 | return p.norm() 12 | elif dim == 0: 13 | output_size = (p.size(0),) + (1,) * (p.dim() - 1) 14 | return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size) 15 | elif dim == p.dim() - 1: 16 | output_size = (1,) * (p.dim() - 1) + (p.size(-1),) 17 | return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size) 18 | else: 19 | return norm(p.transpose(0, dim), 0).transpose(0, dim) 20 | 21 | 22 | def squeeze2d(x, factor=2) -> torch.Tensor: 23 | assert factor >= 1 24 | if factor == 1: 25 | return x 26 | batch, n_channels, height, width = x.size() 27 | assert height % factor == 0 and width % factor == 0 28 | # [batch, channels, height, width] -> [batch, channels, height/factor, factor, width/factor, factor] 29 | x = x.view(-1, n_channels, height // factor, factor, width // factor, factor) 30 | 31 | # [batch, factor, factor, n_channels, height/factor, width/factor] 32 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() 33 | # [batch, factor*factor*channels, height/factor, width/factor] 34 | x = x.view(-1, factor * factor * n_channels, height // factor, width // factor) 35 | 36 | # [batch, channels, factor, factor, height/factor, width/factor] 37 | # x = x.permute(0, 1, 3, 5, 2, 4).contiguous() 38 | # [batch, channels*factor*factor, height/factor, width/factor] 39 | # x = x.view(-1, n_channels * factor * factor, height // factor, width // factor) 40 | return x 41 | 42 | 43 | def unsqueeze2d(x: torch.Tensor, factor=2) -> torch.Tensor: 44 | factor = int(factor) 45 | assert factor >= 1 46 | if factor == 1: 47 | return x 48 | batch, n_channels, height, width = x.size() 49 | num_bins = factor ** 2 50 | assert n_channels >= num_bins and n_channels % num_bins == 0 51 | 52 | # [batch, channels, height, width] -> [batch, factor, factor, channels/(factor*factor), height, width] 53 | x = x.view(-1, factor, factor, n_channels // num_bins, height, width) 54 | # [batch, channels/(factor*factor), height, factor, width, factor] 55 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() 56 | 57 | # [batch, channels, height, width] -> [batch, channels/(factor*factor), factor, factor, height, width] 58 | # x = x.view(-1, n_channels // num_bins, factor, factor, height, width) 59 | # [batch, channels/(factor*factor), height, factor, width, factor] 60 | # x = x.permute(0, 1, 4, 2, 5, 3).contiguous() 61 | 62 | # [batch, channels/(factor*factor), height*factor, width*factor] 63 | x = x.view(-1, n_channels // num_bins, height * factor, width * factor) 64 | return x 65 | 66 | 67 | def split2d(x: torch.Tensor, z1_channels) -> Tuple[torch.Tensor, torch.Tensor]: 68 | z1 = x[:, :z1_channels] 69 | z2 = x[:, z1_channels:] 70 | return z1, z2 71 | 72 | 73 | def unsplit2d(xs: List[torch.Tensor]) -> torch.Tensor: 74 | # [batch, channels, heigh, weight] 75 | return torch.cat(xs, dim=1) 76 | 77 | 78 | def exponentialMovingAverage(original, shadow, decay_rate, init=False): 79 | params = dict() 80 | for name, param in shadow.named_parameters(): 81 | params[name] = param 82 | for name, param in original.named_parameters(): 83 | shadow_param = params[name] 84 | if init: 85 | shadow_param.data.copy_(param.data) 86 | else: 87 | shadow_param.data.add_((1 - decay_rate) * (param.data - shadow_param.data)) 88 | 89 | 90 | def logPlusOne(x): 91 | """ 92 | compute log(x + 1) for small x 93 | Args: 94 | x: Tensor 95 | 96 | Returns: Tensor 97 | log(x+1) 98 | 99 | """ 100 | eps=1e-4 101 | mask = x.abs().le(eps).type_as(x) 102 | return x.mul(x.mul(-0.5) + 1.0) * mask + (x + 1.0).log() * (1.0 - mask) 103 | 104 | 105 | def gate(x1, x2): 106 | return x1 * x2.sigmoid_() 107 | 108 | 109 | def total_grad_norm(parameters, norm_type=2): 110 | if isinstance(parameters, torch.Tensor): 111 | parameters = [parameters] 112 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 113 | norm_type = float(norm_type) 114 | if norm_type == inf: 115 | total_norm = max(p.grad.data.abs().max() for p in parameters) 116 | else: 117 | total_norm = 0 118 | for p in parameters: 119 | param_norm = p.grad.data.norm(norm_type) 120 | total_norm += param_norm.item() ** norm_type 121 | total_norm = total_norm ** (1. / norm_type) 122 | return total_norm 123 | 124 | 125 | def make_positions(tensor, padding_idx): 126 | """Replace non-padding symbols with their position numbers. 127 | Position numbers begin at padding_idx+1. Padding symbols are ignored. 128 | """ 129 | mask = tensor.ne(padding_idx).long() 130 | return torch.cumsum(mask, dim=1) * mask 131 | --------------------------------------------------------------------------------