├── LICENSE
├── README.md
├── examples
├── images.png
└── time-series.png
├── image
├── celeba-data
│ └── download-celeba.sh
├── celeba_decoder.py
├── celeba_encoder.py
├── celeba_fid.py
├── celeba_pbigan.py
├── celeba_pbigan_fid.py
├── celeba_pvae.py
├── fid.py
├── flow.py
├── inception.py
├── masked_celeba.py
├── masked_mnist.py
├── mmd.py
├── mnist_decoder.py
├── mnist_encoder.py
├── mnist_pbigan.py
├── mnist_pvae.py
├── utils.py
└── visualize.py
├── requirements.txt
└── time-series
├── .gitattributes
├── ema.py
├── evaluate.py
├── figures
├── pbigan.png
└── pvae.png
├── flow.py
├── gen_toy_data.py
├── layers.py
├── mimic3_pbigan.py
├── mimic3_pvae.py
├── mmd.py
├── sn_layers.py
├── spline_cconv.py
├── time_series.py
├── toy_layers.py
├── toy_pbigan.py
├── toy_pvae.py
├── toy_time_series.ipynb
├── tracker.py
├── utils.py
└── vis.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Steven Cheng-Xian Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning from Irregularly-Sampled Time Series: A Missing Data Perspective
2 |
3 | This repository provides a PyTorch implementation of the paper
4 | ["Learning from Irregularly-Sampled Time Series: A Missing Data Perspective"](https://arxiv.org/abs/2008.07599).
5 |
6 |
7 | ## Requirements
8 |
9 | This repository requires Python 3.6 or later.
10 | The file [requirements.txt](requirements.txt) contains the full list of
11 | required Python modules and their version that we tested on.
12 | To install requirements:
13 |
14 | ```sh
15 | pip install -r requirements.txt
16 | ```
17 |
18 |
19 | ## Image
20 |
21 |
22 |
23 | ### MNIST
24 |
25 | Under the `image` directory, the following commands train P-VAE and P-BiGAN for
26 | incomplete MNIST:
27 |
28 | ```sh
29 | # P-VAE:
30 | python mnist_pvae.py
31 | # P-BiGAN:
32 | python mnist_pbigan.py
33 | ```
34 |
35 | ### CelebA
36 |
37 | For CelebA, you need to download the dataset from its
38 | [website](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).
39 | Specifically, you may either:
40 | * Download the file `img_align_celeba.zip` from [this link](https://drive.google.com/uc?export=download&id=0B7EVK8r0v71pZjFTYXZWM3FlRnM)
41 | and extract the zip file into the directory `image/celeba-data`, or
42 | * Run the script [download-celeba.sh](image/celeba-data/download-celeba.sh)
43 | under the directory `image/celeba-data`. Make sure you have
44 | [curl](https://curl.haxx.se) on your system.
45 |
46 | ```sh
47 | cd image/celeba-data && bash download-celeba.sh
48 | ```
49 |
50 | Under the `image` directory, the following commands train P-VAE and P-BiGAN for
51 | incomplete CelebA:
52 | ```sh
53 | # P-VAE:
54 | python celeba_pvae.py
55 | # P-BiGAN:
56 | python celeba_pbigan.py
57 | ```
58 |
59 | ### Command-line options
60 |
61 | For both MNIST and CelebA scripts, using the option
62 | `--mask block --block-len n` to specify "square observation" missingness
63 | with n-by-n observed blocks and
64 | `--mask indep --obs-prob .2` to specify "independent dropout" missingness
65 | with 80% missing pixels.
66 |
67 | Use `-h` to see all the available command-line options for each script
68 | (also for the scripts for time series described below).
69 |
70 | ## Time Series
71 |
72 | Our implementation takes as input a time series dataset in a format
73 | composed of three tensors `time`, `data`, `mask` saved as numpy's
74 | [npz](https://numpy.org/doc/stable/reference/generated/numpy.savez_compressed.html) file.
75 | For a time series of `N` data cases, each of which has `C` channels
76 | with each channel having at most `L` observations (time-value pairs),
77 | it is represented by three tensors `time`, `data` and `mask`
78 | of size `(N, C, L)`:
79 |
80 | * `mask` is the binary mask indicating which entries in `time` and `data`
81 | correspond to a missing value.
82 | `mask[n, c, k]` is 1 if the `k`-th entry of the `c`-th channel of the
83 | `n`-th time series is observed, and 0 if it is missing.
84 | * `time` stores the timestamps of the time series rescaled to the range [0, 1].
85 | Note that for those missing entries, whose corresponding
86 | `mask` entry is zero, they must be set to values within [0, 1]
87 | for the decoder to work correctly.
88 | The easiest way is to set those to zero by `time *= mask`.
89 | * `data` stores the corresponding time series values associated with `time`.
90 | For those missing entries, they may contain arbitrary values.
91 |
92 |
93 | The script [gen_toy_data.py](time-series/gen_toy_data.py) is an example
94 | of creating a synthetic time series dataset in such format.
95 |
96 | ### Synthetic data
97 |
98 | [This notebook](https://nbviewer.jupyter.org/github/steveli/partial-encoder-decoder/blob/master/time-series/toy_time_series.ipynb)
99 | provides an overview of P-VAE and P-BiGAN
100 | and demonstrates how to train them on a synthetic dataset.
101 |
102 |
103 |
104 | Under the `time-series` directory, the following commands train a P-VAE
105 | and P-BiGAN on a synthetic multivariate time series dataset:
106 |
107 | ```sh
108 | # P-VAE:
109 | python toy_pvae.py
110 | # P-BiGAN:
111 | python toy_pbigan.py
112 | ```
113 |
114 | ### MIMIC-III
115 |
116 | MIMIC-III can be downloaded following the instructions from
117 | its [website](https://mimic.physionet.org/gettingstarted/access/).
118 |
119 | For the experiments, we apply the optional preprocessing
120 | used in [this work](https://github.com/mlds-lab/interp-net)
121 | to the MIMIC-III dataset.
122 |
123 | For time series classification task, our implementation takes as input
124 | one of the following three labeled time series data format:
125 |
126 | 1. Unsplit format with an additional label vector with the following 4 fields.
127 | The data will be randomly split into train/test/validation set.
128 | * `(time|data|mask)`: numpy array of shape `(N, C, L)` as described before.
129 | * `label`: binary label of shape `(N,)`.
130 | 2. Data come with train/test split with the following 8 fields.
131 | The training set will be
132 | subsequently split into a smaller training set (80%)
133 | and a validation set (20%).
134 | * `(train|test)_(time|data|mask)`
135 | * `(train|test)_label`
136 | 3. Data come with train/test/validation split with the following 12 fields.
137 | This is useful for model selection based on the metric evaluated
138 | on the validation set with multiple runs (with different randomness).
139 | * `(train|test|val)_(time|data|mask)`
140 | * `(train|test|val)_label`
141 |
142 | The function `split_data` in [time_series.py](time-series/time_series.py)
143 | demonstrates how the data file is read and split
144 | into training/test/validation set.
145 | You can follow this to create time series data of your own.
146 |
147 | Once the time series data is ready,
148 | run the following command under the `time-series` directory:
149 |
150 | ```sh
151 | # P-VAE:
152 | python mimic3_pvae.py
153 | # P-BiGAN:
154 | python mimic3_pbigan.py
155 | ```
156 |
157 | ## Citation
158 |
159 | If you find our work relevant to your research, please cite:
160 |
161 | ```bibtex
162 | @InProceedings{li2020learning,
163 | title = {Learning from Irregularly-Sampled Time Series: A Missing Data Perspective},
164 | author = {Li, Steven Cheng-Xian and Marlin, Benjamin M.},
165 | booktitle = {Proceedings of the 37th International Conference on Machine Learning},
166 | year = {2020}
167 | }
168 | ```
169 |
170 | ## Contact
171 |
172 | Your feedback would be greatly appreciated!
173 | Reach us at
.
174 |
--------------------------------------------------------------------------------
/examples/images.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/steveli/partial-encoder-decoder/402abdb97b1cdfcca2a30b05507cd97be560ee75/examples/images.png
--------------------------------------------------------------------------------
/examples/time-series.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/steveli/partial-encoder-decoder/402abdb97b1cdfcca2a30b05507cd97be560ee75/examples/time-series.png
--------------------------------------------------------------------------------
/image/celeba-data/download-celeba.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | fileid=0B7EVK8r0v71pZjFTYXZWM3FlRnM
4 | filename=img_align_celeba.zip
5 |
6 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null
7 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename}
8 | unzip $filename
9 | rm -f cookie $filename
10 |
--------------------------------------------------------------------------------
/image/celeba_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def dconv_bn_relu(in_dim, out_dim):
6 | return nn.Sequential(
7 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
8 | padding=2, output_padding=1, bias=False),
9 | nn.BatchNorm2d(out_dim),
10 | nn.ReLU())
11 |
12 |
13 | class ConvDecoder(nn.Module):
14 | def __init__(self, latent_size=128):
15 | super().__init__()
16 |
17 | self.out_channels = 3
18 | dim = 64
19 |
20 | self.l1 = nn.Sequential(
21 | nn.Linear(latent_size, dim * 8 * 4 * 4, bias=False),
22 | nn.BatchNorm1d(dim * 8 * 4 * 4),
23 | nn.ReLU())
24 |
25 | self.l2_5 = nn.Sequential(
26 | dconv_bn_relu(dim * 8, dim * 4),
27 | dconv_bn_relu(dim * 4, dim * 2),
28 | dconv_bn_relu(dim * 2, dim),
29 | nn.ConvTranspose2d(dim, self.out_channels, 5, 2,
30 | padding=2, output_padding=1))
31 |
32 | def forward(self, input):
33 | x = self.l1(input)
34 | x = x.view(x.shape[0], -1, 4, 4)
35 | x = self.l2_5(x)
36 | return x, torch.sigmoid(x)
37 |
--------------------------------------------------------------------------------
/image/celeba_encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from torch.distributions.normal import Normal
4 | import flow
5 |
6 |
7 | def conv_ln_lrelu(in_dim, out_dim):
8 | return nn.Sequential(
9 | nn.Conv2d(in_dim, out_dim, 5, 2, 2),
10 | nn.InstanceNorm2d(out_dim, affine=True),
11 | nn.LeakyReLU(0.2))
12 |
13 |
14 | class ConvEncoder(nn.Module):
15 | def __init__(self, latent_size, flow_depth=2, logprob=False):
16 | super().__init__()
17 |
18 | if logprob:
19 | self.encode_func = self.encode_logprob
20 | else:
21 | self.encode_func = self.encode
22 |
23 | dim = 64
24 | self.ls = nn.Sequential(
25 | nn.Conv2d(3, dim, 5, 2, 2), nn.LeakyReLU(0.2),
26 | conv_ln_lrelu(dim, dim * 2),
27 | conv_ln_lrelu(dim * 2, dim * 4),
28 | conv_ln_lrelu(dim * 4, dim * 8),
29 | nn.Conv2d(dim * 8, latent_size, 4))
30 |
31 | if flow_depth > 0:
32 | # IAF
33 | hidden_size = latent_size * 2
34 | flow_layers = [flow.InverseAutoregressiveFlow(
35 | latent_size, hidden_size, latent_size)
36 | for _ in range(flow_depth)]
37 |
38 | flow_layers.append(flow.Reverse(latent_size))
39 | self.q_z_flow = flow.FlowSequential(*flow_layers)
40 | self.enc_chunk = 3
41 | else:
42 | self.q_z_flow = None
43 | self.enc_chunk = 2
44 |
45 | fc_out_size = latent_size * self.enc_chunk
46 | self.fc = nn.Sequential(
47 | nn.Linear(latent_size, fc_out_size),
48 | nn.LayerNorm(fc_out_size),
49 | nn.LeakyReLU(0.2),
50 | nn.Linear(fc_out_size, fc_out_size),
51 | )
52 |
53 | def forward(self, input, k_samples=5):
54 | return self.encode_func(input, k_samples)
55 |
56 | def encode_logprob(self, input, k_samples=5):
57 | x = self.ls(input)
58 | x = x.view(input.shape[0], -1)
59 | fc_out = self.fc(x).chunk(self.enc_chunk, dim=1)
60 | mu, logvar = fc_out[:2]
61 | std = F.softplus(logvar)
62 | qz_x = Normal(mu, std)
63 | z = qz_x.rsample([k_samples])
64 | log_q_z = qz_x.log_prob(z)
65 | if self.q_z_flow:
66 | z, log_q_z_flow = self.q_z_flow(z, context=fc_out[2])
67 | log_q_z = (log_q_z + log_q_z_flow).sum(-1)
68 | else:
69 | log_q_z = log_q_z.sum(-1)
70 | return z, log_q_z
71 |
72 | def encode(self, input, _):
73 | x = self.ls(input)
74 | x = x.view(input.shape[0], -1)
75 | fc_out = self.fc(x).chunk(self.enc_chunk, dim=1)
76 | mu, logvar = fc_out[:2]
77 | std = F.softplus(logvar)
78 | qz_x = Normal(mu, std)
79 | z = qz_x.rsample()
80 | if self.q_z_flow:
81 | z, _ = self.q_z_flow(z, context=fc_out[2])
82 | return z
83 |
--------------------------------------------------------------------------------
/image/celeba_fid.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | from torchvision import datasets, transforms
3 | from PIL import Image
4 | from fid import FID
5 |
6 |
7 | class CelebAFID(FID):
8 | def __init__(self, batch_size=256, data_name='celeba',
9 | workers=0, verbose=True):
10 | self.batch_size = batch_size
11 | self.workers = workers
12 | super().__init__(data_name, verbose)
13 |
14 | def complete_data(self):
15 | data = datasets.ImageFolder(
16 | 'celeba',
17 | transforms.Compose([
18 | transforms.CenterCrop(108),
19 | transforms.Resize(size=64, interpolation=Image.BICUBIC),
20 | transforms.ToTensor(),
21 | # transforms.Normalize(mean=(.5, .5, .5), std=(.5, .5, .5)),
22 | ]))
23 |
24 | images = len(data)
25 | data_loader = DataLoader(
26 | data, batch_size=self.batch_size, num_workers=self.workers)
27 |
28 | return data_loader, images
29 |
--------------------------------------------------------------------------------
/image/celeba_pbigan.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import grad
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | from torch.utils.data import DataLoader
8 | from pathlib import Path
9 | from datetime import datetime
10 | import pprint
11 | import argparse
12 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA
13 | from celeba_decoder import ConvDecoder
14 | from celeba_encoder import ConvEncoder, conv_ln_lrelu
15 | from mmd import mmd
16 | from utils import mkdir, make_scheduler
17 | from visualize import Visualizer
18 |
19 |
20 | use_cuda = torch.cuda.is_available()
21 | device = torch.device('cuda' if use_cuda else 'cpu')
22 |
23 |
24 | class PBiGAN(nn.Module):
25 | def __init__(self, encoder, decoder, ae_loss='mse'):
26 | super().__init__()
27 | self.encoder = encoder
28 | self.decoder = decoder
29 | self.ae_loss = ae_loss
30 |
31 | def forward(self, x, mask, ae=True):
32 | z_T = self.encoder(x * mask)
33 |
34 | z_gen = torch.empty_like(z_T).normal_()
35 | x_gen_logit, x_gen = self.decoder(z_gen)
36 |
37 | x_logit, x_recon = self.decoder(z_T)
38 |
39 | recon_loss = 0
40 | if ae:
41 | if self.ae_loss == 'mse':
42 | recon_loss = F.mse_loss(
43 | x_recon * mask, x * mask, reduction='none') * mask
44 | recon_loss = recon_loss.sum((1, 2, 3)).mean()
45 | elif self.ae_loss == 'l1':
46 | recon_loss = F.l1_loss(
47 | x_recon * mask, x * mask, reduction='none') * mask
48 | recon_loss = recon_loss.sum((1, 2, 3)).mean()
49 | elif self.ae_loss == 'smooth_l1':
50 | recon_loss = F.smooth_l1_loss(
51 | x_recon * mask, x * mask, reduction='none') * mask
52 | recon_loss = recon_loss.sum((1, 2, 3)).mean()
53 | elif self.ae_loss == 'bce':
54 | # Bernoulli noise
55 | # recon_loss: -log p(x|z)
56 | recon_loss = F.binary_cross_entropy_with_logits(
57 | x_logit * mask, x * mask, reduction='none') * mask
58 | recon_loss = recon_loss.sum((1, 2, 3)).mean()
59 |
60 | return z_T, z_gen, x_recon, x_gen, recon_loss
61 |
62 | def impute(self, x, mask):
63 | self.eval()
64 | with torch.no_grad():
65 | z_T = self.encoder(x * mask)
66 | _, x_recon = self.decoder(z_T)
67 | self.train()
68 | return x_recon
69 |
70 |
71 | class ConvCritic(nn.Module):
72 | def __init__(self, latent_size):
73 | super().__init__()
74 |
75 | dim = 64
76 | self.conv = nn.Sequential(
77 | nn.Conv2d(3, dim, 5, 2, 2), nn.LeakyReLU(0.2),
78 | conv_ln_lrelu(dim, dim * 2),
79 | conv_ln_lrelu(dim * 2, dim * 4),
80 | conv_ln_lrelu(dim * 4, dim * 8),
81 | nn.Conv2d(dim * 8, latent_size, 4))
82 |
83 | embed_size = 64
84 |
85 | self.z_fc = nn.Sequential(
86 | nn.Linear(latent_size, embed_size),
87 | nn.LayerNorm(embed_size),
88 | nn.LeakyReLU(0.2),
89 | nn.Linear(embed_size, embed_size),
90 | )
91 |
92 | self.x_fc = nn.Linear(latent_size, embed_size)
93 |
94 | self.xz_fc = nn.Sequential(
95 | nn.Linear(embed_size * 2, embed_size),
96 | nn.LayerNorm(embed_size),
97 | nn.LeakyReLU(0.2),
98 | nn.Linear(embed_size, 1),
99 | )
100 |
101 | def forward(self, input):
102 | x, z = input
103 | x = self.conv(x)
104 | x = x.view(x.shape[0], -1)
105 | x = self.x_fc(x)
106 | z = self.z_fc(z)
107 | xz = torch.cat((x, z), 1)
108 | xz = self.xz_fc(xz)
109 | return xz.view(-1)
110 |
111 |
112 | class GradientPenalty:
113 | def __init__(self, critic, batch_size=64, gp_lambda=10):
114 | self.critic = critic
115 | self.gp_lambda = gp_lambda
116 | # Interpolation coefficient
117 | self.eps = torch.empty(batch_size, device=device)
118 | # For computing the gradient penalty
119 | self.ones = torch.ones(batch_size).to(device)
120 |
121 | def interpolate(self, real, fake):
122 | eps = self.eps.view([-1] + [1] * (len(real.shape) - 1))
123 | return (eps * real + (1 - eps) * fake).requires_grad_()
124 |
125 | def __call__(self, real, fake):
126 | real = [x.detach() for x in real]
127 | fake = [x.detach() for x in fake]
128 | self.eps.uniform_(0, 1)
129 | interp = [self.interpolate(a, b) for a, b in zip(real, fake)]
130 | grad_d = grad(self.critic(interp),
131 | interp,
132 | grad_outputs=self.ones,
133 | create_graph=True)
134 | batch_size = real[0].shape[0]
135 | grad_d = torch.cat([g.view(batch_size, -1) for g in grad_d], 1)
136 | grad_penalty = ((grad_d.norm(dim=1) - 1)**2).mean() * self.gp_lambda
137 | return grad_penalty
138 |
139 |
140 | def train_pbigan(args):
141 | torch.manual_seed(args.seed)
142 |
143 | if args.mask == 'indep':
144 | data = IndepMaskedCelebA(obs_prob=args.obs_prob)
145 | mask_str = f'{args.mask}_{args.obs_prob}'
146 | elif args.mask == 'block':
147 | data = BlockMaskedCelebA(block_len=args.block_len)
148 | mask_str = f'{args.mask}_{args.block_len}'
149 |
150 | data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True,
151 | drop_last=True)
152 | mask_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True,
153 | drop_last=True)
154 |
155 | test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True)
156 |
157 | decoder = ConvDecoder(args.latent)
158 | encoder = ConvEncoder(args.latent, args.flow, logprob=False)
159 | pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)
160 |
161 | critic = ConvCritic(args.latent).to(device)
162 |
163 | optimizer = optim.Adam(pbigan.parameters(), lr=args.lr, betas=(.5, .9))
164 |
165 | critic_optimizer = optim.Adam(
166 | critic.parameters(), lr=args.lr, betas=(.5, .9))
167 |
168 | grad_penalty = GradientPenalty(critic, args.batch_size)
169 |
170 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch)
171 |
172 | path = '{}_{}_{}'.format(
173 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), mask_str)
174 | output_dir = Path('results') / 'celeba-pbigan' / path
175 | mkdir(output_dir)
176 | print(output_dir)
177 |
178 | if args.save_interval > 0:
179 | model_dir = mkdir(output_dir / 'model')
180 |
181 | with (output_dir / 'args.txt').open('w') as f:
182 | print(pprint.pformat(vars(args)), file=f)
183 |
184 | vis = Visualizer(output_dir, loss_xlim=(0, args.epoch))
185 |
186 | test_x, test_mask, index = iter(test_loader).next()
187 | test_x = test_x.to(device)
188 | test_mask = test_mask.to(device).float()
189 | bbox = None
190 | if data.mask_loc is not None:
191 | bbox = [data.mask_loc[idx] for idx in index]
192 |
193 | n_critic = 5
194 | critic_updates = 0
195 | ae_weight = 0
196 |
197 | for epoch in range(args.epoch):
198 | loss_breakdown = defaultdict(float)
199 |
200 | if epoch >= args.ae_start:
201 | ae_weight = args.ae
202 |
203 | for (x, mask, _), (_, mask_gen, _) in zip(data_loader, mask_loader):
204 | x = x.to(device)
205 | mask = mask.to(device).float()
206 | mask_gen = mask_gen.to(device).float()
207 |
208 | if critic_updates < n_critic:
209 | z_enc, z_gen, x_rec, x_gen, _ = pbigan(x, mask, ae=False)
210 |
211 | real_score = critic((x * mask, z_enc)).mean()
212 | fake_score = critic((x_gen * mask_gen, z_gen)).mean()
213 |
214 | w_dist = real_score - fake_score
215 | D_loss = -w_dist + grad_penalty((x * mask, z_enc),
216 | (x_gen * mask_gen, z_gen))
217 |
218 | critic_optimizer.zero_grad()
219 | D_loss.backward()
220 | critic_optimizer.step()
221 |
222 | loss_breakdown['D'] += D_loss.item()
223 |
224 | critic_updates += 1
225 | else:
226 | critic_updates = 0
227 |
228 | # Update generators' parameters
229 | for p in critic.parameters():
230 | p.requires_grad_(False)
231 |
232 | z_enc, z_gen, x_rec, x_gen, ae_loss = pbigan(
233 | x, mask, ae=(args.ae > 0))
234 |
235 | real_score = critic((x * mask, z_enc)).mean()
236 | fake_score = critic((x_gen * mask_gen, z_gen)).mean()
237 |
238 | G_loss = real_score - fake_score
239 |
240 | ae_loss = ae_loss * ae_weight
241 | loss = G_loss + ae_loss
242 |
243 | mmd_loss = 0
244 | if args.mmd > 0:
245 | mmd_loss = mmd(z_enc, z_gen)
246 | loss += mmd_loss * args.mmd
247 |
248 | optimizer.zero_grad()
249 | loss.backward()
250 | optimizer.step()
251 |
252 | loss_breakdown['G'] += G_loss.item()
253 | if torch.is_tensor(ae_loss):
254 | loss_breakdown['AE'] += ae_loss.item()
255 | if torch.is_tensor(mmd_loss):
256 | loss_breakdown['MMD'] += mmd_loss.item()
257 | loss_breakdown['total'] += loss.item()
258 |
259 | for p in critic.parameters():
260 | p.requires_grad_(True)
261 |
262 | if scheduler:
263 | scheduler.step()
264 |
265 | vis.plot_loss(epoch, loss_breakdown)
266 |
267 | if epoch % args.plot_interval == 0:
268 | with torch.no_grad():
269 | pbigan.eval()
270 | z, z_gen, x_rec, x_gen, ae_loss = pbigan(test_x, test_mask)
271 | pbigan.train()
272 | vis.plot(epoch, test_x, test_mask, bbox, x_rec, x_gen)
273 |
274 | model_dict = {
275 | 'pbigan': pbigan.state_dict(),
276 | 'critic': critic.state_dict(),
277 | 'history': vis.history,
278 | 'epoch': epoch,
279 | 'args': args,
280 | }
281 | torch.save(model_dict, str(output_dir / 'model.pth'))
282 | if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0:
283 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))
284 |
285 | print(output_dir)
286 |
287 |
288 | def main():
289 | parser = argparse.ArgumentParser()
290 |
291 | parser.add_argument('--seed', type=int, default=3,
292 | help='random seed')
293 | # training options
294 | parser.add_argument('--plot-interval', type=int, default=10,
295 | help='plot interval. 0 to disable plotting.')
296 | parser.add_argument('--save-interval', type=int, default=0,
297 | help='interval to save models. 0 to disable saving.')
298 | parser.add_argument('--mask', default='block',
299 | help='missing data mask. (options: block, indep)')
300 | # option for block: set to 0 for variable size
301 | parser.add_argument('--block-len', type=int, default=32,
302 | help='size of observed block. '
303 | 'Set to 0 to use variable size')
304 | # option for indep:
305 | parser.add_argument('--obs-prob', type=float, default=.2,
306 | help='observed probability for independent dropout')
307 |
308 | parser.add_argument('--flow', type=int, default=2,
309 | help='number of IAF layers')
310 | parser.add_argument('--lr', type=float, default=2e-4,
311 | help='learning rate')
312 | parser.add_argument('--min-lr', type=float, default=5e-5,
313 | help='min learning rate for LR scheduler. '
314 | '-1 to disable annealing')
315 |
316 | parser.add_argument('--epoch', type=int, default=500,
317 | help='number of training epochs')
318 | parser.add_argument('--batch-size', type=int, default=256,
319 | help='batch size')
320 | parser.add_argument('--ae', type=float, default=.002,
321 | help='autoencoding regularization strength')
322 | parser.add_argument('--ae-start', type=int, default=0,
323 | help='start epoch of autoencoding regularization')
324 | parser.add_argument('--prefix', default='pbigan',
325 | help='prefix of output directory')
326 | parser.add_argument('--latent', type=int, default=128,
327 | help='dimension of latent variable')
328 | parser.add_argument('--aeloss', default='smooth_l1',
329 | help='autoencoding loss. '
330 | '(options: mse, bce, smooth_l1, l1)')
331 | # --mmd 0 to disable mmd regularization
332 | parser.add_argument('--mmd', type=float, default=0,
333 | help='MMD strength for latent variable')
334 |
335 | args = parser.parse_args()
336 |
337 | train_pbigan(args)
338 |
339 |
340 | if __name__ == '__main__':
341 | main()
342 |
--------------------------------------------------------------------------------
/image/celeba_pbigan_fid.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 | import torch
4 | from torch.utils.data import DataLoader
5 | from fid import BaseSampler, BaseImputationSampler
6 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA
7 | from celeba_fid import CelebAFID
8 | from celeba_decoder import ConvDecoder
9 | from celeba_encoder import ConvEncoder
10 | from celeba_pbigan import PBiGAN
11 |
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('root_dir')
15 | parser.add_argument('--batch-size', type=int, default=256)
16 | parser.add_argument('--only', action='store_true')
17 | args = parser.parse_args()
18 |
19 |
20 | use_cuda = torch.cuda.is_available()
21 | device = torch.device('cuda' if use_cuda else 'cpu')
22 |
23 |
24 | class Sampler(BaseSampler):
25 | def __init__(self, model, latent_size, images=60000, batch_size=256):
26 | super().__init__(images)
27 | self.model = model
28 | self.rand_z = torch.empty(batch_size, latent_size, device=device)
29 |
30 | def sample(self):
31 | self.rand_z.normal_()
32 | return self.model.decoder(self.rand_z)[1]
33 |
34 |
35 | class ImputationSampler(BaseImputationSampler):
36 | def __init__(self, data_loader, model, batch_size=256):
37 | super().__init__(data_loader)
38 | self.model = model
39 |
40 | def impute(self, data, mask):
41 | z = self.model.encoder(data * mask)
42 | x_recon = self.model.decoder(z)[1]
43 | imputed_data = data * mask + x_recon * (1 - mask)
44 | return imputed_data
45 |
46 |
47 | class Data:
48 | def __init__(self, args, batch_size):
49 | self.args = args
50 | self.batch_size = batch_size
51 | self.data_loader = None
52 |
53 | def gen_data(self):
54 | args = self.args
55 | if args.mask == 'indep':
56 | data = IndepMaskedCelebA(obs_prob=args.obs_prob)
57 | elif args.mask == 'block':
58 | data = BlockMaskedCelebA(block_len=args.block_len)
59 |
60 | self.data_size = len(data)
61 | self.data_loader = DataLoader(data, batch_size=self.batch_size)
62 |
63 | def get_data(self):
64 | if self.data_loader is None:
65 | self.gen_data()
66 | return self.data_loader, self.data_size
67 |
68 |
69 | def pretrained_misgan_fid(model_file, data_loader, data_size):
70 | model = torch.load(model_file, map_location='cpu')
71 |
72 | model_args = model['args']
73 | decoder = ConvDecoder(model_args.latent)
74 | encoder = ConvEncoder(model_args.latent, model_args.flow, logprob=False)
75 | pbigan = PBiGAN(encoder, decoder, model_args.aeloss).to(device)
76 | pbigan.load_state_dict(model['pbigan'])
77 |
78 | batch_size = args.batch_size
79 |
80 | pbigan.eval()
81 | with torch.no_grad():
82 | compute_fid = CelebAFID(batch_size=batch_size)
83 | sampler = Sampler(pbigan, model_args.latent, data_size, batch_size)
84 | gen_fid = compute_fid.fid(sampler, data_size)
85 | print('fid: {:.2f}'.format(gen_fid))
86 |
87 | imputation_sampler = ImputationSampler(data_loader, pbigan, batch_size)
88 | imp_fid = compute_fid.fid(imputation_sampler, data_size)
89 | print('impute fid: {:.2f}'.format(imp_fid))
90 |
91 | return gen_fid, imp_fid
92 |
93 |
94 | def gen_fid_file(model_file, fid_file, imp_fid_file, data):
95 | if imp_fid_file.exists():
96 | print('skip')
97 | return
98 |
99 | fid, imp_fid = pretrained_misgan_fid(model_file, *data.get_data())
100 |
101 | with fid_file.open('w') as f:
102 | print(fid, file=f)
103 |
104 | if imp_fid is not None:
105 | with imp_fid_file.open('w') as f:
106 | print(imp_fid, file=f)
107 |
108 |
109 | def main():
110 | root_dir = Path(args.root_dir)
111 | model_file = root_dir / 'model.pth'
112 | print(model_file)
113 | fid_file = root_dir / 'fid.txt'
114 | imp_fid_file = root_dir / 'impute-fid.txt'
115 |
116 | model = torch.load(model_file, map_location='cpu')
117 | data = Data(model['args'], args.batch_size)
118 |
119 | gen_fid_file(model_file, fid_file, imp_fid_file, data)
120 |
121 | if args.only:
122 | return
123 |
124 | model_dir = root_dir / 'model'
125 | for model_file in sorted(model_dir.glob('*.pth')):
126 | print(model_file)
127 | fid_file = model_dir / f'{model_file.stem}-fid.txt'
128 | imp_fid_file = model_dir / f'{model_file.stem}-impute-fid.txt'
129 | gen_fid_file(model_file, fid_file, imp_fid_file, data)
130 |
131 |
132 | if __name__ == '__main__':
133 | main()
134 |
--------------------------------------------------------------------------------
/image/celeba_pvae.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | from torch.distributions.normal import Normal
7 | from torch.distributions.categorical import Categorical
8 | from torch.utils.data import DataLoader
9 | import math
10 | import sys
11 | import logging
12 | from pathlib import Path
13 | from datetime import datetime
14 | import pprint
15 | import argparse
16 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA
17 | from celeba_decoder import ConvDecoder
18 | from celeba_encoder import ConvEncoder
19 | from utils import mkdir, make_scheduler
20 | from visualize import Visualizer
21 |
22 |
23 | use_cuda = torch.cuda.is_available()
24 | device = torch.device('cuda' if use_cuda else 'cpu')
25 |
26 |
27 | class PVAE(nn.Module):
28 | def __init__(self, encoder, decoder):
29 | super().__init__()
30 | self.encoder = encoder
31 | self.decoder = decoder
32 |
33 | def forward(self, x, mask, k_samples=5, kl_weight=1, ae_weight=1):
34 | z_T, log_q_z = self.encoder(x * mask, k_samples)
35 |
36 | pz = Normal(torch.zeros_like(z_T), torch.ones_like(z_T))
37 | log_p_z = pz.log_prob(z_T).sum(-1)
38 | # kl_loss: log q(z|x) - log p(z)
39 | kl_loss = log_q_z - log_p_z
40 |
41 | # Reshape z to accommodate modules with strict input shape requirements
42 | # such as convolutional layers.
43 | x_logit, x_recon = self.decoder(z_T.view(-1, *z_T.shape[2:]))
44 | expanded_mask = mask[None]
45 | masked_logit = x_logit.view(k_samples, *x.shape) * expanded_mask
46 | masked_x = (x * mask)[None].expand_as(masked_logit)
47 | # Bernoulli noise
48 | bce = F.binary_cross_entropy_with_logits(
49 | masked_logit, masked_x, reduction='none')
50 | recon_loss = (bce * expanded_mask).sum((2, 3, 4))
51 |
52 | # elbo = log p(x|z) + log p(z) - log q(z|x)
53 | elbo = -(recon_loss * ae_weight + kl_loss * kl_weight)
54 |
55 | # IWAE loss: -log E[p(x|z) p(z) / q(z|x)]
56 | # Here we ignore the constant shift of -log(k_samples)
57 | loss = -elbo.logsumexp(0).mean()
58 |
59 | x_recon = x_recon.view(-1, *x.shape)
60 | loss_breakdown = {
61 | 'loss': loss.item(),
62 | 'KL': kl_loss.mean().item(),
63 | 'recon': recon_loss.mean().item(),
64 | }
65 | return loss, z_T, x_recon, elbo, loss_breakdown
66 |
67 | def impute(self, x, mask, k_samples=10):
68 | self.eval()
69 | with torch.no_grad():
70 | _, z, x_recon, elbo, _ = self(x, mask, k_samples)
71 | # sampling importance resampling
72 | is_idx = Categorical(logits=elbo.t()).sample()
73 | batch_idx = torch.arange(len(x))
74 | z = z[is_idx, batch_idx]
75 | x_recon = x_recon[is_idx, batch_idx]
76 | self.train()
77 | return x_recon
78 |
79 |
80 | def train_pvae(args):
81 | torch.manual_seed(args.seed)
82 |
83 | if args.mask == 'indep':
84 | data = IndepMaskedCelebA(obs_prob=args.obs_prob)
85 | mask_str = f'{args.mask}_{args.obs_prob}'
86 | elif args.mask == 'block':
87 | data = BlockMaskedCelebA(block_len=args.block_len)
88 | mask_str = f'{args.mask}_{args.block_len}'
89 |
90 | data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True,
91 | drop_last=True)
92 |
93 | test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True)
94 |
95 | decoder = ConvDecoder(args.latent)
96 | encoder = ConvEncoder(args.latent, args.flow, logprob=True)
97 | pvae = PVAE(encoder, decoder).to(device)
98 |
99 | optimizer = optim.Adam(pvae.parameters(), lr=args.lr)
100 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch)
101 |
102 | rand_z = torch.empty(args.batch_size, args.latent, device=device)
103 |
104 | path = '{}_{}_{}'.format(
105 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), mask_str)
106 | output_dir = Path('results') / 'celeba-pvae' / path
107 | mkdir(output_dir)
108 | print(output_dir)
109 |
110 | if args.save_interval > 0:
111 | model_dir = mkdir(output_dir / 'model')
112 |
113 | logging.basicConfig(
114 | level=logging.INFO,
115 | format='%(asctime)s %(message)s',
116 | datefmt='%Y-%m-%d %H:%M:%S',
117 | handlers=[
118 | logging.FileHandler(output_dir / 'log.txt'),
119 | logging.StreamHandler(sys.stdout),
120 | ],
121 | )
122 |
123 | with (output_dir / 'args.txt').open('w') as f:
124 | print(pprint.pformat(vars(args)), file=f)
125 |
126 | vis = Visualizer(output_dir, loss_xlim=(0, args.epoch))
127 |
128 | test_x, test_mask, index = iter(test_loader).next()
129 | test_x = test_x.to(device)
130 | test_mask = test_mask.to(device).float()
131 | bbox = None
132 | if data.mask_loc is not None:
133 | bbox = [data.mask_loc[idx] for idx in index]
134 |
135 | kl_center = (args.kl_on + args.kl_off) / 2
136 | kl_scale = 1 / min(args.kl_on - args.kl_off, 1)
137 |
138 | for epoch in range(args.epoch):
139 | if epoch >= args.kl_on:
140 | kl_weight = 1
141 | elif epoch < args.kl_off:
142 | kl_weight = 0
143 | else:
144 | kl_weight = 1 / (1 + math.exp(-(epoch - kl_center) * kl_scale))
145 | loss_breakdown = defaultdict(float)
146 | for x, mask, _ in data_loader:
147 | x = x.to(device)
148 | mask = mask.to(device).float()
149 |
150 | optimizer.zero_grad()
151 | loss, _, _, _, loss_info = pvae(
152 | x, mask, args.k, kl_weight, args.ae)
153 | loss.backward()
154 | optimizer.step()
155 | for name, val in loss_info.items():
156 | loss_breakdown[name] += val
157 |
158 | if scheduler:
159 | scheduler.step()
160 |
161 | vis.plot_loss(epoch, loss_breakdown)
162 |
163 | if epoch % args.plot_interval == 0:
164 | x_recon = pvae.impute(test_x, test_mask, args.k)
165 | with torch.no_grad():
166 | pvae.eval()
167 | rand_z.normal_()
168 | _, x_gen = decoder(rand_z)
169 | pvae.train()
170 | vis.plot(epoch, test_x, test_mask, bbox, x_recon, x_gen)
171 |
172 | model_dict = {
173 | 'pvae': pvae.state_dict(),
174 | 'history': vis.history,
175 | 'epoch': epoch,
176 | 'args': args,
177 | }
178 | torch.save(model_dict, str(output_dir / 'model.pth'))
179 | if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0:
180 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))
181 |
182 | print(output_dir)
183 |
184 |
185 | def main():
186 | parser = argparse.ArgumentParser()
187 |
188 | parser.add_argument('--seed', type=int, default=3,
189 | help='random seed')
190 | # training options
191 | parser.add_argument('--plot-interval', type=int, default=50,
192 | help='plot interval. 0 to disable plotting.')
193 | parser.add_argument('--save-interval', type=int, default=50,
194 | help='interval to save models. 0 to disable saving.')
195 | parser.add_argument('--mask', default='block',
196 | help='missing data mask. (options: block, indep)')
197 | # option for block: set to 0 for variable size
198 | parser.add_argument('--block-len', type=int, default=32,
199 | help='size of observed block. '
200 | 'Set to 0 to use variable size')
201 | # option for indep:
202 | parser.add_argument('--obs-prob', type=float, default=.2,
203 | help='observed probability for independent dropout')
204 |
205 | parser.add_argument('--flow', type=int, default=2,
206 | help='number of IAF layers')
207 | parser.add_argument('--lr', type=float, default=1e-4,
208 | help='learning rate')
209 | parser.add_argument('--min-lr', type=float, default=-1,
210 | help='min learning rate for LR scheduler. '
211 | '-1 to disable annealing')
212 |
213 | parser.add_argument('--epoch', type=int, default=500,
214 | help='number of training epochs')
215 | parser.add_argument('--batch-size', type=int, default=256,
216 | help='batch size')
217 | parser.add_argument('--k', type=int, default=5,
218 | help='number of importance weights')
219 | parser.add_argument('--prefix', default='pvae',
220 | help='prefix of output directory')
221 | parser.add_argument('--latent', type=int, default=128,
222 | help='dimension of latent variable')
223 | parser.add_argument('--kl-off', type=int, default=10,
224 | help='epoch to start tune up KL weight from zero')
225 | # set --kl-on to 0 to use constant kl_weight = 1
226 | parser.add_argument('--kl-on', type=int, default=20,
227 | help='start epoch to use KL weight 1')
228 | parser.add_argument('--ae', type=float, default=1,
229 | help='log-likelihood weight for ELBO')
230 |
231 | args = parser.parse_args()
232 |
233 | train_pvae(args)
234 |
235 |
236 | if __name__ == '__main__':
237 | main()
238 |
--------------------------------------------------------------------------------
/image/fid.py:
--------------------------------------------------------------------------------
1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs
2 |
3 | **Code apapted from https://github.com/mseitzer/pytorch-fid**
4 |
5 | The FID metric calculates the distance between two distributions of images.
6 | Typically, we have summary statistics (mean & covariance matrix) of one
7 | of these distributions, while the 2nd distribution is given by a GAN.
8 |
9 | When run as a stand-alone program, it compares the distribution of
10 | images that are stored as PNG/JPEG at a specified location with a
11 | distribution given by summary statistics (in pickle format).
12 |
13 | The FID is calculated by assuming that X_1 and X_2 are the activations of
14 | the pool_3 layer of the inception net for generated samples and real world
15 | samples respectivly.
16 |
17 | See --help to see further details.
18 |
19 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
20 | of Tensorflow
21 |
22 | Copyright 2018 Institute of Bioinformatics, JKU Linz
23 |
24 | Licensed under the Apache License, Version 2.0 (the "License");
25 | you may not use this file except in compliance with the License.
26 | You may obtain a copy of the License at
27 |
28 | http://www.apache.org/licenses/LICENSE-2.0
29 |
30 | Unless required by applicable law or agreed to in writing, software
31 | distributed under the License is distributed on an "AS IS" BASIS,
32 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33 | See the License for the specific language governing permissions and
34 | limitations under the License.
35 | """
36 |
37 | from pathlib import Path
38 | import torch
39 | import numpy as np
40 | from scipy import linalg
41 | import time
42 | import sys
43 | from inception import InceptionV3
44 |
45 |
46 | use_cuda = torch.cuda.is_available()
47 | device = torch.device('cuda' if use_cuda else 'cpu')
48 |
49 | FEATURE_DIM = 2048
50 | RESIZE = 299
51 |
52 |
53 | def get_activations(image_iterator, images, model, verbose=True):
54 | """Calculates the activations of the pool_3 layer for all images.
55 |
56 | Params:
57 | -- image_iterator
58 | : A generator that generates a batch of images at a time.
59 | -- images : Number of images that will be generated by
60 | image_iterator.
61 | -- model : Instance of inception model
62 | -- verbose : If set to True and parameter out_step is given, the number
63 | of calculated batches is reported.
64 | Returns:
65 | -- A numpy array of dimension (num images, dims) that contains the
66 | activations of the given tensor when feeding inception with the
67 | query tensor.
68 | """
69 | model.eval()
70 |
71 | if not sys.stdout.isatty():
72 | verbose = False
73 |
74 | pred_arr = np.empty((images, FEATURE_DIM))
75 | end = 0
76 | t0 = time.time()
77 |
78 | for batch in image_iterator:
79 | if not isinstance(batch, torch.Tensor):
80 | batch = batch[0]
81 | start = end
82 | batch_size = batch.shape[0]
83 | end = start + batch_size
84 |
85 | with torch.no_grad():
86 | batch = batch.to(device)
87 | pred = model(batch)[0]
88 | batch_feature = pred.cpu().numpy().reshape(batch_size, -1)
89 | pred_arr[start:end] = batch_feature
90 |
91 | if verbose:
92 | print('\rProcessed: {} time: {:.2f}'.format(
93 | end, time.time() - t0), end='', flush=True)
94 |
95 | assert end == images
96 |
97 | if verbose:
98 | print(' done')
99 |
100 | return pred_arr
101 |
102 |
103 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
104 | """Numpy implementation of the Frechet Distance.
105 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
106 | and X_2 ~ N(mu_2, C_2) is
107 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
108 |
109 | Stable version by Dougal J. Sutherland.
110 |
111 | Params:
112 | -- mu1 : Numpy array containing the activations of a layer of the
113 | inception net (like returned by the function 'get_predictions')
114 | for generated samples.
115 | -- mu2 : The sample mean over activations, precalculated on an
116 | representive data set.
117 | -- sigma1: The covariance matrix over activations for generated samples.
118 | -- sigma2: The covariance matrix over activations, precalculated on an
119 | representive data set.
120 |
121 | Returns:
122 | -- : The Frechet Distance.
123 | """
124 |
125 | mu1 = np.atleast_1d(mu1)
126 | mu2 = np.atleast_1d(mu2)
127 |
128 | sigma1 = np.atleast_2d(sigma1)
129 | sigma2 = np.atleast_2d(sigma2)
130 |
131 | assert mu1.shape == mu2.shape, \
132 | 'Training and test mean vectors have different lengths'
133 | assert sigma1.shape == sigma2.shape, \
134 | 'Training and test covariances have different dimensions'
135 |
136 | diff = mu1 - mu2
137 |
138 | # Product might be almost singular
139 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
140 | if not np.isfinite(covmean).all():
141 | msg = ('fid calculation produces singular product; '
142 | 'adding %s to diagonal of cov estimates') % eps
143 | print(msg)
144 | offset = np.eye(sigma1.shape[0]) * eps
145 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
146 |
147 | # Numerical error might give slight imaginary component
148 | if np.iscomplexobj(covmean):
149 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
150 | m = np.max(np.abs(covmean.imag))
151 | raise ValueError('Imaginary component {}'.format(m))
152 | covmean = covmean.real
153 |
154 | tr_covmean = np.trace(covmean)
155 |
156 | return (diff.dot(diff) + np.trace(sigma1) +
157 | np.trace(sigma2) - 2 * tr_covmean)
158 |
159 |
160 | def calculate_activation_statistics(image_iterator, images, model,
161 | verbose=False):
162 | """Calculation of the statistics used by the FID.
163 | Params:
164 | -- image_iterator
165 | : A generator that generates a batch of images at a time.
166 | -- images : Number of images that will be generated by
167 | image_iterator.
168 | -- model : Instance of inception model
169 | -- verbose : If set to True and parameter out_step is given, the
170 | number of calculated batches is reported.
171 | Returns:
172 | -- mu : The mean over samples of the activations of the pool_3 layer of
173 | the inception model.
174 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
175 | the inception model.
176 | """
177 | act = get_activations(image_iterator, images, model, verbose)
178 | mu = np.mean(act, axis=0)
179 | sigma = np.cov(act, rowvar=False)
180 | return mu, sigma
181 |
182 |
183 | class FID:
184 | def __init__(self, data_name, verbose=True):
185 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[FEATURE_DIM]
186 | model = InceptionV3([block_idx], RESIZE).to(device)
187 | self.verbose = verbose
188 |
189 | stats_dir = Path('fid_stats')
190 | stats_file = stats_dir / '{}_act_{}_{}.npz'.format(
191 | data_name, FEATURE_DIM, RESIZE)
192 |
193 | try:
194 | f = np.load(str(stats_file))
195 | mu, sigma = f['mu'], f['sigma']
196 | f.close()
197 | except FileNotFoundError:
198 | data_loader, images = self.complete_data()
199 | mu, sigma = calculate_activation_statistics(
200 | data_loader, images, model, verbose)
201 | stats_dir.mkdir(parents=True, exist_ok=True)
202 | np.savez(stats_file, mu=mu, sigma=sigma)
203 |
204 | self.model = model
205 | self.stats = mu, sigma
206 |
207 | def complete_data(self):
208 | raise NotImplementedError
209 |
210 | def fid(self, image_iterator, images):
211 | mu, sigma = calculate_activation_statistics(
212 | image_iterator, images, self.model, verbose=self.verbose)
213 | return calculate_frechet_distance(mu, sigma, *self.stats)
214 |
215 |
216 | class BaseSampler:
217 | def __init__(self, images):
218 | self.images = images
219 |
220 | def __iter__(self):
221 | self.n = 0
222 | return self
223 |
224 | def __next__(self):
225 | if self.n < self.images:
226 | batch = self.sample()
227 | batch_size = batch.shape[0]
228 | self.n += batch_size
229 | if self.n > self.images:
230 | return batch[:-(self.n - self.images)]
231 | return batch
232 | else:
233 | raise StopIteration
234 |
235 | def sample(self):
236 | raise NotImplementedError
237 |
238 |
239 | class BaseImputationSampler:
240 | def __init__(self, data_loader):
241 | self.data_loader = data_loader
242 |
243 | def __iter__(self):
244 | self.data_iter = iter(self.data_loader)
245 | return self
246 |
247 | def __next__(self):
248 | data, mask = next(self.data_iter)[:2]
249 | data = data.to(device)
250 | mask = mask.float().to(device)
251 | imputed_data = self.impute(data, mask)
252 | return mask * data + (1 - mask) * imputed_data
253 |
254 | def impute(self, data, mask):
255 | raise NotImplementedError
256 |
--------------------------------------------------------------------------------
/image/flow.py:
--------------------------------------------------------------------------------
1 | """Code adapted from https://github.com/altosaar/variational-autoencoder"""
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 |
8 | class InverseAutoregressiveFlow(nn.Module):
9 | """Inverse Autoregressive Flows with LSTM-type update. One block.
10 |
11 | Eq 11-14 of https://arxiv.org/abs/1606.04934
12 | """
13 | def __init__(self, num_input, num_hidden, num_context):
14 | super().__init__()
15 | self.made = MADE(num_input=num_input, num_output=num_input * 2,
16 | num_hidden=num_hidden, num_context=num_context)
17 | # init such that sigmoid(s) is close to 1 for stability
18 | self.sigmoid_arg_bias = nn.Parameter(torch.ones(num_input) * 2)
19 | self.sigmoid = nn.Sigmoid()
20 | self.log_sigmoid = nn.LogSigmoid()
21 |
22 | def forward(self, input, context=None):
23 | m, s = torch.chunk(self.made(input, context), chunks=2, dim=-1)
24 | s = s + self.sigmoid_arg_bias
25 | sigmoid = self.sigmoid(s)
26 | z = sigmoid * input + (1 - sigmoid) * m
27 | return z, -self.log_sigmoid(s)
28 |
29 |
30 | class FlowSequential(nn.Sequential):
31 | """Forward pass."""
32 |
33 | def forward(self, input, context=None):
34 | total_log_prob = torch.zeros_like(input)
35 | for block in self._modules.values():
36 | input, log_prob = block(input, context)
37 | total_log_prob += log_prob
38 | return input, total_log_prob
39 |
40 |
41 | class MaskedLinear(nn.Module):
42 | """Linear layer with some input-output connections masked."""
43 | def __init__(self, in_features, out_features, mask, context_features=None,
44 | bias=True):
45 | super().__init__()
46 | self.linear = nn.Linear(in_features, out_features, bias)
47 | self.register_buffer("mask", mask)
48 | if context_features is not None:
49 | self.cond_linear = nn.Linear(context_features, out_features,
50 | bias=False)
51 |
52 | def forward(self, input, context=None):
53 | output = F.linear(input, self.mask * self.linear.weight,
54 | self.linear.bias)
55 | if context is None:
56 | return output
57 | else:
58 | return output + self.cond_linear(context)
59 |
60 |
61 | class MADE(nn.Module):
62 | """Implements MADE: Masked Autoencoder for Distribution Estimation.
63 |
64 | Follows https://arxiv.org/abs/1502.03509
65 |
66 | This is used to build MAF:
67 | Masked Autoregressive Flow (https://arxiv.org/abs/1705.07057).
68 | """
69 | def __init__(self, num_input, num_output, num_hidden, num_context):
70 | super().__init__()
71 | # m corresponds to m(k), the maximum degree of a node in the MADE paper
72 | self._m = []
73 | self._masks = []
74 | self._build_masks(num_input, num_output, num_hidden, num_layers=3)
75 | self._check_masks()
76 | modules = []
77 | self.input_context_net = MaskedLinear(
78 | num_input, num_hidden, self._masks[0], num_context)
79 | modules.append(nn.ReLU())
80 | modules.append(MaskedLinear(
81 | num_hidden, num_hidden, self._masks[1], context_features=None))
82 | modules.append(nn.ReLU())
83 | modules.append(MaskedLinear(
84 | num_hidden, num_output, self._masks[2], context_features=None))
85 | self.net = nn.Sequential(*modules)
86 |
87 | def _build_masks(self, num_input, num_output, num_hidden, num_layers):
88 | """Build the masks according to Eq 12 and 13 in the MADE paper."""
89 | rng = np.random.RandomState(0)
90 | # assign input units a number between 1 and D
91 | self._m.append(np.arange(1, num_input + 1))
92 | for i in range(1, num_layers + 1):
93 | # randomly assign maximum number of input nodes to connect to
94 | if i == num_layers:
95 | # assign output layer units a number between 1 and D
96 | m = np.arange(1, num_input + 1)
97 | assert num_output % num_input == 0, (
98 | "num_output must be multiple of num_input")
99 | self._m.append(np.hstack(
100 | [m for _ in range(num_output // num_input)]))
101 | else:
102 | # assign hidden layer units a number between 1 and D-1
103 | self._m.append(rng.randint(1, num_input, size=num_hidden))
104 | # self._m.append(
105 | # np.arange(1, num_hidden + 1) % (num_input - 1) + 1)
106 | if i == num_layers:
107 | mask = self._m[i][None, :] > self._m[i - 1][:, None]
108 | else:
109 | # input to hidden & hidden to hidden
110 | mask = self._m[i][None, :] >= self._m[i - 1][:, None]
111 | # need to transpose for torch linear layer.
112 | # shape (num_output, num_input)
113 | self._masks.append(torch.from_numpy(mask.astype(np.float32).T))
114 |
115 | def _check_masks(self):
116 | """Check that the connectivity matrix between layers is lower
117 | triangular."""
118 | # (num_input, num_hidden)
119 | prev = self._masks[0].t()
120 | for i in range(1, len(self._masks)):
121 | # num_hidden is second axis
122 | prev = prev @ self._masks[i].t()
123 | final = prev.numpy()
124 | num_input = self._masks[0].shape[1]
125 | num_output = self._masks[-1].shape[0]
126 | assert final.shape == (num_input, num_output)
127 | if num_output == num_input:
128 | assert np.triu(final).all() == 0
129 | else:
130 | for submat in np.split(
131 | final, indices_or_sections=num_output // num_input,
132 | axis=1):
133 | assert np.triu(submat).all() == 0
134 |
135 | def forward(self, input, context=None):
136 | # first hidden layer receives input and context
137 | hidden = self.input_context_net(input, context)
138 | # rest of the network is conditioned on both input and context
139 | return self.net(hidden)
140 |
141 |
142 | class Reverse(nn.Module):
143 | """ An implementation of a reversing layer from
144 | Density estimation using Real NVP
145 | (https://arxiv.org/abs/1605.08803).
146 |
147 | From https://github.com/ikostrikov/pytorch-flows/blob/master/main.py
148 | """
149 |
150 | def __init__(self, num_input):
151 | super(Reverse, self).__init__()
152 | self.perm = np.array(np.arange(0, num_input)[::-1])
153 | self.inv_perm = np.argsort(self.perm)
154 |
155 | def forward(self, inputs, context=None, mode='forward'):
156 | if mode == "forward":
157 | return inputs[..., self.perm], torch.zeros_like(inputs)
158 | elif mode == "inverse":
159 | return inputs[..., self.inv_perm], torch.zeros_like(inputs)
160 | else:
161 | raise ValueError("Mode must be one of {forward, inverse}.")
162 |
--------------------------------------------------------------------------------
/image/inception.py:
--------------------------------------------------------------------------------
1 | """
2 | Code adapted from https://github.com/pytorch/vision
3 | """
4 |
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torchvision import models
8 |
9 |
10 | class InceptionV3(nn.Module):
11 | """Pretrained InceptionV3 network returning feature maps"""
12 |
13 | # Index of default block of inception to return,
14 | # corresponds to output of final average pooling
15 | DEFAULT_BLOCK_INDEX = 3
16 |
17 | # Maps feature dimensionality to their output blocks indices
18 | BLOCK_INDEX_BY_DIM = {
19 | 64: 0, # First max pooling features
20 | 192: 1, # Second max pooling featurs
21 | 768: 2, # Pre-aux classifier features
22 | 2048: 3 # Final average pooling features
23 | }
24 |
25 | def __init__(self,
26 | output_blocks=[DEFAULT_BLOCK_INDEX],
27 | resize_input=299, # -1: not resize
28 | normalize_input=True,
29 | requires_grad=False):
30 | """Build pretrained InceptionV3
31 |
32 | Parameters
33 | ----------
34 | output_blocks : list of int
35 | Indices of blocks to return features of. Possible values are:
36 | - 0: corresponds to output of first max pooling
37 | - 1: corresponds to output of second max pooling
38 | - 2: corresponds to output which is fed to aux classifier
39 | - 3: corresponds to output of final average pooling
40 | resize_input : bool
41 | If true, bilinearly resizes input to width and height 299 before
42 | feeding input to model. As the network without fully connected
43 | layers is fully convolutional, it should be able to handle inputs
44 | of arbitrary size, so resizing might not be strictly needed
45 | normalize_input : bool
46 | If true, normalizes the input to the statistics the pretrained
47 | Inception network expects
48 | requires_grad : bool
49 | If true, parameters of the model require gradient. Possibly useful
50 | for finetuning the network
51 | """
52 | super(InceptionV3, self).__init__()
53 |
54 | self.resize_input = resize_input
55 | self.normalize_input = normalize_input
56 | self.output_blocks = sorted(output_blocks)
57 | self.last_needed_block = max(output_blocks)
58 |
59 | assert self.last_needed_block <= 3, \
60 | 'Last possible output block index is 3'
61 |
62 | self.blocks = nn.ModuleList()
63 |
64 | inception = models.inception_v3(pretrained=True)
65 |
66 | # Block 0: input to maxpool1
67 | block0 = [
68 | inception.Conv2d_1a_3x3,
69 | inception.Conv2d_2a_3x3,
70 | inception.Conv2d_2b_3x3,
71 | nn.MaxPool2d(kernel_size=3, stride=2)
72 | ]
73 | self.blocks.append(nn.Sequential(*block0))
74 |
75 | # Block 1: maxpool1 to maxpool2
76 | if self.last_needed_block >= 1:
77 | block1 = [
78 | inception.Conv2d_3b_1x1,
79 | inception.Conv2d_4a_3x3,
80 | nn.MaxPool2d(kernel_size=3, stride=2)
81 | ]
82 | self.blocks.append(nn.Sequential(*block1))
83 |
84 | # Block 2: maxpool2 to aux classifier
85 | if self.last_needed_block >= 2:
86 | block2 = [
87 | inception.Mixed_5b,
88 | inception.Mixed_5c,
89 | inception.Mixed_5d,
90 | inception.Mixed_6a,
91 | inception.Mixed_6b,
92 | inception.Mixed_6c,
93 | inception.Mixed_6d,
94 | inception.Mixed_6e,
95 | ]
96 | self.blocks.append(nn.Sequential(*block2))
97 |
98 | # Block 3: aux classifier to final avgpool
99 | if self.last_needed_block >= 3:
100 | block3 = [
101 | inception.Mixed_7a,
102 | inception.Mixed_7b,
103 | inception.Mixed_7c,
104 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
105 | ]
106 | self.blocks.append(nn.Sequential(*block3))
107 |
108 | for param in self.parameters():
109 | param.requires_grad = requires_grad
110 |
111 | def forward(self, inp):
112 | """Get Inception feature maps
113 |
114 | Parameters
115 | ----------
116 | inp : torch.autograd.Variable
117 | Input tensor of shape Bx3xHxW. Values are expected to be in
118 | range (0, 1)
119 |
120 | Returns
121 | -------
122 | List of torch.autograd.Variable, corresponding to the selected output
123 | block, sorted ascending by index
124 | """
125 | outp = []
126 | x = inp
127 |
128 | if self.resize_input > 0:
129 | # size = 299
130 | x = F.interpolate(x, size=(self.resize_input, self.resize_input),
131 | mode='bilinear', align_corners=True)
132 |
133 | if self.normalize_input:
134 | x = x.clone()
135 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
136 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
137 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
138 |
139 | for idx, block in enumerate(self.blocks):
140 | x = block(x)
141 | if idx in self.output_blocks:
142 | outp.append(x)
143 |
144 | if idx == self.last_needed_block:
145 | break
146 |
147 | return outp
148 |
--------------------------------------------------------------------------------
/image/masked_celeba.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision import datasets, transforms
3 | import numpy as np
4 | from PIL import Image
5 |
6 |
7 | class MaskedCelebA(datasets.ImageFolder):
8 | def __init__(self, data_dir='celeba-data', image_size=64, random_seed=0):
9 | transform = transforms.Compose([
10 | transforms.CenterCrop(108),
11 | transforms.Resize(size=image_size, interpolation=Image.BICUBIC),
12 | transforms.ToTensor(),
13 | # transforms.Normalize(mean=(.5, .5, .5), std=(.5, .5, .5)),
14 | ])
15 |
16 | super().__init__(data_dir, transform)
17 |
18 | self.rnd = np.random.RandomState(random_seed)
19 | torch.manual_seed(random_seed)
20 | self.image_size = image_size
21 | self.cache_images()
22 | self.generate_masks()
23 | self.mask = torch.stack(self.mask)
24 |
25 | def cache_images(self):
26 | images = []
27 | for i in range(len(self)):
28 | image, _ = super().__getitem__(i)
29 | images.append(image)
30 | self.images = torch.stack(images)
31 |
32 | def __getitem__(self, index):
33 | return self.images[index], self.mask[index], index
34 |
35 | def __len__(self):
36 | return super().__len__()
37 |
38 |
39 | class BlockMaskedCelebA(MaskedCelebA):
40 | def __init__(self, block_len=None, *args, **kwargs):
41 | self.block_len = block_len
42 | super().__init__(*args, **kwargs)
43 |
44 | def generate_masks(self):
45 | d0_len = d1_len = self.image_size
46 | d0_min_len = 12
47 | d0_max_len = d0_len - d0_min_len
48 | d1_min_len = 12
49 | d1_max_len = d1_len - d1_min_len
50 |
51 | n_masks = len(self)
52 | self.mask = [None] * n_masks
53 | self.mask_loc = [None] * n_masks
54 | for i in range(n_masks):
55 | if self.block_len == 0:
56 | d0_mask_len = self.rnd.randint(d0_min_len, d0_max_len)
57 | d1_mask_len = self.rnd.randint(d1_min_len, d1_max_len)
58 | else:
59 | d0_mask_len = d1_mask_len = self.block_len
60 |
61 | d0_start = self.rnd.randint(0, d0_len - d0_mask_len + 1)
62 | d1_start = self.rnd.randint(0, d1_len - d1_mask_len + 1)
63 |
64 | mask = torch.zeros((d0_len, d1_len), dtype=torch.uint8)
65 | mask[d0_start:(d0_start + d0_mask_len),
66 | d1_start:(d1_start + d1_mask_len)] = 1
67 | self.mask[i] = mask[None]
68 | self.mask_loc[i] = d0_start, d1_start, d0_mask_len, d1_mask_len
69 |
70 |
71 | class IndepMaskedCelebA(MaskedCelebA):
72 | def __init__(self, obs_prob=.2, obs_prob_max=None, *args, **kwargs):
73 | self.prob = obs_prob
74 | self.prob_max = obs_prob_max
75 | self.mask_loc = None
76 | super().__init__(*args, **kwargs)
77 |
78 | def generate_masks(self):
79 | imsize = self.image_size
80 | prob = self.prob
81 | prob_max = self.prob_max
82 | n_masks = len(self)
83 | self.mask = [None] * n_masks
84 | for i in range(n_masks):
85 | if prob_max is None:
86 | p = prob
87 | else:
88 | p = self.rnd.uniform(prob, prob_max)
89 | self.mask[i] = torch.ByteTensor(1, imsize, imsize).bernoulli_(p)
90 |
--------------------------------------------------------------------------------
/image/masked_mnist.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from torchvision import datasets, transforms
4 | import numpy as np
5 |
6 |
7 | class MaskedMNIST(Dataset):
8 | def __init__(self, data_dir='mnist-data', image_size=28, random_seed=0):
9 | self.rnd = np.random.RandomState(random_seed)
10 | torch.manual_seed(random_seed)
11 | self.image_size = image_size
12 | if image_size == 28:
13 | self.data = datasets.MNIST(
14 | data_dir, train=True, download=True,
15 | transform=transforms.ToTensor())
16 | else:
17 | self.data = datasets.MNIST(
18 | data_dir, train=True, download=True,
19 | transform=transforms.Compose([
20 | transforms.Resize(image_size), transforms.ToTensor()]))
21 | self.generate_masks()
22 |
23 | def __getitem__(self, index):
24 | image, label = self.data[index]
25 | mask = self.mask[index]
26 | return image * mask.float(), mask[None], index
27 |
28 | def __len__(self):
29 | return len(self.data)
30 |
31 | def generate_masks(self):
32 | raise NotImplementedError
33 |
34 |
35 | class BlockMaskedMNIST(MaskedMNIST):
36 | def __init__(self, block_len=11, block_len_max=None, *args, **kwargs):
37 | self.block_len = block_len
38 | self.block_len_max = block_len_max
39 | super().__init__(*args, **kwargs)
40 |
41 | def generate_masks(self):
42 | d0_len = d1_len = self.image_size
43 | n_masks = len(self)
44 | self.mask = [None] * n_masks
45 | self.mask_loc = [None] * n_masks
46 | for i in range(n_masks):
47 | if self.block_len_max is None:
48 | d0_mask_len = d1_mask_len = self.block_len
49 | else:
50 | d0_mask_len = self.rnd.randint(
51 | self.block_len, self.block_len_max)
52 | d1_mask_len = self.rnd.randint(
53 | self.block_len, self.block_len_max)
54 |
55 | d0_start = self.rnd.randint(0, d0_len - d0_mask_len + 1)
56 | d1_start = self.rnd.randint(0, d1_len - d1_mask_len + 1)
57 |
58 | mask = torch.zeros((d0_len, d1_len), dtype=torch.uint8)
59 | mask[d0_start:(d0_start + d0_mask_len),
60 | d1_start:(d1_start + d1_mask_len)] = 1
61 | self.mask[i] = mask
62 | self.mask_loc[i] = d0_start, d1_start, d0_mask_len, d1_mask_len
63 |
64 |
65 | class IndepMaskedMNIST(MaskedMNIST):
66 | def __init__(self, obs_prob=.2, obs_prob_max=None, *args, **kwargs):
67 | self.prob = obs_prob
68 | self.prob_max = obs_prob_max
69 | self.mask_loc = None
70 | super().__init__(*args, **kwargs)
71 |
72 | def generate_masks(self):
73 | imsize = self.image_size
74 | n_masks = len(self)
75 | self.mask = [None] * n_masks
76 | for i in range(n_masks):
77 | if self.prob_max is None:
78 | p = self.prob
79 | else:
80 | p = self.rnd.uniform(self.prob, self.prob_max)
81 | self.mask[i] = torch.ByteTensor(imsize, imsize).bernoulli_(p)
82 |
--------------------------------------------------------------------------------
/image/mmd.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def mmd(x, y):
5 | n, dim = x.shape
6 |
7 | xx = (x**2).sum(1, keepdim=True)
8 | yy = (y**2).sum(1, keepdim=True)
9 |
10 | outer_xx = torch.mm(x, x.t())
11 | outer_yy = torch.mm(y, y.t())
12 | outer_xy = torch.mm(x, y.t())
13 |
14 | diff_xx = xx + xx.t() - 2 * outer_xx
15 | diff_yy = yy + yy.t() - 2 * outer_yy
16 | diff_xy = xx + yy.t() - 2 * outer_xy
17 |
18 | C = 2. * dim
19 | k_xx = C / (C + diff_xx)
20 | k_yy = C / (C + diff_yy)
21 | k_xy = C / (C + diff_xy)
22 |
23 | mean_xx = (k_xx.sum() - k_xx.diag().sum()) / (n * (n - 1))
24 | mean_yy = (k_yy.sum() - k_yy.diag().sum()) / (n * (n - 1))
25 | mean_xy = k_xy.sum() / (n * n)
26 |
27 | return mean_xx + mean_yy - 2 * mean_xy
28 |
--------------------------------------------------------------------------------
/image/mnist_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ConvDecoder(nn.Module):
6 | def __init__(self, latent_size=128):
7 | super().__init__()
8 |
9 | self.DIM = 64
10 | self.latent_size = latent_size
11 |
12 | self.preprocess = nn.Sequential(
13 | nn.Linear(latent_size, 4 * 4 * 4 * self.DIM),
14 | nn.ReLU(True),
15 | )
16 | self.block1 = nn.Sequential(
17 | nn.ConvTranspose2d(4 * self.DIM, 2 * self.DIM, 5),
18 | nn.ReLU(True),
19 | )
20 | self.block2 = nn.Sequential(
21 | nn.ConvTranspose2d(2 * self.DIM, self.DIM, 5),
22 | nn.ReLU(True),
23 | )
24 | self.deconv_out = nn.ConvTranspose2d(self.DIM, 1, 8, stride=2)
25 |
26 | def forward(self, input):
27 | net = self.preprocess(input)
28 | net = net.view(-1, 4 * self.DIM, 4, 4)
29 | net = self.block1(net)
30 | net = net[:, :, :7, :7]
31 | net = self.block2(net)
32 | net = self.deconv_out(net)
33 | net = net.view(-1, 1, 28, 28)
34 | return net, torch.sigmoid(net)
35 |
--------------------------------------------------------------------------------
/image/mnist_encoder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from torch.distributions.normal import Normal
4 | import flow
5 |
6 |
7 | class ConvEncoder(nn.Module):
8 | def __init__(self, latent_size, flow_depth=2, logprob=False):
9 | super().__init__()
10 |
11 | if logprob:
12 | self.encode_func = self.encode_logprob
13 | else:
14 | self.encode_func = self.encode
15 |
16 | DIM = 64
17 | self.main = nn.Sequential(
18 | nn.Conv2d(1, DIM, 5, stride=2, padding=2),
19 | nn.ReLU(True),
20 | nn.Conv2d(DIM, 2 * DIM, 5, stride=2, padding=2),
21 | nn.ReLU(True),
22 | nn.Conv2d(2 * DIM, 4 * DIM, 5, stride=2, padding=2),
23 | nn.ReLU(True),
24 | )
25 |
26 | if flow_depth > 0:
27 | # IAF
28 | hidden_size = latent_size * 2
29 | flow_layers = [flow.InverseAutoregressiveFlow(
30 | latent_size, hidden_size, latent_size)
31 | for _ in range(flow_depth)]
32 |
33 | flow_layers.append(flow.Reverse(latent_size))
34 | self.q_z_flow = flow.FlowSequential(*flow_layers)
35 | self.enc_chunk = 3
36 | else:
37 | self.q_z_flow = None
38 | self.enc_chunk = 2
39 |
40 | fc_out_size = latent_size * self.enc_chunk
41 | conv_out_size = 4 * 4 * 4 * DIM
42 | self.fc = nn.Sequential(
43 | nn.Linear(conv_out_size, fc_out_size),
44 | nn.LayerNorm(fc_out_size),
45 | nn.LeakyReLU(0.2),
46 | nn.Linear(fc_out_size, fc_out_size),
47 | )
48 |
49 | def forward(self, input, k_samples=5):
50 | return self.encode_func(input, k_samples)
51 |
52 | def encode_logprob(self, input, k_samples=5):
53 | x = self.main(input.view(-1, 1, 28, 28))
54 | x = x.view(input.shape[0], -1)
55 | fc_out = self.fc(x).chunk(self.enc_chunk, dim=1)
56 | mu, logvar = fc_out[:2]
57 | std = F.softplus(logvar)
58 | qz_x = Normal(mu, std)
59 | z = qz_x.rsample([k_samples])
60 | log_q_z = qz_x.log_prob(z)
61 | if self.q_z_flow:
62 | z, log_q_z_flow = self.q_z_flow(z, context=fc_out[2])
63 | log_q_z = (log_q_z + log_q_z_flow).sum(-1)
64 | else:
65 | log_q_z = log_q_z.sum(-1)
66 | return z, log_q_z
67 |
68 | def encode(self, input, _):
69 | x = self.main(input.view(-1, 1, 28, 28))
70 | x = x.view(input.shape[0], -1)
71 | fc_out = self.fc(x).chunk(self.enc_chunk, dim=1)
72 | mu, logvar = fc_out[:2]
73 | std = F.softplus(logvar)
74 | qz_x = Normal(mu, std)
75 | z = qz_x.rsample()
76 | if self.q_z_flow:
77 | z, _ = self.q_z_flow(z, context=fc_out[2])
78 | return z
79 |
--------------------------------------------------------------------------------
/image/mnist_pbigan.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import torch
3 | import torch.nn as nn
4 | from torch.autograd import grad
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | from torch.utils.data import DataLoader
8 | import sys
9 | import logging
10 | from pathlib import Path
11 | from datetime import datetime
12 | import pprint
13 | import argparse
14 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST
15 | from mnist_decoder import ConvDecoder
16 | from mnist_encoder import ConvEncoder
17 | from utils import mkdir, make_scheduler
18 | from visualize import Visualizer
19 |
20 |
21 | use_cuda = torch.cuda.is_available()
22 | device = torch.device('cuda' if use_cuda else 'cpu')
23 |
24 |
25 | class PBiGAN(nn.Module):
26 | def __init__(self, encoder, decoder, ae_loss='bce'):
27 | super().__init__()
28 | self.encoder = encoder
29 | self.decoder = decoder
30 | self.ae_loss = ae_loss
31 |
32 | def forward(self, x, mask, ae=True):
33 | z_T = self.encoder(x * mask)
34 |
35 | z_gen = torch.empty_like(z_T).normal_()
36 | x_gen_logit, x_gen = self.decoder(z_gen)
37 |
38 | x_logit, x_recon = self.decoder(z_T)
39 |
40 | recon_loss = None
41 | if ae:
42 | if self.ae_loss == 'mse':
43 | recon_loss = F.mse_loss(
44 | x_recon * mask, x * mask, reduction='none') * mask
45 | recon_loss = recon_loss.sum((1, 2, 3)).mean()
46 | elif self.ae_loss == 'l1':
47 | recon_loss = F.l1_loss(
48 | x_recon * mask, x * mask, reduction='none') * mask
49 | recon_loss = recon_loss.sum((1, 2, 3)).mean()
50 | elif self.ae_loss == 'smooth_l1':
51 | recon_loss = F.smooth_l1_loss(
52 | x_recon * mask, x * mask, reduction='none') * mask
53 | recon_loss = recon_loss.sum((1, 2, 3)).mean()
54 | elif self.ae_loss == 'bce':
55 | # Bernoulli noise
56 | # recon_loss: -log p(x|z)
57 | recon_loss = F.binary_cross_entropy_with_logits(
58 | x_logit * mask, x * mask, reduction='none') * mask
59 | recon_loss = recon_loss.sum((1, 2, 3)).mean()
60 |
61 | return z_T, z_gen, x_recon, x_gen, recon_loss
62 |
63 |
64 | class ConvCritic(nn.Module):
65 | def __init__(self, latent_size):
66 | super().__init__()
67 |
68 | self.DIM = 64
69 | self.main = nn.Sequential(
70 | nn.Conv2d(1, self.DIM, 5, stride=2, padding=2),
71 | nn.ReLU(True),
72 | nn.Conv2d(self.DIM, 2 * self.DIM, 5, stride=2, padding=2),
73 | nn.ReLU(True),
74 | nn.Conv2d(2 * self.DIM, 4 * self.DIM, 5, stride=2, padding=2),
75 | nn.ReLU(True),
76 | )
77 |
78 | embed_size = 64
79 |
80 | self.z_fc = nn.Sequential(
81 | nn.Linear(latent_size, embed_size),
82 | nn.LayerNorm(embed_size),
83 | nn.LeakyReLU(0.2),
84 | nn.Linear(embed_size, embed_size),
85 | )
86 |
87 | self.x_fc = nn.Linear(4 * 4 * 4 * self.DIM, embed_size)
88 |
89 | self.xz_fc = nn.Sequential(
90 | nn.Linear(embed_size * 2, embed_size),
91 | nn.LayerNorm(embed_size),
92 | nn.LeakyReLU(0.2),
93 | nn.Linear(embed_size, 1),
94 | )
95 |
96 | def forward(self, input):
97 | x, z = input
98 | x = x.view(-1, 1, 28, 28)
99 | x = self.main(x)
100 | x = x.view(x.shape[0], -1)
101 | x = self.x_fc(x)
102 | z = self.z_fc(z)
103 | xz = torch.cat((x, z), 1)
104 | xz = self.xz_fc(xz)
105 | return xz.view(-1)
106 |
107 |
108 | class GradientPenalty:
109 | def __init__(self, critic, batch_size=64, gp_lambda=10):
110 | self.critic = critic
111 | self.gp_lambda = gp_lambda
112 | # Interpolation coefficient
113 | self.eps = torch.empty(batch_size, device=device)
114 | # For computing the gradient penalty
115 | self.ones = torch.ones(batch_size).to(device)
116 |
117 | def interpolate(self, real, fake):
118 | eps = self.eps.view([-1] + [1] * (len(real.shape) - 1))
119 | return (eps * real + (1 - eps) * fake).requires_grad_()
120 |
121 | def __call__(self, real, fake):
122 | real = [x.detach() for x in real]
123 | fake = [x.detach() for x in fake]
124 | self.eps.uniform_(0, 1)
125 | interp = [self.interpolate(a, b) for a, b in zip(real, fake)]
126 | grad_d = grad(self.critic(interp),
127 | interp,
128 | grad_outputs=self.ones,
129 | create_graph=True)
130 | batch_size = real[0].shape[0]
131 | grad_d = torch.cat([g.view(batch_size, -1) for g in grad_d], 1)
132 | grad_penalty = ((grad_d.norm(dim=1) - 1)**2).mean() * self.gp_lambda
133 | return grad_penalty
134 |
135 |
136 | def train_pbigan(args):
137 | torch.manual_seed(args.seed)
138 |
139 | if args.mask == 'indep':
140 | data = IndepMaskedMNIST(obs_prob=args.obs_prob,
141 | obs_prob_max=args.obs_prob_max)
142 | mask_str = f'{args.mask}_{args.obs_prob}_{args.obs_prob_max}'
143 | elif args.mask == 'block':
144 | data = BlockMaskedMNIST(block_len=args.block_len,
145 | block_len_max=args.block_len_max)
146 | mask_str = f'{args.mask}_{args.block_len}_{args.block_len_max}'
147 |
148 | data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True,
149 | drop_last=True)
150 | mask_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True,
151 | drop_last=True)
152 |
153 | # Evaluate the training progress using 2000 examples from the training data
154 | test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True)
155 |
156 | decoder = ConvDecoder(args.latent)
157 | encoder = ConvEncoder(args.latent, args.flow, logprob=False)
158 | pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)
159 |
160 | critic = ConvCritic(args.latent).to(device)
161 |
162 | lrate = 1e-4
163 | optimizer = optim.Adam(pbigan.parameters(), lr=lrate, betas=(.5, .9))
164 |
165 | critic_optimizer = optim.Adam(
166 | critic.parameters(), lr=lrate, betas=(.5, .9))
167 |
168 | grad_penalty = GradientPenalty(critic, args.batch_size)
169 |
170 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch)
171 |
172 | path = '{}_{}_{}'.format(
173 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), mask_str)
174 | output_dir = Path('results') / 'mnist-pbigan' / path
175 | mkdir(output_dir)
176 | print(output_dir)
177 |
178 | if args.save_interval > 0:
179 | model_dir = mkdir(output_dir / 'model')
180 |
181 | logging.basicConfig(
182 | level=logging.INFO,
183 | format='%(asctime)s %(message)s',
184 | datefmt='%Y-%m-%d %H:%M:%S',
185 | handlers=[
186 | logging.FileHandler(output_dir / 'log.txt'),
187 | logging.StreamHandler(sys.stdout),
188 | ],
189 | )
190 |
191 | with (output_dir / 'args.txt').open('w') as f:
192 | print(pprint.pformat(vars(args)), file=f)
193 |
194 | vis = Visualizer(output_dir)
195 |
196 | test_x, test_mask, index = iter(test_loader).next()
197 | test_x = test_x.to(device)
198 | test_mask = test_mask.to(device).float()
199 | bbox = None
200 | if data.mask_loc is not None:
201 | bbox = [data.mask_loc[idx] for idx in index]
202 |
203 | n_critic = 5
204 | critic_updates = 0
205 | ae_weight = 0
206 | ae_flat = 100
207 |
208 | for epoch in range(args.epoch):
209 | loss_breakdown = defaultdict(float)
210 |
211 | if epoch > ae_flat:
212 | ae_weight = args.ae * (epoch - ae_flat) / (args.epoch - ae_flat)
213 |
214 | for (x, mask, _), (_, mask_gen, _) in zip(data_loader, mask_loader):
215 | x = x.to(device)
216 | mask = mask.to(device).float()
217 | mask_gen = mask_gen.to(device).float()
218 |
219 | z_enc, z_gen, x_rec, x_gen, _ = pbigan(x, mask, ae=False)
220 |
221 | real_score = critic((x * mask, z_enc)).mean()
222 | fake_score = critic((x_gen * mask_gen, z_gen)).mean()
223 |
224 | w_dist = real_score - fake_score
225 | D_loss = -w_dist + grad_penalty((x * mask, z_enc),
226 | (x_gen * mask_gen, z_gen))
227 |
228 | critic_optimizer.zero_grad()
229 | D_loss.backward()
230 | critic_optimizer.step()
231 |
232 | loss_breakdown['D'] += D_loss.item()
233 |
234 | critic_updates += 1
235 |
236 | if critic_updates == n_critic:
237 | critic_updates = 0
238 |
239 | # Update generators' parameters
240 | for p in critic.parameters():
241 | p.requires_grad_(False)
242 |
243 | z_enc, z_gen, x_rec, x_gen, ae_loss = pbigan(x, mask)
244 |
245 | real_score = critic((x * mask, z_enc)).mean()
246 | fake_score = critic((x_gen * mask_gen, z_gen)).mean()
247 |
248 | G_loss = real_score - fake_score
249 |
250 | ae_loss = ae_loss * ae_weight
251 | loss = G_loss + ae_loss
252 |
253 | optimizer.zero_grad()
254 | loss.backward()
255 | optimizer.step()
256 |
257 | loss_breakdown['G'] += G_loss.item()
258 | loss_breakdown['AE'] += ae_loss.item()
259 | loss_breakdown['total'] += loss.item()
260 |
261 | for p in critic.parameters():
262 | p.requires_grad_(True)
263 |
264 | if scheduler:
265 | scheduler.step()
266 |
267 | vis.plot_loss(epoch, loss_breakdown)
268 |
269 | if epoch % args.plot_interval == 0:
270 | with torch.no_grad():
271 | pbigan.eval()
272 | z, z_gen, x_rec, x_gen, ae_loss = pbigan(test_x, test_mask)
273 | pbigan.train()
274 | vis.plot(epoch, test_x, test_mask, bbox, x_rec, x_gen)
275 |
276 | model_dict = {
277 | 'pbigan': pbigan.state_dict(),
278 | 'critic': critic.state_dict(),
279 | 'history': vis.history,
280 | 'epoch': epoch,
281 | 'args': args,
282 | }
283 | torch.save(model_dict, str(output_dir / 'model.pth'))
284 | if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0:
285 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))
286 |
287 | print(output_dir)
288 |
289 |
290 | def main():
291 | parser = argparse.ArgumentParser()
292 |
293 | parser.add_argument('--seed', type=int, default=3,
294 | help='random seed')
295 | # training options
296 | parser.add_argument('--plot-interval', type=int, default=50,
297 | help='plot interval. 0 to disable plotting.')
298 | parser.add_argument('--save-interval', type=int, default=0,
299 | help='interval to save models. 0 to disable saving.')
300 | parser.add_argument('--mask', default='block',
301 | help='missing data mask. (options: block, indep)')
302 | # option for block: set to 0 for variable size
303 | parser.add_argument('--block-len', type=int, default=12,
304 | help='size of observed block')
305 | parser.add_argument('--block-len-max', type=int, default=None,
306 | help='max size of observed block. '
307 | 'Use fixed-size observed block if unspecified.')
308 | # option for indep:
309 | parser.add_argument('--obs-prob', type=float, default=.2,
310 | help='observed probability for independent dropout')
311 | parser.add_argument('--obs-prob-max', type=float, default=None,
312 | help='max observed probability for independent '
313 | 'dropout. Use fixed probability if unspecified.')
314 |
315 | parser.add_argument('--flow', type=int, default=2,
316 | help='number of IAF layers')
317 | parser.add_argument('--lr', type=float, default=1e-3,
318 | help='learning rate')
319 | parser.add_argument('--min-lr', type=float, default=-1,
320 | help='min learning rate for LR scheduler. '
321 | '-1 to disable annealing')
322 |
323 | parser.add_argument('--arch', default='conv',
324 | help='network architecture. (options: fc, conv)')
325 | parser.add_argument('--epoch', type=int, default=2000,
326 | help='number of training epochs')
327 | parser.add_argument('--batch-size', type=int, default=128,
328 | help='batch size')
329 | parser.add_argument('--ae', type=float, default=.1,
330 | help='autoencoding regularization strength')
331 | parser.add_argument('--prefix', default='pbigan',
332 | help='prefix of output directory')
333 | parser.add_argument('--latent', type=int, default=128,
334 | help='dimension of latent variable')
335 | parser.add_argument('--aeloss', default='bce',
336 | help='autoencoding loss. '
337 | '(options: mse, bce, smooth_l1, l1)')
338 |
339 | args = parser.parse_args()
340 |
341 | train_pbigan(args)
342 |
343 |
344 | if __name__ == '__main__':
345 | main()
346 |
--------------------------------------------------------------------------------
/image/mnist_pvae.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | from torch.distributions.normal import Normal
7 | from torch.distributions.categorical import Categorical
8 | from torch.utils.data import DataLoader
9 | import math
10 | import sys
11 | import logging
12 | from pathlib import Path
13 | from datetime import datetime
14 | import pprint
15 | import argparse
16 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST
17 | from mnist_decoder import ConvDecoder
18 | from mnist_encoder import ConvEncoder
19 | from utils import mkdir, make_scheduler
20 | from visualize import Visualizer
21 |
22 |
23 | use_cuda = torch.cuda.is_available()
24 | device = torch.device('cuda' if use_cuda else 'cpu')
25 |
26 |
27 | class PVAE(nn.Module):
28 | def __init__(self, encoder, decoder):
29 | super().__init__()
30 | self.encoder = encoder
31 | self.decoder = decoder
32 |
33 | def forward(self, x, mask, k_samples=5, kl_weight=1):
34 | z_T, log_q_z = self.encoder(x * mask, k_samples)
35 |
36 | pz = Normal(torch.zeros_like(z_T), torch.ones_like(z_T))
37 | log_p_z = pz.log_prob(z_T).sum(-1)
38 | # kl_loss: log q(z|x) - log p(z)
39 | kl_loss = log_q_z - log_p_z
40 |
41 | # Reshape z to accommodate modules with strict input shape requirements
42 | # such as convolutional layers.
43 | x_logit, x_recon = self.decoder(z_T.view(-1, *z_T.shape[2:]))
44 | flat_mask = mask.view(x.shape[0], -1)
45 | flat_logit = x_logit.view(*z_T.shape[:2], -1) * flat_mask
46 | flat_x = (x * mask).view(1, x.shape[0], -1).expand_as(flat_logit)
47 | # Bernoulli noise
48 | # recon_loss: -log p(x|z)
49 | recon_loss = (F.binary_cross_entropy_with_logits(
50 | flat_logit, flat_x, reduction='none') * flat_mask).sum(-1)
51 |
52 | # elbo = log p(x|z) + log p(z) - log q(z|x)
53 | elbo = -(recon_loss + kl_loss * kl_weight)
54 |
55 | # IWAE loss: -log E[p(x|z) p(z) / q(z|x)]
56 | # Here we ignore the constant shift of -log(k_samples)
57 | loss = -elbo.logsumexp(0).mean()
58 |
59 | x_recon = x_recon.view(-1, *x.shape)
60 | loss_breakdown = {
61 | 'loss': loss.item(),
62 | 'KL': kl_loss.mean().item(),
63 | 'recon': recon_loss.mean().item(),
64 | }
65 | return loss, z_T, x_recon, elbo, loss_breakdown
66 |
67 | def impute(self, x, mask, k_samples=10):
68 | self.eval()
69 | with torch.no_grad():
70 | _, z, x_recon, elbo, _ = self(x, mask, k_samples)
71 | # sampling importance resampling
72 | is_idx = Categorical(logits=elbo.t()).sample()
73 | batch_idx = torch.arange(len(x))
74 | z = z[is_idx, batch_idx]
75 | x_recon = x_recon[is_idx, batch_idx]
76 | self.train()
77 | return x_recon
78 |
79 |
80 | def train_pvae(args):
81 | torch.manual_seed(args.seed)
82 |
83 | if args.mask == 'indep':
84 | data = IndepMaskedMNIST(obs_prob=args.obs_prob,
85 | obs_prob_max=args.obs_prob_max)
86 | mask_str = f'{args.mask}_{args.obs_prob}_{args.obs_prob_max}'
87 | elif args.mask == 'block':
88 | data = BlockMaskedMNIST(block_len=args.block_len,
89 | block_len_max=args.block_len_max)
90 | mask_str = f'{args.mask}_{args.block_len}_{args.block_len_max}'
91 |
92 | data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True,
93 | drop_last=True)
94 |
95 | # Evaluate the training progress using 2000 examples from the training data
96 | test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True)
97 |
98 | decoder = ConvDecoder(args.latent)
99 | encoder = ConvEncoder(args.latent, args.flow, logprob=True)
100 | pvae = PVAE(encoder, decoder).to(device)
101 |
102 | optimizer = optim.Adam(pvae.parameters(), lr=args.lr)
103 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch)
104 |
105 | rand_z = torch.empty(args.batch_size, args.latent, device=device)
106 |
107 | path = '{}_{}_{}'.format(
108 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), mask_str)
109 | output_dir = Path('results') / 'mnist-pvae' / path
110 | mkdir(output_dir)
111 | print(output_dir)
112 |
113 | if args.save_interval > 0:
114 | model_dir = mkdir(output_dir / 'model')
115 |
116 | logging.basicConfig(
117 | level=logging.INFO,
118 | format='%(asctime)s %(message)s',
119 | datefmt='%Y-%m-%d %H:%M:%S',
120 | handlers=[
121 | logging.FileHandler(output_dir / 'log.txt'),
122 | logging.StreamHandler(sys.stdout),
123 | ],
124 | )
125 |
126 | with (output_dir / 'args.txt').open('w') as f:
127 | print(pprint.pformat(vars(args)), file=f)
128 |
129 | vis = Visualizer(output_dir)
130 |
131 | test_x, test_mask, index = iter(test_loader).next()
132 | test_x = test_x.to(device)
133 | test_mask = test_mask.to(device).float()
134 | bbox = None
135 | if data.mask_loc is not None:
136 | bbox = [data.mask_loc[idx] for idx in index]
137 |
138 | kl_center = (args.kl_on + args.kl_off) / 2
139 | kl_scale = 12 / min(args.kl_on - args.kl_off, 1)
140 |
141 | for epoch in range(args.epoch):
142 | if epoch >= args.kl_on:
143 | kl_weight = 1
144 | elif epoch < args.kl_off:
145 | kl_weight = 0
146 | else:
147 | kl_weight = 1 / (1 + math.exp(-(epoch - kl_center) * kl_scale))
148 | loss_breakdown = defaultdict(float)
149 | for x, mask, _ in data_loader:
150 | x = x.to(device)
151 | mask = mask.to(device).float()
152 |
153 | optimizer.zero_grad()
154 | loss, _, _, _, loss_info = pvae(
155 | x, mask, args.k, kl_weight=kl_weight)
156 | loss.backward()
157 | optimizer.step()
158 | for name, val in loss_info.items():
159 | loss_breakdown[name] += val
160 |
161 | if scheduler:
162 | scheduler.step()
163 |
164 | vis.plot_loss(epoch, loss_breakdown)
165 |
166 | if epoch % args.plot_interval == 0:
167 | x_recon = pvae.impute(test_x, test_mask, args.k)
168 | with torch.no_grad():
169 | pvae.eval()
170 | rand_z.normal_()
171 | _, x_gen = decoder(rand_z)
172 | pvae.train()
173 | vis.plot(epoch, test_x, test_mask, bbox, x_recon, x_gen)
174 |
175 | model_dict = {
176 | 'pvae': pvae.state_dict(),
177 | 'history': vis.history,
178 | 'epoch': epoch,
179 | 'args': args,
180 | }
181 | torch.save(model_dict, str(output_dir / 'model.pth'))
182 | if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0:
183 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))
184 |
185 | print(output_dir)
186 |
187 |
188 | def main():
189 | parser = argparse.ArgumentParser()
190 |
191 | parser.add_argument('--seed', type=int, default=3,
192 | help='random seed')
193 | # training options
194 | parser.add_argument('--plot-interval', type=int, default=50,
195 | help='plot interval. 0 to disable plotting.')
196 | parser.add_argument('--save-interval', type=int, default=50,
197 | help='interval to save models. 0 to disable saving.')
198 | parser.add_argument('--mask', default='block',
199 | help='missing data mask. (options: block, indep)')
200 | # option for block: set to 0 for variable size
201 | parser.add_argument('--block-len', type=int, default=12,
202 | help='size of observed block. '
203 | 'Set to 0 to use variable size')
204 | parser.add_argument('--block-len-max', type=int, default=None,
205 | help='max size of observed block. '
206 | 'Use fixed-size observed block if unspecified.')
207 | # option for indep:
208 | parser.add_argument('--obs-prob', type=float, default=.2,
209 | help='observed probability for independent dropout')
210 | parser.add_argument('--obs-prob-max', type=float, default=None,
211 | help='max observed probability for independent '
212 | 'dropout. Use fixed probability if unspecified.')
213 |
214 | parser.add_argument('--flow', type=int, default=2,
215 | help='number of IAF layers')
216 | parser.add_argument('--lr', type=float, default=1e-3,
217 | help='learning rate')
218 | parser.add_argument('--min-lr', type=float, default=-1,
219 | help='min learning rate for LR scheduler. '
220 | '-1 to disable annealing')
221 |
222 | parser.add_argument('--epoch', type=int, default=4000,
223 | help='number of training epochs')
224 | parser.add_argument('--batch-size', type=int, default=128,
225 | help='batch size')
226 | parser.add_argument('--k', type=int, default=5,
227 | help='number of importance weights')
228 | parser.add_argument('--prefix', default='pvae',
229 | help='prefix of output directory')
230 | parser.add_argument('--latent', type=int, default=128,
231 | help='dimension of latent variable')
232 | parser.add_argument('--kl-off', type=int, default=200,
233 | help='epoch to start tune up KL weight from zero')
234 | # set --kl-on to 0 to use constant kl_weight = 1
235 | parser.add_argument('--kl-on', type=int, default=0,
236 | help='start epoch to use KL weight 1')
237 |
238 | args = parser.parse_args()
239 |
240 | train_pvae(args)
241 |
242 |
243 | if __name__ == '__main__':
244 | main()
245 |
--------------------------------------------------------------------------------
/image/utils.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 |
3 |
4 | def count_parameters(model):
5 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
6 |
7 |
8 | def mkdir(path):
9 | path.mkdir(parents=True, exist_ok=True)
10 | return path
11 |
12 |
13 | def make_scheduler(optimizer, lr, min_lr, epochs, steps=10):
14 | if min_lr < 0:
15 | return None
16 | step_size = epochs // steps
17 | gamma = (min_lr / lr)**(1 / steps)
18 | return optim.lr_scheduler.StepLR(
19 | optimizer, step_size=step_size, gamma=gamma)
20 |
--------------------------------------------------------------------------------
/image/visualize.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import matplotlib.gridspec as gridspec
3 | from matplotlib.patches import Rectangle
4 | import seaborn as sns
5 | import logging
6 | import sys
7 | from collections import defaultdict
8 | from pathlib import Path
9 | from utils import mkdir
10 |
11 |
12 | class Visualizer(object):
13 | def __init__(self,
14 | output_dir,
15 | loss_xlim=None,
16 | loss_ylim=None):
17 | self.output_dir = Path(output_dir)
18 | self.recons_dir = mkdir(self.output_dir / 'recons')
19 | self.gen_dir = mkdir(self.output_dir / 'gen')
20 | self.loss_xlim = loss_xlim
21 | self.loss_ylim = loss_ylim
22 | sns.set()
23 | self.rows, self.cols = 8, 16
24 | self.at_start = True
25 |
26 | self.history = defaultdict(list)
27 | logging.basicConfig(
28 | level=logging.INFO,
29 | format='%(asctime)s %(message)s',
30 | datefmt='%Y-%m-%d %H:%M:%S',
31 | handlers=[
32 | logging.FileHandler(output_dir / 'log.txt'),
33 | logging.StreamHandler(sys.stdout),
34 | ],
35 | )
36 | self.print_header = True
37 |
38 | def plot_subgrid(self, images, bbox=None, filename=None):
39 | rows, cols = 8, 16
40 | scale = .75
41 | fig, ax = plt.subplots(figsize=(cols * scale, rows * scale))
42 | ax.set_axis_off()
43 | ax.set_xticks([])
44 | ax.set_yticks([])
45 |
46 | inner_grid = gridspec.GridSpec(rows, cols, fig, wspace=.05, hspace=.05,
47 | left=0, right=1, top=1, bottom=0)
48 |
49 | images = images[:(rows * cols)].cpu().numpy()
50 | if images.shape[1] == 1: # single channel
51 | images = images.squeeze(1)
52 | cmap = 'binary_r'
53 | else: # 3 channels
54 | images = images.transpose((0, 2, 3, 1))
55 | cmap = None
56 |
57 | for i, image in enumerate(images):
58 | ax = plt.Subplot(fig, inner_grid[i])
59 | ax.set_axis_off()
60 | ax.set_xticks([])
61 | ax.set_yticks([])
62 | ax.set_aspect('equal')
63 | ax.imshow(image, interpolation='none', aspect='equal',
64 | cmap=cmap, vmin=0, vmax=1)
65 |
66 | if bbox is not None:
67 | d0, d1, d0_len, d1_len = bbox[i]
68 | ax.add_patch(Rectangle(
69 | (d1 - .5, d0 - .5), d1_len, d0_len, lw=1,
70 | edgecolor='red', fill=False))
71 | fig.add_subplot(ax)
72 |
73 | if filename is not None:
74 | plt.savefig(str(filename))
75 | plt.close(fig)
76 |
77 | def plot(self, epoch, x, mask, bbox, x_recon, x_gen):
78 | if self.at_start:
79 | self.plot_subgrid(x * mask + .5 * (1 - mask), bbox,
80 | self.recons_dir / f'groundtruth.png')
81 | self.at_start = False
82 | self.plot_subgrid(x * mask + x_recon * (1 - mask), bbox,
83 | self.recons_dir / f'{epoch:04d}.png')
84 | self.plot_subgrid(x_gen, None, self.gen_dir / f'{epoch:04d}.png')
85 |
86 | def plot_loss(self, epoch, losses):
87 | for name, val in losses.items():
88 | self.history[name].append(val)
89 |
90 | fig, ax_trace = plt.subplots(figsize=(6, 4))
91 | ax_trace.set_ylabel('loss')
92 | ax_trace.set_xlabel('epochs')
93 | if self.loss_xlim is not None:
94 | ax_trace.set_xlim(self.loss_xlim)
95 | if self.loss_ylim is not None:
96 | ax_trace.set_ylim(self.loss_ylim)
97 | for label, loss in self.history.items():
98 | ax_trace.plot(loss, '-', label=label)
99 | if len(self.history) > 1:
100 | ax_trace.legend(ncol=len(self.history), loc='upper center')
101 | plt.tight_layout()
102 | plt.savefig(str(self.output_dir / 'loss.png'), dpi=300)
103 | plt.close(fig)
104 |
105 | if self.print_header:
106 | logging.info(' ' * 7 + ' '.join(
107 | f'{key:>12}' for key in sorted(losses)))
108 | self.print_header = False
109 | logging.info(f'[{epoch:4}] ' + ' '.join(
110 | f'{val:12.4f}' for _, val in sorted(losses.items())))
111 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.1.1
2 | numpy==1.17.0
3 | Pillow>=7.1.0
4 | scikit-learn==0.21.3
5 | scipy==1.4.1
6 | seaborn==0.9.0
7 | torch==1.1.0
8 | torch-spline-conv==1.1.0
9 | torchvision==0.3.0
10 |
--------------------------------------------------------------------------------
/time-series/.gitattributes:
--------------------------------------------------------------------------------
1 | *.ipynb linguist-language=Python
2 |
--------------------------------------------------------------------------------
/time-series/ema.py:
--------------------------------------------------------------------------------
1 | class EMA:
2 | def __init__(self, model, decay, start_iter=0):
3 | self.model = model
4 | self.beta = 1 - decay
5 | self.start_iter = start_iter
6 | self.iter = 0
7 | self.shadow = None
8 |
9 | def state_dict(self):
10 | """Returns the EMA state as a dictionary for serialization."""
11 | # NOTE: skip saving `model`
12 | return {
13 | 'beta': self.beta,
14 | 'start_iter': self.start_iter,
15 | 'iter': self.iter,
16 | 'shadow': self.shadow,
17 | }
18 |
19 | def update(self):
20 | if self.iter < self.start_iter:
21 | self.iter += 1
22 | else:
23 | if self.shadow is None:
24 | self.shadow = {}
25 | for name, param in self.model.named_parameters():
26 | if param.requires_grad:
27 | self.shadow[name] = param.data.clone()
28 | for name, param in self.model.named_parameters():
29 | if param.requires_grad:
30 | # p = p - (1 - delay) * (p - theta)
31 | # = delay * p + (1 - delay) * theta
32 | self.shadow[name].sub_(
33 | self.beta * (self.shadow[name] - param.data))
34 |
35 | def apply(self):
36 | if self.shadow is None:
37 | return
38 | for name, param in self.model.named_parameters():
39 | if param.requires_grad:
40 | self.shadow[name], param.data = param.data, self.shadow[name]
41 |
42 | def restore(self):
43 | self.apply()
44 |
--------------------------------------------------------------------------------
/time-series/evaluate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from sklearn.metrics import roc_auc_score
4 |
5 |
6 | class Evaluator:
7 | def __init__(self, model, val_loader, test_loader,
8 | log_dir, eval_args={}):
9 | self.model = model
10 | self.val_loader = val_loader
11 | self.test_loader = test_loader
12 | self.log_dir = log_dir
13 | self.eval_args = eval_args
14 |
15 | self.test_auc_log, self.val_auc_log = [], []
16 | self.best_auc, self.best_val_auc = [-float('inf')] * 2
17 |
18 | def evaluate(self, epoch):
19 | self.model.eval()
20 | with torch.no_grad():
21 | val_auc = self.compute_auc(self.val_loader)
22 | print('val AUC:', val_auc)
23 | self.val_auc_log.append((epoch, val_auc))
24 |
25 | test_auc = self.compute_auc(self.test_loader)
26 | print('test AUC:', test_auc)
27 | self.test_auc_log.append((epoch, test_auc))
28 | self.model.train()
29 |
30 | if val_auc > self.best_auc:
31 | self.best_auc = val_auc
32 | torch.save({'model': self.model.state_dict()},
33 | str(self.log_dir / 'best-model.pth'))
34 |
35 | with (self.log_dir / 'val_auc.txt').open('a') as f:
36 | print(epoch, val_auc, file=f)
37 | with (self.log_dir / 'test_auc.txt').open('a') as f:
38 | print(epoch, test_auc, file=f)
39 |
40 | def compute_auc(self, data_loader):
41 | y_true, y_score = [], []
42 | for (val, idx, mask, y, _, cconv_graph) in data_loader:
43 | score = self.model.predict(
44 | val, idx, mask, cconv_graph, **self.eval_args)
45 | y_score.append(score.cpu().numpy())
46 | y_true.append(y.cpu().numpy())
47 |
48 | y_true = np.concatenate(y_true)
49 | y_score = np.concatenate(y_score)
50 | return roc_auc_score(y_true, y_score)
51 |
--------------------------------------------------------------------------------
/time-series/figures/pbigan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/steveli/partial-encoder-decoder/402abdb97b1cdfcca2a30b05507cd97be560ee75/time-series/figures/pbigan.png
--------------------------------------------------------------------------------
/time-series/figures/pvae.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/steveli/partial-encoder-decoder/402abdb97b1cdfcca2a30b05507cd97be560ee75/time-series/figures/pvae.png
--------------------------------------------------------------------------------
/time-series/flow.py:
--------------------------------------------------------------------------------
1 | """Code adapted from https://github.com/altosaar/variational-autoencoder"""
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 |
7 |
8 | class InverseAutoregressiveFlow(nn.Module):
9 | """Inverse Autoregressive Flows with LSTM-type update. One block.
10 |
11 | Eq 11-14 of https://arxiv.org/abs/1606.04934
12 | """
13 | def __init__(self, num_input, num_hidden, num_context):
14 | super().__init__()
15 | self.made = MADE(num_input=num_input, num_output=num_input * 2,
16 | num_hidden=num_hidden, num_context=num_context)
17 | # init such that sigmoid(s) is close to 1 for stability
18 | self.sigmoid_arg_bias = nn.Parameter(torch.ones(num_input) * 2)
19 | self.sigmoid = nn.Sigmoid()
20 | self.log_sigmoid = nn.LogSigmoid()
21 |
22 | def forward(self, input, context=None):
23 | m, s = torch.chunk(self.made(input, context), chunks=2, dim=-1)
24 | s = s + self.sigmoid_arg_bias
25 | sigmoid = self.sigmoid(s)
26 | z = sigmoid * input + (1 - sigmoid) * m
27 | return z, -self.log_sigmoid(s)
28 |
29 |
30 | class FlowSequential(nn.Sequential):
31 | """Forward pass."""
32 |
33 | def forward(self, input, context=None):
34 | total_log_prob = torch.zeros_like(input)
35 | for block in self._modules.values():
36 | input, log_prob = block(input, context)
37 | total_log_prob += log_prob
38 | return input, total_log_prob
39 |
40 |
41 | class MaskedLinear(nn.Module):
42 | """Linear layer with some input-output connections masked."""
43 | def __init__(self, in_features, out_features, mask, context_features=None,
44 | bias=True):
45 | super().__init__()
46 | self.linear = nn.Linear(in_features, out_features, bias)
47 | self.register_buffer("mask", mask)
48 | if context_features is not None:
49 | self.cond_linear = nn.Linear(context_features, out_features,
50 | bias=False)
51 |
52 | def forward(self, input, context=None):
53 | output = F.linear(input, self.mask * self.linear.weight,
54 | self.linear.bias)
55 | if context is None:
56 | return output
57 | else:
58 | return output + self.cond_linear(context)
59 |
60 |
61 | class MADE(nn.Module):
62 | """Implements MADE: Masked Autoencoder for Distribution Estimation.
63 |
64 | Follows https://arxiv.org/abs/1502.03509
65 |
66 | This is used to build MAF:
67 | Masked Autoregressive Flow (https://arxiv.org/abs/1705.07057).
68 | """
69 | def __init__(self, num_input, num_output, num_hidden, num_context):
70 | super().__init__()
71 | # m corresponds to m(k), the maximum degree of a node in the MADE paper
72 | self._m = []
73 | self._masks = []
74 | self._build_masks(num_input, num_output, num_hidden, num_layers=3)
75 | self._check_masks()
76 | modules = []
77 | self.input_context_net = MaskedLinear(
78 | num_input, num_hidden, self._masks[0], num_context)
79 | modules.append(nn.ReLU())
80 | modules.append(MaskedLinear(
81 | num_hidden, num_hidden, self._masks[1], context_features=None))
82 | modules.append(nn.ReLU())
83 | modules.append(MaskedLinear(
84 | num_hidden, num_output, self._masks[2], context_features=None))
85 | self.net = nn.Sequential(*modules)
86 |
87 | def _build_masks(self, num_input, num_output, num_hidden, num_layers):
88 | """Build the masks according to Eq 12 and 13 in the MADE paper."""
89 | rng = np.random.RandomState(0)
90 | # assign input units a number between 1 and D
91 | self._m.append(np.arange(1, num_input + 1))
92 | for i in range(1, num_layers + 1):
93 | # randomly assign maximum number of input nodes to connect to
94 | if i == num_layers:
95 | # assign output layer units a number between 1 and D
96 | m = np.arange(1, num_input + 1)
97 | assert num_output % num_input == 0, (
98 | "num_output must be multiple of num_input")
99 | self._m.append(np.hstack(
100 | [m for _ in range(num_output // num_input)]))
101 | else:
102 | # assign hidden layer units a number between 1 and D-1
103 | self._m.append(rng.randint(1, num_input, size=num_hidden))
104 | # self._m.append(
105 | # np.arange(1, num_hidden + 1) % (num_input - 1) + 1)
106 | if i == num_layers:
107 | mask = self._m[i][None, :] > self._m[i - 1][:, None]
108 | else:
109 | # input to hidden & hidden to hidden
110 | mask = self._m[i][None, :] >= self._m[i - 1][:, None]
111 | # need to transpose for torch linear layer.
112 | # shape (num_output, num_input)
113 | self._masks.append(torch.from_numpy(mask.astype(np.float32).T))
114 |
115 | def _check_masks(self):
116 | """Check that the connectivity matrix between layers is lower
117 | triangular."""
118 | # (num_input, num_hidden)
119 | prev = self._masks[0].t()
120 | for i in range(1, len(self._masks)):
121 | # num_hidden is second axis
122 | prev = prev @ self._masks[i].t()
123 | final = prev.numpy()
124 | num_input = self._masks[0].shape[1]
125 | num_output = self._masks[-1].shape[0]
126 | assert final.shape == (num_input, num_output)
127 | if num_output == num_input:
128 | assert np.triu(final).all() == 0
129 | else:
130 | for submat in np.split(
131 | final, indices_or_sections=num_output // num_input,
132 | axis=1):
133 | assert np.triu(submat).all() == 0
134 |
135 | def forward(self, input, context=None):
136 | # first hidden layer receives input and context
137 | hidden = self.input_context_net(input, context)
138 | # rest of the network is conditioned on both input and context
139 | return self.net(hidden)
140 |
141 |
142 | class Reverse(nn.Module):
143 | """ An implementation of a reversing layer from
144 | Density estimation using Real NVP
145 | (https://arxiv.org/abs/1605.08803).
146 |
147 | From https://github.com/ikostrikov/pytorch-flows/blob/master/main.py
148 | """
149 |
150 | def __init__(self, num_input):
151 | super(Reverse, self).__init__()
152 | self.perm = np.array(np.arange(0, num_input)[::-1])
153 | self.inv_perm = np.argsort(self.perm)
154 |
155 | def forward(self, inputs, context=None, mode='forward'):
156 | if mode == "forward":
157 | return inputs[:, :, self.perm], torch.zeros_like(inputs)
158 | elif mode == "inverse":
159 | return inputs[:, :, self.inv_perm], torch.zeros_like(inputs)
160 | else:
161 | raise ValueError("Mode must be one of {forward, inverse}.")
162 |
--------------------------------------------------------------------------------
/time-series/gen_toy_data.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torch.distributions.exponential import Exponential
4 | import math
5 |
6 |
7 | class HomogeneousPoissonProcess:
8 | def __init__(self, rate=1):
9 | self.rate = rate
10 | self.exp = Exponential(rate)
11 |
12 | def sample(self, size, max_seq_len, max_time=math.inf):
13 | gaps = self.exp.sample((size, max_seq_len))
14 | times = torch.cumsum(gaps, dim=1)
15 | masks = (times <= max_time).float()
16 | return times, masks
17 |
18 |
19 | def gen_data(n_samples=10000, seq_len=200, max_time=1, poisson_rate=50,
20 | obs_span_rate=.25, save_file=None):
21 | """Generates a 3-channel synthetic dataset.
22 |
23 | The observations are within a window of size (max_time * obs_span_rate)
24 | randomly occurring at the time span [0, max_time].
25 |
26 | Args:
27 | n_samples:
28 | Number of data cases.
29 | seq_len:
30 | Maximum number of observations in a channel.
31 | max_time:
32 | Length of time interval [0, max_time].
33 | poisson_rate:
34 | Rate of homogeneous Poisson process.
35 | obs_span_rate:
36 | The continuous portion of the time span [0, max_time]
37 | that observations are restricted in.
38 | save_file:
39 | File name that the generated data is saved to.
40 | """
41 | n_channels = 3
42 | time_unif = np.linspace(0, max_time, seq_len)
43 | time_unif_3ch = np.broadcast_to(time_unif, (n_channels, seq_len))
44 | data_unif = np.empty((n_samples, n_channels, seq_len))
45 | sparse_data, sparse_time, sparse_mask = [
46 | np.empty((n_samples, n_channels, seq_len)) for _ in range(3)]
47 | tpp = HomogeneousPoissonProcess(rate=poisson_rate)
48 |
49 | def gen_time_series(offset1, offset2, t):
50 | t1 = t[0] + offset1
51 | t2 = t[2] + offset2
52 | t1_shift = t[1] + offset1 + 20
53 | data = np.empty((3, seq_len))
54 | data[0] = np.sin(t1 * 20 + np.sin(t1 * 20)) * .8
55 | data[1] = -np.sin(t1_shift * 20 + np.sin(t1_shift * 20)) * .5
56 | data[2] = np.sin(t2 * 12)
57 | return data
58 |
59 | for i in range(n_samples):
60 | offset1 = np.random.normal(0, 10)
61 | offset2 = np.random.uniform(0, 10)
62 |
63 | # Noise-free evenly-sampled time series
64 | data_unif[i] = gen_time_series(offset1, offset2, time_unif_3ch)
65 |
66 | # Generate observations between [0, obs_span_rate].
67 | times, masks = tpp.sample(3, seq_len, max_time=obs_span_rate)
68 | # Add independent random offset Unif(0, 1 - obs_span_rate) to each
69 | # channel so that all the observations will still be within [0, 1].
70 | times += torch.rand((3, 1)) * (1 - obs_span_rate)
71 | # Scale time span from [0, 1] to [0, max_time].
72 | times *= max_time
73 | # Set time entries corresponding to unobserved samples to time 0.
74 | sparse_time[i] = times * masks
75 | sparse_mask[i] = masks
76 | sparse_data[i] = gen_time_series(offset1, offset2, times)
77 |
78 | # Add a small independent Gaussian noise to each channel
79 | sparse_data += np.random.normal(0, .01, sparse_data.shape)
80 |
81 | # Pack the data to minimize the padded entries
82 | compact_len = sparse_mask.astype(int).sum(axis=2).max()
83 | compact_data, compact_time, compact_mask = [
84 | np.zeros((n_samples, 3, compact_len)) for _ in range(3)]
85 | for i in range(n_samples):
86 | for j in range(3):
87 | idx = sparse_mask[i, j] == 1
88 | n_obs = idx.sum()
89 | compact_data[i, j, :n_obs] = sparse_data[i, j, idx]
90 | compact_time[i, j, :n_obs] = sparse_time[i, j, idx]
91 | compact_mask[i, j, :n_obs] = sparse_mask[i, j, idx]
92 |
93 | if save_file:
94 | np.savez_compressed(
95 | save_file,
96 | time=compact_time,
97 | data=compact_data,
98 | mask=compact_mask,
99 | data_unif=data_unif,
100 | time_unif=time_unif,
101 | )
102 |
103 | return compact_data, compact_time, compact_mask, data_unif, time_unif
104 |
105 |
106 | def main():
107 | gen_data(n_samples=10000, seq_len=200, max_time=1, poisson_rate=50,
108 | obs_span_rate=.25, save_file='toy-data.npz')
109 |
110 |
111 | if __name__ == '__main__':
112 | main()
113 |
--------------------------------------------------------------------------------
/time-series/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResLinearBlock(nn.Module):
7 | def __init__(self, in_size, out_size):
8 | super().__init__()
9 | self.linear = nn.Sequential(
10 | nn.Linear(in_size, out_size),
11 | nn.Dropout(),
12 | nn.LeakyReLU(.2),
13 | nn.Linear(out_size, out_size),
14 | nn.Dropout(),
15 | nn.LeakyReLU(.2),
16 | )
17 |
18 | self.skip = nn.Sequential(
19 | nn.Linear(in_size, out_size),
20 | nn.LeakyReLU(.2),
21 | )
22 |
23 | def forward(self, x):
24 | return self.linear(x) + self.skip(x)
25 |
26 |
27 | class Classifier(nn.Module):
28 | def __init__(self, in_size, layers=1):
29 | super().__init__()
30 | blocks = []
31 | for _ in range(layers):
32 | blocks.append(ResLinearBlock(in_size, in_size))
33 | # No spectral normalization for the last layer
34 | blocks.append(nn.Linear(in_size, 1))
35 | self.res_linear = nn.Sequential(*blocks)
36 |
37 | def forward(self, x):
38 | return self.res_linear(x)
39 |
40 |
41 | class GBlock(nn.Module):
42 | def __init__(self, in_channels, out_channels):
43 | super().__init__()
44 | self.activation = nn.ReLU(inplace=False)
45 | self.bn1 = nn.BatchNorm1d(in_channels)
46 | self.bn2 = nn.BatchNorm1d(out_channels)
47 | self.conv1 = nn.Conv1d(
48 | in_channels, out_channels, kernel_size=3, padding=1)
49 | self.conv2 = nn.Conv1d(
50 | out_channels, out_channels, kernel_size=3, padding=1)
51 | self.convx = None
52 | if in_channels != out_channels:
53 | self.convx = nn.Conv1d(
54 | in_channels, out_channels, kernel_size=1)
55 |
56 | def forward(self, x):
57 | h = self.activation(self.bn1(x))
58 | h = F.interpolate(h, scale_factor=2)
59 | x = F.interpolate(x, scale_factor=2)
60 | if self.convx:
61 | x = self.convx(x)
62 | h = self.conv1(h)
63 | h = self.activation(self.bn2(h))
64 | h = self.conv2(h)
65 | return h + x
66 |
67 |
68 | class GridDecoder(nn.Module):
69 | def __init__(self, dim_z, channels, start_len=16, squash=None):
70 | super().__init__()
71 | self.activation = nn.ReLU(inplace=False)
72 | self.start_len = start_len
73 | self.linear = nn.Linear(dim_z, channels[0] * start_len)
74 | self.blocks = nn.Sequential(
75 | *[GBlock(in_channels, channels[c + 1])
76 | for c, in_channels in enumerate(channels[:-2])])
77 | self.output = nn.Sequential(
78 | nn.BatchNorm1d(channels[-2]),
79 | self.activation,
80 | nn.Conv1d(channels[-2], channels[-1], kernel_size=3, padding=1),
81 | )
82 | self.squash = squash
83 |
84 | def forward(self, z):
85 | h = self.linear(z)
86 | h = h.view(h.shape[0], -1, self.start_len)
87 | h = self.blocks(h)
88 | h = self.output(h)
89 | if self.squash:
90 | h = self.squash(h)
91 | return h
92 |
93 |
94 | class DBlock(nn.Module):
95 | def __init__(self, in_channels, out_channels, downsample=True):
96 | super().__init__()
97 | self.activation = nn.ReLU(inplace=False)
98 | self.conv1 = nn.Conv1d(
99 | in_channels, out_channels, kernel_size=3, padding=1)
100 | self.conv2 = nn.Conv1d(
101 | out_channels, out_channels, kernel_size=3, padding=1)
102 | self.convx = None
103 | if in_channels != out_channels:
104 | self.convx = nn.Conv1d(in_channels, out_channels, kernel_size=1)
105 | self.downsample = None
106 | if downsample:
107 | self.downsample = nn.AvgPool1d(2)
108 |
109 | def shortcut(self, x):
110 | if self.convx:
111 | x = self.convx(x)
112 | if self.downsample:
113 | x = self.downsample(x)
114 | return x
115 |
116 | def forward(self, x):
117 | # pre-activation
118 | h = self.activation(x)
119 | h = self.conv1(h)
120 | h = self.conv2(self.activation(h))
121 | if self.downsample:
122 | h = self.downsample(h)
123 | return h + self.shortcut(x)
124 |
125 |
126 | class GridEncoder(nn.Module):
127 | def __init__(self, channels, out_dim=1):
128 | super().__init__()
129 | self.activation = nn.ReLU(inplace=False)
130 | self.blocks = nn.Sequential(
131 | *[DBlock(in_channels, out_channels)
132 | for in_channels, out_channels
133 | in zip(channels[:-1], channels[1:])])
134 | self.linear = nn.Linear(channels[-1], out_dim)
135 |
136 | def forward(self, x):
137 | h = x
138 | h = self.blocks(h)
139 | h = self.activation(h).sum(2)
140 | return self.linear(h)
141 |
142 |
143 | class Decoder(nn.Module):
144 | def __init__(self, grid_decoder, max_time=5, kernel_bw=None, dec_ref=128):
145 | super().__init__()
146 | if kernel_bw is None:
147 | self.kernel_bw = max_time / dec_ref * 3
148 | else:
149 | self.kernel_bw = kernel_bw
150 | # ref_times are the assigned time stamps for the evenly-spaced
151 | # generated sequences by conv1d.
152 | self.register_buffer('ref_times', torch.linspace(0, max_time, dec_ref))
153 | self.ref_times = self.ref_times[:, None]
154 | self.grid_decoder = grid_decoder
155 |
156 | def forward(self, code, time, mask):
157 | """
158 | Args:
159 | code: shape (batch_size, latent_size)
160 | time: shape (batch_size, channels, max_seq_len)
161 | mask: shape (batch_size, channels, max_seq_len)
162 |
163 | Returns:
164 | interpolated tensor of shape (batch_size, max_seq_len)
165 | """
166 | # shape of x: (batch_size, n_channels, dec_ref)
167 | x = self.grid_decoder(code)
168 |
169 | # t_diff shape: (batch_size, n_channels, dec_ref, max_seq_len)
170 | t_diff = time[:, :, None] - self.ref_times
171 |
172 | # Epanechnikov quadratic kernel:
173 | # K_\lambda(x_0, x) = relu(3/4 * (1 - (|x_0 - x| / \lambda)^2))
174 | # shape of w: (batch_size, n_channels, dec_ref, max_seq_len)
175 | w = F.relu((1 - (t_diff / self.kernel_bw)**2) * .75)
176 | # avoid divided by zero
177 | # normalizer = torch.clamp(w.sum(2), min=1e-6)
178 | # return ((x[:, :, :, None] * w).sum(2) * mask) / normalizer
179 | ks_x = ((x[:, :, :, None] * w).sum(2) * mask) / w.sum(2)
180 | return ks_x
181 |
182 |
183 | def gan_loss(real, fake, real_target, fake_target):
184 | real_score = sum(F.binary_cross_entropy_with_logits(
185 | r, r.new_tensor(real_target).expand_as(r)) for r in real)
186 | fake_score = sum(F.binary_cross_entropy_with_logits(
187 | f, f.new_tensor(fake_target).expand_as(f)) for f in fake)
188 | return real_score + fake_score
189 |
--------------------------------------------------------------------------------
/time-series/mimic3_pbigan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils import spectral_norm
5 | import torch.optim as optim
6 | from torch.distributions.normal import Normal
7 | from torch.utils.data import DataLoader
8 | import numpy as np
9 | from datetime import datetime
10 | import time
11 | from pathlib import Path
12 | import argparse
13 | from collections import defaultdict
14 | from spline_cconv import ContinuousConv1D
15 | import time_series
16 | from ema import EMA
17 | from utils import count_parameters, mkdir, make_scheduler
18 | from mmd import mmd
19 | from tracker import Tracker
20 | from evaluate import Evaluator
21 | from sn_layers import (
22 | InvertibleLinearResNet,
23 | Classifier,
24 | GridDecoder,
25 | GridEncoder,
26 | Decoder,
27 | gan_loss,
28 | )
29 |
30 |
31 | use_cuda = torch.cuda.is_available()
32 | device = torch.device('cuda' if use_cuda else 'cpu')
33 |
34 |
35 | class Encoder(nn.Module):
36 | def __init__(self, cconv, latent_size, channels, trans_layers=1):
37 | super().__init__()
38 | self.cconv = cconv
39 | self.grid_encoder = GridEncoder(channels, latent_size * 2)
40 | self.trans = InvertibleLinearResNet(
41 | latent_size, latent_size, trans_layers).to(device)
42 |
43 | def forward(self, cconv_graph, batch_size):
44 | x = self.cconv(*cconv_graph, batch_size)
45 | mu, logvar = self.grid_encoder(x).chunk(2, dim=1)
46 | if self.training:
47 | std = F.softplus(logvar)
48 | qz_x = Normal(mu, std)
49 | z_0 = qz_x.rsample()
50 | z_T = self.trans(z_0)
51 | else:
52 | z_T = self.trans(mu)
53 | return z_T
54 |
55 |
56 | class ConvCritic(nn.Module):
57 | def __init__(self, cconv, latent_size, channels, embed_size=32):
58 | super().__init__()
59 |
60 | self.cconv = cconv
61 | self.grid_critic = GridEncoder(channels, embed_size)
62 |
63 | self.z_dis = nn.Sequential(
64 | spectral_norm(nn.Linear(latent_size, embed_size)),
65 | nn.LeakyReLU(0.2),
66 | spectral_norm(nn.Linear(embed_size, embed_size)),
67 | )
68 |
69 | self.x_linear = spectral_norm(nn.Linear(embed_size, 1))
70 |
71 | self.xz_dis = nn.Sequential(
72 | spectral_norm(nn.Linear(embed_size * 2, embed_size)),
73 | nn.LeakyReLU(0.2),
74 | spectral_norm(nn.Linear(embed_size, 1)),
75 | )
76 |
77 | def forward(self, cconv_graph, batch_size, z):
78 | x = self.cconv(*cconv_graph, batch_size)
79 | x = self.grid_critic(x)
80 | z = self.z_dis(z)
81 | xz = torch.cat((x, z), 1)
82 | xz = self.xz_dis(xz)
83 | x_out = self.x_linear(x).view(-1)
84 | xz_out = xz.view(-1)
85 | return xz_out, x_out
86 |
87 |
88 | class PBiGAN(nn.Module):
89 | def __init__(self, encoder, decoder, classifier, ae_loss='mse'):
90 | super().__init__()
91 | self.encoder = encoder
92 | self.decoder = decoder
93 | self.classifier = classifier
94 | self.ae_loss = ae_loss
95 |
96 | def forward(self, data, time, mask, y, cconv_graph, time_t, mask_t):
97 | batch_size = len(data)
98 | z_T = self.encoder(cconv_graph, batch_size)
99 |
100 | z_gen = torch.empty_like(z_T).normal_()
101 | x_gen = self.decoder(z_gen, time_t, mask_t)
102 |
103 | x_recon = self.decoder(z_T, time, mask)
104 |
105 | if self.ae_loss == 'mse':
106 | ae_loss = F.mse_loss(x_recon, data, reduction='none') * mask
107 | elif self.ae_loss == 'smooth_l1':
108 | ae_loss = F.smooth_l1_loss(x_recon, data, reduction='none') * mask
109 |
110 | ae_loss = ae_loss.sum((-1, -2))
111 |
112 | y_logit = self.classifier(z_T).view(-1)
113 |
114 | # cls_loss: -log p(y|z)
115 | cls_loss = F.binary_cross_entropy_with_logits(
116 | y_logit, y.expand_as(y_logit), reduction='none')
117 |
118 | return z_T, x_recon, z_gen, x_gen, ae_loss.mean(), cls_loss.mean()
119 |
120 | def predict(self, data, time, mask, cconv_graph):
121 | batch_size = len(data)
122 | z_T = self.encoder(cconv_graph, batch_size)
123 | y_logit = self.classifier(z_T).view(-1)
124 | return y_logit
125 |
126 |
127 | def main():
128 | parser = argparse.ArgumentParser()
129 |
130 | parser.add_argument('--data', default='mimic3.npz',
131 | help='data file')
132 | parser.add_argument('--seed', type=int, default=None,
133 | help='random seed. Randomly set if not specified.')
134 |
135 | # training options
136 | parser.add_argument('--nz', type=int, default=32,
137 | help='dimension of latent variable')
138 | parser.add_argument('--epoch', type=int, default=500,
139 | help='number of training epochs')
140 | parser.add_argument('--batch-size', type=int, default=64,
141 | help='batch size')
142 | # Use smaller test batch size to accommodate more importance samples
143 | parser.add_argument('--test-batch-size', type=int, default=32,
144 | help='batch size for validation and test set')
145 | parser.add_argument('--lr', type=float, default=2e-4,
146 | help='encoder/decoder learning rate')
147 | parser.add_argument('--dis-lr', type=float, default=3e-4,
148 | help='discriminator learning rate')
149 | parser.add_argument('--min-lr', type=float, default=1e-4,
150 | help='min encoder/decoder learning rate for LR '
151 | 'scheduler. -1 to disable annealing')
152 | parser.add_argument('--min-dis-lr', type=float, default=1.5e-4,
153 | help='min discriminator learning rate for LR '
154 | 'scheduler. -1 to disable annealing')
155 | parser.add_argument('--wd', type=float, default=1e-4,
156 | help='weight decay')
157 | parser.add_argument('--overlap', type=float, default=.5,
158 | help='kernel overlap')
159 | parser.add_argument('--cls', type=float, default=1,
160 | help='classification weight')
161 | parser.add_argument('--clsdep', type=int, default=1,
162 | help='number of layers for classifier')
163 | parser.add_argument('--eval-interval', type=int, default=1,
164 | help='AUC evaluation interval. '
165 | '0 to disable evaluation.')
166 | parser.add_argument('--save-interval', type=int, default=0,
167 | help='interval to save models. 0 to disable saving.')
168 | parser.add_argument('--prefix', default='pbigan',
169 | help='prefix of output directory')
170 | parser.add_argument('--comp', type=int, default=7,
171 | help='continuous convolution kernel size')
172 | parser.add_argument('--ae', type=float, default=1,
173 | help='autoencoding regularization strength')
174 | parser.add_argument('--aeloss', default='mse',
175 | help='autoencoding loss. (options: mse, smooth_l1)')
176 | parser.add_argument('--dec-ch', default='8-16-16',
177 | help='decoder architecture')
178 | parser.add_argument('--enc-ch', default='64-32-32-16',
179 | help='encoder architecture')
180 | parser.add_argument('--dis-ch', default=None,
181 | help='discriminator architecture. Use encoder '
182 | 'architecture if unspecified.')
183 | parser.add_argument('--rescale', dest='rescale', action='store_const',
184 | const=True, default=True,
185 | help='if set, rescale time to [-1, 1]')
186 | parser.add_argument('--no-rescale', dest='rescale', action='store_const',
187 | const=False)
188 | parser.add_argument('--cconvnorm', dest='cconv_norm',
189 | action='store_const', const=True, default=True,
190 | help='if set, normalize continuous convolutional '
191 | 'layer using mean pooling')
192 | parser.add_argument('--no-cconvnorm', dest='cconv_norm',
193 | action='store_const', const=False)
194 | parser.add_argument('--cconv-ref', type=int, default=98,
195 | help='number of evenly-spaced reference locations '
196 | 'for continuous convolutional layer')
197 | parser.add_argument('--dec-ref', type=int, default=128,
198 | help='number of evenly-spaced reference locations '
199 | 'for decoder')
200 | parser.add_argument('--trans', type=int, default=2,
201 | help='number of encoder layers')
202 | parser.add_argument('--ema', dest='ema', type=int, default=0,
203 | help='start epoch of exponential moving average '
204 | '(EMA). -1 to disable EMA')
205 | parser.add_argument('--ema-decay', type=float, default=.9999,
206 | help='EMA decay')
207 | parser.add_argument('--mmd', type=float, default=1,
208 | help='MMD strength for latent variable')
209 |
210 | args = parser.parse_args()
211 |
212 | nz = args.nz
213 |
214 | epochs = args.epoch
215 | eval_interval = args.eval_interval
216 | save_interval = args.save_interval
217 |
218 | if args.seed is None:
219 | rnd = np.random.RandomState(None)
220 | random_seed = rnd.randint(np.iinfo(np.uint32).max)
221 | else:
222 | random_seed = args.seed
223 | rnd = np.random.RandomState(random_seed)
224 | np.random.seed(random_seed)
225 | torch.manual_seed(random_seed)
226 |
227 | max_time = 5
228 | cconv_ref = args.cconv_ref
229 | overlap = args.overlap
230 | train_dataset, val_dataset, test_dataset = time_series.split_data(
231 | args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale)
232 |
233 | train_loader = DataLoader(
234 | train_dataset, batch_size=args.batch_size, shuffle=True,
235 | drop_last=True, collate_fn=train_dataset.collate_fn)
236 | n_train_batch = len(train_loader)
237 |
238 | time_loader = DataLoader(
239 | train_dataset, batch_size=args.batch_size, shuffle=True,
240 | drop_last=True, collate_fn=train_dataset.collate_fn)
241 |
242 | val_loader = DataLoader(
243 | val_dataset, batch_size=args.test_batch_size, shuffle=False,
244 | collate_fn=val_dataset.collate_fn)
245 |
246 | test_loader = DataLoader(
247 | test_dataset, batch_size=args.test_batch_size, shuffle=False,
248 | collate_fn=test_dataset.collate_fn)
249 |
250 | in_channels, seq_len = train_dataset.data.shape[1:]
251 |
252 | if args.dis_ch is None:
253 | args.dis_ch = args.enc_ch
254 |
255 | dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels]
256 | enc_channels = [int(c) for c in args.enc_ch.split('-')]
257 | dis_channels = [int(c) for c in args.dis_ch.split('-')]
258 |
259 | out_channels = enc_channels[0]
260 |
261 | squash = torch.sigmoid
262 | if args.rescale:
263 | squash = torch.tanh
264 |
265 | dec_ch_up = 2**(len(dec_channels) - 2)
266 | assert args.dec_ref % dec_ch_up == 0, (
267 | f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.')
268 | dec_len0 = args.dec_ref // dec_ch_up
269 | grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash)
270 |
271 | decoder = Decoder(
272 | grid_decoder, max_time=max_time, dec_ref=args.dec_ref).to(device)
273 | cconv = ContinuousConv1D(
274 | in_channels, out_channels, max_time, cconv_ref, overlap_rate=overlap,
275 | kernel_size=args.comp, norm=args.cconv_norm).to(device)
276 | encoder = Encoder(cconv, nz, enc_channels, args.trans).to(device)
277 |
278 | classifier = Classifier(nz, args.clsdep).to(device)
279 |
280 | pbigan = PBiGAN(
281 | encoder, decoder, classifier, ae_loss=args.aeloss).to(device)
282 |
283 | ema = None
284 | if args.ema >= 0:
285 | ema = EMA(pbigan, args.ema_decay, args.ema)
286 |
287 | critic_cconv = ContinuousConv1D(
288 | in_channels, out_channels, max_time, cconv_ref, overlap_rate=overlap,
289 | kernel_size=args.comp, norm=args.cconv_norm).to(device)
290 | critic_embed = 32
291 | critic = ConvCritic(
292 | critic_cconv, nz, dis_channels, critic_embed).to(device)
293 |
294 | optimizer = optim.Adam(
295 | pbigan.parameters(), lr=args.lr,
296 | betas=(0, .999), weight_decay=args.wd)
297 | critic_optimizer = optim.Adam(
298 | critic.parameters(), lr=args.dis_lr,
299 | betas=(0, .999), weight_decay=args.wd)
300 |
301 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)
302 | dis_scheduler = make_scheduler(
303 | critic_optimizer, args.dis_lr, args.min_dis_lr, epochs)
304 |
305 | path = '{}_{}'.format(
306 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'))
307 |
308 | output_dir = Path('results') / 'mimic3-pbigan' / path
309 | print(output_dir)
310 | log_dir = mkdir(output_dir / 'log')
311 | model_dir = mkdir(output_dir / 'model')
312 |
313 | start_epoch = 0
314 |
315 | with (log_dir / 'seed.txt').open('w') as f:
316 | print(random_seed, file=f)
317 | with (log_dir / 'gpu.txt').open('a') as f:
318 | print(torch.cuda.device_count(), start_epoch, file=f)
319 | with (log_dir / 'args.txt').open('w') as f:
320 | for key, val in sorted(vars(args).items()):
321 | print(f'{key}: {val}', file=f)
322 | with (log_dir / 'params.txt').open('w') as f:
323 | def print_params_count(module, name):
324 | try: # sum counts if module is a list
325 | params_count = sum(count_parameters(m) for m in module)
326 | except TypeError:
327 | params_count = count_parameters(module)
328 | print(f'{name} {params_count}', file=f)
329 | print_params_count(grid_decoder, 'grid_decoder')
330 | print_params_count(decoder, 'decoder')
331 | print_params_count(cconv, 'cconv')
332 | print_params_count(encoder, 'encoder')
333 | print_params_count(classifier, 'classifier')
334 | print_params_count(pbigan, 'pbigan')
335 | print_params_count(critic, 'critic')
336 | print_params_count([pbigan, critic], 'total')
337 |
338 | tracker = Tracker(log_dir, n_train_batch)
339 | evaluator = Evaluator(pbigan, val_loader, test_loader, log_dir)
340 | start = time.time()
341 | epoch_start = start
342 |
343 | batch_size = args.batch_size
344 |
345 | for epoch in range(start_epoch, epochs):
346 | loss_breakdown = defaultdict(float)
347 | epoch_start = time.time()
348 |
349 | if epoch >= 40:
350 | args.cls = 200
351 |
352 | for ((val, idx, mask, y, _, cconv_graph),
353 | (_, idx_t, mask_t, _, index, _)) in zip(
354 | train_loader, time_loader):
355 |
356 | z_enc, x_recon, z_gen, x_gen, ae_loss, cls_loss = pbigan(
357 | val, idx, mask, y, cconv_graph, idx_t, mask_t)
358 |
359 | cconv_graph_gen = train_dataset.make_graph(
360 | x_gen, idx_t, mask_t, index)
361 |
362 | # Don't need pbigan.requires_grad_(False);
363 | # critic takes as input only the detached tensors.
364 | real = critic(cconv_graph, batch_size, z_enc.detach())
365 | detached_graph = [[cat_y.detach() for cat_y in x] if i == 2 else x
366 | for i, x in enumerate(cconv_graph_gen)]
367 | fake = critic(detached_graph, batch_size, z_gen.detach())
368 |
369 | D_loss = gan_loss(real, fake, 1, 0)
370 |
371 | critic_optimizer.zero_grad()
372 | D_loss.backward()
373 | critic_optimizer.step()
374 |
375 | for p in critic.parameters():
376 | p.requires_grad_(False)
377 | real = critic(cconv_graph, batch_size, z_enc)
378 | fake = critic(cconv_graph_gen, batch_size, z_gen)
379 |
380 | G_loss = gan_loss(real, fake, 0, 1)
381 |
382 | mmd_loss = mmd(z_enc, z_gen)
383 |
384 | loss = (G_loss + ae_loss * args.ae + cls_loss * args.cls
385 | + mmd_loss * args.mmd)
386 |
387 | optimizer.zero_grad()
388 | loss.backward()
389 | optimizer.step()
390 | for p in critic.parameters():
391 | p.requires_grad_(True)
392 |
393 | if ema:
394 | ema.update()
395 |
396 | loss_breakdown['D'] += D_loss.item()
397 | loss_breakdown['G'] += G_loss.item()
398 | loss_breakdown['AE'] += ae_loss.item()
399 | loss_breakdown['MMD'] += mmd_loss.item()
400 | loss_breakdown['CLS'] += cls_loss.item()
401 | loss_breakdown['total'] += loss.item()
402 |
403 | if scheduler:
404 | scheduler.step()
405 | if dis_scheduler:
406 | dis_scheduler.step()
407 |
408 | cur_time = time.time()
409 | tracker.log(
410 | epoch, loss_breakdown, cur_time - epoch_start, cur_time - start)
411 |
412 | if eval_interval > 0 and (epoch + 1) % eval_interval == 0:
413 | if ema:
414 | ema.apply()
415 | evaluator.evaluate(epoch)
416 | ema.restore()
417 | else:
418 | evaluator.evaluate(epoch)
419 |
420 | model_dict = {
421 | 'pbigan': pbigan.state_dict(),
422 | 'critic': critic.state_dict(),
423 | 'ema': ema.state_dict() if ema else None,
424 | 'epoch': epoch + 1,
425 | 'args': args,
426 | }
427 | torch.save(model_dict, str(log_dir / 'model.pth'))
428 | if save_interval > 0 and (epoch + 1) % save_interval == 0:
429 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))
430 |
431 | print(output_dir)
432 |
433 |
434 | if __name__ == '__main__':
435 | main()
436 |
--------------------------------------------------------------------------------
/time-series/mimic3_pvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.optim as optim
5 | from torch.distributions.normal import Normal
6 | from torch.utils.data import DataLoader
7 | import numpy as np
8 | from datetime import datetime
9 | import time
10 | from pathlib import Path
11 | import argparse
12 | import math
13 | from collections import defaultdict
14 | from spline_cconv import ContinuousConv1D
15 | import time_series
16 | import flow
17 | from ema import EMA
18 | from utils import count_parameters, mkdir, make_scheduler
19 | from tracker import Tracker
20 | from evaluate import Evaluator
21 | from layers import (
22 | Classifier,
23 | GridDecoder,
24 | GridEncoder,
25 | Decoder,
26 | )
27 |
28 |
29 | use_cuda = torch.cuda.is_available()
30 | device = torch.device('cuda' if use_cuda else 'cpu')
31 |
32 |
33 | class Encoder(nn.Module):
34 | def __init__(self, cconv, latent_size, channels, flow_depth=2):
35 | super().__init__()
36 | self.cconv = cconv
37 |
38 | if flow_depth > 0:
39 | hidden_size = latent_size * 2
40 | flow_layers = [flow.InverseAutoregressiveFlow(
41 | latent_size, hidden_size, latent_size)
42 | for _ in range(flow_depth)]
43 |
44 | flow_layers.append(flow.Reverse(latent_size))
45 | self.q_z_flow = flow.FlowSequential(*flow_layers)
46 | self.enc_chunk = 3
47 | else:
48 | self.q_z_flow = None
49 | self.enc_chunk = 2
50 |
51 | self.grid_encoder = GridEncoder(channels, latent_size * self.enc_chunk)
52 |
53 | def forward(self, cconv_graph, batch_size, iw_samples=3):
54 | x = self.cconv(*cconv_graph, batch_size)
55 | grid_enc = self.grid_encoder(x).chunk(self.enc_chunk, dim=1)
56 | mu, logvar = grid_enc[:2]
57 | std = F.softplus(logvar)
58 | qz_x = Normal(mu, std)
59 | z_0 = qz_x.rsample([iw_samples])
60 | log_q_z_0 = qz_x.log_prob(z_0)
61 | if self.q_z_flow:
62 | z_T, log_q_z_flow = self.q_z_flow(z_0, context=grid_enc[2])
63 | log_q_z = (log_q_z_0 + log_q_z_flow).sum(-1)
64 | else:
65 | z_T, log_q_z = z_0, log_q_z_0.sum(-1)
66 | return z_T, log_q_z
67 |
68 |
69 | def masked_loss(loss_fn, pred, data, mask):
70 | # return (loss_fn(pred * mask, data * mask,
71 | # reduction='none') * mask).mean()
72 | # Expand data shape from (batch_size, d) to (iw_samples, batch_size, d)
73 | return loss_fn(pred, data.expand_as(pred), reduction='none') * mask
74 |
75 |
76 | class PVAE(nn.Module):
77 | def __init__(self, encoder, decoder, classifier, sigma=.2, cls_weight=100):
78 | super().__init__()
79 | self.encoder = encoder
80 | self.decoder = decoder
81 | self.classifier = classifier
82 | self.sigma = sigma
83 | self.cls_weight = cls_weight
84 |
85 | def forward(self, data, time, mask, y, cconv_graph, iw_samples=3,
86 | ts_lambda=1, kl_lambda=1):
87 | batch_size = len(data)
88 | z_T, log_q_z = self.encoder(cconv_graph, batch_size, iw_samples)
89 |
90 | pz = Normal(torch.zeros_like(z_T), torch.ones_like(z_T))
91 | log_p_z = pz.log_prob(z_T).sum(-1)
92 | # kl_loss: log q(z|x) - log p(z)
93 | kl_loss = log_q_z - log_p_z
94 |
95 | var2 = 2 * self.sigma**2
96 | # half_log2pivar: log(2 * pi * sigma^2) / 2
97 | half_log2pivar = .5 * math.log(math.pi * var2)
98 |
99 | # Multivariate Gaussian log-likelihood:
100 | # -D/2 * log(2*pi*sigma^2) - 1/2 \sum_{i=1}^D (x_i - mu_i)^2 / sigma^2
101 | def neg_gaussian_logp(pred, data, mask=None):
102 | se = F.mse_loss(pred, data.expand_as(pred), reduction='none')
103 | if mask is None:
104 | return se / var2 + half_log2pivar
105 | return (se / var2 + half_log2pivar) * mask
106 |
107 | # Reshape z to accommodate modules with strict input shape
108 | # requirements such as convolutional layers.
109 | # Expected shape of x_recon: (iw_samples * batch_size, C, L)
110 | z_flat = z_T.view(-1, *z_T.shape[2:])
111 | x_recon = self.decoder(
112 | z_flat,
113 | time.repeat((iw_samples, 1, 1)),
114 | mask.repeat((iw_samples, 1, 1)))
115 |
116 | # Gaussian noise for time series
117 | # data shape :(batch_size, C, L)
118 | # x_recon shape: (iw_samples * batch_size, C, L)
119 | x_recon = x_recon.view(iw_samples, *data.shape)
120 | neg_logp = neg_gaussian_logp(x_recon, data, mask)
121 | # neg_logp: -log p(x|z)
122 | neg_logp = neg_logp.sum((-1, -2))
123 |
124 | y_logit = self.classifier(z_flat).view(iw_samples, -1)
125 |
126 | # cls_loss: -log p(y|z)
127 | cls_loss = F.binary_cross_entropy_with_logits(
128 | y_logit, y.expand_as(y_logit), reduction='none')
129 |
130 | # elbo_x = log p(x|z) + log p(z) - log q(z|x)
131 | elbo_x = -(neg_logp * ts_lambda + kl_loss * kl_lambda)
132 |
133 | with torch.no_grad():
134 | is_weight = F.softmax(elbo_x, 0)
135 |
136 | # IWAE loss: -log E[p(x|z) p(z) / q(z|x)]
137 | # Here we ignore the constant shift of -log(k_samples)
138 | loss_x = -elbo_x.logsumexp(0).mean()
139 | loss_y = (is_weight * cls_loss).sum(0).mean()
140 | loss = loss_x + loss_y * self.cls_weight
141 |
142 | # For debugging
143 | x_se = masked_loss(F.mse_loss, x_recon, data, mask)
144 | mse = x_se.sum((-1, -2)) / mask.sum((-1, -2)).clamp(min=1)
145 |
146 | CE = (is_weight * cls_loss).sum(0).mean().item()
147 | loss_breakdown = {
148 | 'loss': loss.item(),
149 | 'reconst.': neg_logp.mean().item() * ts_lambda,
150 | 'MSE': mse.mean().item(),
151 | 'KL': kl_loss.mean().item() * kl_lambda,
152 | 'CE': CE,
153 | 'classif.': CE * self.cls_weight,
154 | }
155 | return loss, z_T, elbo_x, loss_breakdown
156 |
157 | def predict(self, data, time, mask, cconv_graph, iw_samples=50):
158 | dummy_y = data.new_zeros(len(data))
159 | _, z, elbo, _ = self(
160 | data, time, mask, dummy_y, cconv_graph, iw_samples)
161 | z_flat = z.view(-1, *z.shape[2:])
162 | pred_logit = self.classifier(z_flat).view(iw_samples, -1)
163 | is_weight = F.softmax(elbo, 0)
164 |
165 | # Importance reweighted predictive probability
166 | # p(y|x) =~ E_{q_IW(z|x)}[p(y|z)]
167 | py_z = torch.sigmoid(pred_logit)
168 | expected_py_z = (is_weight * py_z).sum(0)
169 | return expected_py_z
170 |
171 |
172 | def main():
173 | parser = argparse.ArgumentParser()
174 |
175 | parser.add_argument('--data', default='mimic3.npz',
176 | help='data file')
177 | parser.add_argument('--seed', type=int, default=None,
178 | help='random seed. Randomly set if not specified.')
179 |
180 | # training options
181 | parser.add_argument('--nz', type=int, default=32,
182 | help='dimension of latent variable')
183 | parser.add_argument('--epoch', type=int, default=200,
184 | help='number of training epochs')
185 | parser.add_argument('--batch-size', type=int, default=64,
186 | help='batch size')
187 | # Use smaller test batch size to accommodate more importance samples
188 | parser.add_argument('--test-batch-size', type=int, default=32,
189 | help='batch size for validation and test set')
190 | parser.add_argument('--train-k', type=int, default=8,
191 | help='number of importance weights for training')
192 | parser.add_argument('--test-k', type=int, default=50,
193 | help='number of importance weights for evaluation')
194 | parser.add_argument('--flow', type=int, default=2,
195 | help='number of IAF layers')
196 | parser.add_argument('--lr', type=float, default=2e-4,
197 | help='global learning rate')
198 | parser.add_argument('--enc-lr', type=float, default=1e-4,
199 | help='encoder learning rate')
200 | parser.add_argument('--dec-lr', type=float, default=1e-4,
201 | help='decoder learning rate')
202 | parser.add_argument('--min-lr', type=float, default=-1,
203 | help='min learning rate for LR scheduler. '
204 | '-1 to disable annealing')
205 | parser.add_argument('--wd', type=float, default=1e-3,
206 | help='weight decay')
207 | parser.add_argument('--overlap', type=float, default=.5,
208 | help='kernel overlap')
209 | parser.add_argument('--cls', type=float, default=200,
210 | help='classification weight')
211 | parser.add_argument('--clsdep', type=int, default=1,
212 | help='number of layers for classifier')
213 | parser.add_argument('--ts', type=float, default=1,
214 | help='log-likelihood weight for ELBO')
215 | parser.add_argument('--kl', type=float, default=.1,
216 | help='KL weight for ELBO')
217 | parser.add_argument('--eval-interval', type=int, default=1,
218 | help='AUC evaluation interval. '
219 | '0 to disable evaluation.')
220 | parser.add_argument('--save-interval', type=int, default=0,
221 | help='interval to save models. 0 to disable saving.')
222 | parser.add_argument('--prefix', default='pvae',
223 | help='prefix of output directory')
224 | parser.add_argument('--comp', type=int, default=7,
225 | help='continuous convolution kernel size')
226 | parser.add_argument('--sigma', type=float, default=.2,
227 | help='standard deviation for Gaussian likelihood')
228 | parser.add_argument('--dec-ch', default='8-16-16',
229 | help='decoder architecture')
230 | parser.add_argument('--enc-ch', default='64-32-32-16',
231 | help='encoder architecture')
232 | parser.add_argument('--rescale', dest='rescale', action='store_const',
233 | const=True, default=True,
234 | help='if set, rescale time to [-1, 1]')
235 | parser.add_argument('--no-rescale', dest='rescale', action='store_const',
236 | const=False)
237 | parser.add_argument('--cconvnorm', dest='cconv_norm',
238 | action='store_const', const=True, default=True,
239 | help='if set, normalize continuous convolutional '
240 | 'layer using mean pooling')
241 | parser.add_argument('--no-cconvnorm', dest='cconv_norm',
242 | action='store_const', const=False)
243 | parser.add_argument('--cconv-ref', type=int, default=98,
244 | help='number of evenly-spaced reference locations '
245 | 'for continuous convolutional layer')
246 | parser.add_argument('--dec-ref', type=int, default=128,
247 | help='number of evenly-spaced reference locations '
248 | 'for decoder')
249 | parser.add_argument('--ema', dest='ema', type=int, default=0,
250 | help='start epoch of exponential moving average '
251 | '(EMA). -1 to disable EMA')
252 | parser.add_argument('--ema-decay', type=float, default=.9999,
253 | help='EMA decay')
254 |
255 | args = parser.parse_args()
256 |
257 | nz = args.nz
258 |
259 | epochs = args.epoch
260 | eval_interval = args.eval_interval
261 | save_interval = args.save_interval
262 |
263 | if args.seed is None:
264 | rnd = np.random.RandomState(None)
265 | random_seed = rnd.randint(np.iinfo(np.uint32).max)
266 | else:
267 | random_seed = args.seed
268 | rnd = np.random.RandomState(random_seed)
269 | np.random.seed(random_seed)
270 | torch.manual_seed(random_seed)
271 |
272 | max_time = 5
273 | cconv_ref = args.cconv_ref
274 | overlap = args.overlap
275 | train_dataset, val_dataset, test_dataset = time_series.split_data(
276 | args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale)
277 |
278 | train_loader = DataLoader(
279 | train_dataset, batch_size=args.batch_size, shuffle=True,
280 | drop_last=True, collate_fn=train_dataset.collate_fn)
281 | n_train_batch = len(train_loader)
282 |
283 | val_loader = DataLoader(
284 | val_dataset, batch_size=args.test_batch_size, shuffle=False,
285 | collate_fn=val_dataset.collate_fn)
286 |
287 | test_loader = DataLoader(
288 | test_dataset, batch_size=args.test_batch_size, shuffle=False,
289 | collate_fn=test_dataset.collate_fn)
290 |
291 | in_channels, seq_len = train_dataset.data.shape[1:]
292 |
293 | dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels]
294 | enc_channels = [int(c) for c in args.enc_ch.split('-')]
295 |
296 | out_channels = enc_channels[0]
297 |
298 | squash = torch.sigmoid
299 | if args.rescale:
300 | squash = torch.tanh
301 |
302 | dec_ch_up = 2**(len(dec_channels) - 2)
303 | assert args.dec_ref % dec_ch_up == 0, (
304 | f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.')
305 | dec_len0 = args.dec_ref // dec_ch_up
306 | grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash)
307 |
308 | decoder = Decoder(
309 | grid_decoder, max_time=max_time, dec_ref=args.dec_ref).to(device)
310 |
311 | cconv = ContinuousConv1D(in_channels, out_channels, max_time, cconv_ref,
312 | overlap_rate=overlap, kernel_size=args.comp,
313 | norm=args.cconv_norm).to(device)
314 | encoder = Encoder(cconv, nz, enc_channels, args.flow).to(device)
315 |
316 | classifier = Classifier(nz, args.clsdep).to(device)
317 |
318 | pvae = PVAE(
319 | encoder, decoder, classifier, args.sigma, args.cls).to(device)
320 |
321 | ema = None
322 | if args.ema >= 0:
323 | ema = EMA(pvae, args.ema_decay, args.ema)
324 |
325 | other_params = [param for name, param in pvae.named_parameters()
326 | if not (name.startswith('decoder.grid_decoder')
327 | or name.startswith('encoder.grid_encoder'))]
328 | params = [
329 | {'params': decoder.grid_decoder.parameters(), 'lr': args.dec_lr},
330 | {'params': encoder.grid_encoder.parameters(), 'lr': args.enc_lr},
331 | {'params': other_params},
332 | ]
333 |
334 | optimizer = optim.Adam(
335 | params, lr=args.lr, weight_decay=args.wd)
336 |
337 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)
338 |
339 | path = '{}_{}'.format(
340 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'))
341 |
342 | output_dir = Path('results') / 'mimic3-pvae' / path
343 | print(output_dir)
344 | log_dir = mkdir(output_dir / 'log')
345 | model_dir = mkdir(output_dir / 'model')
346 |
347 | start_epoch = 0
348 |
349 | with (log_dir / 'seed.txt').open('w') as f:
350 | print(random_seed, file=f)
351 | with (log_dir / 'gpu.txt').open('a') as f:
352 | print(torch.cuda.device_count(), start_epoch, file=f)
353 | with (log_dir / 'args.txt').open('w') as f:
354 | for key, val in sorted(vars(args).items()):
355 | print(f'{key}: {val}', file=f)
356 | with (log_dir / 'params.txt').open('w') as f:
357 | def print_params_count(module, name):
358 | try: # sum counts if module is a list
359 | params_count = sum(count_parameters(m) for m in module)
360 | except TypeError:
361 | params_count = count_parameters(module)
362 | print(f'{name} {params_count}', file=f)
363 | print_params_count(grid_decoder, 'grid_decoder')
364 | print_params_count(decoder, 'decoder')
365 | print_params_count(cconv, 'cconv')
366 | print_params_count(encoder, 'encoder')
367 | print_params_count(classifier, 'classifier')
368 | print_params_count(pvae, 'pvae')
369 | print_params_count(pvae, 'total')
370 |
371 | tracker = Tracker(log_dir, n_train_batch)
372 | evaluator = Evaluator(pvae, val_loader, test_loader, log_dir,
373 | eval_args={'iw_samples': args.test_k})
374 | start = time.time()
375 | epoch_start = start
376 |
377 | for epoch in range(start_epoch, epochs):
378 | loss_breakdown = defaultdict(float)
379 | epoch_start = time.time()
380 | for (val, idx, mask, y, _, cconv_graph) in train_loader:
381 | optimizer.zero_grad()
382 | loss, _, _, loss_info = pvae(
383 | val, idx, mask, y, cconv_graph, args.train_k, args.ts, args.kl)
384 | loss.backward()
385 | optimizer.step()
386 |
387 | if ema:
388 | ema.update()
389 |
390 | for loss_name, loss_val in loss_info.items():
391 | loss_breakdown[loss_name] += loss_val
392 |
393 | if scheduler:
394 | scheduler.step()
395 |
396 | cur_time = time.time()
397 | tracker.log(
398 | epoch, loss_breakdown, cur_time - epoch_start, cur_time - start)
399 |
400 | if eval_interval > 0 and (epoch + 1) % eval_interval == 0:
401 | if ema:
402 | ema.apply()
403 | evaluator.evaluate(epoch)
404 | ema.restore()
405 | else:
406 | evaluator.evaluate(epoch)
407 |
408 | model_dict = {
409 | 'pvae': pvae.state_dict(),
410 | 'ema': ema.state_dict() if ema else None,
411 | 'epoch': epoch + 1,
412 | 'args': args,
413 | }
414 | torch.save(model_dict, str(log_dir / 'model.pth'))
415 | if save_interval > 0 and (epoch + 1) % save_interval == 0:
416 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))
417 |
418 | print(output_dir)
419 |
420 |
421 | if __name__ == '__main__':
422 | main()
423 |
--------------------------------------------------------------------------------
/time-series/mmd.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def mmd(x, y):
5 | n, dim = x.shape
6 |
7 | xx = (x**2).sum(1, keepdim=True)
8 | yy = (y**2).sum(1, keepdim=True)
9 |
10 | outer_xx = torch.mm(x, x.t())
11 | outer_yy = torch.mm(y, y.t())
12 | outer_xy = torch.mm(x, y.t())
13 |
14 | diff_xx = xx + xx.t() - 2 * outer_xx
15 | diff_yy = yy + yy.t() - 2 * outer_yy
16 | diff_xy = xx + yy.t() - 2 * outer_xy
17 |
18 | C = 2. * dim
19 | k_xx = C / (C + diff_xx)
20 | k_yy = C / (C + diff_yy)
21 | k_xy = C / (C + diff_xy)
22 |
23 | mean_xx = (k_xx.sum() - k_xx.diag().sum()) / (n * (n - 1))
24 | mean_yy = (k_yy.sum() - k_yy.diag().sum()) / (n * (n - 1))
25 | mean_xy = k_xy.sum() / (n * n)
26 |
27 | return mean_xx + mean_yy - 2 * mean_xy
28 |
--------------------------------------------------------------------------------
/time-series/sn_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils import spectral_norm
4 | import torch.nn.functional as F
5 |
6 |
7 | class InvertibleLinearResNetBlock(nn.Module):
8 | def __init__(self, in_size, hid_size):
9 | super().__init__()
10 | self.res = nn.Sequential(
11 | nn.ELU(),
12 | spectral_norm(nn.Linear(in_size, hid_size)),
13 | nn.ELU(),
14 | spectral_norm(nn.Linear(hid_size, in_size)),
15 | )
16 |
17 | def forward(self, x):
18 | return x + self.res(x)
19 |
20 |
21 | class InvertibleLinearResNet(nn.Module):
22 | def __init__(self, in_size, hid_size, layers=1):
23 | super().__init__()
24 | self.res = nn.Sequential(
25 | *[InvertibleLinearResNetBlock(in_size, hid_size)
26 | for _ in range(layers)])
27 |
28 | def forward(self, x):
29 | return self.res(x)
30 |
31 |
32 | class ResLinearBlock(nn.Module):
33 | def __init__(self, in_size, out_size):
34 | super().__init__()
35 | self.linear = nn.Sequential(
36 | spectral_norm(nn.Linear(in_size, out_size)),
37 | nn.Dropout(),
38 | nn.LeakyReLU(.2),
39 | spectral_norm(nn.Linear(out_size, out_size)),
40 | nn.Dropout(),
41 | nn.LeakyReLU(.2),
42 | )
43 |
44 | self.skip = nn.Sequential(
45 | spectral_norm(nn.Linear(in_size, out_size)),
46 | nn.LeakyReLU(.2),
47 | )
48 |
49 | def forward(self, x):
50 | return self.linear(x) + self.skip(x)
51 |
52 |
53 | class Classifier(nn.Module):
54 | def __init__(self, in_size, layers=1):
55 | super().__init__()
56 | blocks = []
57 | for _ in range(layers):
58 | blocks.append(ResLinearBlock(in_size, in_size))
59 | # No spectral normalization for the last layer
60 | blocks.append(nn.Linear(in_size, 1))
61 | self.res_linear = nn.Sequential(*blocks)
62 |
63 | def forward(self, x):
64 | return self.res_linear(x)
65 |
66 |
67 | class GBlock(nn.Module):
68 | def __init__(self, in_channels, out_channels):
69 | super().__init__()
70 | self.activation = nn.ReLU(inplace=False)
71 | self.bn1 = nn.BatchNorm1d(in_channels)
72 | self.bn2 = nn.BatchNorm1d(out_channels)
73 | self.conv1 = spectral_norm(nn.Conv1d(
74 | in_channels, out_channels, kernel_size=3, padding=1))
75 | self.conv2 = spectral_norm(nn.Conv1d(
76 | out_channels, out_channels, kernel_size=3, padding=1))
77 | self.convx = None
78 | if in_channels != out_channels:
79 | self.convx = spectral_norm(nn.Conv1d(
80 | in_channels, out_channels, kernel_size=1))
81 |
82 | def forward(self, x):
83 | h = self.activation(self.bn1(x))
84 | h = F.interpolate(h, scale_factor=2)
85 | x = F.interpolate(x, scale_factor=2)
86 | if self.convx:
87 | x = self.convx(x)
88 | h = self.conv1(h)
89 | h = self.activation(self.bn2(h))
90 | h = self.conv2(h)
91 | return h + x
92 |
93 |
94 | class GridDecoder(nn.Module):
95 | def __init__(self, dim_z, channels, start_len=16, squash=None):
96 | super().__init__()
97 | self.activation = nn.ReLU(inplace=False)
98 | self.start_len = start_len
99 | self.linear = spectral_norm(nn.Linear(dim_z, channels[0] * start_len))
100 | self.blocks = nn.Sequential(
101 | *[GBlock(in_channels, channels[c + 1])
102 | for c, in_channels in enumerate(channels[:-2])])
103 | self.output = nn.Sequential(
104 | nn.BatchNorm1d(channels[-2]),
105 | self.activation,
106 | spectral_norm(nn.Conv1d(
107 | channels[-2], channels[-1], kernel_size=3, padding=1)),
108 | )
109 | self.squash = squash
110 |
111 | def forward(self, z):
112 | h = self.linear(z)
113 | h = h.view(h.shape[0], -1, self.start_len)
114 | h = self.blocks(h)
115 | h = self.output(h)
116 | if self.squash:
117 | h = self.squash(h)
118 | return h
119 |
120 |
121 | class DBlock(nn.Module):
122 | def __init__(self, in_channels, out_channels, downsample=True):
123 | super().__init__()
124 | self.activation = nn.ReLU(inplace=False)
125 | self.conv1 = spectral_norm(nn.Conv1d(
126 | in_channels, out_channels, kernel_size=3, padding=1))
127 | self.conv2 = spectral_norm(nn.Conv1d(
128 | out_channels, out_channels, kernel_size=3, padding=1))
129 | self.convx = None
130 | if in_channels != out_channels:
131 | self.convx = spectral_norm(nn.Conv1d(
132 | in_channels, out_channels, kernel_size=1))
133 | self.downsample = None
134 | if downsample:
135 | self.downsample = nn.AvgPool1d(2)
136 |
137 | def shortcut(self, x):
138 | if self.convx:
139 | x = self.convx(x)
140 | if self.downsample:
141 | x = self.downsample(x)
142 | return x
143 |
144 | def forward(self, x):
145 | # pre-activation
146 | h = self.activation(x)
147 | h = self.conv1(h)
148 | h = self.conv2(self.activation(h))
149 | if self.downsample:
150 | h = self.downsample(h)
151 | return h + self.shortcut(x)
152 |
153 |
154 | class GridEncoder(nn.Module):
155 | def __init__(self, channels, out_dim=1):
156 | super().__init__()
157 | self.activation = nn.ReLU(inplace=False)
158 | self.blocks = nn.Sequential(
159 | *[DBlock(in_channels, out_channels)
160 | for in_channels, out_channels
161 | in zip(channels[:-1], channels[1:])])
162 | self.linear = spectral_norm(nn.Linear(channels[-1], out_dim))
163 |
164 | def forward(self, x):
165 | h = x
166 | h = self.blocks(h)
167 | h = self.activation(h).sum(2)
168 | return self.linear(h)
169 |
170 |
171 | class Decoder(nn.Module):
172 | def __init__(self, grid_decoder, max_time=5, kernel_bw=None, dec_ref=128):
173 | super().__init__()
174 | if kernel_bw is None:
175 | self.kernel_bw = max_time / dec_ref * 3
176 | else:
177 | self.kernel_bw = kernel_bw
178 | # ref_times are the assigned time stamps for the evenly-spaced
179 | # generated sequences by conv1d.
180 | self.register_buffer('ref_times', torch.linspace(0, max_time, dec_ref))
181 | self.ref_times = self.ref_times[:, None]
182 | self.grid_decoder = grid_decoder
183 |
184 | def forward(self, code, time, mask):
185 | """
186 | Args:
187 | code: shape (batch_size, latent_size)
188 | time: shape (batch_size, channels, max_seq_len)
189 | mask: shape (batch_size, channels, max_seq_len)
190 |
191 | Returns:
192 | interpolated tensor of shape (batch_size, max_seq_len)
193 | """
194 | # shape of x: (batch_size, n_channels, dec_ref)
195 | x = self.grid_decoder(code)
196 |
197 | # t_diff shape: (batch_size, n_channels, dec_ref, max_seq_len)
198 | t_diff = time[:, :, None] - self.ref_times
199 |
200 | # Epanechnikov quadratic kernel:
201 | # K_\lambda(x_0, x) = relu(3/4 * (1 - (|x_0 - x| / \lambda)^2))
202 | # shape of w: (batch_size, n_channels, dec_ref, max_seq_len)
203 | w = F.relu((1 - (t_diff / self.kernel_bw)**2) * .75)
204 | # avoid divided by zero
205 | # normalizer = torch.clamp(w.sum(2), min=1e-6)
206 | # return ((x[:, :, :, None] * w).sum(2) * mask) / normalizer
207 | ks_x = ((x[:, :, :, None] * w).sum(2) * mask) / w.sum(2)
208 | return ks_x
209 |
210 |
211 | def gan_loss(real, fake, real_target, fake_target):
212 | real_score = sum(F.binary_cross_entropy_with_logits(
213 | r, r.new_tensor(real_target).expand_as(r)) for r in real)
214 | fake_score = sum(F.binary_cross_entropy_with_logits(
215 | f, f.new_tensor(fake_target).expand_as(f)) for f in fake)
216 | return real_score + fake_score
217 |
--------------------------------------------------------------------------------
/time-series/spline_cconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.data.dataloader import default_collate
4 | from torch_spline_conv import SplineBasis, SplineWeighting
5 | import math
6 |
7 |
8 | def kernel_width(max_time, ref_size, overlap_rate):
9 | return max_time / (ref_size + overlap_rate - overlap_rate * ref_size)
10 |
11 |
12 | class ContinuousConv1D(nn.Module):
13 | def __init__(self,
14 | in_channels,
15 | out_channels=64,
16 | max_time=5,
17 | ref_size=98,
18 | overlap_rate=.5,
19 | kernel_size=5,
20 | norm=False,
21 | bias=True,
22 | spline_degree=1):
23 | super().__init__()
24 |
25 | self.in_channels = in_channels
26 | self.out_channels = out_channels
27 | self.ref_size = ref_size
28 | self.overlap_rate = overlap_rate
29 | self.kernel_width = kernel_width(max_time, ref_size, overlap_rate)
30 | self.spline_degree = spline_degree
31 | self.norm = norm
32 |
33 | margin = self.kernel_width / 2
34 | refs = torch.linspace(margin, max_time - margin, ref_size)
35 | self.register_buffer('refs', refs)
36 |
37 | kernel_size = torch.tensor([kernel_size], dtype=torch.long)
38 | self.register_buffer('kernel_size', kernel_size)
39 |
40 | is_open_spline = torch.tensor([1], dtype=torch.uint8)
41 | self.register_buffer('is_open_spline', is_open_spline)
42 |
43 | self.weight = nn.Parameter(
44 | torch.Tensor(kernel_size, in_channels, out_channels))
45 | if bias:
46 | self.bias = nn.Parameter(torch.Tensor(out_channels))
47 | else:
48 | self.register_parameter('bias', None)
49 |
50 | self.reset_parameters()
51 |
52 | def reset_parameters(self):
53 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
54 | if self.bias is not None:
55 | bound = 1 / math.sqrt(self.in_channels)
56 | nn.init.uniform_(self.bias, -bound, bound)
57 |
58 | def forward(self, pseudo, ref_idx, y, ref_deg, batch_size):
59 | conv_out = y[0].new_zeros(
60 | self.in_channels, self.ref_size * batch_size, self.out_channels)
61 | for c in range(self.in_channels):
62 | data = SplineBasis.apply(pseudo[c], self.kernel_size,
63 | self.is_open_spline, self.spline_degree)
64 | out = SplineWeighting.apply(
65 | y[c], self.weight[:, c].unsqueeze(1), *data)
66 | idx = ref_idx[c].expand_as(out)
67 | conv_out[c].scatter_add_(0, idx, out)
68 | if self.norm:
69 | conv_out[c].div_(ref_deg[c])
70 | conv_out = conv_out.sum(0)
71 | conv_out = conv_out.view(batch_size, self.ref_size, self.out_channels)
72 | conv_out = conv_out.transpose(1, 2)
73 | if self.bias is not None:
74 | conv_out = conv_out + self.bias[:, None]
75 | return conv_out
76 |
77 |
78 | def gen_collate_fn(channels, max_time=5, ref_size=98, overlap_rate=.5,
79 | device=None):
80 | k_width = kernel_width(max_time, ref_size, overlap_rate)
81 | margin = k_width / 2
82 | refs_ = torch.linspace(margin, max_time - margin, ref_size)
83 |
84 | def collate_fn(batch):
85 | y0 = batch[0][0]
86 | refs = refs_.to(y0.device)
87 |
88 | pseudo = [[] for _ in range(channels)]
89 | cum_ref_idx = [[] for _ in range(channels)]
90 | concat_y = [[] for _ in range(channels)]
91 | deg = [[] for _ in range(channels)]
92 |
93 | for i, ts_info in enumerate(batch):
94 | y, t, m = ts_info[:3]
95 | for c in range(channels):
96 | tc = t[c][m[c] == 1]
97 | yc = y[c][m[c] == 1]
98 | dis = (tc - refs[:, None]) / k_width + .5
99 | mask = (dis <= 1) * (dis >= 0)
100 | ref_idx, t_idx = torch.nonzero(mask).t()
101 | # Pseudo coordinates in [0, 1]
102 | pseudo[c].append(dis[mask])
103 | # Indices accumulated across mini-batch. Used for adding
104 | # convolution results to linearized padded tensor.
105 | cum_ref_idx[c].append(ref_idx + i * ref_size)
106 | concat_y[c].append(yc[t_idx])
107 | deg[c].append(y0.new_zeros(ref_size).scatter_add_(
108 | 0, ref_idx, y0.new_ones(ref_idx.shape)))
109 |
110 | for c in range(channels):
111 | pseudo[c] = torch.cat(pseudo[c]).unsqueeze(1).to(device)
112 | cum_ref_idx[c] = torch.cat(cum_ref_idx[c]).unsqueeze(1).to(device)
113 | concat_y[c] = torch.cat(concat_y[c]).unsqueeze(1).to(device)
114 | # clamp(min=1) to avoid dividing by zero
115 | deg[c] = torch.cat(deg[c]).clamp(min=1).unsqueeze(1).to(device)
116 |
117 | converted_batch = [x.to(device) for x in default_collate(batch)]
118 | return converted_batch + [(pseudo, cum_ref_idx, concat_y, deg)]
119 |
120 | return collate_fn
121 |
--------------------------------------------------------------------------------
/time-series/time_series.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from torch.utils.data.dataloader import default_collate
4 | import numpy as np
5 | from sklearn.model_selection import train_test_split
6 |
7 |
8 | def kernel_width(max_time, cconv_ref, overlap_rate):
9 | return max_time / (cconv_ref + overlap_rate - overlap_rate * cconv_ref)
10 |
11 |
12 | class TimeSeries(Dataset):
13 | def __init__(self, data, time, mask, label=None,
14 | max_time=5, cconv_ref=98, overlap_rate=.5, device=None):
15 | self.data = torch.tensor(data, dtype=torch.float)
16 | self.time = torch.tensor(time, dtype=torch.float)
17 | self.mask = torch.tensor(mask, dtype=torch.float)
18 |
19 | if label is None:
20 | TimeSeries.__getitem__ = lambda self, index: (
21 | self.data[index], self.time[index], self.mask[index], index)
22 | else:
23 | self.label = torch.tensor(label, dtype=torch.float)
24 | TimeSeries.__getitem__ = lambda self, index: (
25 | self.data[index], self.time[index], self.mask[index],
26 | self.label[index], index)
27 |
28 | self.data_len, self.channels = self.data.shape[:2]
29 | self.cconv_ref = cconv_ref
30 | self.device = device
31 | k_width = kernel_width(max_time, cconv_ref, overlap_rate)
32 | margin = k_width / 2
33 | refs = torch.linspace(margin, max_time - margin, cconv_ref)
34 |
35 | self.pseudo, self.deg, self.ref_idx, self.t_idx = [
36 | [[None] * self.channels for _ in range(self.data_len)]
37 | for _ in range(4)]
38 |
39 | for i, (y, t, m) in enumerate(zip(self.data, self.time, self.mask)):
40 | for c in range(self.channels):
41 | tc = t[c][m[c] == 1]
42 | dis = (tc - refs[:, None]) / k_width + .5
43 | dmask = (dis <= 1) * (dis >= 0)
44 | self.ref_idx[i][c], self.t_idx[i][c] = torch.nonzero(dmask).t()
45 | # Pseudo coordinates in [0, 1]
46 | self.pseudo[i][c] = dis[dmask]
47 | cur_deg = torch.zeros(self.cconv_ref)
48 | cur_deg.scatter_add_(0, self.ref_idx[i][c],
49 | torch.ones(self.ref_idx[i][c].shape))
50 | self.deg[i][c] = cur_deg.clamp(min=1)
51 |
52 | def __len__(self):
53 | return self.data_len
54 |
55 | def make_graph(self, data, time, mask, index):
56 | pseudo = [
57 | torch.cat([self.pseudo[idx][c] for idx in index])
58 | .to(self.device).unsqueeze_(1).requires_grad_(False)
59 | for c in range(self.channels)]
60 |
61 | # Indices accumulated across mini-batch. Used for adding
62 | # convolution results to linearized padded tensor.
63 | cum_ref_idx = [
64 | torch.cat([self.ref_idx[idx][c] + i * self.cconv_ref
65 | for i, idx in enumerate(index)])
66 | .to(self.device).unsqueeze_(1).requires_grad_(False)
67 | for c in range(self.channels)]
68 |
69 | concat_y = [
70 | torch.cat(
71 | [y[c][(m[c] == 1).requires_grad_(False)][self.t_idx[idx][c]]
72 | for y, m, idx in zip(data, mask, index)])
73 | .to(self.device).unsqueeze_(1)
74 | for c in range(self.channels)]
75 |
76 | deg = [
77 | torch.cat([self.deg[idx][c] for idx in index])
78 | .to(self.device).unsqueeze_(1).requires_grad_(False)
79 | for c in range(self.channels)]
80 |
81 | return pseudo, cum_ref_idx, concat_y, deg
82 |
83 | def collate_fn(self, batch):
84 | batch = [x.to(self.device) for x in default_collate(batch)]
85 | # For labeled data, skip the label as the 4th entry.
86 | (data, time, mask), index = batch[:3], batch[-1]
87 | graph = self.make_graph(data, time, mask, index)
88 | return batch + [graph]
89 |
90 |
91 | def split_data(data_file, rnd, max_time, cconv_ref, overlap, device,
92 | rescale=False):
93 | raw_data = np.load(data_file)
94 |
95 | if len(raw_data) == 4:
96 | time_np = raw_data['time']
97 | data_np = raw_data['data']
98 | mask_np = raw_data['mask']
99 | label_np = raw_data['label'].squeeze()
100 |
101 | (tv_time, test_time, tv_data, test_data,
102 | tv_mask, test_mask, tv_label, test_label) = train_test_split(
103 | time_np, data_np, mask_np, label_np,
104 | train_size=.8, stratify=label_np, random_state=rnd)
105 |
106 | (train_time, val_time, train_data, val_data,
107 | train_mask, val_mask, train_label, val_label) = train_test_split(
108 | tv_time, tv_data, tv_mask, tv_label,
109 | train_size=.8, stratify=tv_label, random_state=rnd)
110 |
111 | elif len(raw_data) == 8:
112 | tv_time = raw_data['train_time']
113 | tv_data = raw_data['train_data']
114 | tv_mask = raw_data['train_mask']
115 | tv_label = raw_data['train_label']
116 |
117 | test_time = raw_data['test_time']
118 | test_data = raw_data['test_data']
119 | test_mask = raw_data['test_mask']
120 | test_label = raw_data['test_label']
121 |
122 | (train_time, val_time, train_data, val_data,
123 | train_mask, val_mask, train_label, val_label) = train_test_split(
124 | tv_time, tv_data, tv_mask, tv_label,
125 | train_size=.8, stratify=tv_label, random_state=rnd)
126 |
127 | elif len(raw_data) == 12:
128 | train_time = raw_data['train_time']
129 | train_data = raw_data['train_data']
130 | train_mask = raw_data['train_mask']
131 | train_label = raw_data['train_label']
132 |
133 | test_time = raw_data['test_time']
134 | test_data = raw_data['test_data']
135 | test_mask = raw_data['test_mask']
136 | test_label = raw_data['test_label']
137 |
138 | val_time = raw_data['val_time']
139 | val_data = raw_data['val_data']
140 | val_mask = raw_data['val_mask']
141 | val_label = raw_data['val_label']
142 | else:
143 | raise Exception('Invalid data')
144 |
145 | # Scale time
146 | train_time *= max_time
147 | test_time *= max_time
148 | val_time *= max_time
149 |
150 | # Rescale data from [0, 1] to [-1, 1]
151 | if rescale:
152 | train_data = 2 * train_data - 1
153 | val_data = 2 * val_data - 1
154 | test_data = 2 * test_data - 1
155 |
156 | train_dataset = TimeSeries(
157 | train_data, train_time, train_mask, train_label, max_time=max_time,
158 | cconv_ref=cconv_ref, overlap_rate=overlap, device=device)
159 |
160 | val_dataset = TimeSeries(
161 | val_data, val_time, val_mask, val_label, max_time=max_time,
162 | cconv_ref=cconv_ref, overlap_rate=overlap, device=device)
163 |
164 | test_dataset = TimeSeries(
165 | test_data, test_time, test_mask, test_label, max_time=max_time,
166 | cconv_ref=cconv_ref, overlap_rate=overlap, device=device)
167 |
168 | return train_dataset, val_dataset, test_dataset
169 |
--------------------------------------------------------------------------------
/time-series/toy_layers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn.utils import spectral_norm
3 |
4 |
5 | def dconv_bn_relu(in_dim, out_dim):
6 | return nn.Sequential(
7 | nn.ConvTranspose1d(in_dim, out_dim, 5, 2,
8 | padding=2, output_padding=1, bias=False),
9 | nn.BatchNorm1d(out_dim),
10 | nn.ReLU())
11 |
12 |
13 | class SeqGeneratorDiscrete(nn.Module):
14 | def __init__(self, n_channels=3, latent_size=128, squash=None):
15 | super().__init__()
16 |
17 | self.l1 = nn.Sequential(
18 | nn.Linear(latent_size, 2048, bias=False),
19 | nn.BatchNorm1d(2048),
20 | nn.ReLU())
21 |
22 | self.l2 = nn.Sequential(
23 | dconv_bn_relu(256, 128),
24 | dconv_bn_relu(128, 64),
25 | dconv_bn_relu(64, 32),
26 | nn.ConvTranspose1d(32, n_channels, 5, 2,
27 | padding=2, output_padding=1))
28 | self.squash = squash
29 |
30 | def forward(self, z):
31 | h = self.l1(z)
32 | h = h.view(h.shape[0], -1, 8)
33 | h = self.l2(h)
34 | if self.squash:
35 | h = self.squash(h)
36 | return h
37 |
38 |
39 | def conv_ln_lrelu(in_dim, out_dim):
40 | return nn.Sequential(
41 | spectral_norm(nn.Conv1d(in_dim, out_dim, 5, 2, 2)),
42 | nn.LeakyReLU(0.2))
43 |
--------------------------------------------------------------------------------
/time-series/toy_pbigan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils import spectral_norm
5 | import torch.optim as optim
6 | from torch.utils.data import DataLoader
7 | import numpy as np
8 | from datetime import datetime
9 | import time
10 | from pathlib import Path
11 | import argparse
12 | from collections import defaultdict
13 | from spline_cconv import ContinuousConv1D
14 | from time_series import TimeSeries
15 | from ema import EMA
16 | from mmd import mmd
17 | from tracker import Tracker
18 | from vis import Visualizer
19 | from gen_toy_data import gen_data
20 | from utils import Rescaler, mkdir, make_scheduler
21 | from layers import Decoder, gan_loss
22 | from toy_layers import SeqGeneratorDiscrete, conv_ln_lrelu
23 |
24 |
25 | use_cuda = torch.cuda.is_available()
26 | device = torch.device('cuda' if use_cuda else 'cpu')
27 |
28 |
29 | class Encoder(nn.Module):
30 | def __init__(self, cconv, latent_size, norm_trans=True):
31 | super().__init__()
32 | self.cconv = cconv
33 | self.ls = nn.Sequential(
34 | nn.LeakyReLU(0.2),
35 | conv_ln_lrelu(64, 128),
36 | conv_ln_lrelu(128, 256),
37 | # conv_ln_lrelu(256, 512),
38 | # conv_ln_lrelu(512, 64),
39 | conv_ln_lrelu(256, 32),
40 | )
41 | conv_size = 416
42 | self.fc = nn.Sequential(
43 | spectral_norm(nn.Linear(conv_size, latent_size * 2)),
44 | nn.LeakyReLU(0.2),
45 | spectral_norm(nn.Linear(latent_size * 2, latent_size * 2)),
46 | )
47 | self.norm_trans = norm_trans
48 | if norm_trans:
49 | self.fc2 = nn.Sequential(
50 | spectral_norm(nn.Linear(latent_size, latent_size)),
51 | nn.LeakyReLU(0.2),
52 | spectral_norm(nn.Linear(latent_size, latent_size)),
53 | )
54 |
55 | def forward(self, cconv_graph, batch_size):
56 | x = self.cconv(*cconv_graph, batch_size)
57 | # expected shape: (batch_size, 448)
58 | x = self.ls(x).view(x.shape[0], -1)
59 | mu, logvar = self.fc(x).chunk(2, dim=1)
60 | if self.training:
61 | std = F.softplus(logvar)
62 | eps = torch.empty_like(std).normal_()
63 | # return mu + eps * std, mu, logvar, eps
64 | z = mu + eps * std
65 | else:
66 | z = mu
67 | if self.norm_trans:
68 | z = self.fc2(z)
69 | return z
70 |
71 |
72 | class GridCritic(nn.Module):
73 | def __init__(self):
74 | super().__init__()
75 | self.ls = nn.Sequential(
76 | nn.LeakyReLU(0.2),
77 | conv_ln_lrelu(64, 128),
78 | conv_ln_lrelu(128, 256),
79 | # conv_ln_lrelu(256, 512),
80 | # conv_ln_lrelu(512, 1)
81 | conv_ln_lrelu(256, 1),
82 | )
83 |
84 | def forward(self, x):
85 | # expected shape: (batch_size, 448)
86 | return self.ls(x)
87 |
88 |
89 | class ConvCritic(nn.Module):
90 | def __init__(self, cconv, latent_size, embed_size=13):
91 | super().__init__()
92 | self.cconv = cconv
93 | self.grid_critic = GridCritic()
94 | # self.x_dis = spectral_norm(nn.Linear(7, embed_size))
95 |
96 | self.z_dis = nn.Sequential(
97 | spectral_norm(nn.Linear(latent_size, embed_size)),
98 | nn.LeakyReLU(0.2),
99 | spectral_norm(nn.Linear(embed_size, embed_size)),
100 | )
101 |
102 | self.x_linear = spectral_norm(nn.Linear(embed_size, 1))
103 |
104 | self.xz_dis = nn.Sequential(
105 | spectral_norm(nn.Linear(embed_size * 2, embed_size)),
106 | nn.LeakyReLU(0.2),
107 | spectral_norm(nn.Linear(embed_size, 1)),
108 | )
109 |
110 | def forward(self, cconv_graph, batch_size, z):
111 | x = self.cconv(*cconv_graph, batch_size)
112 | x = self.grid_critic(x)
113 | x = x.squeeze(1)
114 | # x = self.x_dis(x)
115 | z = self.z_dis(z)
116 | xz = torch.cat((x, z), 1)
117 | xz = self.xz_dis(xz)
118 | x_out = self.x_linear(x).view(-1)
119 | xz_out = xz.view(-1)
120 | return xz_out, x_out
121 |
122 |
123 | class PBiGAN(nn.Module):
124 | def __init__(self, encoder, decoder, ae_loss='mse'):
125 | super().__init__()
126 | self.encoder = encoder
127 | self.decoder = decoder
128 | self.ae_loss = ae_loss
129 |
130 | def forward(self, data, time, mask, cconv_graph, time_t, mask_t):
131 | batch_size = len(data)
132 | z_T = self.encoder(cconv_graph, batch_size)
133 |
134 | z_gen = torch.empty_like(z_T).normal_()
135 | x_gen = self.decoder(z_gen, time_t, mask_t)
136 |
137 | x_recon = self.decoder(z_T, time, mask)
138 |
139 | if self.ae_loss == 'mse':
140 | ae_loss = F.mse_loss(x_recon, data, reduction='none') * mask
141 | elif self.ae_loss == 'smooth_l1':
142 | ae_loss = F.smooth_l1_loss(x_recon, data, reduction='none') * mask
143 |
144 | ae_loss = ae_loss.sum((-1, -2))
145 |
146 | return z_T, x_recon, z_gen, x_gen, ae_loss.mean()
147 |
148 |
149 | def main():
150 | parser = argparse.ArgumentParser()
151 |
152 | default_dataset = 'toy-data.npz'
153 | parser.add_argument('--data', default=default_dataset,
154 | help='data file')
155 | parser.add_argument('--seed', type=int, default=None,
156 | help='random seed. Randomly set if not specified.')
157 |
158 | # training options
159 | parser.add_argument('--nz', type=int, default=32,
160 | help='dimension of latent variable')
161 | parser.add_argument('--epoch', type=int, default=1000,
162 | help='number of training epochs')
163 | parser.add_argument('--batch-size', type=int, default=128,
164 | help='batch size')
165 | parser.add_argument('--lr', type=float, default=8e-5,
166 | help='encoder/decoder learning rate')
167 | parser.add_argument('--dis-lr', type=float, default=1e-4,
168 | help='discriminator learning rate')
169 | parser.add_argument('--min-lr', type=float, default=5e-5,
170 | help='min encoder/decoder learning rate for LR '
171 | 'scheduler. -1 to disable annealing')
172 | parser.add_argument('--min-dis-lr', type=float, default=7e-5,
173 | help='min discriminator learning rate for LR '
174 | 'scheduler. -1 to disable annealing')
175 | parser.add_argument('--wd', type=float, default=0,
176 | help='weight decay')
177 | parser.add_argument('--overlap', type=float, default=.5,
178 | help='kernel overlap')
179 | parser.add_argument('--no-norm-trans', action='store_true',
180 | help='if set, use Gaussian posterior without '
181 | 'transformation')
182 | parser.add_argument('--plot-interval', type=int, default=1,
183 | help='plot interval. 0 to disable plotting.')
184 | parser.add_argument('--save-interval', type=int, default=0,
185 | help='interval to save models. 0 to disable saving.')
186 | parser.add_argument('--prefix', default='pbigan',
187 | help='prefix of output directory')
188 | parser.add_argument('--comp', type=int, default=7,
189 | help='continuous convolution kernel size')
190 | parser.add_argument('--ae', type=float, default=.2,
191 | help='autoencoding regularization strength')
192 | parser.add_argument('--aeloss', default='smooth_l1',
193 | help='autoencoding loss. (options: mse, smooth_l1)')
194 | parser.add_argument('--ema', dest='ema', type=int, default=-1,
195 | help='start epoch of exponential moving average '
196 | '(EMA). -1 to disable EMA')
197 | parser.add_argument('--ema-decay', type=float, default=.9999,
198 | help='EMA decay')
199 | parser.add_argument('--mmd', type=float, default=1,
200 | help='MMD strength for latent variable')
201 |
202 | # squash is off when rescale is off
203 | parser.add_argument('--squash', dest='squash', action='store_const',
204 | const=True, default=True,
205 | help='bound the generated time series value '
206 | 'using tanh')
207 | parser.add_argument('--no-squash', dest='squash', action='store_const',
208 | const=False)
209 |
210 | # rescale to [-1, 1]
211 | parser.add_argument('--rescale', dest='rescale', action='store_const',
212 | const=True, default=True,
213 | help='if set, rescale time to [-1, 1]')
214 | parser.add_argument('--no-rescale', dest='rescale', action='store_const',
215 | const=False)
216 |
217 | args = parser.parse_args()
218 |
219 | batch_size = args.batch_size
220 | nz = args.nz
221 |
222 | epochs = args.epoch
223 | plot_interval = args.plot_interval
224 | save_interval = args.save_interval
225 |
226 | try:
227 | npz = np.load(args.data)
228 | train_data = npz['data']
229 | train_time = npz['time']
230 | train_mask = npz['mask']
231 | except FileNotFoundError:
232 | if args.data != default_dataset:
233 | raise
234 | # Generate the default toy dataset from scratch
235 | train_data, train_time, train_mask, _, _ = gen_data(
236 | n_samples=10000, seq_len=200, max_time=1, poisson_rate=50,
237 | obs_span_rate=.25, save_file=default_dataset)
238 |
239 | _, in_channels, seq_len = train_data.shape
240 | train_time *= train_mask
241 |
242 | if args.seed is None:
243 | rnd = np.random.RandomState(None)
244 | random_seed = rnd.randint(np.iinfo(np.uint32).max)
245 | else:
246 | random_seed = args.seed
247 | rnd = np.random.RandomState(random_seed)
248 | np.random.seed(random_seed)
249 | torch.manual_seed(random_seed)
250 |
251 | # Scale time
252 | max_time = 5
253 | train_time *= max_time
254 |
255 | squash = None
256 | rescaler = None
257 | if args.rescale:
258 | rescaler = Rescaler(train_data)
259 | train_data = rescaler.rescale(train_data)
260 | if args.squash:
261 | squash = torch.tanh
262 |
263 | out_channels = 64
264 | cconv_ref = 98
265 |
266 | train_dataset = TimeSeries(
267 | train_data, train_time, train_mask, label=None, max_time=max_time,
268 | cconv_ref=cconv_ref, overlap_rate=args.overlap, device=device)
269 |
270 | train_loader = DataLoader(
271 | train_dataset, batch_size=batch_size, shuffle=True,
272 | drop_last=True, collate_fn=train_dataset.collate_fn)
273 | n_train_batch = len(train_loader)
274 |
275 | time_loader = DataLoader(
276 | train_dataset, batch_size=batch_size, shuffle=True,
277 | drop_last=True, collate_fn=train_dataset.collate_fn)
278 |
279 | test_loader = DataLoader(train_dataset, batch_size=batch_size,
280 | collate_fn=train_dataset.collate_fn)
281 |
282 | grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash)
283 | decoder = Decoder(grid_decoder, max_time=max_time).to(device)
284 |
285 | cconv = ContinuousConv1D(
286 | in_channels, out_channels, max_time, cconv_ref,
287 | overlap_rate=args.overlap, kernel_size=args.comp, norm=True).to(device)
288 | encoder = Encoder(cconv, nz, not args.no_norm_trans).to(device)
289 |
290 | pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device)
291 |
292 | critic_cconv = ContinuousConv1D(
293 | in_channels, out_channels, max_time, cconv_ref,
294 | overlap_rate=args.overlap, kernel_size=args.comp, norm=True).to(device)
295 | critic = ConvCritic(critic_cconv, nz).to(device)
296 |
297 | ema = None
298 | if args.ema >= 0:
299 | ema = EMA(pbigan, args.ema_decay, args.ema)
300 |
301 | optimizer = optim.Adam(
302 | pbigan.parameters(), lr=args.lr, weight_decay=args.wd)
303 | critic_optimizer = optim.Adam(
304 | critic.parameters(), lr=args.dis_lr, weight_decay=args.wd)
305 |
306 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)
307 | dis_scheduler = make_scheduler(
308 | critic_optimizer, args.dis_lr, args.min_dis_lr, epochs)
309 |
310 | path = '{}_{}'.format(
311 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'))
312 |
313 | output_dir = Path('results') / 'toy-pbigan' / path
314 | print(output_dir)
315 | log_dir = mkdir(output_dir / 'log')
316 | model_dir = mkdir(output_dir / 'model')
317 |
318 | start_epoch = 0
319 |
320 | with (log_dir / 'seed.txt').open('w') as f:
321 | print(random_seed, file=f)
322 | with (log_dir / 'gpu.txt').open('a') as f:
323 | print(torch.cuda.device_count(), start_epoch, file=f)
324 | with (log_dir / 'args.txt').open('w') as f:
325 | for key, val in sorted(vars(args).items()):
326 | print(f'{key}: {val}', file=f)
327 |
328 | tracker = Tracker(log_dir, n_train_batch)
329 | visualizer = Visualizer(encoder, decoder, batch_size, max_time,
330 | test_loader, rescaler, output_dir, device)
331 | start = time.time()
332 | epoch_start = start
333 |
334 | for epoch in range(start_epoch, epochs):
335 | loss_breakdown = defaultdict(float)
336 |
337 | for ((val, idx, mask, _, cconv_graph),
338 | (_, idx_t, mask_t, index, _)) in zip(
339 | train_loader, time_loader):
340 |
341 | z_enc, x_recon, z_gen, x_gen, ae_loss = pbigan(
342 | val, idx, mask, cconv_graph, idx_t, mask_t)
343 |
344 | cconv_graph_gen = train_dataset.make_graph(
345 | x_gen, idx_t, mask_t, index)
346 |
347 | real = critic(cconv_graph, batch_size, z_enc)
348 | fake = critic(cconv_graph_gen, batch_size, z_gen)
349 |
350 | D_loss = gan_loss(real, fake, 1, 0)
351 |
352 | critic_optimizer.zero_grad()
353 | D_loss.backward(retain_graph=True)
354 | critic_optimizer.step()
355 |
356 | G_loss = gan_loss(real, fake, 0, 1)
357 |
358 | mmd_loss = mmd(z_enc, z_gen)
359 |
360 | loss = G_loss + ae_loss * args.ae + mmd_loss * args.mmd
361 |
362 | optimizer.zero_grad()
363 | loss.backward()
364 | optimizer.step()
365 |
366 | if ema:
367 | ema.update()
368 |
369 | loss_breakdown['D'] += D_loss.item()
370 | loss_breakdown['G'] += G_loss.item()
371 | loss_breakdown['AE'] += ae_loss.item()
372 | loss_breakdown['MMD'] += mmd_loss.item()
373 | loss_breakdown['total'] += loss.item()
374 |
375 | if scheduler:
376 | scheduler.step()
377 | if dis_scheduler:
378 | dis_scheduler.step()
379 |
380 | cur_time = time.time()
381 | tracker.log(
382 | epoch, loss_breakdown, cur_time - epoch_start, cur_time - start)
383 |
384 | if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
385 | if ema:
386 | ema.apply()
387 | visualizer.plot(epoch)
388 | ema.restore()
389 | else:
390 | visualizer.plot(epoch)
391 |
392 | model_dict = {
393 | 'pbigan': pbigan.state_dict(),
394 | 'critic': critic.state_dict(),
395 | 'ema': ema.state_dict() if ema else None,
396 | 'epoch': epoch + 1,
397 | 'args': args,
398 | }
399 | torch.save(model_dict, str(log_dir / 'model.pth'))
400 | if save_interval > 0 and (epoch + 1) % save_interval == 0:
401 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))
402 |
403 | print(output_dir)
404 |
405 |
406 | if __name__ == '__main__':
407 | main()
408 |
--------------------------------------------------------------------------------
/time-series/toy_pvae.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils import spectral_norm
5 | import torch.optim as optim
6 | from torch.utils.data import DataLoader
7 | import numpy as np
8 | from datetime import datetime
9 | import time
10 | from pathlib import Path
11 | import argparse
12 | from collections import defaultdict
13 | from spline_cconv import ContinuousConv1D
14 | from time_series import TimeSeries
15 | from tracker import Tracker
16 | from vis import Visualizer
17 | from gen_toy_data import gen_data
18 | from utils import Rescaler, mkdir, make_scheduler
19 | from layers import Decoder
20 | from toy_layers import SeqGeneratorDiscrete, conv_ln_lrelu
21 |
22 |
23 | use_cuda = torch.cuda.is_available()
24 | device = torch.device('cuda' if use_cuda else 'cpu')
25 |
26 |
27 | class Encoder(nn.Module):
28 | def __init__(self, latent_size, cconv):
29 | super().__init__()
30 | self.cconv = cconv
31 | self.ls = nn.Sequential(
32 | nn.LeakyReLU(0.2),
33 | conv_ln_lrelu(64, 128),
34 | conv_ln_lrelu(128, 256),
35 | conv_ln_lrelu(256, 512),
36 | conv_ln_lrelu(512, 64),
37 | )
38 | conv_size = 448
39 | self.fc = nn.Sequential(
40 | spectral_norm(nn.Linear(conv_size, latent_size * 2)),
41 | nn.LeakyReLU(0.2),
42 | spectral_norm(nn.Linear(latent_size * 2, latent_size * 2)),
43 | )
44 |
45 | def forward(self, cconv_graph, batch_size):
46 | x = self.cconv(*cconv_graph, batch_size)
47 | # expected shape: (batch_size, 448)
48 | x = self.ls(x).view(x.shape[0], -1)
49 | mu, logvar = self.fc(x).chunk(2, dim=1)
50 | std = torch.exp(logvar * .5)
51 | eps = torch.empty_like(std).normal_()
52 | return mu + eps * std, mu, logvar, eps
53 |
54 |
55 | class PVAE(nn.Module):
56 | def __init__(self, encoder, decoder, sigma=.2):
57 | super().__init__()
58 | self.encoder = encoder
59 | self.decoder = decoder
60 | self.sigma = sigma
61 |
62 | def forward(self, data, time, mask, cconv_graph):
63 | batch_size = len(data)
64 | z, mu, logvar, eps = self.encoder(cconv_graph, batch_size)
65 | x_recon = self.decoder(z, time, mask)
66 | # Gaussian noise
67 | recon_loss = (1 / (2 * self.sigma**2) * F.mse_loss(
68 | x_recon * mask, data * mask, reduction='none') * mask).sum((1, 2))
69 | kl_loss = .5 * (z**2 - logvar - eps**2).sum(1)
70 | loss = recon_loss.mean() + kl_loss.mean()
71 | return loss
72 |
73 |
74 | def main():
75 | parser = argparse.ArgumentParser()
76 |
77 | default_dataset = 'toy-data.npz'
78 | parser.add_argument('--data', default=default_dataset,
79 | help='data file')
80 | parser.add_argument('--seed', type=int, default=None,
81 | help='random seed. Randomly set if not specified.')
82 |
83 | # training options
84 | parser.add_argument('--nz', type=int, default=32,
85 | help='dimension of latent variable')
86 | parser.add_argument('--epoch', type=int, default=1000,
87 | help='number of training epochs')
88 | parser.add_argument('--batch-size', type=int, default=128,
89 | help='batch size')
90 | parser.add_argument('--lr', type=float, default=1e-4,
91 | help='learning rate')
92 | parser.add_argument('--min-lr', type=float, default=5e-5,
93 | help='min learning rate for LR scheduler. '
94 | '-1 to disable annealing')
95 | parser.add_argument('--plot-interval', type=int, default=10,
96 | help='plot interval. 0 to disable plotting.')
97 | parser.add_argument('--save-interval', type=int, default=0,
98 | help='interval to save models. 0 to disable saving.')
99 | parser.add_argument('--prefix', default='pvae',
100 | help='prefix of output directory')
101 | parser.add_argument('--comp', type=int, default=5,
102 | help='continuous convolution kernel size')
103 | parser.add_argument('--sigma', type=float, default=.2,
104 | help='standard deviation for Gaussian likelihood')
105 | parser.add_argument('--overlap', type=float, default=.5,
106 | help='kernel overlap')
107 | # squash is off when rescale is off
108 | parser.add_argument('--squash', dest='squash', action='store_const',
109 | const=True, default=True,
110 | help='bound the generated time series value '
111 | 'using tanh')
112 | parser.add_argument('--no-squash', dest='squash', action='store_const',
113 | const=False)
114 |
115 | # rescale to [-1, 1]
116 | parser.add_argument('--rescale', dest='rescale', action='store_const',
117 | const=True, default=True,
118 | help='if set, rescale time to [-1, 1]')
119 | parser.add_argument('--no-rescale', dest='rescale', action='store_const',
120 | const=False)
121 |
122 | args = parser.parse_args()
123 |
124 | batch_size = args.batch_size
125 | nz = args.nz
126 |
127 | epochs = args.epoch
128 | plot_interval = args.plot_interval
129 | save_interval = args.save_interval
130 |
131 | try:
132 | npz = np.load(args.data)
133 | train_data = npz['data']
134 | train_time = npz['time']
135 | train_mask = npz['mask']
136 | except FileNotFoundError:
137 | if args.data != default_dataset:
138 | raise
139 | # Generate the default toy dataset from scratch
140 | train_data, train_time, train_mask, _, _ = gen_data(
141 | n_samples=10000, seq_len=200, max_time=1, poisson_rate=50,
142 | obs_span_rate=.25, save_file=default_dataset)
143 |
144 | _, in_channels, seq_len = train_data.shape
145 | train_time *= train_mask
146 |
147 | if args.seed is None:
148 | rnd = np.random.RandomState(None)
149 | random_seed = rnd.randint(np.iinfo(np.uint32).max)
150 | else:
151 | random_seed = args.seed
152 | rnd = np.random.RandomState(random_seed)
153 | np.random.seed(random_seed)
154 | torch.manual_seed(random_seed)
155 |
156 | # Scale time
157 | max_time = 5
158 | train_time *= max_time
159 |
160 | squash = None
161 | rescaler = None
162 | if args.rescale:
163 | rescaler = Rescaler(train_data)
164 | train_data = rescaler.rescale(train_data)
165 | if args.squash:
166 | squash = torch.tanh
167 |
168 | out_channels = 64
169 | cconv_ref = 98
170 |
171 | train_dataset = TimeSeries(
172 | train_data, train_time, train_mask, label=None, max_time=max_time,
173 | cconv_ref=cconv_ref, overlap_rate=args.overlap, device=device)
174 |
175 | train_loader = DataLoader(
176 | train_dataset, batch_size=batch_size, shuffle=True,
177 | drop_last=True, collate_fn=train_dataset.collate_fn)
178 | n_train_batch = len(train_loader)
179 |
180 | test_batch_size = 64
181 | test_loader = DataLoader(train_dataset, batch_size=test_batch_size,
182 | collate_fn=train_dataset.collate_fn)
183 |
184 | grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash)
185 | decoder = Decoder(grid_decoder, max_time=max_time).to(device)
186 |
187 | cconv = ContinuousConv1D(
188 | in_channels, out_channels, max_time, cconv_ref,
189 | overlap_rate=args.overlap, kernel_size=args.comp, norm=True).to(device)
190 |
191 | encoder = Encoder(nz, cconv).to(device)
192 |
193 | pvae = PVAE(encoder, decoder, sigma=args.sigma).to(device)
194 |
195 | optimizer = optim.Adam(pvae.parameters(), lr=args.lr)
196 |
197 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs)
198 |
199 | path = '{}_{}_{}'.format(
200 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'),
201 | '_'.join([f'lr_{args.lr:g}']))
202 |
203 | output_dir = Path('results') / 'toy-pvae' / path
204 | print(output_dir)
205 | log_dir = mkdir(output_dir / 'log')
206 | model_dir = mkdir(output_dir / 'model')
207 |
208 | start_epoch = 0
209 |
210 | with (log_dir / 'seed.txt').open('w') as f:
211 | print(random_seed, file=f)
212 | with (log_dir / 'gpu.txt').open('a') as f:
213 | print(torch.cuda.device_count(), start_epoch, file=f)
214 | with (log_dir / 'args.txt').open('w') as f:
215 | for key, val in sorted(vars(args).items()):
216 | print(f'{key}: {val}', file=f)
217 |
218 | tracker = Tracker(log_dir, n_train_batch)
219 | visualizer = Visualizer(encoder, decoder, test_batch_size, max_time,
220 | test_loader, rescaler, output_dir, device)
221 | start = time.time()
222 | epoch_start = start
223 |
224 | for epoch in range(start_epoch, epochs):
225 | loss_breakdown = defaultdict(float)
226 | for val, idx, mask, _, cconv_graph in train_loader:
227 | optimizer.zero_grad()
228 | loss = pvae(val, idx, mask, cconv_graph)
229 | loss.backward()
230 | optimizer.step()
231 | loss_breakdown['loss'] += loss.item()
232 |
233 | if scheduler:
234 | scheduler.step()
235 |
236 | cur_time = time.time()
237 | tracker.log(
238 | epoch, loss_breakdown, cur_time - epoch_start, cur_time - start)
239 |
240 | if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
241 | visualizer.plot(epoch)
242 |
243 | model_dict = {
244 | 'pvae': pvae.state_dict(),
245 | 'epoch': epoch + 1,
246 | 'args': args,
247 | }
248 | torch.save(model_dict, str(log_dir / 'model.pth'))
249 | if save_interval > 0 and (epoch + 1) % save_interval == 0:
250 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth'))
251 |
252 | print(output_dir)
253 |
254 |
255 | if __name__ == '__main__':
256 | main()
257 |
--------------------------------------------------------------------------------
/time-series/tracker.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import defaultdict
3 | import logging
4 | import sys
5 |
6 |
7 | class Tracker:
8 | def __init__(self, log_dir, n_train_batch):
9 | self.log_dir = log_dir
10 | self.n_train_batch = n_train_batch
11 | self.loss = defaultdict(list)
12 |
13 | logging.basicConfig(
14 | level=logging.INFO,
15 | format='%(asctime)s %(message)s',
16 | datefmt='%Y-%m-%d %H:%M:%S',
17 | handlers=[
18 | logging.FileHandler(log_dir / 'log.txt'),
19 | logging.StreamHandler(sys.stdout),
20 | ],
21 | )
22 | self.print_header = True
23 |
24 | def log(self, epoch, loss_breakdown, epoch_time, time_elapsed):
25 | for loss_name, loss_val in loss_breakdown.items():
26 | self.loss[loss_name].append(loss_val / self.n_train_batch)
27 |
28 | if self.print_header:
29 | logging.info(' ' * 7 + ' '.join(
30 | f'{key:>12}' for key in sorted(self.loss)))
31 | self.print_header = False
32 | logging.info(f'[{epoch:4}] ' + ' '.join(
33 | f'{val[-1]:12.4f}' for _, val in sorted(self.loss.items())))
34 |
35 | torch.save(self.loss, str(self.log_dir / 'log.pth'))
36 |
37 | with (self.log_dir / 'time.txt').open('a') as f:
38 | print(epoch, epoch_time, time_elapsed, file=f)
39 |
--------------------------------------------------------------------------------
/time-series/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | import numpy as np
4 |
5 |
6 | def to_numpy(v):
7 | if torch.is_tensor(v):
8 | return v.cpu().numpy()
9 | return v
10 |
11 |
12 | def count_parameters(model):
13 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
14 |
15 |
16 | class Rescaler:
17 | def __init__(self, data):
18 | channels = data.shape[1]
19 | ch_min = np.array([data[:, i].min() for i in range(channels)])
20 | ch_max = np.array([data[:, i].max() for i in range(channels)])
21 | self.ch_min, self.ch_max = ch_min[:, None], ch_max[:, None]
22 |
23 | def rescale(self, data):
24 | return 2 * (data - self.ch_min) / (self.ch_max - self.ch_min) - 1
25 |
26 | def unrescale(self, data):
27 | return .5 * (data + 1) * (self.ch_max - self.ch_min) + self.ch_min
28 |
29 |
30 | def mkdir(path):
31 | path.mkdir(parents=True, exist_ok=True)
32 | return path
33 |
34 |
35 | def make_scheduler(optimizer, lr, min_lr, epochs, steps=10):
36 | if min_lr < 0:
37 | return None
38 | step_size = epochs // steps
39 | gamma = (min_lr / lr)**(1 / steps)
40 | return optim.lr_scheduler.StepLR(
41 | optimizer, step_size=step_size, gamma=gamma)
42 |
--------------------------------------------------------------------------------
/time-series/vis.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import matplotlib.pyplot as plt
3 | import matplotlib.gridspec as gridspec
4 | import seaborn as sns
5 | from utils import mkdir, to_numpy
6 |
7 |
8 | sns.set()
9 |
10 | sns.set_style('darkgrid', {
11 | 'axes.spines.left': True,
12 | 'axes.spines.bottom': True,
13 | 'axes.spines.right': True,
14 | 'axes.spines.top': True,
15 | 'patch.edgecolor': 'k',
16 | 'axes.edgecolor': '.1',
17 | 'axes.facecolor': '.95',
18 | })
19 |
20 |
21 | def plot_samples(data_unif, time_unif, data=None, time=None, mask=None,
22 | rescaler=None, max_time=1, img_path=None, nrows=2, ncols=4):
23 | data_unif = to_numpy(data_unif)
24 | time_unif = to_numpy(time_unif)
25 | n_channels = data_unif.shape[1]
26 | if rescaler:
27 | data_unif = rescaler.unrescale(data_unif)
28 |
29 | if data is not None:
30 | data = to_numpy(data)
31 | time = to_numpy(time)
32 | mask = to_numpy(mask)
33 | if rescaler:
34 | data = rescaler.unrescale(data)
35 |
36 | fig = plt.figure(figsize=(6 * ncols, 2 * n_channels * nrows))
37 | gs = gridspec.GridSpec(nrows, ncols, wspace=.1, hspace=.1)
38 |
39 | for i in range(nrows * ncols):
40 | outer_ax = plt.subplot(gs[i])
41 | outer_ax.set_xticks([])
42 | outer_ax.set_yticks([])
43 |
44 | inner_grid = gridspec.GridSpecFromSubplotSpec(
45 | n_channels, 1, subplot_spec=gs[i], hspace=0)
46 | for k in range(n_channels):
47 | ax = plt.Subplot(fig, inner_grid[k])
48 | ax.plot(time_unif[i, k], data_unif[i, k],
49 | 'k-', alpha=.5, linewidth=.6)
50 | if data is not None:
51 | ax.scatter(time[i, k, mask[i, k] == 1],
52 | data[i, k, mask[i, k] == 1], c='r',
53 | s=30,
54 | linewidth=1.5,
55 | marker='x')
56 | ax.set_ylim(-1.2, 1.2)
57 | ax.set_xlim(0, max_time)
58 | ax.axes.xaxis.set_ticklabels([])
59 | ax.axes.yaxis.set_ticklabels([])
60 | fig.add_subplot(ax)
61 |
62 | if img_path:
63 | plt.savefig(img_path, bbox_inches='tight')
64 | plt.close(fig)
65 |
66 |
67 | class Visualizer:
68 | def __init__(self, encoder, decoder, batch_size, max_time, test_loader,
69 | rescaler, output_dir, device):
70 | self.encoder = encoder
71 | self.decoder = decoder
72 | self.batch_size = batch_size
73 | self.rescaler = rescaler
74 | (self.test_val, self.test_idx, self.test_mask,
75 | _, self.test_cconv_graph) = next(iter(test_loader))
76 | in_channels = self.test_val.shape[1]
77 | self.max_time = max_time
78 | t = torch.linspace(0, max_time, 200, device=device)
79 | self.t = t.expand(batch_size, in_channels, len(t)).contiguous()
80 | self.t_mask = torch.ones_like(self.t)
81 |
82 | self.gen_data_dir = mkdir(output_dir / 'gen')
83 | self.imp_data_dir = mkdir(output_dir / 'imp')
84 |
85 | def plot(self, epoch):
86 | filename = f'{epoch:04d}.png'
87 |
88 | self.encoder.eval()
89 | self.decoder.eval()
90 | with torch.no_grad():
91 | z = self.encoder(self.test_cconv_graph, self.batch_size)
92 | if not torch.is_tensor(z): # P-VAE encoder returns a list
93 | z = z[0]
94 | imp_data = self.decoder(z, self.t, self.t_mask)
95 | plot_samples(imp_data, self.t,
96 | self.test_val, self.test_idx, self.test_mask,
97 | rescaler=self.rescaler,
98 | max_time=self.max_time,
99 | img_path=f'{self.imp_data_dir / filename}')
100 |
101 | data_noise = torch.empty_like(z).normal_()
102 | gen_data = self.decoder(data_noise, self.t, self.t_mask)
103 | plot_samples(gen_data, self.t,
104 | rescaler=self.rescaler,
105 | max_time=self.max_time,
106 | img_path=f'{self.gen_data_dir / filename}')
107 | self.decoder.train()
108 | self.encoder.train()
109 |
--------------------------------------------------------------------------------