├── .gitignore ├── images ├── denoise.png ├── smoothing.gif ├── celeba_samples.png ├── cifar10_samples.png ├── mnist_samples.png ├── celeba_noisy_samples.png ├── mnist_noisy_samples.png └── cifar10_noisy_samples.png ├── requirements.txt ├── runners ├── __init__.py ├── pixelcnnpp_gradient_sampler_runner.py ├── pixelcnnpp_sampler_runner.py ├── pixelcnnpp_smoothed_train_runner.py └── pixelcnnpp_conditioned_train_runner.py ├── configs ├── pixelcnnpp_smoothed_train_mnist.yml ├── pixelcnnpp_smoothed_train_celeba.yml ├── pixelcnnpp_smoothed_train_cifar10.yml ├── pixelcnnpp_conditioned_train_celeba.yml ├── pixelcnnpp_conditioned_train_mnist.yml ├── pixelcnnpp_conditioned_train_cifar10.yml ├── pixelcnnpp_smoothed_sample.yml ├── pixelcnnpp_reverse_sample.yml └── pixelcnnpp_gradient_sample.yml ├── models └── ema.py ├── dataset.py ├── README.md ├── main.py └── pixelcnnpp ├── layers.py ├── pixelcnnpp.py └── samplers.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | */__pycache__ 3 | runs/ 4 | .ipynb_checkpoints/ 5 | runner.sh 6 | 7 | -------------------------------------------------------------------------------- /images/denoise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/denoise.png -------------------------------------------------------------------------------- /images/smoothing.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/smoothing.gif -------------------------------------------------------------------------------- /images/celeba_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/celeba_samples.png -------------------------------------------------------------------------------- /images/cifar10_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/cifar10_samples.png -------------------------------------------------------------------------------- /images/mnist_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/mnist_samples.png -------------------------------------------------------------------------------- /images/celeba_noisy_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/celeba_noisy_samples.png -------------------------------------------------------------------------------- /images/mnist_noisy_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/mnist_noisy_samples.png -------------------------------------------------------------------------------- /images/cifar10_noisy_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/cifar10_noisy_samples.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.3.4 2 | numpy==1.19.5 3 | PyYAML==5.4.1 4 | tensorboard==2.4.1 5 | torch==1.5.0+cu101 6 | torchvision==0.6.0+cu101 7 | tqdm==4.59.0 -------------------------------------------------------------------------------- /runners/__init__.py: -------------------------------------------------------------------------------- 1 | from runners.pixelcnnpp_smoothed_train_runner import SmoothedPixelCNNPPTrainRunner 2 | from runners.pixelcnnpp_conditioned_train_runner import ReversePixelCNNPPTrainRunner 3 | 4 | from runners.pixelcnnpp_sampler_runner import PixelCNNPPSamplerRunner 5 | from runners.pixelcnnpp_gradient_sampler_runner import PixelCNNPPGradientSamplerRunner 6 | 7 | -------------------------------------------------------------------------------- /configs/pixelcnnpp_smoothed_train_mnist.yml: -------------------------------------------------------------------------------- 1 | #### training: 2 | noise: 0.5 3 | 4 | model: 5 | ema: False 6 | ema_rate: 0.999 7 | 8 | ## MNIST 9 | data_dir: runs/mnist 10 | dataset: MNIST 11 | nr_resnet: 5 12 | nr_filters: 40 13 | nr_logistic_mix: 10 14 | input_channels: 1 15 | 16 | 17 | batch_size: 80 18 | lr: 0.0002 19 | lr_decay: 0.999995 20 | max_epochs: 1000 21 | 22 | #### logging 23 | print_every: 50 24 | save_interval: 10 25 | -------------------------------------------------------------------------------- /configs/pixelcnnpp_smoothed_train_celeba.yml: -------------------------------------------------------------------------------- 1 | #### training: 2 | noise: 0.3 3 | 4 | model: 5 | ema: False 6 | ema_rate: 0.999 7 | 8 | ## celeba 9 | data_dir: runs/celeba 10 | dataset: celeba 11 | nr_resnet: 5 12 | nr_filters: 160 13 | nr_logistic_mix: 10 14 | input_channels: 3 15 | 16 | 17 | batch_size: 80 18 | lr: 0.0002 19 | lr_decay: 0.999995 20 | max_epochs: 1000 21 | 22 | #### logging 23 | print_every: 50 24 | save_interval: 10 25 | -------------------------------------------------------------------------------- /configs/pixelcnnpp_smoothed_train_cifar10.yml: -------------------------------------------------------------------------------- 1 | #### training: 2 | noise: 0.3 3 | 4 | model: 5 | ema: False 6 | ema_rate: 0.999 7 | 8 | ## CIFAR-10 9 | data_dir: runs/cifar10 10 | dataset: CIFAR10 11 | nr_resnet: 5 12 | nr_filters: 160 13 | nr_logistic_mix: 10 14 | input_channels: 3 15 | 16 | batch_size: 80 17 | lr: 0.0002 18 | lr_decay: 0.999995 19 | max_epochs: 1000 20 | 21 | #### logging 22 | print_every: 50 23 | save_interval: 10 24 | -------------------------------------------------------------------------------- /configs/pixelcnnpp_conditioned_train_celeba.yml: -------------------------------------------------------------------------------- 1 | #### training: 2 | noise: 0.3 3 | clean_noise: 0.01 4 | with_logit: False 5 | 6 | model: 7 | ema: False 8 | ema_rate: 0.999 9 | 10 | ## celeba 11 | data_dir: runs/celeba 12 | dataset: celeba 13 | nr_resnet: 5 14 | nr_filters: 160 15 | nr_logistic_mix: 10 16 | input_channels: 3 17 | 18 | batch_size: 50 19 | lr: 0.0002 20 | lr_decay: 0.999995 21 | max_epochs: 1000 22 | 23 | #### logging 24 | print_every: 50 25 | save_interval: 10 26 | -------------------------------------------------------------------------------- /configs/pixelcnnpp_conditioned_train_mnist.yml: -------------------------------------------------------------------------------- 1 | #### training: 2 | noise: 0.5 3 | clean_noise: 0.01 4 | with_logit: False 5 | 6 | model: 7 | ema: False 8 | ema_rate: 0.999 9 | 10 | ## MNIST 11 | data_dir: runs/mnist 12 | dataset: MNIST 13 | nr_resnet: 5 14 | nr_filters: 40 15 | nr_logistic_mix: 10 16 | input_channels: 1 17 | 18 | 19 | batch_size: 50 20 | lr: 0.0002 21 | lr_decay: 0.999995 22 | max_epochs: 1000 23 | 24 | #### logging 25 | print_every: 50 26 | save_interval: 10 27 | -------------------------------------------------------------------------------- /configs/pixelcnnpp_conditioned_train_cifar10.yml: -------------------------------------------------------------------------------- 1 | #### training: 2 | noise: 0.3 3 | clean_noise: 0.01 4 | with_logit: False 5 | 6 | model: 7 | ema: False 8 | ema_rate: 0.999 9 | 10 | ## CIFAR-10 11 | data_dir: runs/cifar10 12 | dataset: CIFAR10 13 | nr_resnet: 5 14 | nr_filters: 160 15 | nr_logistic_mix: 10 16 | input_channels: 3 17 | 18 | 19 | batch_size: 50 20 | lr: 0.0002 21 | lr_decay: 0.999995 22 | max_epochs: 1000 23 | 24 | #### logging 25 | print_every: 50 26 | save_interval: 10 27 | -------------------------------------------------------------------------------- /configs/pixelcnnpp_smoothed_sample.yml: -------------------------------------------------------------------------------- 1 | #### training: 2 | with_logit: False 3 | reverse_sampling: False 4 | ema: False 5 | ema_rate: 0.999 6 | ckpt_path: runs/logs/cifar10_smoothed_0.3/pixelcnn_ckpts/checkpoint.pth 7 | 8 | ### CIFAR-10 9 | data_dir: runs/cifar10 10 | dataset: CIFAR10 11 | nr_resnet: 5 12 | nr_filters: 160 13 | nr_logistic_mix: 10 14 | input_channels: 3 15 | 16 | ### MNIST 17 | #data_dir: runs/mnist 18 | #dataset: MNIST 19 | #nr_resnet: 5 20 | #nr_filters: 40 21 | #nr_logistic_mix: 10 22 | #input_channels: 1 23 | 24 | ### celeba 25 | #data_dir: runs/celeba 26 | #dataset: celeba 27 | #nr_resnet: 5 28 | #nr_filters: 160 #80 29 | #nr_logistic_mix: 10 30 | #input_channels: 3 31 | 32 | batch_size: 200 33 | iteration: 1 34 | lr: 0.0002 35 | lr_decay: 0.999995 36 | max_epochs: 1000 37 | 38 | #### logging 39 | print_every: 50 40 | save_interval: 10 41 | -------------------------------------------------------------------------------- /configs/pixelcnnpp_reverse_sample.yml: -------------------------------------------------------------------------------- 1 | #### training: 2 | with_logit: False 3 | reverse_sampling: True 4 | 5 | noisy_samples_path: "runs/logs/cifar10_0.3_images/samples/samples_CIFAR10.pth" 6 | ema: False 7 | ema_rate: 0.999 8 | ckpt_path: runs/logs/reverse_cifar10_0.3/pixelcnn_ckpts/checkpoint.pth 9 | 10 | ### CIFAR-10 11 | data_dir: runs/cifar10 12 | dataset: CIFAR10 13 | nr_resnet: 5 14 | nr_filters: 160 15 | nr_logistic_mix: 10 16 | input_channels: 3 17 | 18 | ### MNIST 19 | #data_dir: runs/mnist 20 | #dataset: MNIST 21 | #nr_resnet: 5 22 | #nr_filters: 40 23 | #nr_logistic_mix: 10 24 | #input_channels: 1 25 | 26 | ### celeba 27 | #data_dir: runs/celeba 28 | #dataset: celeba 29 | #nr_resnet: 5 30 | #nr_filters: 160 #80 31 | #nr_logistic_mix: 10 32 | #input_channels: 3 33 | 34 | batch_size: 200 35 | iteration: 1 36 | lr: 0.0002 37 | lr_decay: 0.999995 38 | max_epochs: 1000 39 | 40 | #### logging 41 | print_every: 50 42 | save_interval: 10 43 | -------------------------------------------------------------------------------- /configs/pixelcnnpp_gradient_sample.yml: -------------------------------------------------------------------------------- 1 | #### training: 2 | with_logit: False #True 3 | ema: False #True 4 | ema_rate: 0.999 5 | 6 | noisy_samples_path: "runs/logs/cifar10_0.3_images/samples/samples_CIFAR10.pth" 7 | ckpt_path: runs/logs/cifar10_smoothed_0.3/pixelcnn_ckpts/checkpoint.pth 8 | 9 | # CIFAR-10 10 | dataset: CIFAR10 11 | nr_resnet: 5 12 | nr_filters: 160 13 | nr_logistic_mix: 10 14 | input_channels: 3 15 | noise: 0.3 16 | 17 | 18 | #### MNIST 19 | #data_dir: runs/mnist 20 | #dataset: MNIST 21 | #nr_resnet: 5 22 | #nr_filters: 40 23 | #nr_logistic_mix: 10 24 | #input_channels: 1 25 | #noise: 0.5 26 | 27 | ### celeba 28 | #data_dir: runs/celeba 29 | #dataset: celeba 30 | #nr_resnet: 5 31 | #nr_filters: 160 32 | #nr_logistic_mix: 10 33 | #input_channels: 3 34 | #noise: 0.3 35 | 36 | 37 | batch_size: 100 38 | iteration: 1 39 | lr: 0.0002 40 | lr_decay: 0.999995 41 | max_epochs: 1000 42 | 43 | #### logging 44 | print_every: 50 45 | save_interval: 10 46 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class EMAHelper(object): 5 | def __init__(self, mu=0.999): 6 | self.mu = mu 7 | self.shadow = {} 8 | 9 | def register(self, module): 10 | if isinstance(module, nn.DataParallel): 11 | module = module.module 12 | for name, param in module.named_parameters(): 13 | if param.requires_grad: 14 | self.shadow[name] = param.data.clone() 15 | 16 | def update(self, module): 17 | if isinstance(module, nn.DataParallel): 18 | module = module.module 19 | for name, param in module.named_parameters(): 20 | if param.requires_grad: 21 | self.shadow[name].data = ( 22 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data 23 | 24 | def ema(self, module): 25 | if isinstance(module, nn.DataParallel): 26 | module = module.module 27 | for name, param in module.named_parameters(): 28 | if param.requires_grad: 29 | param.data.copy_(self.shadow[name].data) 30 | 31 | def ema_copy(self, module): 32 | if isinstance(module, nn.DataParallel): 33 | inner_module = module.module 34 | module_copy = type(inner_module)( 35 | inner_module.config).to(inner_module.config.device) 36 | module_copy.load_state_dict(inner_module.state_dict()) 37 | module_copy = nn.DataParallel(module_copy) 38 | else: 39 | module_copy = type(module)(module.config).to(module.config.device) 40 | module_copy.load_state_dict(module.state_dict()) 41 | # module_copy = copy.deepcopy(module) 42 | self.ema(module_copy) 43 | return module_copy 44 | 45 | def state_dict(self): 46 | return self.shadow 47 | 48 | def load_state_dict(self, state_dict): 49 | self.shadow = state_dict -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as datasets 5 | 6 | import numpy as np 7 | from torch.utils.data import DataLoader, Subset 8 | from torchvision.datasets import CIFAR10, MNIST, FashionMNIST, ImageFolder 9 | 10 | 11 | def get_dataset(config): 12 | kwargs = {'num_workers': 1, 'pin_memory': True, 'drop_last': True} 13 | rescaling = lambda x: (x - .5) * 2. 14 | ds_transforms = transforms.Compose([transforms.ToTensor(), rescaling]) 15 | 16 | if config.dataset == 'MNIST': 17 | train_loader = torch.utils.data.DataLoader( 18 | datasets.MNIST(os.path.join('runs', 'datasets', 'MNIST'), download=True, 19 | train=True, transform=ds_transforms), 20 | batch_size=config.batch_size, 21 | shuffle=True, **kwargs) 22 | 23 | test_loader = torch.utils.data.DataLoader(datasets.MNIST(config.data_dir, train=False, download=True, 24 | transform=ds_transforms), batch_size=config.batch_size, 25 | shuffle=False, **kwargs) 26 | 27 | elif config.dataset == 'FashionMNIST': 28 | train_loader = torch.utils.data.DataLoader( 29 | datasets.FashionMNIST(os.path.join('runs', 'datasets', 'FashionMNIST'), download=True, 30 | train=True, transform=ds_transforms), 31 | batch_size=config.batch_size, 32 | shuffle=True, **kwargs) 33 | 34 | test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(config.data_dir, train=False, download=True, 35 | transform=ds_transforms), batch_size=config.batch_size, 36 | shuffle=False, **kwargs) 37 | 38 | 39 | elif 'CIFAR10' in config.dataset: 40 | 41 | train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(config.data_dir, train=True, 42 | download=True, transform=ds_transforms), 43 | batch_size=config.batch_size, shuffle=True, **kwargs) 44 | 45 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(config.data_dir, train=False, download=True, 46 | transform=ds_transforms), 47 | batch_size=config.batch_size, 48 | shuffle=False, **kwargs) 49 | 50 | elif "celeba" in config.dataset: 51 | dataset = ImageFolder( 52 | root=os.path.join('/atlas/u/yangsong/sliced_score_matching/run', 'datasets', 'celeba'), 53 | transform=transforms.Compose([ 54 | transforms.CenterCrop(140), 55 | transforms.Resize(32), 56 | transforms.ToTensor(), 57 | rescaling 58 | ])) 59 | num_items = len(dataset) # 202599 60 | indices = list(range(num_items)) 61 | random_state = np.random.get_state() 62 | np.random.seed(2020) 63 | np.random.shuffle(indices) 64 | np.random.set_state(random_state) 65 | train_indices, test_indices = indices[:int(num_items * 0.7)], indices[ 66 | int(num_items * 0.7):int(num_items * 0.8)] 67 | test_dataset = Subset(dataset, test_indices) 68 | dataset = Subset(dataset, train_indices) 69 | 70 | train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, **kwargs) 71 | test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=True, **kwargs) 72 | 73 | return train_loader, test_loader 74 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improved Autoregressive Modeling with Distribution Smoothing 2 | 3 | This repo contains the implementation for the paper Improved Autoregressive Modeling with Distribution Smoothing 4 | 5 | by [Chenlin Meng](https://cs.stanford.edu/~chenlin/), [Jiaming Song](http://tsong.me), [Yang Song](http://yang-song.github.io/), [Shengjia Zhao](http://szhao.me/) and [Stefano Ermon](https://cs.stanford.edu/~ermon/), Stanford AI Lab. 6 | 7 |

8 | 9 | 10 | 11 |

12 | 13 |

14 | 15 | 16 | 17 |

18 | 19 | 20 | ## Running Experiments 21 | 22 | ### Dependencies 23 | 24 | Run the following to install all necessary python packages for our code. 25 | 26 | ```bash 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ### Stage1: Learning the smoothed distribution 31 |

32 | 33 |

34 | 35 | To train the PixelCNN++ model on the smoothed distribution for CIFAR-10, run: 36 | ``` 37 | python main.py --runner SmoothedPixelCNNPPTrainRunner --config pixelcnnpp_smoothed_train_cifar10.yml --doc cifar10_smoothed_0.3 --ni 38 | ``` 39 | 40 | ### Stage2: Reverse smoothing 41 |

42 | 43 |

44 | To reverse the smoothing process, we train a second PixelCNN++ model conditioned on the smoothed distribution. 45 | To train the model on CIFAR-10, run: 46 | 47 | ``` 48 | python main.py --runner SmoothedPixelCNNPPTrainRunner --config pixelcnnpp_conditioned_train_cifar10.yml --doc reverse_cifar10_0.3 --ni 49 | ``` 50 | 51 | 52 | ### Sampling 53 | Sampling from stage 1: 54 | 55 | 56 | **pixelcnnpp_smoothed_sample.yml** needs to be modified. 57 | 58 | **ckpt_path**: path to the model trained on the smoothed data in stage 1. 59 | 60 | The **dataset** parameter might also need to be modified accordingly. Selections are MNIST, CIFAR10, or celeba. 61 | ``` 62 | python main.py --runner PixelCNNPPSamplerRunner --config pixelcnnpp_smoothed_sample.yml --doc cifar10_0.3_images 63 | ``` 64 | 65 | Sampling from stage 2: 66 | 67 | **pixelcnnpp_reverse_sample.yml** needs to be modified. 68 | 69 | **noisy_samples_path**: path to the noisy samples generated by the model trained on the smoothed data in stage 1, 70 | 71 | **ckpt_path**: path to the reverse smoothing model in stage 2. 72 | 73 | The **dataset** parameter might need to be changed accordingly. Selections are MNIST, CIFAR10, or celeba. 74 | ``` 75 | python main.py --runner PixelCNNPPSamplerRunner --config pixelcnnpp_reverse_sample.yml --doc cifar10_denoise_images 76 | ``` 77 | 78 | ## References and Acknowledgements 79 | ``` 80 | @article{meng2021improved, 81 | title={Improved Autoregressive Modeling with Distribution Smoothing}, 82 | author={Meng, Chenlin and Song, Jiaming and Song, Yang and Zhao, Shengjia and Ermon, Stefano}, 83 | journal={arXiv preprint arXiv:2103.15089}, 84 | year={2021} 85 | } 86 | ``` 87 | 88 | 89 | This implementation is based on / inspired by: 90 | 91 | - [https://github.com/pclucas14/pixel-cnn-pp](https://github.com/pclucas14/pixel-cnn-pp) (PixelCNN++ PyTorch implementation), 92 | and 93 | - [https://github.com/ermongroup/ncsnv2](https://github.com/ermongroup/ncsnv2) (code structure). 94 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import traceback 4 | import time 5 | import shutil 6 | import logging 7 | import yaml 8 | import sys 9 | import os 10 | import torch 11 | import numpy as np 12 | from runners import * 13 | 14 | 15 | def parse_args_and_config(): 16 | parser = argparse.ArgumentParser(description=globals()['__doc__']) 17 | 18 | parser.add_argument('--runner', type=str, default='SmoothedPixelCNNPPTrainRunner', help='The runner to execute') 19 | parser.add_argument('--config', type=str, default='pixelcnnpp_smoothed_train.yml', help='Path to the config file') 20 | parser.add_argument('--seed', type=int, default=1234, help='Random seed') 21 | parser.add_argument('--run', type=str, default='runs', help='Path for saving running related data.') 22 | parser.add_argument('--doc', type=str, default='0', help='A string for documentation purpose') 23 | parser.add_argument('--comment', type=str, default='', help='A string for experiment comment') 24 | parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical') 25 | parser.add_argument('--test', action='store_true', help='Whether to test the model') 26 | parser.add_argument('--resume_training', action='store_true', help='Whether to resume training') 27 | parser.add_argument('-i', '--image_folder', type=str, default='images', help="The directory of image outputs") 28 | parser.add_argument('--ni', action='store_true', help="No interaction. Suitable for Slurm Job launcher") 29 | 30 | args = parser.parse_args() 31 | args.log = os.path.join(args.run, 'logs', args.doc) 32 | 33 | # parse config file 34 | with open(os.path.join('configs', args.config), 'r') as f: 35 | config = yaml.load(f, Loader=yaml.FullLoader) 36 | new_config = dict2namespace(config) 37 | 38 | if not args.test: 39 | if not args.resume_training and not (args.runner=='PixelCNNPPSamplerRunner') and not (args.runner=='PixelCNNPP_ELBO_Runner'): 40 | if os.path.exists(args.log): 41 | if args.ni: 42 | shutil.rmtree(args.log) 43 | else: 44 | answer = input("Log folder already exists. Overwrite? (Y/n)\n") 45 | if answer.lower() == 'n': 46 | sys.exit(0) 47 | else: 48 | shutil.rmtree(args.log) 49 | 50 | os.makedirs(args.log, exist_ok=True) 51 | with open(os.path.join(args.log, 'config.yml'), 'w') as f: 52 | yaml.dump(vars(new_config), f, default_flow_style=False) 53 | 54 | # setup logger 55 | level = getattr(logging, args.verbose.upper(), None) 56 | if not isinstance(level, int): 57 | raise ValueError('level {} not supported'.format(args.verbose)) 58 | 59 | handler1 = logging.StreamHandler() 60 | handler2 = logging.FileHandler(os.path.join(args.log, 'stdout.txt')) 61 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') 62 | handler1.setFormatter(formatter) 63 | handler2.setFormatter(formatter) 64 | logger = logging.getLogger() 65 | logger.addHandler(handler1) 66 | logger.addHandler(handler2) 67 | logger.setLevel(level) 68 | 69 | else: 70 | level = getattr(logging, args.verbose.upper(), None) 71 | if not isinstance(level, int): 72 | raise ValueError('level {} not supported'.format(args.verbose)) 73 | 74 | handler1 = logging.StreamHandler() 75 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s') 76 | handler1.setFormatter(formatter) 77 | logger = logging.getLogger() 78 | logger.addHandler(handler1) 79 | logger.setLevel(level) 80 | 81 | # add device 82 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 83 | # device = torch.device('cpu') 84 | logging.info("Using device: {}".format(device)) 85 | new_config.device = device 86 | 87 | # set random seed 88 | torch.manual_seed(args.seed) 89 | np.random.seed(args.seed) 90 | if torch.cuda.is_available(): 91 | torch.cuda.manual_seed_all(args.seed) 92 | 93 | return args, new_config 94 | 95 | 96 | def dict2namespace(config): 97 | namespace = argparse.Namespace() 98 | for key, value in config.items(): 99 | if isinstance(value, dict): 100 | new_value = dict2namespace(value) 101 | else: 102 | new_value = value 103 | setattr(namespace, key, new_value) 104 | return namespace 105 | 106 | 107 | def main(): 108 | args, config = parse_args_and_config() 109 | logging.info("Writing log file to {}".format(args.log)) 110 | logging.info("Exp instance id = {}".format(os.getpid())) 111 | logging.info("Exp comment = {}".format(args.comment)) 112 | logging.info("Config =") 113 | print(">" * 80) 114 | print(yaml.dump(vars(config), default_flow_style=False)) 115 | print("<" * 80) 116 | 117 | try: 118 | runner = eval(args.runner)(args, config) 119 | if not args.test: 120 | runner.train() 121 | else: 122 | runner.test() 123 | except: 124 | logging.error(traceback.format_exc()) 125 | 126 | return 0 127 | 128 | 129 | if __name__ == '__main__': 130 | sys.exit(main()) 131 | -------------------------------------------------------------------------------- /runners/pixelcnnpp_gradient_sampler_runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torchvision.utils as utils 4 | from pixelcnnpp.pixelcnnpp import (PixelCNN, 5 | load_part_of_model, 6 | mix_logistic_loss_1d, 7 | mix_logistic_loss, 8 | ) 9 | 10 | from functools import partial 11 | import dataset 12 | from pixelcnnpp.samplers import * 13 | from torchvision.utils import save_image 14 | import shutil 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | import numbers 18 | import math 19 | import matplotlib.pyplot as plt 20 | import time 21 | 22 | import torchvision 23 | import torch.autograd as autograd 24 | 25 | 26 | 27 | class PixelCNNPPGradientSamplerRunner(object): 28 | def __init__(self, args, config): 29 | self.args = args 30 | self.config = config 31 | 32 | def logit_transform(self, image, lambd=1e-6): 33 | image = .5 * image + .5 34 | image = lambd + (1 - 2 * lambd) * image 35 | latent_image = torch.log(image) - torch.log1p(-image) 36 | ll = F.softplus(-latent_image).sum() + F.softplus(latent_image).sum() + np.prod( 37 | image.shape) * (np.log(1 - 2 * lambd) + np.log(.5)) 38 | nll = -ll 39 | return latent_image, nll 40 | 41 | def train(self): 42 | assert not self.config.ema, "ema sampling is not supported now" 43 | self.load_pixelcnnpp() 44 | self.sample() 45 | 46 | def sample(self): 47 | sample_batch_size = self.config.batch_size 48 | self.ar_model.eval() 49 | model = partial(self.ar_model, sample=True) 50 | 51 | rescaling_inv = lambda x: .5 * x + .5 52 | rescaling = lambda x: (x - .5) * 2. 53 | if self.config.dataset == 'CIFAR10' or self.config.dataset == 'celeba': 54 | x = torch.zeros(sample_batch_size, 3, 32, 32, device=self.config.device) 55 | clamp = False 56 | bisection_iter = 20 57 | basic_sampler = partial(sample_from_discretized_mix_logistic_inverse_CDF, model=model, 58 | nr_mix=self.config.nr_logistic_mix, clamp=clamp, 59 | bisection_iter=bisection_iter) 60 | 61 | elif 'MNIST' in self.config.dataset: 62 | x = torch.zeros(sample_batch_size, 1, 28, 28, device=self.config.device) 63 | clamp = False 64 | bisection_iter = 30 65 | basic_sampler = partial(sample_from_discretized_mix_logistic_inverse_CDF_1d, model=model, 66 | nr_mix=self.config.nr_logistic_mix, clamp=clamp, 67 | bisection_iter=bisection_iter) 68 | 69 | # if os.path.exists(os.path.join(self.args.log, 'images')): 70 | # shutil.rmtree(os.path.join(self.args.log, 'images')) 71 | # 72 | # if os.path.exists(os.path.join(self.args.log, 'samples')): 73 | # shutil.rmtree(os.path.join(self.args.log, 'samples')) 74 | 75 | os.makedirs(os.path.join(self.args.log, 'images'), exist_ok=True) 76 | os.makedirs(os.path.join(self.args.log, 'samples'), exist_ok=True) 77 | 78 | 79 | def sigmoid_transform(samples, lambd=1e-6): 80 | samples = torch.sigmoid(samples) 81 | samples = (samples - lambd) / (1 - 2 * lambd) 82 | return samples 83 | 84 | import pickle 85 | import torch.autograd as autograd 86 | 87 | noisy = torch.load(self.config.noisy_samples_path) 88 | if self.config.with_logit is True: 89 | torchvision.utils.save_image(sigmoid_transform(noisy), os.path.join(self.args.log, 'images', "noisy_samples.png")) 90 | else: 91 | torchvision.utils.save_image(rescaling_inv(noisy), os.path.join(self.args.log, 'images', "noisy_samples.png")) 92 | print(noisy.shape) 93 | # Sampling from current model 94 | 95 | images_array = [] 96 | for it in range(self.config.iteration): 97 | print("{}/{}".format(it, self.config.iteration)) 98 | x = noisy[it * self.config.batch_size: (it + 1) * self.config.batch_size] 99 | # x.requires_grad_(True) 100 | # output = model(x) #.detach() 101 | # # x.requires_grad_(True) 102 | 103 | output = model(x).detach() 104 | x.requires_grad_(True) 105 | 106 | if x.shape[1] == 1: 107 | log_pdf = mix_logistic_loss_1d(x, output, likelihood=True) 108 | else: 109 | log_pdf = mix_logistic_loss(x, output, likelihood=True) 110 | 111 | score = autograd.grad(log_pdf.sum(), x, create_graph=True)[0] 112 | x = x + self.config.noise ** 2 * score 113 | x = x.detach().data 114 | 115 | if it == 0: 116 | if self.config.with_logit is True: 117 | images_concat = torchvision.utils.make_grid(sigmoid_transform(x), nrow=int(x.shape[0] ** 0.5), 118 | padding=0, pad_value=0) 119 | else: 120 | images_concat = torchvision.utils.make_grid(rescaling_inv(x)[:, :, -x.shape[-1]:, :], nrow=int(x.shape[0] ** 0.5), 121 | padding=0, pad_value=0) 122 | torchvision.utils.save_image(images_concat, os.path.join(self.args.log, 'images', "gradient_denoised_samples.png")) 123 | images_array.append(x.data.cpu()) 124 | del(score) 125 | 126 | torch.save(torch.cat(images_array, dim=0), os.path.join(self.args.log, 'samples', "gradient_denoised_samples_{}.pkl".format(self.config.dataset))) 127 | 128 | torch.save(torch.cat(images_array, dim=0), 129 | os.path.join(self.args.log, 'samples', "gradient_denoised_samples_{}.pkl".format(self.config.dataset))) 130 | 131 | 132 | def load_pixelcnnpp(self): 133 | def load_parallel(path=self.config.ckpt_path, loc=self.config.device): 134 | checkpoint = torch.load(path, map_location=loc)[0] 135 | state_dict = checkpoint 136 | for k in list(state_dict.keys()): 137 | if not k.startswith('module.'): 138 | # # remove prefix 139 | state_dict["module."+k] = state_dict[k] 140 | # # delete renamed or unused k 141 | del state_dict[k] 142 | return state_dict 143 | 144 | obs = (1, 56, 28) if 'MNIST' in self.config.dataset else (3, 32, 32) 145 | input_channels = obs[0] 146 | 147 | model = PixelCNN(self.config) 148 | model = model.to(self.config.device) 149 | model = torch.nn.DataParallel(model) 150 | 151 | back_compat = False 152 | if back_compat: 153 | load_part_of_model(model, self.config.ckpt_path, back_compat=True) 154 | else: 155 | model.load_state_dict(torch.load(self.config.ckpt_path, map_location=self.config.device)[0]) 156 | 157 | print('model parameters loaded') 158 | self.ar_model = model 159 | 160 | def test(self): 161 | pass 162 | -------------------------------------------------------------------------------- /runners/pixelcnnpp_sampler_runner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import torchvision.utils as utils 4 | from pixelcnnpp.pixelcnnpp import (PixelCNN, load_part_of_model) 5 | 6 | from functools import partial 7 | import dataset 8 | from pixelcnnpp.samplers import * 9 | from torchvision.utils import save_image 10 | import shutil 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import numbers 14 | import math 15 | import matplotlib.pyplot as plt 16 | import time 17 | 18 | import torchvision 19 | import torch.autograd as autograd 20 | 21 | 22 | class PixelCNNPPSamplerRunner(object): 23 | def __init__(self, args, config): 24 | self.args = args 25 | self.config = config 26 | 27 | def logit_transform(self, image, lambd=1e-6): 28 | image = .5 * image + .5 29 | image = lambd + (1 - 2 * lambd) * image 30 | latent_image = torch.log(image) - torch.log1p(-image) 31 | ll = F.softplus(-latent_image).sum() + F.softplus(latent_image).sum() + np.prod( 32 | image.shape) * (np.log(1 - 2 * lambd) + np.log(.5)) 33 | nll = -ll 34 | return latent_image, nll 35 | 36 | def train(self): 37 | assert not self.config.ema, "ema sampling is not supported now" 38 | self.load_pixelcnnpp() 39 | self.sample() 40 | 41 | def sample(self): 42 | sample_batch_size = self.config.batch_size 43 | self.ar_model.eval() 44 | model = partial(self.ar_model, sample=True) 45 | 46 | rescaling_inv = lambda x: .5 * x + .5 47 | rescaling = lambda x: (x - .5) * 2. 48 | if self.config.dataset == 'CIFAR10' or self.config.dataset == 'celeba': 49 | x = torch.zeros(sample_batch_size, 3, 32, 32, device=self.config.device) 50 | clamp = False 51 | bisection_iter = 20 52 | basic_sampler = partial(sample_from_discretized_mix_logistic_inverse_CDF, model=model, 53 | nr_mix=self.config.nr_logistic_mix, clamp=clamp, 54 | bisection_iter=bisection_iter) 55 | # basic_sampler = lambda x: sample_from_discretized_mix_logistic(x, model, 56 | # self.config.nr_logistic_mix, 57 | # clamp=clamp) 58 | 59 | elif 'MNIST' in self.config.dataset: 60 | x = torch.zeros(sample_batch_size, 1, 28, 28, device=self.config.device) 61 | clamp = False 62 | bisection_iter = 30 63 | basic_sampler = partial(sample_from_discretized_mix_logistic_inverse_CDF_1d, model=model, 64 | nr_mix=self.config.nr_logistic_mix, clamp=clamp, 65 | bisection_iter=bisection_iter) 66 | # basic_sampler = lambda x: sample_from_discretized_mix_logistic_1d(x, model, self.config.nr_logistic_mix, 67 | # clamp=clamp) 68 | 69 | os.makedirs(os.path.join(self.args.log, 'images'), exist_ok=True) 70 | os.makedirs(os.path.join(self.args.log, 'samples'), exist_ok=True) 71 | 72 | def sigmoid_transform(samples, lambd=1e-6): 73 | samples = torch.sigmoid(samples) 74 | samples = (samples - lambd) / (1 - 2 * lambd) 75 | return samples 76 | 77 | if self.config.reverse_sampling: 78 | noisy = torch.load(self.config.noisy_samples_path) 79 | 80 | if self.config.with_logit is True: 81 | images_concat = torchvision.utils.make_grid(sigmoid_transform(noisy), nrow=int(self.config.batch_size ** 0.5), padding=0, 82 | pad_value=0) 83 | else: 84 | images_concat = torchvision.utils.make_grid(rescaling_inv(noisy), nrow=int(self.config.batch_size ** 0.5), padding=0, 85 | pad_value=0) 86 | torchvision.utils.save_image(images_concat, os.path.join(self.args.log, 'images', "original_noisy.png")) 87 | 88 | # Sampling from current model 89 | with torch.no_grad(): 90 | images_array = [] 91 | with torch.no_grad(): 92 | for it in range(self.config.iteration): 93 | if self.config.reverse_sampling: 94 | x = noisy[it * self.config.batch_size: (it+1) * self.config.batch_size] 95 | x = torch.cat([x, x], dim=2) 96 | else: 97 | x = torch.randn_like(x) 98 | 99 | x = x.cuda(non_blocking=True) 100 | for i in range(-x.shape[-1], 0, 1): 101 | for j in range(x.shape[-1]): 102 | print(it, i, j, flush=True) 103 | samples = basic_sampler(x) 104 | x[:, :, i, j] = samples[:, :, i, j] 105 | 106 | if it == 0: 107 | if self.config.with_logit is True: 108 | images_concat = torchvision.utils.make_grid(sigmoid_transform(x), nrow=int(x.shape[0] ** 0.5), 109 | padding=0, pad_value=0) 110 | else: 111 | images_concat = torchvision.utils.make_grid(rescaling_inv(x)[:, :, -x.shape[-1]:, :], nrow=int(x.shape[0] ** 0.5), 112 | padding=0, pad_value=0) 113 | torchvision.utils.save_image(images_concat, os.path.join(self.args.log, 'images', "samples.png")) 114 | images_array.append(x) 115 | torch.save(torch.cat(images_array, dim=0), os.path.join(self.args.log, 'samples', "samples_{}.pth".format(self.config.dataset))) 116 | 117 | torch.save(torch.cat(images_array, dim=0), os.path.join(self.args.log, 'samples', "samples_{}.pth".format(self.config.dataset))) 118 | 119 | 120 | def load_pixelcnnpp(self): 121 | def load_parallel(path=self.config.ckpt_path, loc=self.config.device): 122 | checkpoint = torch.load(path, map_location=loc)[0] 123 | state_dict = checkpoint 124 | for k in list(state_dict.keys()): 125 | if not k.startswith('module.'): 126 | # remove prefix 127 | state_dict["module."+k] = state_dict[k] 128 | # delete renamed or unused k 129 | del state_dict[k] 130 | return state_dict 131 | 132 | obs = (1, 56, 28) if 'MNIST' in self.config.dataset else (3, 32, 32) 133 | input_channels = obs[0] 134 | 135 | model = PixelCNN(self.config) 136 | model = model.to(self.config.device) 137 | model = torch.nn.DataParallel(model) 138 | 139 | back_compat = False 140 | if back_compat: 141 | load_part_of_model(model, self.config.ckpt_path, back_compat=True) 142 | else: 143 | model.load_state_dict(torch.load(self.config.ckpt_path, map_location=self.config.device)[0]) 144 | 145 | print('model parameters loaded') 146 | self.ar_model = model 147 | 148 | def test(self): 149 | pass 150 | -------------------------------------------------------------------------------- /runners/pixelcnnpp_smoothed_train_runner.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | import torch.optim as optim 4 | import torch.optim.lr_scheduler as lr_scheduler 5 | import torchvision.utils as utils 6 | from pixelcnnpp.pixelcnnpp import (PixelCNN, mix_logistic_loss, mix_logistic_loss_1d) 7 | 8 | from functools import partial 9 | from torch.utils.tensorboard import SummaryWriter 10 | import dataset 11 | import torch.nn.functional as F 12 | 13 | from models.ema import EMAHelper 14 | from pixelcnnpp.samplers import * 15 | 16 | 17 | class SmoothedPixelCNNPPTrainRunner(object): 18 | def __init__(self, args, config): 19 | self.args = args 20 | self.config = config 21 | 22 | def train(self): 23 | obs = (1, 28, 28) if 'MNIST' in self.config.dataset else (3, 32, 32) 24 | input_channels = obs[0] 25 | train_loader, test_loader = dataset.get_dataset(self.config) 26 | 27 | model = PixelCNN(self.config) 28 | model = model.to(self.config.device) 29 | model = torch.nn.DataParallel(model) 30 | sample_model = partial(model, sample=True) 31 | 32 | rescaling_inv = lambda x: .5 * x + .5 33 | rescaling = lambda x: (x - .5) * 2. 34 | 35 | if 'MNIST' in self.config.dataset: 36 | loss_op = lambda real, fake: mix_logistic_loss_1d(real, fake) 37 | clamp = False 38 | sample_op = lambda x: sample_from_discretized_mix_logistic_1d(x, sample_model, self.config.nr_logistic_mix, clamp=clamp) 39 | 40 | elif 'CIFAR10' in self.config.dataset: 41 | loss_op = lambda real, fake: mix_logistic_loss(real, fake) 42 | clamp = False 43 | sample_op = lambda x: sample_from_discretized_mix_logistic(x, sample_model, self.config.nr_logistic_mix, clamp=clamp) 44 | 45 | elif 'celeba' in self.config.dataset: 46 | loss_op = lambda real, fake: mix_logistic_loss(real, fake) 47 | clamp = False 48 | sample_op = lambda x: sample_from_discretized_mix_logistic(x, sample_model, self.config.nr_logistic_mix, clamp=clamp) 49 | 50 | else: 51 | raise Exception('{} dataset not in {mnist, cifar10, celeba}'.format(self.config.dataset)) 52 | 53 | if self.config.model.ema: 54 | ema_helper = EMAHelper(mu=self.config.model.ema_rate) 55 | ema_helper.register(model) 56 | else: 57 | ema_helper = None 58 | 59 | optimizer = optim.Adam(model.parameters(), lr=self.config.lr) 60 | scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.config.lr_decay) 61 | 62 | ckpt_path = os.path.join(self.args.log, 'pixelcnn_ckpts') 63 | if not os.path.exists(ckpt_path): 64 | os.makedirs(ckpt_path) 65 | 66 | if self.args.resume_training: 67 | state_dict = torch.load(os.path.join(ckpt_path, 'checkpoint.pth'), map_location=self.config.device) 68 | model.load_state_dict(state_dict[0]) 69 | optimizer.load_state_dict(state_dict[1]) 70 | scheduler.load_state_dict(state_dict[2]) 71 | if len(state_dict) > 3: 72 | epoch = state_dict[3] 73 | if self.config.model.ema: 74 | ema_helper.load_state_dict(states[4]) 75 | print('model parameters loaded') 76 | 77 | tb_path = os.path.join(self.args.log, 'tensorboard') 78 | if os.path.exists(tb_path): 79 | shutil.rmtree(tb_path) 80 | 81 | os.makedirs(tb_path) 82 | tb_logger = SummaryWriter(log_dir=tb_path) 83 | 84 | 85 | def debug_sample(model, data): 86 | model.eval() 87 | data = data.cuda() 88 | with torch.no_grad(): 89 | for i in range(obs[1]): 90 | for j in range(obs[2]): 91 | data_v = data 92 | out_sample = sample_op(data_v) 93 | data[:, :, i, j] = out_sample.data[:, :, i, j] 94 | return data 95 | 96 | print('starting training', flush=True) 97 | writes = 0 98 | for epoch in range(self.config.max_epochs): 99 | train_loss = 0. 100 | model.train() 101 | for batch_idx, (input, _) in enumerate(train_loader): 102 | input = input.cuda(non_blocking=True) 103 | # input: [-1, 1] 104 | ## add noise to the entire image 105 | input = input + torch.randn_like(input) * self.config.noise 106 | output = model(input) 107 | loss = loss_op(input, output) 108 | 109 | optimizer.zero_grad() 110 | loss.backward() 111 | optimizer.step() 112 | if self.config.model.ema: 113 | ema_helper.update(model) 114 | 115 | train_loss += loss.item() 116 | if (batch_idx + 1) % self.config.print_every == 0: 117 | deno = self.config.print_every * self.config.batch_size * np.prod(obs) * np.log(2.) 118 | train_loss = train_loss / deno 119 | print('epoch: {}, batch: {}, loss : {:.4f}'.format(epoch, batch_idx, train_loss), flush=True) 120 | tb_logger.add_scalar('loss', train_loss, global_step=writes) 121 | train_loss = 0. 122 | writes += 1 123 | 124 | # decrease learning rate 125 | scheduler.step() 126 | 127 | if self.config.model.ema: 128 | test_model = ema_helper.ema_copy(model) 129 | else: 130 | test_model = model 131 | 132 | test_model.eval() 133 | test_loss = 0. 134 | with torch.no_grad(): 135 | for batch_idx, (input_var, _) in enumerate(test_loader): 136 | input_var = input_var.cuda(non_blocking=True) 137 | 138 | input_var = input_var + torch.randn_like(input_var) * self.config.noise 139 | output = test_model(input_var) 140 | loss = loss_op(input_var, output) 141 | test_loss += loss.item() 142 | del loss, output 143 | 144 | deno = batch_idx * self.config.batch_size * np.prod(obs) * np.log(2.) 145 | test_loss = test_loss / deno 146 | print('epoch: %s, test loss : %s' % (epoch, test_loss), flush=True) 147 | tb_logger.add_scalar('test_loss', test_loss, global_step=writes) 148 | 149 | if (epoch + 1) % self.config.save_interval == 0: 150 | state_dict = [ 151 | model.state_dict(), 152 | optimizer.state_dict(), 153 | scheduler.state_dict(), 154 | epoch, 155 | ] 156 | if self.config.model.ema: 157 | state_dict.append(ema_helper.state_dict()) 158 | 159 | if (epoch + 1) % (self.config.save_interval * 2) == 0: 160 | torch.save(state_dict, os.path.join(ckpt_path, f'ckpt_epoch_{epoch}.pth')) 161 | torch.save(state_dict, os.path.join(ckpt_path, 'checkpoint.pth')) 162 | 163 | if epoch % 10 == 0: 164 | print('sampling...', flush=True) 165 | sample_t = debug_sample(test_model, input_var[:25]) 166 | sample_t = rescaling_inv(sample_t) 167 | 168 | if not os.path.exists(os.path.join(self.args.log, 'images')): 169 | os.makedirs(os.path.join(self.args.log, 'images')) 170 | utils.save_image(sample_t, os.path.join(self.args.log, 'images', f'sample_epoch_{epoch}.png'), 171 | nrow=5, padding=0) 172 | 173 | if self.config.model.ema: 174 | del test_model 175 | 176 | def test(self): 177 | raise NotImplementedError() 178 | -------------------------------------------------------------------------------- /pixelcnnpp/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils import weight_norm as wn 4 | import torch.nn.functional as F 5 | 6 | 7 | def concat_elu(x): 8 | """ like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU """ 9 | # Pytorch ordering 10 | axis = len(x.size()) - 3 11 | return F.elu(torch.cat([x, -x], dim=axis)) 12 | 13 | 14 | def log_sum_exp(x): 15 | """ numerically stable log_sum_exp implementation that prevents overflow """ 16 | # TF ordering 17 | axis = len(x.size()) - 1 18 | m, _ = torch.max(x, dim=axis) 19 | m2, _ = torch.max(x, dim=axis, keepdim=True) 20 | return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) 21 | 22 | 23 | def down_shift(x, pad=None): 24 | # Pytorch ordering 25 | xs = [int(y) for y in x.size()] 26 | # when downshifting, the last row is removed 27 | x = x[:, :, :xs[2] - 1, :] 28 | # padding left, padding right, padding top, padding bottom 29 | pad = nn.ZeroPad2d((0, 0, 1, 0)) if pad is None else pad 30 | return pad(x) 31 | 32 | 33 | def right_shift(x, pad=None): 34 | # Pytorch ordering 35 | xs = [int(y) for y in x.size()] 36 | # when righshifting, the last column is removed 37 | x = x[:, :, :, :xs[3] - 1] 38 | # padding left, padding right, padding top, padding bottom 39 | pad = nn.ZeroPad2d((1, 0, 0, 0)) if pad is None else pad 40 | return pad(x) 41 | 42 | 43 | def log_prob_from_logits(x): 44 | """ numerically stable log_softmax implementation that prevents overflow """ 45 | # TF ordering 46 | axis = len(x.size()) - 1 47 | m, _ = torch.max(x, dim=axis, keepdim=True) 48 | return x - m - torch.log(torch.sum(torch.exp(x - m), dim=axis, keepdim=True)) 49 | 50 | 51 | def to_one_hot(tensor, n, fill_with=1.): 52 | # we perform one hot encore with respect to the last axis 53 | one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() 54 | if tensor.is_cuda: one_hot = one_hot.cuda() 55 | one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) 56 | return one_hot 57 | 58 | 59 | class nin(nn.Module): 60 | def __init__(self, dim_in, dim_out): 61 | super(nin, self).__init__() 62 | self.lin_a = wn(nn.Linear(dim_in, dim_out)) 63 | self.dim_out = dim_out 64 | 65 | def forward(self, x): 66 | og_x = x 67 | # assumes pytorch ordering 68 | """ a network in network layer (1x1 CONV) """ 69 | # TODO : try with original ordering 70 | x = x.permute(0, 2, 3, 1) 71 | shp = [int(y) for y in x.size()] 72 | out = self.lin_a(x.contiguous().view(shp[0] * shp[1] * shp[2], shp[3])) 73 | shp[-1] = self.dim_out 74 | out = out.view(shp) 75 | return out.permute(0, 3, 1, 2) 76 | 77 | 78 | class down_shifted_conv2d(nn.Module): 79 | def __init__(self, num_filters_in, num_filters_out, filter_size=(2, 3), stride=(1, 1), 80 | shift_output_down=False, norm='weight_norm'): 81 | super(down_shifted_conv2d, self).__init__() 82 | 83 | assert norm in [None, 'batch_norm', 'weight_norm'] 84 | self.conv = nn.Conv2d(num_filters_in, num_filters_out, filter_size, stride) 85 | self.shift_output_down = shift_output_down 86 | self.norm = norm 87 | self.pad = nn.ZeroPad2d((int((filter_size[1] - 1) / 2), # pad left 88 | int((filter_size[1] - 1) / 2), # pad right 89 | filter_size[0] - 1, # pad top 90 | 0)) # pad down 91 | 92 | if norm == 'weight_norm': 93 | self.conv = wn(self.conv) 94 | elif norm == 'batch_norm': 95 | self.bn = nn.BatchNorm2d(num_filters_out) 96 | 97 | if shift_output_down: 98 | self.down_shift = lambda x: down_shift(x, pad=nn.ZeroPad2d((0, 0, 1, 0))) 99 | 100 | def forward(self, x): 101 | x = self.pad(x) 102 | x = self.conv(x) 103 | x = self.bn(x) if self.norm == 'batch_norm' else x 104 | return self.down_shift(x) if self.shift_output_down else x 105 | 106 | 107 | class down_shifted_deconv2d(nn.Module): 108 | def __init__(self, num_filters_in, num_filters_out, filter_size=(2, 3), stride=(1, 1)): 109 | super(down_shifted_deconv2d, self).__init__() 110 | self.deconv = wn(nn.ConvTranspose2d(num_filters_in, num_filters_out, filter_size, stride, 111 | output_padding=1)) 112 | self.filter_size = filter_size 113 | self.stride = stride 114 | 115 | def forward(self, x): 116 | x = self.deconv(x) 117 | xs = [int(y) for y in x.size()] 118 | return x[:, :, :(xs[2] - self.filter_size[0] + 1), 119 | int((self.filter_size[1] - 1) / 2):(xs[3] - int((self.filter_size[1] - 1) / 2))] 120 | 121 | 122 | class down_right_shifted_conv2d(nn.Module): 123 | def __init__(self, num_filters_in, num_filters_out, filter_size=(2, 2), stride=(1, 1), 124 | shift_output_right=False, norm='weight_norm'): 125 | super(down_right_shifted_conv2d, self).__init__() 126 | 127 | assert norm in [None, 'batch_norm', 'weight_norm'] 128 | self.pad = nn.ZeroPad2d((filter_size[1] - 1, 0, filter_size[0] - 1, 0)) 129 | self.conv = nn.Conv2d(num_filters_in, num_filters_out, filter_size, stride=stride) 130 | self.shift_output_right = shift_output_right 131 | self.norm = norm 132 | 133 | if norm == 'weight_norm': 134 | self.conv = wn(self.conv) 135 | elif norm == 'batch_norm': 136 | self.bn = nn.BatchNorm2d(num_filters_out) 137 | 138 | if shift_output_right: 139 | self.right_shift = lambda x: right_shift(x, pad=nn.ZeroPad2d((1, 0, 0, 0))) 140 | 141 | def forward(self, x): 142 | x = self.pad(x) 143 | x = self.conv(x) 144 | x = self.bn(x) if self.norm == 'batch_norm' else x 145 | return self.right_shift(x) if self.shift_output_right else x 146 | 147 | 148 | class down_right_shifted_deconv2d(nn.Module): 149 | def __init__(self, num_filters_in, num_filters_out, filter_size=(2, 2), stride=(1, 1), 150 | shift_output_right=False): 151 | super(down_right_shifted_deconv2d, self).__init__() 152 | self.deconv = wn(nn.ConvTranspose2d(num_filters_in, num_filters_out, filter_size, 153 | stride, output_padding=1)) 154 | self.filter_size = filter_size 155 | self.stride = stride 156 | 157 | def forward(self, x): 158 | x = self.deconv(x) 159 | xs = [int(y) for y in x.size()] 160 | x = x[:, :, :(xs[2] - self.filter_size[0] + 1):, :(xs[3] - self.filter_size[1] + 1)] 161 | return x 162 | 163 | 164 | ''' 165 | skip connection parameter : 0 = no skip connection 166 | 1 = skip connection where skip input size === input size 167 | 2 = skip connection where skip input size === 2 * input size 168 | ''' 169 | 170 | 171 | class gated_resnet(nn.Module): 172 | def __init__(self, num_filters, conv_op, nonlinearity=concat_elu, skip_connection=0): 173 | super(gated_resnet, self).__init__() 174 | self.skip_connection = skip_connection 175 | self.nonlinearity = nonlinearity 176 | self.conv_input = conv_op(2 * num_filters, num_filters) # cuz of concat elu 177 | 178 | if skip_connection != 0: 179 | self.nin_skip = nin(2 * skip_connection * num_filters, num_filters) 180 | 181 | self.dropout = nn.Dropout2d(0.5) 182 | self.conv_out = conv_op(2 * num_filters, 2 * num_filters) 183 | 184 | def forward(self, og_x, a=None): 185 | x = self.conv_input(self.nonlinearity(og_x)) 186 | if a is not None: 187 | x += self.nin_skip(self.nonlinearity(a)) 188 | x = self.nonlinearity(x) 189 | x = self.dropout(x) 190 | x = self.conv_out(x) 191 | a, b = torch.chunk(x, 2, dim=1) 192 | c3 = a * torch.sigmoid(b) 193 | return og_x + c3 194 | -------------------------------------------------------------------------------- /runners/pixelcnnpp_conditioned_train_runner.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | import torch.optim as optim 4 | import torch.optim.lr_scheduler as lr_scheduler 5 | import torchvision.utils as utils 6 | from pixelcnnpp.pixelcnnpp import (PixelCNN, mix_logistic_loss, mix_logistic_loss_1d) 7 | 8 | from functools import partial 9 | from torch.utils.tensorboard import SummaryWriter 10 | import dataset 11 | import torch.nn.functional as F 12 | from models.ema import EMAHelper 13 | from pixelcnnpp.samplers import * 14 | 15 | 16 | class ReversePixelCNNPPTrainRunner(object): 17 | def __init__(self, args, config): 18 | self.args = args 19 | self.config = config 20 | 21 | def train(self): 22 | obs = (1, 28, 28) if 'MNIST' in self.config.dataset else (3, 32, 32) 23 | input_channels = obs[0] 24 | train_loader, test_loader = dataset.get_dataset(self.config) 25 | model = PixelCNN(self.config) 26 | model = model.to(self.config.device) 27 | model = torch.nn.DataParallel(model) 28 | sample_model = partial(model, sample=True) 29 | 30 | rescaling_inv = lambda x: .5 * x + .5 31 | rescaling = lambda x: (x - .5) * 2. 32 | 33 | if 'MNIST' in self.config.dataset: 34 | loss_op = lambda real, fake: mix_logistic_loss_1d(real, fake) 35 | clamp = False 36 | sample_op = lambda x: sample_from_discretized_mix_logistic_1d(x, sample_model, self.config.nr_logistic_mix, 37 | clamp=clamp) 38 | 39 | elif 'CIFAR10' in self.config.dataset: 40 | loss_op = lambda real, fake: mix_logistic_loss(real, fake) 41 | clamp = False 42 | sample_op = lambda x: sample_from_discretized_mix_logistic(x, sample_model, self.config.nr_logistic_mix, 43 | clamp=clamp) 44 | 45 | elif 'celeba' in self.config.dataset: 46 | loss_op = lambda real, fake: mix_logistic_loss(real, fake) 47 | clamp = False 48 | sample_op = lambda x: sample_from_discretized_mix_logistic(x, sample_model, self.config.nr_logistic_mix, 49 | clamp=clamp) 50 | else: 51 | raise Exception('{} dataset not in {mnist, cifar10, celeba}'.format(self.config.dataset)) 52 | 53 | if self.config.model.ema: 54 | ema_helper = EMAHelper(mu=self.config.model.ema_rate) 55 | ema_helper.register(model) 56 | else: 57 | ema_helper = None 58 | 59 | optimizer = optim.Adam(model.parameters(), lr=self.config.lr) 60 | scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.config.lr_decay) 61 | 62 | ckpt_path = os.path.join(self.args.log, 'pixelcnn_ckpts') 63 | if not os.path.exists(ckpt_path): 64 | os.makedirs(ckpt_path) 65 | 66 | if self.args.resume_training: 67 | state_dict = torch.load(os.path.join(ckpt_path, 'checkpoint.pth'), map_location=self.config.device) 68 | model.load_state_dict(state_dict[0]) 69 | optimizer.load_state_dict(state_dict[1]) 70 | scheduler.load_state_dict(state_dict[2]) 71 | if len(state_dict) > 3: 72 | epoch = state_dict[3] 73 | if self.config.model.ema: 74 | ema_helper.load_state_dict(states[4]) 75 | print('model parameters loaded') 76 | 77 | tb_path = os.path.join(self.args.log, 'tensorboard') 78 | if os.path.exists(tb_path): 79 | shutil.rmtree(tb_path) 80 | 81 | os.makedirs(tb_path) 82 | tb_logger = SummaryWriter(log_dir=tb_path) 83 | 84 | def debug_sample(model, noisy_image): 85 | model.eval() 86 | with torch.no_grad(): 87 | data = torch.cat([noisy_image, noisy_image], dim=2) 88 | for i in range(obs[1], obs[1] * 2, 1): 89 | for j in range(obs[2]): 90 | data_v = data 91 | out_sample = sample_op(data_v) 92 | data[:, :, i, j] = out_sample.data[:, :, i, j] 93 | return data 94 | 95 | print('starting training', flush=True) 96 | writes = 0 97 | for epoch in range(self.config.max_epochs): 98 | train_loss = 0. 99 | model.train() 100 | for batch_idx, (input, _) in enumerate(train_loader): 101 | input = input.cuda(non_blocking=True) 102 | 103 | # input: [-1, 1] 104 | ## add noise to the entire image 105 | noisy_input = input + torch.randn_like(input) * self.config.noise 106 | clean_input = input + torch.randn_like(input) * self.config.clean_noise # add very small noise 107 | input = torch.cat([noisy_input, clean_input], dim=2) 108 | output = model(input)[:, :, input.shape[-1]:, :] 109 | loss = loss_op(clean_input, output) 110 | optimizer.zero_grad() 111 | loss.backward() 112 | optimizer.step() 113 | 114 | if self.config.model.ema: 115 | ema_helper.update(model) 116 | 117 | train_loss += loss.item() 118 | if (batch_idx + 1) % self.config.print_every == 0: 119 | deno = self.config.print_every * self.config.batch_size * np.prod(obs) * np.log(2.) 120 | train_loss = train_loss / deno 121 | print('epoch: {}, batch: {}, loss : {:.4f}'.format(epoch, batch_idx, train_loss), flush=True) 122 | tb_logger.add_scalar('loss', train_loss, global_step=writes) 123 | train_loss = 0. 124 | writes += 1 125 | # decrease learning rate 126 | scheduler.step() 127 | 128 | if self.config.model.ema: 129 | test_model = ema_helper.ema_copy(model) 130 | else: 131 | test_model = model 132 | 133 | test_model.eval() 134 | test_loss = 0. 135 | with torch.no_grad(): 136 | for batch_idx, (input_var, _) in enumerate(test_loader): 137 | input_var = input_var.cuda(non_blocking=True) 138 | 139 | noisy_input_var = input_var + torch.randn_like(input_var) * self.config.noise 140 | clean_input_var = input_var + torch.randn_like(input_var) * self.config.clean_noise #* 0.02 141 | 142 | input_var = torch.cat([noisy_input_var, clean_input_var], dim=2) 143 | output = test_model(input_var)[:, :, input_var.shape[-1]:, :] 144 | loss = loss_op(clean_input_var, output) 145 | test_loss += loss.item() 146 | del loss, output 147 | 148 | 149 | deno = batch_idx * self.config.batch_size * np.prod(obs) * np.log(2.) 150 | test_loss = test_loss / deno 151 | print('epoch: %s, test loss : %s' % (epoch, test_loss), flush=True) 152 | tb_logger.add_scalar('test_loss', test_loss, global_step=writes) 153 | 154 | if (epoch + 1) % self.config.save_interval == 0: 155 | state_dict = [ 156 | model.state_dict(), 157 | optimizer.state_dict(), 158 | scheduler.state_dict(), 159 | epoch, 160 | ] 161 | if self.config.model.ema: 162 | state_dict.append(ema_helper.state_dict()) 163 | 164 | if (epoch + 1) % (self.config.save_interval * 2) == 0: 165 | torch.save(state_dict, os.path.join(ckpt_path, f'ckpt_epoch_{epoch}.pth')) 166 | torch.save(state_dict, os.path.join(ckpt_path, 'checkpoint.pth')) 167 | 168 | if epoch % 10 == 0: 169 | print('sampling...', flush=True) 170 | sample_t = debug_sample(test_model, noisy_input_var[:25]) 171 | sample_t = torch.cat([clean_input_var[:25], sample_t], dim=2) #add original sample 172 | if self.config.with_logit is True: 173 | sample_t = sigmoid_transform(sample_t) 174 | else: 175 | sample_t = rescaling_inv(sample_t) 176 | 177 | if not os.path.exists(os.path.join(self.args.log, 'images')): 178 | os.makedirs(os.path.join(self.args.log, 'images')) 179 | utils.save_image(sample_t, os.path.join(self.args.log, 'images', f'sample_epoch_{epoch}.png'), 180 | nrow=5, padding=0) 181 | if self.config.model.ema: 182 | del test_model 183 | 184 | def test(self): 185 | raise NotImplementedError() 186 | -------------------------------------------------------------------------------- /pixelcnnpp/pixelcnnpp.py: -------------------------------------------------------------------------------- 1 | from pixelcnnpp.layers import * 2 | from torch.autograd import Variable 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | 9 | class PixelCNNLayer_up(nn.Module): 10 | def __init__(self, nr_resnet, nr_filters, resnet_nonlinearity): 11 | super(PixelCNNLayer_up, self).__init__() 12 | self.nr_resnet = nr_resnet 13 | # stream from pixels above 14 | self.u_stream = nn.ModuleList([gated_resnet(nr_filters, down_shifted_conv2d, 15 | resnet_nonlinearity, skip_connection=0) 16 | for _ in range(nr_resnet)]) 17 | 18 | # stream from pixels above and to thes left 19 | self.ul_stream = nn.ModuleList([gated_resnet(nr_filters, down_right_shifted_conv2d, 20 | resnet_nonlinearity, skip_connection=1) 21 | for _ in range(nr_resnet)]) 22 | 23 | def forward(self, u, ul): 24 | u_list, ul_list = [], [] 25 | 26 | for i in range(self.nr_resnet): 27 | u = self.u_stream[i](u) 28 | ul = self.ul_stream[i](ul, a=u) 29 | u_list += [u] 30 | ul_list += [ul] 31 | 32 | return u_list, ul_list 33 | 34 | 35 | class PixelCNNLayer_down(nn.Module): 36 | def __init__(self, nr_resnet, nr_filters, resnet_nonlinearity): 37 | super(PixelCNNLayer_down, self).__init__() 38 | self.nr_resnet = nr_resnet 39 | # stream from pixels above 40 | self.u_stream = nn.ModuleList([gated_resnet(nr_filters, down_shifted_conv2d, 41 | resnet_nonlinearity, skip_connection=1) 42 | for _ in range(nr_resnet)]) 43 | 44 | # stream from pixels above and to thes left 45 | self.ul_stream = nn.ModuleList([gated_resnet(nr_filters, down_right_shifted_conv2d, 46 | resnet_nonlinearity, skip_connection=2) 47 | for _ in range(nr_resnet)]) 48 | 49 | def forward(self, u, ul, u_list, ul_list): 50 | for i in range(self.nr_resnet): 51 | u = self.u_stream[i](u, a=u_list.pop()) 52 | ul = self.ul_stream[i](ul, a=torch.cat((u, ul_list.pop()), 1)) 53 | 54 | return u, ul 55 | 56 | 57 | class PixelCNN(nn.Module): 58 | def __init__(self, config): 59 | super(PixelCNN, self).__init__() 60 | self.config = config 61 | nr_resnet = self.config.nr_resnet 62 | nr_filters = self.config.nr_filters 63 | input_channels = self.config.input_channels 64 | nr_logistic_mix = self.config.nr_logistic_mix 65 | resnet_nonlinearity = 'concat_elu' 66 | 67 | if resnet_nonlinearity == 'concat_elu': 68 | self.resnet_nonlinearity = lambda x: concat_elu(x) 69 | else: 70 | raise Exception('right now only concat elu is supported as resnet nonlinearity.') 71 | 72 | self.nr_filters = nr_filters 73 | self.input_channels = input_channels 74 | self.nr_logistic_mix = nr_logistic_mix 75 | self.right_shift_pad = nn.ZeroPad2d((1, 0, 0, 0)) 76 | self.down_shift_pad = nn.ZeroPad2d((0, 0, 1, 0)) 77 | 78 | down_nr_resnet = [nr_resnet] + [nr_resnet + 1] * 2 79 | self.down_layers = nn.ModuleList([PixelCNNLayer_down(down_nr_resnet[i], nr_filters, 80 | self.resnet_nonlinearity) for i in range(3)]) 81 | 82 | self.up_layers = nn.ModuleList([PixelCNNLayer_up(nr_resnet, nr_filters, 83 | self.resnet_nonlinearity) for _ in range(3)]) 84 | 85 | self.downsize_u_stream = nn.ModuleList([down_shifted_conv2d(nr_filters, nr_filters, 86 | stride=(2, 2)) for _ in range(2)]) 87 | 88 | self.downsize_ul_stream = nn.ModuleList([down_right_shifted_conv2d(nr_filters, 89 | nr_filters, stride=(2, 2)) for _ in 90 | range(2)]) 91 | 92 | self.upsize_u_stream = nn.ModuleList([down_shifted_deconv2d(nr_filters, nr_filters, 93 | stride=(2, 2)) for _ in range(2)]) 94 | 95 | self.upsize_ul_stream = nn.ModuleList([down_right_shifted_deconv2d(nr_filters, 96 | nr_filters, stride=(2, 2)) for _ in 97 | range(2)]) 98 | 99 | self.u_init = down_shifted_conv2d(input_channels + 1, nr_filters, filter_size=(2, 3), 100 | shift_output_down=True) 101 | 102 | self.ul_init = nn.ModuleList([down_shifted_conv2d(input_channels + 1, nr_filters, 103 | filter_size=(1, 3), shift_output_down=True), 104 | down_right_shifted_conv2d(input_channels + 1, nr_filters, 105 | filter_size=(2, 1), shift_output_right=True)]) 106 | 107 | num_mix = 3 if self.input_channels == 1 else 10 108 | self.nin_out = nin(nr_filters, num_mix * nr_logistic_mix) 109 | self.init_padding = None 110 | 111 | def forward(self, x, sample=False): 112 | # similar as done in the tf repo : 113 | if self.init_padding is None and not sample: 114 | xs = [int(y) for y in x.size()] 115 | padding = Variable(torch.ones(xs[0], 1, xs[2], xs[3]), requires_grad=False) 116 | self.init_padding = padding.cuda() if x.is_cuda else padding 117 | 118 | if sample: 119 | xs = [int(y) for y in x.size()] 120 | padding = Variable(torch.ones(xs[0], 1, xs[2], xs[3]), requires_grad=False) 121 | padding = padding.cuda() if x.is_cuda else padding 122 | x = torch.cat((x, padding), 1) 123 | 124 | ### UP PASS ### 125 | x = x if sample else torch.cat((x, self.init_padding), 1) 126 | u_list = [self.u_init(x)] 127 | ul_list = [self.ul_init[0](x) + self.ul_init[1](x)] 128 | for i in range(3): 129 | # resnet block 130 | u_out, ul_out = self.up_layers[i](u_list[-1], ul_list[-1]) 131 | u_list += u_out 132 | ul_list += ul_out 133 | 134 | if i != 2: 135 | # downscale (only twice) 136 | u_list += [self.downsize_u_stream[i](u_list[-1])] 137 | ul_list += [self.downsize_ul_stream[i](ul_list[-1])] 138 | 139 | ### DOWN PASS ### 140 | u = u_list.pop() 141 | ul = ul_list.pop() 142 | 143 | for i in range(3): 144 | # resnet block 145 | u, ul = self.down_layers[i](u, ul, u_list, ul_list) 146 | 147 | # upscale (only twice) 148 | if i != 2: 149 | u = self.upsize_u_stream[i](u) 150 | ul = self.upsize_ul_stream[i](ul) 151 | 152 | x_out = self.nin_out(F.elu(ul)) 153 | 154 | assert len(u_list) == len(ul_list) == 0, breakpoint() 155 | 156 | return x_out 157 | 158 | 159 | 160 | 161 | 162 | def mix_logistic_loss(x, l, likelihood=False): 163 | """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ 164 | # Pytorch ordering 165 | x = x.permute(0, 2, 3, 1) 166 | l = l.permute(0, 2, 3, 1) 167 | xs = [int(y) for y in x.size()] 168 | ls = [int(y) for y in l.size()] 169 | 170 | # here and below: unpacking the params of the mixture of logistics 171 | nr_mix = int(ls[-1] / 10) 172 | logit_probs = l[:, :, :, :nr_mix] 173 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3]) # 3 for mean, scale, coef 174 | means = l[:, :, :, :, :nr_mix] 175 | log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) 176 | 177 | coeffs = torch.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) 178 | # here and below: getting the means and adjusting them based on preceding 179 | # sub-pixels 180 | x = x.contiguous() 181 | x = x.unsqueeze(-1) + (torch.zeros(xs + [nr_mix]).cuda()).detach() 182 | m2 = (means[:, :, :, 1, :] + coeffs[:, :, :, 0, :] 183 | * x[:, :, :, 0, :]).view(xs[0], xs[1], xs[2], 1, nr_mix) 184 | 185 | m3 = (means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + 186 | coeffs[:, :, :, 2, :] * x[:, :, :, 1, :]).view(xs[0], xs[1], xs[2], 1, nr_mix) 187 | 188 | means = torch.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3) 189 | centered_x = x - means 190 | inv_stdv = torch.exp(-log_scales) 191 | mid_in = inv_stdv * centered_x 192 | log_probs = mid_in - log_scales - 2. * F.softplus(mid_in) 193 | 194 | if likelihood: 195 | log_probs = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs) 196 | return log_sum_exp(log_probs) 197 | 198 | log_probs = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs) 199 | 200 | return -torch.sum(log_sum_exp(log_probs)) 201 | 202 | 203 | def mix_logistic_loss_1d(x, l, likelihood=False): 204 | """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """ 205 | # Pytorch ordering 206 | x = x.permute(0, 2, 3, 1) 207 | l = l.permute(0, 2, 3, 1) 208 | xs = [int(y) for y in x.size()] 209 | ls = [int(y) for y in l.size()] 210 | 211 | # here and below: unpacking the params of the mixture of logistics 212 | nr_mix = int(ls[-1] / 3) 213 | logit_probs = l[:, :, :, :nr_mix] 214 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2]) # 2 for mean, scale 215 | means = l[:, :, :, :, :nr_mix] 216 | log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) 217 | # here and below: getting the means and adjusting them based on preceding 218 | # sub-pixels 219 | x = x.contiguous() 220 | x = x.unsqueeze(-1) + (torch.zeros(xs + [nr_mix]).cuda()).detach() 221 | 222 | centered_x = x - means 223 | inv_stdv = torch.exp(-log_scales) 224 | mid_in = inv_stdv * centered_x 225 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in) 226 | 227 | if likelihood: 228 | log_probs = torch.sum(log_pdf_mid, dim=3) + log_prob_from_logits(logit_probs) 229 | return log_sum_exp(log_probs) 230 | 231 | log_probs = torch.sum(log_pdf_mid, dim=3) + log_prob_from_logits(logit_probs) 232 | 233 | return -torch.sum(log_sum_exp(log_probs)) 234 | 235 | 236 | def load_part_of_model(model, path, back_compat=False): 237 | params = torch.load(path) 238 | added = 0 239 | for name, param in params.items(): 240 | if back_compat: 241 | name = '.'.join(name.split('.')[1:]) 242 | if name in model.state_dict().keys(): 243 | try: 244 | model.state_dict()[name].copy_(param) 245 | added += 1 246 | except Exception as e: 247 | print(e) 248 | 249 | print('added %s of params:' % (added / float(len(model.state_dict().keys())))) 250 | -------------------------------------------------------------------------------- /pixelcnnpp/samplers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from itertools import product 4 | import tqdm 5 | import torch.autograd as autograd 6 | from torch.nn import functional as F 7 | from pixelcnnpp.layers import to_one_hot 8 | 9 | 10 | def sample_from_discretized_mix_logistic_inverse_CDF(x, model, nr_mix, noise=[], u=None, clamp=True, bisection_iter=15, T=1): 11 | # Pytorch ordering 12 | l = model(x) 13 | l = l.permute(0, 2, 3, 1) 14 | ls = [int(y) for y in l.size()] 15 | xs = ls[:-1] + [3] 16 | 17 | #added 18 | if len(noise) != 0: 19 | noise = noise.permute(0, 2, 3, 1) 20 | #added 21 | 22 | # unpack parameters 23 | logit_probs = l[:, :, :, :nr_mix] / T 24 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3]) 25 | # sample mixture indicator from softmax 26 | if u is None: 27 | u = l.new_empty(l.shape[0], l.shape[1] * l.shape[2] * 3) 28 | u.uniform_(1e-5, 1. - 1e-5) 29 | u = torch.log(u) - torch.log(1. - u) 30 | 31 | u_r, u_g, u_b = torch.chunk(u, chunks=3, dim=-1) 32 | 33 | u_r = u_r.reshape(ls[:-1]) 34 | u_g = u_g.reshape(ls[:-1]) 35 | u_b = u_b.reshape(ls[:-1]) 36 | 37 | log_softmax = torch.log_softmax(logit_probs, dim=-1) 38 | coeffs = torch.tanh(l[:, :, :, :, 2 * nr_mix: 3 * nr_mix]) 39 | means = l[:, :, :, :, :nr_mix] 40 | log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) + np.log(T) 41 | if clamp: 42 | ubs = l.new_ones(ls[:-1]) 43 | lbs = -ubs 44 | else: 45 | ubs = l.new_ones(ls[:-1]) * 20. 46 | lbs = -ubs 47 | 48 | means_r = means[..., 0, :] 49 | log_scales_r = log_scales[..., 0, :] 50 | 51 | def log_cdf_pdf_r(values, mode='cdf', mixtures=False): 52 | values = values.unsqueeze(-1) 53 | centered_values = (values - means_r) / log_scales_r.exp() 54 | 55 | if mode == 'cdf': 56 | log_logistic_cdf = -F.softplus(-centered_values) 57 | log_logistic_sf = -F.softplus(centered_values) 58 | log_cdf = torch.logsumexp(log_softmax + log_logistic_cdf, dim=-1) 59 | log_sf = torch.logsumexp(log_softmax + log_logistic_sf, dim=-1) 60 | logit = log_cdf - log_sf 61 | 62 | return logit if not mixtures else (logit, log_logistic_cdf) 63 | 64 | elif mode == 'pdf': 65 | log_logistic_pdf = -centered_values - log_scales_r - 2. * F.softplus(-centered_values) 66 | log_pdf = torch.logsumexp(log_softmax + log_logistic_pdf, dim=-1) 67 | 68 | return log_pdf if not mixtures else (log_pdf, log_logistic_pdf) 69 | 70 | x0 = binary_search(u_r, lbs.clone(), ubs.clone(), lambda x: log_cdf_pdf_r(x, mode='cdf'), bisection_iter) 71 | 72 | if len(noise) == 0: 73 | means_g = x0.unsqueeze(-1) * coeffs[:, :, :, 0, :] + means[..., 1, :] 74 | else: 75 | means_g = (x0.unsqueeze(-1) + noise[:, :, :, 0].unsqueeze(-1)) * coeffs[:, :, :, 0, :] + means[..., 1, :] 76 | 77 | means_g = means_g.detach() #added, to make autograd sample correct 78 | log_scales_g = log_scales[..., 1, :] 79 | 80 | log_p_r, log_p_r_mixtures = log_cdf_pdf_r(x0, mode='pdf', mixtures=True) 81 | 82 | def log_cdf_pdf_g(values, mode='cdf', mixtures=False): 83 | values = values.unsqueeze(-1) 84 | centered_values = (values - means_g) / log_scales_g.exp() 85 | 86 | if mode == 'cdf': 87 | log_logistic_cdf = log_p_r_mixtures - log_p_r[..., None] - F.softplus(-centered_values) 88 | log_logistic_sf = log_p_r_mixtures - log_p_r[..., None] - F.softplus(centered_values) 89 | log_cdf = torch.logsumexp(log_softmax + log_logistic_cdf, dim=-1) 90 | log_sf = torch.logsumexp(log_softmax + log_logistic_sf, dim=-1) 91 | logit = log_cdf - log_sf 92 | 93 | return logit if not mixtures else (logit, log_logistic_cdf) 94 | 95 | elif mode == 'pdf': 96 | log_logistic_pdf = log_p_r_mixtures - log_p_r[..., None] - centered_values - log_scales_g - 2. * F.softplus( 97 | -centered_values) 98 | log_pdf = torch.logsumexp(log_softmax + log_logistic_pdf, dim=-1) 99 | 100 | return log_pdf if not mixtures else (log_pdf, log_logistic_pdf) 101 | 102 | x1 = binary_search(u_g, lbs.clone(), ubs.clone(), lambda x: log_cdf_pdf_g(x, mode='cdf'), bisection_iter) 103 | 104 | if len(noise) == 0: 105 | means_b = x1.unsqueeze(-1) * coeffs[:, :, :, 2, :] + x0.unsqueeze(-1) * coeffs[:, :, :, 1, :] + means[..., 2, :] 106 | else: 107 | means_b = (x1.unsqueeze(-1) + noise[:, :, :, 1].unsqueeze(-1)) * coeffs[:, :, :, 2, :] + \ 108 | (x0.unsqueeze(-1) + noise[:, :, :, 0].unsqueeze(-1)) * coeffs[:, :, :, 1, :] + means[..., 2, :] 109 | 110 | means_b = means_b.detach() #added, to make autograd sample correct 111 | log_scales_b = log_scales[..., 2, :] 112 | 113 | log_p_g, log_p_g_mixtures = log_cdf_pdf_g(x1, mode='pdf', mixtures=True) 114 | 115 | def log_cdf_pdf_b(values, mode='cdf', mixtures=False): 116 | values = values.unsqueeze(-1) 117 | centered_values = (values - means_b) / log_scales_b.exp() 118 | 119 | if mode == 'cdf': 120 | log_logistic_cdf = log_p_g_mixtures - log_p_g[..., None] - F.softplus(-centered_values) 121 | log_logistic_sf = log_p_g_mixtures - log_p_g[..., None] - F.softplus(centered_values) 122 | log_cdf = torch.logsumexp(log_softmax + log_logistic_cdf, dim=-1) 123 | log_sf = torch.logsumexp(log_softmax + log_logistic_sf, dim=-1) 124 | logit = log_cdf - log_sf 125 | 126 | return logit if not mixtures else (logit, log_logistic_cdf) 127 | 128 | elif mode == 'pdf': 129 | log_logistic_pdf = log_p_g_mixtures - log_p_g[..., None] - centered_values - log_scales_b - 2. * F.softplus( 130 | -centered_values) 131 | log_pdf = torch.logsumexp(log_softmax + log_logistic_pdf, dim=-1) 132 | 133 | return log_pdf if not mixtures else (log_pdf, log_logistic_pdf) 134 | 135 | x2 = binary_search(u_b, lbs.clone(), ubs.clone(), lambda x: log_cdf_pdf_b(x, mode='cdf'), bisection_iter) 136 | 137 | out = torch.cat([x0.view(xs[:-1] + [1]), x1.view(xs[:-1] + [1]), x2.view(xs[:-1] + [1])], dim=3) 138 | # put back in Pytorch ordering 139 | out = out.permute(0, 3, 1, 2) 140 | return out 141 | 142 | 143 | def sample_from_discretized_mix_logistic_inverse_CDF_1d(x, model, nr_mix, u=None, clamp=True, bisection_iter=15): 144 | # Pytorch ordering 145 | 146 | with torch.autograd.profiler.profile(enabled=False, use_cuda=True) as prof: 147 | l = model(x) 148 | # print(prof.key_averages().table(sort_by='cpu_time_total')) 149 | # breakpoint() 150 | 151 | l = l.permute(0, 2, 3, 1) 152 | ls = [int(y) for y in l.size()] 153 | xs = ls[:-1] + [1] 154 | 155 | # unpack parameters 156 | logit_probs = l[:, :, :, :nr_mix] 157 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2]) 158 | # sample mixture indicator from softmax 159 | if u is None: 160 | u = l.new_empty(l.shape[0], l.shape[1] * l.shape[2]) 161 | u.uniform_(1e-5, 1. - 1e-5) 162 | u = torch.log(u) - torch.log(1. - u) 163 | 164 | u_r = u.reshape(ls[:-1]) 165 | 166 | log_softmax = torch.log_softmax(logit_probs, dim=-1) 167 | means = l[:, :, :, :, :nr_mix] 168 | log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) 169 | if clamp is True: 170 | ubs = l.new_ones(ls[:-1]) * 1. 171 | lbs = -ubs 172 | else: 173 | ubs = l.new_ones(ls[:-1]) * 30. 174 | lbs = -ubs 175 | 176 | means_r = means[..., 0, :] 177 | log_scales_r = log_scales[..., 0, :] 178 | 179 | def log_cdf_pdf_r(values, mode='cdf', mixtures=False): 180 | values = values.unsqueeze(-1) 181 | centered_values = (values - means_r) / log_scales_r.exp() 182 | 183 | if mode == 'cdf': 184 | log_logistic_cdf = -F.softplus(-centered_values) 185 | log_logistic_sf = -F.softplus(centered_values) 186 | log_cdf = torch.logsumexp(log_softmax + log_logistic_cdf, dim=-1) 187 | log_sf = torch.logsumexp(log_softmax + log_logistic_sf, dim=-1) 188 | logit = log_cdf - log_sf 189 | 190 | return logit if not mixtures else (logit, log_logistic_cdf) 191 | 192 | elif mode == 'pdf': 193 | log_logistic_pdf = -centered_values - log_scales_r - 2. * F.softplus(-centered_values) 194 | log_pdf = torch.logsumexp(log_softmax + log_logistic_pdf, dim=-1) 195 | 196 | return log_pdf if not mixtures else (log_pdf, log_logistic_pdf) 197 | 198 | x0 = binary_search(u_r, lbs.clone(), ubs.clone(), lambda x: log_cdf_pdf_r(x, mode='cdf'), bisection_iter) 199 | 200 | out = x0.view(xs[:-1] + [1]) 201 | out = out.permute(0, 3, 1, 2) 202 | return out 203 | 204 | 205 | def sample_from_discretized_mix_logistic_1d(x, model, nr_mix, u=None, clamp=True): 206 | # Pytorch ordering 207 | l = model(x) 208 | l = l.permute(0, 2, 3, 1) 209 | ls = [int(y) for y in l.size()] 210 | xs = ls[:-1] + [1] # [3] 211 | 212 | # unpack parameters 213 | logit_probs = l[:, :, :, :nr_mix] 214 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2]) # for mean, scale 215 | 216 | # sample mixture indicator from softmax 217 | if u is None: 218 | u = l.new_empty(l.shape[0], l.shape[1] * l.shape[2] * (nr_mix + 1)) 219 | u.uniform_(1e-5, 1. - 1e-5) 220 | mixture_u, sample_u = torch.split(u, [l.shape[1] * l.shape[2] * nr_mix, 221 | l.shape[1] * l.shape[2] * 1], dim=-1) 222 | mixture_u = mixture_u.reshape(l.shape[0], l.shape[1], l.shape[2], nr_mix) 223 | sample_u = sample_u.reshape(l.shape[0], l.shape[1], l.shape[2], 1) 224 | 225 | mixture_u = logit_probs.data - torch.log(- torch.log(mixture_u)) 226 | _, argmax = mixture_u.max(dim=3) 227 | 228 | one_hot = to_one_hot(argmax, nr_mix) 229 | sel = one_hot.view(xs[:-1] + [1, nr_mix]) 230 | # select logistic parameters 231 | means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) 232 | log_scales = torch.clamp(torch.sum( 233 | l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.) 234 | 235 | x = means + torch.exp(log_scales) * (torch.log(sample_u) - torch.log(1. - sample_u)) 236 | if clamp: 237 | x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.) 238 | else: 239 | x0 = x[:, :, :, 0] 240 | 241 | out = x0.unsqueeze(1) 242 | return out 243 | 244 | 245 | def sample_from_discretized_mix_logistic(x, model, nr_mix, u=None, T=1, clamp=True): 246 | # Pytorch ordering 247 | l = model(x) 248 | l = l.permute(0, 2, 3, 1) 249 | ls = [int(y) for y in l.size()] 250 | xs = ls[:-1] + [3] 251 | 252 | # unpack parameters 253 | logit_probs = l[:, :, :, :nr_mix] / T 254 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3]) 255 | # sample mixture indicator from softmax 256 | if u is None: 257 | u = l.new_empty(l.shape[0], l.shape[1] * l.shape[2] * (nr_mix + 3)) 258 | u.uniform_(1e-5, 1. - 1e-5) 259 | 260 | mixture_u, sample_u = torch.split(u, [l.shape[1] * l.shape[2] * nr_mix, 261 | l.shape[1] * l.shape[2] * 3], dim=-1) 262 | mixture_u = mixture_u.reshape(l.shape[0], l.shape[1], l.shape[2], nr_mix) 263 | sample_u = sample_u.reshape(l.shape[0], l.shape[1], l.shape[2], 3) 264 | 265 | mixture_u = logit_probs.data - torch.log(- torch.log(mixture_u)) 266 | _, argmax = mixture_u.max(dim=3) 267 | 268 | one_hot = to_one_hot(argmax, nr_mix) 269 | sel = one_hot.view(xs[:-1] + [1, nr_mix]) 270 | # select logistic parameters 271 | means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4) 272 | log_scales = torch.clamp(torch.sum( 273 | l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.) + np.log(T) 274 | coeffs = torch.sum(torch.tanh( 275 | l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) * sel, dim=4) 276 | # sample from logistic & clip to interval 277 | # we don't actually round to the nearest 8bit value when sampling 278 | 279 | x = means + torch.exp(log_scales) * (torch.log(sample_u) - torch.log(1. - sample_u)) 280 | if clamp: 281 | x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.) 282 | else: 283 | x0 = x[:, :, :, 0] 284 | 285 | if clamp: 286 | x1 = torch.clamp(torch.clamp( 287 | x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, min=-1.), max=1.) 288 | else: 289 | x1 = x[:, :, :, 1] + coeffs[:, :, :, 0] * x0 290 | 291 | if clamp: 292 | x2 = torch.clamp(torch.clamp( 293 | x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1, min=-1.), max=1.) 294 | else: 295 | x2 = x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1 296 | 297 | out = torch.cat([x0.view(xs[:-1] + [1]), x1.view(xs[:-1] + [1]), x2.view(xs[:-1] + [1])], dim=3) 298 | # put back in Pytorch ordering 299 | out = out.permute(0, 3, 1, 2) 300 | return out 301 | 302 | 303 | def binary_search(log_cdf, lb, ub, cdf_fun, n_iter=15): 304 | with torch.no_grad(): 305 | for i in range(n_iter): 306 | mid = (lb + ub) / 2. 307 | mid_cdf_value = cdf_fun(mid) 308 | right_idxes = mid_cdf_value < log_cdf 309 | left_idxes = ~right_idxes 310 | lb[right_idxes] = torch.min(mid[right_idxes], ub[right_idxes]) 311 | ub[left_idxes] = torch.max(mid[left_idxes], lb[left_idxes]) 312 | 313 | return mid --------------------------------------------------------------------------------