├── .gitignore
├── images
├── denoise.png
├── smoothing.gif
├── celeba_samples.png
├── cifar10_samples.png
├── mnist_samples.png
├── celeba_noisy_samples.png
├── mnist_noisy_samples.png
└── cifar10_noisy_samples.png
├── requirements.txt
├── runners
├── __init__.py
├── pixelcnnpp_gradient_sampler_runner.py
├── pixelcnnpp_sampler_runner.py
├── pixelcnnpp_smoothed_train_runner.py
└── pixelcnnpp_conditioned_train_runner.py
├── configs
├── pixelcnnpp_smoothed_train_mnist.yml
├── pixelcnnpp_smoothed_train_celeba.yml
├── pixelcnnpp_smoothed_train_cifar10.yml
├── pixelcnnpp_conditioned_train_celeba.yml
├── pixelcnnpp_conditioned_train_mnist.yml
├── pixelcnnpp_conditioned_train_cifar10.yml
├── pixelcnnpp_smoothed_sample.yml
├── pixelcnnpp_reverse_sample.yml
└── pixelcnnpp_gradient_sample.yml
├── models
└── ema.py
├── dataset.py
├── README.md
├── main.py
└── pixelcnnpp
├── layers.py
├── pixelcnnpp.py
└── samplers.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | */__pycache__
3 | runs/
4 | .ipynb_checkpoints/
5 | runner.sh
6 |
7 |
--------------------------------------------------------------------------------
/images/denoise.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/denoise.png
--------------------------------------------------------------------------------
/images/smoothing.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/smoothing.gif
--------------------------------------------------------------------------------
/images/celeba_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/celeba_samples.png
--------------------------------------------------------------------------------
/images/cifar10_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/cifar10_samples.png
--------------------------------------------------------------------------------
/images/mnist_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/mnist_samples.png
--------------------------------------------------------------------------------
/images/celeba_noisy_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/celeba_noisy_samples.png
--------------------------------------------------------------------------------
/images/mnist_noisy_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/mnist_noisy_samples.png
--------------------------------------------------------------------------------
/images/cifar10_noisy_samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenlin9/Autoregressive-Modeling-with-Distribution-Smoothing/HEAD/images/cifar10_noisy_samples.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.3.4
2 | numpy==1.19.5
3 | PyYAML==5.4.1
4 | tensorboard==2.4.1
5 | torch==1.5.0+cu101
6 | torchvision==0.6.0+cu101
7 | tqdm==4.59.0
--------------------------------------------------------------------------------
/runners/__init__.py:
--------------------------------------------------------------------------------
1 | from runners.pixelcnnpp_smoothed_train_runner import SmoothedPixelCNNPPTrainRunner
2 | from runners.pixelcnnpp_conditioned_train_runner import ReversePixelCNNPPTrainRunner
3 |
4 | from runners.pixelcnnpp_sampler_runner import PixelCNNPPSamplerRunner
5 | from runners.pixelcnnpp_gradient_sampler_runner import PixelCNNPPGradientSamplerRunner
6 |
7 |
--------------------------------------------------------------------------------
/configs/pixelcnnpp_smoothed_train_mnist.yml:
--------------------------------------------------------------------------------
1 | #### training:
2 | noise: 0.5
3 |
4 | model:
5 | ema: False
6 | ema_rate: 0.999
7 |
8 | ## MNIST
9 | data_dir: runs/mnist
10 | dataset: MNIST
11 | nr_resnet: 5
12 | nr_filters: 40
13 | nr_logistic_mix: 10
14 | input_channels: 1
15 |
16 |
17 | batch_size: 80
18 | lr: 0.0002
19 | lr_decay: 0.999995
20 | max_epochs: 1000
21 |
22 | #### logging
23 | print_every: 50
24 | save_interval: 10
25 |
--------------------------------------------------------------------------------
/configs/pixelcnnpp_smoothed_train_celeba.yml:
--------------------------------------------------------------------------------
1 | #### training:
2 | noise: 0.3
3 |
4 | model:
5 | ema: False
6 | ema_rate: 0.999
7 |
8 | ## celeba
9 | data_dir: runs/celeba
10 | dataset: celeba
11 | nr_resnet: 5
12 | nr_filters: 160
13 | nr_logistic_mix: 10
14 | input_channels: 3
15 |
16 |
17 | batch_size: 80
18 | lr: 0.0002
19 | lr_decay: 0.999995
20 | max_epochs: 1000
21 |
22 | #### logging
23 | print_every: 50
24 | save_interval: 10
25 |
--------------------------------------------------------------------------------
/configs/pixelcnnpp_smoothed_train_cifar10.yml:
--------------------------------------------------------------------------------
1 | #### training:
2 | noise: 0.3
3 |
4 | model:
5 | ema: False
6 | ema_rate: 0.999
7 |
8 | ## CIFAR-10
9 | data_dir: runs/cifar10
10 | dataset: CIFAR10
11 | nr_resnet: 5
12 | nr_filters: 160
13 | nr_logistic_mix: 10
14 | input_channels: 3
15 |
16 | batch_size: 80
17 | lr: 0.0002
18 | lr_decay: 0.999995
19 | max_epochs: 1000
20 |
21 | #### logging
22 | print_every: 50
23 | save_interval: 10
24 |
--------------------------------------------------------------------------------
/configs/pixelcnnpp_conditioned_train_celeba.yml:
--------------------------------------------------------------------------------
1 | #### training:
2 | noise: 0.3
3 | clean_noise: 0.01
4 | with_logit: False
5 |
6 | model:
7 | ema: False
8 | ema_rate: 0.999
9 |
10 | ## celeba
11 | data_dir: runs/celeba
12 | dataset: celeba
13 | nr_resnet: 5
14 | nr_filters: 160
15 | nr_logistic_mix: 10
16 | input_channels: 3
17 |
18 | batch_size: 50
19 | lr: 0.0002
20 | lr_decay: 0.999995
21 | max_epochs: 1000
22 |
23 | #### logging
24 | print_every: 50
25 | save_interval: 10
26 |
--------------------------------------------------------------------------------
/configs/pixelcnnpp_conditioned_train_mnist.yml:
--------------------------------------------------------------------------------
1 | #### training:
2 | noise: 0.5
3 | clean_noise: 0.01
4 | with_logit: False
5 |
6 | model:
7 | ema: False
8 | ema_rate: 0.999
9 |
10 | ## MNIST
11 | data_dir: runs/mnist
12 | dataset: MNIST
13 | nr_resnet: 5
14 | nr_filters: 40
15 | nr_logistic_mix: 10
16 | input_channels: 1
17 |
18 |
19 | batch_size: 50
20 | lr: 0.0002
21 | lr_decay: 0.999995
22 | max_epochs: 1000
23 |
24 | #### logging
25 | print_every: 50
26 | save_interval: 10
27 |
--------------------------------------------------------------------------------
/configs/pixelcnnpp_conditioned_train_cifar10.yml:
--------------------------------------------------------------------------------
1 | #### training:
2 | noise: 0.3
3 | clean_noise: 0.01
4 | with_logit: False
5 |
6 | model:
7 | ema: False
8 | ema_rate: 0.999
9 |
10 | ## CIFAR-10
11 | data_dir: runs/cifar10
12 | dataset: CIFAR10
13 | nr_resnet: 5
14 | nr_filters: 160
15 | nr_logistic_mix: 10
16 | input_channels: 3
17 |
18 |
19 | batch_size: 50
20 | lr: 0.0002
21 | lr_decay: 0.999995
22 | max_epochs: 1000
23 |
24 | #### logging
25 | print_every: 50
26 | save_interval: 10
27 |
--------------------------------------------------------------------------------
/configs/pixelcnnpp_smoothed_sample.yml:
--------------------------------------------------------------------------------
1 | #### training:
2 | with_logit: False
3 | reverse_sampling: False
4 | ema: False
5 | ema_rate: 0.999
6 | ckpt_path: runs/logs/cifar10_smoothed_0.3/pixelcnn_ckpts/checkpoint.pth
7 |
8 | ### CIFAR-10
9 | data_dir: runs/cifar10
10 | dataset: CIFAR10
11 | nr_resnet: 5
12 | nr_filters: 160
13 | nr_logistic_mix: 10
14 | input_channels: 3
15 |
16 | ### MNIST
17 | #data_dir: runs/mnist
18 | #dataset: MNIST
19 | #nr_resnet: 5
20 | #nr_filters: 40
21 | #nr_logistic_mix: 10
22 | #input_channels: 1
23 |
24 | ### celeba
25 | #data_dir: runs/celeba
26 | #dataset: celeba
27 | #nr_resnet: 5
28 | #nr_filters: 160 #80
29 | #nr_logistic_mix: 10
30 | #input_channels: 3
31 |
32 | batch_size: 200
33 | iteration: 1
34 | lr: 0.0002
35 | lr_decay: 0.999995
36 | max_epochs: 1000
37 |
38 | #### logging
39 | print_every: 50
40 | save_interval: 10
41 |
--------------------------------------------------------------------------------
/configs/pixelcnnpp_reverse_sample.yml:
--------------------------------------------------------------------------------
1 | #### training:
2 | with_logit: False
3 | reverse_sampling: True
4 |
5 | noisy_samples_path: "runs/logs/cifar10_0.3_images/samples/samples_CIFAR10.pth"
6 | ema: False
7 | ema_rate: 0.999
8 | ckpt_path: runs/logs/reverse_cifar10_0.3/pixelcnn_ckpts/checkpoint.pth
9 |
10 | ### CIFAR-10
11 | data_dir: runs/cifar10
12 | dataset: CIFAR10
13 | nr_resnet: 5
14 | nr_filters: 160
15 | nr_logistic_mix: 10
16 | input_channels: 3
17 |
18 | ### MNIST
19 | #data_dir: runs/mnist
20 | #dataset: MNIST
21 | #nr_resnet: 5
22 | #nr_filters: 40
23 | #nr_logistic_mix: 10
24 | #input_channels: 1
25 |
26 | ### celeba
27 | #data_dir: runs/celeba
28 | #dataset: celeba
29 | #nr_resnet: 5
30 | #nr_filters: 160 #80
31 | #nr_logistic_mix: 10
32 | #input_channels: 3
33 |
34 | batch_size: 200
35 | iteration: 1
36 | lr: 0.0002
37 | lr_decay: 0.999995
38 | max_epochs: 1000
39 |
40 | #### logging
41 | print_every: 50
42 | save_interval: 10
43 |
--------------------------------------------------------------------------------
/configs/pixelcnnpp_gradient_sample.yml:
--------------------------------------------------------------------------------
1 | #### training:
2 | with_logit: False #True
3 | ema: False #True
4 | ema_rate: 0.999
5 |
6 | noisy_samples_path: "runs/logs/cifar10_0.3_images/samples/samples_CIFAR10.pth"
7 | ckpt_path: runs/logs/cifar10_smoothed_0.3/pixelcnn_ckpts/checkpoint.pth
8 |
9 | # CIFAR-10
10 | dataset: CIFAR10
11 | nr_resnet: 5
12 | nr_filters: 160
13 | nr_logistic_mix: 10
14 | input_channels: 3
15 | noise: 0.3
16 |
17 |
18 | #### MNIST
19 | #data_dir: runs/mnist
20 | #dataset: MNIST
21 | #nr_resnet: 5
22 | #nr_filters: 40
23 | #nr_logistic_mix: 10
24 | #input_channels: 1
25 | #noise: 0.5
26 |
27 | ### celeba
28 | #data_dir: runs/celeba
29 | #dataset: celeba
30 | #nr_resnet: 5
31 | #nr_filters: 160
32 | #nr_logistic_mix: 10
33 | #input_channels: 3
34 | #noise: 0.3
35 |
36 |
37 | batch_size: 100
38 | iteration: 1
39 | lr: 0.0002
40 | lr_decay: 0.999995
41 | max_epochs: 1000
42 |
43 | #### logging
44 | print_every: 50
45 | save_interval: 10
46 |
--------------------------------------------------------------------------------
/models/ema.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class EMAHelper(object):
5 | def __init__(self, mu=0.999):
6 | self.mu = mu
7 | self.shadow = {}
8 |
9 | def register(self, module):
10 | if isinstance(module, nn.DataParallel):
11 | module = module.module
12 | for name, param in module.named_parameters():
13 | if param.requires_grad:
14 | self.shadow[name] = param.data.clone()
15 |
16 | def update(self, module):
17 | if isinstance(module, nn.DataParallel):
18 | module = module.module
19 | for name, param in module.named_parameters():
20 | if param.requires_grad:
21 | self.shadow[name].data = (
22 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data
23 |
24 | def ema(self, module):
25 | if isinstance(module, nn.DataParallel):
26 | module = module.module
27 | for name, param in module.named_parameters():
28 | if param.requires_grad:
29 | param.data.copy_(self.shadow[name].data)
30 |
31 | def ema_copy(self, module):
32 | if isinstance(module, nn.DataParallel):
33 | inner_module = module.module
34 | module_copy = type(inner_module)(
35 | inner_module.config).to(inner_module.config.device)
36 | module_copy.load_state_dict(inner_module.state_dict())
37 | module_copy = nn.DataParallel(module_copy)
38 | else:
39 | module_copy = type(module)(module.config).to(module.config.device)
40 | module_copy.load_state_dict(module.state_dict())
41 | # module_copy = copy.deepcopy(module)
42 | self.ema(module_copy)
43 | return module_copy
44 |
45 | def state_dict(self):
46 | return self.shadow
47 |
48 | def load_state_dict(self, state_dict):
49 | self.shadow = state_dict
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torchvision.transforms as transforms
4 | import torchvision.datasets as datasets
5 |
6 | import numpy as np
7 | from torch.utils.data import DataLoader, Subset
8 | from torchvision.datasets import CIFAR10, MNIST, FashionMNIST, ImageFolder
9 |
10 |
11 | def get_dataset(config):
12 | kwargs = {'num_workers': 1, 'pin_memory': True, 'drop_last': True}
13 | rescaling = lambda x: (x - .5) * 2.
14 | ds_transforms = transforms.Compose([transforms.ToTensor(), rescaling])
15 |
16 | if config.dataset == 'MNIST':
17 | train_loader = torch.utils.data.DataLoader(
18 | datasets.MNIST(os.path.join('runs', 'datasets', 'MNIST'), download=True,
19 | train=True, transform=ds_transforms),
20 | batch_size=config.batch_size,
21 | shuffle=True, **kwargs)
22 |
23 | test_loader = torch.utils.data.DataLoader(datasets.MNIST(config.data_dir, train=False, download=True,
24 | transform=ds_transforms), batch_size=config.batch_size,
25 | shuffle=False, **kwargs)
26 |
27 | elif config.dataset == 'FashionMNIST':
28 | train_loader = torch.utils.data.DataLoader(
29 | datasets.FashionMNIST(os.path.join('runs', 'datasets', 'FashionMNIST'), download=True,
30 | train=True, transform=ds_transforms),
31 | batch_size=config.batch_size,
32 | shuffle=True, **kwargs)
33 |
34 | test_loader = torch.utils.data.DataLoader(datasets.FashionMNIST(config.data_dir, train=False, download=True,
35 | transform=ds_transforms), batch_size=config.batch_size,
36 | shuffle=False, **kwargs)
37 |
38 |
39 | elif 'CIFAR10' in config.dataset:
40 |
41 | train_loader = torch.utils.data.DataLoader(datasets.CIFAR10(config.data_dir, train=True,
42 | download=True, transform=ds_transforms),
43 | batch_size=config.batch_size, shuffle=True, **kwargs)
44 |
45 | test_loader = torch.utils.data.DataLoader(datasets.CIFAR10(config.data_dir, train=False, download=True,
46 | transform=ds_transforms),
47 | batch_size=config.batch_size,
48 | shuffle=False, **kwargs)
49 |
50 | elif "celeba" in config.dataset:
51 | dataset = ImageFolder(
52 | root=os.path.join('/atlas/u/yangsong/sliced_score_matching/run', 'datasets', 'celeba'),
53 | transform=transforms.Compose([
54 | transforms.CenterCrop(140),
55 | transforms.Resize(32),
56 | transforms.ToTensor(),
57 | rescaling
58 | ]))
59 | num_items = len(dataset) # 202599
60 | indices = list(range(num_items))
61 | random_state = np.random.get_state()
62 | np.random.seed(2020)
63 | np.random.shuffle(indices)
64 | np.random.set_state(random_state)
65 | train_indices, test_indices = indices[:int(num_items * 0.7)], indices[
66 | int(num_items * 0.7):int(num_items * 0.8)]
67 | test_dataset = Subset(dataset, test_indices)
68 | dataset = Subset(dataset, train_indices)
69 |
70 | train_loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, **kwargs)
71 | test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=True, **kwargs)
72 |
73 | return train_loader, test_loader
74 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Improved Autoregressive Modeling with Distribution Smoothing
2 |
3 | This repo contains the implementation for the paper Improved Autoregressive Modeling with Distribution Smoothing
4 |
5 | by [Chenlin Meng](https://cs.stanford.edu/~chenlin/), [Jiaming Song](http://tsong.me), [Yang Song](http://yang-song.github.io/), [Shengjia Zhao](http://szhao.me/) and [Stefano Ermon](https://cs.stanford.edu/~ermon/), Stanford AI Lab.
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | ## Running Experiments
21 |
22 | ### Dependencies
23 |
24 | Run the following to install all necessary python packages for our code.
25 |
26 | ```bash
27 | pip install -r requirements.txt
28 | ```
29 |
30 | ### Stage1: Learning the smoothed distribution
31 |
32 |
33 |
34 |
35 | To train the PixelCNN++ model on the smoothed distribution for CIFAR-10, run:
36 | ```
37 | python main.py --runner SmoothedPixelCNNPPTrainRunner --config pixelcnnpp_smoothed_train_cifar10.yml --doc cifar10_smoothed_0.3 --ni
38 | ```
39 |
40 | ### Stage2: Reverse smoothing
41 |
42 |
43 |
44 | To reverse the smoothing process, we train a second PixelCNN++ model conditioned on the smoothed distribution.
45 | To train the model on CIFAR-10, run:
46 |
47 | ```
48 | python main.py --runner SmoothedPixelCNNPPTrainRunner --config pixelcnnpp_conditioned_train_cifar10.yml --doc reverse_cifar10_0.3 --ni
49 | ```
50 |
51 |
52 | ### Sampling
53 | Sampling from stage 1:
54 |
55 |
56 | **pixelcnnpp_smoothed_sample.yml** needs to be modified.
57 |
58 | **ckpt_path**: path to the model trained on the smoothed data in stage 1.
59 |
60 | The **dataset** parameter might also need to be modified accordingly. Selections are MNIST, CIFAR10, or celeba.
61 | ```
62 | python main.py --runner PixelCNNPPSamplerRunner --config pixelcnnpp_smoothed_sample.yml --doc cifar10_0.3_images
63 | ```
64 |
65 | Sampling from stage 2:
66 |
67 | **pixelcnnpp_reverse_sample.yml** needs to be modified.
68 |
69 | **noisy_samples_path**: path to the noisy samples generated by the model trained on the smoothed data in stage 1,
70 |
71 | **ckpt_path**: path to the reverse smoothing model in stage 2.
72 |
73 | The **dataset** parameter might need to be changed accordingly. Selections are MNIST, CIFAR10, or celeba.
74 | ```
75 | python main.py --runner PixelCNNPPSamplerRunner --config pixelcnnpp_reverse_sample.yml --doc cifar10_denoise_images
76 | ```
77 |
78 | ## References and Acknowledgements
79 | ```
80 | @article{meng2021improved,
81 | title={Improved Autoregressive Modeling with Distribution Smoothing},
82 | author={Meng, Chenlin and Song, Jiaming and Song, Yang and Zhao, Shengjia and Ermon, Stefano},
83 | journal={arXiv preprint arXiv:2103.15089},
84 | year={2021}
85 | }
86 | ```
87 |
88 |
89 | This implementation is based on / inspired by:
90 |
91 | - [https://github.com/pclucas14/pixel-cnn-pp](https://github.com/pclucas14/pixel-cnn-pp) (PixelCNN++ PyTorch implementation),
92 | and
93 | - [https://github.com/ermongroup/ncsnv2](https://github.com/ermongroup/ncsnv2) (code structure).
94 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import traceback
4 | import time
5 | import shutil
6 | import logging
7 | import yaml
8 | import sys
9 | import os
10 | import torch
11 | import numpy as np
12 | from runners import *
13 |
14 |
15 | def parse_args_and_config():
16 | parser = argparse.ArgumentParser(description=globals()['__doc__'])
17 |
18 | parser.add_argument('--runner', type=str, default='SmoothedPixelCNNPPTrainRunner', help='The runner to execute')
19 | parser.add_argument('--config', type=str, default='pixelcnnpp_smoothed_train.yml', help='Path to the config file')
20 | parser.add_argument('--seed', type=int, default=1234, help='Random seed')
21 | parser.add_argument('--run', type=str, default='runs', help='Path for saving running related data.')
22 | parser.add_argument('--doc', type=str, default='0', help='A string for documentation purpose')
23 | parser.add_argument('--comment', type=str, default='', help='A string for experiment comment')
24 | parser.add_argument('--verbose', type=str, default='info', help='Verbose level: info | debug | warning | critical')
25 | parser.add_argument('--test', action='store_true', help='Whether to test the model')
26 | parser.add_argument('--resume_training', action='store_true', help='Whether to resume training')
27 | parser.add_argument('-i', '--image_folder', type=str, default='images', help="The directory of image outputs")
28 | parser.add_argument('--ni', action='store_true', help="No interaction. Suitable for Slurm Job launcher")
29 |
30 | args = parser.parse_args()
31 | args.log = os.path.join(args.run, 'logs', args.doc)
32 |
33 | # parse config file
34 | with open(os.path.join('configs', args.config), 'r') as f:
35 | config = yaml.load(f, Loader=yaml.FullLoader)
36 | new_config = dict2namespace(config)
37 |
38 | if not args.test:
39 | if not args.resume_training and not (args.runner=='PixelCNNPPSamplerRunner') and not (args.runner=='PixelCNNPP_ELBO_Runner'):
40 | if os.path.exists(args.log):
41 | if args.ni:
42 | shutil.rmtree(args.log)
43 | else:
44 | answer = input("Log folder already exists. Overwrite? (Y/n)\n")
45 | if answer.lower() == 'n':
46 | sys.exit(0)
47 | else:
48 | shutil.rmtree(args.log)
49 |
50 | os.makedirs(args.log, exist_ok=True)
51 | with open(os.path.join(args.log, 'config.yml'), 'w') as f:
52 | yaml.dump(vars(new_config), f, default_flow_style=False)
53 |
54 | # setup logger
55 | level = getattr(logging, args.verbose.upper(), None)
56 | if not isinstance(level, int):
57 | raise ValueError('level {} not supported'.format(args.verbose))
58 |
59 | handler1 = logging.StreamHandler()
60 | handler2 = logging.FileHandler(os.path.join(args.log, 'stdout.txt'))
61 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
62 | handler1.setFormatter(formatter)
63 | handler2.setFormatter(formatter)
64 | logger = logging.getLogger()
65 | logger.addHandler(handler1)
66 | logger.addHandler(handler2)
67 | logger.setLevel(level)
68 |
69 | else:
70 | level = getattr(logging, args.verbose.upper(), None)
71 | if not isinstance(level, int):
72 | raise ValueError('level {} not supported'.format(args.verbose))
73 |
74 | handler1 = logging.StreamHandler()
75 | formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
76 | handler1.setFormatter(formatter)
77 | logger = logging.getLogger()
78 | logger.addHandler(handler1)
79 | logger.setLevel(level)
80 |
81 | # add device
82 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
83 | # device = torch.device('cpu')
84 | logging.info("Using device: {}".format(device))
85 | new_config.device = device
86 |
87 | # set random seed
88 | torch.manual_seed(args.seed)
89 | np.random.seed(args.seed)
90 | if torch.cuda.is_available():
91 | torch.cuda.manual_seed_all(args.seed)
92 |
93 | return args, new_config
94 |
95 |
96 | def dict2namespace(config):
97 | namespace = argparse.Namespace()
98 | for key, value in config.items():
99 | if isinstance(value, dict):
100 | new_value = dict2namespace(value)
101 | else:
102 | new_value = value
103 | setattr(namespace, key, new_value)
104 | return namespace
105 |
106 |
107 | def main():
108 | args, config = parse_args_and_config()
109 | logging.info("Writing log file to {}".format(args.log))
110 | logging.info("Exp instance id = {}".format(os.getpid()))
111 | logging.info("Exp comment = {}".format(args.comment))
112 | logging.info("Config =")
113 | print(">" * 80)
114 | print(yaml.dump(vars(config), default_flow_style=False))
115 | print("<" * 80)
116 |
117 | try:
118 | runner = eval(args.runner)(args, config)
119 | if not args.test:
120 | runner.train()
121 | else:
122 | runner.test()
123 | except:
124 | logging.error(traceback.format_exc())
125 |
126 | return 0
127 |
128 |
129 | if __name__ == '__main__':
130 | sys.exit(main())
131 |
--------------------------------------------------------------------------------
/runners/pixelcnnpp_gradient_sampler_runner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import torchvision.utils as utils
4 | from pixelcnnpp.pixelcnnpp import (PixelCNN,
5 | load_part_of_model,
6 | mix_logistic_loss_1d,
7 | mix_logistic_loss,
8 | )
9 |
10 | from functools import partial
11 | import dataset
12 | from pixelcnnpp.samplers import *
13 | from torchvision.utils import save_image
14 | import shutil
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | import numbers
18 | import math
19 | import matplotlib.pyplot as plt
20 | import time
21 |
22 | import torchvision
23 | import torch.autograd as autograd
24 |
25 |
26 |
27 | class PixelCNNPPGradientSamplerRunner(object):
28 | def __init__(self, args, config):
29 | self.args = args
30 | self.config = config
31 |
32 | def logit_transform(self, image, lambd=1e-6):
33 | image = .5 * image + .5
34 | image = lambd + (1 - 2 * lambd) * image
35 | latent_image = torch.log(image) - torch.log1p(-image)
36 | ll = F.softplus(-latent_image).sum() + F.softplus(latent_image).sum() + np.prod(
37 | image.shape) * (np.log(1 - 2 * lambd) + np.log(.5))
38 | nll = -ll
39 | return latent_image, nll
40 |
41 | def train(self):
42 | assert not self.config.ema, "ema sampling is not supported now"
43 | self.load_pixelcnnpp()
44 | self.sample()
45 |
46 | def sample(self):
47 | sample_batch_size = self.config.batch_size
48 | self.ar_model.eval()
49 | model = partial(self.ar_model, sample=True)
50 |
51 | rescaling_inv = lambda x: .5 * x + .5
52 | rescaling = lambda x: (x - .5) * 2.
53 | if self.config.dataset == 'CIFAR10' or self.config.dataset == 'celeba':
54 | x = torch.zeros(sample_batch_size, 3, 32, 32, device=self.config.device)
55 | clamp = False
56 | bisection_iter = 20
57 | basic_sampler = partial(sample_from_discretized_mix_logistic_inverse_CDF, model=model,
58 | nr_mix=self.config.nr_logistic_mix, clamp=clamp,
59 | bisection_iter=bisection_iter)
60 |
61 | elif 'MNIST' in self.config.dataset:
62 | x = torch.zeros(sample_batch_size, 1, 28, 28, device=self.config.device)
63 | clamp = False
64 | bisection_iter = 30
65 | basic_sampler = partial(sample_from_discretized_mix_logistic_inverse_CDF_1d, model=model,
66 | nr_mix=self.config.nr_logistic_mix, clamp=clamp,
67 | bisection_iter=bisection_iter)
68 |
69 | # if os.path.exists(os.path.join(self.args.log, 'images')):
70 | # shutil.rmtree(os.path.join(self.args.log, 'images'))
71 | #
72 | # if os.path.exists(os.path.join(self.args.log, 'samples')):
73 | # shutil.rmtree(os.path.join(self.args.log, 'samples'))
74 |
75 | os.makedirs(os.path.join(self.args.log, 'images'), exist_ok=True)
76 | os.makedirs(os.path.join(self.args.log, 'samples'), exist_ok=True)
77 |
78 |
79 | def sigmoid_transform(samples, lambd=1e-6):
80 | samples = torch.sigmoid(samples)
81 | samples = (samples - lambd) / (1 - 2 * lambd)
82 | return samples
83 |
84 | import pickle
85 | import torch.autograd as autograd
86 |
87 | noisy = torch.load(self.config.noisy_samples_path)
88 | if self.config.with_logit is True:
89 | torchvision.utils.save_image(sigmoid_transform(noisy), os.path.join(self.args.log, 'images', "noisy_samples.png"))
90 | else:
91 | torchvision.utils.save_image(rescaling_inv(noisy), os.path.join(self.args.log, 'images', "noisy_samples.png"))
92 | print(noisy.shape)
93 | # Sampling from current model
94 |
95 | images_array = []
96 | for it in range(self.config.iteration):
97 | print("{}/{}".format(it, self.config.iteration))
98 | x = noisy[it * self.config.batch_size: (it + 1) * self.config.batch_size]
99 | # x.requires_grad_(True)
100 | # output = model(x) #.detach()
101 | # # x.requires_grad_(True)
102 |
103 | output = model(x).detach()
104 | x.requires_grad_(True)
105 |
106 | if x.shape[1] == 1:
107 | log_pdf = mix_logistic_loss_1d(x, output, likelihood=True)
108 | else:
109 | log_pdf = mix_logistic_loss(x, output, likelihood=True)
110 |
111 | score = autograd.grad(log_pdf.sum(), x, create_graph=True)[0]
112 | x = x + self.config.noise ** 2 * score
113 | x = x.detach().data
114 |
115 | if it == 0:
116 | if self.config.with_logit is True:
117 | images_concat = torchvision.utils.make_grid(sigmoid_transform(x), nrow=int(x.shape[0] ** 0.5),
118 | padding=0, pad_value=0)
119 | else:
120 | images_concat = torchvision.utils.make_grid(rescaling_inv(x)[:, :, -x.shape[-1]:, :], nrow=int(x.shape[0] ** 0.5),
121 | padding=0, pad_value=0)
122 | torchvision.utils.save_image(images_concat, os.path.join(self.args.log, 'images', "gradient_denoised_samples.png"))
123 | images_array.append(x.data.cpu())
124 | del(score)
125 |
126 | torch.save(torch.cat(images_array, dim=0), os.path.join(self.args.log, 'samples', "gradient_denoised_samples_{}.pkl".format(self.config.dataset)))
127 |
128 | torch.save(torch.cat(images_array, dim=0),
129 | os.path.join(self.args.log, 'samples', "gradient_denoised_samples_{}.pkl".format(self.config.dataset)))
130 |
131 |
132 | def load_pixelcnnpp(self):
133 | def load_parallel(path=self.config.ckpt_path, loc=self.config.device):
134 | checkpoint = torch.load(path, map_location=loc)[0]
135 | state_dict = checkpoint
136 | for k in list(state_dict.keys()):
137 | if not k.startswith('module.'):
138 | # # remove prefix
139 | state_dict["module."+k] = state_dict[k]
140 | # # delete renamed or unused k
141 | del state_dict[k]
142 | return state_dict
143 |
144 | obs = (1, 56, 28) if 'MNIST' in self.config.dataset else (3, 32, 32)
145 | input_channels = obs[0]
146 |
147 | model = PixelCNN(self.config)
148 | model = model.to(self.config.device)
149 | model = torch.nn.DataParallel(model)
150 |
151 | back_compat = False
152 | if back_compat:
153 | load_part_of_model(model, self.config.ckpt_path, back_compat=True)
154 | else:
155 | model.load_state_dict(torch.load(self.config.ckpt_path, map_location=self.config.device)[0])
156 |
157 | print('model parameters loaded')
158 | self.ar_model = model
159 |
160 | def test(self):
161 | pass
162 |
--------------------------------------------------------------------------------
/runners/pixelcnnpp_sampler_runner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import torchvision.utils as utils
4 | from pixelcnnpp.pixelcnnpp import (PixelCNN, load_part_of_model)
5 |
6 | from functools import partial
7 | import dataset
8 | from pixelcnnpp.samplers import *
9 | from torchvision.utils import save_image
10 | import shutil
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | import numbers
14 | import math
15 | import matplotlib.pyplot as plt
16 | import time
17 |
18 | import torchvision
19 | import torch.autograd as autograd
20 |
21 |
22 | class PixelCNNPPSamplerRunner(object):
23 | def __init__(self, args, config):
24 | self.args = args
25 | self.config = config
26 |
27 | def logit_transform(self, image, lambd=1e-6):
28 | image = .5 * image + .5
29 | image = lambd + (1 - 2 * lambd) * image
30 | latent_image = torch.log(image) - torch.log1p(-image)
31 | ll = F.softplus(-latent_image).sum() + F.softplus(latent_image).sum() + np.prod(
32 | image.shape) * (np.log(1 - 2 * lambd) + np.log(.5))
33 | nll = -ll
34 | return latent_image, nll
35 |
36 | def train(self):
37 | assert not self.config.ema, "ema sampling is not supported now"
38 | self.load_pixelcnnpp()
39 | self.sample()
40 |
41 | def sample(self):
42 | sample_batch_size = self.config.batch_size
43 | self.ar_model.eval()
44 | model = partial(self.ar_model, sample=True)
45 |
46 | rescaling_inv = lambda x: .5 * x + .5
47 | rescaling = lambda x: (x - .5) * 2.
48 | if self.config.dataset == 'CIFAR10' or self.config.dataset == 'celeba':
49 | x = torch.zeros(sample_batch_size, 3, 32, 32, device=self.config.device)
50 | clamp = False
51 | bisection_iter = 20
52 | basic_sampler = partial(sample_from_discretized_mix_logistic_inverse_CDF, model=model,
53 | nr_mix=self.config.nr_logistic_mix, clamp=clamp,
54 | bisection_iter=bisection_iter)
55 | # basic_sampler = lambda x: sample_from_discretized_mix_logistic(x, model,
56 | # self.config.nr_logistic_mix,
57 | # clamp=clamp)
58 |
59 | elif 'MNIST' in self.config.dataset:
60 | x = torch.zeros(sample_batch_size, 1, 28, 28, device=self.config.device)
61 | clamp = False
62 | bisection_iter = 30
63 | basic_sampler = partial(sample_from_discretized_mix_logistic_inverse_CDF_1d, model=model,
64 | nr_mix=self.config.nr_logistic_mix, clamp=clamp,
65 | bisection_iter=bisection_iter)
66 | # basic_sampler = lambda x: sample_from_discretized_mix_logistic_1d(x, model, self.config.nr_logistic_mix,
67 | # clamp=clamp)
68 |
69 | os.makedirs(os.path.join(self.args.log, 'images'), exist_ok=True)
70 | os.makedirs(os.path.join(self.args.log, 'samples'), exist_ok=True)
71 |
72 | def sigmoid_transform(samples, lambd=1e-6):
73 | samples = torch.sigmoid(samples)
74 | samples = (samples - lambd) / (1 - 2 * lambd)
75 | return samples
76 |
77 | if self.config.reverse_sampling:
78 | noisy = torch.load(self.config.noisy_samples_path)
79 |
80 | if self.config.with_logit is True:
81 | images_concat = torchvision.utils.make_grid(sigmoid_transform(noisy), nrow=int(self.config.batch_size ** 0.5), padding=0,
82 | pad_value=0)
83 | else:
84 | images_concat = torchvision.utils.make_grid(rescaling_inv(noisy), nrow=int(self.config.batch_size ** 0.5), padding=0,
85 | pad_value=0)
86 | torchvision.utils.save_image(images_concat, os.path.join(self.args.log, 'images', "original_noisy.png"))
87 |
88 | # Sampling from current model
89 | with torch.no_grad():
90 | images_array = []
91 | with torch.no_grad():
92 | for it in range(self.config.iteration):
93 | if self.config.reverse_sampling:
94 | x = noisy[it * self.config.batch_size: (it+1) * self.config.batch_size]
95 | x = torch.cat([x, x], dim=2)
96 | else:
97 | x = torch.randn_like(x)
98 |
99 | x = x.cuda(non_blocking=True)
100 | for i in range(-x.shape[-1], 0, 1):
101 | for j in range(x.shape[-1]):
102 | print(it, i, j, flush=True)
103 | samples = basic_sampler(x)
104 | x[:, :, i, j] = samples[:, :, i, j]
105 |
106 | if it == 0:
107 | if self.config.with_logit is True:
108 | images_concat = torchvision.utils.make_grid(sigmoid_transform(x), nrow=int(x.shape[0] ** 0.5),
109 | padding=0, pad_value=0)
110 | else:
111 | images_concat = torchvision.utils.make_grid(rescaling_inv(x)[:, :, -x.shape[-1]:, :], nrow=int(x.shape[0] ** 0.5),
112 | padding=0, pad_value=0)
113 | torchvision.utils.save_image(images_concat, os.path.join(self.args.log, 'images', "samples.png"))
114 | images_array.append(x)
115 | torch.save(torch.cat(images_array, dim=0), os.path.join(self.args.log, 'samples', "samples_{}.pth".format(self.config.dataset)))
116 |
117 | torch.save(torch.cat(images_array, dim=0), os.path.join(self.args.log, 'samples', "samples_{}.pth".format(self.config.dataset)))
118 |
119 |
120 | def load_pixelcnnpp(self):
121 | def load_parallel(path=self.config.ckpt_path, loc=self.config.device):
122 | checkpoint = torch.load(path, map_location=loc)[0]
123 | state_dict = checkpoint
124 | for k in list(state_dict.keys()):
125 | if not k.startswith('module.'):
126 | # remove prefix
127 | state_dict["module."+k] = state_dict[k]
128 | # delete renamed or unused k
129 | del state_dict[k]
130 | return state_dict
131 |
132 | obs = (1, 56, 28) if 'MNIST' in self.config.dataset else (3, 32, 32)
133 | input_channels = obs[0]
134 |
135 | model = PixelCNN(self.config)
136 | model = model.to(self.config.device)
137 | model = torch.nn.DataParallel(model)
138 |
139 | back_compat = False
140 | if back_compat:
141 | load_part_of_model(model, self.config.ckpt_path, back_compat=True)
142 | else:
143 | model.load_state_dict(torch.load(self.config.ckpt_path, map_location=self.config.device)[0])
144 |
145 | print('model parameters loaded')
146 | self.ar_model = model
147 |
148 | def test(self):
149 | pass
150 |
--------------------------------------------------------------------------------
/runners/pixelcnnpp_smoothed_train_runner.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import os
3 | import torch.optim as optim
4 | import torch.optim.lr_scheduler as lr_scheduler
5 | import torchvision.utils as utils
6 | from pixelcnnpp.pixelcnnpp import (PixelCNN, mix_logistic_loss, mix_logistic_loss_1d)
7 |
8 | from functools import partial
9 | from torch.utils.tensorboard import SummaryWriter
10 | import dataset
11 | import torch.nn.functional as F
12 |
13 | from models.ema import EMAHelper
14 | from pixelcnnpp.samplers import *
15 |
16 |
17 | class SmoothedPixelCNNPPTrainRunner(object):
18 | def __init__(self, args, config):
19 | self.args = args
20 | self.config = config
21 |
22 | def train(self):
23 | obs = (1, 28, 28) if 'MNIST' in self.config.dataset else (3, 32, 32)
24 | input_channels = obs[0]
25 | train_loader, test_loader = dataset.get_dataset(self.config)
26 |
27 | model = PixelCNN(self.config)
28 | model = model.to(self.config.device)
29 | model = torch.nn.DataParallel(model)
30 | sample_model = partial(model, sample=True)
31 |
32 | rescaling_inv = lambda x: .5 * x + .5
33 | rescaling = lambda x: (x - .5) * 2.
34 |
35 | if 'MNIST' in self.config.dataset:
36 | loss_op = lambda real, fake: mix_logistic_loss_1d(real, fake)
37 | clamp = False
38 | sample_op = lambda x: sample_from_discretized_mix_logistic_1d(x, sample_model, self.config.nr_logistic_mix, clamp=clamp)
39 |
40 | elif 'CIFAR10' in self.config.dataset:
41 | loss_op = lambda real, fake: mix_logistic_loss(real, fake)
42 | clamp = False
43 | sample_op = lambda x: sample_from_discretized_mix_logistic(x, sample_model, self.config.nr_logistic_mix, clamp=clamp)
44 |
45 | elif 'celeba' in self.config.dataset:
46 | loss_op = lambda real, fake: mix_logistic_loss(real, fake)
47 | clamp = False
48 | sample_op = lambda x: sample_from_discretized_mix_logistic(x, sample_model, self.config.nr_logistic_mix, clamp=clamp)
49 |
50 | else:
51 | raise Exception('{} dataset not in {mnist, cifar10, celeba}'.format(self.config.dataset))
52 |
53 | if self.config.model.ema:
54 | ema_helper = EMAHelper(mu=self.config.model.ema_rate)
55 | ema_helper.register(model)
56 | else:
57 | ema_helper = None
58 |
59 | optimizer = optim.Adam(model.parameters(), lr=self.config.lr)
60 | scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.config.lr_decay)
61 |
62 | ckpt_path = os.path.join(self.args.log, 'pixelcnn_ckpts')
63 | if not os.path.exists(ckpt_path):
64 | os.makedirs(ckpt_path)
65 |
66 | if self.args.resume_training:
67 | state_dict = torch.load(os.path.join(ckpt_path, 'checkpoint.pth'), map_location=self.config.device)
68 | model.load_state_dict(state_dict[0])
69 | optimizer.load_state_dict(state_dict[1])
70 | scheduler.load_state_dict(state_dict[2])
71 | if len(state_dict) > 3:
72 | epoch = state_dict[3]
73 | if self.config.model.ema:
74 | ema_helper.load_state_dict(states[4])
75 | print('model parameters loaded')
76 |
77 | tb_path = os.path.join(self.args.log, 'tensorboard')
78 | if os.path.exists(tb_path):
79 | shutil.rmtree(tb_path)
80 |
81 | os.makedirs(tb_path)
82 | tb_logger = SummaryWriter(log_dir=tb_path)
83 |
84 |
85 | def debug_sample(model, data):
86 | model.eval()
87 | data = data.cuda()
88 | with torch.no_grad():
89 | for i in range(obs[1]):
90 | for j in range(obs[2]):
91 | data_v = data
92 | out_sample = sample_op(data_v)
93 | data[:, :, i, j] = out_sample.data[:, :, i, j]
94 | return data
95 |
96 | print('starting training', flush=True)
97 | writes = 0
98 | for epoch in range(self.config.max_epochs):
99 | train_loss = 0.
100 | model.train()
101 | for batch_idx, (input, _) in enumerate(train_loader):
102 | input = input.cuda(non_blocking=True)
103 | # input: [-1, 1]
104 | ## add noise to the entire image
105 | input = input + torch.randn_like(input) * self.config.noise
106 | output = model(input)
107 | loss = loss_op(input, output)
108 |
109 | optimizer.zero_grad()
110 | loss.backward()
111 | optimizer.step()
112 | if self.config.model.ema:
113 | ema_helper.update(model)
114 |
115 | train_loss += loss.item()
116 | if (batch_idx + 1) % self.config.print_every == 0:
117 | deno = self.config.print_every * self.config.batch_size * np.prod(obs) * np.log(2.)
118 | train_loss = train_loss / deno
119 | print('epoch: {}, batch: {}, loss : {:.4f}'.format(epoch, batch_idx, train_loss), flush=True)
120 | tb_logger.add_scalar('loss', train_loss, global_step=writes)
121 | train_loss = 0.
122 | writes += 1
123 |
124 | # decrease learning rate
125 | scheduler.step()
126 |
127 | if self.config.model.ema:
128 | test_model = ema_helper.ema_copy(model)
129 | else:
130 | test_model = model
131 |
132 | test_model.eval()
133 | test_loss = 0.
134 | with torch.no_grad():
135 | for batch_idx, (input_var, _) in enumerate(test_loader):
136 | input_var = input_var.cuda(non_blocking=True)
137 |
138 | input_var = input_var + torch.randn_like(input_var) * self.config.noise
139 | output = test_model(input_var)
140 | loss = loss_op(input_var, output)
141 | test_loss += loss.item()
142 | del loss, output
143 |
144 | deno = batch_idx * self.config.batch_size * np.prod(obs) * np.log(2.)
145 | test_loss = test_loss / deno
146 | print('epoch: %s, test loss : %s' % (epoch, test_loss), flush=True)
147 | tb_logger.add_scalar('test_loss', test_loss, global_step=writes)
148 |
149 | if (epoch + 1) % self.config.save_interval == 0:
150 | state_dict = [
151 | model.state_dict(),
152 | optimizer.state_dict(),
153 | scheduler.state_dict(),
154 | epoch,
155 | ]
156 | if self.config.model.ema:
157 | state_dict.append(ema_helper.state_dict())
158 |
159 | if (epoch + 1) % (self.config.save_interval * 2) == 0:
160 | torch.save(state_dict, os.path.join(ckpt_path, f'ckpt_epoch_{epoch}.pth'))
161 | torch.save(state_dict, os.path.join(ckpt_path, 'checkpoint.pth'))
162 |
163 | if epoch % 10 == 0:
164 | print('sampling...', flush=True)
165 | sample_t = debug_sample(test_model, input_var[:25])
166 | sample_t = rescaling_inv(sample_t)
167 |
168 | if not os.path.exists(os.path.join(self.args.log, 'images')):
169 | os.makedirs(os.path.join(self.args.log, 'images'))
170 | utils.save_image(sample_t, os.path.join(self.args.log, 'images', f'sample_epoch_{epoch}.png'),
171 | nrow=5, padding=0)
172 |
173 | if self.config.model.ema:
174 | del test_model
175 |
176 | def test(self):
177 | raise NotImplementedError()
178 |
--------------------------------------------------------------------------------
/pixelcnnpp/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils import weight_norm as wn
4 | import torch.nn.functional as F
5 |
6 |
7 | def concat_elu(x):
8 | """ like concatenated ReLU (http://arxiv.org/abs/1603.05201), but then with ELU """
9 | # Pytorch ordering
10 | axis = len(x.size()) - 3
11 | return F.elu(torch.cat([x, -x], dim=axis))
12 |
13 |
14 | def log_sum_exp(x):
15 | """ numerically stable log_sum_exp implementation that prevents overflow """
16 | # TF ordering
17 | axis = len(x.size()) - 1
18 | m, _ = torch.max(x, dim=axis)
19 | m2, _ = torch.max(x, dim=axis, keepdim=True)
20 | return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
21 |
22 |
23 | def down_shift(x, pad=None):
24 | # Pytorch ordering
25 | xs = [int(y) for y in x.size()]
26 | # when downshifting, the last row is removed
27 | x = x[:, :, :xs[2] - 1, :]
28 | # padding left, padding right, padding top, padding bottom
29 | pad = nn.ZeroPad2d((0, 0, 1, 0)) if pad is None else pad
30 | return pad(x)
31 |
32 |
33 | def right_shift(x, pad=None):
34 | # Pytorch ordering
35 | xs = [int(y) for y in x.size()]
36 | # when righshifting, the last column is removed
37 | x = x[:, :, :, :xs[3] - 1]
38 | # padding left, padding right, padding top, padding bottom
39 | pad = nn.ZeroPad2d((1, 0, 0, 0)) if pad is None else pad
40 | return pad(x)
41 |
42 |
43 | def log_prob_from_logits(x):
44 | """ numerically stable log_softmax implementation that prevents overflow """
45 | # TF ordering
46 | axis = len(x.size()) - 1
47 | m, _ = torch.max(x, dim=axis, keepdim=True)
48 | return x - m - torch.log(torch.sum(torch.exp(x - m), dim=axis, keepdim=True))
49 |
50 |
51 | def to_one_hot(tensor, n, fill_with=1.):
52 | # we perform one hot encore with respect to the last axis
53 | one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
54 | if tensor.is_cuda: one_hot = one_hot.cuda()
55 | one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
56 | return one_hot
57 |
58 |
59 | class nin(nn.Module):
60 | def __init__(self, dim_in, dim_out):
61 | super(nin, self).__init__()
62 | self.lin_a = wn(nn.Linear(dim_in, dim_out))
63 | self.dim_out = dim_out
64 |
65 | def forward(self, x):
66 | og_x = x
67 | # assumes pytorch ordering
68 | """ a network in network layer (1x1 CONV) """
69 | # TODO : try with original ordering
70 | x = x.permute(0, 2, 3, 1)
71 | shp = [int(y) for y in x.size()]
72 | out = self.lin_a(x.contiguous().view(shp[0] * shp[1] * shp[2], shp[3]))
73 | shp[-1] = self.dim_out
74 | out = out.view(shp)
75 | return out.permute(0, 3, 1, 2)
76 |
77 |
78 | class down_shifted_conv2d(nn.Module):
79 | def __init__(self, num_filters_in, num_filters_out, filter_size=(2, 3), stride=(1, 1),
80 | shift_output_down=False, norm='weight_norm'):
81 | super(down_shifted_conv2d, self).__init__()
82 |
83 | assert norm in [None, 'batch_norm', 'weight_norm']
84 | self.conv = nn.Conv2d(num_filters_in, num_filters_out, filter_size, stride)
85 | self.shift_output_down = shift_output_down
86 | self.norm = norm
87 | self.pad = nn.ZeroPad2d((int((filter_size[1] - 1) / 2), # pad left
88 | int((filter_size[1] - 1) / 2), # pad right
89 | filter_size[0] - 1, # pad top
90 | 0)) # pad down
91 |
92 | if norm == 'weight_norm':
93 | self.conv = wn(self.conv)
94 | elif norm == 'batch_norm':
95 | self.bn = nn.BatchNorm2d(num_filters_out)
96 |
97 | if shift_output_down:
98 | self.down_shift = lambda x: down_shift(x, pad=nn.ZeroPad2d((0, 0, 1, 0)))
99 |
100 | def forward(self, x):
101 | x = self.pad(x)
102 | x = self.conv(x)
103 | x = self.bn(x) if self.norm == 'batch_norm' else x
104 | return self.down_shift(x) if self.shift_output_down else x
105 |
106 |
107 | class down_shifted_deconv2d(nn.Module):
108 | def __init__(self, num_filters_in, num_filters_out, filter_size=(2, 3), stride=(1, 1)):
109 | super(down_shifted_deconv2d, self).__init__()
110 | self.deconv = wn(nn.ConvTranspose2d(num_filters_in, num_filters_out, filter_size, stride,
111 | output_padding=1))
112 | self.filter_size = filter_size
113 | self.stride = stride
114 |
115 | def forward(self, x):
116 | x = self.deconv(x)
117 | xs = [int(y) for y in x.size()]
118 | return x[:, :, :(xs[2] - self.filter_size[0] + 1),
119 | int((self.filter_size[1] - 1) / 2):(xs[3] - int((self.filter_size[1] - 1) / 2))]
120 |
121 |
122 | class down_right_shifted_conv2d(nn.Module):
123 | def __init__(self, num_filters_in, num_filters_out, filter_size=(2, 2), stride=(1, 1),
124 | shift_output_right=False, norm='weight_norm'):
125 | super(down_right_shifted_conv2d, self).__init__()
126 |
127 | assert norm in [None, 'batch_norm', 'weight_norm']
128 | self.pad = nn.ZeroPad2d((filter_size[1] - 1, 0, filter_size[0] - 1, 0))
129 | self.conv = nn.Conv2d(num_filters_in, num_filters_out, filter_size, stride=stride)
130 | self.shift_output_right = shift_output_right
131 | self.norm = norm
132 |
133 | if norm == 'weight_norm':
134 | self.conv = wn(self.conv)
135 | elif norm == 'batch_norm':
136 | self.bn = nn.BatchNorm2d(num_filters_out)
137 |
138 | if shift_output_right:
139 | self.right_shift = lambda x: right_shift(x, pad=nn.ZeroPad2d((1, 0, 0, 0)))
140 |
141 | def forward(self, x):
142 | x = self.pad(x)
143 | x = self.conv(x)
144 | x = self.bn(x) if self.norm == 'batch_norm' else x
145 | return self.right_shift(x) if self.shift_output_right else x
146 |
147 |
148 | class down_right_shifted_deconv2d(nn.Module):
149 | def __init__(self, num_filters_in, num_filters_out, filter_size=(2, 2), stride=(1, 1),
150 | shift_output_right=False):
151 | super(down_right_shifted_deconv2d, self).__init__()
152 | self.deconv = wn(nn.ConvTranspose2d(num_filters_in, num_filters_out, filter_size,
153 | stride, output_padding=1))
154 | self.filter_size = filter_size
155 | self.stride = stride
156 |
157 | def forward(self, x):
158 | x = self.deconv(x)
159 | xs = [int(y) for y in x.size()]
160 | x = x[:, :, :(xs[2] - self.filter_size[0] + 1):, :(xs[3] - self.filter_size[1] + 1)]
161 | return x
162 |
163 |
164 | '''
165 | skip connection parameter : 0 = no skip connection
166 | 1 = skip connection where skip input size === input size
167 | 2 = skip connection where skip input size === 2 * input size
168 | '''
169 |
170 |
171 | class gated_resnet(nn.Module):
172 | def __init__(self, num_filters, conv_op, nonlinearity=concat_elu, skip_connection=0):
173 | super(gated_resnet, self).__init__()
174 | self.skip_connection = skip_connection
175 | self.nonlinearity = nonlinearity
176 | self.conv_input = conv_op(2 * num_filters, num_filters) # cuz of concat elu
177 |
178 | if skip_connection != 0:
179 | self.nin_skip = nin(2 * skip_connection * num_filters, num_filters)
180 |
181 | self.dropout = nn.Dropout2d(0.5)
182 | self.conv_out = conv_op(2 * num_filters, 2 * num_filters)
183 |
184 | def forward(self, og_x, a=None):
185 | x = self.conv_input(self.nonlinearity(og_x))
186 | if a is not None:
187 | x += self.nin_skip(self.nonlinearity(a))
188 | x = self.nonlinearity(x)
189 | x = self.dropout(x)
190 | x = self.conv_out(x)
191 | a, b = torch.chunk(x, 2, dim=1)
192 | c3 = a * torch.sigmoid(b)
193 | return og_x + c3
194 |
--------------------------------------------------------------------------------
/runners/pixelcnnpp_conditioned_train_runner.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import os
3 | import torch.optim as optim
4 | import torch.optim.lr_scheduler as lr_scheduler
5 | import torchvision.utils as utils
6 | from pixelcnnpp.pixelcnnpp import (PixelCNN, mix_logistic_loss, mix_logistic_loss_1d)
7 |
8 | from functools import partial
9 | from torch.utils.tensorboard import SummaryWriter
10 | import dataset
11 | import torch.nn.functional as F
12 | from models.ema import EMAHelper
13 | from pixelcnnpp.samplers import *
14 |
15 |
16 | class ReversePixelCNNPPTrainRunner(object):
17 | def __init__(self, args, config):
18 | self.args = args
19 | self.config = config
20 |
21 | def train(self):
22 | obs = (1, 28, 28) if 'MNIST' in self.config.dataset else (3, 32, 32)
23 | input_channels = obs[0]
24 | train_loader, test_loader = dataset.get_dataset(self.config)
25 | model = PixelCNN(self.config)
26 | model = model.to(self.config.device)
27 | model = torch.nn.DataParallel(model)
28 | sample_model = partial(model, sample=True)
29 |
30 | rescaling_inv = lambda x: .5 * x + .5
31 | rescaling = lambda x: (x - .5) * 2.
32 |
33 | if 'MNIST' in self.config.dataset:
34 | loss_op = lambda real, fake: mix_logistic_loss_1d(real, fake)
35 | clamp = False
36 | sample_op = lambda x: sample_from_discretized_mix_logistic_1d(x, sample_model, self.config.nr_logistic_mix,
37 | clamp=clamp)
38 |
39 | elif 'CIFAR10' in self.config.dataset:
40 | loss_op = lambda real, fake: mix_logistic_loss(real, fake)
41 | clamp = False
42 | sample_op = lambda x: sample_from_discretized_mix_logistic(x, sample_model, self.config.nr_logistic_mix,
43 | clamp=clamp)
44 |
45 | elif 'celeba' in self.config.dataset:
46 | loss_op = lambda real, fake: mix_logistic_loss(real, fake)
47 | clamp = False
48 | sample_op = lambda x: sample_from_discretized_mix_logistic(x, sample_model, self.config.nr_logistic_mix,
49 | clamp=clamp)
50 | else:
51 | raise Exception('{} dataset not in {mnist, cifar10, celeba}'.format(self.config.dataset))
52 |
53 | if self.config.model.ema:
54 | ema_helper = EMAHelper(mu=self.config.model.ema_rate)
55 | ema_helper.register(model)
56 | else:
57 | ema_helper = None
58 |
59 | optimizer = optim.Adam(model.parameters(), lr=self.config.lr)
60 | scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=self.config.lr_decay)
61 |
62 | ckpt_path = os.path.join(self.args.log, 'pixelcnn_ckpts')
63 | if not os.path.exists(ckpt_path):
64 | os.makedirs(ckpt_path)
65 |
66 | if self.args.resume_training:
67 | state_dict = torch.load(os.path.join(ckpt_path, 'checkpoint.pth'), map_location=self.config.device)
68 | model.load_state_dict(state_dict[0])
69 | optimizer.load_state_dict(state_dict[1])
70 | scheduler.load_state_dict(state_dict[2])
71 | if len(state_dict) > 3:
72 | epoch = state_dict[3]
73 | if self.config.model.ema:
74 | ema_helper.load_state_dict(states[4])
75 | print('model parameters loaded')
76 |
77 | tb_path = os.path.join(self.args.log, 'tensorboard')
78 | if os.path.exists(tb_path):
79 | shutil.rmtree(tb_path)
80 |
81 | os.makedirs(tb_path)
82 | tb_logger = SummaryWriter(log_dir=tb_path)
83 |
84 | def debug_sample(model, noisy_image):
85 | model.eval()
86 | with torch.no_grad():
87 | data = torch.cat([noisy_image, noisy_image], dim=2)
88 | for i in range(obs[1], obs[1] * 2, 1):
89 | for j in range(obs[2]):
90 | data_v = data
91 | out_sample = sample_op(data_v)
92 | data[:, :, i, j] = out_sample.data[:, :, i, j]
93 | return data
94 |
95 | print('starting training', flush=True)
96 | writes = 0
97 | for epoch in range(self.config.max_epochs):
98 | train_loss = 0.
99 | model.train()
100 | for batch_idx, (input, _) in enumerate(train_loader):
101 | input = input.cuda(non_blocking=True)
102 |
103 | # input: [-1, 1]
104 | ## add noise to the entire image
105 | noisy_input = input + torch.randn_like(input) * self.config.noise
106 | clean_input = input + torch.randn_like(input) * self.config.clean_noise # add very small noise
107 | input = torch.cat([noisy_input, clean_input], dim=2)
108 | output = model(input)[:, :, input.shape[-1]:, :]
109 | loss = loss_op(clean_input, output)
110 | optimizer.zero_grad()
111 | loss.backward()
112 | optimizer.step()
113 |
114 | if self.config.model.ema:
115 | ema_helper.update(model)
116 |
117 | train_loss += loss.item()
118 | if (batch_idx + 1) % self.config.print_every == 0:
119 | deno = self.config.print_every * self.config.batch_size * np.prod(obs) * np.log(2.)
120 | train_loss = train_loss / deno
121 | print('epoch: {}, batch: {}, loss : {:.4f}'.format(epoch, batch_idx, train_loss), flush=True)
122 | tb_logger.add_scalar('loss', train_loss, global_step=writes)
123 | train_loss = 0.
124 | writes += 1
125 | # decrease learning rate
126 | scheduler.step()
127 |
128 | if self.config.model.ema:
129 | test_model = ema_helper.ema_copy(model)
130 | else:
131 | test_model = model
132 |
133 | test_model.eval()
134 | test_loss = 0.
135 | with torch.no_grad():
136 | for batch_idx, (input_var, _) in enumerate(test_loader):
137 | input_var = input_var.cuda(non_blocking=True)
138 |
139 | noisy_input_var = input_var + torch.randn_like(input_var) * self.config.noise
140 | clean_input_var = input_var + torch.randn_like(input_var) * self.config.clean_noise #* 0.02
141 |
142 | input_var = torch.cat([noisy_input_var, clean_input_var], dim=2)
143 | output = test_model(input_var)[:, :, input_var.shape[-1]:, :]
144 | loss = loss_op(clean_input_var, output)
145 | test_loss += loss.item()
146 | del loss, output
147 |
148 |
149 | deno = batch_idx * self.config.batch_size * np.prod(obs) * np.log(2.)
150 | test_loss = test_loss / deno
151 | print('epoch: %s, test loss : %s' % (epoch, test_loss), flush=True)
152 | tb_logger.add_scalar('test_loss', test_loss, global_step=writes)
153 |
154 | if (epoch + 1) % self.config.save_interval == 0:
155 | state_dict = [
156 | model.state_dict(),
157 | optimizer.state_dict(),
158 | scheduler.state_dict(),
159 | epoch,
160 | ]
161 | if self.config.model.ema:
162 | state_dict.append(ema_helper.state_dict())
163 |
164 | if (epoch + 1) % (self.config.save_interval * 2) == 0:
165 | torch.save(state_dict, os.path.join(ckpt_path, f'ckpt_epoch_{epoch}.pth'))
166 | torch.save(state_dict, os.path.join(ckpt_path, 'checkpoint.pth'))
167 |
168 | if epoch % 10 == 0:
169 | print('sampling...', flush=True)
170 | sample_t = debug_sample(test_model, noisy_input_var[:25])
171 | sample_t = torch.cat([clean_input_var[:25], sample_t], dim=2) #add original sample
172 | if self.config.with_logit is True:
173 | sample_t = sigmoid_transform(sample_t)
174 | else:
175 | sample_t = rescaling_inv(sample_t)
176 |
177 | if not os.path.exists(os.path.join(self.args.log, 'images')):
178 | os.makedirs(os.path.join(self.args.log, 'images'))
179 | utils.save_image(sample_t, os.path.join(self.args.log, 'images', f'sample_epoch_{epoch}.png'),
180 | nrow=5, padding=0)
181 | if self.config.model.ema:
182 | del test_model
183 |
184 | def test(self):
185 | raise NotImplementedError()
186 |
--------------------------------------------------------------------------------
/pixelcnnpp/pixelcnnpp.py:
--------------------------------------------------------------------------------
1 | from pixelcnnpp.layers import *
2 | from torch.autograd import Variable
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | import torch.nn.functional as F
7 |
8 |
9 | class PixelCNNLayer_up(nn.Module):
10 | def __init__(self, nr_resnet, nr_filters, resnet_nonlinearity):
11 | super(PixelCNNLayer_up, self).__init__()
12 | self.nr_resnet = nr_resnet
13 | # stream from pixels above
14 | self.u_stream = nn.ModuleList([gated_resnet(nr_filters, down_shifted_conv2d,
15 | resnet_nonlinearity, skip_connection=0)
16 | for _ in range(nr_resnet)])
17 |
18 | # stream from pixels above and to thes left
19 | self.ul_stream = nn.ModuleList([gated_resnet(nr_filters, down_right_shifted_conv2d,
20 | resnet_nonlinearity, skip_connection=1)
21 | for _ in range(nr_resnet)])
22 |
23 | def forward(self, u, ul):
24 | u_list, ul_list = [], []
25 |
26 | for i in range(self.nr_resnet):
27 | u = self.u_stream[i](u)
28 | ul = self.ul_stream[i](ul, a=u)
29 | u_list += [u]
30 | ul_list += [ul]
31 |
32 | return u_list, ul_list
33 |
34 |
35 | class PixelCNNLayer_down(nn.Module):
36 | def __init__(self, nr_resnet, nr_filters, resnet_nonlinearity):
37 | super(PixelCNNLayer_down, self).__init__()
38 | self.nr_resnet = nr_resnet
39 | # stream from pixels above
40 | self.u_stream = nn.ModuleList([gated_resnet(nr_filters, down_shifted_conv2d,
41 | resnet_nonlinearity, skip_connection=1)
42 | for _ in range(nr_resnet)])
43 |
44 | # stream from pixels above and to thes left
45 | self.ul_stream = nn.ModuleList([gated_resnet(nr_filters, down_right_shifted_conv2d,
46 | resnet_nonlinearity, skip_connection=2)
47 | for _ in range(nr_resnet)])
48 |
49 | def forward(self, u, ul, u_list, ul_list):
50 | for i in range(self.nr_resnet):
51 | u = self.u_stream[i](u, a=u_list.pop())
52 | ul = self.ul_stream[i](ul, a=torch.cat((u, ul_list.pop()), 1))
53 |
54 | return u, ul
55 |
56 |
57 | class PixelCNN(nn.Module):
58 | def __init__(self, config):
59 | super(PixelCNN, self).__init__()
60 | self.config = config
61 | nr_resnet = self.config.nr_resnet
62 | nr_filters = self.config.nr_filters
63 | input_channels = self.config.input_channels
64 | nr_logistic_mix = self.config.nr_logistic_mix
65 | resnet_nonlinearity = 'concat_elu'
66 |
67 | if resnet_nonlinearity == 'concat_elu':
68 | self.resnet_nonlinearity = lambda x: concat_elu(x)
69 | else:
70 | raise Exception('right now only concat elu is supported as resnet nonlinearity.')
71 |
72 | self.nr_filters = nr_filters
73 | self.input_channels = input_channels
74 | self.nr_logistic_mix = nr_logistic_mix
75 | self.right_shift_pad = nn.ZeroPad2d((1, 0, 0, 0))
76 | self.down_shift_pad = nn.ZeroPad2d((0, 0, 1, 0))
77 |
78 | down_nr_resnet = [nr_resnet] + [nr_resnet + 1] * 2
79 | self.down_layers = nn.ModuleList([PixelCNNLayer_down(down_nr_resnet[i], nr_filters,
80 | self.resnet_nonlinearity) for i in range(3)])
81 |
82 | self.up_layers = nn.ModuleList([PixelCNNLayer_up(nr_resnet, nr_filters,
83 | self.resnet_nonlinearity) for _ in range(3)])
84 |
85 | self.downsize_u_stream = nn.ModuleList([down_shifted_conv2d(nr_filters, nr_filters,
86 | stride=(2, 2)) for _ in range(2)])
87 |
88 | self.downsize_ul_stream = nn.ModuleList([down_right_shifted_conv2d(nr_filters,
89 | nr_filters, stride=(2, 2)) for _ in
90 | range(2)])
91 |
92 | self.upsize_u_stream = nn.ModuleList([down_shifted_deconv2d(nr_filters, nr_filters,
93 | stride=(2, 2)) for _ in range(2)])
94 |
95 | self.upsize_ul_stream = nn.ModuleList([down_right_shifted_deconv2d(nr_filters,
96 | nr_filters, stride=(2, 2)) for _ in
97 | range(2)])
98 |
99 | self.u_init = down_shifted_conv2d(input_channels + 1, nr_filters, filter_size=(2, 3),
100 | shift_output_down=True)
101 |
102 | self.ul_init = nn.ModuleList([down_shifted_conv2d(input_channels + 1, nr_filters,
103 | filter_size=(1, 3), shift_output_down=True),
104 | down_right_shifted_conv2d(input_channels + 1, nr_filters,
105 | filter_size=(2, 1), shift_output_right=True)])
106 |
107 | num_mix = 3 if self.input_channels == 1 else 10
108 | self.nin_out = nin(nr_filters, num_mix * nr_logistic_mix)
109 | self.init_padding = None
110 |
111 | def forward(self, x, sample=False):
112 | # similar as done in the tf repo :
113 | if self.init_padding is None and not sample:
114 | xs = [int(y) for y in x.size()]
115 | padding = Variable(torch.ones(xs[0], 1, xs[2], xs[3]), requires_grad=False)
116 | self.init_padding = padding.cuda() if x.is_cuda else padding
117 |
118 | if sample:
119 | xs = [int(y) for y in x.size()]
120 | padding = Variable(torch.ones(xs[0], 1, xs[2], xs[3]), requires_grad=False)
121 | padding = padding.cuda() if x.is_cuda else padding
122 | x = torch.cat((x, padding), 1)
123 |
124 | ### UP PASS ###
125 | x = x if sample else torch.cat((x, self.init_padding), 1)
126 | u_list = [self.u_init(x)]
127 | ul_list = [self.ul_init[0](x) + self.ul_init[1](x)]
128 | for i in range(3):
129 | # resnet block
130 | u_out, ul_out = self.up_layers[i](u_list[-1], ul_list[-1])
131 | u_list += u_out
132 | ul_list += ul_out
133 |
134 | if i != 2:
135 | # downscale (only twice)
136 | u_list += [self.downsize_u_stream[i](u_list[-1])]
137 | ul_list += [self.downsize_ul_stream[i](ul_list[-1])]
138 |
139 | ### DOWN PASS ###
140 | u = u_list.pop()
141 | ul = ul_list.pop()
142 |
143 | for i in range(3):
144 | # resnet block
145 | u, ul = self.down_layers[i](u, ul, u_list, ul_list)
146 |
147 | # upscale (only twice)
148 | if i != 2:
149 | u = self.upsize_u_stream[i](u)
150 | ul = self.upsize_ul_stream[i](ul)
151 |
152 | x_out = self.nin_out(F.elu(ul))
153 |
154 | assert len(u_list) == len(ul_list) == 0, breakpoint()
155 |
156 | return x_out
157 |
158 |
159 |
160 |
161 |
162 | def mix_logistic_loss(x, l, likelihood=False):
163 | """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
164 | # Pytorch ordering
165 | x = x.permute(0, 2, 3, 1)
166 | l = l.permute(0, 2, 3, 1)
167 | xs = [int(y) for y in x.size()]
168 | ls = [int(y) for y in l.size()]
169 |
170 | # here and below: unpacking the params of the mixture of logistics
171 | nr_mix = int(ls[-1] / 10)
172 | logit_probs = l[:, :, :, :nr_mix]
173 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3]) # 3 for mean, scale, coef
174 | means = l[:, :, :, :, :nr_mix]
175 | log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.)
176 |
177 | coeffs = torch.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix])
178 | # here and below: getting the means and adjusting them based on preceding
179 | # sub-pixels
180 | x = x.contiguous()
181 | x = x.unsqueeze(-1) + (torch.zeros(xs + [nr_mix]).cuda()).detach()
182 | m2 = (means[:, :, :, 1, :] + coeffs[:, :, :, 0, :]
183 | * x[:, :, :, 0, :]).view(xs[0], xs[1], xs[2], 1, nr_mix)
184 |
185 | m3 = (means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] +
186 | coeffs[:, :, :, 2, :] * x[:, :, :, 1, :]).view(xs[0], xs[1], xs[2], 1, nr_mix)
187 |
188 | means = torch.cat((means[:, :, :, 0, :].unsqueeze(3), m2, m3), dim=3)
189 | centered_x = x - means
190 | inv_stdv = torch.exp(-log_scales)
191 | mid_in = inv_stdv * centered_x
192 | log_probs = mid_in - log_scales - 2. * F.softplus(mid_in)
193 |
194 | if likelihood:
195 | log_probs = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs)
196 | return log_sum_exp(log_probs)
197 |
198 | log_probs = torch.sum(log_probs, dim=3) + log_prob_from_logits(logit_probs)
199 |
200 | return -torch.sum(log_sum_exp(log_probs))
201 |
202 |
203 | def mix_logistic_loss_1d(x, l, likelihood=False):
204 | """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
205 | # Pytorch ordering
206 | x = x.permute(0, 2, 3, 1)
207 | l = l.permute(0, 2, 3, 1)
208 | xs = [int(y) for y in x.size()]
209 | ls = [int(y) for y in l.size()]
210 |
211 | # here and below: unpacking the params of the mixture of logistics
212 | nr_mix = int(ls[-1] / 3)
213 | logit_probs = l[:, :, :, :nr_mix]
214 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2]) # 2 for mean, scale
215 | means = l[:, :, :, :, :nr_mix]
216 | log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.)
217 | # here and below: getting the means and adjusting them based on preceding
218 | # sub-pixels
219 | x = x.contiguous()
220 | x = x.unsqueeze(-1) + (torch.zeros(xs + [nr_mix]).cuda()).detach()
221 |
222 | centered_x = x - means
223 | inv_stdv = torch.exp(-log_scales)
224 | mid_in = inv_stdv * centered_x
225 | log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
226 |
227 | if likelihood:
228 | log_probs = torch.sum(log_pdf_mid, dim=3) + log_prob_from_logits(logit_probs)
229 | return log_sum_exp(log_probs)
230 |
231 | log_probs = torch.sum(log_pdf_mid, dim=3) + log_prob_from_logits(logit_probs)
232 |
233 | return -torch.sum(log_sum_exp(log_probs))
234 |
235 |
236 | def load_part_of_model(model, path, back_compat=False):
237 | params = torch.load(path)
238 | added = 0
239 | for name, param in params.items():
240 | if back_compat:
241 | name = '.'.join(name.split('.')[1:])
242 | if name in model.state_dict().keys():
243 | try:
244 | model.state_dict()[name].copy_(param)
245 | added += 1
246 | except Exception as e:
247 | print(e)
248 |
249 | print('added %s of params:' % (added / float(len(model.state_dict().keys()))))
250 |
--------------------------------------------------------------------------------
/pixelcnnpp/samplers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from itertools import product
4 | import tqdm
5 | import torch.autograd as autograd
6 | from torch.nn import functional as F
7 | from pixelcnnpp.layers import to_one_hot
8 |
9 |
10 | def sample_from_discretized_mix_logistic_inverse_CDF(x, model, nr_mix, noise=[], u=None, clamp=True, bisection_iter=15, T=1):
11 | # Pytorch ordering
12 | l = model(x)
13 | l = l.permute(0, 2, 3, 1)
14 | ls = [int(y) for y in l.size()]
15 | xs = ls[:-1] + [3]
16 |
17 | #added
18 | if len(noise) != 0:
19 | noise = noise.permute(0, 2, 3, 1)
20 | #added
21 |
22 | # unpack parameters
23 | logit_probs = l[:, :, :, :nr_mix] / T
24 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3])
25 | # sample mixture indicator from softmax
26 | if u is None:
27 | u = l.new_empty(l.shape[0], l.shape[1] * l.shape[2] * 3)
28 | u.uniform_(1e-5, 1. - 1e-5)
29 | u = torch.log(u) - torch.log(1. - u)
30 |
31 | u_r, u_g, u_b = torch.chunk(u, chunks=3, dim=-1)
32 |
33 | u_r = u_r.reshape(ls[:-1])
34 | u_g = u_g.reshape(ls[:-1])
35 | u_b = u_b.reshape(ls[:-1])
36 |
37 | log_softmax = torch.log_softmax(logit_probs, dim=-1)
38 | coeffs = torch.tanh(l[:, :, :, :, 2 * nr_mix: 3 * nr_mix])
39 | means = l[:, :, :, :, :nr_mix]
40 | log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.) + np.log(T)
41 | if clamp:
42 | ubs = l.new_ones(ls[:-1])
43 | lbs = -ubs
44 | else:
45 | ubs = l.new_ones(ls[:-1]) * 20.
46 | lbs = -ubs
47 |
48 | means_r = means[..., 0, :]
49 | log_scales_r = log_scales[..., 0, :]
50 |
51 | def log_cdf_pdf_r(values, mode='cdf', mixtures=False):
52 | values = values.unsqueeze(-1)
53 | centered_values = (values - means_r) / log_scales_r.exp()
54 |
55 | if mode == 'cdf':
56 | log_logistic_cdf = -F.softplus(-centered_values)
57 | log_logistic_sf = -F.softplus(centered_values)
58 | log_cdf = torch.logsumexp(log_softmax + log_logistic_cdf, dim=-1)
59 | log_sf = torch.logsumexp(log_softmax + log_logistic_sf, dim=-1)
60 | logit = log_cdf - log_sf
61 |
62 | return logit if not mixtures else (logit, log_logistic_cdf)
63 |
64 | elif mode == 'pdf':
65 | log_logistic_pdf = -centered_values - log_scales_r - 2. * F.softplus(-centered_values)
66 | log_pdf = torch.logsumexp(log_softmax + log_logistic_pdf, dim=-1)
67 |
68 | return log_pdf if not mixtures else (log_pdf, log_logistic_pdf)
69 |
70 | x0 = binary_search(u_r, lbs.clone(), ubs.clone(), lambda x: log_cdf_pdf_r(x, mode='cdf'), bisection_iter)
71 |
72 | if len(noise) == 0:
73 | means_g = x0.unsqueeze(-1) * coeffs[:, :, :, 0, :] + means[..., 1, :]
74 | else:
75 | means_g = (x0.unsqueeze(-1) + noise[:, :, :, 0].unsqueeze(-1)) * coeffs[:, :, :, 0, :] + means[..., 1, :]
76 |
77 | means_g = means_g.detach() #added, to make autograd sample correct
78 | log_scales_g = log_scales[..., 1, :]
79 |
80 | log_p_r, log_p_r_mixtures = log_cdf_pdf_r(x0, mode='pdf', mixtures=True)
81 |
82 | def log_cdf_pdf_g(values, mode='cdf', mixtures=False):
83 | values = values.unsqueeze(-1)
84 | centered_values = (values - means_g) / log_scales_g.exp()
85 |
86 | if mode == 'cdf':
87 | log_logistic_cdf = log_p_r_mixtures - log_p_r[..., None] - F.softplus(-centered_values)
88 | log_logistic_sf = log_p_r_mixtures - log_p_r[..., None] - F.softplus(centered_values)
89 | log_cdf = torch.logsumexp(log_softmax + log_logistic_cdf, dim=-1)
90 | log_sf = torch.logsumexp(log_softmax + log_logistic_sf, dim=-1)
91 | logit = log_cdf - log_sf
92 |
93 | return logit if not mixtures else (logit, log_logistic_cdf)
94 |
95 | elif mode == 'pdf':
96 | log_logistic_pdf = log_p_r_mixtures - log_p_r[..., None] - centered_values - log_scales_g - 2. * F.softplus(
97 | -centered_values)
98 | log_pdf = torch.logsumexp(log_softmax + log_logistic_pdf, dim=-1)
99 |
100 | return log_pdf if not mixtures else (log_pdf, log_logistic_pdf)
101 |
102 | x1 = binary_search(u_g, lbs.clone(), ubs.clone(), lambda x: log_cdf_pdf_g(x, mode='cdf'), bisection_iter)
103 |
104 | if len(noise) == 0:
105 | means_b = x1.unsqueeze(-1) * coeffs[:, :, :, 2, :] + x0.unsqueeze(-1) * coeffs[:, :, :, 1, :] + means[..., 2, :]
106 | else:
107 | means_b = (x1.unsqueeze(-1) + noise[:, :, :, 1].unsqueeze(-1)) * coeffs[:, :, :, 2, :] + \
108 | (x0.unsqueeze(-1) + noise[:, :, :, 0].unsqueeze(-1)) * coeffs[:, :, :, 1, :] + means[..., 2, :]
109 |
110 | means_b = means_b.detach() #added, to make autograd sample correct
111 | log_scales_b = log_scales[..., 2, :]
112 |
113 | log_p_g, log_p_g_mixtures = log_cdf_pdf_g(x1, mode='pdf', mixtures=True)
114 |
115 | def log_cdf_pdf_b(values, mode='cdf', mixtures=False):
116 | values = values.unsqueeze(-1)
117 | centered_values = (values - means_b) / log_scales_b.exp()
118 |
119 | if mode == 'cdf':
120 | log_logistic_cdf = log_p_g_mixtures - log_p_g[..., None] - F.softplus(-centered_values)
121 | log_logistic_sf = log_p_g_mixtures - log_p_g[..., None] - F.softplus(centered_values)
122 | log_cdf = torch.logsumexp(log_softmax + log_logistic_cdf, dim=-1)
123 | log_sf = torch.logsumexp(log_softmax + log_logistic_sf, dim=-1)
124 | logit = log_cdf - log_sf
125 |
126 | return logit if not mixtures else (logit, log_logistic_cdf)
127 |
128 | elif mode == 'pdf':
129 | log_logistic_pdf = log_p_g_mixtures - log_p_g[..., None] - centered_values - log_scales_b - 2. * F.softplus(
130 | -centered_values)
131 | log_pdf = torch.logsumexp(log_softmax + log_logistic_pdf, dim=-1)
132 |
133 | return log_pdf if not mixtures else (log_pdf, log_logistic_pdf)
134 |
135 | x2 = binary_search(u_b, lbs.clone(), ubs.clone(), lambda x: log_cdf_pdf_b(x, mode='cdf'), bisection_iter)
136 |
137 | out = torch.cat([x0.view(xs[:-1] + [1]), x1.view(xs[:-1] + [1]), x2.view(xs[:-1] + [1])], dim=3)
138 | # put back in Pytorch ordering
139 | out = out.permute(0, 3, 1, 2)
140 | return out
141 |
142 |
143 | def sample_from_discretized_mix_logistic_inverse_CDF_1d(x, model, nr_mix, u=None, clamp=True, bisection_iter=15):
144 | # Pytorch ordering
145 |
146 | with torch.autograd.profiler.profile(enabled=False, use_cuda=True) as prof:
147 | l = model(x)
148 | # print(prof.key_averages().table(sort_by='cpu_time_total'))
149 | # breakpoint()
150 |
151 | l = l.permute(0, 2, 3, 1)
152 | ls = [int(y) for y in l.size()]
153 | xs = ls[:-1] + [1]
154 |
155 | # unpack parameters
156 | logit_probs = l[:, :, :, :nr_mix]
157 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2])
158 | # sample mixture indicator from softmax
159 | if u is None:
160 | u = l.new_empty(l.shape[0], l.shape[1] * l.shape[2])
161 | u.uniform_(1e-5, 1. - 1e-5)
162 | u = torch.log(u) - torch.log(1. - u)
163 |
164 | u_r = u.reshape(ls[:-1])
165 |
166 | log_softmax = torch.log_softmax(logit_probs, dim=-1)
167 | means = l[:, :, :, :, :nr_mix]
168 | log_scales = torch.clamp(l[:, :, :, :, nr_mix:2 * nr_mix], min=-7.)
169 | if clamp is True:
170 | ubs = l.new_ones(ls[:-1]) * 1.
171 | lbs = -ubs
172 | else:
173 | ubs = l.new_ones(ls[:-1]) * 30.
174 | lbs = -ubs
175 |
176 | means_r = means[..., 0, :]
177 | log_scales_r = log_scales[..., 0, :]
178 |
179 | def log_cdf_pdf_r(values, mode='cdf', mixtures=False):
180 | values = values.unsqueeze(-1)
181 | centered_values = (values - means_r) / log_scales_r.exp()
182 |
183 | if mode == 'cdf':
184 | log_logistic_cdf = -F.softplus(-centered_values)
185 | log_logistic_sf = -F.softplus(centered_values)
186 | log_cdf = torch.logsumexp(log_softmax + log_logistic_cdf, dim=-1)
187 | log_sf = torch.logsumexp(log_softmax + log_logistic_sf, dim=-1)
188 | logit = log_cdf - log_sf
189 |
190 | return logit if not mixtures else (logit, log_logistic_cdf)
191 |
192 | elif mode == 'pdf':
193 | log_logistic_pdf = -centered_values - log_scales_r - 2. * F.softplus(-centered_values)
194 | log_pdf = torch.logsumexp(log_softmax + log_logistic_pdf, dim=-1)
195 |
196 | return log_pdf if not mixtures else (log_pdf, log_logistic_pdf)
197 |
198 | x0 = binary_search(u_r, lbs.clone(), ubs.clone(), lambda x: log_cdf_pdf_r(x, mode='cdf'), bisection_iter)
199 |
200 | out = x0.view(xs[:-1] + [1])
201 | out = out.permute(0, 3, 1, 2)
202 | return out
203 |
204 |
205 | def sample_from_discretized_mix_logistic_1d(x, model, nr_mix, u=None, clamp=True):
206 | # Pytorch ordering
207 | l = model(x)
208 | l = l.permute(0, 2, 3, 1)
209 | ls = [int(y) for y in l.size()]
210 | xs = ls[:-1] + [1] # [3]
211 |
212 | # unpack parameters
213 | logit_probs = l[:, :, :, :nr_mix]
214 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 2]) # for mean, scale
215 |
216 | # sample mixture indicator from softmax
217 | if u is None:
218 | u = l.new_empty(l.shape[0], l.shape[1] * l.shape[2] * (nr_mix + 1))
219 | u.uniform_(1e-5, 1. - 1e-5)
220 | mixture_u, sample_u = torch.split(u, [l.shape[1] * l.shape[2] * nr_mix,
221 | l.shape[1] * l.shape[2] * 1], dim=-1)
222 | mixture_u = mixture_u.reshape(l.shape[0], l.shape[1], l.shape[2], nr_mix)
223 | sample_u = sample_u.reshape(l.shape[0], l.shape[1], l.shape[2], 1)
224 |
225 | mixture_u = logit_probs.data - torch.log(- torch.log(mixture_u))
226 | _, argmax = mixture_u.max(dim=3)
227 |
228 | one_hot = to_one_hot(argmax, nr_mix)
229 | sel = one_hot.view(xs[:-1] + [1, nr_mix])
230 | # select logistic parameters
231 | means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4)
232 | log_scales = torch.clamp(torch.sum(
233 | l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.)
234 |
235 | x = means + torch.exp(log_scales) * (torch.log(sample_u) - torch.log(1. - sample_u))
236 | if clamp:
237 | x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.)
238 | else:
239 | x0 = x[:, :, :, 0]
240 |
241 | out = x0.unsqueeze(1)
242 | return out
243 |
244 |
245 | def sample_from_discretized_mix_logistic(x, model, nr_mix, u=None, T=1, clamp=True):
246 | # Pytorch ordering
247 | l = model(x)
248 | l = l.permute(0, 2, 3, 1)
249 | ls = [int(y) for y in l.size()]
250 | xs = ls[:-1] + [3]
251 |
252 | # unpack parameters
253 | logit_probs = l[:, :, :, :nr_mix] / T
254 | l = l[:, :, :, nr_mix:].contiguous().view(xs + [nr_mix * 3])
255 | # sample mixture indicator from softmax
256 | if u is None:
257 | u = l.new_empty(l.shape[0], l.shape[1] * l.shape[2] * (nr_mix + 3))
258 | u.uniform_(1e-5, 1. - 1e-5)
259 |
260 | mixture_u, sample_u = torch.split(u, [l.shape[1] * l.shape[2] * nr_mix,
261 | l.shape[1] * l.shape[2] * 3], dim=-1)
262 | mixture_u = mixture_u.reshape(l.shape[0], l.shape[1], l.shape[2], nr_mix)
263 | sample_u = sample_u.reshape(l.shape[0], l.shape[1], l.shape[2], 3)
264 |
265 | mixture_u = logit_probs.data - torch.log(- torch.log(mixture_u))
266 | _, argmax = mixture_u.max(dim=3)
267 |
268 | one_hot = to_one_hot(argmax, nr_mix)
269 | sel = one_hot.view(xs[:-1] + [1, nr_mix])
270 | # select logistic parameters
271 | means = torch.sum(l[:, :, :, :, :nr_mix] * sel, dim=4)
272 | log_scales = torch.clamp(torch.sum(
273 | l[:, :, :, :, nr_mix:2 * nr_mix] * sel, dim=4), min=-7.) + np.log(T)
274 | coeffs = torch.sum(torch.tanh(
275 | l[:, :, :, :, 2 * nr_mix:3 * nr_mix]) * sel, dim=4)
276 | # sample from logistic & clip to interval
277 | # we don't actually round to the nearest 8bit value when sampling
278 |
279 | x = means + torch.exp(log_scales) * (torch.log(sample_u) - torch.log(1. - sample_u))
280 | if clamp:
281 | x0 = torch.clamp(torch.clamp(x[:, :, :, 0], min=-1.), max=1.)
282 | else:
283 | x0 = x[:, :, :, 0]
284 |
285 | if clamp:
286 | x1 = torch.clamp(torch.clamp(
287 | x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, min=-1.), max=1.)
288 | else:
289 | x1 = x[:, :, :, 1] + coeffs[:, :, :, 0] * x0
290 |
291 | if clamp:
292 | x2 = torch.clamp(torch.clamp(
293 | x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1, min=-1.), max=1.)
294 | else:
295 | x2 = x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1
296 |
297 | out = torch.cat([x0.view(xs[:-1] + [1]), x1.view(xs[:-1] + [1]), x2.view(xs[:-1] + [1])], dim=3)
298 | # put back in Pytorch ordering
299 | out = out.permute(0, 3, 1, 2)
300 | return out
301 |
302 |
303 | def binary_search(log_cdf, lb, ub, cdf_fun, n_iter=15):
304 | with torch.no_grad():
305 | for i in range(n_iter):
306 | mid = (lb + ub) / 2.
307 | mid_cdf_value = cdf_fun(mid)
308 | right_idxes = mid_cdf_value < log_cdf
309 | left_idxes = ~right_idxes
310 | lb[right_idxes] = torch.min(mid[right_idxes], ub[right_idxes])
311 | ub[left_idxes] = torch.max(mid[left_idxes], lb[left_idxes])
312 |
313 | return mid
--------------------------------------------------------------------------------