├── .gitattributes ├── LICENSE ├── README.md ├── img ├── misgan-impute.png └── misgan.png ├── misgan.ipynb ├── requirements.txt ├── src-torch1.6 ├── celeba_critic.py ├── celeba_fid.py ├── celeba_generator.py ├── celeba_misgan.py ├── celeba_misgan_impute.py ├── fcnet.py ├── fid.py ├── imputer.py ├── inception.py ├── masked_celeba.py ├── masked_mnist.py ├── misgan.ipynb ├── misgan.py ├── misgan_impute.py ├── mnist_critic.py ├── mnist_fid.py ├── mnist_generator.py ├── mnist_imputer.py ├── mnist_misgan.py ├── mnist_misgan_impute.py ├── mnist_model.py ├── plot.py ├── requirements.txt ├── unet.py └── utils.py └── src ├── celeba_critic.py ├── celeba_fid.py ├── celeba_generator.py ├── celeba_misgan.py ├── celeba_misgan_impute.py ├── fcnet.py ├── fid.py ├── imputer.py ├── inception.py ├── masked_celeba.py ├── masked_mnist.py ├── misgan.py ├── misgan_impute.py ├── mnist_critic.py ├── mnist_fid.py ├── mnist_generator.py ├── mnist_imputer.py ├── mnist_misgan.py ├── mnist_misgan_impute.py ├── mnist_model.py ├── plot.py ├── unet.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | misgan.ipynb linguist-language=Python 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Steven Cheng-Xian Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MisGAN: Learning from Incomplete Data with GANs 2 | 3 | This repository provides a PyTorch implementation of 4 | [MisGAN](https://arxiv.org/abs/1902.09599), 5 | a GAN-based framework for learning from incomplete data. 6 | 7 | **Note:** Please check out our 8 | [follow-up work](https://github.com/steveli/partial-encoder-decoder) 9 | on models that can be trained faster and more stably. 10 | 11 | 12 | ## Requirements 13 | 14 | The code requires Python 3.6 or later. 15 | The file [requirements.txt](requirements.txt) contains the full list of 16 | required Python modules. 17 | 18 | 19 | ## Jupyter notebook 20 | 21 | We provide a [notebook](misgan.ipynb) that includes an overview of MisGAN 22 | as well as the annotated implementation that runs on MNIST. 23 | The notebook can be viewed from 24 | [here](https://nbviewer.jupyter.org/github/steveli/misgan/blob/master/misgan.ipynb). 25 | 26 | 27 | ## Usage 28 | 29 | The source code can be found in the `src` directory. 30 | Separate scripts are provided to run MisGAN on MNIST and CelebA datasets. 31 | 32 | For CelebA, you will need to download the dataset from its 33 | [website](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html): 34 | 35 | * Download the file `img_align_celeba.zip` (available from [this link](https://drive.google.com/uc?export=download&id=0B7EVK8r0v71pZjFTYXZWM3FlRnM)). 36 | * Extract the zip file into the directory `src/celeba-data` that you create. 37 | 38 | The commands below need to be run under the `src` directory. 39 | 40 | MisGAN on MNIST: 41 | ```bash 42 | python mnist_misgan.py 43 | ``` 44 | 45 | MisGAN imputation on MNIST: 46 | ```bash 47 | python mnist_misgan_impute.py 48 | ``` 49 | 50 | MisGAN on CelebA: 51 | ```bash 52 | python celeba_misgan.py 53 | ``` 54 | 55 | MisGAN imputation on CelebA: 56 | ```bash 57 | python celeba_misgan_impute.py 58 | ``` 59 | 60 | Use `-h` to see all available command-line arguments for each script. 61 | 62 | 63 | ## References 64 | 65 | Steven Cheng-Xian Li, Bo Jiang, Benjamin Marlin. 66 | "MisGAN: Learning from Incomplete Data with Generative Adversarial Networks." 67 | ICLR 2019. 68 | \[[arXiv](https://arxiv.org/abs/1902.09599)\] 69 | 70 | 71 | ## Contact 72 | 73 | Your feedback would be greatly appreciated! 74 | Reach us at . 75 | -------------------------------------------------------------------------------- /img/misgan-impute.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/steveli/misgan/f30dd73ebe602c81b1a0cfb72708c41687fb13d1/img/misgan-impute.png -------------------------------------------------------------------------------- /img/misgan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/steveli/misgan/f30dd73ebe602c81b1a0cfb72708c41687fb13d1/img/misgan.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=2.2.2 2 | numpy>=1.14.5 3 | Pillow>=5.1.0 4 | scipy>=1.1.0 5 | seaborn>=0.8.1 6 | torch==1.1.0 7 | torchvision==0.3.0 8 | -------------------------------------------------------------------------------- /src-torch1.6/celeba_critic.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def conv_ln_lrelu(in_dim, out_dim): 5 | return nn.Sequential( 6 | nn.Conv2d(in_dim, out_dim, 5, 2, 2), 7 | nn.InstanceNorm2d(out_dim, affine=True), 8 | nn.LeakyReLU(0.2)) 9 | 10 | 11 | class ConvCritic(nn.Module): 12 | def __init__(self, n_channels): 13 | super().__init__() 14 | dim = 64 15 | self.ls = nn.Sequential( 16 | nn.Conv2d(n_channels, dim, 5, 2, 2), nn.LeakyReLU(0.2), 17 | conv_ln_lrelu(dim, dim * 2), 18 | conv_ln_lrelu(dim * 2, dim * 4), 19 | conv_ln_lrelu(dim * 4, dim * 8), 20 | nn.Conv2d(dim * 8, 1, 4)) 21 | 22 | def forward(self, input): 23 | net = self.ls(input) 24 | return net.view(-1) 25 | -------------------------------------------------------------------------------- /src-torch1.6/celeba_fid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from torchvision import datasets, transforms 7 | from PIL import Image 8 | from celeba_generator import ConvDataGenerator 9 | from fid import BaseSampler, BaseImputationSampler 10 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA 11 | from imputer import UNetImputer 12 | from fid import FID 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('root_dir') 17 | parser.add_argument('--batch-size', type=int, default=256) 18 | parser.add_argument('--workers', type=int, default=0) 19 | parser.add_argument('--skip-exist', action='store_true') 20 | args = parser.parse_args() 21 | 22 | 23 | use_cuda = torch.cuda.is_available() 24 | device = torch.device('cuda' if use_cuda else 'cpu') 25 | 26 | 27 | class CelebAFID(FID): 28 | def __init__(self, batch_size=256, data_name='celeba', 29 | workers=0, verbose=True): 30 | self.batch_size = batch_size 31 | self.workers = workers 32 | super().__init__(data_name, verbose) 33 | 34 | def complete_data(self): 35 | data = datasets.ImageFolder( 36 | 'celeba', 37 | transforms.Compose([ 38 | transforms.CenterCrop(108), 39 | transforms.Resize(size=64, interpolation=Image.BICUBIC), 40 | transforms.ToTensor(), 41 | # transforms.Normalize(mean=(.5, .5, .5), std=(.5, .5, .5)), 42 | ])) 43 | 44 | images = len(data) 45 | data_loader = DataLoader( 46 | data, batch_size=self.batch_size, num_workers=self.workers) 47 | 48 | return data_loader, images 49 | 50 | 51 | class MisGANSampler(BaseSampler): 52 | def __init__(self, data_gen, images=60000, batch_size=256): 53 | super().__init__(images) 54 | self.data_gen = data_gen 55 | self.batch_size = batch_size 56 | latent_dim = 128 57 | self.data_noise = torch.FloatTensor(batch_size, latent_dim).to(device) 58 | 59 | def sample(self): 60 | self.data_noise.normal_() 61 | return self.data_gen(self.data_noise) 62 | 63 | 64 | class MisGANImputationSampler(BaseImputationSampler): 65 | def __init__(self, data_loader, imputer, batch_size=256): 66 | super().__init__(data_loader) 67 | self.imputer = imputer 68 | self.impu_noise = torch.FloatTensor(batch_size, 3, 64, 64).to(device) 69 | 70 | def impute(self, data, mask): 71 | if data.shape[0] != self.impu_noise.shape[0]: 72 | self.impu_noise.resize_(data.shape) 73 | self.impu_noise.uniform_() 74 | return self.imputer(data, mask, self.impu_noise) 75 | 76 | 77 | def get_data_loader(args, batch_size): 78 | if args.mask == 'indep': 79 | data = IndepMaskedCelebA( 80 | data_dir=args.data_dir, 81 | obs_prob=args.obs_prob, obs_prob_high=args.obs_prob_high) 82 | elif args.mask == 'block': 83 | data = BlockMaskedCelebA( 84 | data_dir=args.data_dir, block_len=args.block_len) 85 | 86 | data_size = len(data) 87 | data_loader = DataLoader( 88 | data, batch_size=batch_size, num_workers=args.workers) 89 | return data_loader, data_size 90 | 91 | 92 | def parallelize(model): 93 | return nn.DataParallel(model).to(device) 94 | 95 | 96 | def pretrained_misgan_fid(model_file, samples=202599): 97 | model = torch.load(model_file, map_location='cpu') 98 | data_gen = parallelize(ConvDataGenerator()) 99 | data_gen.load_state_dict(model['data_gen']) 100 | 101 | batch_size = args.batch_size 102 | 103 | compute_fid = CelebAFID(batch_size=batch_size) 104 | sampler = MisGANSampler(data_gen, samples, batch_size) 105 | gen_fid = compute_fid.fid(sampler, samples) 106 | print(f'fid: {gen_fid:.2f}') 107 | 108 | imp_fid = None 109 | if 'imputer' in model: 110 | imputer = UNetImputer().to(device) 111 | imputer.load_state_dict(model['imputer']) 112 | data_loader, data_size = get_data_loader(model['args'], batch_size) 113 | imputation_sampler = MisGANImputationSampler( 114 | data_loader, imputer, batch_size) 115 | imp_fid = compute_fid.fid(imputation_sampler, data_size) 116 | print(f'impute fid: {imp_fid:.2f}') 117 | 118 | return gen_fid, imp_fid 119 | 120 | 121 | def main(): 122 | root_dir = Path(args.root_dir) 123 | fid_file = root_dir / 'fid.txt' 124 | if args.skip_exist and fid_file.exists(): 125 | return 126 | try: 127 | model_file = max((root_dir / 'model').glob('*.pth')) 128 | except ValueError: 129 | return 130 | 131 | print(root_dir.name) 132 | fid, imp_fid = pretrained_misgan_fid(model_file) 133 | 134 | with fid_file.open('w') as f: 135 | print(fid, file=f) 136 | 137 | if imp_fid is not None: 138 | with (root_dir / 'impute-fid.txt').open('w') as f: 139 | print(imp_fid, file=f) 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /src-torch1.6/celeba_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def add_mask_transformer(self, temperature=.66, hard_sigmoid=(-.1, 1.1)): 7 | """ 8 | hard_sigmoid: 9 | False: use sigmoid only 10 | True: hard thresholding 11 | (a, b): hard thresholding on rescaled sigmoid 12 | """ 13 | self.temperature = temperature 14 | self.hard_sigmoid = hard_sigmoid 15 | 16 | if hard_sigmoid is False: 17 | self.transform = lambda x: torch.sigmoid(x / temperature) 18 | elif hard_sigmoid is True: 19 | self.transform = lambda x: F.hardtanh( 20 | x / temperature, 0, 1) 21 | else: 22 | a, b = hard_sigmoid 23 | self.transform = lambda x: F.hardtanh( 24 | torch.sigmoid(x / temperature) * (b - a) + a, 0, 1) 25 | 26 | 27 | def dconv_bn_relu(in_dim, out_dim): 28 | return nn.Sequential( 29 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2, 30 | padding=2, output_padding=1, bias=False), 31 | nn.BatchNorm2d(out_dim), 32 | nn.ReLU()) 33 | 34 | 35 | # Must sub-class ConvGenerator to provide transform() 36 | class ConvGenerator(nn.Module): 37 | def __init__(self, latent_size=128): 38 | super().__init__() 39 | 40 | dim = 64 41 | 42 | self.l1 = nn.Sequential( 43 | nn.Linear(latent_size, dim * 8 * 4 * 4, bias=False), 44 | nn.BatchNorm1d(dim * 8 * 4 * 4), 45 | nn.ReLU()) 46 | 47 | self.l2_5 = nn.Sequential( 48 | dconv_bn_relu(dim * 8, dim * 4), 49 | dconv_bn_relu(dim * 4, dim * 2), 50 | dconv_bn_relu(dim * 2, dim), 51 | nn.ConvTranspose2d(dim, self.out_channels, 5, 2, 52 | padding=2, output_padding=1)) 53 | 54 | def forward(self, input): 55 | net = self.l1(input) 56 | net = net.view(net.shape[0], -1, 4, 4) 57 | net = self.l2_5(net) 58 | return self.transform(net) 59 | 60 | 61 | class ConvDataGenerator(ConvGenerator): 62 | def __init__(self, latent_size=128): 63 | self.out_channels = 3 64 | super().__init__(latent_size=latent_size) 65 | self.transform = lambda x: torch.sigmoid(x) 66 | 67 | 68 | class ConvMaskGenerator(ConvGenerator): 69 | def __init__(self, latent_size=128, temperature=.66, 70 | hard_sigmoid=(-.1, 1.1)): 71 | self.out_channels = 1 72 | super().__init__(latent_size=latent_size) 73 | add_mask_transformer(self, temperature, hard_sigmoid) 74 | -------------------------------------------------------------------------------- /src-torch1.6/celeba_misgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from datetime import datetime 4 | from pathlib import Path 5 | import argparse 6 | from celeba_generator import ConvDataGenerator, ConvMaskGenerator 7 | from celeba_critic import ConvCritic 8 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA 9 | from misgan import misgan 10 | 11 | 12 | use_cuda = torch.cuda.is_available() 13 | device = torch.device('cuda' if use_cuda else 'cpu') 14 | 15 | 16 | def parallelize(model): 17 | return nn.DataParallel(model).to(device) 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | 23 | # resume from checkpoint 24 | parser.add_argument('--resume') 25 | 26 | # path of CelebA dataset 27 | parser.add_argument('--data-dir', default='celeba-data') 28 | 29 | # training options 30 | parser.add_argument('--epoch', type=int, default=600) 31 | parser.add_argument('--batch-size', type=int, default=256) 32 | 33 | # log options: 0 to disable plot-interval or save-interval 34 | parser.add_argument('--plot-interval', type=int, default=100) 35 | parser.add_argument('--save-interval', type=int, default=0) 36 | parser.add_argument('--prefix', default='misgan') 37 | 38 | # mask options (data): block|indep 39 | parser.add_argument('--mask', default='block') 40 | # option for block: set to 0 for variable size 41 | parser.add_argument('--block-len', type=int, default=32) 42 | # option for indep: 43 | parser.add_argument('--obs-prob', type=float, default=.2) 44 | parser.add_argument('--obs-prob-high', type=float, default=None) 45 | 46 | # model options 47 | parser.add_argument('--tau', type=float, default=.5) 48 | parser.add_argument('--alpha', type=float, default=.1) # 0: separate 49 | # options for mask generator: sigmoid, hardsigmoid, fusion 50 | parser.add_argument('--maskgen', default='fusion') 51 | parser.add_argument('--gp-lambda', type=float, default=10) 52 | parser.add_argument('--n-critic', type=int, default=5) 53 | parser.add_argument('--n-latent', type=int, default=128) 54 | 55 | args = parser.parse_args() 56 | 57 | checkpoint = None 58 | # Resume from previously stored checkpoint 59 | if args.resume: 60 | print(f'Resume: {args.resume}') 61 | output_dir = Path(args.resume) 62 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'), 63 | map_location='cpu') 64 | for key, arg in vars(checkpoint['args']).items(): 65 | if key not in ['resume']: 66 | setattr(args, key, arg) 67 | 68 | if args.maskgen == 'sigmoid': 69 | hard_sigmoid = False 70 | elif args.maskgen == 'hardsigmoid': 71 | hard_sigmoid = True 72 | elif args.maskgen == 'fusion': 73 | hard_sigmoid = -.1, 1.1 74 | else: 75 | raise NotImplementedError 76 | 77 | mask = args.mask 78 | obs_prob = args.obs_prob 79 | obs_prob_high = args.obs_prob_high 80 | block_len = args.block_len 81 | if block_len == 0: 82 | block_len = None 83 | if mask == 'indep': 84 | if obs_prob_high is None: 85 | mask_str = f'indep_{obs_prob:g}' 86 | else: 87 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}' 88 | elif mask == 'block': 89 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize') 90 | else: 91 | raise NotImplementedError 92 | 93 | path = '{}_{}_{}'.format( 94 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), 95 | '_'.join([ 96 | f'tau_{args.tau:g}', 97 | f'alpha_{args.alpha:g}', 98 | f'maskgen_{args.maskgen}', 99 | mask_str, 100 | ])) 101 | 102 | if not args.resume: 103 | output_dir = Path('results') / 'celeba' / path 104 | print(output_dir) 105 | 106 | if mask == 'indep': 107 | data = IndepMaskedCelebA( 108 | data_dir=args.data_dir, 109 | obs_prob=obs_prob, obs_prob_high=obs_prob_high) 110 | elif mask == 'block': 111 | data = BlockMaskedCelebA( 112 | data_dir=args.data_dir, block_len=block_len) 113 | n_gpu = torch.cuda.device_count() 114 | print(f'Use {n_gpu} GPUs.') 115 | 116 | data_gen = parallelize(ConvDataGenerator()) 117 | mask_gen = parallelize(ConvMaskGenerator(hard_sigmoid=hard_sigmoid)) 118 | 119 | data_critic = parallelize(ConvCritic(n_channels=3)) 120 | mask_critic = parallelize(ConvCritic(n_channels=1)) 121 | 122 | misgan(args, data_gen, mask_gen, data_critic, mask_critic, data, 123 | output_dir, checkpoint) 124 | 125 | 126 | if __name__ == '__main__': 127 | main() 128 | -------------------------------------------------------------------------------- /src-torch1.6/celeba_misgan_impute.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from datetime import datetime 4 | from pathlib import Path 5 | import argparse 6 | from celeba_generator import ConvDataGenerator, ConvMaskGenerator 7 | from celeba_critic import ConvCritic 8 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA 9 | from imputer import UNetImputer 10 | from misgan_impute import misgan_impute 11 | 12 | 13 | use_cuda = torch.cuda.is_available() 14 | device = torch.device('cuda' if use_cuda else 'cpu') 15 | 16 | 17 | def parallelize(model): 18 | return nn.DataParallel(model).to(device) 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser() 23 | 24 | # resume from checkpoint 25 | parser.add_argument('--resume') 26 | 27 | # path of CelebA dataset 28 | parser.add_argument('--data-dir', default='celeba-data') 29 | 30 | # training options 31 | parser.add_argument('--workers', type=int, default=0) 32 | parser.add_argument('--epoch', type=int, default=800) 33 | parser.add_argument('--batch-size', type=int, default=512) 34 | parser.add_argument('--pretrain', default=None) 35 | parser.add_argument('--imputeronly', action='store_true') 36 | 37 | # log options: 0 to disable plot-interval or save-interval 38 | parser.add_argument('--plot-interval', type=int, default=50) 39 | parser.add_argument('--save-interval', type=int, default=0) 40 | parser.add_argument('--prefix', default='impute') 41 | 42 | # mask options (data): block|indep 43 | parser.add_argument('--mask', default='block') 44 | # option for block: set to 0 for variable size 45 | parser.add_argument('--block-len', type=int, default=32) 46 | # option for indep: 47 | parser.add_argument('--obs-prob', type=float, default=.2) 48 | parser.add_argument('--obs-prob-high', type=float, default=None) 49 | 50 | # model options 51 | parser.add_argument('--tau', type=float, default=.5) 52 | parser.add_argument('--alpha', type=float, default=.1) # 0: separate 53 | parser.add_argument('--beta', type=float, default=.1) 54 | parser.add_argument('--gamma', type=float, default=0) 55 | # options for mask generator: sigmoid, hardsigmoid, fusion 56 | parser.add_argument('--maskgen', default='fusion') 57 | parser.add_argument('--gp-lambda', type=float, default=10) 58 | parser.add_argument('--n-critic', type=int, default=5) 59 | parser.add_argument('--n-latent', type=int, default=128) 60 | 61 | args = parser.parse_args() 62 | 63 | checkpoint = None 64 | # Resume from previously stored checkpoint 65 | if args.resume: 66 | print(f'Resume: {args.resume}') 67 | output_dir = Path(args.resume) 68 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'), 69 | map_location='cpu') 70 | for key, arg in vars(checkpoint['args']).items(): 71 | if key not in ['resume']: 72 | setattr(args, key, arg) 73 | 74 | if args.imputeronly: 75 | assert args.pretrain is not None 76 | 77 | mask = args.mask 78 | obs_prob = args.obs_prob 79 | obs_prob_high = args.obs_prob_high 80 | block_len = args.block_len 81 | if block_len == 0: 82 | block_len = None 83 | 84 | if args.maskgen == 'sigmoid': 85 | hard_sigmoid = False 86 | elif args.maskgen == 'hardsigmoid': 87 | hard_sigmoid = True 88 | elif args.maskgen == 'fusion': 89 | hard_sigmoid = -.1, 1.1 90 | else: 91 | raise NotImplementedError 92 | 93 | if mask == 'indep': 94 | if obs_prob_high is None: 95 | mask_str = f'indep_{obs_prob:g}' 96 | else: 97 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}' 98 | elif mask == 'block': 99 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize') 100 | else: 101 | raise NotImplementedError 102 | 103 | path = '{}_{}_{}'.format( 104 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), 105 | '_'.join([ 106 | f'tau_{args.tau:g}', 107 | f'maskgen_{args.maskgen}', 108 | f'coef_{args.alpha:g}_{args.beta:g}_{args.gamma:g}', 109 | mask_str, 110 | ])) 111 | 112 | if not args.resume: 113 | output_dir = Path('results') / 'celeba' / path 114 | print(output_dir) 115 | 116 | if mask == 'indep': 117 | data = IndepMaskedCelebA( 118 | data_dir=args.data_dir, 119 | obs_prob=obs_prob, obs_prob_high=obs_prob_high) 120 | elif mask == 'block': 121 | data = BlockMaskedCelebA( 122 | data_dir=args.data_dir, block_len=block_len) 123 | 124 | n_gpu = torch.cuda.device_count() 125 | print(f'Use {n_gpu} GPUs.') 126 | data_gen = parallelize(ConvDataGenerator()) 127 | mask_gen = parallelize(ConvMaskGenerator(hard_sigmoid=hard_sigmoid)) 128 | imputer = UNetImputer().to(device) 129 | 130 | data_critic = parallelize(ConvCritic(n_channels=3)) 131 | mask_critic = parallelize(ConvCritic(n_channels=1)) 132 | impu_critic = parallelize(ConvCritic(n_channels=3)) 133 | 134 | misgan_impute(args, data_gen, mask_gen, imputer, 135 | data_critic, mask_critic, impu_critic, 136 | data, output_dir, checkpoint) 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /src-torch1.6/fcnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class FullyConnectedNet(nn.Module): 5 | def __init__(self, weights, output_shape=None): 6 | super().__init__() 7 | n_layers = len(weights) - 1 8 | 9 | layers = [nn.Linear(weights[0], weights[1])] 10 | for i in range(1, n_layers): 11 | layers.extend([nn.ReLU(), nn.Linear(weights[i], weights[i + 1])]) 12 | 13 | self.model = nn.Sequential(*layers) 14 | self.output_shape = output_shape 15 | 16 | def forward(self, input): 17 | output = self.model(input.view(input.shape[0], -1)) 18 | if self.output_shape is not None: 19 | output = output.view(self.output_shape) 20 | return output 21 | -------------------------------------------------------------------------------- /src-torch1.6/fid.py: -------------------------------------------------------------------------------- 1 | """Code adapted from https://github.com/mseitzer/pytorch-fid 2 | """ 3 | from pathlib import Path 4 | import torch 5 | import numpy as np 6 | from scipy import linalg 7 | import time 8 | import sys 9 | from inception import InceptionV3 10 | 11 | 12 | use_cuda = torch.cuda.is_available() 13 | device = torch.device('cuda' if use_cuda else 'cpu') 14 | 15 | FEATURE_DIM = 2048 16 | RESIZE = 299 17 | 18 | 19 | def get_activations(image_iterator, images, model, verbose=True): 20 | """Calculates the activations of the pool_3 layer for all images. 21 | 22 | Params: 23 | -- image_iterator 24 | : A generator that generates a batch of images at a time. 25 | -- images : Number of images that will be generated by 26 | image_iterator. 27 | -- model : Instance of inception model 28 | -- verbose : If set to True and parameter out_step is given, the number 29 | of calculated batches is reported. 30 | Returns: 31 | -- A numpy array of dimension (num images, dims) that contains the 32 | activations of the given tensor when feeding inception with the 33 | query tensor. 34 | """ 35 | model.eval() 36 | 37 | if not sys.stdout.isatty(): 38 | verbose = False 39 | 40 | pred_arr = np.empty((images, FEATURE_DIM)) 41 | end = 0 42 | t0 = time.time() 43 | 44 | for batch in image_iterator: 45 | if not isinstance(batch, torch.Tensor): 46 | batch = batch[0] 47 | start = end 48 | batch_size = batch.shape[0] 49 | end = start + batch_size 50 | 51 | with torch.no_grad(): 52 | batch = batch.to(device) 53 | pred = model(batch)[0] 54 | batch_feature = pred.cpu().numpy().reshape(batch_size, -1) 55 | pred_arr[start:end] = batch_feature 56 | 57 | if verbose: 58 | print('\rProcessed: {} time: {:.2f}'.format( 59 | end, time.time() - t0), end='', flush=True) 60 | 61 | assert end == images 62 | 63 | if verbose: 64 | print(' done') 65 | 66 | return pred_arr 67 | 68 | 69 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 70 | """Numpy implementation of the Frechet Distance. 71 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 72 | and X_2 ~ N(mu_2, C_2) is 73 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 74 | 75 | Stable version by Dougal J. Sutherland. 76 | 77 | Params: 78 | -- mu1 : Numpy array containing the activations of a layer of the 79 | inception net (like returned by the function 'get_predictions') 80 | for generated samples. 81 | -- mu2 : The sample mean over activations, precalculated on an 82 | representive data set. 83 | -- sigma1: The covariance matrix over activations for generated samples. 84 | -- sigma2: The covariance matrix over activations, precalculated on an 85 | representive data set. 86 | 87 | Returns: 88 | -- : The Frechet Distance. 89 | """ 90 | 91 | mu1 = np.atleast_1d(mu1) 92 | mu2 = np.atleast_1d(mu2) 93 | 94 | sigma1 = np.atleast_2d(sigma1) 95 | sigma2 = np.atleast_2d(sigma2) 96 | 97 | assert mu1.shape == mu2.shape, \ 98 | 'Training and test mean vectors have different lengths' 99 | assert sigma1.shape == sigma2.shape, \ 100 | 'Training and test covariances have different dimensions' 101 | 102 | diff = mu1 - mu2 103 | 104 | # Product might be almost singular 105 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 106 | if not np.isfinite(covmean).all(): 107 | msg = ('fid calculation produces singular product; ' 108 | 'adding %s to diagonal of cov estimates') % eps 109 | print(msg) 110 | offset = np.eye(sigma1.shape[0]) * eps 111 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 112 | 113 | # Numerical error might give slight imaginary component 114 | if np.iscomplexobj(covmean): 115 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 116 | m = np.max(np.abs(covmean.imag)) 117 | raise ValueError('Imaginary component {}'.format(m)) 118 | covmean = covmean.real 119 | 120 | tr_covmean = np.trace(covmean) 121 | 122 | return (diff.dot(diff) + np.trace(sigma1) + 123 | np.trace(sigma2) - 2 * tr_covmean) 124 | 125 | 126 | def calculate_activation_statistics(image_iterator, images, model, 127 | verbose=False): 128 | """Calculation of the statistics used by the FID. 129 | Params: 130 | -- image_iterator 131 | : A generator that generates a batch of images at a time. 132 | -- images : Number of images that will be generated by 133 | image_iterator. 134 | -- model : Instance of inception model 135 | -- verbose : If set to True and parameter out_step is given, the 136 | number of calculated batches is reported. 137 | Returns: 138 | -- mu : The mean over samples of the activations of the pool_3 layer of 139 | the inception model. 140 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 141 | the inception model. 142 | """ 143 | act = get_activations(image_iterator, images, model, verbose) 144 | mu = np.mean(act, axis=0) 145 | sigma = np.cov(act, rowvar=False) 146 | return mu, sigma 147 | 148 | 149 | class FID: 150 | def __init__(self, data_name, verbose=True): 151 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[FEATURE_DIM] 152 | model = InceptionV3([block_idx], RESIZE).to(device) 153 | self.verbose = verbose 154 | 155 | stats_dir = Path('fid_stats') 156 | stats_file = stats_dir / '{}_act_{}_{}.npz'.format( 157 | data_name, FEATURE_DIM, RESIZE) 158 | 159 | try: 160 | f = np.load(str(stats_file)) 161 | mu, sigma = f['mu'], f['sigma'] 162 | f.close() 163 | except FileNotFoundError: 164 | data_loader, images = self.complete_data() 165 | mu, sigma = calculate_activation_statistics( 166 | data_loader, images, model, verbose) 167 | stats_dir.mkdir(parents=True, exist_ok=True) 168 | np.savez(stats_file, mu=mu, sigma=sigma) 169 | 170 | self.model = model 171 | self.stats = mu, sigma 172 | 173 | def complete_data(self): 174 | raise NotImplementedError 175 | 176 | def fid(self, image_iterator, images): 177 | mu, sigma = calculate_activation_statistics( 178 | image_iterator, images, self.model, verbose=self.verbose) 179 | return calculate_frechet_distance(mu, sigma, *self.stats) 180 | 181 | 182 | class BaseSampler: 183 | def __init__(self, images): 184 | self.images = images 185 | 186 | def __iter__(self): 187 | self.n = 0 188 | return self 189 | 190 | def __next__(self): 191 | if self.n < self.images: 192 | batch = self.sample() 193 | batch_size = batch.shape[0] 194 | self.n += batch_size 195 | if self.n > self.images: 196 | return batch[:-(self.n - self.images)] 197 | return batch 198 | else: 199 | raise StopIteration 200 | 201 | def sample(self): 202 | raise NotImplementedError 203 | 204 | 205 | class BaseImputationSampler: 206 | def __init__(self, data_loader): 207 | self.data_loader = data_loader 208 | 209 | def __iter__(self): 210 | self.data_iter = iter(self.data_loader) 211 | return self 212 | 213 | def __next__(self): 214 | data, mask = next(self.data_iter)[:2] 215 | data = data.to(device) 216 | mask = mask.float()[:, None].to(device) 217 | imputed_data = self.impute(data, mask) 218 | return mask * data + (1 - mask) * imputed_data 219 | 220 | def impute(self, data, mask): 221 | raise NotImplementedError 222 | -------------------------------------------------------------------------------- /src-torch1.6/imputer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fcnet import FullyConnectedNet 4 | from unet import UnetSkipConnectionBlock 5 | 6 | 7 | # Code adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 8 | class UNet(nn.Module): 9 | def __init__(self, input_nc=3, output_nc=3, ngf=64, layers=5, 10 | norm_layer=nn.BatchNorm2d): 11 | super().__init__() 12 | 13 | mid_layers = layers - 2 14 | fact = 2**mid_layers 15 | 16 | unet_block = UnetSkipConnectionBlock( 17 | ngf * fact, ngf * fact, input_nc=None, submodule=None, 18 | norm_layer=norm_layer, innermost=True) 19 | 20 | for _ in range(mid_layers): 21 | half_fact = fact // 2 22 | unet_block = UnetSkipConnectionBlock( 23 | ngf * half_fact, ngf * fact, input_nc=None, 24 | submodule=unet_block, norm_layer=norm_layer) 25 | fact = half_fact 26 | 27 | unet_block = UnetSkipConnectionBlock( 28 | output_nc, ngf, input_nc=input_nc, submodule=unet_block, 29 | outermost=True, norm_layer=norm_layer) 30 | 31 | self.model = unet_block 32 | 33 | def forward(self, input): 34 | return self.model(input) 35 | 36 | 37 | class Imputer(nn.Module): 38 | def __init__(self): 39 | super().__init__() 40 | self.transform = lambda x: torch.sigmoid(x) 41 | 42 | def forward(self, input, mask, noise): 43 | net = input * mask + noise * (1 - mask) 44 | net = self.imputer_net(net) 45 | net = self.transform(net) 46 | # NOT replacing observed part with input data for computing 47 | # autoencoding loss. 48 | # return input * mask + net * (1 - mask) 49 | return net 50 | 51 | 52 | class UNetImputer(Imputer): 53 | def __init__(self, *args, **kwargs): 54 | super().__init__() 55 | self.imputer_net = UNet(*args, **kwargs) 56 | 57 | 58 | class FullyConnectedImputer(Imputer): 59 | def __init__(self, *args, **kwargs): 60 | super().__init__() 61 | self.imputer_net = FullyConnectedNet(*args, **kwargs) 62 | -------------------------------------------------------------------------------- /src-torch1.6/inception.py: -------------------------------------------------------------------------------- 1 | """Code from https://github.com/mseitzer/pytorch-fid 2 | """ 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import models 6 | 7 | 8 | class InceptionV3(nn.Module): 9 | """Pretrained InceptionV3 network returning feature maps""" 10 | 11 | # Index of default block of inception to return, 12 | # corresponds to output of final average pooling 13 | DEFAULT_BLOCK_INDEX = 3 14 | 15 | # Maps feature dimensionality to their output blocks indices 16 | BLOCK_INDEX_BY_DIM = { 17 | 64: 0, # First max pooling features 18 | 192: 1, # Second max pooling featurs 19 | 768: 2, # Pre-aux classifier features 20 | 2048: 3 # Final average pooling features 21 | } 22 | 23 | def __init__(self, 24 | output_blocks=[DEFAULT_BLOCK_INDEX], 25 | resize_input=299, # -1: not resize 26 | normalize_input=True, 27 | requires_grad=False): 28 | """Build pretrained InceptionV3 29 | 30 | Parameters 31 | ---------- 32 | output_blocks : list of int 33 | Indices of blocks to return features of. Possible values are: 34 | - 0: corresponds to output of first max pooling 35 | - 1: corresponds to output of second max pooling 36 | - 2: corresponds to output which is fed to aux classifier 37 | - 3: corresponds to output of final average pooling 38 | resize_input : bool 39 | If true, bilinearly resizes input to width and height 299 before 40 | feeding input to model. As the network without fully connected 41 | layers is fully convolutional, it should be able to handle inputs 42 | of arbitrary size, so resizing might not be strictly needed 43 | normalize_input : bool 44 | If true, normalizes the input to the statistics the pretrained 45 | Inception network expects 46 | requires_grad : bool 47 | If true, parameters of the model require gradient. Possibly useful 48 | for finetuning the network 49 | """ 50 | super(InceptionV3, self).__init__() 51 | 52 | self.resize_input = resize_input 53 | self.normalize_input = normalize_input 54 | self.output_blocks = sorted(output_blocks) 55 | self.last_needed_block = max(output_blocks) 56 | 57 | assert self.last_needed_block <= 3, \ 58 | 'Last possible output block index is 3' 59 | 60 | self.blocks = nn.ModuleList() 61 | 62 | inception = models.inception_v3(pretrained=True) 63 | 64 | # Block 0: input to maxpool1 65 | block0 = [ 66 | inception.Conv2d_1a_3x3, 67 | inception.Conv2d_2a_3x3, 68 | inception.Conv2d_2b_3x3, 69 | nn.MaxPool2d(kernel_size=3, stride=2) 70 | ] 71 | self.blocks.append(nn.Sequential(*block0)) 72 | 73 | # Block 1: maxpool1 to maxpool2 74 | if self.last_needed_block >= 1: 75 | block1 = [ 76 | inception.Conv2d_3b_1x1, 77 | inception.Conv2d_4a_3x3, 78 | nn.MaxPool2d(kernel_size=3, stride=2) 79 | ] 80 | self.blocks.append(nn.Sequential(*block1)) 81 | 82 | # Block 2: maxpool2 to aux classifier 83 | if self.last_needed_block >= 2: 84 | block2 = [ 85 | inception.Mixed_5b, 86 | inception.Mixed_5c, 87 | inception.Mixed_5d, 88 | inception.Mixed_6a, 89 | inception.Mixed_6b, 90 | inception.Mixed_6c, 91 | inception.Mixed_6d, 92 | inception.Mixed_6e, 93 | ] 94 | self.blocks.append(nn.Sequential(*block2)) 95 | 96 | # Block 3: aux classifier to final avgpool 97 | if self.last_needed_block >= 3: 98 | block3 = [ 99 | inception.Mixed_7a, 100 | inception.Mixed_7b, 101 | inception.Mixed_7c, 102 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 103 | ] 104 | self.blocks.append(nn.Sequential(*block3)) 105 | 106 | for param in self.parameters(): 107 | param.requires_grad = requires_grad 108 | 109 | def forward(self, inp): 110 | """Get Inception feature maps 111 | 112 | Parameters 113 | ---------- 114 | inp : torch.autograd.Variable 115 | Input tensor of shape Bx3xHxW. Values are expected to be in 116 | range (0, 1) 117 | 118 | Returns 119 | ------- 120 | List of torch.autograd.Variable, corresponding to the selected output 121 | block, sorted ascending by index 122 | """ 123 | outp = [] 124 | x = inp 125 | 126 | if self.resize_input > 0: 127 | # size = 299 128 | x = F.interpolate(x, size=(self.resize_input, self.resize_input), 129 | mode='bilinear', align_corners=True) 130 | 131 | if self.normalize_input: 132 | x = x.clone() 133 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 134 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 135 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 136 | 137 | for idx, block in enumerate(self.blocks): 138 | x = block(x) 139 | if idx in self.output_blocks: 140 | outp.append(x) 141 | 142 | if idx == self.last_needed_block: 143 | break 144 | 145 | return outp 146 | -------------------------------------------------------------------------------- /src-torch1.6/masked_celeba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | class MaskedCelebA(datasets.ImageFolder): 8 | def __init__(self, data_dir='celeba-data', image_size=64, random_seed=0): 9 | transform = transforms.Compose([ 10 | transforms.CenterCrop(108), 11 | transforms.Resize(size=image_size, interpolation=Image.BICUBIC), 12 | transforms.ToTensor(), 13 | # transforms.Normalize(mean=(.5, .5, .5), std=(.5, .5, .5)), 14 | ]) 15 | 16 | super().__init__(data_dir, transform) 17 | 18 | self.rnd = np.random.RandomState(random_seed) 19 | self.image_size = image_size 20 | self.generate_masks() 21 | 22 | def __getitem__(self, index): 23 | image, label = super().__getitem__(index) 24 | return image, self.mask[index], label, index 25 | 26 | def __len__(self): 27 | return super().__len__() 28 | 29 | 30 | class BlockMaskedCelebA(MaskedCelebA): 31 | def __init__(self, block_len=None, *args, **kwargs): 32 | self.block_len = block_len 33 | super().__init__(*args, **kwargs) 34 | 35 | def generate_masks(self): 36 | d0_len = d1_len = self.image_size 37 | d0_min_len = 12 38 | d0_max_len = d0_len - d0_min_len 39 | d1_min_len = 12 40 | d1_max_len = d1_len - d1_min_len 41 | 42 | n_masks = len(self) 43 | self.mask = [None] * n_masks 44 | self.mask_info = [None] * n_masks 45 | for i in range(n_masks): 46 | if self.block_len is None: 47 | d0_mask_len = self.rnd.randint(d0_min_len, d0_max_len) 48 | d1_mask_len = self.rnd.randint(d1_min_len, d1_max_len) 49 | else: 50 | d0_mask_len = d1_mask_len = self.block_len 51 | 52 | d0_start = self.rnd.randint(0, d0_len - d0_mask_len + 1) 53 | d1_start = self.rnd.randint(0, d1_len - d1_mask_len + 1) 54 | 55 | mask = torch.zeros((d0_len, d1_len), dtype=torch.uint8) 56 | mask[d0_start:(d0_start + d0_mask_len), 57 | d1_start:(d1_start + d1_mask_len)] = 1 58 | self.mask[i] = mask 59 | self.mask_info[i] = d0_start, d1_start, d0_mask_len, d1_mask_len 60 | 61 | 62 | class IndepMaskedCelebA(MaskedCelebA): 63 | def __init__(self, obs_prob=.2, obs_prob_high=None, *args, **kwargs): 64 | self.prob = obs_prob 65 | self.prob_high = obs_prob_high 66 | super().__init__(*args, **kwargs) 67 | 68 | def generate_masks(self): 69 | imsize = self.image_size 70 | prob = self.prob 71 | prob_high = self.prob_high 72 | n_masks = len(self) 73 | self.mask = [None] * n_masks 74 | for i in range(n_masks): 75 | if prob_high is None: 76 | p = prob 77 | else: 78 | p = self.rnd.uniform(prob, prob_high) 79 | self.mask[i] = torch.ByteTensor(imsize, imsize).bernoulli_(p) 80 | -------------------------------------------------------------------------------- /src-torch1.6/masked_mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torchvision import datasets, transforms 4 | import numpy as np 5 | 6 | 7 | class MaskedMNIST(Dataset): 8 | def __init__(self, data_dir='mnist-data', image_size=28, random_seed=0): 9 | self.rnd = np.random.RandomState(random_seed) 10 | self.image_size = image_size 11 | if image_size == 28: 12 | self.data = datasets.MNIST( 13 | data_dir, train=True, download=True, 14 | transform=transforms.ToTensor()) 15 | else: 16 | self.data = datasets.MNIST( 17 | data_dir, train=True, download=True, 18 | transform=transforms.Compose([ 19 | transforms.Resize(image_size), transforms.ToTensor()])) 20 | self.generate_masks() 21 | 22 | def __getitem__(self, index): 23 | image, label = self.data[index] 24 | return image, self.mask[index], label, index 25 | 26 | def __len__(self): 27 | return len(self.data) 28 | 29 | def generate_masks(self): 30 | raise NotImplementedError 31 | 32 | 33 | class BlockMaskedMNIST(MaskedMNIST): 34 | def __init__(self, block_len=None, *args, **kwargs): 35 | self.block_len = block_len 36 | super().__init__(*args, **kwargs) 37 | 38 | def generate_masks(self): 39 | d0_len = d1_len = self.image_size 40 | d0_min_len = 7 41 | d0_max_len = d0_len - d0_min_len 42 | d1_min_len = 7 43 | d1_max_len = d1_len - d1_min_len 44 | 45 | n_masks = len(self) 46 | self.mask = [None] * n_masks 47 | self.mask_info = [None] * n_masks 48 | for i in range(n_masks): 49 | if self.block_len is None: 50 | d0_mask_len = self.rnd.randint(d0_min_len, d0_max_len) 51 | d1_mask_len = self.rnd.randint(d1_min_len, d1_max_len) 52 | else: 53 | d0_mask_len = d1_mask_len = self.block_len 54 | 55 | d0_start = self.rnd.randint(0, d0_len - d0_mask_len + 1) 56 | d1_start = self.rnd.randint(0, d1_len - d1_mask_len + 1) 57 | 58 | mask = torch.zeros((d0_len, d1_len), dtype=torch.uint8) 59 | mask[d0_start:(d0_start + d0_mask_len), 60 | d1_start:(d1_start + d1_mask_len)] = 1 61 | self.mask[i] = mask 62 | self.mask_info[i] = d0_start, d1_start, d0_mask_len, d1_mask_len 63 | 64 | 65 | class IndepMaskedMNIST(MaskedMNIST): 66 | def __init__(self, obs_prob=.2, obs_prob_high=None, *args, **kwargs): 67 | self.prob = obs_prob 68 | self.prob_high = obs_prob_high 69 | super().__init__(*args, **kwargs) 70 | 71 | def generate_masks(self): 72 | imsize = self.image_size 73 | prob = self.prob 74 | prob_high = self.prob_high 75 | n_masks = len(self) 76 | self.mask = [None] * n_masks 77 | for i in range(n_masks): 78 | if prob_high is None: 79 | p = prob 80 | else: 81 | p = self.rnd.uniform(prob, prob_high) 82 | self.mask[i] = torch.ByteTensor(imsize, imsize).bernoulli_(p) 83 | -------------------------------------------------------------------------------- /src-torch1.6/misgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.utils.data import DataLoader 4 | import time 5 | import pylab as plt 6 | import seaborn as sns 7 | from collections import defaultdict 8 | from plot import plot_samples 9 | from utils import CriticUpdater, mkdir, mask_data 10 | 11 | 12 | use_cuda = torch.cuda.is_available() 13 | device = torch.device('cuda' if use_cuda else 'cpu') 14 | 15 | 16 | def misgan(args, data_gen, mask_gen, data_critic, mask_critic, data, 17 | output_dir, checkpoint=None): 18 | n_critic = args.n_critic 19 | gp_lambda = args.gp_lambda 20 | batch_size = args.batch_size 21 | nz = args.n_latent 22 | epochs = args.epoch 23 | plot_interval = args.plot_interval 24 | save_interval = args.save_interval 25 | alpha = args.alpha 26 | tau = args.tau 27 | 28 | gen_data_dir = mkdir(output_dir / 'img') 29 | gen_mask_dir = mkdir(output_dir / 'mask') 30 | log_dir = mkdir(output_dir / 'log') 31 | model_dir = mkdir(output_dir / 'model') 32 | 33 | data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, 34 | drop_last=True) 35 | n_batch = len(data_loader) 36 | 37 | data_noise = torch.FloatTensor(batch_size, nz).to(device) 38 | mask_noise = torch.FloatTensor(batch_size, nz).to(device) 39 | 40 | # Interpolation coefficient 41 | eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device) 42 | 43 | # For computing gradient penalty 44 | ones = torch.ones(batch_size).to(device) 45 | 46 | lrate = 1e-4 47 | # lrate = 1e-5 48 | data_gen_optimizer = optim.Adam( 49 | data_gen.parameters(), lr=lrate, betas=(.5, .9)) 50 | mask_gen_optimizer = optim.Adam( 51 | mask_gen.parameters(), lr=lrate, betas=(.5, .9)) 52 | 53 | data_critic_optimizer = optim.Adam( 54 | data_critic.parameters(), lr=lrate, betas=(.5, .9)) 55 | mask_critic_optimizer = optim.Adam( 56 | mask_critic.parameters(), lr=lrate, betas=(.5, .9)) 57 | 58 | update_data_critic = CriticUpdater( 59 | data_critic, data_critic_optimizer, eps, ones, gp_lambda) 60 | update_mask_critic = CriticUpdater( 61 | mask_critic, mask_critic_optimizer, eps, ones, gp_lambda) 62 | 63 | start_epoch = 0 64 | critic_updates = 0 65 | log = defaultdict(list) 66 | 67 | if checkpoint: 68 | data_gen.load_state_dict(checkpoint['data_gen']) 69 | mask_gen.load_state_dict(checkpoint['mask_gen']) 70 | data_critic.load_state_dict(checkpoint['data_critic']) 71 | mask_critic.load_state_dict(checkpoint['mask_critic']) 72 | data_gen_optimizer.load_state_dict(checkpoint['data_gen_opt']) 73 | mask_gen_optimizer.load_state_dict(checkpoint['mask_gen_opt']) 74 | data_critic_optimizer.load_state_dict(checkpoint['data_critic_opt']) 75 | mask_critic_optimizer.load_state_dict(checkpoint['mask_critic_opt']) 76 | start_epoch = checkpoint['epoch'] 77 | critic_updates = checkpoint['critic_updates'] 78 | log = checkpoint['log'] 79 | 80 | with (log_dir / 'gpu.txt').open('a') as f: 81 | print(torch.cuda.device_count(), start_epoch, file=f) 82 | 83 | def save_model(path, epoch, critic_updates=0): 84 | torch.save({ 85 | 'data_gen': data_gen.state_dict(), 86 | 'mask_gen': mask_gen.state_dict(), 87 | 'data_critic': data_critic.state_dict(), 88 | 'mask_critic': mask_critic.state_dict(), 89 | 'data_gen_opt': data_gen_optimizer.state_dict(), 90 | 'mask_gen_opt': mask_gen_optimizer.state_dict(), 91 | 'data_critic_opt': data_critic_optimizer.state_dict(), 92 | 'mask_critic_opt': mask_critic_optimizer.state_dict(), 93 | 'epoch': epoch + 1, 94 | 'critic_updates': critic_updates, 95 | 'log': log, 96 | 'args': args, 97 | }, str(path)) 98 | 99 | sns.set() 100 | 101 | start = time.time() 102 | epoch_start = start 103 | 104 | for epoch in range(start_epoch, epochs): 105 | sum_data_loss, sum_mask_loss = 0, 0 106 | for real_data, real_mask, _, _ in data_loader: 107 | # Assume real_data and mask have the same number of channels. 108 | # Could be modified to handle multi-channel images and 109 | # single-channel masks. 110 | real_mask = real_mask.float()[:, None] 111 | 112 | real_data = real_data.to(device) 113 | real_mask = real_mask.to(device) 114 | 115 | masked_real_data = mask_data(real_data, real_mask, tau) 116 | 117 | # Update discriminators' parameters 118 | data_noise.normal_() 119 | mask_noise.normal_() 120 | 121 | fake_data = data_gen(data_noise) 122 | fake_mask = mask_gen(mask_noise) 123 | 124 | masked_fake_data = mask_data(fake_data, fake_mask, tau) 125 | 126 | update_data_critic(masked_real_data, masked_fake_data) 127 | update_mask_critic(real_mask, fake_mask) 128 | 129 | sum_data_loss += update_data_critic.loss_value 130 | sum_mask_loss += update_mask_critic.loss_value 131 | 132 | critic_updates += 1 133 | 134 | if critic_updates == n_critic: 135 | critic_updates = 0 136 | 137 | # Update generators' parameters 138 | for p in data_critic.parameters(): 139 | p.requires_grad_(False) 140 | for p in mask_critic.parameters(): 141 | p.requires_grad_(False) 142 | 143 | data_noise.normal_() 144 | mask_noise.normal_() 145 | 146 | fake_data = data_gen(data_noise) 147 | fake_mask = mask_gen(mask_noise) 148 | masked_fake_data = mask_data(fake_data, fake_mask, tau) 149 | 150 | data_loss = -data_critic(masked_fake_data).mean() 151 | data_gen.zero_grad() 152 | data_loss.backward() 153 | data_gen_optimizer.step() 154 | 155 | data_noise.normal_() 156 | mask_noise.normal_() 157 | 158 | fake_data = data_gen(data_noise) 159 | fake_mask = mask_gen(mask_noise) 160 | masked_fake_data = mask_data(fake_data, fake_mask, tau) 161 | 162 | data_loss = -data_critic(masked_fake_data).mean() 163 | mask_loss = -mask_critic(fake_mask).mean() 164 | mask_gen.zero_grad() 165 | (mask_loss + data_loss * alpha).backward() 166 | mask_gen_optimizer.step() 167 | 168 | for p in data_critic.parameters(): 169 | p.requires_grad_(True) 170 | for p in mask_critic.parameters(): 171 | p.requires_grad_(True) 172 | 173 | mean_data_loss = sum_data_loss / n_batch 174 | mean_mask_loss = sum_mask_loss / n_batch 175 | log['data loss', 'data_loss'].append(mean_data_loss) 176 | log['mask loss', 'mask_loss'].append(mean_mask_loss) 177 | 178 | for (name, shortname), trace in log.items(): 179 | fig, ax = plt.subplots(figsize=(6, 4)) 180 | ax.plot(trace) 181 | ax.set_ylabel(name) 182 | ax.set_xlabel('epoch') 183 | fig.savefig(str(log_dir / f'{shortname}.png'), dpi=300) 184 | plt.close(fig) 185 | 186 | if plot_interval > 0 and (epoch + 1) % plot_interval == 0: 187 | print(f'[{epoch:4}] {mean_data_loss:12.4f} {mean_mask_loss:12.4f}') 188 | 189 | filename = f'{epoch:04d}.png' 190 | 191 | data_gen.eval() 192 | mask_gen.eval() 193 | 194 | with torch.no_grad(): 195 | data_noise.normal_() 196 | mask_noise.normal_() 197 | 198 | data_samples = data_gen(data_noise) 199 | plot_samples(data_samples, str(gen_data_dir / filename)) 200 | 201 | mask_samples = mask_gen(mask_noise) 202 | plot_samples(mask_samples, str(gen_mask_dir / filename)) 203 | 204 | data_gen.train() 205 | mask_gen.train() 206 | 207 | if save_interval > 0 and (epoch + 1) % save_interval == 0: 208 | save_model(model_dir / f'{epoch:04d}.pth', epoch, critic_updates) 209 | 210 | epoch_end = time.time() 211 | time_elapsed = epoch_end - start 212 | epoch_time = epoch_end - epoch_start 213 | epoch_start = epoch_end 214 | with (log_dir / 'time.txt').open('a') as f: 215 | print(epoch, epoch_time, time_elapsed, file=f) 216 | save_model(log_dir / 'checkpoint.pth', epoch, critic_updates) 217 | 218 | print(output_dir) 219 | -------------------------------------------------------------------------------- /src-torch1.6/misgan_impute.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.utils.data import DataLoader 4 | import time 5 | import pylab as plt 6 | import seaborn as sns 7 | from collections import defaultdict 8 | from plot import plot_grid, plot_samples 9 | from utils import CriticUpdater, mask_norm, mkdir, mask_data 10 | 11 | 12 | use_cuda = torch.cuda.is_available() 13 | device = torch.device('cuda' if use_cuda else 'cpu') 14 | 15 | 16 | def misgan_impute(args, data_gen, mask_gen, imputer, 17 | data_critic, mask_critic, impu_critic, 18 | data, output_dir, checkpoint=None): 19 | n_critic = args.n_critic 20 | gp_lambda = args.gp_lambda 21 | batch_size = args.batch_size 22 | nz = args.n_latent 23 | epochs = args.epoch 24 | plot_interval = args.plot_interval 25 | save_model_interval = args.save_interval 26 | alpha = args.alpha 27 | beta = args.beta 28 | gamma = args.gamma 29 | tau = args.tau 30 | update_all_networks = not args.imputeronly 31 | 32 | gen_data_dir = mkdir(output_dir / 'img') 33 | gen_mask_dir = mkdir(output_dir / 'mask') 34 | impute_dir = mkdir(output_dir / 'impute') 35 | log_dir = mkdir(output_dir / 'log') 36 | model_dir = mkdir(output_dir / 'model') 37 | 38 | data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, 39 | drop_last=True, num_workers=args.workers) 40 | n_batch = len(data_loader) 41 | data_shape = data[0][0].shape 42 | 43 | data_noise = torch.FloatTensor(batch_size, nz).to(device) 44 | mask_noise = torch.FloatTensor(batch_size, nz).to(device) 45 | impu_noise = torch.FloatTensor(batch_size, *data_shape).to(device) 46 | 47 | # Interpolation coefficient 48 | eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device) 49 | 50 | # For computing gradient penalty 51 | ones = torch.ones(batch_size).to(device) 52 | 53 | lrate = 1e-4 54 | imputer_lrate = 2e-4 55 | data_gen_optimizer = optim.Adam( 56 | data_gen.parameters(), lr=lrate, betas=(.5, .9)) 57 | mask_gen_optimizer = optim.Adam( 58 | mask_gen.parameters(), lr=lrate, betas=(.5, .9)) 59 | imputer_optimizer = optim.Adam( 60 | imputer.parameters(), lr=imputer_lrate, betas=(.5, .9)) 61 | 62 | data_critic_optimizer = optim.Adam( 63 | data_critic.parameters(), lr=lrate, betas=(.5, .9)) 64 | mask_critic_optimizer = optim.Adam( 65 | mask_critic.parameters(), lr=lrate, betas=(.5, .9)) 66 | impu_critic_optimizer = optim.Adam( 67 | impu_critic.parameters(), lr=imputer_lrate, betas=(.5, .9)) 68 | 69 | update_data_critic = CriticUpdater( 70 | data_critic, data_critic_optimizer, eps, ones, gp_lambda) 71 | update_mask_critic = CriticUpdater( 72 | mask_critic, mask_critic_optimizer, eps, ones, gp_lambda) 73 | update_impu_critic = CriticUpdater( 74 | impu_critic, impu_critic_optimizer, eps, ones, gp_lambda) 75 | 76 | start_epoch = 0 77 | critic_updates = 0 78 | log = defaultdict(list) 79 | 80 | if args.resume: 81 | data_gen.load_state_dict(checkpoint['data_gen']) 82 | mask_gen.load_state_dict(checkpoint['mask_gen']) 83 | imputer.load_state_dict(checkpoint['imputer']) 84 | data_critic.load_state_dict(checkpoint['data_critic']) 85 | mask_critic.load_state_dict(checkpoint['mask_critic']) 86 | impu_critic.load_state_dict(checkpoint['impu_critic']) 87 | data_gen_optimizer.load_state_dict(checkpoint['data_gen_opt']) 88 | mask_gen_optimizer.load_state_dict(checkpoint['mask_gen_opt']) 89 | imputer_optimizer.load_state_dict(checkpoint['imputer_opt']) 90 | data_critic_optimizer.load_state_dict(checkpoint['data_critic_opt']) 91 | mask_critic_optimizer.load_state_dict(checkpoint['mask_critic_opt']) 92 | impu_critic_optimizer.load_state_dict(checkpoint['impu_critic_opt']) 93 | start_epoch = checkpoint['epoch'] 94 | critic_updates = checkpoint['critic_updates'] 95 | log = checkpoint['log'] 96 | elif args.pretrain: 97 | pretrain = torch.load(args.pretrain, map_location='cpu') 98 | data_gen.load_state_dict(pretrain['data_gen']) 99 | mask_gen.load_state_dict(pretrain['mask_gen']) 100 | data_critic.load_state_dict(pretrain['data_critic']) 101 | mask_critic.load_state_dict(pretrain['mask_critic']) 102 | if 'imputer' in pretrain: 103 | imputer.load_state_dict(pretrain['imputer']) 104 | impu_critic.load_state_dict(pretrain['impu_critic']) 105 | 106 | with (log_dir / 'gpu.txt').open('a') as f: 107 | print(torch.cuda.device_count(), start_epoch, file=f) 108 | 109 | def save_model(path, epoch, critic_updates=0): 110 | torch.save({ 111 | 'data_gen': data_gen.state_dict(), 112 | 'mask_gen': mask_gen.state_dict(), 113 | 'imputer': imputer.state_dict(), 114 | 'data_critic': data_critic.state_dict(), 115 | 'mask_critic': mask_critic.state_dict(), 116 | 'impu_critic': impu_critic.state_dict(), 117 | 'data_gen_opt': data_gen_optimizer.state_dict(), 118 | 'mask_gen_opt': mask_gen_optimizer.state_dict(), 119 | 'imputer_opt': imputer_optimizer.state_dict(), 120 | 'data_critic_opt': data_critic_optimizer.state_dict(), 121 | 'mask_critic_opt': mask_critic_optimizer.state_dict(), 122 | 'impu_critic_opt': impu_critic_optimizer.state_dict(), 123 | 'epoch': epoch + 1, 124 | 'critic_updates': critic_updates, 125 | 'log': log, 126 | 'args': args, 127 | }, str(path)) 128 | 129 | sns.set() 130 | start = time.time() 131 | epoch_start = start 132 | 133 | for epoch in range(start_epoch, epochs): 134 | sum_data_loss, sum_mask_loss, sum_impu_loss = 0, 0, 0 135 | for real_data, real_mask, _, index in data_loader: 136 | # Assume real_data and real_mask have the same number of channels. 137 | # Could be modified to handle multi-channel images and 138 | # single-channel masks. 139 | real_mask = real_mask.float()[:, None] 140 | 141 | real_data = real_data.to(device) 142 | real_mask = real_mask.to(device) 143 | 144 | masked_real_data = mask_data(real_data, real_mask, tau) 145 | 146 | # Update discriminators' parameters 147 | data_noise.normal_() 148 | fake_data = data_gen(data_noise) 149 | 150 | impu_noise.uniform_() 151 | imputed_data = imputer(real_data, real_mask, impu_noise) 152 | masked_imputed_data = mask_data(real_data, real_mask, imputed_data) 153 | 154 | if update_all_networks: 155 | mask_noise.normal_() 156 | fake_mask = mask_gen(mask_noise) 157 | masked_fake_data = mask_data(fake_data, fake_mask, tau) 158 | update_data_critic(masked_real_data, masked_fake_data) 159 | update_mask_critic(real_mask, fake_mask) 160 | 161 | sum_data_loss += update_data_critic.loss_value 162 | sum_mask_loss += update_mask_critic.loss_value 163 | 164 | update_impu_critic(fake_data, masked_imputed_data) 165 | sum_impu_loss += update_impu_critic.loss_value 166 | 167 | critic_updates += 1 168 | 169 | if critic_updates == n_critic: 170 | critic_updates = 0 171 | 172 | # Update generators' parameters 173 | if update_all_networks: 174 | for p in data_critic.parameters(): 175 | p.requires_grad_(False) 176 | for p in mask_critic.parameters(): 177 | p.requires_grad_(False) 178 | for p in impu_critic.parameters(): 179 | p.requires_grad_(False) 180 | 181 | impu_noise.uniform_() 182 | imputed_data = imputer(real_data, real_mask, impu_noise) 183 | masked_imputed_data = mask_data(real_data, real_mask, 184 | imputed_data) 185 | impu_loss = -impu_critic(masked_imputed_data).mean() 186 | 187 | if update_all_networks: 188 | data_noise.normal_() 189 | fake_data = data_gen(data_noise) 190 | mask_noise.normal_() 191 | fake_mask = mask_gen(mask_noise) 192 | masked_fake_data = mask_data(fake_data, fake_mask, tau) 193 | data_loss = -data_critic(masked_fake_data).mean() 194 | mask_loss = -mask_critic(fake_mask).mean() 195 | 196 | mask_gen.zero_grad() 197 | (mask_loss + data_loss * alpha).backward(retain_graph=True) 198 | mask_gen_optimizer.step() 199 | 200 | data_noise.normal_() 201 | fake_data = data_gen(data_noise) 202 | mask_noise.normal_() 203 | fake_mask = mask_gen(mask_noise) 204 | masked_fake_data = mask_data(fake_data, fake_mask, tau) 205 | data_loss = -data_critic(masked_fake_data).mean() 206 | 207 | data_gen.zero_grad() 208 | (data_loss + impu_loss * beta).backward(retain_graph=True) 209 | data_gen_optimizer.step() 210 | 211 | imputer.zero_grad() 212 | if gamma > 0: 213 | imputer_mismatch_loss = mask_norm( 214 | (imputed_data - real_data)**2, real_mask) 215 | (impu_loss + imputer_mismatch_loss * gamma).backward() 216 | else: 217 | impu_loss.backward() 218 | imputer_optimizer.step() 219 | 220 | if update_all_networks: 221 | for p in data_critic.parameters(): 222 | p.requires_grad_(True) 223 | for p in mask_critic.parameters(): 224 | p.requires_grad_(True) 225 | for p in impu_critic.parameters(): 226 | p.requires_grad_(True) 227 | 228 | if update_all_networks: 229 | mean_data_loss = sum_data_loss / n_batch 230 | mean_mask_loss = sum_mask_loss / n_batch 231 | log['data loss', 'data_loss'].append(mean_data_loss) 232 | log['mask loss', 'mask_loss'].append(mean_mask_loss) 233 | mean_impu_loss = sum_impu_loss / n_batch 234 | log['imputer loss', 'impu_loss'].append(mean_impu_loss) 235 | 236 | if plot_interval > 0 and (epoch + 1) % plot_interval == 0: 237 | if update_all_networks: 238 | print('[{:4}] {:12.4f} {:12.4f} {:12.4f}'.format( 239 | epoch, mean_data_loss, mean_mask_loss, mean_impu_loss)) 240 | else: 241 | print('[{:4}] {:12.4f}'.format(epoch, mean_impu_loss)) 242 | 243 | filename = f'{epoch:04d}.png' 244 | with torch.no_grad(): 245 | data_gen.eval() 246 | mask_gen.eval() 247 | imputer.eval() 248 | 249 | data_noise.normal_() 250 | mask_noise.normal_() 251 | 252 | data_samples = data_gen(data_noise) 253 | plot_samples(data_samples, str(gen_data_dir / filename)) 254 | 255 | mask_samples = mask_gen(mask_noise) 256 | plot_samples(mask_samples, str(gen_mask_dir / filename)) 257 | 258 | # Plot imputation results 259 | impu_noise.uniform_() 260 | imputed_data = imputer(real_data, real_mask, impu_noise) 261 | imputed_data = mask_data(real_data, real_mask, imputed_data) 262 | if hasattr(data, 'mask_info'): 263 | bbox = [data.mask_info[idx] for idx in index] 264 | else: 265 | bbox = None 266 | plot_grid(imputed_data, bbox, gap=2, 267 | save_file=str(impute_dir / filename)) 268 | 269 | data_gen.train() 270 | mask_gen.train() 271 | imputer.train() 272 | 273 | for (name, shortname), trace in log.items(): 274 | fig, ax = plt.subplots(figsize=(6, 4)) 275 | ax.plot(trace) 276 | ax.set_ylabel(name) 277 | ax.set_xlabel('epoch') 278 | fig.savefig(str(log_dir / f'{shortname}.png'), dpi=300) 279 | plt.close(fig) 280 | 281 | if save_model_interval > 0 and (epoch + 1) % save_model_interval == 0: 282 | save_model(model_dir / f'{epoch:04d}.pth', epoch, critic_updates) 283 | 284 | epoch_end = time.time() 285 | time_elapsed = epoch_end - start 286 | epoch_time = epoch_end - epoch_start 287 | epoch_start = epoch_end 288 | with (log_dir / 'epoch-time.txt').open('a') as f: 289 | print(epoch, epoch_time, time_elapsed, file=f) 290 | save_model(log_dir / 'checkpoint.pth', epoch, critic_updates) 291 | 292 | print(output_dir) 293 | -------------------------------------------------------------------------------- /src-torch1.6/mnist_critic.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ConvCritic(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | self.DIM = 64 9 | main = nn.Sequential( 10 | nn.Conv2d(1, self.DIM, 5, stride=2, padding=2), 11 | nn.ReLU(True), 12 | nn.Conv2d(self.DIM, 2 * self.DIM, 5, stride=2, padding=2), 13 | nn.ReLU(True), 14 | nn.Conv2d(2 * self.DIM, 4 * self.DIM, 5, stride=2, padding=2), 15 | nn.ReLU(True), 16 | ) 17 | self.main = main 18 | self.output = nn.Linear(4 * 4 * 4 * self.DIM, 1) 19 | 20 | def forward(self, input): 21 | input = input.view(-1, 1, 28, 28) 22 | net = self.main(input) 23 | net = net.view(-1, 4 * 4 * 4 * self.DIM) 24 | net = self.output(net) 25 | return net.view(-1) 26 | 27 | 28 | class FCCritic(nn.Module): 29 | def __init__(self): 30 | super().__init__() 31 | 32 | self.in_dim = 784 33 | self.main = nn.Sequential( 34 | nn.Linear(self.in_dim, 512), 35 | nn.ReLU(True), 36 | nn.Linear(512, 256), 37 | nn.ReLU(True), 38 | nn.Linear(256, 128), 39 | nn.ReLU(True), 40 | nn.Linear(128, 1), 41 | ) 42 | 43 | def forward(self, input): 44 | input = input.view(input.size(0), -1) 45 | out = self.main(input) 46 | return out.view(-1) 47 | -------------------------------------------------------------------------------- /src-torch1.6/mnist_fid.py: -------------------------------------------------------------------------------- 1 | """Code adapted from https://github.com/mseitzer/pytorch-fid 2 | """ 3 | import torch 4 | import numpy as np 5 | from scipy import linalg 6 | from torch.utils.data import DataLoader 7 | from torchvision import datasets, transforms 8 | import argparse 9 | 10 | import mnist_model 11 | from mnist_generator import ConvDataGenerator, FCDataGenerator 12 | from mnist_imputer import ComplementImputer, MaskImputer, FixedNoiseDimImputer 13 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST 14 | from pathlib import Path 15 | 16 | 17 | use_cuda = torch.cuda.is_available() 18 | device = torch.device('cuda' if use_cuda else 'cpu') 19 | 20 | feature_layer = 0 21 | 22 | 23 | def get_activations(image_generator, images, model, verbose=False): 24 | """Calculates the activations of the pool_3 layer for all images. 25 | 26 | Params: 27 | -- image_generator 28 | : A generator that generates a batch of images at a time. 29 | -- images : Number of images that will be generated by 30 | image_generator. 31 | -- model : Instance of inception model 32 | -- verbose : If set to True and parameter out_step is given, the number 33 | of calculated batches is reported. 34 | Returns: 35 | -- A numpy array of dimension (num images, dims) that contains the 36 | activations of the given tensor when feeding inception with the 37 | query tensor. 38 | """ 39 | model.eval() 40 | 41 | pred_arr = None 42 | end = 0 43 | for i, batch in enumerate(image_generator): 44 | if verbose: 45 | print('\rPropagating batch %d' % (i + 1), end='', flush=True) 46 | start = end 47 | batch_size = batch.shape[0] 48 | end = start + batch_size 49 | batch = batch.to(device) 50 | 51 | with torch.no_grad(): 52 | model(batch) 53 | pred = model.feature[feature_layer] 54 | batch_feature = pred.cpu().numpy().reshape(batch_size, -1) 55 | if pred_arr is None: 56 | pred_arr = np.empty((images, batch_feature.shape[1])) 57 | pred_arr[start:end] = batch_feature 58 | 59 | if verbose: 60 | print(' done') 61 | 62 | return pred_arr 63 | 64 | 65 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 66 | """Numpy implementation of the Frechet Distance. 67 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 68 | and X_2 ~ N(mu_2, C_2) is 69 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 70 | 71 | Stable version by Dougal J. Sutherland. 72 | 73 | Params: 74 | -- mu1 : Numpy array containing the activations of a layer of the 75 | inception net (like returned by the function 'get_predictions') 76 | for generated samples. 77 | -- mu2 : The sample mean over activations, precalculated on an 78 | representive data set. 79 | -- sigma1: The covariance matrix over activations for generated samples. 80 | -- sigma2: The covariance matrix over activations, precalculated on an 81 | representive data set. 82 | 83 | Returns: 84 | -- : The Frechet Distance. 85 | """ 86 | 87 | mu1 = np.atleast_1d(mu1) 88 | mu2 = np.atleast_1d(mu2) 89 | 90 | sigma1 = np.atleast_2d(sigma1) 91 | sigma2 = np.atleast_2d(sigma2) 92 | 93 | assert mu1.shape == mu2.shape, \ 94 | 'Training and test mean vectors have different lengths' 95 | assert sigma1.shape == sigma2.shape, \ 96 | 'Training and test covariances have different dimensions' 97 | 98 | diff = mu1 - mu2 99 | 100 | # Product might be almost singular 101 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 102 | if not np.isfinite(covmean).all(): 103 | msg = ('fid calculation produces singular product; ' 104 | 'adding %s to diagonal of cov estimates') % eps 105 | print(msg) 106 | offset = np.eye(sigma1.shape[0]) * eps 107 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 108 | 109 | # Numerical error might give slight imaginary component 110 | if np.iscomplexobj(covmean): 111 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 112 | m = np.max(np.abs(covmean.imag)) 113 | raise ValueError(f'Imaginary component {m}') 114 | covmean = covmean.real 115 | 116 | tr_covmean = np.trace(covmean) 117 | 118 | return (diff.dot(diff) + np.trace(sigma1) + 119 | np.trace(sigma2) - 2 * tr_covmean) 120 | 121 | 122 | def calculate_activation_statistics(image_generator, images, model, 123 | verbose=False, weight=None): 124 | """Calculation of the statistics used by the FID. 125 | Params: 126 | -- image_generator 127 | : A generator that generates a batch of images at a time. 128 | -- images : Number of images that will be generated by 129 | image_generator. 130 | -- model : Instance of inception model 131 | -- verbose : If set to True and parameter out_step is given, the 132 | number of calculated batches is reported. 133 | Returns: 134 | -- mu : The mean over samples of the activations of the pool_3 layer of 135 | the inception model. 136 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 137 | the inception model. 138 | """ 139 | act = get_activations(image_generator, images, model, verbose) 140 | if weight is None: 141 | mu = np.mean(act, axis=0) 142 | sigma = np.cov(act, rowvar=False) 143 | else: 144 | mu = np.average(act, axis=0, weights=weight) 145 | sigma = np.cov(act, rowvar=False, aweights=weight) 146 | return mu, sigma 147 | 148 | 149 | class MNISTModel: 150 | def __init__(self): 151 | model = mnist_model.Net().to(device) 152 | model.eval() 153 | map_location = None if use_cuda else 'cpu' 154 | model.load_state_dict( 155 | torch.load('mnist.pth', map_location=map_location)) 156 | 157 | stats_file = f'mnist_act_{feature_layer}.npz' 158 | try: 159 | f = np.load(stats_file) 160 | m_mnist, s_mnist = f['mu'][:], f['sigma'][:] 161 | f.close() 162 | except FileNotFoundError: 163 | data = datasets.MNIST('data', train=True, download=True, 164 | transform=transforms.ToTensor()) 165 | images = len(data) 166 | batch_size = 64 167 | data_loader = DataLoader([image for image, _ in data], 168 | batch_size=batch_size) 169 | m_mnist, s_mnist = calculate_activation_statistics( 170 | data_loader, images, model, verbose=True) 171 | np.savez(stats_file, mu=m_mnist, sigma=s_mnist) 172 | 173 | self.model = model 174 | self.mnist_stats = m_mnist, s_mnist 175 | 176 | def get_feature(self, samples): 177 | self.model(samples) 178 | feature = self.model.feature[feature_layer] 179 | return feature.cpu().numpy().reshape(samples.shape[0], -1) 180 | 181 | def fid(self, features): 182 | mu = np.mean(features, axis=0) 183 | sigma = np.cov(features, rowvar=False) 184 | return calculate_frechet_distance(mu, sigma, *self.mnist_stats) 185 | 186 | 187 | def data_generator_fid(data_generator, 188 | n_samples=60000, batch_size=64, verbose=False): 189 | mnist_model = MNISTModel() 190 | latent_size = 128 191 | data_noise = torch.FloatTensor(batch_size, latent_size).to(device) 192 | 193 | with torch.no_grad(): 194 | count = 0 195 | features = None 196 | while count < n_samples: 197 | data_noise.normal_() 198 | samples = data_generator(data_noise) 199 | batch_feature = mnist_model.get_feature(samples) 200 | 201 | if features is None: 202 | features = np.empty((n_samples, batch_feature.shape[1])) 203 | 204 | if count + batch_size > n_samples: 205 | batch_size = n_samples - count 206 | features[count:] = batch_feature[:batch_size] 207 | else: 208 | features[count:(count + batch_size)] = batch_feature 209 | 210 | count += batch_size 211 | if verbose: 212 | print(f'\rGenerate images {count}', end='', flush=True) 213 | if verbose: 214 | print(' done') 215 | return mnist_model.fid(features) 216 | 217 | 218 | def imputer_fid(imputer, data, batch_size=64, verbose=False): 219 | mnist_model = MNISTModel() 220 | impu_noise = torch.FloatTensor(batch_size, 1, 28, 28).to(device) 221 | data_loader = DataLoader(data, batch_size=batch_size, drop_last=True) 222 | n_samples = len(data_loader) * batch_size 223 | 224 | with torch.no_grad(): 225 | start = 0 226 | features = None 227 | for real_data, real_mask, _, index in data_loader: 228 | real_mask = real_mask.float()[:, None] 229 | real_data = real_data.to(device) 230 | real_mask = real_mask.to(device) 231 | impu_noise.uniform_() 232 | imputed_data = imputer(real_data, real_mask, impu_noise) 233 | 234 | batch_feature = mnist_model.get_feature(imputed_data) 235 | if features is None: 236 | features = np.empty((n_samples, batch_feature.shape[1])) 237 | features[start:(start + batch_size)] = batch_feature 238 | start += batch_size 239 | if verbose: 240 | print(f'\rGenerate images {start}', end='', flush=True) 241 | if verbose: 242 | print(' done') 243 | return mnist_model.fid(features) 244 | 245 | 246 | def pretrained_misgan_fid(model_file, samples=60000, batch_size=64): 247 | model = torch.load(model_file, map_location='cpu') 248 | args = model['args'] 249 | if args.generator == 'conv': 250 | DataGenerator = ConvDataGenerator 251 | elif args.generator == 'fc': 252 | DataGenerator = FCDataGenerator 253 | data_gen = DataGenerator().to(device) 254 | data_gen.load_state_dict(model['data_gen']) 255 | return data_generator_fid(data_gen, verbose=True) 256 | 257 | 258 | def pretrained_imputer_fid(model_file, save_file, batch_size=64): 259 | model = torch.load(model_file, map_location='cpu') 260 | if 'imputer' not in model: 261 | return 262 | args = model['args'] 263 | 264 | if args.imputer == 'comp': 265 | Imputer = ComplementImputer 266 | elif args.imputer == 'mask': 267 | Imputer = MaskImputer 268 | elif args.imputer == 'fix': 269 | Imputer = FixedNoiseDimImputer 270 | 271 | hid_lens = [int(n) for n in args.arch.split('-')] 272 | imputer = Imputer(arch=hid_lens).to(device) 273 | imputer.load_state_dict(model['imputer']) 274 | 275 | block_len = args.block_len 276 | if block_len == 0: 277 | block_len = None 278 | 279 | if args.mask == 'indep': 280 | data = IndepMaskedMNIST(obs_prob=args.obs_prob, 281 | obs_prob_high=args.obs_prob_high) 282 | elif args.mask == 'block': 283 | data = BlockMaskedMNIST(block_len=block_len) 284 | 285 | fid = imputer_fid(imputer, data, verbose=True) 286 | with save_file.open('w') as f: 287 | print(fid, file=f) 288 | print('imputer fid:', fid) 289 | 290 | 291 | def main(): 292 | parser = argparse.ArgumentParser() 293 | parser.add_argument('root_dir') 294 | parser.add_argument('--skip-exist', action='store_true') 295 | args = parser.parse_args() 296 | 297 | skip_exist = args.skip_exist 298 | 299 | root_dir = Path(args.root_dir) 300 | fid_file = root_dir / f'fid-{feature_layer}.txt' 301 | if skip_exist and fid_file.exists(): 302 | return 303 | try: 304 | model_file = max((root_dir / 'model').glob('*.pth')) 305 | except ValueError: 306 | return 307 | 308 | fid = pretrained_misgan_fid(model_file) 309 | print(f'{root_dir.name}: {fid}') 310 | with fid_file.open('w') as f: 311 | print(fid, file=f) 312 | 313 | # Compute FID for the imputer if it is in the model 314 | imputer_fid_file = root_dir / f'impute-fid-{feature_layer}.txt' 315 | pretrained_imputer_fid(model_file, imputer_fid_file) 316 | 317 | 318 | if __name__ == '__main__': 319 | main() 320 | -------------------------------------------------------------------------------- /src-torch1.6/mnist_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def add_data_transformer(self): 7 | self.transform = lambda x: torch.sigmoid(x).view(-1, 1, 28, 28) 8 | 9 | 10 | def add_mask_transformer(self, temperature=.66, hard_sigmoid=(-.1, 1.1)): 11 | """ 12 | hard_sigmoid: 13 | False: use sigmoid only 14 | True: hard thresholding 15 | (a, b): hard thresholding on rescaled sigmoid 16 | """ 17 | self.temperature = temperature 18 | self.hard_sigmoid = hard_sigmoid 19 | 20 | view = -1, 1, 28, 28 21 | if hard_sigmoid is False: 22 | self.transform = lambda x: torch.sigmoid(x / temperature).view(*view) 23 | elif hard_sigmoid is True: 24 | self.transform = lambda x: F.hardtanh( 25 | x / temperature, 0, 1).view(*view) 26 | else: 27 | a, b = hard_sigmoid 28 | self.transform = lambda x: F.hardtanh( 29 | torch.sigmoid(x / temperature) * (b - a) + a, 0, 1).view(*view) 30 | 31 | 32 | # Must sub-class ConvGenerator to provide transform() 33 | class ConvGenerator(nn.Module): 34 | def __init__(self, latent_size=128): 35 | super().__init__() 36 | 37 | self.DIM = 64 38 | self.latent_size = latent_size 39 | 40 | self.preprocess = nn.Sequential( 41 | nn.Linear(latent_size, 4 * 4 * 4 * self.DIM), 42 | nn.ReLU(True), 43 | ) 44 | self.block1 = nn.Sequential( 45 | nn.ConvTranspose2d(4 * self.DIM, 2 * self.DIM, 5), 46 | nn.ReLU(True), 47 | ) 48 | self.block2 = nn.Sequential( 49 | nn.ConvTranspose2d(2 * self.DIM, self.DIM, 5), 50 | nn.ReLU(True), 51 | ) 52 | self.deconv_out = nn.ConvTranspose2d(self.DIM, 1, 8, stride=2) 53 | 54 | def forward(self, input): 55 | net = self.preprocess(input) 56 | net = net.view(-1, 4 * self.DIM, 4, 4) 57 | net = self.block1(net) 58 | net = net[:, :, :7, :7] 59 | net = self.block2(net) 60 | net = self.deconv_out(net) 61 | return self.transform(net) 62 | 63 | 64 | # Must sub-class FCGenerator to provide transform() 65 | class FCGenerator(nn.Module): 66 | def __init__(self, latent_size=128): 67 | super().__init__() 68 | self.latent_size = latent_size 69 | self.fc = nn.Sequential( 70 | nn.Linear(latent_size, 256), 71 | nn.ReLU(True), 72 | nn.Linear(256, 512), 73 | nn.ReLU(True), 74 | nn.Linear(512, 784), 75 | ) 76 | 77 | def forward(self, input): 78 | net = self.fc(input) 79 | return self.transform(net) 80 | 81 | 82 | class ConvDataGenerator(ConvGenerator): 83 | def __init__(self, latent_size=128): 84 | super().__init__(latent_size=latent_size) 85 | add_data_transformer(self) 86 | 87 | 88 | class FCDataGenerator(FCGenerator): 89 | def __init__(self, latent_size=128): 90 | super().__init__(latent_size=latent_size) 91 | add_data_transformer(self) 92 | 93 | 94 | class ConvMaskGenerator(ConvGenerator): 95 | def __init__(self, latent_size=128, temperature=.66, 96 | hard_sigmoid=(-.1, 1.1)): 97 | super().__init__(latent_size=latent_size) 98 | add_mask_transformer(self, temperature, hard_sigmoid) 99 | 100 | 101 | class FCMaskGenerator(FCGenerator): 102 | def __init__(self, latent_size=128, temperature=.66, 103 | hard_sigmoid=(-.1, 1.1)): 104 | super().__init__(latent_size=latent_size) 105 | add_mask_transformer(self, temperature, hard_sigmoid) 106 | -------------------------------------------------------------------------------- /src-torch1.6/mnist_imputer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # Must sub-class Imputer to provide fc1 7 | class Imputer(nn.Module): 8 | def __init__(self, arch=(784, 784)): 9 | super().__init__() 10 | # self.fc1 = nn.Linear(784, arch[0]) 11 | self.fc2 = nn.Linear(arch[0], arch[1]) 12 | self.fc3 = nn.Linear(arch[1], arch[0]) 13 | self.fc4 = nn.Linear(arch[0], 784) 14 | self.transform = lambda x: torch.sigmoid(x).view(-1, 1, 28, 28) 15 | 16 | def forward(self, input, data, mask): 17 | net = input.view(input.size(0), -1) 18 | net = F.relu(self.fc1(net)) 19 | net = F.relu(self.fc2(net)) 20 | net = F.relu(self.fc3(net)) 21 | net = self.fc4(net) 22 | net = self.transform(net) 23 | # return data * mask + net * (1 - mask) 24 | # NOT replacing observed part with input data for computing 25 | # autoencoding loss. 26 | return net 27 | 28 | 29 | class ComplementImputer(Imputer): 30 | def __init__(self, arch=(784, 784)): 31 | super().__init__(arch=arch) 32 | self.fc1 = nn.Linear(784, arch[0]) 33 | 34 | def forward(self, input, mask, noise): 35 | net = input * mask + noise * (1 - mask) 36 | return super().forward(net, input, mask) 37 | 38 | 39 | class MaskImputer(Imputer): 40 | def __init__(self, arch=(784, 784)): 41 | super().__init__(arch=arch) 42 | self.fc1 = nn.Linear(784 * 2, arch[0]) 43 | 44 | def forward(self, input, mask, noise): 45 | batch_size = input.size(0) 46 | net = torch.cat( 47 | [(input * mask + noise * (1 - mask)).view(batch_size, -1), 48 | mask.view(batch_size, -1)], 1) 49 | return super().forward(net, input, mask) 50 | 51 | 52 | class FixedNoiseDimImputer(Imputer): 53 | def __init__(self, arch=(784, 784)): 54 | super().__init__(arch=arch) 55 | self.fc1 = nn.Linear(784 * 3, arch[0]) 56 | 57 | def forward(self, input, mask, noise): 58 | batch_size = input.size(0) 59 | net = torch.cat([(input * mask).view(batch_size, -1), 60 | mask.view(batch_size, -1), 61 | noise.view(batch_size, -1)], 1) 62 | return super().forward(net, input, mask) 63 | -------------------------------------------------------------------------------- /src-torch1.6/mnist_misgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datetime import datetime 3 | from pathlib import Path 4 | import argparse 5 | from mnist_generator import (ConvDataGenerator, FCDataGenerator, 6 | ConvMaskGenerator, FCMaskGenerator) 7 | from mnist_critic import ConvCritic, FCCritic 8 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST 9 | from misgan import misgan 10 | 11 | 12 | use_cuda = torch.cuda.is_available() 13 | device = torch.device('cuda' if use_cuda else 'cpu') 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | 19 | # resume from checkpoint 20 | parser.add_argument('--resume') 21 | # training options 22 | parser.add_argument('--epoch', type=int, default=500) 23 | parser.add_argument('--batch-size', type=int, default=64) 24 | 25 | # log options: 0 to disable plot-interval or save-interval 26 | parser.add_argument('--plot-interval', type=int, default=50) 27 | parser.add_argument('--save-interval', type=int, default=0) 28 | parser.add_argument('--prefix', default='misgan') 29 | 30 | # mask options (data): block|indep 31 | parser.add_argument('--mask', default='block') 32 | # option for block: set to 0 for variable size 33 | parser.add_argument('--block-len', type=int, default=14) 34 | # option for indep: 35 | parser.add_argument('--obs-prob', type=float, default=.2) 36 | parser.add_argument('--obs-prob-high', type=float, default=None) 37 | 38 | # model options 39 | parser.add_argument('--tau', type=float, default=0) 40 | parser.add_argument('--generator', default='conv') # conv|fc 41 | parser.add_argument('--critic', default='conv') # conv|fc 42 | # parser.add_argument('--alpha', type=float, default=.1) # 0: separate 43 | parser.add_argument('--alpha', type=float, default=.2) # 0: separate 44 | # options for mask generator: sigmoid, hardsigmoid, fusion 45 | # parser.add_argument('--maskgen', default='fusion') 46 | parser.add_argument('--maskgen', default='sigmoid') 47 | parser.add_argument('--gp-lambda', type=float, default=10) 48 | parser.add_argument('--n-critic', type=int, default=5) 49 | parser.add_argument('--n-latent', type=int, default=128) 50 | 51 | args = parser.parse_args() 52 | 53 | checkpoint = None 54 | # Resume from previously stored checkpoint 55 | if args.resume: 56 | print(f'Resume: {args.resume}') 57 | output_dir = Path(args.resume) 58 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'), 59 | map_location='cpu') 60 | for key, arg in vars(checkpoint['args']).items(): 61 | if key not in ['resume']: 62 | setattr(args, key, arg) 63 | 64 | if args.generator == 'conv': 65 | DataGenerator = ConvDataGenerator 66 | MaskGenerator = ConvMaskGenerator 67 | elif args.generator == 'fc': 68 | DataGenerator = FCDataGenerator 69 | MaskGenerator = FCMaskGenerator 70 | else: 71 | raise NotImplementedError 72 | 73 | if args.critic == 'conv': 74 | Critic = ConvCritic 75 | elif args.critic == 'fc': 76 | Critic = FCCritic 77 | else: 78 | raise NotImplementedError 79 | 80 | if args.maskgen == 'sigmoid': 81 | hard_sigmoid = False 82 | elif args.maskgen == 'hardsigmoid': 83 | hard_sigmoid = True 84 | elif args.maskgen == 'fusion': 85 | hard_sigmoid = -.1, 1.1 86 | else: 87 | raise NotImplementedError 88 | 89 | mask = args.mask 90 | obs_prob = args.obs_prob 91 | obs_prob_high = args.obs_prob_high 92 | block_len = args.block_len 93 | if block_len == 0: 94 | block_len = None 95 | 96 | if mask == 'indep': 97 | if obs_prob_high is None: 98 | mask_str = f'indep_{obs_prob:g}' 99 | else: 100 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}' 101 | elif mask == 'block': 102 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize') 103 | else: 104 | raise NotImplementedError 105 | 106 | path = '{}_{}_{}'.format( 107 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), 108 | '_'.join([ 109 | f'gen_{args.generator}', 110 | f'critic_{args.critic}', 111 | f'tau_{args.tau:g}', 112 | f'alpha_{args.alpha:g}', 113 | f'maskgen_{args.maskgen}', 114 | mask_str, 115 | ])) 116 | 117 | if not args.resume: 118 | output_dir = Path('results') / 'mnist' / path 119 | print(output_dir) 120 | 121 | if mask == 'indep': 122 | data = IndepMaskedMNIST(obs_prob=obs_prob, obs_prob_high=obs_prob_high) 123 | elif mask == 'block': 124 | data = BlockMaskedMNIST(block_len=block_len) 125 | 126 | data_gen = DataGenerator().to(device) 127 | mask_gen = MaskGenerator(hard_sigmoid=hard_sigmoid).to(device) 128 | 129 | data_critic = Critic().to(device) 130 | mask_critic = Critic().to(device) 131 | 132 | misgan(args, data_gen, mask_gen, data_critic, mask_critic, data, 133 | output_dir, checkpoint) 134 | 135 | 136 | if __name__ == '__main__': 137 | main() 138 | -------------------------------------------------------------------------------- /src-torch1.6/mnist_misgan_impute.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datetime import datetime 3 | from pathlib import Path 4 | import argparse 5 | from mnist_generator import (ConvDataGenerator, FCDataGenerator, 6 | ConvMaskGenerator, FCMaskGenerator) 7 | from mnist_imputer import (ComplementImputer, 8 | MaskImputer, 9 | FixedNoiseDimImputer) 10 | from mnist_critic import ConvCritic, FCCritic 11 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST 12 | from misgan_impute import misgan_impute 13 | 14 | 15 | use_cuda = torch.cuda.is_available() 16 | device = torch.device('cuda' if use_cuda else 'cpu') 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | 22 | # resume from checkpoint 23 | parser.add_argument('--resume') 24 | 25 | # training options 26 | parser.add_argument('--workers', type=int, default=0) 27 | parser.add_argument('--epoch', type=int, default=1000) 28 | parser.add_argument('--batch-size', type=int, default=64) 29 | parser.add_argument('--pretrain', default=None) 30 | parser.add_argument('--imputeronly', action='store_true') 31 | 32 | # log options: 0 to disable plot-interval or save-interval 33 | parser.add_argument('--plot-interval', type=int, default=100) 34 | parser.add_argument('--save-interval', type=int, default=0) 35 | parser.add_argument('--prefix', default='impute') 36 | 37 | # mask options (data): block|indep 38 | parser.add_argument('--mask', default='block') 39 | # option for block: set to 0 for variable size 40 | parser.add_argument('--block-len', type=int, default=14) 41 | # option for indep: 42 | parser.add_argument('--obs-prob', type=float, default=.2) 43 | parser.add_argument('--obs-prob-high', type=float, default=None) 44 | 45 | # model options 46 | parser.add_argument('--tau', type=float, default=0) 47 | parser.add_argument('--generator', default='conv') # conv|fc 48 | parser.add_argument('--critic', default='conv') # conv|fc 49 | parser.add_argument('--alpha', type=float, default=.1) # 0: separate 50 | parser.add_argument('--beta', type=float, default=.1) 51 | parser.add_argument('--gamma', type=float, default=0) 52 | parser.add_argument('--arch', default='784-784') 53 | parser.add_argument('--imputer', default='comp') # comp|mask|fix 54 | # options for mask generator: sigmoid, hardsigmoid, fusion 55 | parser.add_argument('--maskgen', default='fusion') 56 | parser.add_argument('--gp-lambda', type=float, default=10) 57 | parser.add_argument('--n-critic', type=int, default=5) 58 | parser.add_argument('--n-latent', type=int, default=128) 59 | 60 | args = parser.parse_args() 61 | 62 | checkpoint = None 63 | # Resume from previously stored checkpoint 64 | if args.resume: 65 | print(f'Resume: {args.resume}') 66 | output_dir = Path(args.resume) 67 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'), 68 | map_location='cpu') 69 | for key, arg in vars(checkpoint['args']).items(): 70 | if key not in ['resume']: 71 | setattr(args, key, arg) 72 | 73 | if args.imputeronly: 74 | assert args.pretrain is not None 75 | 76 | arch = args.arch 77 | imputer_type = args.imputer 78 | mask = args.mask 79 | obs_prob = args.obs_prob 80 | obs_prob_high = args.obs_prob_high 81 | block_len = args.block_len 82 | if block_len == 0: 83 | block_len = None 84 | 85 | if args.generator == 'conv': 86 | DataGenerator = ConvDataGenerator 87 | MaskGenerator = ConvMaskGenerator 88 | elif args.generator == 'fc': 89 | DataGenerator = FCDataGenerator 90 | MaskGenerator = FCMaskGenerator 91 | else: 92 | raise NotImplementedError 93 | 94 | if imputer_type == 'comp': 95 | Imputer = ComplementImputer 96 | elif imputer_type == 'mask': 97 | Imputer = MaskImputer 98 | elif imputer_type == 'fix': 99 | Imputer = FixedNoiseDimImputer 100 | else: 101 | raise NotImplementedError 102 | 103 | if args.critic == 'conv': 104 | Critic = ConvCritic 105 | elif args.critic == 'fc': 106 | Critic = FCCritic 107 | else: 108 | raise NotImplementedError 109 | 110 | if args.maskgen == 'sigmoid': 111 | hard_sigmoid = False 112 | elif args.maskgen == 'hardsigmoid': 113 | hard_sigmoid = True 114 | elif args.maskgen == 'fusion': 115 | hard_sigmoid = -.1, 1.1 116 | else: 117 | raise NotImplementedError 118 | 119 | if mask == 'indep': 120 | if obs_prob_high is None: 121 | mask_str = f'indep_{obs_prob:g}' 122 | else: 123 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}' 124 | elif mask == 'block': 125 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize') 126 | else: 127 | raise NotImplementedError 128 | 129 | path = '{}_{}_{}'.format( 130 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), 131 | '_'.join([ 132 | f'gen_{args.generator}', 133 | f'critic_{args.critic}', 134 | f'imp_{args.imputer}', 135 | f'tau_{args.tau:g}', 136 | f'arch_{args.arch}', 137 | f'maskgen_{args.maskgen}', 138 | f'coef_{args.alpha:g}_{args.beta:g}_{args.gamma:g}', 139 | mask_str 140 | ])) 141 | 142 | if not args.resume: 143 | output_dir = Path('results') / 'mnist' / path 144 | print(output_dir) 145 | 146 | if mask == 'indep': 147 | data = IndepMaskedMNIST( 148 | obs_prob=obs_prob, obs_prob_high=obs_prob_high) 149 | elif mask == 'block': 150 | data = BlockMaskedMNIST(block_len=block_len) 151 | 152 | data_gen = DataGenerator().to(device) 153 | mask_gen = MaskGenerator(hard_sigmoid=hard_sigmoid).to(device) 154 | 155 | hid_lens = [int(n) for n in arch.split('-')] 156 | imputer = Imputer(arch=hid_lens).to(device) 157 | 158 | data_critic = Critic().to(device) 159 | mask_critic = Critic().to(device) 160 | impu_critic = Critic().to(device) 161 | 162 | misgan_impute(args, data_gen, mask_gen, imputer, 163 | data_critic, mask_critic, impu_critic, 164 | data, output_dir, checkpoint) 165 | 166 | 167 | if __name__ == '__main__': 168 | main() 169 | -------------------------------------------------------------------------------- /src-torch1.6/mnist_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from https://github.com/pytorch/examples/blob/master/mnist/main.py 3 | """ 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torchvision import datasets, transforms 10 | 11 | 12 | class Net(nn.Module): 13 | def __init__(self): 14 | super(Net, self).__init__() 15 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 16 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 17 | self.conv2_drop = nn.Dropout2d() 18 | self.fc1 = nn.Linear(320, 50) 19 | self.fc2 = nn.Linear(50, 10) 20 | 21 | def forward(self, x): 22 | feature = [] 23 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 24 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 25 | x = x.view(-1, 320) 26 | x = self.fc1(x) 27 | feature.append(x) 28 | x = F.relu(x) 29 | x = F.dropout(x, training=self.training) 30 | x = self.fc2(x) 31 | feature.append(x) 32 | self.feature = feature 33 | return F.log_softmax(x, dim=1) 34 | 35 | 36 | def main(): 37 | # Training settings 38 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 39 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 40 | help='input batch size for training (default: 64)') 41 | parser.add_argument('--test-batch-size', type=int, 42 | default=1000, metavar='N', 43 | help='input batch size for testing (default: 1000)') 44 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 45 | help='number of epochs to train (default: 100)') 46 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 47 | help='learning rate (default: 0.01)') 48 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 49 | help='SGD momentum (default: 0.5)') 50 | parser.add_argument('--no-cuda', action='store_true', default=False, 51 | help='disables CUDA training') 52 | parser.add_argument('--seed', type=int, default=1, metavar='S', 53 | help='random seed (default: 1)') 54 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 55 | help='number of batches to wait before logging ' 56 | 'training status') 57 | args = parser.parse_args() 58 | args.cuda = not args.no_cuda and torch.cuda.is_available() 59 | 60 | torch.manual_seed(args.seed) 61 | if args.cuda: 62 | torch.cuda.manual_seed(args.seed) 63 | 64 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 65 | train_loader = torch.utils.data.DataLoader( 66 | datasets.MNIST('../data', train=True, download=True, 67 | transform=transforms.Compose([ 68 | transforms.ToTensor(), 69 | transforms.Normalize((0.1307,), (0.3081,)) 70 | ])), 71 | batch_size=args.batch_size, shuffle=True, **kwargs) 72 | test_loader = torch.utils.data.DataLoader( 73 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 74 | transforms.ToTensor(), 75 | transforms.Normalize((0.1307,), (0.3081,)) 76 | ])), 77 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 78 | 79 | model = Net() 80 | if args.cuda: 81 | model.cuda() 82 | 83 | optimizer = optim.SGD(model.parameters(), lr=args.lr, 84 | momentum=args.momentum) 85 | 86 | def train(epoch): 87 | model.train() 88 | for batch_idx, (data, target) in enumerate(train_loader): 89 | if args.cuda: 90 | data, target = data.cuda(), target.cuda() 91 | optimizer.zero_grad() 92 | output = model(data) 93 | loss = F.nll_loss(output, target) 94 | loss.backward() 95 | optimizer.step() 96 | if batch_idx % args.log_interval == 0: 97 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 98 | epoch, batch_idx * len(data), len(train_loader.dataset), 99 | 100. * batch_idx / len(train_loader), loss.item())) 100 | 101 | def test(): 102 | model.eval() 103 | test_loss = 0 104 | correct = 0 105 | with torch.no_grad(): 106 | for data, target in test_loader: 107 | if args.cuda: 108 | data, target = data.cuda(), target.cuda() 109 | output = model(data) 110 | # sum up batch loss 111 | test_loss += F.nll_loss(output, target, reduction='sum').item() 112 | # get the index of the max log-probability 113 | pred = output.argmax(dim=1, keepdim=True) 114 | correct += (pred == target.view_as(pred)).long().cpu().sum() 115 | 116 | test_loss /= len(test_loader.dataset) 117 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n' 118 | .format(test_loss, correct, len(test_loader.dataset), 119 | 100. * correct / len(test_loader.dataset))) 120 | 121 | for epoch in range(1, args.epochs + 1): 122 | train(epoch) 123 | test() 124 | 125 | torch.save(model.state_dict(), 'mnist.pth') 126 | 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /src-torch1.6/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pylab as plt 3 | from matplotlib.patches import Rectangle 4 | from PIL import Image 5 | 6 | 7 | def plot_grid(image, bbox=None, gap=0, gap_value=1, 8 | nrow=4, ncol=8, save_file=None): 9 | image = image.cpu().numpy() 10 | channels, len0, len1 = image[0].shape 11 | grid = np.empty( 12 | (nrow * (len0 + gap) - gap, ncol * (len1 + gap) - gap, channels)) 13 | # Convert to W, H, C 14 | image = image.transpose((0, 2, 3, 1)) 15 | grid.fill(gap_value) 16 | 17 | for i, x in enumerate(image): 18 | if i >= nrow * ncol: 19 | break 20 | p0 = (i // ncol) * (len0 + gap) 21 | p1 = (i % ncol) * (len1 + gap) 22 | grid[p0:(p0 + len0), p1:(p1 + len1)] = x 23 | 24 | # figsize = np.r_[ncol, nrow] * .75 25 | scale = 2.5 26 | figsize = ncol * scale, nrow * scale # FIXME: scale by len0, len1 27 | fig = plt.figure(figsize=figsize) 28 | ax = plt.Axes(fig, [0, 0, 1, 1]) 29 | ax.set_axis_off() 30 | fig.add_axes(ax) 31 | grid = grid.squeeze() 32 | ax.imshow(grid, cmap='binary_r', interpolation='none', aspect='equal') 33 | 34 | if bbox is not None: 35 | nplot = min(len(image), nrow * ncol) 36 | for i in range(nplot): 37 | if len(bbox) == 1: 38 | d0, d1, d0_len, d1_len = bbox[0] 39 | else: 40 | d0, d1, d0_len, d1_len = bbox[i] 41 | p0 = (i // ncol) * (len0 + gap) 42 | p1 = (i % ncol) * (len1 + gap) 43 | offset = np.array([p1 + d1, p0 + d0]) - .5 44 | ax.add_patch(Rectangle( 45 | offset, d1_len, d0_len, lw=4, edgecolor='red', fill=False)) 46 | 47 | if save_file: 48 | fig.savefig(save_file) 49 | plt.close(fig) 50 | 51 | 52 | def plot_samples(samples, save_file, nrow=4, ncol=8): 53 | x = samples.cpu().numpy() 54 | channels, len0, len1 = x[0].shape 55 | x_merge = np.zeros((nrow * len0, ncol * len1, channels)) 56 | 57 | for i, x_ in enumerate(x): 58 | if i >= nrow * ncol: 59 | break 60 | p0 = (i // ncol) * len0 61 | p1 = (i % ncol) * len1 62 | x_merge[p0:(p0 + len0), p1:(p1 + len1)] = x_.transpose((1, 2, 0)) 63 | 64 | x_merge = (x_merge * 255).clip(0, 255).astype(np.uint8) 65 | # squeeze() to remove the last dimension for the single-channel image. 66 | im = Image.fromarray(x_merge.squeeze()) 67 | im.save(save_file) 68 | -------------------------------------------------------------------------------- /src-torch1.6/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.3.2 2 | numpy==1.19.2 3 | Pillow==8.0.1 4 | scipy==1.5.3 5 | seaborn==0.11.0 6 | torch==1.6.0 7 | torchvision==0.7.0 8 | -------------------------------------------------------------------------------- /src-torch1.6/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Code adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 6 | # 7 | # Defines the submodule with skip connection. 8 | # X -------------------identity---------------------- X 9 | # |-- downsampling -- |submodule| -- upsampling --| 10 | class UnetSkipConnectionBlock(nn.Module): 11 | def __init__(self, outer_nc, inner_nc, input_nc=None, 12 | submodule=None, outermost=False, innermost=False, 13 | norm_layer=nn.BatchNorm2d): 14 | super().__init__() 15 | self.outermost = outermost 16 | use_bias = norm_layer == nn.InstanceNorm2d 17 | if input_nc is None: 18 | input_nc = outer_nc 19 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 20 | stride=2, padding=1, bias=use_bias) 21 | downrelu = nn.LeakyReLU(0.2, True) 22 | if norm_layer is not None: 23 | downnorm = norm_layer(inner_nc) 24 | upnorm = norm_layer(outer_nc) 25 | uprelu = nn.ReLU(True) 26 | 27 | if outermost: 28 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 29 | kernel_size=4, stride=2, 30 | padding=1) 31 | down = [downconv] 32 | up = [uprelu, upconv] 33 | model = down + [submodule] + up 34 | elif innermost: 35 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 36 | kernel_size=4, stride=2, 37 | padding=1, bias=use_bias) 38 | down = [downrelu, downconv] 39 | up = [uprelu, upconv] 40 | if norm_layer is not None: 41 | up.append(upnorm) 42 | model = down + up 43 | else: 44 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 45 | kernel_size=4, stride=2, 46 | padding=1, bias=use_bias) 47 | down = [downrelu, downconv] 48 | up = [uprelu, upconv] 49 | if norm_layer is not None: 50 | down.append(downnorm) 51 | up.append(upnorm) 52 | 53 | model = down + [submodule] + up 54 | 55 | self.model = nn.Sequential(*model) 56 | 57 | def forward(self, x): 58 | if self.outermost: 59 | return self.model(x) 60 | else: 61 | return torch.cat([x, self.model(x)], 1) 62 | -------------------------------------------------------------------------------- /src-torch1.6/utils.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import grad 2 | 3 | 4 | class CriticUpdater: 5 | def __init__(self, critic, critic_optimizer, eps, ones, gp_lambda=10): 6 | self.critic = critic 7 | self.critic_optimizer = critic_optimizer 8 | self.eps = eps 9 | self.ones = ones 10 | self.gp_lambda = gp_lambda 11 | 12 | def __call__(self, real, fake): 13 | real = real.detach() 14 | fake = fake.detach() 15 | self.critic.zero_grad() 16 | self.eps.uniform_(0, 1) 17 | interp = (self.eps * real + (1 - self.eps) * fake).requires_grad_() 18 | grad_d = grad(self.critic(interp), interp, grad_outputs=self.ones, 19 | create_graph=True)[0] 20 | grad_d = grad_d.view(real.shape[0], -1) 21 | grad_penalty = ((grad_d.norm(dim=1) - 1)**2).mean() * self.gp_lambda 22 | w_dist = self.critic(fake).mean() - self.critic(real).mean() 23 | loss = w_dist + grad_penalty 24 | loss.backward() 25 | self.critic_optimizer.step() 26 | self.loss_value = loss.item() 27 | 28 | 29 | def mask_norm(diff, mask): 30 | """Mask normalization""" 31 | dim = 1, 2, 3 32 | # Assume mask.sum(1) is non-zero throughout 33 | return ((diff * mask).sum(dim) / mask.sum(dim)).mean() 34 | 35 | 36 | def mkdir(path): 37 | path.mkdir(parents=True, exist_ok=True) 38 | return path 39 | 40 | 41 | def mask_data(data, mask, tau): 42 | return mask * data + (1 - mask) * tau 43 | -------------------------------------------------------------------------------- /src/celeba_critic.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def conv_ln_lrelu(in_dim, out_dim): 5 | return nn.Sequential( 6 | nn.Conv2d(in_dim, out_dim, 5, 2, 2), 7 | nn.InstanceNorm2d(out_dim, affine=True), 8 | nn.LeakyReLU(0.2)) 9 | 10 | 11 | class ConvCritic(nn.Module): 12 | def __init__(self, n_channels): 13 | super().__init__() 14 | dim = 64 15 | self.ls = nn.Sequential( 16 | nn.Conv2d(n_channels, dim, 5, 2, 2), nn.LeakyReLU(0.2), 17 | conv_ln_lrelu(dim, dim * 2), 18 | conv_ln_lrelu(dim * 2, dim * 4), 19 | conv_ln_lrelu(dim * 4, dim * 8), 20 | nn.Conv2d(dim * 8, 1, 4)) 21 | 22 | def forward(self, input): 23 | net = self.ls(input) 24 | return net.view(-1) 25 | -------------------------------------------------------------------------------- /src/celeba_fid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import torch 4 | import torch.nn as nn 5 | from torch.utils.data import DataLoader 6 | from torchvision import datasets, transforms 7 | from PIL import Image 8 | from celeba_generator import ConvDataGenerator 9 | from fid import BaseSampler, BaseImputationSampler 10 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA 11 | from imputer import UNetImputer 12 | from fid import FID 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('root_dir') 17 | parser.add_argument('--batch-size', type=int, default=256) 18 | parser.add_argument('--workers', type=int, default=0) 19 | parser.add_argument('--skip-exist', action='store_true') 20 | args = parser.parse_args() 21 | 22 | 23 | use_cuda = torch.cuda.is_available() 24 | device = torch.device('cuda' if use_cuda else 'cpu') 25 | 26 | 27 | class CelebAFID(FID): 28 | def __init__(self, batch_size=256, data_name='celeba', 29 | workers=0, verbose=True): 30 | self.batch_size = batch_size 31 | self.workers = workers 32 | super().__init__(data_name, verbose) 33 | 34 | def complete_data(self): 35 | data = datasets.ImageFolder( 36 | 'celeba', 37 | transforms.Compose([ 38 | transforms.CenterCrop(108), 39 | transforms.Resize(size=64, interpolation=Image.BICUBIC), 40 | transforms.ToTensor(), 41 | # transforms.Normalize(mean=(.5, .5, .5), std=(.5, .5, .5)), 42 | ])) 43 | 44 | images = len(data) 45 | data_loader = DataLoader( 46 | data, batch_size=self.batch_size, num_workers=self.workers) 47 | 48 | return data_loader, images 49 | 50 | 51 | class MisGANSampler(BaseSampler): 52 | def __init__(self, data_gen, images=60000, batch_size=256): 53 | super().__init__(images) 54 | self.data_gen = data_gen 55 | self.batch_size = batch_size 56 | latent_dim = 128 57 | self.data_noise = torch.FloatTensor(batch_size, latent_dim).to(device) 58 | 59 | def sample(self): 60 | self.data_noise.normal_() 61 | return self.data_gen(self.data_noise) 62 | 63 | 64 | class MisGANImputationSampler(BaseImputationSampler): 65 | def __init__(self, data_loader, imputer, batch_size=256): 66 | super().__init__(data_loader) 67 | self.imputer = imputer 68 | self.impu_noise = torch.FloatTensor(batch_size, 3, 64, 64).to(device) 69 | 70 | def impute(self, data, mask): 71 | if data.shape[0] != self.impu_noise.shape[0]: 72 | self.impu_noise.resize_(data.shape) 73 | self.impu_noise.uniform_() 74 | return self.imputer(data, mask, self.impu_noise) 75 | 76 | 77 | def get_data_loader(args, batch_size): 78 | if args.mask == 'indep': 79 | data = IndepMaskedCelebA( 80 | data_dir=args.data_dir, 81 | obs_prob=args.obs_prob, obs_prob_high=args.obs_prob_high) 82 | elif args.mask == 'block': 83 | data = BlockMaskedCelebA( 84 | data_dir=args.data_dir, block_len=args.block_len) 85 | 86 | data_size = len(data) 87 | data_loader = DataLoader( 88 | data, batch_size=batch_size, num_workers=args.workers) 89 | return data_loader, data_size 90 | 91 | 92 | def parallelize(model): 93 | return nn.DataParallel(model).to(device) 94 | 95 | 96 | def pretrained_misgan_fid(model_file, samples=202599): 97 | model = torch.load(model_file, map_location='cpu') 98 | data_gen = parallelize(ConvDataGenerator()) 99 | data_gen.load_state_dict(model['data_gen']) 100 | 101 | batch_size = args.batch_size 102 | 103 | compute_fid = CelebAFID(batch_size=batch_size) 104 | sampler = MisGANSampler(data_gen, samples, batch_size) 105 | gen_fid = compute_fid.fid(sampler, samples) 106 | print(f'fid: {gen_fid:.2f}') 107 | 108 | imp_fid = None 109 | if 'imputer' in model: 110 | imputer = UNetImputer().to(device) 111 | imputer.load_state_dict(model['imputer']) 112 | data_loader, data_size = get_data_loader(model['args'], batch_size) 113 | imputation_sampler = MisGANImputationSampler( 114 | data_loader, imputer, batch_size) 115 | imp_fid = compute_fid.fid(imputation_sampler, data_size) 116 | print(f'impute fid: {imp_fid:.2f}') 117 | 118 | return gen_fid, imp_fid 119 | 120 | 121 | def main(): 122 | root_dir = Path(args.root_dir) 123 | fid_file = root_dir / 'fid.txt' 124 | if args.skip_exist and fid_file.exists(): 125 | return 126 | try: 127 | model_file = max((root_dir / 'model').glob('*.pth')) 128 | except ValueError: 129 | return 130 | 131 | print(root_dir.name) 132 | fid, imp_fid = pretrained_misgan_fid(model_file) 133 | 134 | with fid_file.open('w') as f: 135 | print(fid, file=f) 136 | 137 | if imp_fid is not None: 138 | with (root_dir / 'impute-fid.txt').open('w') as f: 139 | print(imp_fid, file=f) 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /src/celeba_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def add_mask_transformer(self, temperature=.66, hard_sigmoid=(-.1, 1.1)): 7 | """ 8 | hard_sigmoid: 9 | False: use sigmoid only 10 | True: hard thresholding 11 | (a, b): hard thresholding on rescaled sigmoid 12 | """ 13 | self.temperature = temperature 14 | self.hard_sigmoid = hard_sigmoid 15 | 16 | if hard_sigmoid is False: 17 | self.transform = lambda x: torch.sigmoid(x / temperature) 18 | elif hard_sigmoid is True: 19 | self.transform = lambda x: F.hardtanh( 20 | x / temperature, 0, 1) 21 | else: 22 | a, b = hard_sigmoid 23 | self.transform = lambda x: F.hardtanh( 24 | torch.sigmoid(x / temperature) * (b - a) + a, 0, 1) 25 | 26 | 27 | def dconv_bn_relu(in_dim, out_dim): 28 | return nn.Sequential( 29 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2, 30 | padding=2, output_padding=1, bias=False), 31 | nn.BatchNorm2d(out_dim), 32 | nn.ReLU()) 33 | 34 | 35 | # Must sub-class ConvGenerator to provide transform() 36 | class ConvGenerator(nn.Module): 37 | def __init__(self, latent_size=128): 38 | super().__init__() 39 | 40 | dim = 64 41 | 42 | self.l1 = nn.Sequential( 43 | nn.Linear(latent_size, dim * 8 * 4 * 4, bias=False), 44 | nn.BatchNorm1d(dim * 8 * 4 * 4), 45 | nn.ReLU()) 46 | 47 | self.l2_5 = nn.Sequential( 48 | dconv_bn_relu(dim * 8, dim * 4), 49 | dconv_bn_relu(dim * 4, dim * 2), 50 | dconv_bn_relu(dim * 2, dim), 51 | nn.ConvTranspose2d(dim, self.out_channels, 5, 2, 52 | padding=2, output_padding=1)) 53 | 54 | def forward(self, input): 55 | net = self.l1(input) 56 | net = net.view(net.shape[0], -1, 4, 4) 57 | net = self.l2_5(net) 58 | return self.transform(net) 59 | 60 | 61 | class ConvDataGenerator(ConvGenerator): 62 | def __init__(self, latent_size=128): 63 | self.out_channels = 3 64 | super().__init__(latent_size=latent_size) 65 | self.transform = lambda x: torch.sigmoid(x) 66 | 67 | 68 | class ConvMaskGenerator(ConvGenerator): 69 | def __init__(self, latent_size=128, temperature=.66, 70 | hard_sigmoid=(-.1, 1.1)): 71 | self.out_channels = 1 72 | super().__init__(latent_size=latent_size) 73 | add_mask_transformer(self, temperature, hard_sigmoid) 74 | -------------------------------------------------------------------------------- /src/celeba_misgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from datetime import datetime 4 | from pathlib import Path 5 | import argparse 6 | from celeba_generator import ConvDataGenerator, ConvMaskGenerator 7 | from celeba_critic import ConvCritic 8 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA 9 | from misgan import misgan 10 | 11 | 12 | use_cuda = torch.cuda.is_available() 13 | device = torch.device('cuda' if use_cuda else 'cpu') 14 | 15 | 16 | def parallelize(model): 17 | return nn.DataParallel(model).to(device) 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | 23 | # resume from checkpoint 24 | parser.add_argument('--resume') 25 | 26 | # path of CelebA dataset 27 | parser.add_argument('--data-dir', default='celeba-data') 28 | 29 | # training options 30 | parser.add_argument('--epoch', type=int, default=600) 31 | parser.add_argument('--batch-size', type=int, default=256) 32 | 33 | # log options: 0 to disable plot-interval or save-interval 34 | parser.add_argument('--plot-interval', type=int, default=100) 35 | parser.add_argument('--save-interval', type=int, default=0) 36 | parser.add_argument('--prefix', default='misgan') 37 | 38 | # mask options (data): block|indep 39 | parser.add_argument('--mask', default='block') 40 | # option for block: set to 0 for variable size 41 | parser.add_argument('--block-len', type=int, default=32) 42 | # option for indep: 43 | parser.add_argument('--obs-prob', type=float, default=.2) 44 | parser.add_argument('--obs-prob-high', type=float, default=None) 45 | 46 | # model options 47 | parser.add_argument('--tau', type=float, default=.5) 48 | parser.add_argument('--alpha', type=float, default=.1) # 0: separate 49 | # options for mask generator: sigmoid, hardsigmoid, fusion 50 | parser.add_argument('--maskgen', default='fusion') 51 | parser.add_argument('--gp-lambda', type=float, default=10) 52 | parser.add_argument('--n-critic', type=int, default=5) 53 | parser.add_argument('--n-latent', type=int, default=128) 54 | 55 | args = parser.parse_args() 56 | 57 | checkpoint = None 58 | # Resume from previously stored checkpoint 59 | if args.resume: 60 | print(f'Resume: {args.resume}') 61 | output_dir = Path(args.resume) 62 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'), 63 | map_location='cpu') 64 | for key, arg in vars(checkpoint['args']).items(): 65 | if key not in ['resume']: 66 | setattr(args, key, arg) 67 | 68 | if args.maskgen == 'sigmoid': 69 | hard_sigmoid = False 70 | elif args.maskgen == 'hardsigmoid': 71 | hard_sigmoid = True 72 | elif args.maskgen == 'fusion': 73 | hard_sigmoid = -.1, 1.1 74 | else: 75 | raise NotImplementedError 76 | 77 | mask = args.mask 78 | obs_prob = args.obs_prob 79 | obs_prob_high = args.obs_prob_high 80 | block_len = args.block_len 81 | if block_len == 0: 82 | block_len = None 83 | if mask == 'indep': 84 | if obs_prob_high is None: 85 | mask_str = f'indep_{obs_prob:g}' 86 | else: 87 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}' 88 | elif mask == 'block': 89 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize') 90 | else: 91 | raise NotImplementedError 92 | 93 | path = '{}_{}_{}'.format( 94 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), 95 | '_'.join([ 96 | f'tau_{args.tau:g}', 97 | f'alpha_{args.alpha:g}', 98 | f'maskgen_{args.maskgen}', 99 | mask_str, 100 | ])) 101 | 102 | if not args.resume: 103 | output_dir = Path('results') / 'celeba' / path 104 | print(output_dir) 105 | 106 | if mask == 'indep': 107 | data = IndepMaskedCelebA( 108 | data_dir=args.data_dir, 109 | obs_prob=obs_prob, obs_prob_high=obs_prob_high) 110 | elif mask == 'block': 111 | data = BlockMaskedCelebA( 112 | data_dir=args.data_dir, block_len=block_len) 113 | n_gpu = torch.cuda.device_count() 114 | print(f'Use {n_gpu} GPUs.') 115 | 116 | data_gen = parallelize(ConvDataGenerator()) 117 | mask_gen = parallelize(ConvMaskGenerator(hard_sigmoid=hard_sigmoid)) 118 | 119 | data_critic = parallelize(ConvCritic(n_channels=3)) 120 | mask_critic = parallelize(ConvCritic(n_channels=1)) 121 | 122 | misgan(args, data_gen, mask_gen, data_critic, mask_critic, data, 123 | output_dir, checkpoint) 124 | 125 | 126 | if __name__ == '__main__': 127 | main() 128 | -------------------------------------------------------------------------------- /src/celeba_misgan_impute.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from datetime import datetime 4 | from pathlib import Path 5 | import argparse 6 | from celeba_generator import ConvDataGenerator, ConvMaskGenerator 7 | from celeba_critic import ConvCritic 8 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA 9 | from imputer import UNetImputer 10 | from misgan_impute import misgan_impute 11 | 12 | 13 | use_cuda = torch.cuda.is_available() 14 | device = torch.device('cuda' if use_cuda else 'cpu') 15 | 16 | 17 | def parallelize(model): 18 | return nn.DataParallel(model).to(device) 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser() 23 | 24 | # resume from checkpoint 25 | parser.add_argument('--resume') 26 | 27 | # path of CelebA dataset 28 | parser.add_argument('--data-dir', default='celeba-data') 29 | 30 | # training options 31 | parser.add_argument('--workers', type=int, default=0) 32 | parser.add_argument('--epoch', type=int, default=800) 33 | parser.add_argument('--batch-size', type=int, default=512) 34 | parser.add_argument('--pretrain', default=None) 35 | parser.add_argument('--imputeronly', action='store_true') 36 | 37 | # log options: 0 to disable plot-interval or save-interval 38 | parser.add_argument('--plot-interval', type=int, default=50) 39 | parser.add_argument('--save-interval', type=int, default=0) 40 | parser.add_argument('--prefix', default='impute') 41 | 42 | # mask options (data): block|indep 43 | parser.add_argument('--mask', default='block') 44 | # option for block: set to 0 for variable size 45 | parser.add_argument('--block-len', type=int, default=32) 46 | # option for indep: 47 | parser.add_argument('--obs-prob', type=float, default=.2) 48 | parser.add_argument('--obs-prob-high', type=float, default=None) 49 | 50 | # model options 51 | parser.add_argument('--tau', type=float, default=.5) 52 | parser.add_argument('--alpha', type=float, default=.1) # 0: separate 53 | parser.add_argument('--beta', type=float, default=.1) 54 | parser.add_argument('--gamma', type=float, default=0) 55 | # options for mask generator: sigmoid, hardsigmoid, fusion 56 | parser.add_argument('--maskgen', default='fusion') 57 | parser.add_argument('--gp-lambda', type=float, default=10) 58 | parser.add_argument('--n-critic', type=int, default=5) 59 | parser.add_argument('--n-latent', type=int, default=128) 60 | 61 | args = parser.parse_args() 62 | 63 | checkpoint = None 64 | # Resume from previously stored checkpoint 65 | if args.resume: 66 | print(f'Resume: {args.resume}') 67 | output_dir = Path(args.resume) 68 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'), 69 | map_location='cpu') 70 | for key, arg in vars(checkpoint['args']).items(): 71 | if key not in ['resume']: 72 | setattr(args, key, arg) 73 | 74 | if args.imputeronly: 75 | assert args.pretrain is not None 76 | 77 | mask = args.mask 78 | obs_prob = args.obs_prob 79 | obs_prob_high = args.obs_prob_high 80 | block_len = args.block_len 81 | if block_len == 0: 82 | block_len = None 83 | 84 | if args.maskgen == 'sigmoid': 85 | hard_sigmoid = False 86 | elif args.maskgen == 'hardsigmoid': 87 | hard_sigmoid = True 88 | elif args.maskgen == 'fusion': 89 | hard_sigmoid = -.1, 1.1 90 | else: 91 | raise NotImplementedError 92 | 93 | if mask == 'indep': 94 | if obs_prob_high is None: 95 | mask_str = f'indep_{obs_prob:g}' 96 | else: 97 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}' 98 | elif mask == 'block': 99 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize') 100 | else: 101 | raise NotImplementedError 102 | 103 | path = '{}_{}_{}'.format( 104 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), 105 | '_'.join([ 106 | f'tau_{args.tau:g}', 107 | f'maskgen_{args.maskgen}', 108 | f'coef_{args.alpha:g}_{args.beta:g}_{args.gamma:g}', 109 | mask_str, 110 | ])) 111 | 112 | if not args.resume: 113 | output_dir = Path('results') / 'celeba' / path 114 | print(output_dir) 115 | 116 | if mask == 'indep': 117 | data = IndepMaskedCelebA( 118 | data_dir=args.data_dir, 119 | obs_prob=obs_prob, obs_prob_high=obs_prob_high) 120 | elif mask == 'block': 121 | data = BlockMaskedCelebA( 122 | data_dir=args.data_dir, block_len=block_len) 123 | 124 | n_gpu = torch.cuda.device_count() 125 | print(f'Use {n_gpu} GPUs.') 126 | data_gen = parallelize(ConvDataGenerator()) 127 | mask_gen = parallelize(ConvMaskGenerator(hard_sigmoid=hard_sigmoid)) 128 | imputer = UNetImputer().to(device) 129 | 130 | data_critic = parallelize(ConvCritic(n_channels=3)) 131 | mask_critic = parallelize(ConvCritic(n_channels=1)) 132 | impu_critic = parallelize(ConvCritic(n_channels=3)) 133 | 134 | misgan_impute(args, data_gen, mask_gen, imputer, 135 | data_critic, mask_critic, impu_critic, 136 | data, output_dir, checkpoint) 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /src/fcnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class FullyConnectedNet(nn.Module): 5 | def __init__(self, weights, output_shape=None): 6 | super().__init__() 7 | n_layers = len(weights) - 1 8 | 9 | layers = [nn.Linear(weights[0], weights[1])] 10 | for i in range(1, n_layers): 11 | layers.extend([nn.ReLU(), nn.Linear(weights[i], weights[i + 1])]) 12 | 13 | self.model = nn.Sequential(*layers) 14 | self.output_shape = output_shape 15 | 16 | def forward(self, input): 17 | output = self.model(input.view(input.shape[0], -1)) 18 | if self.output_shape is not None: 19 | output = output.view(self.output_shape) 20 | return output 21 | -------------------------------------------------------------------------------- /src/fid.py: -------------------------------------------------------------------------------- 1 | """Code adapted from https://github.com/mseitzer/pytorch-fid 2 | """ 3 | from pathlib import Path 4 | import torch 5 | import numpy as np 6 | from scipy import linalg 7 | import time 8 | import sys 9 | from inception import InceptionV3 10 | 11 | 12 | use_cuda = torch.cuda.is_available() 13 | device = torch.device('cuda' if use_cuda else 'cpu') 14 | 15 | FEATURE_DIM = 2048 16 | RESIZE = 299 17 | 18 | 19 | def get_activations(image_iterator, images, model, verbose=True): 20 | """Calculates the activations of the pool_3 layer for all images. 21 | 22 | Params: 23 | -- image_iterator 24 | : A generator that generates a batch of images at a time. 25 | -- images : Number of images that will be generated by 26 | image_iterator. 27 | -- model : Instance of inception model 28 | -- verbose : If set to True and parameter out_step is given, the number 29 | of calculated batches is reported. 30 | Returns: 31 | -- A numpy array of dimension (num images, dims) that contains the 32 | activations of the given tensor when feeding inception with the 33 | query tensor. 34 | """ 35 | model.eval() 36 | 37 | if not sys.stdout.isatty(): 38 | verbose = False 39 | 40 | pred_arr = np.empty((images, FEATURE_DIM)) 41 | end = 0 42 | t0 = time.time() 43 | 44 | for batch in image_iterator: 45 | if not isinstance(batch, torch.Tensor): 46 | batch = batch[0] 47 | start = end 48 | batch_size = batch.shape[0] 49 | end = start + batch_size 50 | 51 | with torch.no_grad(): 52 | batch = batch.to(device) 53 | pred = model(batch)[0] 54 | batch_feature = pred.cpu().numpy().reshape(batch_size, -1) 55 | pred_arr[start:end] = batch_feature 56 | 57 | if verbose: 58 | print('\rProcessed: {} time: {:.2f}'.format( 59 | end, time.time() - t0), end='', flush=True) 60 | 61 | assert end == images 62 | 63 | if verbose: 64 | print(' done') 65 | 66 | return pred_arr 67 | 68 | 69 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 70 | """Numpy implementation of the Frechet Distance. 71 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 72 | and X_2 ~ N(mu_2, C_2) is 73 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 74 | 75 | Stable version by Dougal J. Sutherland. 76 | 77 | Params: 78 | -- mu1 : Numpy array containing the activations of a layer of the 79 | inception net (like returned by the function 'get_predictions') 80 | for generated samples. 81 | -- mu2 : The sample mean over activations, precalculated on an 82 | representive data set. 83 | -- sigma1: The covariance matrix over activations for generated samples. 84 | -- sigma2: The covariance matrix over activations, precalculated on an 85 | representive data set. 86 | 87 | Returns: 88 | -- : The Frechet Distance. 89 | """ 90 | 91 | mu1 = np.atleast_1d(mu1) 92 | mu2 = np.atleast_1d(mu2) 93 | 94 | sigma1 = np.atleast_2d(sigma1) 95 | sigma2 = np.atleast_2d(sigma2) 96 | 97 | assert mu1.shape == mu2.shape, \ 98 | 'Training and test mean vectors have different lengths' 99 | assert sigma1.shape == sigma2.shape, \ 100 | 'Training and test covariances have different dimensions' 101 | 102 | diff = mu1 - mu2 103 | 104 | # Product might be almost singular 105 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 106 | if not np.isfinite(covmean).all(): 107 | msg = ('fid calculation produces singular product; ' 108 | 'adding %s to diagonal of cov estimates') % eps 109 | print(msg) 110 | offset = np.eye(sigma1.shape[0]) * eps 111 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 112 | 113 | # Numerical error might give slight imaginary component 114 | if np.iscomplexobj(covmean): 115 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 116 | m = np.max(np.abs(covmean.imag)) 117 | raise ValueError('Imaginary component {}'.format(m)) 118 | covmean = covmean.real 119 | 120 | tr_covmean = np.trace(covmean) 121 | 122 | return (diff.dot(diff) + np.trace(sigma1) + 123 | np.trace(sigma2) - 2 * tr_covmean) 124 | 125 | 126 | def calculate_activation_statistics(image_iterator, images, model, 127 | verbose=False): 128 | """Calculation of the statistics used by the FID. 129 | Params: 130 | -- image_iterator 131 | : A generator that generates a batch of images at a time. 132 | -- images : Number of images that will be generated by 133 | image_iterator. 134 | -- model : Instance of inception model 135 | -- verbose : If set to True and parameter out_step is given, the 136 | number of calculated batches is reported. 137 | Returns: 138 | -- mu : The mean over samples of the activations of the pool_3 layer of 139 | the inception model. 140 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 141 | the inception model. 142 | """ 143 | act = get_activations(image_iterator, images, model, verbose) 144 | mu = np.mean(act, axis=0) 145 | sigma = np.cov(act, rowvar=False) 146 | return mu, sigma 147 | 148 | 149 | class FID: 150 | def __init__(self, data_name, verbose=True): 151 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[FEATURE_DIM] 152 | model = InceptionV3([block_idx], RESIZE).to(device) 153 | self.verbose = verbose 154 | 155 | stats_dir = Path('fid_stats') 156 | stats_file = stats_dir / '{}_act_{}_{}.npz'.format( 157 | data_name, FEATURE_DIM, RESIZE) 158 | 159 | try: 160 | f = np.load(str(stats_file)) 161 | mu, sigma = f['mu'], f['sigma'] 162 | f.close() 163 | except FileNotFoundError: 164 | data_loader, images = self.complete_data() 165 | mu, sigma = calculate_activation_statistics( 166 | data_loader, images, model, verbose) 167 | stats_dir.mkdir(parents=True, exist_ok=True) 168 | np.savez(stats_file, mu=mu, sigma=sigma) 169 | 170 | self.model = model 171 | self.stats = mu, sigma 172 | 173 | def complete_data(self): 174 | raise NotImplementedError 175 | 176 | def fid(self, image_iterator, images): 177 | mu, sigma = calculate_activation_statistics( 178 | image_iterator, images, self.model, verbose=self.verbose) 179 | return calculate_frechet_distance(mu, sigma, *self.stats) 180 | 181 | 182 | class BaseSampler: 183 | def __init__(self, images): 184 | self.images = images 185 | 186 | def __iter__(self): 187 | self.n = 0 188 | return self 189 | 190 | def __next__(self): 191 | if self.n < self.images: 192 | batch = self.sample() 193 | batch_size = batch.shape[0] 194 | self.n += batch_size 195 | if self.n > self.images: 196 | return batch[:-(self.n - self.images)] 197 | return batch 198 | else: 199 | raise StopIteration 200 | 201 | def sample(self): 202 | raise NotImplementedError 203 | 204 | 205 | class BaseImputationSampler: 206 | def __init__(self, data_loader): 207 | self.data_loader = data_loader 208 | 209 | def __iter__(self): 210 | self.data_iter = iter(self.data_loader) 211 | return self 212 | 213 | def __next__(self): 214 | data, mask = next(self.data_iter)[:2] 215 | data = data.to(device) 216 | mask = mask.float()[:, None].to(device) 217 | imputed_data = self.impute(data, mask) 218 | return mask * data + (1 - mask) * imputed_data 219 | 220 | def impute(self, data, mask): 221 | raise NotImplementedError 222 | -------------------------------------------------------------------------------- /src/imputer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from fcnet import FullyConnectedNet 4 | from unet import UnetSkipConnectionBlock 5 | 6 | 7 | # Code adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 8 | class UNet(nn.Module): 9 | def __init__(self, input_nc=3, output_nc=3, ngf=64, layers=5, 10 | norm_layer=nn.BatchNorm2d): 11 | super().__init__() 12 | 13 | mid_layers = layers - 2 14 | fact = 2**mid_layers 15 | 16 | unet_block = UnetSkipConnectionBlock( 17 | ngf * fact, ngf * fact, input_nc=None, submodule=None, 18 | norm_layer=norm_layer, innermost=True) 19 | 20 | for _ in range(mid_layers): 21 | half_fact = fact // 2 22 | unet_block = UnetSkipConnectionBlock( 23 | ngf * half_fact, ngf * fact, input_nc=None, 24 | submodule=unet_block, norm_layer=norm_layer) 25 | fact = half_fact 26 | 27 | unet_block = UnetSkipConnectionBlock( 28 | output_nc, ngf, input_nc=input_nc, submodule=unet_block, 29 | outermost=True, norm_layer=norm_layer) 30 | 31 | self.model = unet_block 32 | 33 | def forward(self, input): 34 | return self.model(input) 35 | 36 | 37 | class Imputer(nn.Module): 38 | def __init__(self): 39 | super().__init__() 40 | self.transform = lambda x: torch.sigmoid(x) 41 | 42 | def forward(self, input, mask, noise): 43 | net = input * mask + noise * (1 - mask) 44 | net = self.imputer_net(net) 45 | net = self.transform(net) 46 | # NOT replacing observed part with input data for computing 47 | # autoencoding loss. 48 | # return input * mask + net * (1 - mask) 49 | return net 50 | 51 | 52 | class UNetImputer(Imputer): 53 | def __init__(self, *args, **kwargs): 54 | super().__init__() 55 | self.imputer_net = UNet(*args, **kwargs) 56 | 57 | 58 | class FullyConnectedImputer(Imputer): 59 | def __init__(self, *args, **kwargs): 60 | super().__init__() 61 | self.imputer_net = FullyConnectedNet(*args, **kwargs) 62 | -------------------------------------------------------------------------------- /src/inception.py: -------------------------------------------------------------------------------- 1 | """Code from https://github.com/mseitzer/pytorch-fid 2 | """ 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import models 6 | 7 | 8 | class InceptionV3(nn.Module): 9 | """Pretrained InceptionV3 network returning feature maps""" 10 | 11 | # Index of default block of inception to return, 12 | # corresponds to output of final average pooling 13 | DEFAULT_BLOCK_INDEX = 3 14 | 15 | # Maps feature dimensionality to their output blocks indices 16 | BLOCK_INDEX_BY_DIM = { 17 | 64: 0, # First max pooling features 18 | 192: 1, # Second max pooling featurs 19 | 768: 2, # Pre-aux classifier features 20 | 2048: 3 # Final average pooling features 21 | } 22 | 23 | def __init__(self, 24 | output_blocks=[DEFAULT_BLOCK_INDEX], 25 | resize_input=299, # -1: not resize 26 | normalize_input=True, 27 | requires_grad=False): 28 | """Build pretrained InceptionV3 29 | 30 | Parameters 31 | ---------- 32 | output_blocks : list of int 33 | Indices of blocks to return features of. Possible values are: 34 | - 0: corresponds to output of first max pooling 35 | - 1: corresponds to output of second max pooling 36 | - 2: corresponds to output which is fed to aux classifier 37 | - 3: corresponds to output of final average pooling 38 | resize_input : bool 39 | If true, bilinearly resizes input to width and height 299 before 40 | feeding input to model. As the network without fully connected 41 | layers is fully convolutional, it should be able to handle inputs 42 | of arbitrary size, so resizing might not be strictly needed 43 | normalize_input : bool 44 | If true, normalizes the input to the statistics the pretrained 45 | Inception network expects 46 | requires_grad : bool 47 | If true, parameters of the model require gradient. Possibly useful 48 | for finetuning the network 49 | """ 50 | super(InceptionV3, self).__init__() 51 | 52 | self.resize_input = resize_input 53 | self.normalize_input = normalize_input 54 | self.output_blocks = sorted(output_blocks) 55 | self.last_needed_block = max(output_blocks) 56 | 57 | assert self.last_needed_block <= 3, \ 58 | 'Last possible output block index is 3' 59 | 60 | self.blocks = nn.ModuleList() 61 | 62 | inception = models.inception_v3(pretrained=True) 63 | 64 | # Block 0: input to maxpool1 65 | block0 = [ 66 | inception.Conv2d_1a_3x3, 67 | inception.Conv2d_2a_3x3, 68 | inception.Conv2d_2b_3x3, 69 | nn.MaxPool2d(kernel_size=3, stride=2) 70 | ] 71 | self.blocks.append(nn.Sequential(*block0)) 72 | 73 | # Block 1: maxpool1 to maxpool2 74 | if self.last_needed_block >= 1: 75 | block1 = [ 76 | inception.Conv2d_3b_1x1, 77 | inception.Conv2d_4a_3x3, 78 | nn.MaxPool2d(kernel_size=3, stride=2) 79 | ] 80 | self.blocks.append(nn.Sequential(*block1)) 81 | 82 | # Block 2: maxpool2 to aux classifier 83 | if self.last_needed_block >= 2: 84 | block2 = [ 85 | inception.Mixed_5b, 86 | inception.Mixed_5c, 87 | inception.Mixed_5d, 88 | inception.Mixed_6a, 89 | inception.Mixed_6b, 90 | inception.Mixed_6c, 91 | inception.Mixed_6d, 92 | inception.Mixed_6e, 93 | ] 94 | self.blocks.append(nn.Sequential(*block2)) 95 | 96 | # Block 3: aux classifier to final avgpool 97 | if self.last_needed_block >= 3: 98 | block3 = [ 99 | inception.Mixed_7a, 100 | inception.Mixed_7b, 101 | inception.Mixed_7c, 102 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 103 | ] 104 | self.blocks.append(nn.Sequential(*block3)) 105 | 106 | for param in self.parameters(): 107 | param.requires_grad = requires_grad 108 | 109 | def forward(self, inp): 110 | """Get Inception feature maps 111 | 112 | Parameters 113 | ---------- 114 | inp : torch.autograd.Variable 115 | Input tensor of shape Bx3xHxW. Values are expected to be in 116 | range (0, 1) 117 | 118 | Returns 119 | ------- 120 | List of torch.autograd.Variable, corresponding to the selected output 121 | block, sorted ascending by index 122 | """ 123 | outp = [] 124 | x = inp 125 | 126 | if self.resize_input > 0: 127 | # size = 299 128 | x = F.interpolate(x, size=(self.resize_input, self.resize_input), 129 | mode='bilinear', align_corners=True) 130 | 131 | if self.normalize_input: 132 | x = x.clone() 133 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 134 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 135 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 136 | 137 | for idx, block in enumerate(self.blocks): 138 | x = block(x) 139 | if idx in self.output_blocks: 140 | outp.append(x) 141 | 142 | if idx == self.last_needed_block: 143 | break 144 | 145 | return outp 146 | -------------------------------------------------------------------------------- /src/masked_celeba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | class MaskedCelebA(datasets.ImageFolder): 8 | def __init__(self, data_dir='celeba-data', image_size=64, random_seed=0): 9 | transform = transforms.Compose([ 10 | transforms.CenterCrop(108), 11 | transforms.Resize(size=image_size, interpolation=Image.BICUBIC), 12 | transforms.ToTensor(), 13 | # transforms.Normalize(mean=(.5, .5, .5), std=(.5, .5, .5)), 14 | ]) 15 | 16 | super().__init__(data_dir, transform) 17 | 18 | self.rnd = np.random.RandomState(random_seed) 19 | self.image_size = image_size 20 | self.generate_masks() 21 | 22 | def __getitem__(self, index): 23 | image, label = super().__getitem__(index) 24 | return image, self.mask[index], label, index 25 | 26 | def __len__(self): 27 | return super().__len__() 28 | 29 | 30 | class BlockMaskedCelebA(MaskedCelebA): 31 | def __init__(self, block_len=None, *args, **kwargs): 32 | self.block_len = block_len 33 | super().__init__(*args, **kwargs) 34 | 35 | def generate_masks(self): 36 | d0_len = d1_len = self.image_size 37 | d0_min_len = 12 38 | d0_max_len = d0_len - d0_min_len 39 | d1_min_len = 12 40 | d1_max_len = d1_len - d1_min_len 41 | 42 | n_masks = len(self) 43 | self.mask = [None] * n_masks 44 | self.mask_info = [None] * n_masks 45 | for i in range(n_masks): 46 | if self.block_len is None: 47 | d0_mask_len = self.rnd.randint(d0_min_len, d0_max_len) 48 | d1_mask_len = self.rnd.randint(d1_min_len, d1_max_len) 49 | else: 50 | d0_mask_len = d1_mask_len = self.block_len 51 | 52 | d0_start = self.rnd.randint(0, d0_len - d0_mask_len + 1) 53 | d1_start = self.rnd.randint(0, d1_len - d1_mask_len + 1) 54 | 55 | mask = torch.zeros((d0_len, d1_len), dtype=torch.uint8) 56 | mask[d0_start:(d0_start + d0_mask_len), 57 | d1_start:(d1_start + d1_mask_len)] = 1 58 | self.mask[i] = mask 59 | self.mask_info[i] = d0_start, d1_start, d0_mask_len, d1_mask_len 60 | 61 | 62 | class IndepMaskedCelebA(MaskedCelebA): 63 | def __init__(self, obs_prob=.2, obs_prob_high=None, *args, **kwargs): 64 | self.prob = obs_prob 65 | self.prob_high = obs_prob_high 66 | super().__init__(*args, **kwargs) 67 | 68 | def generate_masks(self): 69 | imsize = self.image_size 70 | prob = self.prob 71 | prob_high = self.prob_high 72 | n_masks = len(self) 73 | self.mask = [None] * n_masks 74 | for i in range(n_masks): 75 | if prob_high is None: 76 | p = prob 77 | else: 78 | p = self.rnd.uniform(prob, prob_high) 79 | self.mask[i] = torch.ByteTensor(imsize, imsize).bernoulli_(p) 80 | -------------------------------------------------------------------------------- /src/masked_mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torchvision import datasets, transforms 4 | import numpy as np 5 | 6 | 7 | class MaskedMNIST(Dataset): 8 | def __init__(self, data_dir='mnist-data', image_size=28, random_seed=0): 9 | self.rnd = np.random.RandomState(random_seed) 10 | self.image_size = image_size 11 | if image_size == 28: 12 | self.data = datasets.MNIST( 13 | data_dir, train=True, download=True, 14 | transform=transforms.ToTensor()) 15 | else: 16 | self.data = datasets.MNIST( 17 | data_dir, train=True, download=True, 18 | transform=transforms.Compose([ 19 | transforms.Resize(image_size), transforms.ToTensor()])) 20 | self.generate_masks() 21 | 22 | def __getitem__(self, index): 23 | image, label = self.data[index] 24 | return image, self.mask[index], label, index 25 | 26 | def __len__(self): 27 | return len(self.data) 28 | 29 | def generate_masks(self): 30 | raise NotImplementedError 31 | 32 | 33 | class BlockMaskedMNIST(MaskedMNIST): 34 | def __init__(self, block_len=None, *args, **kwargs): 35 | self.block_len = block_len 36 | super().__init__(*args, **kwargs) 37 | 38 | def generate_masks(self): 39 | d0_len = d1_len = self.image_size 40 | d0_min_len = 7 41 | d0_max_len = d0_len - d0_min_len 42 | d1_min_len = 7 43 | d1_max_len = d1_len - d1_min_len 44 | 45 | n_masks = len(self) 46 | self.mask = [None] * n_masks 47 | self.mask_info = [None] * n_masks 48 | for i in range(n_masks): 49 | if self.block_len is None: 50 | d0_mask_len = self.rnd.randint(d0_min_len, d0_max_len) 51 | d1_mask_len = self.rnd.randint(d1_min_len, d1_max_len) 52 | else: 53 | d0_mask_len = d1_mask_len = self.block_len 54 | 55 | d0_start = self.rnd.randint(0, d0_len - d0_mask_len + 1) 56 | d1_start = self.rnd.randint(0, d1_len - d1_mask_len + 1) 57 | 58 | mask = torch.zeros((d0_len, d1_len), dtype=torch.uint8) 59 | mask[d0_start:(d0_start + d0_mask_len), 60 | d1_start:(d1_start + d1_mask_len)] = 1 61 | self.mask[i] = mask 62 | self.mask_info[i] = d0_start, d1_start, d0_mask_len, d1_mask_len 63 | 64 | 65 | class IndepMaskedMNIST(MaskedMNIST): 66 | def __init__(self, obs_prob=.2, obs_prob_high=None, *args, **kwargs): 67 | self.prob = obs_prob 68 | self.prob_high = obs_prob_high 69 | super().__init__(*args, **kwargs) 70 | 71 | def generate_masks(self): 72 | imsize = self.image_size 73 | prob = self.prob 74 | prob_high = self.prob_high 75 | n_masks = len(self) 76 | self.mask = [None] * n_masks 77 | for i in range(n_masks): 78 | if prob_high is None: 79 | p = prob 80 | else: 81 | p = self.rnd.uniform(prob, prob_high) 82 | self.mask[i] = torch.ByteTensor(imsize, imsize).bernoulli_(p) 83 | -------------------------------------------------------------------------------- /src/misgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.utils.data import DataLoader 4 | import time 5 | import pylab as plt 6 | import seaborn as sns 7 | from collections import defaultdict 8 | from plot import plot_samples 9 | from utils import CriticUpdater, mkdir, mask_data 10 | 11 | 12 | use_cuda = torch.cuda.is_available() 13 | device = torch.device('cuda' if use_cuda else 'cpu') 14 | 15 | 16 | def misgan(args, data_gen, mask_gen, data_critic, mask_critic, data, 17 | output_dir, checkpoint=None): 18 | n_critic = args.n_critic 19 | gp_lambda = args.gp_lambda 20 | batch_size = args.batch_size 21 | nz = args.n_latent 22 | epochs = args.epoch 23 | plot_interval = args.plot_interval 24 | save_interval = args.save_interval 25 | alpha = args.alpha 26 | tau = args.tau 27 | 28 | gen_data_dir = mkdir(output_dir / 'img') 29 | gen_mask_dir = mkdir(output_dir / 'mask') 30 | log_dir = mkdir(output_dir / 'log') 31 | model_dir = mkdir(output_dir / 'model') 32 | 33 | data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, 34 | drop_last=True) 35 | n_batch = len(data_loader) 36 | 37 | data_noise = torch.FloatTensor(batch_size, nz).to(device) 38 | mask_noise = torch.FloatTensor(batch_size, nz).to(device) 39 | 40 | # Interpolation coefficient 41 | eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device) 42 | 43 | # For computing gradient penalty 44 | ones = torch.ones(batch_size).to(device) 45 | 46 | lrate = 1e-4 47 | # lrate = 1e-5 48 | data_gen_optimizer = optim.Adam( 49 | data_gen.parameters(), lr=lrate, betas=(.5, .9)) 50 | mask_gen_optimizer = optim.Adam( 51 | mask_gen.parameters(), lr=lrate, betas=(.5, .9)) 52 | 53 | data_critic_optimizer = optim.Adam( 54 | data_critic.parameters(), lr=lrate, betas=(.5, .9)) 55 | mask_critic_optimizer = optim.Adam( 56 | mask_critic.parameters(), lr=lrate, betas=(.5, .9)) 57 | 58 | update_data_critic = CriticUpdater( 59 | data_critic, data_critic_optimizer, eps, ones, gp_lambda) 60 | update_mask_critic = CriticUpdater( 61 | mask_critic, mask_critic_optimizer, eps, ones, gp_lambda) 62 | 63 | start_epoch = 0 64 | critic_updates = 0 65 | log = defaultdict(list) 66 | 67 | if checkpoint: 68 | data_gen.load_state_dict(checkpoint['data_gen']) 69 | mask_gen.load_state_dict(checkpoint['mask_gen']) 70 | data_critic.load_state_dict(checkpoint['data_critic']) 71 | mask_critic.load_state_dict(checkpoint['mask_critic']) 72 | data_gen_optimizer.load_state_dict(checkpoint['data_gen_opt']) 73 | mask_gen_optimizer.load_state_dict(checkpoint['mask_gen_opt']) 74 | data_critic_optimizer.load_state_dict(checkpoint['data_critic_opt']) 75 | mask_critic_optimizer.load_state_dict(checkpoint['mask_critic_opt']) 76 | start_epoch = checkpoint['epoch'] 77 | critic_updates = checkpoint['critic_updates'] 78 | log = checkpoint['log'] 79 | 80 | with (log_dir / 'gpu.txt').open('a') as f: 81 | print(torch.cuda.device_count(), start_epoch, file=f) 82 | 83 | def save_model(path, epoch, critic_updates=0): 84 | torch.save({ 85 | 'data_gen': data_gen.state_dict(), 86 | 'mask_gen': mask_gen.state_dict(), 87 | 'data_critic': data_critic.state_dict(), 88 | 'mask_critic': mask_critic.state_dict(), 89 | 'data_gen_opt': data_gen_optimizer.state_dict(), 90 | 'mask_gen_opt': mask_gen_optimizer.state_dict(), 91 | 'data_critic_opt': data_critic_optimizer.state_dict(), 92 | 'mask_critic_opt': mask_critic_optimizer.state_dict(), 93 | 'epoch': epoch + 1, 94 | 'critic_updates': critic_updates, 95 | 'log': log, 96 | 'args': args, 97 | }, str(path)) 98 | 99 | sns.set() 100 | 101 | start = time.time() 102 | epoch_start = start 103 | 104 | for epoch in range(start_epoch, epochs): 105 | sum_data_loss, sum_mask_loss = 0, 0 106 | for real_data, real_mask, _, _ in data_loader: 107 | # Assume real_data and mask have the same number of channels. 108 | # Could be modified to handle multi-channel images and 109 | # single-channel masks. 110 | real_mask = real_mask.float()[:, None] 111 | 112 | real_data = real_data.to(device) 113 | real_mask = real_mask.to(device) 114 | 115 | masked_real_data = mask_data(real_data, real_mask, tau) 116 | 117 | # Update discriminators' parameters 118 | data_noise.normal_() 119 | mask_noise.normal_() 120 | 121 | fake_data = data_gen(data_noise) 122 | fake_mask = mask_gen(mask_noise) 123 | 124 | masked_fake_data = mask_data(fake_data, fake_mask, tau) 125 | 126 | update_data_critic(masked_real_data, masked_fake_data) 127 | update_mask_critic(real_mask, fake_mask) 128 | 129 | sum_data_loss += update_data_critic.loss_value 130 | sum_mask_loss += update_mask_critic.loss_value 131 | 132 | critic_updates += 1 133 | 134 | if critic_updates == n_critic: 135 | critic_updates = 0 136 | 137 | # Update generators' parameters 138 | 139 | for p in data_critic.parameters(): 140 | p.requires_grad_(False) 141 | for p in mask_critic.parameters(): 142 | p.requires_grad_(False) 143 | 144 | data_gen.zero_grad() 145 | mask_gen.zero_grad() 146 | 147 | data_noise.normal_() 148 | mask_noise.normal_() 149 | 150 | fake_data = data_gen(data_noise) 151 | fake_mask = mask_gen(mask_noise) 152 | masked_fake_data = mask_data(fake_data, fake_mask, tau) 153 | 154 | data_loss = -data_critic(masked_fake_data).mean() 155 | data_loss.backward(retain_graph=True) 156 | data_gen_optimizer.step() 157 | 158 | mask_loss = -mask_critic(fake_mask).mean() 159 | (mask_loss + data_loss * alpha).backward() 160 | mask_gen_optimizer.step() 161 | 162 | for p in data_critic.parameters(): 163 | p.requires_grad_(True) 164 | for p in mask_critic.parameters(): 165 | p.requires_grad_(True) 166 | 167 | mean_data_loss = sum_data_loss / n_batch 168 | mean_mask_loss = sum_mask_loss / n_batch 169 | log['data loss', 'data_loss'].append(mean_data_loss) 170 | log['mask loss', 'mask_loss'].append(mean_mask_loss) 171 | 172 | for (name, shortname), trace in log.items(): 173 | fig, ax = plt.subplots(figsize=(6, 4)) 174 | ax.plot(trace) 175 | ax.set_ylabel(name) 176 | ax.set_xlabel('epoch') 177 | fig.savefig(str(log_dir / f'{shortname}.png'), dpi=300) 178 | plt.close(fig) 179 | 180 | if plot_interval > 0 and (epoch + 1) % plot_interval == 0: 181 | print(f'[{epoch:4}] {mean_data_loss:12.4f} {mean_mask_loss:12.4f}') 182 | 183 | filename = f'{epoch:04d}.png' 184 | 185 | data_gen.eval() 186 | mask_gen.eval() 187 | 188 | with torch.no_grad(): 189 | data_noise.normal_() 190 | mask_noise.normal_() 191 | 192 | data_samples = data_gen(data_noise) 193 | plot_samples(data_samples, str(gen_data_dir / filename)) 194 | 195 | mask_samples = mask_gen(mask_noise) 196 | plot_samples(mask_samples, str(gen_mask_dir / filename)) 197 | 198 | data_gen.train() 199 | mask_gen.train() 200 | 201 | if save_interval > 0 and (epoch + 1) % save_interval == 0: 202 | save_model(model_dir / f'{epoch:04d}.pth', epoch, critic_updates) 203 | 204 | epoch_end = time.time() 205 | time_elapsed = epoch_end - start 206 | epoch_time = epoch_end - epoch_start 207 | epoch_start = epoch_end 208 | with (log_dir / 'time.txt').open('a') as f: 209 | print(epoch, epoch_time, time_elapsed, file=f) 210 | save_model(log_dir / 'checkpoint.pth', epoch, critic_updates) 211 | 212 | print(output_dir) 213 | -------------------------------------------------------------------------------- /src/misgan_impute.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | from torch.utils.data import DataLoader 4 | import time 5 | import pylab as plt 6 | import seaborn as sns 7 | from collections import defaultdict 8 | from plot import plot_grid, plot_samples 9 | from utils import CriticUpdater, mask_norm, mkdir, mask_data 10 | 11 | 12 | use_cuda = torch.cuda.is_available() 13 | device = torch.device('cuda' if use_cuda else 'cpu') 14 | 15 | 16 | def misgan_impute(args, data_gen, mask_gen, imputer, 17 | data_critic, mask_critic, impu_critic, 18 | data, output_dir, checkpoint=None): 19 | n_critic = args.n_critic 20 | gp_lambda = args.gp_lambda 21 | batch_size = args.batch_size 22 | nz = args.n_latent 23 | epochs = args.epoch 24 | plot_interval = args.plot_interval 25 | save_model_interval = args.save_interval 26 | alpha = args.alpha 27 | beta = args.beta 28 | gamma = args.gamma 29 | tau = args.tau 30 | update_all_networks = not args.imputeronly 31 | 32 | gen_data_dir = mkdir(output_dir / 'img') 33 | gen_mask_dir = mkdir(output_dir / 'mask') 34 | impute_dir = mkdir(output_dir / 'impute') 35 | log_dir = mkdir(output_dir / 'log') 36 | model_dir = mkdir(output_dir / 'model') 37 | 38 | data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, 39 | drop_last=True, num_workers=args.workers) 40 | n_batch = len(data_loader) 41 | data_shape = data[0][0].shape 42 | 43 | data_noise = torch.FloatTensor(batch_size, nz).to(device) 44 | mask_noise = torch.FloatTensor(batch_size, nz).to(device) 45 | impu_noise = torch.FloatTensor(batch_size, *data_shape).to(device) 46 | 47 | # Interpolation coefficient 48 | eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device) 49 | 50 | # For computing gradient penalty 51 | ones = torch.ones(batch_size).to(device) 52 | 53 | lrate = 1e-4 54 | imputer_lrate = 2e-4 55 | data_gen_optimizer = optim.Adam( 56 | data_gen.parameters(), lr=lrate, betas=(.5, .9)) 57 | mask_gen_optimizer = optim.Adam( 58 | mask_gen.parameters(), lr=lrate, betas=(.5, .9)) 59 | imputer_optimizer = optim.Adam( 60 | imputer.parameters(), lr=imputer_lrate, betas=(.5, .9)) 61 | 62 | data_critic_optimizer = optim.Adam( 63 | data_critic.parameters(), lr=lrate, betas=(.5, .9)) 64 | mask_critic_optimizer = optim.Adam( 65 | mask_critic.parameters(), lr=lrate, betas=(.5, .9)) 66 | impu_critic_optimizer = optim.Adam( 67 | impu_critic.parameters(), lr=imputer_lrate, betas=(.5, .9)) 68 | 69 | update_data_critic = CriticUpdater( 70 | data_critic, data_critic_optimizer, eps, ones, gp_lambda) 71 | update_mask_critic = CriticUpdater( 72 | mask_critic, mask_critic_optimizer, eps, ones, gp_lambda) 73 | update_impu_critic = CriticUpdater( 74 | impu_critic, impu_critic_optimizer, eps, ones, gp_lambda) 75 | 76 | start_epoch = 0 77 | critic_updates = 0 78 | log = defaultdict(list) 79 | 80 | if args.resume: 81 | data_gen.load_state_dict(checkpoint['data_gen']) 82 | mask_gen.load_state_dict(checkpoint['mask_gen']) 83 | imputer.load_state_dict(checkpoint['imputer']) 84 | data_critic.load_state_dict(checkpoint['data_critic']) 85 | mask_critic.load_state_dict(checkpoint['mask_critic']) 86 | impu_critic.load_state_dict(checkpoint['impu_critic']) 87 | data_gen_optimizer.load_state_dict(checkpoint['data_gen_opt']) 88 | mask_gen_optimizer.load_state_dict(checkpoint['mask_gen_opt']) 89 | imputer_optimizer.load_state_dict(checkpoint['imputer_opt']) 90 | data_critic_optimizer.load_state_dict(checkpoint['data_critic_opt']) 91 | mask_critic_optimizer.load_state_dict(checkpoint['mask_critic_opt']) 92 | impu_critic_optimizer.load_state_dict(checkpoint['impu_critic_opt']) 93 | start_epoch = checkpoint['epoch'] 94 | critic_updates = checkpoint['critic_updates'] 95 | log = checkpoint['log'] 96 | elif args.pretrain: 97 | pretrain = torch.load(args.pretrain, map_location='cpu') 98 | data_gen.load_state_dict(pretrain['data_gen']) 99 | mask_gen.load_state_dict(pretrain['mask_gen']) 100 | data_critic.load_state_dict(pretrain['data_critic']) 101 | mask_critic.load_state_dict(pretrain['mask_critic']) 102 | if 'imputer' in pretrain: 103 | imputer.load_state_dict(pretrain['imputer']) 104 | impu_critic.load_state_dict(pretrain['impu_critic']) 105 | 106 | with (log_dir / 'gpu.txt').open('a') as f: 107 | print(torch.cuda.device_count(), start_epoch, file=f) 108 | 109 | def save_model(path, epoch, critic_updates=0): 110 | torch.save({ 111 | 'data_gen': data_gen.state_dict(), 112 | 'mask_gen': mask_gen.state_dict(), 113 | 'imputer': imputer.state_dict(), 114 | 'data_critic': data_critic.state_dict(), 115 | 'mask_critic': mask_critic.state_dict(), 116 | 'impu_critic': impu_critic.state_dict(), 117 | 'data_gen_opt': data_gen_optimizer.state_dict(), 118 | 'mask_gen_opt': mask_gen_optimizer.state_dict(), 119 | 'imputer_opt': imputer_optimizer.state_dict(), 120 | 'data_critic_opt': data_critic_optimizer.state_dict(), 121 | 'mask_critic_opt': mask_critic_optimizer.state_dict(), 122 | 'impu_critic_opt': impu_critic_optimizer.state_dict(), 123 | 'epoch': epoch + 1, 124 | 'critic_updates': critic_updates, 125 | 'log': log, 126 | 'args': args, 127 | }, str(path)) 128 | 129 | sns.set() 130 | start = time.time() 131 | epoch_start = start 132 | 133 | for epoch in range(start_epoch, epochs): 134 | sum_data_loss, sum_mask_loss, sum_impu_loss = 0, 0, 0 135 | for real_data, real_mask, _, index in data_loader: 136 | # Assume real_data and real_mask have the same number of channels. 137 | # Could be modified to handle multi-channel images and 138 | # single-channel masks. 139 | real_mask = real_mask.float()[:, None] 140 | 141 | real_data = real_data.to(device) 142 | real_mask = real_mask.to(device) 143 | 144 | masked_real_data = mask_data(real_data, real_mask, tau) 145 | 146 | # Update discriminators' parameters 147 | data_noise.normal_() 148 | fake_data = data_gen(data_noise) 149 | 150 | impu_noise.uniform_() 151 | imputed_data = imputer(real_data, real_mask, impu_noise) 152 | masked_imputed_data = mask_data(real_data, real_mask, imputed_data) 153 | 154 | if update_all_networks: 155 | mask_noise.normal_() 156 | fake_mask = mask_gen(mask_noise) 157 | masked_fake_data = mask_data(fake_data, fake_mask, tau) 158 | update_data_critic(masked_real_data, masked_fake_data) 159 | update_mask_critic(real_mask, fake_mask) 160 | 161 | sum_data_loss += update_data_critic.loss_value 162 | sum_mask_loss += update_mask_critic.loss_value 163 | 164 | update_impu_critic(fake_data, masked_imputed_data) 165 | sum_impu_loss += update_impu_critic.loss_value 166 | 167 | critic_updates += 1 168 | 169 | if critic_updates == n_critic: 170 | critic_updates = 0 171 | 172 | # Update generators' parameters 173 | if update_all_networks: 174 | for p in data_critic.parameters(): 175 | p.requires_grad_(False) 176 | for p in mask_critic.parameters(): 177 | p.requires_grad_(False) 178 | for p in impu_critic.parameters(): 179 | p.requires_grad_(False) 180 | 181 | data_noise.normal_() 182 | fake_data = data_gen(data_noise) 183 | 184 | if update_all_networks: 185 | mask_noise.normal_() 186 | fake_mask = mask_gen(mask_noise) 187 | masked_fake_data = mask_data(fake_data, fake_mask, tau) 188 | data_loss = -data_critic(masked_fake_data).mean() 189 | mask_loss = -mask_critic(fake_mask).mean() 190 | 191 | impu_noise.uniform_() 192 | imputed_data = imputer(real_data, real_mask, impu_noise) 193 | masked_imputed_data = mask_data(real_data, real_mask, 194 | imputed_data) 195 | impu_loss = -impu_critic(masked_imputed_data).mean() 196 | 197 | if update_all_networks: 198 | mask_gen.zero_grad() 199 | (mask_loss + data_loss * alpha).backward(retain_graph=True) 200 | mask_gen_optimizer.step() 201 | 202 | data_gen.zero_grad() 203 | (data_loss + impu_loss * beta).backward(retain_graph=True) 204 | data_gen_optimizer.step() 205 | 206 | imputer.zero_grad() 207 | if gamma > 0: 208 | imputer_mismatch_loss = mask_norm( 209 | (imputed_data - real_data)**2, real_mask) 210 | (impu_loss + imputer_mismatch_loss * gamma).backward() 211 | else: 212 | impu_loss.backward() 213 | imputer_optimizer.step() 214 | 215 | if update_all_networks: 216 | for p in data_critic.parameters(): 217 | p.requires_grad_(True) 218 | for p in mask_critic.parameters(): 219 | p.requires_grad_(True) 220 | for p in impu_critic.parameters(): 221 | p.requires_grad_(True) 222 | 223 | if update_all_networks: 224 | mean_data_loss = sum_data_loss / n_batch 225 | mean_mask_loss = sum_mask_loss / n_batch 226 | log['data loss', 'data_loss'].append(mean_data_loss) 227 | log['mask loss', 'mask_loss'].append(mean_mask_loss) 228 | mean_impu_loss = sum_impu_loss / n_batch 229 | log['imputer loss', 'impu_loss'].append(mean_impu_loss) 230 | 231 | if plot_interval > 0 and (epoch + 1) % plot_interval == 0: 232 | if update_all_networks: 233 | print('[{:4}] {:12.4f} {:12.4f} {:12.4f}'.format( 234 | epoch, mean_data_loss, mean_mask_loss, mean_impu_loss)) 235 | else: 236 | print('[{:4}] {:12.4f}'.format(epoch, mean_impu_loss)) 237 | 238 | filename = f'{epoch:04d}.png' 239 | with torch.no_grad(): 240 | data_gen.eval() 241 | mask_gen.eval() 242 | imputer.eval() 243 | 244 | data_noise.normal_() 245 | mask_noise.normal_() 246 | 247 | data_samples = data_gen(data_noise) 248 | plot_samples(data_samples, str(gen_data_dir / filename)) 249 | 250 | mask_samples = mask_gen(mask_noise) 251 | plot_samples(mask_samples, str(gen_mask_dir / filename)) 252 | 253 | # Plot imputation results 254 | impu_noise.uniform_() 255 | imputed_data = imputer(real_data, real_mask, impu_noise) 256 | imputed_data = mask_data(real_data, real_mask, imputed_data) 257 | if hasattr(data, 'mask_info'): 258 | bbox = [data.mask_info[idx] for idx in index] 259 | else: 260 | bbox = None 261 | plot_grid(imputed_data, bbox, gap=2, 262 | save_file=str(impute_dir / filename)) 263 | 264 | data_gen.train() 265 | mask_gen.train() 266 | imputer.train() 267 | 268 | for (name, shortname), trace in log.items(): 269 | fig, ax = plt.subplots(figsize=(6, 4)) 270 | ax.plot(trace) 271 | ax.set_ylabel(name) 272 | ax.set_xlabel('epoch') 273 | fig.savefig(str(log_dir / f'{shortname}.png'), dpi=300) 274 | plt.close(fig) 275 | 276 | if save_model_interval > 0 and (epoch + 1) % save_model_interval == 0: 277 | save_model(model_dir / f'{epoch:04d}.pth', epoch, critic_updates) 278 | 279 | epoch_end = time.time() 280 | time_elapsed = epoch_end - start 281 | epoch_time = epoch_end - epoch_start 282 | epoch_start = epoch_end 283 | with (log_dir / 'epoch-time.txt').open('a') as f: 284 | print(epoch, epoch_time, time_elapsed, file=f) 285 | save_model(log_dir / 'checkpoint.pth', epoch, critic_updates) 286 | 287 | print(output_dir) 288 | -------------------------------------------------------------------------------- /src/mnist_critic.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ConvCritic(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | self.DIM = 64 9 | main = nn.Sequential( 10 | nn.Conv2d(1, self.DIM, 5, stride=2, padding=2), 11 | nn.ReLU(True), 12 | nn.Conv2d(self.DIM, 2 * self.DIM, 5, stride=2, padding=2), 13 | nn.ReLU(True), 14 | nn.Conv2d(2 * self.DIM, 4 * self.DIM, 5, stride=2, padding=2), 15 | nn.ReLU(True), 16 | ) 17 | self.main = main 18 | self.output = nn.Linear(4 * 4 * 4 * self.DIM, 1) 19 | 20 | def forward(self, input): 21 | input = input.view(-1, 1, 28, 28) 22 | net = self.main(input) 23 | net = net.view(-1, 4 * 4 * 4 * self.DIM) 24 | net = self.output(net) 25 | return net.view(-1) 26 | 27 | 28 | class FCCritic(nn.Module): 29 | def __init__(self): 30 | super().__init__() 31 | 32 | self.in_dim = 784 33 | self.main = nn.Sequential( 34 | nn.Linear(self.in_dim, 512), 35 | nn.ReLU(True), 36 | nn.Linear(512, 256), 37 | nn.ReLU(True), 38 | nn.Linear(256, 128), 39 | nn.ReLU(True), 40 | nn.Linear(128, 1), 41 | ) 42 | 43 | def forward(self, input): 44 | input = input.view(input.size(0), -1) 45 | out = self.main(input) 46 | return out.view(-1) 47 | -------------------------------------------------------------------------------- /src/mnist_fid.py: -------------------------------------------------------------------------------- 1 | """Code adapted from https://github.com/mseitzer/pytorch-fid 2 | """ 3 | import torch 4 | import numpy as np 5 | from scipy import linalg 6 | from torch.utils.data import DataLoader 7 | from torchvision import datasets, transforms 8 | import argparse 9 | 10 | import mnist_model 11 | from mnist_generator import ConvDataGenerator, FCDataGenerator 12 | from mnist_imputer import ComplementImputer, MaskImputer, FixedNoiseDimImputer 13 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST 14 | from pathlib import Path 15 | 16 | 17 | use_cuda = torch.cuda.is_available() 18 | device = torch.device('cuda' if use_cuda else 'cpu') 19 | 20 | feature_layer = 0 21 | 22 | 23 | def get_activations(image_generator, images, model, verbose=False): 24 | """Calculates the activations of the pool_3 layer for all images. 25 | 26 | Params: 27 | -- image_generator 28 | : A generator that generates a batch of images at a time. 29 | -- images : Number of images that will be generated by 30 | image_generator. 31 | -- model : Instance of inception model 32 | -- verbose : If set to True and parameter out_step is given, the number 33 | of calculated batches is reported. 34 | Returns: 35 | -- A numpy array of dimension (num images, dims) that contains the 36 | activations of the given tensor when feeding inception with the 37 | query tensor. 38 | """ 39 | model.eval() 40 | 41 | pred_arr = None 42 | end = 0 43 | for i, batch in enumerate(image_generator): 44 | if verbose: 45 | print('\rPropagating batch %d' % (i + 1), end='', flush=True) 46 | start = end 47 | batch_size = batch.shape[0] 48 | end = start + batch_size 49 | batch = batch.to(device) 50 | 51 | with torch.no_grad(): 52 | model(batch) 53 | pred = model.feature[feature_layer] 54 | batch_feature = pred.cpu().numpy().reshape(batch_size, -1) 55 | if pred_arr is None: 56 | pred_arr = np.empty((images, batch_feature.shape[1])) 57 | pred_arr[start:end] = batch_feature 58 | 59 | if verbose: 60 | print(' done') 61 | 62 | return pred_arr 63 | 64 | 65 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 66 | """Numpy implementation of the Frechet Distance. 67 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 68 | and X_2 ~ N(mu_2, C_2) is 69 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 70 | 71 | Stable version by Dougal J. Sutherland. 72 | 73 | Params: 74 | -- mu1 : Numpy array containing the activations of a layer of the 75 | inception net (like returned by the function 'get_predictions') 76 | for generated samples. 77 | -- mu2 : The sample mean over activations, precalculated on an 78 | representive data set. 79 | -- sigma1: The covariance matrix over activations for generated samples. 80 | -- sigma2: The covariance matrix over activations, precalculated on an 81 | representive data set. 82 | 83 | Returns: 84 | -- : The Frechet Distance. 85 | """ 86 | 87 | mu1 = np.atleast_1d(mu1) 88 | mu2 = np.atleast_1d(mu2) 89 | 90 | sigma1 = np.atleast_2d(sigma1) 91 | sigma2 = np.atleast_2d(sigma2) 92 | 93 | assert mu1.shape == mu2.shape, \ 94 | 'Training and test mean vectors have different lengths' 95 | assert sigma1.shape == sigma2.shape, \ 96 | 'Training and test covariances have different dimensions' 97 | 98 | diff = mu1 - mu2 99 | 100 | # Product might be almost singular 101 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 102 | if not np.isfinite(covmean).all(): 103 | msg = ('fid calculation produces singular product; ' 104 | 'adding %s to diagonal of cov estimates') % eps 105 | print(msg) 106 | offset = np.eye(sigma1.shape[0]) * eps 107 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 108 | 109 | # Numerical error might give slight imaginary component 110 | if np.iscomplexobj(covmean): 111 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 112 | m = np.max(np.abs(covmean.imag)) 113 | raise ValueError(f'Imaginary component {m}') 114 | covmean = covmean.real 115 | 116 | tr_covmean = np.trace(covmean) 117 | 118 | return (diff.dot(diff) + np.trace(sigma1) + 119 | np.trace(sigma2) - 2 * tr_covmean) 120 | 121 | 122 | def calculate_activation_statistics(image_generator, images, model, 123 | verbose=False, weight=None): 124 | """Calculation of the statistics used by the FID. 125 | Params: 126 | -- image_generator 127 | : A generator that generates a batch of images at a time. 128 | -- images : Number of images that will be generated by 129 | image_generator. 130 | -- model : Instance of inception model 131 | -- verbose : If set to True and parameter out_step is given, the 132 | number of calculated batches is reported. 133 | Returns: 134 | -- mu : The mean over samples of the activations of the pool_3 layer of 135 | the inception model. 136 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 137 | the inception model. 138 | """ 139 | act = get_activations(image_generator, images, model, verbose) 140 | if weight is None: 141 | mu = np.mean(act, axis=0) 142 | sigma = np.cov(act, rowvar=False) 143 | else: 144 | mu = np.average(act, axis=0, weights=weight) 145 | sigma = np.cov(act, rowvar=False, aweights=weight) 146 | return mu, sigma 147 | 148 | 149 | class MNISTModel: 150 | def __init__(self): 151 | model = mnist_model.Net().to(device) 152 | model.eval() 153 | map_location = None if use_cuda else 'cpu' 154 | model.load_state_dict( 155 | torch.load('mnist.pth', map_location=map_location)) 156 | 157 | stats_file = f'mnist_act_{feature_layer}.npz' 158 | try: 159 | f = np.load(stats_file) 160 | m_mnist, s_mnist = f['mu'][:], f['sigma'][:] 161 | f.close() 162 | except FileNotFoundError: 163 | data = datasets.MNIST('data', train=True, download=True, 164 | transform=transforms.ToTensor()) 165 | images = len(data) 166 | batch_size = 64 167 | data_loader = DataLoader([image for image, _ in data], 168 | batch_size=batch_size) 169 | m_mnist, s_mnist = calculate_activation_statistics( 170 | data_loader, images, model, verbose=True) 171 | np.savez(stats_file, mu=m_mnist, sigma=s_mnist) 172 | 173 | self.model = model 174 | self.mnist_stats = m_mnist, s_mnist 175 | 176 | def get_feature(self, samples): 177 | self.model(samples) 178 | feature = self.model.feature[feature_layer] 179 | return feature.cpu().numpy().reshape(samples.shape[0], -1) 180 | 181 | def fid(self, features): 182 | mu = np.mean(features, axis=0) 183 | sigma = np.cov(features, rowvar=False) 184 | return calculate_frechet_distance(mu, sigma, *self.mnist_stats) 185 | 186 | 187 | def data_generator_fid(data_generator, 188 | n_samples=60000, batch_size=64, verbose=False): 189 | mnist_model = MNISTModel() 190 | latent_size = 128 191 | data_noise = torch.FloatTensor(batch_size, latent_size).to(device) 192 | 193 | with torch.no_grad(): 194 | count = 0 195 | features = None 196 | while count < n_samples: 197 | data_noise.normal_() 198 | samples = data_generator(data_noise) 199 | batch_feature = mnist_model.get_feature(samples) 200 | 201 | if features is None: 202 | features = np.empty((n_samples, batch_feature.shape[1])) 203 | 204 | if count + batch_size > n_samples: 205 | batch_size = n_samples - count 206 | features[count:] = batch_feature[:batch_size] 207 | else: 208 | features[count:(count + batch_size)] = batch_feature 209 | 210 | count += batch_size 211 | if verbose: 212 | print(f'\rGenerate images {count}', end='', flush=True) 213 | if verbose: 214 | print(' done') 215 | return mnist_model.fid(features) 216 | 217 | 218 | def imputer_fid(imputer, data, batch_size=64, verbose=False): 219 | mnist_model = MNISTModel() 220 | impu_noise = torch.FloatTensor(batch_size, 1, 28, 28).to(device) 221 | data_loader = DataLoader(data, batch_size=batch_size, drop_last=True) 222 | n_samples = len(data_loader) * batch_size 223 | 224 | with torch.no_grad(): 225 | start = 0 226 | features = None 227 | for real_data, real_mask, _, index in data_loader: 228 | real_mask = real_mask.float()[:, None] 229 | real_data = real_data.to(device) 230 | real_mask = real_mask.to(device) 231 | impu_noise.uniform_() 232 | imputed_data = imputer(real_data, real_mask, impu_noise) 233 | 234 | batch_feature = mnist_model.get_feature(imputed_data) 235 | if features is None: 236 | features = np.empty((n_samples, batch_feature.shape[1])) 237 | features[start:(start + batch_size)] = batch_feature 238 | start += batch_size 239 | if verbose: 240 | print(f'\rGenerate images {start}', end='', flush=True) 241 | if verbose: 242 | print(' done') 243 | return mnist_model.fid(features) 244 | 245 | 246 | def pretrained_misgan_fid(model_file, samples=60000, batch_size=64): 247 | model = torch.load(model_file, map_location='cpu') 248 | args = model['args'] 249 | if args.generator == 'conv': 250 | DataGenerator = ConvDataGenerator 251 | elif args.generator == 'fc': 252 | DataGenerator = FCDataGenerator 253 | data_gen = DataGenerator().to(device) 254 | data_gen.load_state_dict(model['data_gen']) 255 | return data_generator_fid(data_gen, verbose=True) 256 | 257 | 258 | def pretrained_imputer_fid(model_file, save_file, batch_size=64): 259 | model = torch.load(model_file, map_location='cpu') 260 | if 'imputer' not in model: 261 | return 262 | args = model['args'] 263 | 264 | if args.imputer == 'comp': 265 | Imputer = ComplementImputer 266 | elif args.imputer == 'mask': 267 | Imputer = MaskImputer 268 | elif args.imputer == 'fix': 269 | Imputer = FixedNoiseDimImputer 270 | 271 | hid_lens = [int(n) for n in args.arch.split('-')] 272 | imputer = Imputer(arch=hid_lens).to(device) 273 | imputer.load_state_dict(model['imputer']) 274 | 275 | block_len = args.block_len 276 | if block_len == 0: 277 | block_len = None 278 | 279 | if args.mask == 'indep': 280 | data = IndepMaskedMNIST(obs_prob=args.obs_prob, 281 | obs_prob_high=args.obs_prob_high) 282 | elif args.mask == 'block': 283 | data = BlockMaskedMNIST(block_len=block_len) 284 | 285 | fid = imputer_fid(imputer, data, verbose=True) 286 | with save_file.open('w') as f: 287 | print(fid, file=f) 288 | print('imputer fid:', fid) 289 | 290 | 291 | def main(): 292 | parser = argparse.ArgumentParser() 293 | parser.add_argument('root_dir') 294 | parser.add_argument('--skip-exist', action='store_true') 295 | args = parser.parse_args() 296 | 297 | skip_exist = args.skip_exist 298 | 299 | root_dir = Path(args.root_dir) 300 | fid_file = root_dir / f'fid-{feature_layer}.txt' 301 | if skip_exist and fid_file.exists(): 302 | return 303 | try: 304 | model_file = max((root_dir / 'model').glob('*.pth')) 305 | except ValueError: 306 | return 307 | 308 | fid = pretrained_misgan_fid(model_file) 309 | print(f'{root_dir.name}: {fid}') 310 | with fid_file.open('w') as f: 311 | print(fid, file=f) 312 | 313 | # Compute FID for the imputer if it is in the model 314 | imputer_fid_file = root_dir / f'impute-fid-{feature_layer}.txt' 315 | pretrained_imputer_fid(model_file, imputer_fid_file) 316 | 317 | 318 | if __name__ == '__main__': 319 | main() 320 | -------------------------------------------------------------------------------- /src/mnist_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def add_data_transformer(self): 7 | self.transform = lambda x: torch.sigmoid(x).view(-1, 1, 28, 28) 8 | 9 | 10 | def add_mask_transformer(self, temperature=.66, hard_sigmoid=(-.1, 1.1)): 11 | """ 12 | hard_sigmoid: 13 | False: use sigmoid only 14 | True: hard thresholding 15 | (a, b): hard thresholding on rescaled sigmoid 16 | """ 17 | self.temperature = temperature 18 | self.hard_sigmoid = hard_sigmoid 19 | 20 | view = -1, 1, 28, 28 21 | if hard_sigmoid is False: 22 | self.transform = lambda x: torch.sigmoid(x / temperature).view(*view) 23 | elif hard_sigmoid is True: 24 | self.transform = lambda x: F.hardtanh( 25 | x / temperature, 0, 1).view(*view) 26 | else: 27 | a, b = hard_sigmoid 28 | self.transform = lambda x: F.hardtanh( 29 | torch.sigmoid(x / temperature) * (b - a) + a, 0, 1).view(*view) 30 | 31 | 32 | # Must sub-class ConvGenerator to provide transform() 33 | class ConvGenerator(nn.Module): 34 | def __init__(self, latent_size=128): 35 | super().__init__() 36 | 37 | self.DIM = 64 38 | self.latent_size = latent_size 39 | 40 | self.preprocess = nn.Sequential( 41 | nn.Linear(latent_size, 4 * 4 * 4 * self.DIM), 42 | nn.ReLU(True), 43 | ) 44 | self.block1 = nn.Sequential( 45 | nn.ConvTranspose2d(4 * self.DIM, 2 * self.DIM, 5), 46 | nn.ReLU(True), 47 | ) 48 | self.block2 = nn.Sequential( 49 | nn.ConvTranspose2d(2 * self.DIM, self.DIM, 5), 50 | nn.ReLU(True), 51 | ) 52 | self.deconv_out = nn.ConvTranspose2d(self.DIM, 1, 8, stride=2) 53 | 54 | def forward(self, input): 55 | net = self.preprocess(input) 56 | net = net.view(-1, 4 * self.DIM, 4, 4) 57 | net = self.block1(net) 58 | net = net[:, :, :7, :7] 59 | net = self.block2(net) 60 | net = self.deconv_out(net) 61 | return self.transform(net) 62 | 63 | 64 | # Must sub-class FCGenerator to provide transform() 65 | class FCGenerator(nn.Module): 66 | def __init__(self, latent_size=128): 67 | super().__init__() 68 | self.latent_size = latent_size 69 | self.fc = nn.Sequential( 70 | nn.Linear(latent_size, 256), 71 | nn.ReLU(True), 72 | nn.Linear(256, 512), 73 | nn.ReLU(True), 74 | nn.Linear(512, 784), 75 | ) 76 | 77 | def forward(self, input): 78 | net = self.fc(input) 79 | return self.transform(net) 80 | 81 | 82 | class ConvDataGenerator(ConvGenerator): 83 | def __init__(self, latent_size=128): 84 | super().__init__(latent_size=latent_size) 85 | add_data_transformer(self) 86 | 87 | 88 | class FCDataGenerator(FCGenerator): 89 | def __init__(self, latent_size=128): 90 | super().__init__(latent_size=latent_size) 91 | add_data_transformer(self) 92 | 93 | 94 | class ConvMaskGenerator(ConvGenerator): 95 | def __init__(self, latent_size=128, temperature=.66, 96 | hard_sigmoid=(-.1, 1.1)): 97 | super().__init__(latent_size=latent_size) 98 | add_mask_transformer(self, temperature, hard_sigmoid) 99 | 100 | 101 | class FCMaskGenerator(FCGenerator): 102 | def __init__(self, latent_size=128, temperature=.66, 103 | hard_sigmoid=(-.1, 1.1)): 104 | super().__init__(latent_size=latent_size) 105 | add_mask_transformer(self, temperature, hard_sigmoid) 106 | -------------------------------------------------------------------------------- /src/mnist_imputer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # Must sub-class Imputer to provide fc1 7 | class Imputer(nn.Module): 8 | def __init__(self, arch=(784, 784)): 9 | super().__init__() 10 | # self.fc1 = nn.Linear(784, arch[0]) 11 | self.fc2 = nn.Linear(arch[0], arch[1]) 12 | self.fc3 = nn.Linear(arch[1], arch[0]) 13 | self.fc4 = nn.Linear(arch[0], 784) 14 | self.transform = lambda x: torch.sigmoid(x).view(-1, 1, 28, 28) 15 | 16 | def forward(self, input, data, mask): 17 | net = input.view(input.size(0), -1) 18 | net = F.relu(self.fc1(net)) 19 | net = F.relu(self.fc2(net)) 20 | net = F.relu(self.fc3(net)) 21 | net = self.fc4(net) 22 | net = self.transform(net) 23 | # return data * mask + net * (1 - mask) 24 | # NOT replacing observed part with input data for computing 25 | # autoencoding loss. 26 | return net 27 | 28 | 29 | class ComplementImputer(Imputer): 30 | def __init__(self, arch=(784, 784)): 31 | super().__init__(arch=arch) 32 | self.fc1 = nn.Linear(784, arch[0]) 33 | 34 | def forward(self, input, mask, noise): 35 | net = input * mask + noise * (1 - mask) 36 | return super().forward(net, input, mask) 37 | 38 | 39 | class MaskImputer(Imputer): 40 | def __init__(self, arch=(784, 784)): 41 | super().__init__(arch=arch) 42 | self.fc1 = nn.Linear(784 * 2, arch[0]) 43 | 44 | def forward(self, input, mask, noise): 45 | batch_size = input.size(0) 46 | net = torch.cat( 47 | [(input * mask + noise * (1 - mask)).view(batch_size, -1), 48 | mask.view(batch_size, -1)], 1) 49 | return super().forward(net, input, mask) 50 | 51 | 52 | class FixedNoiseDimImputer(Imputer): 53 | def __init__(self, arch=(784, 784)): 54 | super().__init__(arch=arch) 55 | self.fc1 = nn.Linear(784 * 3, arch[0]) 56 | 57 | def forward(self, input, mask, noise): 58 | batch_size = input.size(0) 59 | net = torch.cat([(input * mask).view(batch_size, -1), 60 | mask.view(batch_size, -1), 61 | noise.view(batch_size, -1)], 1) 62 | return super().forward(net, input, mask) 63 | -------------------------------------------------------------------------------- /src/mnist_misgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datetime import datetime 3 | from pathlib import Path 4 | import argparse 5 | from mnist_generator import (ConvDataGenerator, FCDataGenerator, 6 | ConvMaskGenerator, FCMaskGenerator) 7 | from mnist_critic import ConvCritic, FCCritic 8 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST 9 | from misgan import misgan 10 | 11 | 12 | use_cuda = torch.cuda.is_available() 13 | device = torch.device('cuda' if use_cuda else 'cpu') 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | 19 | # resume from checkpoint 20 | parser.add_argument('--resume') 21 | # training options 22 | parser.add_argument('--epoch', type=int, default=500) 23 | parser.add_argument('--batch-size', type=int, default=64) 24 | 25 | # log options: 0 to disable plot-interval or save-interval 26 | parser.add_argument('--plot-interval', type=int, default=50) 27 | parser.add_argument('--save-interval', type=int, default=0) 28 | parser.add_argument('--prefix', default='misgan') 29 | 30 | # mask options (data): block|indep 31 | parser.add_argument('--mask', default='block') 32 | # option for block: set to 0 for variable size 33 | parser.add_argument('--block-len', type=int, default=14) 34 | # option for indep: 35 | parser.add_argument('--obs-prob', type=float, default=.2) 36 | parser.add_argument('--obs-prob-high', type=float, default=None) 37 | 38 | # model options 39 | parser.add_argument('--tau', type=float, default=0) 40 | parser.add_argument('--generator', default='conv') # conv|fc 41 | parser.add_argument('--critic', default='conv') # conv|fc 42 | # parser.add_argument('--alpha', type=float, default=.1) # 0: separate 43 | parser.add_argument('--alpha', type=float, default=.2) # 0: separate 44 | # options for mask generator: sigmoid, hardsigmoid, fusion 45 | # parser.add_argument('--maskgen', default='fusion') 46 | parser.add_argument('--maskgen', default='sigmoid') 47 | parser.add_argument('--gp-lambda', type=float, default=10) 48 | parser.add_argument('--n-critic', type=int, default=5) 49 | parser.add_argument('--n-latent', type=int, default=128) 50 | 51 | args = parser.parse_args() 52 | 53 | checkpoint = None 54 | # Resume from previously stored checkpoint 55 | if args.resume: 56 | print(f'Resume: {args.resume}') 57 | output_dir = Path(args.resume) 58 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'), 59 | map_location='cpu') 60 | for key, arg in vars(checkpoint['args']).items(): 61 | if key not in ['resume']: 62 | setattr(args, key, arg) 63 | 64 | if args.generator == 'conv': 65 | DataGenerator = ConvDataGenerator 66 | MaskGenerator = ConvMaskGenerator 67 | elif args.generator == 'fc': 68 | DataGenerator = FCDataGenerator 69 | MaskGenerator = FCMaskGenerator 70 | else: 71 | raise NotImplementedError 72 | 73 | if args.critic == 'conv': 74 | Critic = ConvCritic 75 | elif args.critic == 'fc': 76 | Critic = FCCritic 77 | else: 78 | raise NotImplementedError 79 | 80 | if args.maskgen == 'sigmoid': 81 | hard_sigmoid = False 82 | elif args.maskgen == 'hardsigmoid': 83 | hard_sigmoid = True 84 | elif args.maskgen == 'fusion': 85 | hard_sigmoid = -.1, 1.1 86 | else: 87 | raise NotImplementedError 88 | 89 | mask = args.mask 90 | obs_prob = args.obs_prob 91 | obs_prob_high = args.obs_prob_high 92 | block_len = args.block_len 93 | if block_len == 0: 94 | block_len = None 95 | 96 | if mask == 'indep': 97 | if obs_prob_high is None: 98 | mask_str = f'indep_{obs_prob:g}' 99 | else: 100 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}' 101 | elif mask == 'block': 102 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize') 103 | else: 104 | raise NotImplementedError 105 | 106 | path = '{}_{}_{}'.format( 107 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), 108 | '_'.join([ 109 | f'gen_{args.generator}', 110 | f'critic_{args.critic}', 111 | f'tau_{args.tau:g}', 112 | f'alpha_{args.alpha:g}', 113 | f'maskgen_{args.maskgen}', 114 | mask_str, 115 | ])) 116 | 117 | if not args.resume: 118 | output_dir = Path('results') / 'mnist' / path 119 | print(output_dir) 120 | 121 | if mask == 'indep': 122 | data = IndepMaskedMNIST(obs_prob=obs_prob, obs_prob_high=obs_prob_high) 123 | elif mask == 'block': 124 | data = BlockMaskedMNIST(block_len=block_len) 125 | 126 | data_gen = DataGenerator().to(device) 127 | mask_gen = MaskGenerator(hard_sigmoid=hard_sigmoid).to(device) 128 | 129 | data_critic = Critic().to(device) 130 | mask_critic = Critic().to(device) 131 | 132 | misgan(args, data_gen, mask_gen, data_critic, mask_critic, data, 133 | output_dir, checkpoint) 134 | 135 | 136 | if __name__ == '__main__': 137 | main() 138 | -------------------------------------------------------------------------------- /src/mnist_misgan_impute.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datetime import datetime 3 | from pathlib import Path 4 | import argparse 5 | from mnist_generator import (ConvDataGenerator, FCDataGenerator, 6 | ConvMaskGenerator, FCMaskGenerator) 7 | from mnist_imputer import (ComplementImputer, 8 | MaskImputer, 9 | FixedNoiseDimImputer) 10 | from mnist_critic import ConvCritic, FCCritic 11 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST 12 | from misgan_impute import misgan_impute 13 | 14 | 15 | use_cuda = torch.cuda.is_available() 16 | device = torch.device('cuda' if use_cuda else 'cpu') 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | 22 | # resume from checkpoint 23 | parser.add_argument('--resume') 24 | 25 | # training options 26 | parser.add_argument('--workers', type=int, default=0) 27 | parser.add_argument('--epoch', type=int, default=1000) 28 | parser.add_argument('--batch-size', type=int, default=64) 29 | parser.add_argument('--pretrain', default=None) 30 | parser.add_argument('--imputeronly', action='store_true') 31 | 32 | # log options: 0 to disable plot-interval or save-interval 33 | parser.add_argument('--plot-interval', type=int, default=100) 34 | parser.add_argument('--save-interval', type=int, default=0) 35 | parser.add_argument('--prefix', default='impute') 36 | 37 | # mask options (data): block|indep 38 | parser.add_argument('--mask', default='block') 39 | # option for block: set to 0 for variable size 40 | parser.add_argument('--block-len', type=int, default=14) 41 | # option for indep: 42 | parser.add_argument('--obs-prob', type=float, default=.2) 43 | parser.add_argument('--obs-prob-high', type=float, default=None) 44 | 45 | # model options 46 | parser.add_argument('--tau', type=float, default=0) 47 | parser.add_argument('--generator', default='conv') # conv|fc 48 | parser.add_argument('--critic', default='conv') # conv|fc 49 | parser.add_argument('--alpha', type=float, default=.1) # 0: separate 50 | parser.add_argument('--beta', type=float, default=.1) 51 | parser.add_argument('--gamma', type=float, default=0) 52 | parser.add_argument('--arch', default='784-784') 53 | parser.add_argument('--imputer', default='comp') # comp|mask|fix 54 | # options for mask generator: sigmoid, hardsigmoid, fusion 55 | parser.add_argument('--maskgen', default='fusion') 56 | parser.add_argument('--gp-lambda', type=float, default=10) 57 | parser.add_argument('--n-critic', type=int, default=5) 58 | parser.add_argument('--n-latent', type=int, default=128) 59 | 60 | args = parser.parse_args() 61 | 62 | checkpoint = None 63 | # Resume from previously stored checkpoint 64 | if args.resume: 65 | print(f'Resume: {args.resume}') 66 | output_dir = Path(args.resume) 67 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'), 68 | map_location='cpu') 69 | for key, arg in vars(checkpoint['args']).items(): 70 | if key not in ['resume']: 71 | setattr(args, key, arg) 72 | 73 | if args.imputeronly: 74 | assert args.pretrain is not None 75 | 76 | arch = args.arch 77 | imputer_type = args.imputer 78 | mask = args.mask 79 | obs_prob = args.obs_prob 80 | obs_prob_high = args.obs_prob_high 81 | block_len = args.block_len 82 | if block_len == 0: 83 | block_len = None 84 | 85 | if args.generator == 'conv': 86 | DataGenerator = ConvDataGenerator 87 | MaskGenerator = ConvMaskGenerator 88 | elif args.generator == 'fc': 89 | DataGenerator = FCDataGenerator 90 | MaskGenerator = FCMaskGenerator 91 | else: 92 | raise NotImplementedError 93 | 94 | if imputer_type == 'comp': 95 | Imputer = ComplementImputer 96 | elif imputer_type == 'mask': 97 | Imputer = MaskImputer 98 | elif imputer_type == 'fix': 99 | Imputer = FixedNoiseDimImputer 100 | else: 101 | raise NotImplementedError 102 | 103 | if args.critic == 'conv': 104 | Critic = ConvCritic 105 | elif args.critic == 'fc': 106 | Critic = FCCritic 107 | else: 108 | raise NotImplementedError 109 | 110 | if args.maskgen == 'sigmoid': 111 | hard_sigmoid = False 112 | elif args.maskgen == 'hardsigmoid': 113 | hard_sigmoid = True 114 | elif args.maskgen == 'fusion': 115 | hard_sigmoid = -.1, 1.1 116 | else: 117 | raise NotImplementedError 118 | 119 | if mask == 'indep': 120 | if obs_prob_high is None: 121 | mask_str = f'indep_{obs_prob:g}' 122 | else: 123 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}' 124 | elif mask == 'block': 125 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize') 126 | else: 127 | raise NotImplementedError 128 | 129 | path = '{}_{}_{}'.format( 130 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), 131 | '_'.join([ 132 | f'gen_{args.generator}', 133 | f'critic_{args.critic}', 134 | f'imp_{args.imputer}', 135 | f'tau_{args.tau:g}', 136 | f'arch_{args.arch}', 137 | f'maskgen_{args.maskgen}', 138 | f'coef_{args.alpha:g}_{args.beta:g}_{args.gamma:g}', 139 | mask_str 140 | ])) 141 | 142 | if not args.resume: 143 | output_dir = Path('results') / 'mnist' / path 144 | print(output_dir) 145 | 146 | if mask == 'indep': 147 | data = IndepMaskedMNIST( 148 | obs_prob=obs_prob, obs_prob_high=obs_prob_high) 149 | elif mask == 'block': 150 | data = BlockMaskedMNIST(block_len=block_len) 151 | 152 | data_gen = DataGenerator().to(device) 153 | mask_gen = MaskGenerator(hard_sigmoid=hard_sigmoid).to(device) 154 | 155 | hid_lens = [int(n) for n in arch.split('-')] 156 | imputer = Imputer(arch=hid_lens).to(device) 157 | 158 | data_critic = Critic().to(device) 159 | mask_critic = Critic().to(device) 160 | impu_critic = Critic().to(device) 161 | 162 | misgan_impute(args, data_gen, mask_gen, imputer, 163 | data_critic, mask_critic, impu_critic, 164 | data, output_dir, checkpoint) 165 | 166 | 167 | if __name__ == '__main__': 168 | main() 169 | -------------------------------------------------------------------------------- /src/mnist_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from https://github.com/pytorch/examples/blob/master/mnist/main.py 3 | """ 4 | import argparse 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torchvision import datasets, transforms 10 | 11 | 12 | class Net(nn.Module): 13 | def __init__(self): 14 | super(Net, self).__init__() 15 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5) 16 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 17 | self.conv2_drop = nn.Dropout2d() 18 | self.fc1 = nn.Linear(320, 50) 19 | self.fc2 = nn.Linear(50, 10) 20 | 21 | def forward(self, x): 22 | feature = [] 23 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 24 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 25 | x = x.view(-1, 320) 26 | x = self.fc1(x) 27 | feature.append(x) 28 | x = F.relu(x) 29 | x = F.dropout(x, training=self.training) 30 | x = self.fc2(x) 31 | feature.append(x) 32 | self.feature = feature 33 | return F.log_softmax(x, dim=1) 34 | 35 | 36 | def main(): 37 | # Training settings 38 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example') 39 | parser.add_argument('--batch-size', type=int, default=64, metavar='N', 40 | help='input batch size for training (default: 64)') 41 | parser.add_argument('--test-batch-size', type=int, 42 | default=1000, metavar='N', 43 | help='input batch size for testing (default: 1000)') 44 | parser.add_argument('--epochs', type=int, default=100, metavar='N', 45 | help='number of epochs to train (default: 100)') 46 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', 47 | help='learning rate (default: 0.01)') 48 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 49 | help='SGD momentum (default: 0.5)') 50 | parser.add_argument('--no-cuda', action='store_true', default=False, 51 | help='disables CUDA training') 52 | parser.add_argument('--seed', type=int, default=1, metavar='S', 53 | help='random seed (default: 1)') 54 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 55 | help='number of batches to wait before logging ' 56 | 'training status') 57 | args = parser.parse_args() 58 | args.cuda = not args.no_cuda and torch.cuda.is_available() 59 | 60 | torch.manual_seed(args.seed) 61 | if args.cuda: 62 | torch.cuda.manual_seed(args.seed) 63 | 64 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 65 | train_loader = torch.utils.data.DataLoader( 66 | datasets.MNIST('../data', train=True, download=True, 67 | transform=transforms.Compose([ 68 | transforms.ToTensor(), 69 | transforms.Normalize((0.1307,), (0.3081,)) 70 | ])), 71 | batch_size=args.batch_size, shuffle=True, **kwargs) 72 | test_loader = torch.utils.data.DataLoader( 73 | datasets.MNIST('../data', train=False, transform=transforms.Compose([ 74 | transforms.ToTensor(), 75 | transforms.Normalize((0.1307,), (0.3081,)) 76 | ])), 77 | batch_size=args.test_batch_size, shuffle=True, **kwargs) 78 | 79 | model = Net() 80 | if args.cuda: 81 | model.cuda() 82 | 83 | optimizer = optim.SGD(model.parameters(), lr=args.lr, 84 | momentum=args.momentum) 85 | 86 | def train(epoch): 87 | model.train() 88 | for batch_idx, (data, target) in enumerate(train_loader): 89 | if args.cuda: 90 | data, target = data.cuda(), target.cuda() 91 | optimizer.zero_grad() 92 | output = model(data) 93 | loss = F.nll_loss(output, target) 94 | loss.backward() 95 | optimizer.step() 96 | if batch_idx % args.log_interval == 0: 97 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 98 | epoch, batch_idx * len(data), len(train_loader.dataset), 99 | 100. * batch_idx / len(train_loader), loss.item())) 100 | 101 | def test(): 102 | model.eval() 103 | test_loss = 0 104 | correct = 0 105 | with torch.no_grad(): 106 | for data, target in test_loader: 107 | if args.cuda: 108 | data, target = data.cuda(), target.cuda() 109 | output = model(data) 110 | # sum up batch loss 111 | test_loss += F.nll_loss(output, target, reduction='sum').item() 112 | # get the index of the max log-probability 113 | pred = output.argmax(dim=1, keepdim=True) 114 | correct += (pred == target.view_as(pred)).long().cpu().sum() 115 | 116 | test_loss /= len(test_loader.dataset) 117 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n' 118 | .format(test_loss, correct, len(test_loader.dataset), 119 | 100. * correct / len(test_loader.dataset))) 120 | 121 | for epoch in range(1, args.epochs + 1): 122 | train(epoch) 123 | test() 124 | 125 | torch.save(model.state_dict(), 'mnist.pth') 126 | 127 | 128 | if __name__ == '__main__': 129 | main() 130 | -------------------------------------------------------------------------------- /src/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pylab as plt 3 | from matplotlib.patches import Rectangle 4 | from PIL import Image 5 | 6 | 7 | def plot_grid(image, bbox=None, gap=0, gap_value=1, 8 | nrow=4, ncol=8, save_file=None): 9 | image = image.cpu().numpy() 10 | channels, len0, len1 = image[0].shape 11 | grid = np.empty( 12 | (nrow * (len0 + gap) - gap, ncol * (len1 + gap) - gap, channels)) 13 | # Convert to W, H, C 14 | image = image.transpose((0, 2, 3, 1)) 15 | grid.fill(gap_value) 16 | 17 | for i, x in enumerate(image): 18 | if i >= nrow * ncol: 19 | break 20 | p0 = (i // ncol) * (len0 + gap) 21 | p1 = (i % ncol) * (len1 + gap) 22 | grid[p0:(p0 + len0), p1:(p1 + len1)] = x 23 | 24 | # figsize = np.r_[ncol, nrow] * .75 25 | scale = 2.5 26 | figsize = ncol * scale, nrow * scale # FIXME: scale by len0, len1 27 | fig = plt.figure(figsize=figsize) 28 | ax = plt.Axes(fig, [0, 0, 1, 1]) 29 | ax.set_axis_off() 30 | fig.add_axes(ax) 31 | grid = grid.squeeze() 32 | ax.imshow(grid, cmap='binary_r', interpolation='none', aspect='equal') 33 | 34 | if bbox is not None: 35 | nplot = min(len(image), nrow * ncol) 36 | for i in range(nplot): 37 | if len(bbox) == 1: 38 | d0, d1, d0_len, d1_len = bbox[0] 39 | else: 40 | d0, d1, d0_len, d1_len = bbox[i] 41 | p0 = (i // ncol) * (len0 + gap) 42 | p1 = (i % ncol) * (len1 + gap) 43 | offset = np.array([p1 + d1, p0 + d0]) - .5 44 | ax.add_patch(Rectangle( 45 | offset, d1_len, d0_len, lw=4, edgecolor='red', fill=False)) 46 | 47 | if save_file: 48 | fig.savefig(save_file) 49 | plt.close(fig) 50 | 51 | 52 | def plot_samples(samples, save_file, nrow=4, ncol=8): 53 | x = samples.cpu().numpy() 54 | channels, len0, len1 = x[0].shape 55 | x_merge = np.zeros((nrow * len0, ncol * len1, channels)) 56 | 57 | for i, x_ in enumerate(x): 58 | if i >= nrow * ncol: 59 | break 60 | p0 = (i // ncol) * len0 61 | p1 = (i % ncol) * len1 62 | x_merge[p0:(p0 + len0), p1:(p1 + len1)] = x_.transpose((1, 2, 0)) 63 | 64 | x_merge = (x_merge * 255).clip(0, 255).astype(np.uint8) 65 | # squeeze() to remove the last dimension for the single-channel image. 66 | im = Image.fromarray(x_merge.squeeze()) 67 | im.save(save_file) 68 | -------------------------------------------------------------------------------- /src/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Code adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 6 | # 7 | # Defines the submodule with skip connection. 8 | # X -------------------identity---------------------- X 9 | # |-- downsampling -- |submodule| -- upsampling --| 10 | class UnetSkipConnectionBlock(nn.Module): 11 | def __init__(self, outer_nc, inner_nc, input_nc=None, 12 | submodule=None, outermost=False, innermost=False, 13 | norm_layer=nn.BatchNorm2d): 14 | super().__init__() 15 | self.outermost = outermost 16 | use_bias = norm_layer == nn.InstanceNorm2d 17 | if input_nc is None: 18 | input_nc = outer_nc 19 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 20 | stride=2, padding=1, bias=use_bias) 21 | downrelu = nn.LeakyReLU(0.2, True) 22 | if norm_layer is not None: 23 | downnorm = norm_layer(inner_nc) 24 | upnorm = norm_layer(outer_nc) 25 | uprelu = nn.ReLU(True) 26 | 27 | if outermost: 28 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 29 | kernel_size=4, stride=2, 30 | padding=1) 31 | down = [downconv] 32 | up = [uprelu, upconv] 33 | model = down + [submodule] + up 34 | elif innermost: 35 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 36 | kernel_size=4, stride=2, 37 | padding=1, bias=use_bias) 38 | down = [downrelu, downconv] 39 | up = [uprelu, upconv] 40 | if norm_layer is not None: 41 | up.append(upnorm) 42 | model = down + up 43 | else: 44 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 45 | kernel_size=4, stride=2, 46 | padding=1, bias=use_bias) 47 | down = [downrelu, downconv] 48 | up = [uprelu, upconv] 49 | if norm_layer is not None: 50 | down.append(downnorm) 51 | up.append(upnorm) 52 | 53 | model = down + [submodule] + up 54 | 55 | self.model = nn.Sequential(*model) 56 | 57 | def forward(self, x): 58 | if self.outermost: 59 | return self.model(x) 60 | else: 61 | return torch.cat([x, self.model(x)], 1) 62 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import grad 2 | 3 | 4 | class CriticUpdater: 5 | def __init__(self, critic, critic_optimizer, eps, ones, gp_lambda=10): 6 | self.critic = critic 7 | self.critic_optimizer = critic_optimizer 8 | self.eps = eps 9 | self.ones = ones 10 | self.gp_lambda = gp_lambda 11 | 12 | def __call__(self, real, fake): 13 | real = real.detach() 14 | fake = fake.detach() 15 | self.critic.zero_grad() 16 | self.eps.uniform_(0, 1) 17 | interp = (self.eps * real + (1 - self.eps) * fake).requires_grad_() 18 | grad_d = grad(self.critic(interp), interp, grad_outputs=self.ones, 19 | create_graph=True)[0] 20 | grad_d = grad_d.view(real.shape[0], -1) 21 | grad_penalty = ((grad_d.norm(dim=1) - 1)**2).mean() * self.gp_lambda 22 | w_dist = self.critic(fake).mean() - self.critic(real).mean() 23 | loss = w_dist + grad_penalty 24 | loss.backward() 25 | self.critic_optimizer.step() 26 | self.loss_value = loss.item() 27 | 28 | 29 | def mask_norm(diff, mask): 30 | """Mask normalization""" 31 | dim = 1, 2, 3 32 | # Assume mask.sum(1) is non-zero throughout 33 | return ((diff * mask).sum(dim) / mask.sum(dim)).mean() 34 | 35 | 36 | def mkdir(path): 37 | path.mkdir(parents=True, exist_ok=True) 38 | return path 39 | 40 | 41 | def mask_data(data, mask, tau): 42 | return mask * data + (1 - mask) * tau 43 | --------------------------------------------------------------------------------