├── .gitattributes
├── LICENSE
├── README.md
├── img
├── misgan-impute.png
└── misgan.png
├── misgan.ipynb
├── requirements.txt
├── src-torch1.6
├── celeba_critic.py
├── celeba_fid.py
├── celeba_generator.py
├── celeba_misgan.py
├── celeba_misgan_impute.py
├── fcnet.py
├── fid.py
├── imputer.py
├── inception.py
├── masked_celeba.py
├── masked_mnist.py
├── misgan.ipynb
├── misgan.py
├── misgan_impute.py
├── mnist_critic.py
├── mnist_fid.py
├── mnist_generator.py
├── mnist_imputer.py
├── mnist_misgan.py
├── mnist_misgan_impute.py
├── mnist_model.py
├── plot.py
├── requirements.txt
├── unet.py
└── utils.py
└── src
├── celeba_critic.py
├── celeba_fid.py
├── celeba_generator.py
├── celeba_misgan.py
├── celeba_misgan_impute.py
├── fcnet.py
├── fid.py
├── imputer.py
├── inception.py
├── masked_celeba.py
├── masked_mnist.py
├── misgan.py
├── misgan_impute.py
├── mnist_critic.py
├── mnist_fid.py
├── mnist_generator.py
├── mnist_imputer.py
├── mnist_misgan.py
├── mnist_misgan_impute.py
├── mnist_model.py
├── plot.py
├── unet.py
└── utils.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | misgan.ipynb linguist-language=Python
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Steven Cheng-Xian Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MisGAN: Learning from Incomplete Data with GANs
2 |
3 | This repository provides a PyTorch implementation of
4 | [MisGAN](https://arxiv.org/abs/1902.09599),
5 | a GAN-based framework for learning from incomplete data.
6 |
7 | **Note:** Please check out our
8 | [follow-up work](https://github.com/steveli/partial-encoder-decoder)
9 | on models that can be trained faster and more stably.
10 |
11 |
12 | ## Requirements
13 |
14 | The code requires Python 3.6 or later.
15 | The file [requirements.txt](requirements.txt) contains the full list of
16 | required Python modules.
17 |
18 |
19 | ## Jupyter notebook
20 |
21 | We provide a [notebook](misgan.ipynb) that includes an overview of MisGAN
22 | as well as the annotated implementation that runs on MNIST.
23 | The notebook can be viewed from
24 | [here](https://nbviewer.jupyter.org/github/steveli/misgan/blob/master/misgan.ipynb).
25 |
26 |
27 | ## Usage
28 |
29 | The source code can be found in the `src` directory.
30 | Separate scripts are provided to run MisGAN on MNIST and CelebA datasets.
31 |
32 | For CelebA, you will need to download the dataset from its
33 | [website](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html):
34 |
35 | * Download the file `img_align_celeba.zip` (available from [this link](https://drive.google.com/uc?export=download&id=0B7EVK8r0v71pZjFTYXZWM3FlRnM)).
36 | * Extract the zip file into the directory `src/celeba-data` that you create.
37 |
38 | The commands below need to be run under the `src` directory.
39 |
40 | MisGAN on MNIST:
41 | ```bash
42 | python mnist_misgan.py
43 | ```
44 |
45 | MisGAN imputation on MNIST:
46 | ```bash
47 | python mnist_misgan_impute.py
48 | ```
49 |
50 | MisGAN on CelebA:
51 | ```bash
52 | python celeba_misgan.py
53 | ```
54 |
55 | MisGAN imputation on CelebA:
56 | ```bash
57 | python celeba_misgan_impute.py
58 | ```
59 |
60 | Use `-h` to see all available command-line arguments for each script.
61 |
62 |
63 | ## References
64 |
65 | Steven Cheng-Xian Li, Bo Jiang, Benjamin Marlin.
66 | "MisGAN: Learning from Incomplete Data with Generative Adversarial Networks."
67 | ICLR 2019.
68 | \[[arXiv](https://arxiv.org/abs/1902.09599)\]
69 |
70 |
71 | ## Contact
72 |
73 | Your feedback would be greatly appreciated!
74 | Reach us at
.
75 |
--------------------------------------------------------------------------------
/img/misgan-impute.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/steveli/misgan/f30dd73ebe602c81b1a0cfb72708c41687fb13d1/img/misgan-impute.png
--------------------------------------------------------------------------------
/img/misgan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/steveli/misgan/f30dd73ebe602c81b1a0cfb72708c41687fb13d1/img/misgan.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib>=2.2.2
2 | numpy>=1.14.5
3 | Pillow>=5.1.0
4 | scipy>=1.1.0
5 | seaborn>=0.8.1
6 | torch==1.1.0
7 | torchvision==0.3.0
8 |
--------------------------------------------------------------------------------
/src-torch1.6/celeba_critic.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def conv_ln_lrelu(in_dim, out_dim):
5 | return nn.Sequential(
6 | nn.Conv2d(in_dim, out_dim, 5, 2, 2),
7 | nn.InstanceNorm2d(out_dim, affine=True),
8 | nn.LeakyReLU(0.2))
9 |
10 |
11 | class ConvCritic(nn.Module):
12 | def __init__(self, n_channels):
13 | super().__init__()
14 | dim = 64
15 | self.ls = nn.Sequential(
16 | nn.Conv2d(n_channels, dim, 5, 2, 2), nn.LeakyReLU(0.2),
17 | conv_ln_lrelu(dim, dim * 2),
18 | conv_ln_lrelu(dim * 2, dim * 4),
19 | conv_ln_lrelu(dim * 4, dim * 8),
20 | nn.Conv2d(dim * 8, 1, 4))
21 |
22 | def forward(self, input):
23 | net = self.ls(input)
24 | return net.view(-1)
25 |
--------------------------------------------------------------------------------
/src-torch1.6/celeba_fid.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | import torch
4 | import torch.nn as nn
5 | from torch.utils.data import DataLoader
6 | from torchvision import datasets, transforms
7 | from PIL import Image
8 | from celeba_generator import ConvDataGenerator
9 | from fid import BaseSampler, BaseImputationSampler
10 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA
11 | from imputer import UNetImputer
12 | from fid import FID
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('root_dir')
17 | parser.add_argument('--batch-size', type=int, default=256)
18 | parser.add_argument('--workers', type=int, default=0)
19 | parser.add_argument('--skip-exist', action='store_true')
20 | args = parser.parse_args()
21 |
22 |
23 | use_cuda = torch.cuda.is_available()
24 | device = torch.device('cuda' if use_cuda else 'cpu')
25 |
26 |
27 | class CelebAFID(FID):
28 | def __init__(self, batch_size=256, data_name='celeba',
29 | workers=0, verbose=True):
30 | self.batch_size = batch_size
31 | self.workers = workers
32 | super().__init__(data_name, verbose)
33 |
34 | def complete_data(self):
35 | data = datasets.ImageFolder(
36 | 'celeba',
37 | transforms.Compose([
38 | transforms.CenterCrop(108),
39 | transforms.Resize(size=64, interpolation=Image.BICUBIC),
40 | transforms.ToTensor(),
41 | # transforms.Normalize(mean=(.5, .5, .5), std=(.5, .5, .5)),
42 | ]))
43 |
44 | images = len(data)
45 | data_loader = DataLoader(
46 | data, batch_size=self.batch_size, num_workers=self.workers)
47 |
48 | return data_loader, images
49 |
50 |
51 | class MisGANSampler(BaseSampler):
52 | def __init__(self, data_gen, images=60000, batch_size=256):
53 | super().__init__(images)
54 | self.data_gen = data_gen
55 | self.batch_size = batch_size
56 | latent_dim = 128
57 | self.data_noise = torch.FloatTensor(batch_size, latent_dim).to(device)
58 |
59 | def sample(self):
60 | self.data_noise.normal_()
61 | return self.data_gen(self.data_noise)
62 |
63 |
64 | class MisGANImputationSampler(BaseImputationSampler):
65 | def __init__(self, data_loader, imputer, batch_size=256):
66 | super().__init__(data_loader)
67 | self.imputer = imputer
68 | self.impu_noise = torch.FloatTensor(batch_size, 3, 64, 64).to(device)
69 |
70 | def impute(self, data, mask):
71 | if data.shape[0] != self.impu_noise.shape[0]:
72 | self.impu_noise.resize_(data.shape)
73 | self.impu_noise.uniform_()
74 | return self.imputer(data, mask, self.impu_noise)
75 |
76 |
77 | def get_data_loader(args, batch_size):
78 | if args.mask == 'indep':
79 | data = IndepMaskedCelebA(
80 | data_dir=args.data_dir,
81 | obs_prob=args.obs_prob, obs_prob_high=args.obs_prob_high)
82 | elif args.mask == 'block':
83 | data = BlockMaskedCelebA(
84 | data_dir=args.data_dir, block_len=args.block_len)
85 |
86 | data_size = len(data)
87 | data_loader = DataLoader(
88 | data, batch_size=batch_size, num_workers=args.workers)
89 | return data_loader, data_size
90 |
91 |
92 | def parallelize(model):
93 | return nn.DataParallel(model).to(device)
94 |
95 |
96 | def pretrained_misgan_fid(model_file, samples=202599):
97 | model = torch.load(model_file, map_location='cpu')
98 | data_gen = parallelize(ConvDataGenerator())
99 | data_gen.load_state_dict(model['data_gen'])
100 |
101 | batch_size = args.batch_size
102 |
103 | compute_fid = CelebAFID(batch_size=batch_size)
104 | sampler = MisGANSampler(data_gen, samples, batch_size)
105 | gen_fid = compute_fid.fid(sampler, samples)
106 | print(f'fid: {gen_fid:.2f}')
107 |
108 | imp_fid = None
109 | if 'imputer' in model:
110 | imputer = UNetImputer().to(device)
111 | imputer.load_state_dict(model['imputer'])
112 | data_loader, data_size = get_data_loader(model['args'], batch_size)
113 | imputation_sampler = MisGANImputationSampler(
114 | data_loader, imputer, batch_size)
115 | imp_fid = compute_fid.fid(imputation_sampler, data_size)
116 | print(f'impute fid: {imp_fid:.2f}')
117 |
118 | return gen_fid, imp_fid
119 |
120 |
121 | def main():
122 | root_dir = Path(args.root_dir)
123 | fid_file = root_dir / 'fid.txt'
124 | if args.skip_exist and fid_file.exists():
125 | return
126 | try:
127 | model_file = max((root_dir / 'model').glob('*.pth'))
128 | except ValueError:
129 | return
130 |
131 | print(root_dir.name)
132 | fid, imp_fid = pretrained_misgan_fid(model_file)
133 |
134 | with fid_file.open('w') as f:
135 | print(fid, file=f)
136 |
137 | if imp_fid is not None:
138 | with (root_dir / 'impute-fid.txt').open('w') as f:
139 | print(imp_fid, file=f)
140 |
141 |
142 | if __name__ == '__main__':
143 | main()
144 |
--------------------------------------------------------------------------------
/src-torch1.6/celeba_generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def add_mask_transformer(self, temperature=.66, hard_sigmoid=(-.1, 1.1)):
7 | """
8 | hard_sigmoid:
9 | False: use sigmoid only
10 | True: hard thresholding
11 | (a, b): hard thresholding on rescaled sigmoid
12 | """
13 | self.temperature = temperature
14 | self.hard_sigmoid = hard_sigmoid
15 |
16 | if hard_sigmoid is False:
17 | self.transform = lambda x: torch.sigmoid(x / temperature)
18 | elif hard_sigmoid is True:
19 | self.transform = lambda x: F.hardtanh(
20 | x / temperature, 0, 1)
21 | else:
22 | a, b = hard_sigmoid
23 | self.transform = lambda x: F.hardtanh(
24 | torch.sigmoid(x / temperature) * (b - a) + a, 0, 1)
25 |
26 |
27 | def dconv_bn_relu(in_dim, out_dim):
28 | return nn.Sequential(
29 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
30 | padding=2, output_padding=1, bias=False),
31 | nn.BatchNorm2d(out_dim),
32 | nn.ReLU())
33 |
34 |
35 | # Must sub-class ConvGenerator to provide transform()
36 | class ConvGenerator(nn.Module):
37 | def __init__(self, latent_size=128):
38 | super().__init__()
39 |
40 | dim = 64
41 |
42 | self.l1 = nn.Sequential(
43 | nn.Linear(latent_size, dim * 8 * 4 * 4, bias=False),
44 | nn.BatchNorm1d(dim * 8 * 4 * 4),
45 | nn.ReLU())
46 |
47 | self.l2_5 = nn.Sequential(
48 | dconv_bn_relu(dim * 8, dim * 4),
49 | dconv_bn_relu(dim * 4, dim * 2),
50 | dconv_bn_relu(dim * 2, dim),
51 | nn.ConvTranspose2d(dim, self.out_channels, 5, 2,
52 | padding=2, output_padding=1))
53 |
54 | def forward(self, input):
55 | net = self.l1(input)
56 | net = net.view(net.shape[0], -1, 4, 4)
57 | net = self.l2_5(net)
58 | return self.transform(net)
59 |
60 |
61 | class ConvDataGenerator(ConvGenerator):
62 | def __init__(self, latent_size=128):
63 | self.out_channels = 3
64 | super().__init__(latent_size=latent_size)
65 | self.transform = lambda x: torch.sigmoid(x)
66 |
67 |
68 | class ConvMaskGenerator(ConvGenerator):
69 | def __init__(self, latent_size=128, temperature=.66,
70 | hard_sigmoid=(-.1, 1.1)):
71 | self.out_channels = 1
72 | super().__init__(latent_size=latent_size)
73 | add_mask_transformer(self, temperature, hard_sigmoid)
74 |
--------------------------------------------------------------------------------
/src-torch1.6/celeba_misgan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from datetime import datetime
4 | from pathlib import Path
5 | import argparse
6 | from celeba_generator import ConvDataGenerator, ConvMaskGenerator
7 | from celeba_critic import ConvCritic
8 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA
9 | from misgan import misgan
10 |
11 |
12 | use_cuda = torch.cuda.is_available()
13 | device = torch.device('cuda' if use_cuda else 'cpu')
14 |
15 |
16 | def parallelize(model):
17 | return nn.DataParallel(model).to(device)
18 |
19 |
20 | def main():
21 | parser = argparse.ArgumentParser()
22 |
23 | # resume from checkpoint
24 | parser.add_argument('--resume')
25 |
26 | # path of CelebA dataset
27 | parser.add_argument('--data-dir', default='celeba-data')
28 |
29 | # training options
30 | parser.add_argument('--epoch', type=int, default=600)
31 | parser.add_argument('--batch-size', type=int, default=256)
32 |
33 | # log options: 0 to disable plot-interval or save-interval
34 | parser.add_argument('--plot-interval', type=int, default=100)
35 | parser.add_argument('--save-interval', type=int, default=0)
36 | parser.add_argument('--prefix', default='misgan')
37 |
38 | # mask options (data): block|indep
39 | parser.add_argument('--mask', default='block')
40 | # option for block: set to 0 for variable size
41 | parser.add_argument('--block-len', type=int, default=32)
42 | # option for indep:
43 | parser.add_argument('--obs-prob', type=float, default=.2)
44 | parser.add_argument('--obs-prob-high', type=float, default=None)
45 |
46 | # model options
47 | parser.add_argument('--tau', type=float, default=.5)
48 | parser.add_argument('--alpha', type=float, default=.1) # 0: separate
49 | # options for mask generator: sigmoid, hardsigmoid, fusion
50 | parser.add_argument('--maskgen', default='fusion')
51 | parser.add_argument('--gp-lambda', type=float, default=10)
52 | parser.add_argument('--n-critic', type=int, default=5)
53 | parser.add_argument('--n-latent', type=int, default=128)
54 |
55 | args = parser.parse_args()
56 |
57 | checkpoint = None
58 | # Resume from previously stored checkpoint
59 | if args.resume:
60 | print(f'Resume: {args.resume}')
61 | output_dir = Path(args.resume)
62 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'),
63 | map_location='cpu')
64 | for key, arg in vars(checkpoint['args']).items():
65 | if key not in ['resume']:
66 | setattr(args, key, arg)
67 |
68 | if args.maskgen == 'sigmoid':
69 | hard_sigmoid = False
70 | elif args.maskgen == 'hardsigmoid':
71 | hard_sigmoid = True
72 | elif args.maskgen == 'fusion':
73 | hard_sigmoid = -.1, 1.1
74 | else:
75 | raise NotImplementedError
76 |
77 | mask = args.mask
78 | obs_prob = args.obs_prob
79 | obs_prob_high = args.obs_prob_high
80 | block_len = args.block_len
81 | if block_len == 0:
82 | block_len = None
83 | if mask == 'indep':
84 | if obs_prob_high is None:
85 | mask_str = f'indep_{obs_prob:g}'
86 | else:
87 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}'
88 | elif mask == 'block':
89 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize')
90 | else:
91 | raise NotImplementedError
92 |
93 | path = '{}_{}_{}'.format(
94 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'),
95 | '_'.join([
96 | f'tau_{args.tau:g}',
97 | f'alpha_{args.alpha:g}',
98 | f'maskgen_{args.maskgen}',
99 | mask_str,
100 | ]))
101 |
102 | if not args.resume:
103 | output_dir = Path('results') / 'celeba' / path
104 | print(output_dir)
105 |
106 | if mask == 'indep':
107 | data = IndepMaskedCelebA(
108 | data_dir=args.data_dir,
109 | obs_prob=obs_prob, obs_prob_high=obs_prob_high)
110 | elif mask == 'block':
111 | data = BlockMaskedCelebA(
112 | data_dir=args.data_dir, block_len=block_len)
113 | n_gpu = torch.cuda.device_count()
114 | print(f'Use {n_gpu} GPUs.')
115 |
116 | data_gen = parallelize(ConvDataGenerator())
117 | mask_gen = parallelize(ConvMaskGenerator(hard_sigmoid=hard_sigmoid))
118 |
119 | data_critic = parallelize(ConvCritic(n_channels=3))
120 | mask_critic = parallelize(ConvCritic(n_channels=1))
121 |
122 | misgan(args, data_gen, mask_gen, data_critic, mask_critic, data,
123 | output_dir, checkpoint)
124 |
125 |
126 | if __name__ == '__main__':
127 | main()
128 |
--------------------------------------------------------------------------------
/src-torch1.6/celeba_misgan_impute.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from datetime import datetime
4 | from pathlib import Path
5 | import argparse
6 | from celeba_generator import ConvDataGenerator, ConvMaskGenerator
7 | from celeba_critic import ConvCritic
8 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA
9 | from imputer import UNetImputer
10 | from misgan_impute import misgan_impute
11 |
12 |
13 | use_cuda = torch.cuda.is_available()
14 | device = torch.device('cuda' if use_cuda else 'cpu')
15 |
16 |
17 | def parallelize(model):
18 | return nn.DataParallel(model).to(device)
19 |
20 |
21 | def main():
22 | parser = argparse.ArgumentParser()
23 |
24 | # resume from checkpoint
25 | parser.add_argument('--resume')
26 |
27 | # path of CelebA dataset
28 | parser.add_argument('--data-dir', default='celeba-data')
29 |
30 | # training options
31 | parser.add_argument('--workers', type=int, default=0)
32 | parser.add_argument('--epoch', type=int, default=800)
33 | parser.add_argument('--batch-size', type=int, default=512)
34 | parser.add_argument('--pretrain', default=None)
35 | parser.add_argument('--imputeronly', action='store_true')
36 |
37 | # log options: 0 to disable plot-interval or save-interval
38 | parser.add_argument('--plot-interval', type=int, default=50)
39 | parser.add_argument('--save-interval', type=int, default=0)
40 | parser.add_argument('--prefix', default='impute')
41 |
42 | # mask options (data): block|indep
43 | parser.add_argument('--mask', default='block')
44 | # option for block: set to 0 for variable size
45 | parser.add_argument('--block-len', type=int, default=32)
46 | # option for indep:
47 | parser.add_argument('--obs-prob', type=float, default=.2)
48 | parser.add_argument('--obs-prob-high', type=float, default=None)
49 |
50 | # model options
51 | parser.add_argument('--tau', type=float, default=.5)
52 | parser.add_argument('--alpha', type=float, default=.1) # 0: separate
53 | parser.add_argument('--beta', type=float, default=.1)
54 | parser.add_argument('--gamma', type=float, default=0)
55 | # options for mask generator: sigmoid, hardsigmoid, fusion
56 | parser.add_argument('--maskgen', default='fusion')
57 | parser.add_argument('--gp-lambda', type=float, default=10)
58 | parser.add_argument('--n-critic', type=int, default=5)
59 | parser.add_argument('--n-latent', type=int, default=128)
60 |
61 | args = parser.parse_args()
62 |
63 | checkpoint = None
64 | # Resume from previously stored checkpoint
65 | if args.resume:
66 | print(f'Resume: {args.resume}')
67 | output_dir = Path(args.resume)
68 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'),
69 | map_location='cpu')
70 | for key, arg in vars(checkpoint['args']).items():
71 | if key not in ['resume']:
72 | setattr(args, key, arg)
73 |
74 | if args.imputeronly:
75 | assert args.pretrain is not None
76 |
77 | mask = args.mask
78 | obs_prob = args.obs_prob
79 | obs_prob_high = args.obs_prob_high
80 | block_len = args.block_len
81 | if block_len == 0:
82 | block_len = None
83 |
84 | if args.maskgen == 'sigmoid':
85 | hard_sigmoid = False
86 | elif args.maskgen == 'hardsigmoid':
87 | hard_sigmoid = True
88 | elif args.maskgen == 'fusion':
89 | hard_sigmoid = -.1, 1.1
90 | else:
91 | raise NotImplementedError
92 |
93 | if mask == 'indep':
94 | if obs_prob_high is None:
95 | mask_str = f'indep_{obs_prob:g}'
96 | else:
97 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}'
98 | elif mask == 'block':
99 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize')
100 | else:
101 | raise NotImplementedError
102 |
103 | path = '{}_{}_{}'.format(
104 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'),
105 | '_'.join([
106 | f'tau_{args.tau:g}',
107 | f'maskgen_{args.maskgen}',
108 | f'coef_{args.alpha:g}_{args.beta:g}_{args.gamma:g}',
109 | mask_str,
110 | ]))
111 |
112 | if not args.resume:
113 | output_dir = Path('results') / 'celeba' / path
114 | print(output_dir)
115 |
116 | if mask == 'indep':
117 | data = IndepMaskedCelebA(
118 | data_dir=args.data_dir,
119 | obs_prob=obs_prob, obs_prob_high=obs_prob_high)
120 | elif mask == 'block':
121 | data = BlockMaskedCelebA(
122 | data_dir=args.data_dir, block_len=block_len)
123 |
124 | n_gpu = torch.cuda.device_count()
125 | print(f'Use {n_gpu} GPUs.')
126 | data_gen = parallelize(ConvDataGenerator())
127 | mask_gen = parallelize(ConvMaskGenerator(hard_sigmoid=hard_sigmoid))
128 | imputer = UNetImputer().to(device)
129 |
130 | data_critic = parallelize(ConvCritic(n_channels=3))
131 | mask_critic = parallelize(ConvCritic(n_channels=1))
132 | impu_critic = parallelize(ConvCritic(n_channels=3))
133 |
134 | misgan_impute(args, data_gen, mask_gen, imputer,
135 | data_critic, mask_critic, impu_critic,
136 | data, output_dir, checkpoint)
137 |
138 |
139 | if __name__ == '__main__':
140 | main()
141 |
--------------------------------------------------------------------------------
/src-torch1.6/fcnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class FullyConnectedNet(nn.Module):
5 | def __init__(self, weights, output_shape=None):
6 | super().__init__()
7 | n_layers = len(weights) - 1
8 |
9 | layers = [nn.Linear(weights[0], weights[1])]
10 | for i in range(1, n_layers):
11 | layers.extend([nn.ReLU(), nn.Linear(weights[i], weights[i + 1])])
12 |
13 | self.model = nn.Sequential(*layers)
14 | self.output_shape = output_shape
15 |
16 | def forward(self, input):
17 | output = self.model(input.view(input.shape[0], -1))
18 | if self.output_shape is not None:
19 | output = output.view(self.output_shape)
20 | return output
21 |
--------------------------------------------------------------------------------
/src-torch1.6/fid.py:
--------------------------------------------------------------------------------
1 | """Code adapted from https://github.com/mseitzer/pytorch-fid
2 | """
3 | from pathlib import Path
4 | import torch
5 | import numpy as np
6 | from scipy import linalg
7 | import time
8 | import sys
9 | from inception import InceptionV3
10 |
11 |
12 | use_cuda = torch.cuda.is_available()
13 | device = torch.device('cuda' if use_cuda else 'cpu')
14 |
15 | FEATURE_DIM = 2048
16 | RESIZE = 299
17 |
18 |
19 | def get_activations(image_iterator, images, model, verbose=True):
20 | """Calculates the activations of the pool_3 layer for all images.
21 |
22 | Params:
23 | -- image_iterator
24 | : A generator that generates a batch of images at a time.
25 | -- images : Number of images that will be generated by
26 | image_iterator.
27 | -- model : Instance of inception model
28 | -- verbose : If set to True and parameter out_step is given, the number
29 | of calculated batches is reported.
30 | Returns:
31 | -- A numpy array of dimension (num images, dims) that contains the
32 | activations of the given tensor when feeding inception with the
33 | query tensor.
34 | """
35 | model.eval()
36 |
37 | if not sys.stdout.isatty():
38 | verbose = False
39 |
40 | pred_arr = np.empty((images, FEATURE_DIM))
41 | end = 0
42 | t0 = time.time()
43 |
44 | for batch in image_iterator:
45 | if not isinstance(batch, torch.Tensor):
46 | batch = batch[0]
47 | start = end
48 | batch_size = batch.shape[0]
49 | end = start + batch_size
50 |
51 | with torch.no_grad():
52 | batch = batch.to(device)
53 | pred = model(batch)[0]
54 | batch_feature = pred.cpu().numpy().reshape(batch_size, -1)
55 | pred_arr[start:end] = batch_feature
56 |
57 | if verbose:
58 | print('\rProcessed: {} time: {:.2f}'.format(
59 | end, time.time() - t0), end='', flush=True)
60 |
61 | assert end == images
62 |
63 | if verbose:
64 | print(' done')
65 |
66 | return pred_arr
67 |
68 |
69 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
70 | """Numpy implementation of the Frechet Distance.
71 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
72 | and X_2 ~ N(mu_2, C_2) is
73 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
74 |
75 | Stable version by Dougal J. Sutherland.
76 |
77 | Params:
78 | -- mu1 : Numpy array containing the activations of a layer of the
79 | inception net (like returned by the function 'get_predictions')
80 | for generated samples.
81 | -- mu2 : The sample mean over activations, precalculated on an
82 | representive data set.
83 | -- sigma1: The covariance matrix over activations for generated samples.
84 | -- sigma2: The covariance matrix over activations, precalculated on an
85 | representive data set.
86 |
87 | Returns:
88 | -- : The Frechet Distance.
89 | """
90 |
91 | mu1 = np.atleast_1d(mu1)
92 | mu2 = np.atleast_1d(mu2)
93 |
94 | sigma1 = np.atleast_2d(sigma1)
95 | sigma2 = np.atleast_2d(sigma2)
96 |
97 | assert mu1.shape == mu2.shape, \
98 | 'Training and test mean vectors have different lengths'
99 | assert sigma1.shape == sigma2.shape, \
100 | 'Training and test covariances have different dimensions'
101 |
102 | diff = mu1 - mu2
103 |
104 | # Product might be almost singular
105 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
106 | if not np.isfinite(covmean).all():
107 | msg = ('fid calculation produces singular product; '
108 | 'adding %s to diagonal of cov estimates') % eps
109 | print(msg)
110 | offset = np.eye(sigma1.shape[0]) * eps
111 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
112 |
113 | # Numerical error might give slight imaginary component
114 | if np.iscomplexobj(covmean):
115 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
116 | m = np.max(np.abs(covmean.imag))
117 | raise ValueError('Imaginary component {}'.format(m))
118 | covmean = covmean.real
119 |
120 | tr_covmean = np.trace(covmean)
121 |
122 | return (diff.dot(diff) + np.trace(sigma1) +
123 | np.trace(sigma2) - 2 * tr_covmean)
124 |
125 |
126 | def calculate_activation_statistics(image_iterator, images, model,
127 | verbose=False):
128 | """Calculation of the statistics used by the FID.
129 | Params:
130 | -- image_iterator
131 | : A generator that generates a batch of images at a time.
132 | -- images : Number of images that will be generated by
133 | image_iterator.
134 | -- model : Instance of inception model
135 | -- verbose : If set to True and parameter out_step is given, the
136 | number of calculated batches is reported.
137 | Returns:
138 | -- mu : The mean over samples of the activations of the pool_3 layer of
139 | the inception model.
140 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
141 | the inception model.
142 | """
143 | act = get_activations(image_iterator, images, model, verbose)
144 | mu = np.mean(act, axis=0)
145 | sigma = np.cov(act, rowvar=False)
146 | return mu, sigma
147 |
148 |
149 | class FID:
150 | def __init__(self, data_name, verbose=True):
151 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[FEATURE_DIM]
152 | model = InceptionV3([block_idx], RESIZE).to(device)
153 | self.verbose = verbose
154 |
155 | stats_dir = Path('fid_stats')
156 | stats_file = stats_dir / '{}_act_{}_{}.npz'.format(
157 | data_name, FEATURE_DIM, RESIZE)
158 |
159 | try:
160 | f = np.load(str(stats_file))
161 | mu, sigma = f['mu'], f['sigma']
162 | f.close()
163 | except FileNotFoundError:
164 | data_loader, images = self.complete_data()
165 | mu, sigma = calculate_activation_statistics(
166 | data_loader, images, model, verbose)
167 | stats_dir.mkdir(parents=True, exist_ok=True)
168 | np.savez(stats_file, mu=mu, sigma=sigma)
169 |
170 | self.model = model
171 | self.stats = mu, sigma
172 |
173 | def complete_data(self):
174 | raise NotImplementedError
175 |
176 | def fid(self, image_iterator, images):
177 | mu, sigma = calculate_activation_statistics(
178 | image_iterator, images, self.model, verbose=self.verbose)
179 | return calculate_frechet_distance(mu, sigma, *self.stats)
180 |
181 |
182 | class BaseSampler:
183 | def __init__(self, images):
184 | self.images = images
185 |
186 | def __iter__(self):
187 | self.n = 0
188 | return self
189 |
190 | def __next__(self):
191 | if self.n < self.images:
192 | batch = self.sample()
193 | batch_size = batch.shape[0]
194 | self.n += batch_size
195 | if self.n > self.images:
196 | return batch[:-(self.n - self.images)]
197 | return batch
198 | else:
199 | raise StopIteration
200 |
201 | def sample(self):
202 | raise NotImplementedError
203 |
204 |
205 | class BaseImputationSampler:
206 | def __init__(self, data_loader):
207 | self.data_loader = data_loader
208 |
209 | def __iter__(self):
210 | self.data_iter = iter(self.data_loader)
211 | return self
212 |
213 | def __next__(self):
214 | data, mask = next(self.data_iter)[:2]
215 | data = data.to(device)
216 | mask = mask.float()[:, None].to(device)
217 | imputed_data = self.impute(data, mask)
218 | return mask * data + (1 - mask) * imputed_data
219 |
220 | def impute(self, data, mask):
221 | raise NotImplementedError
222 |
--------------------------------------------------------------------------------
/src-torch1.6/imputer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from fcnet import FullyConnectedNet
4 | from unet import UnetSkipConnectionBlock
5 |
6 |
7 | # Code adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
8 | class UNet(nn.Module):
9 | def __init__(self, input_nc=3, output_nc=3, ngf=64, layers=5,
10 | norm_layer=nn.BatchNorm2d):
11 | super().__init__()
12 |
13 | mid_layers = layers - 2
14 | fact = 2**mid_layers
15 |
16 | unet_block = UnetSkipConnectionBlock(
17 | ngf * fact, ngf * fact, input_nc=None, submodule=None,
18 | norm_layer=norm_layer, innermost=True)
19 |
20 | for _ in range(mid_layers):
21 | half_fact = fact // 2
22 | unet_block = UnetSkipConnectionBlock(
23 | ngf * half_fact, ngf * fact, input_nc=None,
24 | submodule=unet_block, norm_layer=norm_layer)
25 | fact = half_fact
26 |
27 | unet_block = UnetSkipConnectionBlock(
28 | output_nc, ngf, input_nc=input_nc, submodule=unet_block,
29 | outermost=True, norm_layer=norm_layer)
30 |
31 | self.model = unet_block
32 |
33 | def forward(self, input):
34 | return self.model(input)
35 |
36 |
37 | class Imputer(nn.Module):
38 | def __init__(self):
39 | super().__init__()
40 | self.transform = lambda x: torch.sigmoid(x)
41 |
42 | def forward(self, input, mask, noise):
43 | net = input * mask + noise * (1 - mask)
44 | net = self.imputer_net(net)
45 | net = self.transform(net)
46 | # NOT replacing observed part with input data for computing
47 | # autoencoding loss.
48 | # return input * mask + net * (1 - mask)
49 | return net
50 |
51 |
52 | class UNetImputer(Imputer):
53 | def __init__(self, *args, **kwargs):
54 | super().__init__()
55 | self.imputer_net = UNet(*args, **kwargs)
56 |
57 |
58 | class FullyConnectedImputer(Imputer):
59 | def __init__(self, *args, **kwargs):
60 | super().__init__()
61 | self.imputer_net = FullyConnectedNet(*args, **kwargs)
62 |
--------------------------------------------------------------------------------
/src-torch1.6/inception.py:
--------------------------------------------------------------------------------
1 | """Code from https://github.com/mseitzer/pytorch-fid
2 | """
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision import models
6 |
7 |
8 | class InceptionV3(nn.Module):
9 | """Pretrained InceptionV3 network returning feature maps"""
10 |
11 | # Index of default block of inception to return,
12 | # corresponds to output of final average pooling
13 | DEFAULT_BLOCK_INDEX = 3
14 |
15 | # Maps feature dimensionality to their output blocks indices
16 | BLOCK_INDEX_BY_DIM = {
17 | 64: 0, # First max pooling features
18 | 192: 1, # Second max pooling featurs
19 | 768: 2, # Pre-aux classifier features
20 | 2048: 3 # Final average pooling features
21 | }
22 |
23 | def __init__(self,
24 | output_blocks=[DEFAULT_BLOCK_INDEX],
25 | resize_input=299, # -1: not resize
26 | normalize_input=True,
27 | requires_grad=False):
28 | """Build pretrained InceptionV3
29 |
30 | Parameters
31 | ----------
32 | output_blocks : list of int
33 | Indices of blocks to return features of. Possible values are:
34 | - 0: corresponds to output of first max pooling
35 | - 1: corresponds to output of second max pooling
36 | - 2: corresponds to output which is fed to aux classifier
37 | - 3: corresponds to output of final average pooling
38 | resize_input : bool
39 | If true, bilinearly resizes input to width and height 299 before
40 | feeding input to model. As the network without fully connected
41 | layers is fully convolutional, it should be able to handle inputs
42 | of arbitrary size, so resizing might not be strictly needed
43 | normalize_input : bool
44 | If true, normalizes the input to the statistics the pretrained
45 | Inception network expects
46 | requires_grad : bool
47 | If true, parameters of the model require gradient. Possibly useful
48 | for finetuning the network
49 | """
50 | super(InceptionV3, self).__init__()
51 |
52 | self.resize_input = resize_input
53 | self.normalize_input = normalize_input
54 | self.output_blocks = sorted(output_blocks)
55 | self.last_needed_block = max(output_blocks)
56 |
57 | assert self.last_needed_block <= 3, \
58 | 'Last possible output block index is 3'
59 |
60 | self.blocks = nn.ModuleList()
61 |
62 | inception = models.inception_v3(pretrained=True)
63 |
64 | # Block 0: input to maxpool1
65 | block0 = [
66 | inception.Conv2d_1a_3x3,
67 | inception.Conv2d_2a_3x3,
68 | inception.Conv2d_2b_3x3,
69 | nn.MaxPool2d(kernel_size=3, stride=2)
70 | ]
71 | self.blocks.append(nn.Sequential(*block0))
72 |
73 | # Block 1: maxpool1 to maxpool2
74 | if self.last_needed_block >= 1:
75 | block1 = [
76 | inception.Conv2d_3b_1x1,
77 | inception.Conv2d_4a_3x3,
78 | nn.MaxPool2d(kernel_size=3, stride=2)
79 | ]
80 | self.blocks.append(nn.Sequential(*block1))
81 |
82 | # Block 2: maxpool2 to aux classifier
83 | if self.last_needed_block >= 2:
84 | block2 = [
85 | inception.Mixed_5b,
86 | inception.Mixed_5c,
87 | inception.Mixed_5d,
88 | inception.Mixed_6a,
89 | inception.Mixed_6b,
90 | inception.Mixed_6c,
91 | inception.Mixed_6d,
92 | inception.Mixed_6e,
93 | ]
94 | self.blocks.append(nn.Sequential(*block2))
95 |
96 | # Block 3: aux classifier to final avgpool
97 | if self.last_needed_block >= 3:
98 | block3 = [
99 | inception.Mixed_7a,
100 | inception.Mixed_7b,
101 | inception.Mixed_7c,
102 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
103 | ]
104 | self.blocks.append(nn.Sequential(*block3))
105 |
106 | for param in self.parameters():
107 | param.requires_grad = requires_grad
108 |
109 | def forward(self, inp):
110 | """Get Inception feature maps
111 |
112 | Parameters
113 | ----------
114 | inp : torch.autograd.Variable
115 | Input tensor of shape Bx3xHxW. Values are expected to be in
116 | range (0, 1)
117 |
118 | Returns
119 | -------
120 | List of torch.autograd.Variable, corresponding to the selected output
121 | block, sorted ascending by index
122 | """
123 | outp = []
124 | x = inp
125 |
126 | if self.resize_input > 0:
127 | # size = 299
128 | x = F.interpolate(x, size=(self.resize_input, self.resize_input),
129 | mode='bilinear', align_corners=True)
130 |
131 | if self.normalize_input:
132 | x = x.clone()
133 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
134 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
135 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
136 |
137 | for idx, block in enumerate(self.blocks):
138 | x = block(x)
139 | if idx in self.output_blocks:
140 | outp.append(x)
141 |
142 | if idx == self.last_needed_block:
143 | break
144 |
145 | return outp
146 |
--------------------------------------------------------------------------------
/src-torch1.6/masked_celeba.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import datasets, transforms
3 | import numpy as np
4 | from PIL import Image
5 |
6 |
7 | class MaskedCelebA(datasets.ImageFolder):
8 | def __init__(self, data_dir='celeba-data', image_size=64, random_seed=0):
9 | transform = transforms.Compose([
10 | transforms.CenterCrop(108),
11 | transforms.Resize(size=image_size, interpolation=Image.BICUBIC),
12 | transforms.ToTensor(),
13 | # transforms.Normalize(mean=(.5, .5, .5), std=(.5, .5, .5)),
14 | ])
15 |
16 | super().__init__(data_dir, transform)
17 |
18 | self.rnd = np.random.RandomState(random_seed)
19 | self.image_size = image_size
20 | self.generate_masks()
21 |
22 | def __getitem__(self, index):
23 | image, label = super().__getitem__(index)
24 | return image, self.mask[index], label, index
25 |
26 | def __len__(self):
27 | return super().__len__()
28 |
29 |
30 | class BlockMaskedCelebA(MaskedCelebA):
31 | def __init__(self, block_len=None, *args, **kwargs):
32 | self.block_len = block_len
33 | super().__init__(*args, **kwargs)
34 |
35 | def generate_masks(self):
36 | d0_len = d1_len = self.image_size
37 | d0_min_len = 12
38 | d0_max_len = d0_len - d0_min_len
39 | d1_min_len = 12
40 | d1_max_len = d1_len - d1_min_len
41 |
42 | n_masks = len(self)
43 | self.mask = [None] * n_masks
44 | self.mask_info = [None] * n_masks
45 | for i in range(n_masks):
46 | if self.block_len is None:
47 | d0_mask_len = self.rnd.randint(d0_min_len, d0_max_len)
48 | d1_mask_len = self.rnd.randint(d1_min_len, d1_max_len)
49 | else:
50 | d0_mask_len = d1_mask_len = self.block_len
51 |
52 | d0_start = self.rnd.randint(0, d0_len - d0_mask_len + 1)
53 | d1_start = self.rnd.randint(0, d1_len - d1_mask_len + 1)
54 |
55 | mask = torch.zeros((d0_len, d1_len), dtype=torch.uint8)
56 | mask[d0_start:(d0_start + d0_mask_len),
57 | d1_start:(d1_start + d1_mask_len)] = 1
58 | self.mask[i] = mask
59 | self.mask_info[i] = d0_start, d1_start, d0_mask_len, d1_mask_len
60 |
61 |
62 | class IndepMaskedCelebA(MaskedCelebA):
63 | def __init__(self, obs_prob=.2, obs_prob_high=None, *args, **kwargs):
64 | self.prob = obs_prob
65 | self.prob_high = obs_prob_high
66 | super().__init__(*args, **kwargs)
67 |
68 | def generate_masks(self):
69 | imsize = self.image_size
70 | prob = self.prob
71 | prob_high = self.prob_high
72 | n_masks = len(self)
73 | self.mask = [None] * n_masks
74 | for i in range(n_masks):
75 | if prob_high is None:
76 | p = prob
77 | else:
78 | p = self.rnd.uniform(prob, prob_high)
79 | self.mask[i] = torch.ByteTensor(imsize, imsize).bernoulli_(p)
80 |
--------------------------------------------------------------------------------
/src-torch1.6/masked_mnist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from torchvision import datasets, transforms
4 | import numpy as np
5 |
6 |
7 | class MaskedMNIST(Dataset):
8 | def __init__(self, data_dir='mnist-data', image_size=28, random_seed=0):
9 | self.rnd = np.random.RandomState(random_seed)
10 | self.image_size = image_size
11 | if image_size == 28:
12 | self.data = datasets.MNIST(
13 | data_dir, train=True, download=True,
14 | transform=transforms.ToTensor())
15 | else:
16 | self.data = datasets.MNIST(
17 | data_dir, train=True, download=True,
18 | transform=transforms.Compose([
19 | transforms.Resize(image_size), transforms.ToTensor()]))
20 | self.generate_masks()
21 |
22 | def __getitem__(self, index):
23 | image, label = self.data[index]
24 | return image, self.mask[index], label, index
25 |
26 | def __len__(self):
27 | return len(self.data)
28 |
29 | def generate_masks(self):
30 | raise NotImplementedError
31 |
32 |
33 | class BlockMaskedMNIST(MaskedMNIST):
34 | def __init__(self, block_len=None, *args, **kwargs):
35 | self.block_len = block_len
36 | super().__init__(*args, **kwargs)
37 |
38 | def generate_masks(self):
39 | d0_len = d1_len = self.image_size
40 | d0_min_len = 7
41 | d0_max_len = d0_len - d0_min_len
42 | d1_min_len = 7
43 | d1_max_len = d1_len - d1_min_len
44 |
45 | n_masks = len(self)
46 | self.mask = [None] * n_masks
47 | self.mask_info = [None] * n_masks
48 | for i in range(n_masks):
49 | if self.block_len is None:
50 | d0_mask_len = self.rnd.randint(d0_min_len, d0_max_len)
51 | d1_mask_len = self.rnd.randint(d1_min_len, d1_max_len)
52 | else:
53 | d0_mask_len = d1_mask_len = self.block_len
54 |
55 | d0_start = self.rnd.randint(0, d0_len - d0_mask_len + 1)
56 | d1_start = self.rnd.randint(0, d1_len - d1_mask_len + 1)
57 |
58 | mask = torch.zeros((d0_len, d1_len), dtype=torch.uint8)
59 | mask[d0_start:(d0_start + d0_mask_len),
60 | d1_start:(d1_start + d1_mask_len)] = 1
61 | self.mask[i] = mask
62 | self.mask_info[i] = d0_start, d1_start, d0_mask_len, d1_mask_len
63 |
64 |
65 | class IndepMaskedMNIST(MaskedMNIST):
66 | def __init__(self, obs_prob=.2, obs_prob_high=None, *args, **kwargs):
67 | self.prob = obs_prob
68 | self.prob_high = obs_prob_high
69 | super().__init__(*args, **kwargs)
70 |
71 | def generate_masks(self):
72 | imsize = self.image_size
73 | prob = self.prob
74 | prob_high = self.prob_high
75 | n_masks = len(self)
76 | self.mask = [None] * n_masks
77 | for i in range(n_masks):
78 | if prob_high is None:
79 | p = prob
80 | else:
81 | p = self.rnd.uniform(prob, prob_high)
82 | self.mask[i] = torch.ByteTensor(imsize, imsize).bernoulli_(p)
83 |
--------------------------------------------------------------------------------
/src-torch1.6/misgan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | from torch.utils.data import DataLoader
4 | import time
5 | import pylab as plt
6 | import seaborn as sns
7 | from collections import defaultdict
8 | from plot import plot_samples
9 | from utils import CriticUpdater, mkdir, mask_data
10 |
11 |
12 | use_cuda = torch.cuda.is_available()
13 | device = torch.device('cuda' if use_cuda else 'cpu')
14 |
15 |
16 | def misgan(args, data_gen, mask_gen, data_critic, mask_critic, data,
17 | output_dir, checkpoint=None):
18 | n_critic = args.n_critic
19 | gp_lambda = args.gp_lambda
20 | batch_size = args.batch_size
21 | nz = args.n_latent
22 | epochs = args.epoch
23 | plot_interval = args.plot_interval
24 | save_interval = args.save_interval
25 | alpha = args.alpha
26 | tau = args.tau
27 |
28 | gen_data_dir = mkdir(output_dir / 'img')
29 | gen_mask_dir = mkdir(output_dir / 'mask')
30 | log_dir = mkdir(output_dir / 'log')
31 | model_dir = mkdir(output_dir / 'model')
32 |
33 | data_loader = DataLoader(data, batch_size=batch_size, shuffle=True,
34 | drop_last=True)
35 | n_batch = len(data_loader)
36 |
37 | data_noise = torch.FloatTensor(batch_size, nz).to(device)
38 | mask_noise = torch.FloatTensor(batch_size, nz).to(device)
39 |
40 | # Interpolation coefficient
41 | eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device)
42 |
43 | # For computing gradient penalty
44 | ones = torch.ones(batch_size).to(device)
45 |
46 | lrate = 1e-4
47 | # lrate = 1e-5
48 | data_gen_optimizer = optim.Adam(
49 | data_gen.parameters(), lr=lrate, betas=(.5, .9))
50 | mask_gen_optimizer = optim.Adam(
51 | mask_gen.parameters(), lr=lrate, betas=(.5, .9))
52 |
53 | data_critic_optimizer = optim.Adam(
54 | data_critic.parameters(), lr=lrate, betas=(.5, .9))
55 | mask_critic_optimizer = optim.Adam(
56 | mask_critic.parameters(), lr=lrate, betas=(.5, .9))
57 |
58 | update_data_critic = CriticUpdater(
59 | data_critic, data_critic_optimizer, eps, ones, gp_lambda)
60 | update_mask_critic = CriticUpdater(
61 | mask_critic, mask_critic_optimizer, eps, ones, gp_lambda)
62 |
63 | start_epoch = 0
64 | critic_updates = 0
65 | log = defaultdict(list)
66 |
67 | if checkpoint:
68 | data_gen.load_state_dict(checkpoint['data_gen'])
69 | mask_gen.load_state_dict(checkpoint['mask_gen'])
70 | data_critic.load_state_dict(checkpoint['data_critic'])
71 | mask_critic.load_state_dict(checkpoint['mask_critic'])
72 | data_gen_optimizer.load_state_dict(checkpoint['data_gen_opt'])
73 | mask_gen_optimizer.load_state_dict(checkpoint['mask_gen_opt'])
74 | data_critic_optimizer.load_state_dict(checkpoint['data_critic_opt'])
75 | mask_critic_optimizer.load_state_dict(checkpoint['mask_critic_opt'])
76 | start_epoch = checkpoint['epoch']
77 | critic_updates = checkpoint['critic_updates']
78 | log = checkpoint['log']
79 |
80 | with (log_dir / 'gpu.txt').open('a') as f:
81 | print(torch.cuda.device_count(), start_epoch, file=f)
82 |
83 | def save_model(path, epoch, critic_updates=0):
84 | torch.save({
85 | 'data_gen': data_gen.state_dict(),
86 | 'mask_gen': mask_gen.state_dict(),
87 | 'data_critic': data_critic.state_dict(),
88 | 'mask_critic': mask_critic.state_dict(),
89 | 'data_gen_opt': data_gen_optimizer.state_dict(),
90 | 'mask_gen_opt': mask_gen_optimizer.state_dict(),
91 | 'data_critic_opt': data_critic_optimizer.state_dict(),
92 | 'mask_critic_opt': mask_critic_optimizer.state_dict(),
93 | 'epoch': epoch + 1,
94 | 'critic_updates': critic_updates,
95 | 'log': log,
96 | 'args': args,
97 | }, str(path))
98 |
99 | sns.set()
100 |
101 | start = time.time()
102 | epoch_start = start
103 |
104 | for epoch in range(start_epoch, epochs):
105 | sum_data_loss, sum_mask_loss = 0, 0
106 | for real_data, real_mask, _, _ in data_loader:
107 | # Assume real_data and mask have the same number of channels.
108 | # Could be modified to handle multi-channel images and
109 | # single-channel masks.
110 | real_mask = real_mask.float()[:, None]
111 |
112 | real_data = real_data.to(device)
113 | real_mask = real_mask.to(device)
114 |
115 | masked_real_data = mask_data(real_data, real_mask, tau)
116 |
117 | # Update discriminators' parameters
118 | data_noise.normal_()
119 | mask_noise.normal_()
120 |
121 | fake_data = data_gen(data_noise)
122 | fake_mask = mask_gen(mask_noise)
123 |
124 | masked_fake_data = mask_data(fake_data, fake_mask, tau)
125 |
126 | update_data_critic(masked_real_data, masked_fake_data)
127 | update_mask_critic(real_mask, fake_mask)
128 |
129 | sum_data_loss += update_data_critic.loss_value
130 | sum_mask_loss += update_mask_critic.loss_value
131 |
132 | critic_updates += 1
133 |
134 | if critic_updates == n_critic:
135 | critic_updates = 0
136 |
137 | # Update generators' parameters
138 | for p in data_critic.parameters():
139 | p.requires_grad_(False)
140 | for p in mask_critic.parameters():
141 | p.requires_grad_(False)
142 |
143 | data_noise.normal_()
144 | mask_noise.normal_()
145 |
146 | fake_data = data_gen(data_noise)
147 | fake_mask = mask_gen(mask_noise)
148 | masked_fake_data = mask_data(fake_data, fake_mask, tau)
149 |
150 | data_loss = -data_critic(masked_fake_data).mean()
151 | data_gen.zero_grad()
152 | data_loss.backward()
153 | data_gen_optimizer.step()
154 |
155 | data_noise.normal_()
156 | mask_noise.normal_()
157 |
158 | fake_data = data_gen(data_noise)
159 | fake_mask = mask_gen(mask_noise)
160 | masked_fake_data = mask_data(fake_data, fake_mask, tau)
161 |
162 | data_loss = -data_critic(masked_fake_data).mean()
163 | mask_loss = -mask_critic(fake_mask).mean()
164 | mask_gen.zero_grad()
165 | (mask_loss + data_loss * alpha).backward()
166 | mask_gen_optimizer.step()
167 |
168 | for p in data_critic.parameters():
169 | p.requires_grad_(True)
170 | for p in mask_critic.parameters():
171 | p.requires_grad_(True)
172 |
173 | mean_data_loss = sum_data_loss / n_batch
174 | mean_mask_loss = sum_mask_loss / n_batch
175 | log['data loss', 'data_loss'].append(mean_data_loss)
176 | log['mask loss', 'mask_loss'].append(mean_mask_loss)
177 |
178 | for (name, shortname), trace in log.items():
179 | fig, ax = plt.subplots(figsize=(6, 4))
180 | ax.plot(trace)
181 | ax.set_ylabel(name)
182 | ax.set_xlabel('epoch')
183 | fig.savefig(str(log_dir / f'{shortname}.png'), dpi=300)
184 | plt.close(fig)
185 |
186 | if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
187 | print(f'[{epoch:4}] {mean_data_loss:12.4f} {mean_mask_loss:12.4f}')
188 |
189 | filename = f'{epoch:04d}.png'
190 |
191 | data_gen.eval()
192 | mask_gen.eval()
193 |
194 | with torch.no_grad():
195 | data_noise.normal_()
196 | mask_noise.normal_()
197 |
198 | data_samples = data_gen(data_noise)
199 | plot_samples(data_samples, str(gen_data_dir / filename))
200 |
201 | mask_samples = mask_gen(mask_noise)
202 | plot_samples(mask_samples, str(gen_mask_dir / filename))
203 |
204 | data_gen.train()
205 | mask_gen.train()
206 |
207 | if save_interval > 0 and (epoch + 1) % save_interval == 0:
208 | save_model(model_dir / f'{epoch:04d}.pth', epoch, critic_updates)
209 |
210 | epoch_end = time.time()
211 | time_elapsed = epoch_end - start
212 | epoch_time = epoch_end - epoch_start
213 | epoch_start = epoch_end
214 | with (log_dir / 'time.txt').open('a') as f:
215 | print(epoch, epoch_time, time_elapsed, file=f)
216 | save_model(log_dir / 'checkpoint.pth', epoch, critic_updates)
217 |
218 | print(output_dir)
219 |
--------------------------------------------------------------------------------
/src-torch1.6/misgan_impute.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | from torch.utils.data import DataLoader
4 | import time
5 | import pylab as plt
6 | import seaborn as sns
7 | from collections import defaultdict
8 | from plot import plot_grid, plot_samples
9 | from utils import CriticUpdater, mask_norm, mkdir, mask_data
10 |
11 |
12 | use_cuda = torch.cuda.is_available()
13 | device = torch.device('cuda' if use_cuda else 'cpu')
14 |
15 |
16 | def misgan_impute(args, data_gen, mask_gen, imputer,
17 | data_critic, mask_critic, impu_critic,
18 | data, output_dir, checkpoint=None):
19 | n_critic = args.n_critic
20 | gp_lambda = args.gp_lambda
21 | batch_size = args.batch_size
22 | nz = args.n_latent
23 | epochs = args.epoch
24 | plot_interval = args.plot_interval
25 | save_model_interval = args.save_interval
26 | alpha = args.alpha
27 | beta = args.beta
28 | gamma = args.gamma
29 | tau = args.tau
30 | update_all_networks = not args.imputeronly
31 |
32 | gen_data_dir = mkdir(output_dir / 'img')
33 | gen_mask_dir = mkdir(output_dir / 'mask')
34 | impute_dir = mkdir(output_dir / 'impute')
35 | log_dir = mkdir(output_dir / 'log')
36 | model_dir = mkdir(output_dir / 'model')
37 |
38 | data_loader = DataLoader(data, batch_size=batch_size, shuffle=True,
39 | drop_last=True, num_workers=args.workers)
40 | n_batch = len(data_loader)
41 | data_shape = data[0][0].shape
42 |
43 | data_noise = torch.FloatTensor(batch_size, nz).to(device)
44 | mask_noise = torch.FloatTensor(batch_size, nz).to(device)
45 | impu_noise = torch.FloatTensor(batch_size, *data_shape).to(device)
46 |
47 | # Interpolation coefficient
48 | eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device)
49 |
50 | # For computing gradient penalty
51 | ones = torch.ones(batch_size).to(device)
52 |
53 | lrate = 1e-4
54 | imputer_lrate = 2e-4
55 | data_gen_optimizer = optim.Adam(
56 | data_gen.parameters(), lr=lrate, betas=(.5, .9))
57 | mask_gen_optimizer = optim.Adam(
58 | mask_gen.parameters(), lr=lrate, betas=(.5, .9))
59 | imputer_optimizer = optim.Adam(
60 | imputer.parameters(), lr=imputer_lrate, betas=(.5, .9))
61 |
62 | data_critic_optimizer = optim.Adam(
63 | data_critic.parameters(), lr=lrate, betas=(.5, .9))
64 | mask_critic_optimizer = optim.Adam(
65 | mask_critic.parameters(), lr=lrate, betas=(.5, .9))
66 | impu_critic_optimizer = optim.Adam(
67 | impu_critic.parameters(), lr=imputer_lrate, betas=(.5, .9))
68 |
69 | update_data_critic = CriticUpdater(
70 | data_critic, data_critic_optimizer, eps, ones, gp_lambda)
71 | update_mask_critic = CriticUpdater(
72 | mask_critic, mask_critic_optimizer, eps, ones, gp_lambda)
73 | update_impu_critic = CriticUpdater(
74 | impu_critic, impu_critic_optimizer, eps, ones, gp_lambda)
75 |
76 | start_epoch = 0
77 | critic_updates = 0
78 | log = defaultdict(list)
79 |
80 | if args.resume:
81 | data_gen.load_state_dict(checkpoint['data_gen'])
82 | mask_gen.load_state_dict(checkpoint['mask_gen'])
83 | imputer.load_state_dict(checkpoint['imputer'])
84 | data_critic.load_state_dict(checkpoint['data_critic'])
85 | mask_critic.load_state_dict(checkpoint['mask_critic'])
86 | impu_critic.load_state_dict(checkpoint['impu_critic'])
87 | data_gen_optimizer.load_state_dict(checkpoint['data_gen_opt'])
88 | mask_gen_optimizer.load_state_dict(checkpoint['mask_gen_opt'])
89 | imputer_optimizer.load_state_dict(checkpoint['imputer_opt'])
90 | data_critic_optimizer.load_state_dict(checkpoint['data_critic_opt'])
91 | mask_critic_optimizer.load_state_dict(checkpoint['mask_critic_opt'])
92 | impu_critic_optimizer.load_state_dict(checkpoint['impu_critic_opt'])
93 | start_epoch = checkpoint['epoch']
94 | critic_updates = checkpoint['critic_updates']
95 | log = checkpoint['log']
96 | elif args.pretrain:
97 | pretrain = torch.load(args.pretrain, map_location='cpu')
98 | data_gen.load_state_dict(pretrain['data_gen'])
99 | mask_gen.load_state_dict(pretrain['mask_gen'])
100 | data_critic.load_state_dict(pretrain['data_critic'])
101 | mask_critic.load_state_dict(pretrain['mask_critic'])
102 | if 'imputer' in pretrain:
103 | imputer.load_state_dict(pretrain['imputer'])
104 | impu_critic.load_state_dict(pretrain['impu_critic'])
105 |
106 | with (log_dir / 'gpu.txt').open('a') as f:
107 | print(torch.cuda.device_count(), start_epoch, file=f)
108 |
109 | def save_model(path, epoch, critic_updates=0):
110 | torch.save({
111 | 'data_gen': data_gen.state_dict(),
112 | 'mask_gen': mask_gen.state_dict(),
113 | 'imputer': imputer.state_dict(),
114 | 'data_critic': data_critic.state_dict(),
115 | 'mask_critic': mask_critic.state_dict(),
116 | 'impu_critic': impu_critic.state_dict(),
117 | 'data_gen_opt': data_gen_optimizer.state_dict(),
118 | 'mask_gen_opt': mask_gen_optimizer.state_dict(),
119 | 'imputer_opt': imputer_optimizer.state_dict(),
120 | 'data_critic_opt': data_critic_optimizer.state_dict(),
121 | 'mask_critic_opt': mask_critic_optimizer.state_dict(),
122 | 'impu_critic_opt': impu_critic_optimizer.state_dict(),
123 | 'epoch': epoch + 1,
124 | 'critic_updates': critic_updates,
125 | 'log': log,
126 | 'args': args,
127 | }, str(path))
128 |
129 | sns.set()
130 | start = time.time()
131 | epoch_start = start
132 |
133 | for epoch in range(start_epoch, epochs):
134 | sum_data_loss, sum_mask_loss, sum_impu_loss = 0, 0, 0
135 | for real_data, real_mask, _, index in data_loader:
136 | # Assume real_data and real_mask have the same number of channels.
137 | # Could be modified to handle multi-channel images and
138 | # single-channel masks.
139 | real_mask = real_mask.float()[:, None]
140 |
141 | real_data = real_data.to(device)
142 | real_mask = real_mask.to(device)
143 |
144 | masked_real_data = mask_data(real_data, real_mask, tau)
145 |
146 | # Update discriminators' parameters
147 | data_noise.normal_()
148 | fake_data = data_gen(data_noise)
149 |
150 | impu_noise.uniform_()
151 | imputed_data = imputer(real_data, real_mask, impu_noise)
152 | masked_imputed_data = mask_data(real_data, real_mask, imputed_data)
153 |
154 | if update_all_networks:
155 | mask_noise.normal_()
156 | fake_mask = mask_gen(mask_noise)
157 | masked_fake_data = mask_data(fake_data, fake_mask, tau)
158 | update_data_critic(masked_real_data, masked_fake_data)
159 | update_mask_critic(real_mask, fake_mask)
160 |
161 | sum_data_loss += update_data_critic.loss_value
162 | sum_mask_loss += update_mask_critic.loss_value
163 |
164 | update_impu_critic(fake_data, masked_imputed_data)
165 | sum_impu_loss += update_impu_critic.loss_value
166 |
167 | critic_updates += 1
168 |
169 | if critic_updates == n_critic:
170 | critic_updates = 0
171 |
172 | # Update generators' parameters
173 | if update_all_networks:
174 | for p in data_critic.parameters():
175 | p.requires_grad_(False)
176 | for p in mask_critic.parameters():
177 | p.requires_grad_(False)
178 | for p in impu_critic.parameters():
179 | p.requires_grad_(False)
180 |
181 | impu_noise.uniform_()
182 | imputed_data = imputer(real_data, real_mask, impu_noise)
183 | masked_imputed_data = mask_data(real_data, real_mask,
184 | imputed_data)
185 | impu_loss = -impu_critic(masked_imputed_data).mean()
186 |
187 | if update_all_networks:
188 | data_noise.normal_()
189 | fake_data = data_gen(data_noise)
190 | mask_noise.normal_()
191 | fake_mask = mask_gen(mask_noise)
192 | masked_fake_data = mask_data(fake_data, fake_mask, tau)
193 | data_loss = -data_critic(masked_fake_data).mean()
194 | mask_loss = -mask_critic(fake_mask).mean()
195 |
196 | mask_gen.zero_grad()
197 | (mask_loss + data_loss * alpha).backward(retain_graph=True)
198 | mask_gen_optimizer.step()
199 |
200 | data_noise.normal_()
201 | fake_data = data_gen(data_noise)
202 | mask_noise.normal_()
203 | fake_mask = mask_gen(mask_noise)
204 | masked_fake_data = mask_data(fake_data, fake_mask, tau)
205 | data_loss = -data_critic(masked_fake_data).mean()
206 |
207 | data_gen.zero_grad()
208 | (data_loss + impu_loss * beta).backward(retain_graph=True)
209 | data_gen_optimizer.step()
210 |
211 | imputer.zero_grad()
212 | if gamma > 0:
213 | imputer_mismatch_loss = mask_norm(
214 | (imputed_data - real_data)**2, real_mask)
215 | (impu_loss + imputer_mismatch_loss * gamma).backward()
216 | else:
217 | impu_loss.backward()
218 | imputer_optimizer.step()
219 |
220 | if update_all_networks:
221 | for p in data_critic.parameters():
222 | p.requires_grad_(True)
223 | for p in mask_critic.parameters():
224 | p.requires_grad_(True)
225 | for p in impu_critic.parameters():
226 | p.requires_grad_(True)
227 |
228 | if update_all_networks:
229 | mean_data_loss = sum_data_loss / n_batch
230 | mean_mask_loss = sum_mask_loss / n_batch
231 | log['data loss', 'data_loss'].append(mean_data_loss)
232 | log['mask loss', 'mask_loss'].append(mean_mask_loss)
233 | mean_impu_loss = sum_impu_loss / n_batch
234 | log['imputer loss', 'impu_loss'].append(mean_impu_loss)
235 |
236 | if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
237 | if update_all_networks:
238 | print('[{:4}] {:12.4f} {:12.4f} {:12.4f}'.format(
239 | epoch, mean_data_loss, mean_mask_loss, mean_impu_loss))
240 | else:
241 | print('[{:4}] {:12.4f}'.format(epoch, mean_impu_loss))
242 |
243 | filename = f'{epoch:04d}.png'
244 | with torch.no_grad():
245 | data_gen.eval()
246 | mask_gen.eval()
247 | imputer.eval()
248 |
249 | data_noise.normal_()
250 | mask_noise.normal_()
251 |
252 | data_samples = data_gen(data_noise)
253 | plot_samples(data_samples, str(gen_data_dir / filename))
254 |
255 | mask_samples = mask_gen(mask_noise)
256 | plot_samples(mask_samples, str(gen_mask_dir / filename))
257 |
258 | # Plot imputation results
259 | impu_noise.uniform_()
260 | imputed_data = imputer(real_data, real_mask, impu_noise)
261 | imputed_data = mask_data(real_data, real_mask, imputed_data)
262 | if hasattr(data, 'mask_info'):
263 | bbox = [data.mask_info[idx] for idx in index]
264 | else:
265 | bbox = None
266 | plot_grid(imputed_data, bbox, gap=2,
267 | save_file=str(impute_dir / filename))
268 |
269 | data_gen.train()
270 | mask_gen.train()
271 | imputer.train()
272 |
273 | for (name, shortname), trace in log.items():
274 | fig, ax = plt.subplots(figsize=(6, 4))
275 | ax.plot(trace)
276 | ax.set_ylabel(name)
277 | ax.set_xlabel('epoch')
278 | fig.savefig(str(log_dir / f'{shortname}.png'), dpi=300)
279 | plt.close(fig)
280 |
281 | if save_model_interval > 0 and (epoch + 1) % save_model_interval == 0:
282 | save_model(model_dir / f'{epoch:04d}.pth', epoch, critic_updates)
283 |
284 | epoch_end = time.time()
285 | time_elapsed = epoch_end - start
286 | epoch_time = epoch_end - epoch_start
287 | epoch_start = epoch_end
288 | with (log_dir / 'epoch-time.txt').open('a') as f:
289 | print(epoch, epoch_time, time_elapsed, file=f)
290 | save_model(log_dir / 'checkpoint.pth', epoch, critic_updates)
291 |
292 | print(output_dir)
293 |
--------------------------------------------------------------------------------
/src-torch1.6/mnist_critic.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class ConvCritic(nn.Module):
5 | def __init__(self):
6 | super().__init__()
7 |
8 | self.DIM = 64
9 | main = nn.Sequential(
10 | nn.Conv2d(1, self.DIM, 5, stride=2, padding=2),
11 | nn.ReLU(True),
12 | nn.Conv2d(self.DIM, 2 * self.DIM, 5, stride=2, padding=2),
13 | nn.ReLU(True),
14 | nn.Conv2d(2 * self.DIM, 4 * self.DIM, 5, stride=2, padding=2),
15 | nn.ReLU(True),
16 | )
17 | self.main = main
18 | self.output = nn.Linear(4 * 4 * 4 * self.DIM, 1)
19 |
20 | def forward(self, input):
21 | input = input.view(-1, 1, 28, 28)
22 | net = self.main(input)
23 | net = net.view(-1, 4 * 4 * 4 * self.DIM)
24 | net = self.output(net)
25 | return net.view(-1)
26 |
27 |
28 | class FCCritic(nn.Module):
29 | def __init__(self):
30 | super().__init__()
31 |
32 | self.in_dim = 784
33 | self.main = nn.Sequential(
34 | nn.Linear(self.in_dim, 512),
35 | nn.ReLU(True),
36 | nn.Linear(512, 256),
37 | nn.ReLU(True),
38 | nn.Linear(256, 128),
39 | nn.ReLU(True),
40 | nn.Linear(128, 1),
41 | )
42 |
43 | def forward(self, input):
44 | input = input.view(input.size(0), -1)
45 | out = self.main(input)
46 | return out.view(-1)
47 |
--------------------------------------------------------------------------------
/src-torch1.6/mnist_fid.py:
--------------------------------------------------------------------------------
1 | """Code adapted from https://github.com/mseitzer/pytorch-fid
2 | """
3 | import torch
4 | import numpy as np
5 | from scipy import linalg
6 | from torch.utils.data import DataLoader
7 | from torchvision import datasets, transforms
8 | import argparse
9 |
10 | import mnist_model
11 | from mnist_generator import ConvDataGenerator, FCDataGenerator
12 | from mnist_imputer import ComplementImputer, MaskImputer, FixedNoiseDimImputer
13 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST
14 | from pathlib import Path
15 |
16 |
17 | use_cuda = torch.cuda.is_available()
18 | device = torch.device('cuda' if use_cuda else 'cpu')
19 |
20 | feature_layer = 0
21 |
22 |
23 | def get_activations(image_generator, images, model, verbose=False):
24 | """Calculates the activations of the pool_3 layer for all images.
25 |
26 | Params:
27 | -- image_generator
28 | : A generator that generates a batch of images at a time.
29 | -- images : Number of images that will be generated by
30 | image_generator.
31 | -- model : Instance of inception model
32 | -- verbose : If set to True and parameter out_step is given, the number
33 | of calculated batches is reported.
34 | Returns:
35 | -- A numpy array of dimension (num images, dims) that contains the
36 | activations of the given tensor when feeding inception with the
37 | query tensor.
38 | """
39 | model.eval()
40 |
41 | pred_arr = None
42 | end = 0
43 | for i, batch in enumerate(image_generator):
44 | if verbose:
45 | print('\rPropagating batch %d' % (i + 1), end='', flush=True)
46 | start = end
47 | batch_size = batch.shape[0]
48 | end = start + batch_size
49 | batch = batch.to(device)
50 |
51 | with torch.no_grad():
52 | model(batch)
53 | pred = model.feature[feature_layer]
54 | batch_feature = pred.cpu().numpy().reshape(batch_size, -1)
55 | if pred_arr is None:
56 | pred_arr = np.empty((images, batch_feature.shape[1]))
57 | pred_arr[start:end] = batch_feature
58 |
59 | if verbose:
60 | print(' done')
61 |
62 | return pred_arr
63 |
64 |
65 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
66 | """Numpy implementation of the Frechet Distance.
67 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
68 | and X_2 ~ N(mu_2, C_2) is
69 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
70 |
71 | Stable version by Dougal J. Sutherland.
72 |
73 | Params:
74 | -- mu1 : Numpy array containing the activations of a layer of the
75 | inception net (like returned by the function 'get_predictions')
76 | for generated samples.
77 | -- mu2 : The sample mean over activations, precalculated on an
78 | representive data set.
79 | -- sigma1: The covariance matrix over activations for generated samples.
80 | -- sigma2: The covariance matrix over activations, precalculated on an
81 | representive data set.
82 |
83 | Returns:
84 | -- : The Frechet Distance.
85 | """
86 |
87 | mu1 = np.atleast_1d(mu1)
88 | mu2 = np.atleast_1d(mu2)
89 |
90 | sigma1 = np.atleast_2d(sigma1)
91 | sigma2 = np.atleast_2d(sigma2)
92 |
93 | assert mu1.shape == mu2.shape, \
94 | 'Training and test mean vectors have different lengths'
95 | assert sigma1.shape == sigma2.shape, \
96 | 'Training and test covariances have different dimensions'
97 |
98 | diff = mu1 - mu2
99 |
100 | # Product might be almost singular
101 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
102 | if not np.isfinite(covmean).all():
103 | msg = ('fid calculation produces singular product; '
104 | 'adding %s to diagonal of cov estimates') % eps
105 | print(msg)
106 | offset = np.eye(sigma1.shape[0]) * eps
107 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
108 |
109 | # Numerical error might give slight imaginary component
110 | if np.iscomplexobj(covmean):
111 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
112 | m = np.max(np.abs(covmean.imag))
113 | raise ValueError(f'Imaginary component {m}')
114 | covmean = covmean.real
115 |
116 | tr_covmean = np.trace(covmean)
117 |
118 | return (diff.dot(diff) + np.trace(sigma1) +
119 | np.trace(sigma2) - 2 * tr_covmean)
120 |
121 |
122 | def calculate_activation_statistics(image_generator, images, model,
123 | verbose=False, weight=None):
124 | """Calculation of the statistics used by the FID.
125 | Params:
126 | -- image_generator
127 | : A generator that generates a batch of images at a time.
128 | -- images : Number of images that will be generated by
129 | image_generator.
130 | -- model : Instance of inception model
131 | -- verbose : If set to True and parameter out_step is given, the
132 | number of calculated batches is reported.
133 | Returns:
134 | -- mu : The mean over samples of the activations of the pool_3 layer of
135 | the inception model.
136 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
137 | the inception model.
138 | """
139 | act = get_activations(image_generator, images, model, verbose)
140 | if weight is None:
141 | mu = np.mean(act, axis=0)
142 | sigma = np.cov(act, rowvar=False)
143 | else:
144 | mu = np.average(act, axis=0, weights=weight)
145 | sigma = np.cov(act, rowvar=False, aweights=weight)
146 | return mu, sigma
147 |
148 |
149 | class MNISTModel:
150 | def __init__(self):
151 | model = mnist_model.Net().to(device)
152 | model.eval()
153 | map_location = None if use_cuda else 'cpu'
154 | model.load_state_dict(
155 | torch.load('mnist.pth', map_location=map_location))
156 |
157 | stats_file = f'mnist_act_{feature_layer}.npz'
158 | try:
159 | f = np.load(stats_file)
160 | m_mnist, s_mnist = f['mu'][:], f['sigma'][:]
161 | f.close()
162 | except FileNotFoundError:
163 | data = datasets.MNIST('data', train=True, download=True,
164 | transform=transforms.ToTensor())
165 | images = len(data)
166 | batch_size = 64
167 | data_loader = DataLoader([image for image, _ in data],
168 | batch_size=batch_size)
169 | m_mnist, s_mnist = calculate_activation_statistics(
170 | data_loader, images, model, verbose=True)
171 | np.savez(stats_file, mu=m_mnist, sigma=s_mnist)
172 |
173 | self.model = model
174 | self.mnist_stats = m_mnist, s_mnist
175 |
176 | def get_feature(self, samples):
177 | self.model(samples)
178 | feature = self.model.feature[feature_layer]
179 | return feature.cpu().numpy().reshape(samples.shape[0], -1)
180 |
181 | def fid(self, features):
182 | mu = np.mean(features, axis=0)
183 | sigma = np.cov(features, rowvar=False)
184 | return calculate_frechet_distance(mu, sigma, *self.mnist_stats)
185 |
186 |
187 | def data_generator_fid(data_generator,
188 | n_samples=60000, batch_size=64, verbose=False):
189 | mnist_model = MNISTModel()
190 | latent_size = 128
191 | data_noise = torch.FloatTensor(batch_size, latent_size).to(device)
192 |
193 | with torch.no_grad():
194 | count = 0
195 | features = None
196 | while count < n_samples:
197 | data_noise.normal_()
198 | samples = data_generator(data_noise)
199 | batch_feature = mnist_model.get_feature(samples)
200 |
201 | if features is None:
202 | features = np.empty((n_samples, batch_feature.shape[1]))
203 |
204 | if count + batch_size > n_samples:
205 | batch_size = n_samples - count
206 | features[count:] = batch_feature[:batch_size]
207 | else:
208 | features[count:(count + batch_size)] = batch_feature
209 |
210 | count += batch_size
211 | if verbose:
212 | print(f'\rGenerate images {count}', end='', flush=True)
213 | if verbose:
214 | print(' done')
215 | return mnist_model.fid(features)
216 |
217 |
218 | def imputer_fid(imputer, data, batch_size=64, verbose=False):
219 | mnist_model = MNISTModel()
220 | impu_noise = torch.FloatTensor(batch_size, 1, 28, 28).to(device)
221 | data_loader = DataLoader(data, batch_size=batch_size, drop_last=True)
222 | n_samples = len(data_loader) * batch_size
223 |
224 | with torch.no_grad():
225 | start = 0
226 | features = None
227 | for real_data, real_mask, _, index in data_loader:
228 | real_mask = real_mask.float()[:, None]
229 | real_data = real_data.to(device)
230 | real_mask = real_mask.to(device)
231 | impu_noise.uniform_()
232 | imputed_data = imputer(real_data, real_mask, impu_noise)
233 |
234 | batch_feature = mnist_model.get_feature(imputed_data)
235 | if features is None:
236 | features = np.empty((n_samples, batch_feature.shape[1]))
237 | features[start:(start + batch_size)] = batch_feature
238 | start += batch_size
239 | if verbose:
240 | print(f'\rGenerate images {start}', end='', flush=True)
241 | if verbose:
242 | print(' done')
243 | return mnist_model.fid(features)
244 |
245 |
246 | def pretrained_misgan_fid(model_file, samples=60000, batch_size=64):
247 | model = torch.load(model_file, map_location='cpu')
248 | args = model['args']
249 | if args.generator == 'conv':
250 | DataGenerator = ConvDataGenerator
251 | elif args.generator == 'fc':
252 | DataGenerator = FCDataGenerator
253 | data_gen = DataGenerator().to(device)
254 | data_gen.load_state_dict(model['data_gen'])
255 | return data_generator_fid(data_gen, verbose=True)
256 |
257 |
258 | def pretrained_imputer_fid(model_file, save_file, batch_size=64):
259 | model = torch.load(model_file, map_location='cpu')
260 | if 'imputer' not in model:
261 | return
262 | args = model['args']
263 |
264 | if args.imputer == 'comp':
265 | Imputer = ComplementImputer
266 | elif args.imputer == 'mask':
267 | Imputer = MaskImputer
268 | elif args.imputer == 'fix':
269 | Imputer = FixedNoiseDimImputer
270 |
271 | hid_lens = [int(n) for n in args.arch.split('-')]
272 | imputer = Imputer(arch=hid_lens).to(device)
273 | imputer.load_state_dict(model['imputer'])
274 |
275 | block_len = args.block_len
276 | if block_len == 0:
277 | block_len = None
278 |
279 | if args.mask == 'indep':
280 | data = IndepMaskedMNIST(obs_prob=args.obs_prob,
281 | obs_prob_high=args.obs_prob_high)
282 | elif args.mask == 'block':
283 | data = BlockMaskedMNIST(block_len=block_len)
284 |
285 | fid = imputer_fid(imputer, data, verbose=True)
286 | with save_file.open('w') as f:
287 | print(fid, file=f)
288 | print('imputer fid:', fid)
289 |
290 |
291 | def main():
292 | parser = argparse.ArgumentParser()
293 | parser.add_argument('root_dir')
294 | parser.add_argument('--skip-exist', action='store_true')
295 | args = parser.parse_args()
296 |
297 | skip_exist = args.skip_exist
298 |
299 | root_dir = Path(args.root_dir)
300 | fid_file = root_dir / f'fid-{feature_layer}.txt'
301 | if skip_exist and fid_file.exists():
302 | return
303 | try:
304 | model_file = max((root_dir / 'model').glob('*.pth'))
305 | except ValueError:
306 | return
307 |
308 | fid = pretrained_misgan_fid(model_file)
309 | print(f'{root_dir.name}: {fid}')
310 | with fid_file.open('w') as f:
311 | print(fid, file=f)
312 |
313 | # Compute FID for the imputer if it is in the model
314 | imputer_fid_file = root_dir / f'impute-fid-{feature_layer}.txt'
315 | pretrained_imputer_fid(model_file, imputer_fid_file)
316 |
317 |
318 | if __name__ == '__main__':
319 | main()
320 |
--------------------------------------------------------------------------------
/src-torch1.6/mnist_generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def add_data_transformer(self):
7 | self.transform = lambda x: torch.sigmoid(x).view(-1, 1, 28, 28)
8 |
9 |
10 | def add_mask_transformer(self, temperature=.66, hard_sigmoid=(-.1, 1.1)):
11 | """
12 | hard_sigmoid:
13 | False: use sigmoid only
14 | True: hard thresholding
15 | (a, b): hard thresholding on rescaled sigmoid
16 | """
17 | self.temperature = temperature
18 | self.hard_sigmoid = hard_sigmoid
19 |
20 | view = -1, 1, 28, 28
21 | if hard_sigmoid is False:
22 | self.transform = lambda x: torch.sigmoid(x / temperature).view(*view)
23 | elif hard_sigmoid is True:
24 | self.transform = lambda x: F.hardtanh(
25 | x / temperature, 0, 1).view(*view)
26 | else:
27 | a, b = hard_sigmoid
28 | self.transform = lambda x: F.hardtanh(
29 | torch.sigmoid(x / temperature) * (b - a) + a, 0, 1).view(*view)
30 |
31 |
32 | # Must sub-class ConvGenerator to provide transform()
33 | class ConvGenerator(nn.Module):
34 | def __init__(self, latent_size=128):
35 | super().__init__()
36 |
37 | self.DIM = 64
38 | self.latent_size = latent_size
39 |
40 | self.preprocess = nn.Sequential(
41 | nn.Linear(latent_size, 4 * 4 * 4 * self.DIM),
42 | nn.ReLU(True),
43 | )
44 | self.block1 = nn.Sequential(
45 | nn.ConvTranspose2d(4 * self.DIM, 2 * self.DIM, 5),
46 | nn.ReLU(True),
47 | )
48 | self.block2 = nn.Sequential(
49 | nn.ConvTranspose2d(2 * self.DIM, self.DIM, 5),
50 | nn.ReLU(True),
51 | )
52 | self.deconv_out = nn.ConvTranspose2d(self.DIM, 1, 8, stride=2)
53 |
54 | def forward(self, input):
55 | net = self.preprocess(input)
56 | net = net.view(-1, 4 * self.DIM, 4, 4)
57 | net = self.block1(net)
58 | net = net[:, :, :7, :7]
59 | net = self.block2(net)
60 | net = self.deconv_out(net)
61 | return self.transform(net)
62 |
63 |
64 | # Must sub-class FCGenerator to provide transform()
65 | class FCGenerator(nn.Module):
66 | def __init__(self, latent_size=128):
67 | super().__init__()
68 | self.latent_size = latent_size
69 | self.fc = nn.Sequential(
70 | nn.Linear(latent_size, 256),
71 | nn.ReLU(True),
72 | nn.Linear(256, 512),
73 | nn.ReLU(True),
74 | nn.Linear(512, 784),
75 | )
76 |
77 | def forward(self, input):
78 | net = self.fc(input)
79 | return self.transform(net)
80 |
81 |
82 | class ConvDataGenerator(ConvGenerator):
83 | def __init__(self, latent_size=128):
84 | super().__init__(latent_size=latent_size)
85 | add_data_transformer(self)
86 |
87 |
88 | class FCDataGenerator(FCGenerator):
89 | def __init__(self, latent_size=128):
90 | super().__init__(latent_size=latent_size)
91 | add_data_transformer(self)
92 |
93 |
94 | class ConvMaskGenerator(ConvGenerator):
95 | def __init__(self, latent_size=128, temperature=.66,
96 | hard_sigmoid=(-.1, 1.1)):
97 | super().__init__(latent_size=latent_size)
98 | add_mask_transformer(self, temperature, hard_sigmoid)
99 |
100 |
101 | class FCMaskGenerator(FCGenerator):
102 | def __init__(self, latent_size=128, temperature=.66,
103 | hard_sigmoid=(-.1, 1.1)):
104 | super().__init__(latent_size=latent_size)
105 | add_mask_transformer(self, temperature, hard_sigmoid)
106 |
--------------------------------------------------------------------------------
/src-torch1.6/mnist_imputer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | # Must sub-class Imputer to provide fc1
7 | class Imputer(nn.Module):
8 | def __init__(self, arch=(784, 784)):
9 | super().__init__()
10 | # self.fc1 = nn.Linear(784, arch[0])
11 | self.fc2 = nn.Linear(arch[0], arch[1])
12 | self.fc3 = nn.Linear(arch[1], arch[0])
13 | self.fc4 = nn.Linear(arch[0], 784)
14 | self.transform = lambda x: torch.sigmoid(x).view(-1, 1, 28, 28)
15 |
16 | def forward(self, input, data, mask):
17 | net = input.view(input.size(0), -1)
18 | net = F.relu(self.fc1(net))
19 | net = F.relu(self.fc2(net))
20 | net = F.relu(self.fc3(net))
21 | net = self.fc4(net)
22 | net = self.transform(net)
23 | # return data * mask + net * (1 - mask)
24 | # NOT replacing observed part with input data for computing
25 | # autoencoding loss.
26 | return net
27 |
28 |
29 | class ComplementImputer(Imputer):
30 | def __init__(self, arch=(784, 784)):
31 | super().__init__(arch=arch)
32 | self.fc1 = nn.Linear(784, arch[0])
33 |
34 | def forward(self, input, mask, noise):
35 | net = input * mask + noise * (1 - mask)
36 | return super().forward(net, input, mask)
37 |
38 |
39 | class MaskImputer(Imputer):
40 | def __init__(self, arch=(784, 784)):
41 | super().__init__(arch=arch)
42 | self.fc1 = nn.Linear(784 * 2, arch[0])
43 |
44 | def forward(self, input, mask, noise):
45 | batch_size = input.size(0)
46 | net = torch.cat(
47 | [(input * mask + noise * (1 - mask)).view(batch_size, -1),
48 | mask.view(batch_size, -1)], 1)
49 | return super().forward(net, input, mask)
50 |
51 |
52 | class FixedNoiseDimImputer(Imputer):
53 | def __init__(self, arch=(784, 784)):
54 | super().__init__(arch=arch)
55 | self.fc1 = nn.Linear(784 * 3, arch[0])
56 |
57 | def forward(self, input, mask, noise):
58 | batch_size = input.size(0)
59 | net = torch.cat([(input * mask).view(batch_size, -1),
60 | mask.view(batch_size, -1),
61 | noise.view(batch_size, -1)], 1)
62 | return super().forward(net, input, mask)
63 |
--------------------------------------------------------------------------------
/src-torch1.6/mnist_misgan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datetime import datetime
3 | from pathlib import Path
4 | import argparse
5 | from mnist_generator import (ConvDataGenerator, FCDataGenerator,
6 | ConvMaskGenerator, FCMaskGenerator)
7 | from mnist_critic import ConvCritic, FCCritic
8 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST
9 | from misgan import misgan
10 |
11 |
12 | use_cuda = torch.cuda.is_available()
13 | device = torch.device('cuda' if use_cuda else 'cpu')
14 |
15 |
16 | def main():
17 | parser = argparse.ArgumentParser()
18 |
19 | # resume from checkpoint
20 | parser.add_argument('--resume')
21 | # training options
22 | parser.add_argument('--epoch', type=int, default=500)
23 | parser.add_argument('--batch-size', type=int, default=64)
24 |
25 | # log options: 0 to disable plot-interval or save-interval
26 | parser.add_argument('--plot-interval', type=int, default=50)
27 | parser.add_argument('--save-interval', type=int, default=0)
28 | parser.add_argument('--prefix', default='misgan')
29 |
30 | # mask options (data): block|indep
31 | parser.add_argument('--mask', default='block')
32 | # option for block: set to 0 for variable size
33 | parser.add_argument('--block-len', type=int, default=14)
34 | # option for indep:
35 | parser.add_argument('--obs-prob', type=float, default=.2)
36 | parser.add_argument('--obs-prob-high', type=float, default=None)
37 |
38 | # model options
39 | parser.add_argument('--tau', type=float, default=0)
40 | parser.add_argument('--generator', default='conv') # conv|fc
41 | parser.add_argument('--critic', default='conv') # conv|fc
42 | # parser.add_argument('--alpha', type=float, default=.1) # 0: separate
43 | parser.add_argument('--alpha', type=float, default=.2) # 0: separate
44 | # options for mask generator: sigmoid, hardsigmoid, fusion
45 | # parser.add_argument('--maskgen', default='fusion')
46 | parser.add_argument('--maskgen', default='sigmoid')
47 | parser.add_argument('--gp-lambda', type=float, default=10)
48 | parser.add_argument('--n-critic', type=int, default=5)
49 | parser.add_argument('--n-latent', type=int, default=128)
50 |
51 | args = parser.parse_args()
52 |
53 | checkpoint = None
54 | # Resume from previously stored checkpoint
55 | if args.resume:
56 | print(f'Resume: {args.resume}')
57 | output_dir = Path(args.resume)
58 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'),
59 | map_location='cpu')
60 | for key, arg in vars(checkpoint['args']).items():
61 | if key not in ['resume']:
62 | setattr(args, key, arg)
63 |
64 | if args.generator == 'conv':
65 | DataGenerator = ConvDataGenerator
66 | MaskGenerator = ConvMaskGenerator
67 | elif args.generator == 'fc':
68 | DataGenerator = FCDataGenerator
69 | MaskGenerator = FCMaskGenerator
70 | else:
71 | raise NotImplementedError
72 |
73 | if args.critic == 'conv':
74 | Critic = ConvCritic
75 | elif args.critic == 'fc':
76 | Critic = FCCritic
77 | else:
78 | raise NotImplementedError
79 |
80 | if args.maskgen == 'sigmoid':
81 | hard_sigmoid = False
82 | elif args.maskgen == 'hardsigmoid':
83 | hard_sigmoid = True
84 | elif args.maskgen == 'fusion':
85 | hard_sigmoid = -.1, 1.1
86 | else:
87 | raise NotImplementedError
88 |
89 | mask = args.mask
90 | obs_prob = args.obs_prob
91 | obs_prob_high = args.obs_prob_high
92 | block_len = args.block_len
93 | if block_len == 0:
94 | block_len = None
95 |
96 | if mask == 'indep':
97 | if obs_prob_high is None:
98 | mask_str = f'indep_{obs_prob:g}'
99 | else:
100 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}'
101 | elif mask == 'block':
102 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize')
103 | else:
104 | raise NotImplementedError
105 |
106 | path = '{}_{}_{}'.format(
107 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'),
108 | '_'.join([
109 | f'gen_{args.generator}',
110 | f'critic_{args.critic}',
111 | f'tau_{args.tau:g}',
112 | f'alpha_{args.alpha:g}',
113 | f'maskgen_{args.maskgen}',
114 | mask_str,
115 | ]))
116 |
117 | if not args.resume:
118 | output_dir = Path('results') / 'mnist' / path
119 | print(output_dir)
120 |
121 | if mask == 'indep':
122 | data = IndepMaskedMNIST(obs_prob=obs_prob, obs_prob_high=obs_prob_high)
123 | elif mask == 'block':
124 | data = BlockMaskedMNIST(block_len=block_len)
125 |
126 | data_gen = DataGenerator().to(device)
127 | mask_gen = MaskGenerator(hard_sigmoid=hard_sigmoid).to(device)
128 |
129 | data_critic = Critic().to(device)
130 | mask_critic = Critic().to(device)
131 |
132 | misgan(args, data_gen, mask_gen, data_critic, mask_critic, data,
133 | output_dir, checkpoint)
134 |
135 |
136 | if __name__ == '__main__':
137 | main()
138 |
--------------------------------------------------------------------------------
/src-torch1.6/mnist_misgan_impute.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datetime import datetime
3 | from pathlib import Path
4 | import argparse
5 | from mnist_generator import (ConvDataGenerator, FCDataGenerator,
6 | ConvMaskGenerator, FCMaskGenerator)
7 | from mnist_imputer import (ComplementImputer,
8 | MaskImputer,
9 | FixedNoiseDimImputer)
10 | from mnist_critic import ConvCritic, FCCritic
11 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST
12 | from misgan_impute import misgan_impute
13 |
14 |
15 | use_cuda = torch.cuda.is_available()
16 | device = torch.device('cuda' if use_cuda else 'cpu')
17 |
18 |
19 | def main():
20 | parser = argparse.ArgumentParser()
21 |
22 | # resume from checkpoint
23 | parser.add_argument('--resume')
24 |
25 | # training options
26 | parser.add_argument('--workers', type=int, default=0)
27 | parser.add_argument('--epoch', type=int, default=1000)
28 | parser.add_argument('--batch-size', type=int, default=64)
29 | parser.add_argument('--pretrain', default=None)
30 | parser.add_argument('--imputeronly', action='store_true')
31 |
32 | # log options: 0 to disable plot-interval or save-interval
33 | parser.add_argument('--plot-interval', type=int, default=100)
34 | parser.add_argument('--save-interval', type=int, default=0)
35 | parser.add_argument('--prefix', default='impute')
36 |
37 | # mask options (data): block|indep
38 | parser.add_argument('--mask', default='block')
39 | # option for block: set to 0 for variable size
40 | parser.add_argument('--block-len', type=int, default=14)
41 | # option for indep:
42 | parser.add_argument('--obs-prob', type=float, default=.2)
43 | parser.add_argument('--obs-prob-high', type=float, default=None)
44 |
45 | # model options
46 | parser.add_argument('--tau', type=float, default=0)
47 | parser.add_argument('--generator', default='conv') # conv|fc
48 | parser.add_argument('--critic', default='conv') # conv|fc
49 | parser.add_argument('--alpha', type=float, default=.1) # 0: separate
50 | parser.add_argument('--beta', type=float, default=.1)
51 | parser.add_argument('--gamma', type=float, default=0)
52 | parser.add_argument('--arch', default='784-784')
53 | parser.add_argument('--imputer', default='comp') # comp|mask|fix
54 | # options for mask generator: sigmoid, hardsigmoid, fusion
55 | parser.add_argument('--maskgen', default='fusion')
56 | parser.add_argument('--gp-lambda', type=float, default=10)
57 | parser.add_argument('--n-critic', type=int, default=5)
58 | parser.add_argument('--n-latent', type=int, default=128)
59 |
60 | args = parser.parse_args()
61 |
62 | checkpoint = None
63 | # Resume from previously stored checkpoint
64 | if args.resume:
65 | print(f'Resume: {args.resume}')
66 | output_dir = Path(args.resume)
67 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'),
68 | map_location='cpu')
69 | for key, arg in vars(checkpoint['args']).items():
70 | if key not in ['resume']:
71 | setattr(args, key, arg)
72 |
73 | if args.imputeronly:
74 | assert args.pretrain is not None
75 |
76 | arch = args.arch
77 | imputer_type = args.imputer
78 | mask = args.mask
79 | obs_prob = args.obs_prob
80 | obs_prob_high = args.obs_prob_high
81 | block_len = args.block_len
82 | if block_len == 0:
83 | block_len = None
84 |
85 | if args.generator == 'conv':
86 | DataGenerator = ConvDataGenerator
87 | MaskGenerator = ConvMaskGenerator
88 | elif args.generator == 'fc':
89 | DataGenerator = FCDataGenerator
90 | MaskGenerator = FCMaskGenerator
91 | else:
92 | raise NotImplementedError
93 |
94 | if imputer_type == 'comp':
95 | Imputer = ComplementImputer
96 | elif imputer_type == 'mask':
97 | Imputer = MaskImputer
98 | elif imputer_type == 'fix':
99 | Imputer = FixedNoiseDimImputer
100 | else:
101 | raise NotImplementedError
102 |
103 | if args.critic == 'conv':
104 | Critic = ConvCritic
105 | elif args.critic == 'fc':
106 | Critic = FCCritic
107 | else:
108 | raise NotImplementedError
109 |
110 | if args.maskgen == 'sigmoid':
111 | hard_sigmoid = False
112 | elif args.maskgen == 'hardsigmoid':
113 | hard_sigmoid = True
114 | elif args.maskgen == 'fusion':
115 | hard_sigmoid = -.1, 1.1
116 | else:
117 | raise NotImplementedError
118 |
119 | if mask == 'indep':
120 | if obs_prob_high is None:
121 | mask_str = f'indep_{obs_prob:g}'
122 | else:
123 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}'
124 | elif mask == 'block':
125 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize')
126 | else:
127 | raise NotImplementedError
128 |
129 | path = '{}_{}_{}'.format(
130 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'),
131 | '_'.join([
132 | f'gen_{args.generator}',
133 | f'critic_{args.critic}',
134 | f'imp_{args.imputer}',
135 | f'tau_{args.tau:g}',
136 | f'arch_{args.arch}',
137 | f'maskgen_{args.maskgen}',
138 | f'coef_{args.alpha:g}_{args.beta:g}_{args.gamma:g}',
139 | mask_str
140 | ]))
141 |
142 | if not args.resume:
143 | output_dir = Path('results') / 'mnist' / path
144 | print(output_dir)
145 |
146 | if mask == 'indep':
147 | data = IndepMaskedMNIST(
148 | obs_prob=obs_prob, obs_prob_high=obs_prob_high)
149 | elif mask == 'block':
150 | data = BlockMaskedMNIST(block_len=block_len)
151 |
152 | data_gen = DataGenerator().to(device)
153 | mask_gen = MaskGenerator(hard_sigmoid=hard_sigmoid).to(device)
154 |
155 | hid_lens = [int(n) for n in arch.split('-')]
156 | imputer = Imputer(arch=hid_lens).to(device)
157 |
158 | data_critic = Critic().to(device)
159 | mask_critic = Critic().to(device)
160 | impu_critic = Critic().to(device)
161 |
162 | misgan_impute(args, data_gen, mask_gen, imputer,
163 | data_critic, mask_critic, impu_critic,
164 | data, output_dir, checkpoint)
165 |
166 |
167 | if __name__ == '__main__':
168 | main()
169 |
--------------------------------------------------------------------------------
/src-torch1.6/mnist_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Code adapted from https://github.com/pytorch/examples/blob/master/mnist/main.py
3 | """
4 | import argparse
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.optim as optim
9 | from torchvision import datasets, transforms
10 |
11 |
12 | class Net(nn.Module):
13 | def __init__(self):
14 | super(Net, self).__init__()
15 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
16 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
17 | self.conv2_drop = nn.Dropout2d()
18 | self.fc1 = nn.Linear(320, 50)
19 | self.fc2 = nn.Linear(50, 10)
20 |
21 | def forward(self, x):
22 | feature = []
23 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
24 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
25 | x = x.view(-1, 320)
26 | x = self.fc1(x)
27 | feature.append(x)
28 | x = F.relu(x)
29 | x = F.dropout(x, training=self.training)
30 | x = self.fc2(x)
31 | feature.append(x)
32 | self.feature = feature
33 | return F.log_softmax(x, dim=1)
34 |
35 |
36 | def main():
37 | # Training settings
38 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
39 | parser.add_argument('--batch-size', type=int, default=64, metavar='N',
40 | help='input batch size for training (default: 64)')
41 | parser.add_argument('--test-batch-size', type=int,
42 | default=1000, metavar='N',
43 | help='input batch size for testing (default: 1000)')
44 | parser.add_argument('--epochs', type=int, default=100, metavar='N',
45 | help='number of epochs to train (default: 100)')
46 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
47 | help='learning rate (default: 0.01)')
48 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
49 | help='SGD momentum (default: 0.5)')
50 | parser.add_argument('--no-cuda', action='store_true', default=False,
51 | help='disables CUDA training')
52 | parser.add_argument('--seed', type=int, default=1, metavar='S',
53 | help='random seed (default: 1)')
54 | parser.add_argument('--log-interval', type=int, default=10, metavar='N',
55 | help='number of batches to wait before logging '
56 | 'training status')
57 | args = parser.parse_args()
58 | args.cuda = not args.no_cuda and torch.cuda.is_available()
59 |
60 | torch.manual_seed(args.seed)
61 | if args.cuda:
62 | torch.cuda.manual_seed(args.seed)
63 |
64 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
65 | train_loader = torch.utils.data.DataLoader(
66 | datasets.MNIST('../data', train=True, download=True,
67 | transform=transforms.Compose([
68 | transforms.ToTensor(),
69 | transforms.Normalize((0.1307,), (0.3081,))
70 | ])),
71 | batch_size=args.batch_size, shuffle=True, **kwargs)
72 | test_loader = torch.utils.data.DataLoader(
73 | datasets.MNIST('../data', train=False, transform=transforms.Compose([
74 | transforms.ToTensor(),
75 | transforms.Normalize((0.1307,), (0.3081,))
76 | ])),
77 | batch_size=args.test_batch_size, shuffle=True, **kwargs)
78 |
79 | model = Net()
80 | if args.cuda:
81 | model.cuda()
82 |
83 | optimizer = optim.SGD(model.parameters(), lr=args.lr,
84 | momentum=args.momentum)
85 |
86 | def train(epoch):
87 | model.train()
88 | for batch_idx, (data, target) in enumerate(train_loader):
89 | if args.cuda:
90 | data, target = data.cuda(), target.cuda()
91 | optimizer.zero_grad()
92 | output = model(data)
93 | loss = F.nll_loss(output, target)
94 | loss.backward()
95 | optimizer.step()
96 | if batch_idx % args.log_interval == 0:
97 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
98 | epoch, batch_idx * len(data), len(train_loader.dataset),
99 | 100. * batch_idx / len(train_loader), loss.item()))
100 |
101 | def test():
102 | model.eval()
103 | test_loss = 0
104 | correct = 0
105 | with torch.no_grad():
106 | for data, target in test_loader:
107 | if args.cuda:
108 | data, target = data.cuda(), target.cuda()
109 | output = model(data)
110 | # sum up batch loss
111 | test_loss += F.nll_loss(output, target, reduction='sum').item()
112 | # get the index of the max log-probability
113 | pred = output.argmax(dim=1, keepdim=True)
114 | correct += (pred == target.view_as(pred)).long().cpu().sum()
115 |
116 | test_loss /= len(test_loader.dataset)
117 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
118 | .format(test_loss, correct, len(test_loader.dataset),
119 | 100. * correct / len(test_loader.dataset)))
120 |
121 | for epoch in range(1, args.epochs + 1):
122 | train(epoch)
123 | test()
124 |
125 | torch.save(model.state_dict(), 'mnist.pth')
126 |
127 |
128 | if __name__ == '__main__':
129 | main()
130 |
--------------------------------------------------------------------------------
/src-torch1.6/plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pylab as plt
3 | from matplotlib.patches import Rectangle
4 | from PIL import Image
5 |
6 |
7 | def plot_grid(image, bbox=None, gap=0, gap_value=1,
8 | nrow=4, ncol=8, save_file=None):
9 | image = image.cpu().numpy()
10 | channels, len0, len1 = image[0].shape
11 | grid = np.empty(
12 | (nrow * (len0 + gap) - gap, ncol * (len1 + gap) - gap, channels))
13 | # Convert to W, H, C
14 | image = image.transpose((0, 2, 3, 1))
15 | grid.fill(gap_value)
16 |
17 | for i, x in enumerate(image):
18 | if i >= nrow * ncol:
19 | break
20 | p0 = (i // ncol) * (len0 + gap)
21 | p1 = (i % ncol) * (len1 + gap)
22 | grid[p0:(p0 + len0), p1:(p1 + len1)] = x
23 |
24 | # figsize = np.r_[ncol, nrow] * .75
25 | scale = 2.5
26 | figsize = ncol * scale, nrow * scale # FIXME: scale by len0, len1
27 | fig = plt.figure(figsize=figsize)
28 | ax = plt.Axes(fig, [0, 0, 1, 1])
29 | ax.set_axis_off()
30 | fig.add_axes(ax)
31 | grid = grid.squeeze()
32 | ax.imshow(grid, cmap='binary_r', interpolation='none', aspect='equal')
33 |
34 | if bbox is not None:
35 | nplot = min(len(image), nrow * ncol)
36 | for i in range(nplot):
37 | if len(bbox) == 1:
38 | d0, d1, d0_len, d1_len = bbox[0]
39 | else:
40 | d0, d1, d0_len, d1_len = bbox[i]
41 | p0 = (i // ncol) * (len0 + gap)
42 | p1 = (i % ncol) * (len1 + gap)
43 | offset = np.array([p1 + d1, p0 + d0]) - .5
44 | ax.add_patch(Rectangle(
45 | offset, d1_len, d0_len, lw=4, edgecolor='red', fill=False))
46 |
47 | if save_file:
48 | fig.savefig(save_file)
49 | plt.close(fig)
50 |
51 |
52 | def plot_samples(samples, save_file, nrow=4, ncol=8):
53 | x = samples.cpu().numpy()
54 | channels, len0, len1 = x[0].shape
55 | x_merge = np.zeros((nrow * len0, ncol * len1, channels))
56 |
57 | for i, x_ in enumerate(x):
58 | if i >= nrow * ncol:
59 | break
60 | p0 = (i // ncol) * len0
61 | p1 = (i % ncol) * len1
62 | x_merge[p0:(p0 + len0), p1:(p1 + len1)] = x_.transpose((1, 2, 0))
63 |
64 | x_merge = (x_merge * 255).clip(0, 255).astype(np.uint8)
65 | # squeeze() to remove the last dimension for the single-channel image.
66 | im = Image.fromarray(x_merge.squeeze())
67 | im.save(save_file)
68 |
--------------------------------------------------------------------------------
/src-torch1.6/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.3.2
2 | numpy==1.19.2
3 | Pillow==8.0.1
4 | scipy==1.5.3
5 | seaborn==0.11.0
6 | torch==1.6.0
7 | torchvision==0.7.0
8 |
--------------------------------------------------------------------------------
/src-torch1.6/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | # Code adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
6 | #
7 | # Defines the submodule with skip connection.
8 | # X -------------------identity---------------------- X
9 | # |-- downsampling -- |submodule| -- upsampling --|
10 | class UnetSkipConnectionBlock(nn.Module):
11 | def __init__(self, outer_nc, inner_nc, input_nc=None,
12 | submodule=None, outermost=False, innermost=False,
13 | norm_layer=nn.BatchNorm2d):
14 | super().__init__()
15 | self.outermost = outermost
16 | use_bias = norm_layer == nn.InstanceNorm2d
17 | if input_nc is None:
18 | input_nc = outer_nc
19 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
20 | stride=2, padding=1, bias=use_bias)
21 | downrelu = nn.LeakyReLU(0.2, True)
22 | if norm_layer is not None:
23 | downnorm = norm_layer(inner_nc)
24 | upnorm = norm_layer(outer_nc)
25 | uprelu = nn.ReLU(True)
26 |
27 | if outermost:
28 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
29 | kernel_size=4, stride=2,
30 | padding=1)
31 | down = [downconv]
32 | up = [uprelu, upconv]
33 | model = down + [submodule] + up
34 | elif innermost:
35 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
36 | kernel_size=4, stride=2,
37 | padding=1, bias=use_bias)
38 | down = [downrelu, downconv]
39 | up = [uprelu, upconv]
40 | if norm_layer is not None:
41 | up.append(upnorm)
42 | model = down + up
43 | else:
44 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
45 | kernel_size=4, stride=2,
46 | padding=1, bias=use_bias)
47 | down = [downrelu, downconv]
48 | up = [uprelu, upconv]
49 | if norm_layer is not None:
50 | down.append(downnorm)
51 | up.append(upnorm)
52 |
53 | model = down + [submodule] + up
54 |
55 | self.model = nn.Sequential(*model)
56 |
57 | def forward(self, x):
58 | if self.outermost:
59 | return self.model(x)
60 | else:
61 | return torch.cat([x, self.model(x)], 1)
62 |
--------------------------------------------------------------------------------
/src-torch1.6/utils.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import grad
2 |
3 |
4 | class CriticUpdater:
5 | def __init__(self, critic, critic_optimizer, eps, ones, gp_lambda=10):
6 | self.critic = critic
7 | self.critic_optimizer = critic_optimizer
8 | self.eps = eps
9 | self.ones = ones
10 | self.gp_lambda = gp_lambda
11 |
12 | def __call__(self, real, fake):
13 | real = real.detach()
14 | fake = fake.detach()
15 | self.critic.zero_grad()
16 | self.eps.uniform_(0, 1)
17 | interp = (self.eps * real + (1 - self.eps) * fake).requires_grad_()
18 | grad_d = grad(self.critic(interp), interp, grad_outputs=self.ones,
19 | create_graph=True)[0]
20 | grad_d = grad_d.view(real.shape[0], -1)
21 | grad_penalty = ((grad_d.norm(dim=1) - 1)**2).mean() * self.gp_lambda
22 | w_dist = self.critic(fake).mean() - self.critic(real).mean()
23 | loss = w_dist + grad_penalty
24 | loss.backward()
25 | self.critic_optimizer.step()
26 | self.loss_value = loss.item()
27 |
28 |
29 | def mask_norm(diff, mask):
30 | """Mask normalization"""
31 | dim = 1, 2, 3
32 | # Assume mask.sum(1) is non-zero throughout
33 | return ((diff * mask).sum(dim) / mask.sum(dim)).mean()
34 |
35 |
36 | def mkdir(path):
37 | path.mkdir(parents=True, exist_ok=True)
38 | return path
39 |
40 |
41 | def mask_data(data, mask, tau):
42 | return mask * data + (1 - mask) * tau
43 |
--------------------------------------------------------------------------------
/src/celeba_critic.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def conv_ln_lrelu(in_dim, out_dim):
5 | return nn.Sequential(
6 | nn.Conv2d(in_dim, out_dim, 5, 2, 2),
7 | nn.InstanceNorm2d(out_dim, affine=True),
8 | nn.LeakyReLU(0.2))
9 |
10 |
11 | class ConvCritic(nn.Module):
12 | def __init__(self, n_channels):
13 | super().__init__()
14 | dim = 64
15 | self.ls = nn.Sequential(
16 | nn.Conv2d(n_channels, dim, 5, 2, 2), nn.LeakyReLU(0.2),
17 | conv_ln_lrelu(dim, dim * 2),
18 | conv_ln_lrelu(dim * 2, dim * 4),
19 | conv_ln_lrelu(dim * 4, dim * 8),
20 | nn.Conv2d(dim * 8, 1, 4))
21 |
22 | def forward(self, input):
23 | net = self.ls(input)
24 | return net.view(-1)
25 |
--------------------------------------------------------------------------------
/src/celeba_fid.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | import torch
4 | import torch.nn as nn
5 | from torch.utils.data import DataLoader
6 | from torchvision import datasets, transforms
7 | from PIL import Image
8 | from celeba_generator import ConvDataGenerator
9 | from fid import BaseSampler, BaseImputationSampler
10 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA
11 | from imputer import UNetImputer
12 | from fid import FID
13 |
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('root_dir')
17 | parser.add_argument('--batch-size', type=int, default=256)
18 | parser.add_argument('--workers', type=int, default=0)
19 | parser.add_argument('--skip-exist', action='store_true')
20 | args = parser.parse_args()
21 |
22 |
23 | use_cuda = torch.cuda.is_available()
24 | device = torch.device('cuda' if use_cuda else 'cpu')
25 |
26 |
27 | class CelebAFID(FID):
28 | def __init__(self, batch_size=256, data_name='celeba',
29 | workers=0, verbose=True):
30 | self.batch_size = batch_size
31 | self.workers = workers
32 | super().__init__(data_name, verbose)
33 |
34 | def complete_data(self):
35 | data = datasets.ImageFolder(
36 | 'celeba',
37 | transforms.Compose([
38 | transforms.CenterCrop(108),
39 | transforms.Resize(size=64, interpolation=Image.BICUBIC),
40 | transforms.ToTensor(),
41 | # transforms.Normalize(mean=(.5, .5, .5), std=(.5, .5, .5)),
42 | ]))
43 |
44 | images = len(data)
45 | data_loader = DataLoader(
46 | data, batch_size=self.batch_size, num_workers=self.workers)
47 |
48 | return data_loader, images
49 |
50 |
51 | class MisGANSampler(BaseSampler):
52 | def __init__(self, data_gen, images=60000, batch_size=256):
53 | super().__init__(images)
54 | self.data_gen = data_gen
55 | self.batch_size = batch_size
56 | latent_dim = 128
57 | self.data_noise = torch.FloatTensor(batch_size, latent_dim).to(device)
58 |
59 | def sample(self):
60 | self.data_noise.normal_()
61 | return self.data_gen(self.data_noise)
62 |
63 |
64 | class MisGANImputationSampler(BaseImputationSampler):
65 | def __init__(self, data_loader, imputer, batch_size=256):
66 | super().__init__(data_loader)
67 | self.imputer = imputer
68 | self.impu_noise = torch.FloatTensor(batch_size, 3, 64, 64).to(device)
69 |
70 | def impute(self, data, mask):
71 | if data.shape[0] != self.impu_noise.shape[0]:
72 | self.impu_noise.resize_(data.shape)
73 | self.impu_noise.uniform_()
74 | return self.imputer(data, mask, self.impu_noise)
75 |
76 |
77 | def get_data_loader(args, batch_size):
78 | if args.mask == 'indep':
79 | data = IndepMaskedCelebA(
80 | data_dir=args.data_dir,
81 | obs_prob=args.obs_prob, obs_prob_high=args.obs_prob_high)
82 | elif args.mask == 'block':
83 | data = BlockMaskedCelebA(
84 | data_dir=args.data_dir, block_len=args.block_len)
85 |
86 | data_size = len(data)
87 | data_loader = DataLoader(
88 | data, batch_size=batch_size, num_workers=args.workers)
89 | return data_loader, data_size
90 |
91 |
92 | def parallelize(model):
93 | return nn.DataParallel(model).to(device)
94 |
95 |
96 | def pretrained_misgan_fid(model_file, samples=202599):
97 | model = torch.load(model_file, map_location='cpu')
98 | data_gen = parallelize(ConvDataGenerator())
99 | data_gen.load_state_dict(model['data_gen'])
100 |
101 | batch_size = args.batch_size
102 |
103 | compute_fid = CelebAFID(batch_size=batch_size)
104 | sampler = MisGANSampler(data_gen, samples, batch_size)
105 | gen_fid = compute_fid.fid(sampler, samples)
106 | print(f'fid: {gen_fid:.2f}')
107 |
108 | imp_fid = None
109 | if 'imputer' in model:
110 | imputer = UNetImputer().to(device)
111 | imputer.load_state_dict(model['imputer'])
112 | data_loader, data_size = get_data_loader(model['args'], batch_size)
113 | imputation_sampler = MisGANImputationSampler(
114 | data_loader, imputer, batch_size)
115 | imp_fid = compute_fid.fid(imputation_sampler, data_size)
116 | print(f'impute fid: {imp_fid:.2f}')
117 |
118 | return gen_fid, imp_fid
119 |
120 |
121 | def main():
122 | root_dir = Path(args.root_dir)
123 | fid_file = root_dir / 'fid.txt'
124 | if args.skip_exist and fid_file.exists():
125 | return
126 | try:
127 | model_file = max((root_dir / 'model').glob('*.pth'))
128 | except ValueError:
129 | return
130 |
131 | print(root_dir.name)
132 | fid, imp_fid = pretrained_misgan_fid(model_file)
133 |
134 | with fid_file.open('w') as f:
135 | print(fid, file=f)
136 |
137 | if imp_fid is not None:
138 | with (root_dir / 'impute-fid.txt').open('w') as f:
139 | print(imp_fid, file=f)
140 |
141 |
142 | if __name__ == '__main__':
143 | main()
144 |
--------------------------------------------------------------------------------
/src/celeba_generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def add_mask_transformer(self, temperature=.66, hard_sigmoid=(-.1, 1.1)):
7 | """
8 | hard_sigmoid:
9 | False: use sigmoid only
10 | True: hard thresholding
11 | (a, b): hard thresholding on rescaled sigmoid
12 | """
13 | self.temperature = temperature
14 | self.hard_sigmoid = hard_sigmoid
15 |
16 | if hard_sigmoid is False:
17 | self.transform = lambda x: torch.sigmoid(x / temperature)
18 | elif hard_sigmoid is True:
19 | self.transform = lambda x: F.hardtanh(
20 | x / temperature, 0, 1)
21 | else:
22 | a, b = hard_sigmoid
23 | self.transform = lambda x: F.hardtanh(
24 | torch.sigmoid(x / temperature) * (b - a) + a, 0, 1)
25 |
26 |
27 | def dconv_bn_relu(in_dim, out_dim):
28 | return nn.Sequential(
29 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
30 | padding=2, output_padding=1, bias=False),
31 | nn.BatchNorm2d(out_dim),
32 | nn.ReLU())
33 |
34 |
35 | # Must sub-class ConvGenerator to provide transform()
36 | class ConvGenerator(nn.Module):
37 | def __init__(self, latent_size=128):
38 | super().__init__()
39 |
40 | dim = 64
41 |
42 | self.l1 = nn.Sequential(
43 | nn.Linear(latent_size, dim * 8 * 4 * 4, bias=False),
44 | nn.BatchNorm1d(dim * 8 * 4 * 4),
45 | nn.ReLU())
46 |
47 | self.l2_5 = nn.Sequential(
48 | dconv_bn_relu(dim * 8, dim * 4),
49 | dconv_bn_relu(dim * 4, dim * 2),
50 | dconv_bn_relu(dim * 2, dim),
51 | nn.ConvTranspose2d(dim, self.out_channels, 5, 2,
52 | padding=2, output_padding=1))
53 |
54 | def forward(self, input):
55 | net = self.l1(input)
56 | net = net.view(net.shape[0], -1, 4, 4)
57 | net = self.l2_5(net)
58 | return self.transform(net)
59 |
60 |
61 | class ConvDataGenerator(ConvGenerator):
62 | def __init__(self, latent_size=128):
63 | self.out_channels = 3
64 | super().__init__(latent_size=latent_size)
65 | self.transform = lambda x: torch.sigmoid(x)
66 |
67 |
68 | class ConvMaskGenerator(ConvGenerator):
69 | def __init__(self, latent_size=128, temperature=.66,
70 | hard_sigmoid=(-.1, 1.1)):
71 | self.out_channels = 1
72 | super().__init__(latent_size=latent_size)
73 | add_mask_transformer(self, temperature, hard_sigmoid)
74 |
--------------------------------------------------------------------------------
/src/celeba_misgan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from datetime import datetime
4 | from pathlib import Path
5 | import argparse
6 | from celeba_generator import ConvDataGenerator, ConvMaskGenerator
7 | from celeba_critic import ConvCritic
8 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA
9 | from misgan import misgan
10 |
11 |
12 | use_cuda = torch.cuda.is_available()
13 | device = torch.device('cuda' if use_cuda else 'cpu')
14 |
15 |
16 | def parallelize(model):
17 | return nn.DataParallel(model).to(device)
18 |
19 |
20 | def main():
21 | parser = argparse.ArgumentParser()
22 |
23 | # resume from checkpoint
24 | parser.add_argument('--resume')
25 |
26 | # path of CelebA dataset
27 | parser.add_argument('--data-dir', default='celeba-data')
28 |
29 | # training options
30 | parser.add_argument('--epoch', type=int, default=600)
31 | parser.add_argument('--batch-size', type=int, default=256)
32 |
33 | # log options: 0 to disable plot-interval or save-interval
34 | parser.add_argument('--plot-interval', type=int, default=100)
35 | parser.add_argument('--save-interval', type=int, default=0)
36 | parser.add_argument('--prefix', default='misgan')
37 |
38 | # mask options (data): block|indep
39 | parser.add_argument('--mask', default='block')
40 | # option for block: set to 0 for variable size
41 | parser.add_argument('--block-len', type=int, default=32)
42 | # option for indep:
43 | parser.add_argument('--obs-prob', type=float, default=.2)
44 | parser.add_argument('--obs-prob-high', type=float, default=None)
45 |
46 | # model options
47 | parser.add_argument('--tau', type=float, default=.5)
48 | parser.add_argument('--alpha', type=float, default=.1) # 0: separate
49 | # options for mask generator: sigmoid, hardsigmoid, fusion
50 | parser.add_argument('--maskgen', default='fusion')
51 | parser.add_argument('--gp-lambda', type=float, default=10)
52 | parser.add_argument('--n-critic', type=int, default=5)
53 | parser.add_argument('--n-latent', type=int, default=128)
54 |
55 | args = parser.parse_args()
56 |
57 | checkpoint = None
58 | # Resume from previously stored checkpoint
59 | if args.resume:
60 | print(f'Resume: {args.resume}')
61 | output_dir = Path(args.resume)
62 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'),
63 | map_location='cpu')
64 | for key, arg in vars(checkpoint['args']).items():
65 | if key not in ['resume']:
66 | setattr(args, key, arg)
67 |
68 | if args.maskgen == 'sigmoid':
69 | hard_sigmoid = False
70 | elif args.maskgen == 'hardsigmoid':
71 | hard_sigmoid = True
72 | elif args.maskgen == 'fusion':
73 | hard_sigmoid = -.1, 1.1
74 | else:
75 | raise NotImplementedError
76 |
77 | mask = args.mask
78 | obs_prob = args.obs_prob
79 | obs_prob_high = args.obs_prob_high
80 | block_len = args.block_len
81 | if block_len == 0:
82 | block_len = None
83 | if mask == 'indep':
84 | if obs_prob_high is None:
85 | mask_str = f'indep_{obs_prob:g}'
86 | else:
87 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}'
88 | elif mask == 'block':
89 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize')
90 | else:
91 | raise NotImplementedError
92 |
93 | path = '{}_{}_{}'.format(
94 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'),
95 | '_'.join([
96 | f'tau_{args.tau:g}',
97 | f'alpha_{args.alpha:g}',
98 | f'maskgen_{args.maskgen}',
99 | mask_str,
100 | ]))
101 |
102 | if not args.resume:
103 | output_dir = Path('results') / 'celeba' / path
104 | print(output_dir)
105 |
106 | if mask == 'indep':
107 | data = IndepMaskedCelebA(
108 | data_dir=args.data_dir,
109 | obs_prob=obs_prob, obs_prob_high=obs_prob_high)
110 | elif mask == 'block':
111 | data = BlockMaskedCelebA(
112 | data_dir=args.data_dir, block_len=block_len)
113 | n_gpu = torch.cuda.device_count()
114 | print(f'Use {n_gpu} GPUs.')
115 |
116 | data_gen = parallelize(ConvDataGenerator())
117 | mask_gen = parallelize(ConvMaskGenerator(hard_sigmoid=hard_sigmoid))
118 |
119 | data_critic = parallelize(ConvCritic(n_channels=3))
120 | mask_critic = parallelize(ConvCritic(n_channels=1))
121 |
122 | misgan(args, data_gen, mask_gen, data_critic, mask_critic, data,
123 | output_dir, checkpoint)
124 |
125 |
126 | if __name__ == '__main__':
127 | main()
128 |
--------------------------------------------------------------------------------
/src/celeba_misgan_impute.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from datetime import datetime
4 | from pathlib import Path
5 | import argparse
6 | from celeba_generator import ConvDataGenerator, ConvMaskGenerator
7 | from celeba_critic import ConvCritic
8 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA
9 | from imputer import UNetImputer
10 | from misgan_impute import misgan_impute
11 |
12 |
13 | use_cuda = torch.cuda.is_available()
14 | device = torch.device('cuda' if use_cuda else 'cpu')
15 |
16 |
17 | def parallelize(model):
18 | return nn.DataParallel(model).to(device)
19 |
20 |
21 | def main():
22 | parser = argparse.ArgumentParser()
23 |
24 | # resume from checkpoint
25 | parser.add_argument('--resume')
26 |
27 | # path of CelebA dataset
28 | parser.add_argument('--data-dir', default='celeba-data')
29 |
30 | # training options
31 | parser.add_argument('--workers', type=int, default=0)
32 | parser.add_argument('--epoch', type=int, default=800)
33 | parser.add_argument('--batch-size', type=int, default=512)
34 | parser.add_argument('--pretrain', default=None)
35 | parser.add_argument('--imputeronly', action='store_true')
36 |
37 | # log options: 0 to disable plot-interval or save-interval
38 | parser.add_argument('--plot-interval', type=int, default=50)
39 | parser.add_argument('--save-interval', type=int, default=0)
40 | parser.add_argument('--prefix', default='impute')
41 |
42 | # mask options (data): block|indep
43 | parser.add_argument('--mask', default='block')
44 | # option for block: set to 0 for variable size
45 | parser.add_argument('--block-len', type=int, default=32)
46 | # option for indep:
47 | parser.add_argument('--obs-prob', type=float, default=.2)
48 | parser.add_argument('--obs-prob-high', type=float, default=None)
49 |
50 | # model options
51 | parser.add_argument('--tau', type=float, default=.5)
52 | parser.add_argument('--alpha', type=float, default=.1) # 0: separate
53 | parser.add_argument('--beta', type=float, default=.1)
54 | parser.add_argument('--gamma', type=float, default=0)
55 | # options for mask generator: sigmoid, hardsigmoid, fusion
56 | parser.add_argument('--maskgen', default='fusion')
57 | parser.add_argument('--gp-lambda', type=float, default=10)
58 | parser.add_argument('--n-critic', type=int, default=5)
59 | parser.add_argument('--n-latent', type=int, default=128)
60 |
61 | args = parser.parse_args()
62 |
63 | checkpoint = None
64 | # Resume from previously stored checkpoint
65 | if args.resume:
66 | print(f'Resume: {args.resume}')
67 | output_dir = Path(args.resume)
68 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'),
69 | map_location='cpu')
70 | for key, arg in vars(checkpoint['args']).items():
71 | if key not in ['resume']:
72 | setattr(args, key, arg)
73 |
74 | if args.imputeronly:
75 | assert args.pretrain is not None
76 |
77 | mask = args.mask
78 | obs_prob = args.obs_prob
79 | obs_prob_high = args.obs_prob_high
80 | block_len = args.block_len
81 | if block_len == 0:
82 | block_len = None
83 |
84 | if args.maskgen == 'sigmoid':
85 | hard_sigmoid = False
86 | elif args.maskgen == 'hardsigmoid':
87 | hard_sigmoid = True
88 | elif args.maskgen == 'fusion':
89 | hard_sigmoid = -.1, 1.1
90 | else:
91 | raise NotImplementedError
92 |
93 | if mask == 'indep':
94 | if obs_prob_high is None:
95 | mask_str = f'indep_{obs_prob:g}'
96 | else:
97 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}'
98 | elif mask == 'block':
99 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize')
100 | else:
101 | raise NotImplementedError
102 |
103 | path = '{}_{}_{}'.format(
104 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'),
105 | '_'.join([
106 | f'tau_{args.tau:g}',
107 | f'maskgen_{args.maskgen}',
108 | f'coef_{args.alpha:g}_{args.beta:g}_{args.gamma:g}',
109 | mask_str,
110 | ]))
111 |
112 | if not args.resume:
113 | output_dir = Path('results') / 'celeba' / path
114 | print(output_dir)
115 |
116 | if mask == 'indep':
117 | data = IndepMaskedCelebA(
118 | data_dir=args.data_dir,
119 | obs_prob=obs_prob, obs_prob_high=obs_prob_high)
120 | elif mask == 'block':
121 | data = BlockMaskedCelebA(
122 | data_dir=args.data_dir, block_len=block_len)
123 |
124 | n_gpu = torch.cuda.device_count()
125 | print(f'Use {n_gpu} GPUs.')
126 | data_gen = parallelize(ConvDataGenerator())
127 | mask_gen = parallelize(ConvMaskGenerator(hard_sigmoid=hard_sigmoid))
128 | imputer = UNetImputer().to(device)
129 |
130 | data_critic = parallelize(ConvCritic(n_channels=3))
131 | mask_critic = parallelize(ConvCritic(n_channels=1))
132 | impu_critic = parallelize(ConvCritic(n_channels=3))
133 |
134 | misgan_impute(args, data_gen, mask_gen, imputer,
135 | data_critic, mask_critic, impu_critic,
136 | data, output_dir, checkpoint)
137 |
138 |
139 | if __name__ == '__main__':
140 | main()
141 |
--------------------------------------------------------------------------------
/src/fcnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class FullyConnectedNet(nn.Module):
5 | def __init__(self, weights, output_shape=None):
6 | super().__init__()
7 | n_layers = len(weights) - 1
8 |
9 | layers = [nn.Linear(weights[0], weights[1])]
10 | for i in range(1, n_layers):
11 | layers.extend([nn.ReLU(), nn.Linear(weights[i], weights[i + 1])])
12 |
13 | self.model = nn.Sequential(*layers)
14 | self.output_shape = output_shape
15 |
16 | def forward(self, input):
17 | output = self.model(input.view(input.shape[0], -1))
18 | if self.output_shape is not None:
19 | output = output.view(self.output_shape)
20 | return output
21 |
--------------------------------------------------------------------------------
/src/fid.py:
--------------------------------------------------------------------------------
1 | """Code adapted from https://github.com/mseitzer/pytorch-fid
2 | """
3 | from pathlib import Path
4 | import torch
5 | import numpy as np
6 | from scipy import linalg
7 | import time
8 | import sys
9 | from inception import InceptionV3
10 |
11 |
12 | use_cuda = torch.cuda.is_available()
13 | device = torch.device('cuda' if use_cuda else 'cpu')
14 |
15 | FEATURE_DIM = 2048
16 | RESIZE = 299
17 |
18 |
19 | def get_activations(image_iterator, images, model, verbose=True):
20 | """Calculates the activations of the pool_3 layer for all images.
21 |
22 | Params:
23 | -- image_iterator
24 | : A generator that generates a batch of images at a time.
25 | -- images : Number of images that will be generated by
26 | image_iterator.
27 | -- model : Instance of inception model
28 | -- verbose : If set to True and parameter out_step is given, the number
29 | of calculated batches is reported.
30 | Returns:
31 | -- A numpy array of dimension (num images, dims) that contains the
32 | activations of the given tensor when feeding inception with the
33 | query tensor.
34 | """
35 | model.eval()
36 |
37 | if not sys.stdout.isatty():
38 | verbose = False
39 |
40 | pred_arr = np.empty((images, FEATURE_DIM))
41 | end = 0
42 | t0 = time.time()
43 |
44 | for batch in image_iterator:
45 | if not isinstance(batch, torch.Tensor):
46 | batch = batch[0]
47 | start = end
48 | batch_size = batch.shape[0]
49 | end = start + batch_size
50 |
51 | with torch.no_grad():
52 | batch = batch.to(device)
53 | pred = model(batch)[0]
54 | batch_feature = pred.cpu().numpy().reshape(batch_size, -1)
55 | pred_arr[start:end] = batch_feature
56 |
57 | if verbose:
58 | print('\rProcessed: {} time: {:.2f}'.format(
59 | end, time.time() - t0), end='', flush=True)
60 |
61 | assert end == images
62 |
63 | if verbose:
64 | print(' done')
65 |
66 | return pred_arr
67 |
68 |
69 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
70 | """Numpy implementation of the Frechet Distance.
71 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
72 | and X_2 ~ N(mu_2, C_2) is
73 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
74 |
75 | Stable version by Dougal J. Sutherland.
76 |
77 | Params:
78 | -- mu1 : Numpy array containing the activations of a layer of the
79 | inception net (like returned by the function 'get_predictions')
80 | for generated samples.
81 | -- mu2 : The sample mean over activations, precalculated on an
82 | representive data set.
83 | -- sigma1: The covariance matrix over activations for generated samples.
84 | -- sigma2: The covariance matrix over activations, precalculated on an
85 | representive data set.
86 |
87 | Returns:
88 | -- : The Frechet Distance.
89 | """
90 |
91 | mu1 = np.atleast_1d(mu1)
92 | mu2 = np.atleast_1d(mu2)
93 |
94 | sigma1 = np.atleast_2d(sigma1)
95 | sigma2 = np.atleast_2d(sigma2)
96 |
97 | assert mu1.shape == mu2.shape, \
98 | 'Training and test mean vectors have different lengths'
99 | assert sigma1.shape == sigma2.shape, \
100 | 'Training and test covariances have different dimensions'
101 |
102 | diff = mu1 - mu2
103 |
104 | # Product might be almost singular
105 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
106 | if not np.isfinite(covmean).all():
107 | msg = ('fid calculation produces singular product; '
108 | 'adding %s to diagonal of cov estimates') % eps
109 | print(msg)
110 | offset = np.eye(sigma1.shape[0]) * eps
111 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
112 |
113 | # Numerical error might give slight imaginary component
114 | if np.iscomplexobj(covmean):
115 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
116 | m = np.max(np.abs(covmean.imag))
117 | raise ValueError('Imaginary component {}'.format(m))
118 | covmean = covmean.real
119 |
120 | tr_covmean = np.trace(covmean)
121 |
122 | return (diff.dot(diff) + np.trace(sigma1) +
123 | np.trace(sigma2) - 2 * tr_covmean)
124 |
125 |
126 | def calculate_activation_statistics(image_iterator, images, model,
127 | verbose=False):
128 | """Calculation of the statistics used by the FID.
129 | Params:
130 | -- image_iterator
131 | : A generator that generates a batch of images at a time.
132 | -- images : Number of images that will be generated by
133 | image_iterator.
134 | -- model : Instance of inception model
135 | -- verbose : If set to True and parameter out_step is given, the
136 | number of calculated batches is reported.
137 | Returns:
138 | -- mu : The mean over samples of the activations of the pool_3 layer of
139 | the inception model.
140 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
141 | the inception model.
142 | """
143 | act = get_activations(image_iterator, images, model, verbose)
144 | mu = np.mean(act, axis=0)
145 | sigma = np.cov(act, rowvar=False)
146 | return mu, sigma
147 |
148 |
149 | class FID:
150 | def __init__(self, data_name, verbose=True):
151 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[FEATURE_DIM]
152 | model = InceptionV3([block_idx], RESIZE).to(device)
153 | self.verbose = verbose
154 |
155 | stats_dir = Path('fid_stats')
156 | stats_file = stats_dir / '{}_act_{}_{}.npz'.format(
157 | data_name, FEATURE_DIM, RESIZE)
158 |
159 | try:
160 | f = np.load(str(stats_file))
161 | mu, sigma = f['mu'], f['sigma']
162 | f.close()
163 | except FileNotFoundError:
164 | data_loader, images = self.complete_data()
165 | mu, sigma = calculate_activation_statistics(
166 | data_loader, images, model, verbose)
167 | stats_dir.mkdir(parents=True, exist_ok=True)
168 | np.savez(stats_file, mu=mu, sigma=sigma)
169 |
170 | self.model = model
171 | self.stats = mu, sigma
172 |
173 | def complete_data(self):
174 | raise NotImplementedError
175 |
176 | def fid(self, image_iterator, images):
177 | mu, sigma = calculate_activation_statistics(
178 | image_iterator, images, self.model, verbose=self.verbose)
179 | return calculate_frechet_distance(mu, sigma, *self.stats)
180 |
181 |
182 | class BaseSampler:
183 | def __init__(self, images):
184 | self.images = images
185 |
186 | def __iter__(self):
187 | self.n = 0
188 | return self
189 |
190 | def __next__(self):
191 | if self.n < self.images:
192 | batch = self.sample()
193 | batch_size = batch.shape[0]
194 | self.n += batch_size
195 | if self.n > self.images:
196 | return batch[:-(self.n - self.images)]
197 | return batch
198 | else:
199 | raise StopIteration
200 |
201 | def sample(self):
202 | raise NotImplementedError
203 |
204 |
205 | class BaseImputationSampler:
206 | def __init__(self, data_loader):
207 | self.data_loader = data_loader
208 |
209 | def __iter__(self):
210 | self.data_iter = iter(self.data_loader)
211 | return self
212 |
213 | def __next__(self):
214 | data, mask = next(self.data_iter)[:2]
215 | data = data.to(device)
216 | mask = mask.float()[:, None].to(device)
217 | imputed_data = self.impute(data, mask)
218 | return mask * data + (1 - mask) * imputed_data
219 |
220 | def impute(self, data, mask):
221 | raise NotImplementedError
222 |
--------------------------------------------------------------------------------
/src/imputer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from fcnet import FullyConnectedNet
4 | from unet import UnetSkipConnectionBlock
5 |
6 |
7 | # Code adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
8 | class UNet(nn.Module):
9 | def __init__(self, input_nc=3, output_nc=3, ngf=64, layers=5,
10 | norm_layer=nn.BatchNorm2d):
11 | super().__init__()
12 |
13 | mid_layers = layers - 2
14 | fact = 2**mid_layers
15 |
16 | unet_block = UnetSkipConnectionBlock(
17 | ngf * fact, ngf * fact, input_nc=None, submodule=None,
18 | norm_layer=norm_layer, innermost=True)
19 |
20 | for _ in range(mid_layers):
21 | half_fact = fact // 2
22 | unet_block = UnetSkipConnectionBlock(
23 | ngf * half_fact, ngf * fact, input_nc=None,
24 | submodule=unet_block, norm_layer=norm_layer)
25 | fact = half_fact
26 |
27 | unet_block = UnetSkipConnectionBlock(
28 | output_nc, ngf, input_nc=input_nc, submodule=unet_block,
29 | outermost=True, norm_layer=norm_layer)
30 |
31 | self.model = unet_block
32 |
33 | def forward(self, input):
34 | return self.model(input)
35 |
36 |
37 | class Imputer(nn.Module):
38 | def __init__(self):
39 | super().__init__()
40 | self.transform = lambda x: torch.sigmoid(x)
41 |
42 | def forward(self, input, mask, noise):
43 | net = input * mask + noise * (1 - mask)
44 | net = self.imputer_net(net)
45 | net = self.transform(net)
46 | # NOT replacing observed part with input data for computing
47 | # autoencoding loss.
48 | # return input * mask + net * (1 - mask)
49 | return net
50 |
51 |
52 | class UNetImputer(Imputer):
53 | def __init__(self, *args, **kwargs):
54 | super().__init__()
55 | self.imputer_net = UNet(*args, **kwargs)
56 |
57 |
58 | class FullyConnectedImputer(Imputer):
59 | def __init__(self, *args, **kwargs):
60 | super().__init__()
61 | self.imputer_net = FullyConnectedNet(*args, **kwargs)
62 |
--------------------------------------------------------------------------------
/src/inception.py:
--------------------------------------------------------------------------------
1 | """Code from https://github.com/mseitzer/pytorch-fid
2 | """
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision import models
6 |
7 |
8 | class InceptionV3(nn.Module):
9 | """Pretrained InceptionV3 network returning feature maps"""
10 |
11 | # Index of default block of inception to return,
12 | # corresponds to output of final average pooling
13 | DEFAULT_BLOCK_INDEX = 3
14 |
15 | # Maps feature dimensionality to their output blocks indices
16 | BLOCK_INDEX_BY_DIM = {
17 | 64: 0, # First max pooling features
18 | 192: 1, # Second max pooling featurs
19 | 768: 2, # Pre-aux classifier features
20 | 2048: 3 # Final average pooling features
21 | }
22 |
23 | def __init__(self,
24 | output_blocks=[DEFAULT_BLOCK_INDEX],
25 | resize_input=299, # -1: not resize
26 | normalize_input=True,
27 | requires_grad=False):
28 | """Build pretrained InceptionV3
29 |
30 | Parameters
31 | ----------
32 | output_blocks : list of int
33 | Indices of blocks to return features of. Possible values are:
34 | - 0: corresponds to output of first max pooling
35 | - 1: corresponds to output of second max pooling
36 | - 2: corresponds to output which is fed to aux classifier
37 | - 3: corresponds to output of final average pooling
38 | resize_input : bool
39 | If true, bilinearly resizes input to width and height 299 before
40 | feeding input to model. As the network without fully connected
41 | layers is fully convolutional, it should be able to handle inputs
42 | of arbitrary size, so resizing might not be strictly needed
43 | normalize_input : bool
44 | If true, normalizes the input to the statistics the pretrained
45 | Inception network expects
46 | requires_grad : bool
47 | If true, parameters of the model require gradient. Possibly useful
48 | for finetuning the network
49 | """
50 | super(InceptionV3, self).__init__()
51 |
52 | self.resize_input = resize_input
53 | self.normalize_input = normalize_input
54 | self.output_blocks = sorted(output_blocks)
55 | self.last_needed_block = max(output_blocks)
56 |
57 | assert self.last_needed_block <= 3, \
58 | 'Last possible output block index is 3'
59 |
60 | self.blocks = nn.ModuleList()
61 |
62 | inception = models.inception_v3(pretrained=True)
63 |
64 | # Block 0: input to maxpool1
65 | block0 = [
66 | inception.Conv2d_1a_3x3,
67 | inception.Conv2d_2a_3x3,
68 | inception.Conv2d_2b_3x3,
69 | nn.MaxPool2d(kernel_size=3, stride=2)
70 | ]
71 | self.blocks.append(nn.Sequential(*block0))
72 |
73 | # Block 1: maxpool1 to maxpool2
74 | if self.last_needed_block >= 1:
75 | block1 = [
76 | inception.Conv2d_3b_1x1,
77 | inception.Conv2d_4a_3x3,
78 | nn.MaxPool2d(kernel_size=3, stride=2)
79 | ]
80 | self.blocks.append(nn.Sequential(*block1))
81 |
82 | # Block 2: maxpool2 to aux classifier
83 | if self.last_needed_block >= 2:
84 | block2 = [
85 | inception.Mixed_5b,
86 | inception.Mixed_5c,
87 | inception.Mixed_5d,
88 | inception.Mixed_6a,
89 | inception.Mixed_6b,
90 | inception.Mixed_6c,
91 | inception.Mixed_6d,
92 | inception.Mixed_6e,
93 | ]
94 | self.blocks.append(nn.Sequential(*block2))
95 |
96 | # Block 3: aux classifier to final avgpool
97 | if self.last_needed_block >= 3:
98 | block3 = [
99 | inception.Mixed_7a,
100 | inception.Mixed_7b,
101 | inception.Mixed_7c,
102 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
103 | ]
104 | self.blocks.append(nn.Sequential(*block3))
105 |
106 | for param in self.parameters():
107 | param.requires_grad = requires_grad
108 |
109 | def forward(self, inp):
110 | """Get Inception feature maps
111 |
112 | Parameters
113 | ----------
114 | inp : torch.autograd.Variable
115 | Input tensor of shape Bx3xHxW. Values are expected to be in
116 | range (0, 1)
117 |
118 | Returns
119 | -------
120 | List of torch.autograd.Variable, corresponding to the selected output
121 | block, sorted ascending by index
122 | """
123 | outp = []
124 | x = inp
125 |
126 | if self.resize_input > 0:
127 | # size = 299
128 | x = F.interpolate(x, size=(self.resize_input, self.resize_input),
129 | mode='bilinear', align_corners=True)
130 |
131 | if self.normalize_input:
132 | x = x.clone()
133 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
134 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
135 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
136 |
137 | for idx, block in enumerate(self.blocks):
138 | x = block(x)
139 | if idx in self.output_blocks:
140 | outp.append(x)
141 |
142 | if idx == self.last_needed_block:
143 | break
144 |
145 | return outp
146 |
--------------------------------------------------------------------------------
/src/masked_celeba.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import datasets, transforms
3 | import numpy as np
4 | from PIL import Image
5 |
6 |
7 | class MaskedCelebA(datasets.ImageFolder):
8 | def __init__(self, data_dir='celeba-data', image_size=64, random_seed=0):
9 | transform = transforms.Compose([
10 | transforms.CenterCrop(108),
11 | transforms.Resize(size=image_size, interpolation=Image.BICUBIC),
12 | transforms.ToTensor(),
13 | # transforms.Normalize(mean=(.5, .5, .5), std=(.5, .5, .5)),
14 | ])
15 |
16 | super().__init__(data_dir, transform)
17 |
18 | self.rnd = np.random.RandomState(random_seed)
19 | self.image_size = image_size
20 | self.generate_masks()
21 |
22 | def __getitem__(self, index):
23 | image, label = super().__getitem__(index)
24 | return image, self.mask[index], label, index
25 |
26 | def __len__(self):
27 | return super().__len__()
28 |
29 |
30 | class BlockMaskedCelebA(MaskedCelebA):
31 | def __init__(self, block_len=None, *args, **kwargs):
32 | self.block_len = block_len
33 | super().__init__(*args, **kwargs)
34 |
35 | def generate_masks(self):
36 | d0_len = d1_len = self.image_size
37 | d0_min_len = 12
38 | d0_max_len = d0_len - d0_min_len
39 | d1_min_len = 12
40 | d1_max_len = d1_len - d1_min_len
41 |
42 | n_masks = len(self)
43 | self.mask = [None] * n_masks
44 | self.mask_info = [None] * n_masks
45 | for i in range(n_masks):
46 | if self.block_len is None:
47 | d0_mask_len = self.rnd.randint(d0_min_len, d0_max_len)
48 | d1_mask_len = self.rnd.randint(d1_min_len, d1_max_len)
49 | else:
50 | d0_mask_len = d1_mask_len = self.block_len
51 |
52 | d0_start = self.rnd.randint(0, d0_len - d0_mask_len + 1)
53 | d1_start = self.rnd.randint(0, d1_len - d1_mask_len + 1)
54 |
55 | mask = torch.zeros((d0_len, d1_len), dtype=torch.uint8)
56 | mask[d0_start:(d0_start + d0_mask_len),
57 | d1_start:(d1_start + d1_mask_len)] = 1
58 | self.mask[i] = mask
59 | self.mask_info[i] = d0_start, d1_start, d0_mask_len, d1_mask_len
60 |
61 |
62 | class IndepMaskedCelebA(MaskedCelebA):
63 | def __init__(self, obs_prob=.2, obs_prob_high=None, *args, **kwargs):
64 | self.prob = obs_prob
65 | self.prob_high = obs_prob_high
66 | super().__init__(*args, **kwargs)
67 |
68 | def generate_masks(self):
69 | imsize = self.image_size
70 | prob = self.prob
71 | prob_high = self.prob_high
72 | n_masks = len(self)
73 | self.mask = [None] * n_masks
74 | for i in range(n_masks):
75 | if prob_high is None:
76 | p = prob
77 | else:
78 | p = self.rnd.uniform(prob, prob_high)
79 | self.mask[i] = torch.ByteTensor(imsize, imsize).bernoulli_(p)
80 |
--------------------------------------------------------------------------------
/src/masked_mnist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from torchvision import datasets, transforms
4 | import numpy as np
5 |
6 |
7 | class MaskedMNIST(Dataset):
8 | def __init__(self, data_dir='mnist-data', image_size=28, random_seed=0):
9 | self.rnd = np.random.RandomState(random_seed)
10 | self.image_size = image_size
11 | if image_size == 28:
12 | self.data = datasets.MNIST(
13 | data_dir, train=True, download=True,
14 | transform=transforms.ToTensor())
15 | else:
16 | self.data = datasets.MNIST(
17 | data_dir, train=True, download=True,
18 | transform=transforms.Compose([
19 | transforms.Resize(image_size), transforms.ToTensor()]))
20 | self.generate_masks()
21 |
22 | def __getitem__(self, index):
23 | image, label = self.data[index]
24 | return image, self.mask[index], label, index
25 |
26 | def __len__(self):
27 | return len(self.data)
28 |
29 | def generate_masks(self):
30 | raise NotImplementedError
31 |
32 |
33 | class BlockMaskedMNIST(MaskedMNIST):
34 | def __init__(self, block_len=None, *args, **kwargs):
35 | self.block_len = block_len
36 | super().__init__(*args, **kwargs)
37 |
38 | def generate_masks(self):
39 | d0_len = d1_len = self.image_size
40 | d0_min_len = 7
41 | d0_max_len = d0_len - d0_min_len
42 | d1_min_len = 7
43 | d1_max_len = d1_len - d1_min_len
44 |
45 | n_masks = len(self)
46 | self.mask = [None] * n_masks
47 | self.mask_info = [None] * n_masks
48 | for i in range(n_masks):
49 | if self.block_len is None:
50 | d0_mask_len = self.rnd.randint(d0_min_len, d0_max_len)
51 | d1_mask_len = self.rnd.randint(d1_min_len, d1_max_len)
52 | else:
53 | d0_mask_len = d1_mask_len = self.block_len
54 |
55 | d0_start = self.rnd.randint(0, d0_len - d0_mask_len + 1)
56 | d1_start = self.rnd.randint(0, d1_len - d1_mask_len + 1)
57 |
58 | mask = torch.zeros((d0_len, d1_len), dtype=torch.uint8)
59 | mask[d0_start:(d0_start + d0_mask_len),
60 | d1_start:(d1_start + d1_mask_len)] = 1
61 | self.mask[i] = mask
62 | self.mask_info[i] = d0_start, d1_start, d0_mask_len, d1_mask_len
63 |
64 |
65 | class IndepMaskedMNIST(MaskedMNIST):
66 | def __init__(self, obs_prob=.2, obs_prob_high=None, *args, **kwargs):
67 | self.prob = obs_prob
68 | self.prob_high = obs_prob_high
69 | super().__init__(*args, **kwargs)
70 |
71 | def generate_masks(self):
72 | imsize = self.image_size
73 | prob = self.prob
74 | prob_high = self.prob_high
75 | n_masks = len(self)
76 | self.mask = [None] * n_masks
77 | for i in range(n_masks):
78 | if prob_high is None:
79 | p = prob
80 | else:
81 | p = self.rnd.uniform(prob, prob_high)
82 | self.mask[i] = torch.ByteTensor(imsize, imsize).bernoulli_(p)
83 |
--------------------------------------------------------------------------------
/src/misgan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | from torch.utils.data import DataLoader
4 | import time
5 | import pylab as plt
6 | import seaborn as sns
7 | from collections import defaultdict
8 | from plot import plot_samples
9 | from utils import CriticUpdater, mkdir, mask_data
10 |
11 |
12 | use_cuda = torch.cuda.is_available()
13 | device = torch.device('cuda' if use_cuda else 'cpu')
14 |
15 |
16 | def misgan(args, data_gen, mask_gen, data_critic, mask_critic, data,
17 | output_dir, checkpoint=None):
18 | n_critic = args.n_critic
19 | gp_lambda = args.gp_lambda
20 | batch_size = args.batch_size
21 | nz = args.n_latent
22 | epochs = args.epoch
23 | plot_interval = args.plot_interval
24 | save_interval = args.save_interval
25 | alpha = args.alpha
26 | tau = args.tau
27 |
28 | gen_data_dir = mkdir(output_dir / 'img')
29 | gen_mask_dir = mkdir(output_dir / 'mask')
30 | log_dir = mkdir(output_dir / 'log')
31 | model_dir = mkdir(output_dir / 'model')
32 |
33 | data_loader = DataLoader(data, batch_size=batch_size, shuffle=True,
34 | drop_last=True)
35 | n_batch = len(data_loader)
36 |
37 | data_noise = torch.FloatTensor(batch_size, nz).to(device)
38 | mask_noise = torch.FloatTensor(batch_size, nz).to(device)
39 |
40 | # Interpolation coefficient
41 | eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device)
42 |
43 | # For computing gradient penalty
44 | ones = torch.ones(batch_size).to(device)
45 |
46 | lrate = 1e-4
47 | # lrate = 1e-5
48 | data_gen_optimizer = optim.Adam(
49 | data_gen.parameters(), lr=lrate, betas=(.5, .9))
50 | mask_gen_optimizer = optim.Adam(
51 | mask_gen.parameters(), lr=lrate, betas=(.5, .9))
52 |
53 | data_critic_optimizer = optim.Adam(
54 | data_critic.parameters(), lr=lrate, betas=(.5, .9))
55 | mask_critic_optimizer = optim.Adam(
56 | mask_critic.parameters(), lr=lrate, betas=(.5, .9))
57 |
58 | update_data_critic = CriticUpdater(
59 | data_critic, data_critic_optimizer, eps, ones, gp_lambda)
60 | update_mask_critic = CriticUpdater(
61 | mask_critic, mask_critic_optimizer, eps, ones, gp_lambda)
62 |
63 | start_epoch = 0
64 | critic_updates = 0
65 | log = defaultdict(list)
66 |
67 | if checkpoint:
68 | data_gen.load_state_dict(checkpoint['data_gen'])
69 | mask_gen.load_state_dict(checkpoint['mask_gen'])
70 | data_critic.load_state_dict(checkpoint['data_critic'])
71 | mask_critic.load_state_dict(checkpoint['mask_critic'])
72 | data_gen_optimizer.load_state_dict(checkpoint['data_gen_opt'])
73 | mask_gen_optimizer.load_state_dict(checkpoint['mask_gen_opt'])
74 | data_critic_optimizer.load_state_dict(checkpoint['data_critic_opt'])
75 | mask_critic_optimizer.load_state_dict(checkpoint['mask_critic_opt'])
76 | start_epoch = checkpoint['epoch']
77 | critic_updates = checkpoint['critic_updates']
78 | log = checkpoint['log']
79 |
80 | with (log_dir / 'gpu.txt').open('a') as f:
81 | print(torch.cuda.device_count(), start_epoch, file=f)
82 |
83 | def save_model(path, epoch, critic_updates=0):
84 | torch.save({
85 | 'data_gen': data_gen.state_dict(),
86 | 'mask_gen': mask_gen.state_dict(),
87 | 'data_critic': data_critic.state_dict(),
88 | 'mask_critic': mask_critic.state_dict(),
89 | 'data_gen_opt': data_gen_optimizer.state_dict(),
90 | 'mask_gen_opt': mask_gen_optimizer.state_dict(),
91 | 'data_critic_opt': data_critic_optimizer.state_dict(),
92 | 'mask_critic_opt': mask_critic_optimizer.state_dict(),
93 | 'epoch': epoch + 1,
94 | 'critic_updates': critic_updates,
95 | 'log': log,
96 | 'args': args,
97 | }, str(path))
98 |
99 | sns.set()
100 |
101 | start = time.time()
102 | epoch_start = start
103 |
104 | for epoch in range(start_epoch, epochs):
105 | sum_data_loss, sum_mask_loss = 0, 0
106 | for real_data, real_mask, _, _ in data_loader:
107 | # Assume real_data and mask have the same number of channels.
108 | # Could be modified to handle multi-channel images and
109 | # single-channel masks.
110 | real_mask = real_mask.float()[:, None]
111 |
112 | real_data = real_data.to(device)
113 | real_mask = real_mask.to(device)
114 |
115 | masked_real_data = mask_data(real_data, real_mask, tau)
116 |
117 | # Update discriminators' parameters
118 | data_noise.normal_()
119 | mask_noise.normal_()
120 |
121 | fake_data = data_gen(data_noise)
122 | fake_mask = mask_gen(mask_noise)
123 |
124 | masked_fake_data = mask_data(fake_data, fake_mask, tau)
125 |
126 | update_data_critic(masked_real_data, masked_fake_data)
127 | update_mask_critic(real_mask, fake_mask)
128 |
129 | sum_data_loss += update_data_critic.loss_value
130 | sum_mask_loss += update_mask_critic.loss_value
131 |
132 | critic_updates += 1
133 |
134 | if critic_updates == n_critic:
135 | critic_updates = 0
136 |
137 | # Update generators' parameters
138 |
139 | for p in data_critic.parameters():
140 | p.requires_grad_(False)
141 | for p in mask_critic.parameters():
142 | p.requires_grad_(False)
143 |
144 | data_gen.zero_grad()
145 | mask_gen.zero_grad()
146 |
147 | data_noise.normal_()
148 | mask_noise.normal_()
149 |
150 | fake_data = data_gen(data_noise)
151 | fake_mask = mask_gen(mask_noise)
152 | masked_fake_data = mask_data(fake_data, fake_mask, tau)
153 |
154 | data_loss = -data_critic(masked_fake_data).mean()
155 | data_loss.backward(retain_graph=True)
156 | data_gen_optimizer.step()
157 |
158 | mask_loss = -mask_critic(fake_mask).mean()
159 | (mask_loss + data_loss * alpha).backward()
160 | mask_gen_optimizer.step()
161 |
162 | for p in data_critic.parameters():
163 | p.requires_grad_(True)
164 | for p in mask_critic.parameters():
165 | p.requires_grad_(True)
166 |
167 | mean_data_loss = sum_data_loss / n_batch
168 | mean_mask_loss = sum_mask_loss / n_batch
169 | log['data loss', 'data_loss'].append(mean_data_loss)
170 | log['mask loss', 'mask_loss'].append(mean_mask_loss)
171 |
172 | for (name, shortname), trace in log.items():
173 | fig, ax = plt.subplots(figsize=(6, 4))
174 | ax.plot(trace)
175 | ax.set_ylabel(name)
176 | ax.set_xlabel('epoch')
177 | fig.savefig(str(log_dir / f'{shortname}.png'), dpi=300)
178 | plt.close(fig)
179 |
180 | if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
181 | print(f'[{epoch:4}] {mean_data_loss:12.4f} {mean_mask_loss:12.4f}')
182 |
183 | filename = f'{epoch:04d}.png'
184 |
185 | data_gen.eval()
186 | mask_gen.eval()
187 |
188 | with torch.no_grad():
189 | data_noise.normal_()
190 | mask_noise.normal_()
191 |
192 | data_samples = data_gen(data_noise)
193 | plot_samples(data_samples, str(gen_data_dir / filename))
194 |
195 | mask_samples = mask_gen(mask_noise)
196 | plot_samples(mask_samples, str(gen_mask_dir / filename))
197 |
198 | data_gen.train()
199 | mask_gen.train()
200 |
201 | if save_interval > 0 and (epoch + 1) % save_interval == 0:
202 | save_model(model_dir / f'{epoch:04d}.pth', epoch, critic_updates)
203 |
204 | epoch_end = time.time()
205 | time_elapsed = epoch_end - start
206 | epoch_time = epoch_end - epoch_start
207 | epoch_start = epoch_end
208 | with (log_dir / 'time.txt').open('a') as f:
209 | print(epoch, epoch_time, time_elapsed, file=f)
210 | save_model(log_dir / 'checkpoint.pth', epoch, critic_updates)
211 |
212 | print(output_dir)
213 |
--------------------------------------------------------------------------------
/src/misgan_impute.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | from torch.utils.data import DataLoader
4 | import time
5 | import pylab as plt
6 | import seaborn as sns
7 | from collections import defaultdict
8 | from plot import plot_grid, plot_samples
9 | from utils import CriticUpdater, mask_norm, mkdir, mask_data
10 |
11 |
12 | use_cuda = torch.cuda.is_available()
13 | device = torch.device('cuda' if use_cuda else 'cpu')
14 |
15 |
16 | def misgan_impute(args, data_gen, mask_gen, imputer,
17 | data_critic, mask_critic, impu_critic,
18 | data, output_dir, checkpoint=None):
19 | n_critic = args.n_critic
20 | gp_lambda = args.gp_lambda
21 | batch_size = args.batch_size
22 | nz = args.n_latent
23 | epochs = args.epoch
24 | plot_interval = args.plot_interval
25 | save_model_interval = args.save_interval
26 | alpha = args.alpha
27 | beta = args.beta
28 | gamma = args.gamma
29 | tau = args.tau
30 | update_all_networks = not args.imputeronly
31 |
32 | gen_data_dir = mkdir(output_dir / 'img')
33 | gen_mask_dir = mkdir(output_dir / 'mask')
34 | impute_dir = mkdir(output_dir / 'impute')
35 | log_dir = mkdir(output_dir / 'log')
36 | model_dir = mkdir(output_dir / 'model')
37 |
38 | data_loader = DataLoader(data, batch_size=batch_size, shuffle=True,
39 | drop_last=True, num_workers=args.workers)
40 | n_batch = len(data_loader)
41 | data_shape = data[0][0].shape
42 |
43 | data_noise = torch.FloatTensor(batch_size, nz).to(device)
44 | mask_noise = torch.FloatTensor(batch_size, nz).to(device)
45 | impu_noise = torch.FloatTensor(batch_size, *data_shape).to(device)
46 |
47 | # Interpolation coefficient
48 | eps = torch.FloatTensor(batch_size, 1, 1, 1).to(device)
49 |
50 | # For computing gradient penalty
51 | ones = torch.ones(batch_size).to(device)
52 |
53 | lrate = 1e-4
54 | imputer_lrate = 2e-4
55 | data_gen_optimizer = optim.Adam(
56 | data_gen.parameters(), lr=lrate, betas=(.5, .9))
57 | mask_gen_optimizer = optim.Adam(
58 | mask_gen.parameters(), lr=lrate, betas=(.5, .9))
59 | imputer_optimizer = optim.Adam(
60 | imputer.parameters(), lr=imputer_lrate, betas=(.5, .9))
61 |
62 | data_critic_optimizer = optim.Adam(
63 | data_critic.parameters(), lr=lrate, betas=(.5, .9))
64 | mask_critic_optimizer = optim.Adam(
65 | mask_critic.parameters(), lr=lrate, betas=(.5, .9))
66 | impu_critic_optimizer = optim.Adam(
67 | impu_critic.parameters(), lr=imputer_lrate, betas=(.5, .9))
68 |
69 | update_data_critic = CriticUpdater(
70 | data_critic, data_critic_optimizer, eps, ones, gp_lambda)
71 | update_mask_critic = CriticUpdater(
72 | mask_critic, mask_critic_optimizer, eps, ones, gp_lambda)
73 | update_impu_critic = CriticUpdater(
74 | impu_critic, impu_critic_optimizer, eps, ones, gp_lambda)
75 |
76 | start_epoch = 0
77 | critic_updates = 0
78 | log = defaultdict(list)
79 |
80 | if args.resume:
81 | data_gen.load_state_dict(checkpoint['data_gen'])
82 | mask_gen.load_state_dict(checkpoint['mask_gen'])
83 | imputer.load_state_dict(checkpoint['imputer'])
84 | data_critic.load_state_dict(checkpoint['data_critic'])
85 | mask_critic.load_state_dict(checkpoint['mask_critic'])
86 | impu_critic.load_state_dict(checkpoint['impu_critic'])
87 | data_gen_optimizer.load_state_dict(checkpoint['data_gen_opt'])
88 | mask_gen_optimizer.load_state_dict(checkpoint['mask_gen_opt'])
89 | imputer_optimizer.load_state_dict(checkpoint['imputer_opt'])
90 | data_critic_optimizer.load_state_dict(checkpoint['data_critic_opt'])
91 | mask_critic_optimizer.load_state_dict(checkpoint['mask_critic_opt'])
92 | impu_critic_optimizer.load_state_dict(checkpoint['impu_critic_opt'])
93 | start_epoch = checkpoint['epoch']
94 | critic_updates = checkpoint['critic_updates']
95 | log = checkpoint['log']
96 | elif args.pretrain:
97 | pretrain = torch.load(args.pretrain, map_location='cpu')
98 | data_gen.load_state_dict(pretrain['data_gen'])
99 | mask_gen.load_state_dict(pretrain['mask_gen'])
100 | data_critic.load_state_dict(pretrain['data_critic'])
101 | mask_critic.load_state_dict(pretrain['mask_critic'])
102 | if 'imputer' in pretrain:
103 | imputer.load_state_dict(pretrain['imputer'])
104 | impu_critic.load_state_dict(pretrain['impu_critic'])
105 |
106 | with (log_dir / 'gpu.txt').open('a') as f:
107 | print(torch.cuda.device_count(), start_epoch, file=f)
108 |
109 | def save_model(path, epoch, critic_updates=0):
110 | torch.save({
111 | 'data_gen': data_gen.state_dict(),
112 | 'mask_gen': mask_gen.state_dict(),
113 | 'imputer': imputer.state_dict(),
114 | 'data_critic': data_critic.state_dict(),
115 | 'mask_critic': mask_critic.state_dict(),
116 | 'impu_critic': impu_critic.state_dict(),
117 | 'data_gen_opt': data_gen_optimizer.state_dict(),
118 | 'mask_gen_opt': mask_gen_optimizer.state_dict(),
119 | 'imputer_opt': imputer_optimizer.state_dict(),
120 | 'data_critic_opt': data_critic_optimizer.state_dict(),
121 | 'mask_critic_opt': mask_critic_optimizer.state_dict(),
122 | 'impu_critic_opt': impu_critic_optimizer.state_dict(),
123 | 'epoch': epoch + 1,
124 | 'critic_updates': critic_updates,
125 | 'log': log,
126 | 'args': args,
127 | }, str(path))
128 |
129 | sns.set()
130 | start = time.time()
131 | epoch_start = start
132 |
133 | for epoch in range(start_epoch, epochs):
134 | sum_data_loss, sum_mask_loss, sum_impu_loss = 0, 0, 0
135 | for real_data, real_mask, _, index in data_loader:
136 | # Assume real_data and real_mask have the same number of channels.
137 | # Could be modified to handle multi-channel images and
138 | # single-channel masks.
139 | real_mask = real_mask.float()[:, None]
140 |
141 | real_data = real_data.to(device)
142 | real_mask = real_mask.to(device)
143 |
144 | masked_real_data = mask_data(real_data, real_mask, tau)
145 |
146 | # Update discriminators' parameters
147 | data_noise.normal_()
148 | fake_data = data_gen(data_noise)
149 |
150 | impu_noise.uniform_()
151 | imputed_data = imputer(real_data, real_mask, impu_noise)
152 | masked_imputed_data = mask_data(real_data, real_mask, imputed_data)
153 |
154 | if update_all_networks:
155 | mask_noise.normal_()
156 | fake_mask = mask_gen(mask_noise)
157 | masked_fake_data = mask_data(fake_data, fake_mask, tau)
158 | update_data_critic(masked_real_data, masked_fake_data)
159 | update_mask_critic(real_mask, fake_mask)
160 |
161 | sum_data_loss += update_data_critic.loss_value
162 | sum_mask_loss += update_mask_critic.loss_value
163 |
164 | update_impu_critic(fake_data, masked_imputed_data)
165 | sum_impu_loss += update_impu_critic.loss_value
166 |
167 | critic_updates += 1
168 |
169 | if critic_updates == n_critic:
170 | critic_updates = 0
171 |
172 | # Update generators' parameters
173 | if update_all_networks:
174 | for p in data_critic.parameters():
175 | p.requires_grad_(False)
176 | for p in mask_critic.parameters():
177 | p.requires_grad_(False)
178 | for p in impu_critic.parameters():
179 | p.requires_grad_(False)
180 |
181 | data_noise.normal_()
182 | fake_data = data_gen(data_noise)
183 |
184 | if update_all_networks:
185 | mask_noise.normal_()
186 | fake_mask = mask_gen(mask_noise)
187 | masked_fake_data = mask_data(fake_data, fake_mask, tau)
188 | data_loss = -data_critic(masked_fake_data).mean()
189 | mask_loss = -mask_critic(fake_mask).mean()
190 |
191 | impu_noise.uniform_()
192 | imputed_data = imputer(real_data, real_mask, impu_noise)
193 | masked_imputed_data = mask_data(real_data, real_mask,
194 | imputed_data)
195 | impu_loss = -impu_critic(masked_imputed_data).mean()
196 |
197 | if update_all_networks:
198 | mask_gen.zero_grad()
199 | (mask_loss + data_loss * alpha).backward(retain_graph=True)
200 | mask_gen_optimizer.step()
201 |
202 | data_gen.zero_grad()
203 | (data_loss + impu_loss * beta).backward(retain_graph=True)
204 | data_gen_optimizer.step()
205 |
206 | imputer.zero_grad()
207 | if gamma > 0:
208 | imputer_mismatch_loss = mask_norm(
209 | (imputed_data - real_data)**2, real_mask)
210 | (impu_loss + imputer_mismatch_loss * gamma).backward()
211 | else:
212 | impu_loss.backward()
213 | imputer_optimizer.step()
214 |
215 | if update_all_networks:
216 | for p in data_critic.parameters():
217 | p.requires_grad_(True)
218 | for p in mask_critic.parameters():
219 | p.requires_grad_(True)
220 | for p in impu_critic.parameters():
221 | p.requires_grad_(True)
222 |
223 | if update_all_networks:
224 | mean_data_loss = sum_data_loss / n_batch
225 | mean_mask_loss = sum_mask_loss / n_batch
226 | log['data loss', 'data_loss'].append(mean_data_loss)
227 | log['mask loss', 'mask_loss'].append(mean_mask_loss)
228 | mean_impu_loss = sum_impu_loss / n_batch
229 | log['imputer loss', 'impu_loss'].append(mean_impu_loss)
230 |
231 | if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
232 | if update_all_networks:
233 | print('[{:4}] {:12.4f} {:12.4f} {:12.4f}'.format(
234 | epoch, mean_data_loss, mean_mask_loss, mean_impu_loss))
235 | else:
236 | print('[{:4}] {:12.4f}'.format(epoch, mean_impu_loss))
237 |
238 | filename = f'{epoch:04d}.png'
239 | with torch.no_grad():
240 | data_gen.eval()
241 | mask_gen.eval()
242 | imputer.eval()
243 |
244 | data_noise.normal_()
245 | mask_noise.normal_()
246 |
247 | data_samples = data_gen(data_noise)
248 | plot_samples(data_samples, str(gen_data_dir / filename))
249 |
250 | mask_samples = mask_gen(mask_noise)
251 | plot_samples(mask_samples, str(gen_mask_dir / filename))
252 |
253 | # Plot imputation results
254 | impu_noise.uniform_()
255 | imputed_data = imputer(real_data, real_mask, impu_noise)
256 | imputed_data = mask_data(real_data, real_mask, imputed_data)
257 | if hasattr(data, 'mask_info'):
258 | bbox = [data.mask_info[idx] for idx in index]
259 | else:
260 | bbox = None
261 | plot_grid(imputed_data, bbox, gap=2,
262 | save_file=str(impute_dir / filename))
263 |
264 | data_gen.train()
265 | mask_gen.train()
266 | imputer.train()
267 |
268 | for (name, shortname), trace in log.items():
269 | fig, ax = plt.subplots(figsize=(6, 4))
270 | ax.plot(trace)
271 | ax.set_ylabel(name)
272 | ax.set_xlabel('epoch')
273 | fig.savefig(str(log_dir / f'{shortname}.png'), dpi=300)
274 | plt.close(fig)
275 |
276 | if save_model_interval > 0 and (epoch + 1) % save_model_interval == 0:
277 | save_model(model_dir / f'{epoch:04d}.pth', epoch, critic_updates)
278 |
279 | epoch_end = time.time()
280 | time_elapsed = epoch_end - start
281 | epoch_time = epoch_end - epoch_start
282 | epoch_start = epoch_end
283 | with (log_dir / 'epoch-time.txt').open('a') as f:
284 | print(epoch, epoch_time, time_elapsed, file=f)
285 | save_model(log_dir / 'checkpoint.pth', epoch, critic_updates)
286 |
287 | print(output_dir)
288 |
--------------------------------------------------------------------------------
/src/mnist_critic.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class ConvCritic(nn.Module):
5 | def __init__(self):
6 | super().__init__()
7 |
8 | self.DIM = 64
9 | main = nn.Sequential(
10 | nn.Conv2d(1, self.DIM, 5, stride=2, padding=2),
11 | nn.ReLU(True),
12 | nn.Conv2d(self.DIM, 2 * self.DIM, 5, stride=2, padding=2),
13 | nn.ReLU(True),
14 | nn.Conv2d(2 * self.DIM, 4 * self.DIM, 5, stride=2, padding=2),
15 | nn.ReLU(True),
16 | )
17 | self.main = main
18 | self.output = nn.Linear(4 * 4 * 4 * self.DIM, 1)
19 |
20 | def forward(self, input):
21 | input = input.view(-1, 1, 28, 28)
22 | net = self.main(input)
23 | net = net.view(-1, 4 * 4 * 4 * self.DIM)
24 | net = self.output(net)
25 | return net.view(-1)
26 |
27 |
28 | class FCCritic(nn.Module):
29 | def __init__(self):
30 | super().__init__()
31 |
32 | self.in_dim = 784
33 | self.main = nn.Sequential(
34 | nn.Linear(self.in_dim, 512),
35 | nn.ReLU(True),
36 | nn.Linear(512, 256),
37 | nn.ReLU(True),
38 | nn.Linear(256, 128),
39 | nn.ReLU(True),
40 | nn.Linear(128, 1),
41 | )
42 |
43 | def forward(self, input):
44 | input = input.view(input.size(0), -1)
45 | out = self.main(input)
46 | return out.view(-1)
47 |
--------------------------------------------------------------------------------
/src/mnist_fid.py:
--------------------------------------------------------------------------------
1 | """Code adapted from https://github.com/mseitzer/pytorch-fid
2 | """
3 | import torch
4 | import numpy as np
5 | from scipy import linalg
6 | from torch.utils.data import DataLoader
7 | from torchvision import datasets, transforms
8 | import argparse
9 |
10 | import mnist_model
11 | from mnist_generator import ConvDataGenerator, FCDataGenerator
12 | from mnist_imputer import ComplementImputer, MaskImputer, FixedNoiseDimImputer
13 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST
14 | from pathlib import Path
15 |
16 |
17 | use_cuda = torch.cuda.is_available()
18 | device = torch.device('cuda' if use_cuda else 'cpu')
19 |
20 | feature_layer = 0
21 |
22 |
23 | def get_activations(image_generator, images, model, verbose=False):
24 | """Calculates the activations of the pool_3 layer for all images.
25 |
26 | Params:
27 | -- image_generator
28 | : A generator that generates a batch of images at a time.
29 | -- images : Number of images that will be generated by
30 | image_generator.
31 | -- model : Instance of inception model
32 | -- verbose : If set to True and parameter out_step is given, the number
33 | of calculated batches is reported.
34 | Returns:
35 | -- A numpy array of dimension (num images, dims) that contains the
36 | activations of the given tensor when feeding inception with the
37 | query tensor.
38 | """
39 | model.eval()
40 |
41 | pred_arr = None
42 | end = 0
43 | for i, batch in enumerate(image_generator):
44 | if verbose:
45 | print('\rPropagating batch %d' % (i + 1), end='', flush=True)
46 | start = end
47 | batch_size = batch.shape[0]
48 | end = start + batch_size
49 | batch = batch.to(device)
50 |
51 | with torch.no_grad():
52 | model(batch)
53 | pred = model.feature[feature_layer]
54 | batch_feature = pred.cpu().numpy().reshape(batch_size, -1)
55 | if pred_arr is None:
56 | pred_arr = np.empty((images, batch_feature.shape[1]))
57 | pred_arr[start:end] = batch_feature
58 |
59 | if verbose:
60 | print(' done')
61 |
62 | return pred_arr
63 |
64 |
65 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
66 | """Numpy implementation of the Frechet Distance.
67 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
68 | and X_2 ~ N(mu_2, C_2) is
69 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
70 |
71 | Stable version by Dougal J. Sutherland.
72 |
73 | Params:
74 | -- mu1 : Numpy array containing the activations of a layer of the
75 | inception net (like returned by the function 'get_predictions')
76 | for generated samples.
77 | -- mu2 : The sample mean over activations, precalculated on an
78 | representive data set.
79 | -- sigma1: The covariance matrix over activations for generated samples.
80 | -- sigma2: The covariance matrix over activations, precalculated on an
81 | representive data set.
82 |
83 | Returns:
84 | -- : The Frechet Distance.
85 | """
86 |
87 | mu1 = np.atleast_1d(mu1)
88 | mu2 = np.atleast_1d(mu2)
89 |
90 | sigma1 = np.atleast_2d(sigma1)
91 | sigma2 = np.atleast_2d(sigma2)
92 |
93 | assert mu1.shape == mu2.shape, \
94 | 'Training and test mean vectors have different lengths'
95 | assert sigma1.shape == sigma2.shape, \
96 | 'Training and test covariances have different dimensions'
97 |
98 | diff = mu1 - mu2
99 |
100 | # Product might be almost singular
101 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
102 | if not np.isfinite(covmean).all():
103 | msg = ('fid calculation produces singular product; '
104 | 'adding %s to diagonal of cov estimates') % eps
105 | print(msg)
106 | offset = np.eye(sigma1.shape[0]) * eps
107 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
108 |
109 | # Numerical error might give slight imaginary component
110 | if np.iscomplexobj(covmean):
111 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
112 | m = np.max(np.abs(covmean.imag))
113 | raise ValueError(f'Imaginary component {m}')
114 | covmean = covmean.real
115 |
116 | tr_covmean = np.trace(covmean)
117 |
118 | return (diff.dot(diff) + np.trace(sigma1) +
119 | np.trace(sigma2) - 2 * tr_covmean)
120 |
121 |
122 | def calculate_activation_statistics(image_generator, images, model,
123 | verbose=False, weight=None):
124 | """Calculation of the statistics used by the FID.
125 | Params:
126 | -- image_generator
127 | : A generator that generates a batch of images at a time.
128 | -- images : Number of images that will be generated by
129 | image_generator.
130 | -- model : Instance of inception model
131 | -- verbose : If set to True and parameter out_step is given, the
132 | number of calculated batches is reported.
133 | Returns:
134 | -- mu : The mean over samples of the activations of the pool_3 layer of
135 | the inception model.
136 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
137 | the inception model.
138 | """
139 | act = get_activations(image_generator, images, model, verbose)
140 | if weight is None:
141 | mu = np.mean(act, axis=0)
142 | sigma = np.cov(act, rowvar=False)
143 | else:
144 | mu = np.average(act, axis=0, weights=weight)
145 | sigma = np.cov(act, rowvar=False, aweights=weight)
146 | return mu, sigma
147 |
148 |
149 | class MNISTModel:
150 | def __init__(self):
151 | model = mnist_model.Net().to(device)
152 | model.eval()
153 | map_location = None if use_cuda else 'cpu'
154 | model.load_state_dict(
155 | torch.load('mnist.pth', map_location=map_location))
156 |
157 | stats_file = f'mnist_act_{feature_layer}.npz'
158 | try:
159 | f = np.load(stats_file)
160 | m_mnist, s_mnist = f['mu'][:], f['sigma'][:]
161 | f.close()
162 | except FileNotFoundError:
163 | data = datasets.MNIST('data', train=True, download=True,
164 | transform=transforms.ToTensor())
165 | images = len(data)
166 | batch_size = 64
167 | data_loader = DataLoader([image for image, _ in data],
168 | batch_size=batch_size)
169 | m_mnist, s_mnist = calculate_activation_statistics(
170 | data_loader, images, model, verbose=True)
171 | np.savez(stats_file, mu=m_mnist, sigma=s_mnist)
172 |
173 | self.model = model
174 | self.mnist_stats = m_mnist, s_mnist
175 |
176 | def get_feature(self, samples):
177 | self.model(samples)
178 | feature = self.model.feature[feature_layer]
179 | return feature.cpu().numpy().reshape(samples.shape[0], -1)
180 |
181 | def fid(self, features):
182 | mu = np.mean(features, axis=0)
183 | sigma = np.cov(features, rowvar=False)
184 | return calculate_frechet_distance(mu, sigma, *self.mnist_stats)
185 |
186 |
187 | def data_generator_fid(data_generator,
188 | n_samples=60000, batch_size=64, verbose=False):
189 | mnist_model = MNISTModel()
190 | latent_size = 128
191 | data_noise = torch.FloatTensor(batch_size, latent_size).to(device)
192 |
193 | with torch.no_grad():
194 | count = 0
195 | features = None
196 | while count < n_samples:
197 | data_noise.normal_()
198 | samples = data_generator(data_noise)
199 | batch_feature = mnist_model.get_feature(samples)
200 |
201 | if features is None:
202 | features = np.empty((n_samples, batch_feature.shape[1]))
203 |
204 | if count + batch_size > n_samples:
205 | batch_size = n_samples - count
206 | features[count:] = batch_feature[:batch_size]
207 | else:
208 | features[count:(count + batch_size)] = batch_feature
209 |
210 | count += batch_size
211 | if verbose:
212 | print(f'\rGenerate images {count}', end='', flush=True)
213 | if verbose:
214 | print(' done')
215 | return mnist_model.fid(features)
216 |
217 |
218 | def imputer_fid(imputer, data, batch_size=64, verbose=False):
219 | mnist_model = MNISTModel()
220 | impu_noise = torch.FloatTensor(batch_size, 1, 28, 28).to(device)
221 | data_loader = DataLoader(data, batch_size=batch_size, drop_last=True)
222 | n_samples = len(data_loader) * batch_size
223 |
224 | with torch.no_grad():
225 | start = 0
226 | features = None
227 | for real_data, real_mask, _, index in data_loader:
228 | real_mask = real_mask.float()[:, None]
229 | real_data = real_data.to(device)
230 | real_mask = real_mask.to(device)
231 | impu_noise.uniform_()
232 | imputed_data = imputer(real_data, real_mask, impu_noise)
233 |
234 | batch_feature = mnist_model.get_feature(imputed_data)
235 | if features is None:
236 | features = np.empty((n_samples, batch_feature.shape[1]))
237 | features[start:(start + batch_size)] = batch_feature
238 | start += batch_size
239 | if verbose:
240 | print(f'\rGenerate images {start}', end='', flush=True)
241 | if verbose:
242 | print(' done')
243 | return mnist_model.fid(features)
244 |
245 |
246 | def pretrained_misgan_fid(model_file, samples=60000, batch_size=64):
247 | model = torch.load(model_file, map_location='cpu')
248 | args = model['args']
249 | if args.generator == 'conv':
250 | DataGenerator = ConvDataGenerator
251 | elif args.generator == 'fc':
252 | DataGenerator = FCDataGenerator
253 | data_gen = DataGenerator().to(device)
254 | data_gen.load_state_dict(model['data_gen'])
255 | return data_generator_fid(data_gen, verbose=True)
256 |
257 |
258 | def pretrained_imputer_fid(model_file, save_file, batch_size=64):
259 | model = torch.load(model_file, map_location='cpu')
260 | if 'imputer' not in model:
261 | return
262 | args = model['args']
263 |
264 | if args.imputer == 'comp':
265 | Imputer = ComplementImputer
266 | elif args.imputer == 'mask':
267 | Imputer = MaskImputer
268 | elif args.imputer == 'fix':
269 | Imputer = FixedNoiseDimImputer
270 |
271 | hid_lens = [int(n) for n in args.arch.split('-')]
272 | imputer = Imputer(arch=hid_lens).to(device)
273 | imputer.load_state_dict(model['imputer'])
274 |
275 | block_len = args.block_len
276 | if block_len == 0:
277 | block_len = None
278 |
279 | if args.mask == 'indep':
280 | data = IndepMaskedMNIST(obs_prob=args.obs_prob,
281 | obs_prob_high=args.obs_prob_high)
282 | elif args.mask == 'block':
283 | data = BlockMaskedMNIST(block_len=block_len)
284 |
285 | fid = imputer_fid(imputer, data, verbose=True)
286 | with save_file.open('w') as f:
287 | print(fid, file=f)
288 | print('imputer fid:', fid)
289 |
290 |
291 | def main():
292 | parser = argparse.ArgumentParser()
293 | parser.add_argument('root_dir')
294 | parser.add_argument('--skip-exist', action='store_true')
295 | args = parser.parse_args()
296 |
297 | skip_exist = args.skip_exist
298 |
299 | root_dir = Path(args.root_dir)
300 | fid_file = root_dir / f'fid-{feature_layer}.txt'
301 | if skip_exist and fid_file.exists():
302 | return
303 | try:
304 | model_file = max((root_dir / 'model').glob('*.pth'))
305 | except ValueError:
306 | return
307 |
308 | fid = pretrained_misgan_fid(model_file)
309 | print(f'{root_dir.name}: {fid}')
310 | with fid_file.open('w') as f:
311 | print(fid, file=f)
312 |
313 | # Compute FID for the imputer if it is in the model
314 | imputer_fid_file = root_dir / f'impute-fid-{feature_layer}.txt'
315 | pretrained_imputer_fid(model_file, imputer_fid_file)
316 |
317 |
318 | if __name__ == '__main__':
319 | main()
320 |
--------------------------------------------------------------------------------
/src/mnist_generator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | def add_data_transformer(self):
7 | self.transform = lambda x: torch.sigmoid(x).view(-1, 1, 28, 28)
8 |
9 |
10 | def add_mask_transformer(self, temperature=.66, hard_sigmoid=(-.1, 1.1)):
11 | """
12 | hard_sigmoid:
13 | False: use sigmoid only
14 | True: hard thresholding
15 | (a, b): hard thresholding on rescaled sigmoid
16 | """
17 | self.temperature = temperature
18 | self.hard_sigmoid = hard_sigmoid
19 |
20 | view = -1, 1, 28, 28
21 | if hard_sigmoid is False:
22 | self.transform = lambda x: torch.sigmoid(x / temperature).view(*view)
23 | elif hard_sigmoid is True:
24 | self.transform = lambda x: F.hardtanh(
25 | x / temperature, 0, 1).view(*view)
26 | else:
27 | a, b = hard_sigmoid
28 | self.transform = lambda x: F.hardtanh(
29 | torch.sigmoid(x / temperature) * (b - a) + a, 0, 1).view(*view)
30 |
31 |
32 | # Must sub-class ConvGenerator to provide transform()
33 | class ConvGenerator(nn.Module):
34 | def __init__(self, latent_size=128):
35 | super().__init__()
36 |
37 | self.DIM = 64
38 | self.latent_size = latent_size
39 |
40 | self.preprocess = nn.Sequential(
41 | nn.Linear(latent_size, 4 * 4 * 4 * self.DIM),
42 | nn.ReLU(True),
43 | )
44 | self.block1 = nn.Sequential(
45 | nn.ConvTranspose2d(4 * self.DIM, 2 * self.DIM, 5),
46 | nn.ReLU(True),
47 | )
48 | self.block2 = nn.Sequential(
49 | nn.ConvTranspose2d(2 * self.DIM, self.DIM, 5),
50 | nn.ReLU(True),
51 | )
52 | self.deconv_out = nn.ConvTranspose2d(self.DIM, 1, 8, stride=2)
53 |
54 | def forward(self, input):
55 | net = self.preprocess(input)
56 | net = net.view(-1, 4 * self.DIM, 4, 4)
57 | net = self.block1(net)
58 | net = net[:, :, :7, :7]
59 | net = self.block2(net)
60 | net = self.deconv_out(net)
61 | return self.transform(net)
62 |
63 |
64 | # Must sub-class FCGenerator to provide transform()
65 | class FCGenerator(nn.Module):
66 | def __init__(self, latent_size=128):
67 | super().__init__()
68 | self.latent_size = latent_size
69 | self.fc = nn.Sequential(
70 | nn.Linear(latent_size, 256),
71 | nn.ReLU(True),
72 | nn.Linear(256, 512),
73 | nn.ReLU(True),
74 | nn.Linear(512, 784),
75 | )
76 |
77 | def forward(self, input):
78 | net = self.fc(input)
79 | return self.transform(net)
80 |
81 |
82 | class ConvDataGenerator(ConvGenerator):
83 | def __init__(self, latent_size=128):
84 | super().__init__(latent_size=latent_size)
85 | add_data_transformer(self)
86 |
87 |
88 | class FCDataGenerator(FCGenerator):
89 | def __init__(self, latent_size=128):
90 | super().__init__(latent_size=latent_size)
91 | add_data_transformer(self)
92 |
93 |
94 | class ConvMaskGenerator(ConvGenerator):
95 | def __init__(self, latent_size=128, temperature=.66,
96 | hard_sigmoid=(-.1, 1.1)):
97 | super().__init__(latent_size=latent_size)
98 | add_mask_transformer(self, temperature, hard_sigmoid)
99 |
100 |
101 | class FCMaskGenerator(FCGenerator):
102 | def __init__(self, latent_size=128, temperature=.66,
103 | hard_sigmoid=(-.1, 1.1)):
104 | super().__init__(latent_size=latent_size)
105 | add_mask_transformer(self, temperature, hard_sigmoid)
106 |
--------------------------------------------------------------------------------
/src/mnist_imputer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | # Must sub-class Imputer to provide fc1
7 | class Imputer(nn.Module):
8 | def __init__(self, arch=(784, 784)):
9 | super().__init__()
10 | # self.fc1 = nn.Linear(784, arch[0])
11 | self.fc2 = nn.Linear(arch[0], arch[1])
12 | self.fc3 = nn.Linear(arch[1], arch[0])
13 | self.fc4 = nn.Linear(arch[0], 784)
14 | self.transform = lambda x: torch.sigmoid(x).view(-1, 1, 28, 28)
15 |
16 | def forward(self, input, data, mask):
17 | net = input.view(input.size(0), -1)
18 | net = F.relu(self.fc1(net))
19 | net = F.relu(self.fc2(net))
20 | net = F.relu(self.fc3(net))
21 | net = self.fc4(net)
22 | net = self.transform(net)
23 | # return data * mask + net * (1 - mask)
24 | # NOT replacing observed part with input data for computing
25 | # autoencoding loss.
26 | return net
27 |
28 |
29 | class ComplementImputer(Imputer):
30 | def __init__(self, arch=(784, 784)):
31 | super().__init__(arch=arch)
32 | self.fc1 = nn.Linear(784, arch[0])
33 |
34 | def forward(self, input, mask, noise):
35 | net = input * mask + noise * (1 - mask)
36 | return super().forward(net, input, mask)
37 |
38 |
39 | class MaskImputer(Imputer):
40 | def __init__(self, arch=(784, 784)):
41 | super().__init__(arch=arch)
42 | self.fc1 = nn.Linear(784 * 2, arch[0])
43 |
44 | def forward(self, input, mask, noise):
45 | batch_size = input.size(0)
46 | net = torch.cat(
47 | [(input * mask + noise * (1 - mask)).view(batch_size, -1),
48 | mask.view(batch_size, -1)], 1)
49 | return super().forward(net, input, mask)
50 |
51 |
52 | class FixedNoiseDimImputer(Imputer):
53 | def __init__(self, arch=(784, 784)):
54 | super().__init__(arch=arch)
55 | self.fc1 = nn.Linear(784 * 3, arch[0])
56 |
57 | def forward(self, input, mask, noise):
58 | batch_size = input.size(0)
59 | net = torch.cat([(input * mask).view(batch_size, -1),
60 | mask.view(batch_size, -1),
61 | noise.view(batch_size, -1)], 1)
62 | return super().forward(net, input, mask)
63 |
--------------------------------------------------------------------------------
/src/mnist_misgan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datetime import datetime
3 | from pathlib import Path
4 | import argparse
5 | from mnist_generator import (ConvDataGenerator, FCDataGenerator,
6 | ConvMaskGenerator, FCMaskGenerator)
7 | from mnist_critic import ConvCritic, FCCritic
8 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST
9 | from misgan import misgan
10 |
11 |
12 | use_cuda = torch.cuda.is_available()
13 | device = torch.device('cuda' if use_cuda else 'cpu')
14 |
15 |
16 | def main():
17 | parser = argparse.ArgumentParser()
18 |
19 | # resume from checkpoint
20 | parser.add_argument('--resume')
21 | # training options
22 | parser.add_argument('--epoch', type=int, default=500)
23 | parser.add_argument('--batch-size', type=int, default=64)
24 |
25 | # log options: 0 to disable plot-interval or save-interval
26 | parser.add_argument('--plot-interval', type=int, default=50)
27 | parser.add_argument('--save-interval', type=int, default=0)
28 | parser.add_argument('--prefix', default='misgan')
29 |
30 | # mask options (data): block|indep
31 | parser.add_argument('--mask', default='block')
32 | # option for block: set to 0 for variable size
33 | parser.add_argument('--block-len', type=int, default=14)
34 | # option for indep:
35 | parser.add_argument('--obs-prob', type=float, default=.2)
36 | parser.add_argument('--obs-prob-high', type=float, default=None)
37 |
38 | # model options
39 | parser.add_argument('--tau', type=float, default=0)
40 | parser.add_argument('--generator', default='conv') # conv|fc
41 | parser.add_argument('--critic', default='conv') # conv|fc
42 | # parser.add_argument('--alpha', type=float, default=.1) # 0: separate
43 | parser.add_argument('--alpha', type=float, default=.2) # 0: separate
44 | # options for mask generator: sigmoid, hardsigmoid, fusion
45 | # parser.add_argument('--maskgen', default='fusion')
46 | parser.add_argument('--maskgen', default='sigmoid')
47 | parser.add_argument('--gp-lambda', type=float, default=10)
48 | parser.add_argument('--n-critic', type=int, default=5)
49 | parser.add_argument('--n-latent', type=int, default=128)
50 |
51 | args = parser.parse_args()
52 |
53 | checkpoint = None
54 | # Resume from previously stored checkpoint
55 | if args.resume:
56 | print(f'Resume: {args.resume}')
57 | output_dir = Path(args.resume)
58 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'),
59 | map_location='cpu')
60 | for key, arg in vars(checkpoint['args']).items():
61 | if key not in ['resume']:
62 | setattr(args, key, arg)
63 |
64 | if args.generator == 'conv':
65 | DataGenerator = ConvDataGenerator
66 | MaskGenerator = ConvMaskGenerator
67 | elif args.generator == 'fc':
68 | DataGenerator = FCDataGenerator
69 | MaskGenerator = FCMaskGenerator
70 | else:
71 | raise NotImplementedError
72 |
73 | if args.critic == 'conv':
74 | Critic = ConvCritic
75 | elif args.critic == 'fc':
76 | Critic = FCCritic
77 | else:
78 | raise NotImplementedError
79 |
80 | if args.maskgen == 'sigmoid':
81 | hard_sigmoid = False
82 | elif args.maskgen == 'hardsigmoid':
83 | hard_sigmoid = True
84 | elif args.maskgen == 'fusion':
85 | hard_sigmoid = -.1, 1.1
86 | else:
87 | raise NotImplementedError
88 |
89 | mask = args.mask
90 | obs_prob = args.obs_prob
91 | obs_prob_high = args.obs_prob_high
92 | block_len = args.block_len
93 | if block_len == 0:
94 | block_len = None
95 |
96 | if mask == 'indep':
97 | if obs_prob_high is None:
98 | mask_str = f'indep_{obs_prob:g}'
99 | else:
100 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}'
101 | elif mask == 'block':
102 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize')
103 | else:
104 | raise NotImplementedError
105 |
106 | path = '{}_{}_{}'.format(
107 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'),
108 | '_'.join([
109 | f'gen_{args.generator}',
110 | f'critic_{args.critic}',
111 | f'tau_{args.tau:g}',
112 | f'alpha_{args.alpha:g}',
113 | f'maskgen_{args.maskgen}',
114 | mask_str,
115 | ]))
116 |
117 | if not args.resume:
118 | output_dir = Path('results') / 'mnist' / path
119 | print(output_dir)
120 |
121 | if mask == 'indep':
122 | data = IndepMaskedMNIST(obs_prob=obs_prob, obs_prob_high=obs_prob_high)
123 | elif mask == 'block':
124 | data = BlockMaskedMNIST(block_len=block_len)
125 |
126 | data_gen = DataGenerator().to(device)
127 | mask_gen = MaskGenerator(hard_sigmoid=hard_sigmoid).to(device)
128 |
129 | data_critic = Critic().to(device)
130 | mask_critic = Critic().to(device)
131 |
132 | misgan(args, data_gen, mask_gen, data_critic, mask_critic, data,
133 | output_dir, checkpoint)
134 |
135 |
136 | if __name__ == '__main__':
137 | main()
138 |
--------------------------------------------------------------------------------
/src/mnist_misgan_impute.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datetime import datetime
3 | from pathlib import Path
4 | import argparse
5 | from mnist_generator import (ConvDataGenerator, FCDataGenerator,
6 | ConvMaskGenerator, FCMaskGenerator)
7 | from mnist_imputer import (ComplementImputer,
8 | MaskImputer,
9 | FixedNoiseDimImputer)
10 | from mnist_critic import ConvCritic, FCCritic
11 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST
12 | from misgan_impute import misgan_impute
13 |
14 |
15 | use_cuda = torch.cuda.is_available()
16 | device = torch.device('cuda' if use_cuda else 'cpu')
17 |
18 |
19 | def main():
20 | parser = argparse.ArgumentParser()
21 |
22 | # resume from checkpoint
23 | parser.add_argument('--resume')
24 |
25 | # training options
26 | parser.add_argument('--workers', type=int, default=0)
27 | parser.add_argument('--epoch', type=int, default=1000)
28 | parser.add_argument('--batch-size', type=int, default=64)
29 | parser.add_argument('--pretrain', default=None)
30 | parser.add_argument('--imputeronly', action='store_true')
31 |
32 | # log options: 0 to disable plot-interval or save-interval
33 | parser.add_argument('--plot-interval', type=int, default=100)
34 | parser.add_argument('--save-interval', type=int, default=0)
35 | parser.add_argument('--prefix', default='impute')
36 |
37 | # mask options (data): block|indep
38 | parser.add_argument('--mask', default='block')
39 | # option for block: set to 0 for variable size
40 | parser.add_argument('--block-len', type=int, default=14)
41 | # option for indep:
42 | parser.add_argument('--obs-prob', type=float, default=.2)
43 | parser.add_argument('--obs-prob-high', type=float, default=None)
44 |
45 | # model options
46 | parser.add_argument('--tau', type=float, default=0)
47 | parser.add_argument('--generator', default='conv') # conv|fc
48 | parser.add_argument('--critic', default='conv') # conv|fc
49 | parser.add_argument('--alpha', type=float, default=.1) # 0: separate
50 | parser.add_argument('--beta', type=float, default=.1)
51 | parser.add_argument('--gamma', type=float, default=0)
52 | parser.add_argument('--arch', default='784-784')
53 | parser.add_argument('--imputer', default='comp') # comp|mask|fix
54 | # options for mask generator: sigmoid, hardsigmoid, fusion
55 | parser.add_argument('--maskgen', default='fusion')
56 | parser.add_argument('--gp-lambda', type=float, default=10)
57 | parser.add_argument('--n-critic', type=int, default=5)
58 | parser.add_argument('--n-latent', type=int, default=128)
59 |
60 | args = parser.parse_args()
61 |
62 | checkpoint = None
63 | # Resume from previously stored checkpoint
64 | if args.resume:
65 | print(f'Resume: {args.resume}')
66 | output_dir = Path(args.resume)
67 | checkpoint = torch.load(str(output_dir / 'log' / 'checkpoint.pth'),
68 | map_location='cpu')
69 | for key, arg in vars(checkpoint['args']).items():
70 | if key not in ['resume']:
71 | setattr(args, key, arg)
72 |
73 | if args.imputeronly:
74 | assert args.pretrain is not None
75 |
76 | arch = args.arch
77 | imputer_type = args.imputer
78 | mask = args.mask
79 | obs_prob = args.obs_prob
80 | obs_prob_high = args.obs_prob_high
81 | block_len = args.block_len
82 | if block_len == 0:
83 | block_len = None
84 |
85 | if args.generator == 'conv':
86 | DataGenerator = ConvDataGenerator
87 | MaskGenerator = ConvMaskGenerator
88 | elif args.generator == 'fc':
89 | DataGenerator = FCDataGenerator
90 | MaskGenerator = FCMaskGenerator
91 | else:
92 | raise NotImplementedError
93 |
94 | if imputer_type == 'comp':
95 | Imputer = ComplementImputer
96 | elif imputer_type == 'mask':
97 | Imputer = MaskImputer
98 | elif imputer_type == 'fix':
99 | Imputer = FixedNoiseDimImputer
100 | else:
101 | raise NotImplementedError
102 |
103 | if args.critic == 'conv':
104 | Critic = ConvCritic
105 | elif args.critic == 'fc':
106 | Critic = FCCritic
107 | else:
108 | raise NotImplementedError
109 |
110 | if args.maskgen == 'sigmoid':
111 | hard_sigmoid = False
112 | elif args.maskgen == 'hardsigmoid':
113 | hard_sigmoid = True
114 | elif args.maskgen == 'fusion':
115 | hard_sigmoid = -.1, 1.1
116 | else:
117 | raise NotImplementedError
118 |
119 | if mask == 'indep':
120 | if obs_prob_high is None:
121 | mask_str = f'indep_{obs_prob:g}'
122 | else:
123 | mask_str = f'indep_{obs_prob:g}_{obs_prob_high:g}'
124 | elif mask == 'block':
125 | mask_str = 'block_{}'.format(block_len if block_len else 'varsize')
126 | else:
127 | raise NotImplementedError
128 |
129 | path = '{}_{}_{}'.format(
130 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'),
131 | '_'.join([
132 | f'gen_{args.generator}',
133 | f'critic_{args.critic}',
134 | f'imp_{args.imputer}',
135 | f'tau_{args.tau:g}',
136 | f'arch_{args.arch}',
137 | f'maskgen_{args.maskgen}',
138 | f'coef_{args.alpha:g}_{args.beta:g}_{args.gamma:g}',
139 | mask_str
140 | ]))
141 |
142 | if not args.resume:
143 | output_dir = Path('results') / 'mnist' / path
144 | print(output_dir)
145 |
146 | if mask == 'indep':
147 | data = IndepMaskedMNIST(
148 | obs_prob=obs_prob, obs_prob_high=obs_prob_high)
149 | elif mask == 'block':
150 | data = BlockMaskedMNIST(block_len=block_len)
151 |
152 | data_gen = DataGenerator().to(device)
153 | mask_gen = MaskGenerator(hard_sigmoid=hard_sigmoid).to(device)
154 |
155 | hid_lens = [int(n) for n in arch.split('-')]
156 | imputer = Imputer(arch=hid_lens).to(device)
157 |
158 | data_critic = Critic().to(device)
159 | mask_critic = Critic().to(device)
160 | impu_critic = Critic().to(device)
161 |
162 | misgan_impute(args, data_gen, mask_gen, imputer,
163 | data_critic, mask_critic, impu_critic,
164 | data, output_dir, checkpoint)
165 |
166 |
167 | if __name__ == '__main__':
168 | main()
169 |
--------------------------------------------------------------------------------
/src/mnist_model.py:
--------------------------------------------------------------------------------
1 | """
2 | Code adapted from https://github.com/pytorch/examples/blob/master/mnist/main.py
3 | """
4 | import argparse
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.optim as optim
9 | from torchvision import datasets, transforms
10 |
11 |
12 | class Net(nn.Module):
13 | def __init__(self):
14 | super(Net, self).__init__()
15 | self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
16 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
17 | self.conv2_drop = nn.Dropout2d()
18 | self.fc1 = nn.Linear(320, 50)
19 | self.fc2 = nn.Linear(50, 10)
20 |
21 | def forward(self, x):
22 | feature = []
23 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
24 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
25 | x = x.view(-1, 320)
26 | x = self.fc1(x)
27 | feature.append(x)
28 | x = F.relu(x)
29 | x = F.dropout(x, training=self.training)
30 | x = self.fc2(x)
31 | feature.append(x)
32 | self.feature = feature
33 | return F.log_softmax(x, dim=1)
34 |
35 |
36 | def main():
37 | # Training settings
38 | parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
39 | parser.add_argument('--batch-size', type=int, default=64, metavar='N',
40 | help='input batch size for training (default: 64)')
41 | parser.add_argument('--test-batch-size', type=int,
42 | default=1000, metavar='N',
43 | help='input batch size for testing (default: 1000)')
44 | parser.add_argument('--epochs', type=int, default=100, metavar='N',
45 | help='number of epochs to train (default: 100)')
46 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
47 | help='learning rate (default: 0.01)')
48 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
49 | help='SGD momentum (default: 0.5)')
50 | parser.add_argument('--no-cuda', action='store_true', default=False,
51 | help='disables CUDA training')
52 | parser.add_argument('--seed', type=int, default=1, metavar='S',
53 | help='random seed (default: 1)')
54 | parser.add_argument('--log-interval', type=int, default=10, metavar='N',
55 | help='number of batches to wait before logging '
56 | 'training status')
57 | args = parser.parse_args()
58 | args.cuda = not args.no_cuda and torch.cuda.is_available()
59 |
60 | torch.manual_seed(args.seed)
61 | if args.cuda:
62 | torch.cuda.manual_seed(args.seed)
63 |
64 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
65 | train_loader = torch.utils.data.DataLoader(
66 | datasets.MNIST('../data', train=True, download=True,
67 | transform=transforms.Compose([
68 | transforms.ToTensor(),
69 | transforms.Normalize((0.1307,), (0.3081,))
70 | ])),
71 | batch_size=args.batch_size, shuffle=True, **kwargs)
72 | test_loader = torch.utils.data.DataLoader(
73 | datasets.MNIST('../data', train=False, transform=transforms.Compose([
74 | transforms.ToTensor(),
75 | transforms.Normalize((0.1307,), (0.3081,))
76 | ])),
77 | batch_size=args.test_batch_size, shuffle=True, **kwargs)
78 |
79 | model = Net()
80 | if args.cuda:
81 | model.cuda()
82 |
83 | optimizer = optim.SGD(model.parameters(), lr=args.lr,
84 | momentum=args.momentum)
85 |
86 | def train(epoch):
87 | model.train()
88 | for batch_idx, (data, target) in enumerate(train_loader):
89 | if args.cuda:
90 | data, target = data.cuda(), target.cuda()
91 | optimizer.zero_grad()
92 | output = model(data)
93 | loss = F.nll_loss(output, target)
94 | loss.backward()
95 | optimizer.step()
96 | if batch_idx % args.log_interval == 0:
97 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
98 | epoch, batch_idx * len(data), len(train_loader.dataset),
99 | 100. * batch_idx / len(train_loader), loss.item()))
100 |
101 | def test():
102 | model.eval()
103 | test_loss = 0
104 | correct = 0
105 | with torch.no_grad():
106 | for data, target in test_loader:
107 | if args.cuda:
108 | data, target = data.cuda(), target.cuda()
109 | output = model(data)
110 | # sum up batch loss
111 | test_loss += F.nll_loss(output, target, reduction='sum').item()
112 | # get the index of the max log-probability
113 | pred = output.argmax(dim=1, keepdim=True)
114 | correct += (pred == target.view_as(pred)).long().cpu().sum()
115 |
116 | test_loss /= len(test_loader.dataset)
117 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
118 | .format(test_loss, correct, len(test_loader.dataset),
119 | 100. * correct / len(test_loader.dataset)))
120 |
121 | for epoch in range(1, args.epochs + 1):
122 | train(epoch)
123 | test()
124 |
125 | torch.save(model.state_dict(), 'mnist.pth')
126 |
127 |
128 | if __name__ == '__main__':
129 | main()
130 |
--------------------------------------------------------------------------------
/src/plot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pylab as plt
3 | from matplotlib.patches import Rectangle
4 | from PIL import Image
5 |
6 |
7 | def plot_grid(image, bbox=None, gap=0, gap_value=1,
8 | nrow=4, ncol=8, save_file=None):
9 | image = image.cpu().numpy()
10 | channels, len0, len1 = image[0].shape
11 | grid = np.empty(
12 | (nrow * (len0 + gap) - gap, ncol * (len1 + gap) - gap, channels))
13 | # Convert to W, H, C
14 | image = image.transpose((0, 2, 3, 1))
15 | grid.fill(gap_value)
16 |
17 | for i, x in enumerate(image):
18 | if i >= nrow * ncol:
19 | break
20 | p0 = (i // ncol) * (len0 + gap)
21 | p1 = (i % ncol) * (len1 + gap)
22 | grid[p0:(p0 + len0), p1:(p1 + len1)] = x
23 |
24 | # figsize = np.r_[ncol, nrow] * .75
25 | scale = 2.5
26 | figsize = ncol * scale, nrow * scale # FIXME: scale by len0, len1
27 | fig = plt.figure(figsize=figsize)
28 | ax = plt.Axes(fig, [0, 0, 1, 1])
29 | ax.set_axis_off()
30 | fig.add_axes(ax)
31 | grid = grid.squeeze()
32 | ax.imshow(grid, cmap='binary_r', interpolation='none', aspect='equal')
33 |
34 | if bbox is not None:
35 | nplot = min(len(image), nrow * ncol)
36 | for i in range(nplot):
37 | if len(bbox) == 1:
38 | d0, d1, d0_len, d1_len = bbox[0]
39 | else:
40 | d0, d1, d0_len, d1_len = bbox[i]
41 | p0 = (i // ncol) * (len0 + gap)
42 | p1 = (i % ncol) * (len1 + gap)
43 | offset = np.array([p1 + d1, p0 + d0]) - .5
44 | ax.add_patch(Rectangle(
45 | offset, d1_len, d0_len, lw=4, edgecolor='red', fill=False))
46 |
47 | if save_file:
48 | fig.savefig(save_file)
49 | plt.close(fig)
50 |
51 |
52 | def plot_samples(samples, save_file, nrow=4, ncol=8):
53 | x = samples.cpu().numpy()
54 | channels, len0, len1 = x[0].shape
55 | x_merge = np.zeros((nrow * len0, ncol * len1, channels))
56 |
57 | for i, x_ in enumerate(x):
58 | if i >= nrow * ncol:
59 | break
60 | p0 = (i // ncol) * len0
61 | p1 = (i % ncol) * len1
62 | x_merge[p0:(p0 + len0), p1:(p1 + len1)] = x_.transpose((1, 2, 0))
63 |
64 | x_merge = (x_merge * 255).clip(0, 255).astype(np.uint8)
65 | # squeeze() to remove the last dimension for the single-channel image.
66 | im = Image.fromarray(x_merge.squeeze())
67 | im.save(save_file)
68 |
--------------------------------------------------------------------------------
/src/unet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | # Code adapted from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
6 | #
7 | # Defines the submodule with skip connection.
8 | # X -------------------identity---------------------- X
9 | # |-- downsampling -- |submodule| -- upsampling --|
10 | class UnetSkipConnectionBlock(nn.Module):
11 | def __init__(self, outer_nc, inner_nc, input_nc=None,
12 | submodule=None, outermost=False, innermost=False,
13 | norm_layer=nn.BatchNorm2d):
14 | super().__init__()
15 | self.outermost = outermost
16 | use_bias = norm_layer == nn.InstanceNorm2d
17 | if input_nc is None:
18 | input_nc = outer_nc
19 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
20 | stride=2, padding=1, bias=use_bias)
21 | downrelu = nn.LeakyReLU(0.2, True)
22 | if norm_layer is not None:
23 | downnorm = norm_layer(inner_nc)
24 | upnorm = norm_layer(outer_nc)
25 | uprelu = nn.ReLU(True)
26 |
27 | if outermost:
28 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
29 | kernel_size=4, stride=2,
30 | padding=1)
31 | down = [downconv]
32 | up = [uprelu, upconv]
33 | model = down + [submodule] + up
34 | elif innermost:
35 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
36 | kernel_size=4, stride=2,
37 | padding=1, bias=use_bias)
38 | down = [downrelu, downconv]
39 | up = [uprelu, upconv]
40 | if norm_layer is not None:
41 | up.append(upnorm)
42 | model = down + up
43 | else:
44 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
45 | kernel_size=4, stride=2,
46 | padding=1, bias=use_bias)
47 | down = [downrelu, downconv]
48 | up = [uprelu, upconv]
49 | if norm_layer is not None:
50 | down.append(downnorm)
51 | up.append(upnorm)
52 |
53 | model = down + [submodule] + up
54 |
55 | self.model = nn.Sequential(*model)
56 |
57 | def forward(self, x):
58 | if self.outermost:
59 | return self.model(x)
60 | else:
61 | return torch.cat([x, self.model(x)], 1)
62 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import grad
2 |
3 |
4 | class CriticUpdater:
5 | def __init__(self, critic, critic_optimizer, eps, ones, gp_lambda=10):
6 | self.critic = critic
7 | self.critic_optimizer = critic_optimizer
8 | self.eps = eps
9 | self.ones = ones
10 | self.gp_lambda = gp_lambda
11 |
12 | def __call__(self, real, fake):
13 | real = real.detach()
14 | fake = fake.detach()
15 | self.critic.zero_grad()
16 | self.eps.uniform_(0, 1)
17 | interp = (self.eps * real + (1 - self.eps) * fake).requires_grad_()
18 | grad_d = grad(self.critic(interp), interp, grad_outputs=self.ones,
19 | create_graph=True)[0]
20 | grad_d = grad_d.view(real.shape[0], -1)
21 | grad_penalty = ((grad_d.norm(dim=1) - 1)**2).mean() * self.gp_lambda
22 | w_dist = self.critic(fake).mean() - self.critic(real).mean()
23 | loss = w_dist + grad_penalty
24 | loss.backward()
25 | self.critic_optimizer.step()
26 | self.loss_value = loss.item()
27 |
28 |
29 | def mask_norm(diff, mask):
30 | """Mask normalization"""
31 | dim = 1, 2, 3
32 | # Assume mask.sum(1) is non-zero throughout
33 | return ((diff * mask).sum(dim) / mask.sum(dim)).mean()
34 |
35 |
36 | def mkdir(path):
37 | path.mkdir(parents=True, exist_ok=True)
38 | return path
39 |
40 |
41 | def mask_data(data, mask, tau):
42 | return mask * data + (1 - mask) * tau
43 |
--------------------------------------------------------------------------------