├── LICENSE ├── README.md ├── examples ├── images.png └── time-series.png ├── image ├── celeba-data │ └── download-celeba.sh ├── celeba_decoder.py ├── celeba_encoder.py ├── celeba_fid.py ├── celeba_pbigan.py ├── celeba_pbigan_fid.py ├── celeba_pvae.py ├── fid.py ├── flow.py ├── inception.py ├── masked_celeba.py ├── masked_mnist.py ├── mmd.py ├── mnist_decoder.py ├── mnist_encoder.py ├── mnist_pbigan.py ├── mnist_pvae.py ├── utils.py └── visualize.py ├── requirements.txt └── time-series ├── .gitattributes ├── ema.py ├── evaluate.py ├── figures ├── pbigan.png └── pvae.png ├── flow.py ├── gen_toy_data.py ├── layers.py ├── mimic3_pbigan.py ├── mimic3_pvae.py ├── mmd.py ├── sn_layers.py ├── spline_cconv.py ├── time_series.py ├── toy_layers.py ├── toy_pbigan.py ├── toy_pvae.py ├── toy_time_series.ipynb ├── tracker.py ├── utils.py └── vis.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Steven Cheng-Xian Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning from Irregularly-Sampled Time Series: A Missing Data Perspective 2 | 3 | This repository provides a PyTorch implementation of the paper 4 | ["Learning from Irregularly-Sampled Time Series: A Missing Data Perspective"](https://arxiv.org/abs/2008.07599). 5 | 6 | 7 | ## Requirements 8 | 9 | This repository requires Python 3.6 or later. 10 | The file [requirements.txt](requirements.txt) contains the full list of 11 | required Python modules and their version that we tested on. 12 | To install requirements: 13 | 14 | ```sh 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | 19 | ## Image 20 | 21 | image completion 22 | 23 | ### MNIST 24 | 25 | Under the `image` directory, the following commands train P-VAE and P-BiGAN for 26 | incomplete MNIST: 27 | 28 | ```sh 29 | # P-VAE: 30 | python mnist_pvae.py 31 | # P-BiGAN: 32 | python mnist_pbigan.py 33 | ``` 34 | 35 | ### CelebA 36 | 37 | For CelebA, you need to download the dataset from its 38 | [website](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html). 39 | Specifically, you may either: 40 | * Download the file `img_align_celeba.zip` from [this link](https://drive.google.com/uc?export=download&id=0B7EVK8r0v71pZjFTYXZWM3FlRnM) 41 | and extract the zip file into the directory `image/celeba-data`, or 42 | * Run the script [download-celeba.sh](image/celeba-data/download-celeba.sh) 43 | under the directory `image/celeba-data`. Make sure you have 44 | [curl](https://curl.haxx.se) on your system. 45 | 46 | ```sh 47 | cd image/celeba-data && bash download-celeba.sh 48 | ``` 49 | 50 | Under the `image` directory, the following commands train P-VAE and P-BiGAN for 51 | incomplete CelebA: 52 | ```sh 53 | # P-VAE: 54 | python celeba_pvae.py 55 | # P-BiGAN: 56 | python celeba_pbigan.py 57 | ``` 58 | 59 | ### Command-line options 60 | 61 | For both MNIST and CelebA scripts, using the option 62 | `--mask block --block-len n` to specify "square observation" missingness 63 | with n-by-n observed blocks and 64 | `--mask indep --obs-prob .2` to specify "independent dropout" missingness 65 | with 80% missing pixels. 66 | 67 | Use `-h` to see all the available command-line options for each script 68 | (also for the scripts for time series described below). 69 | 70 | ## Time Series 71 | 72 | Our implementation takes as input a time series dataset in a format 73 | composed of three tensors `time`, `data`, `mask` saved as numpy's 74 | [npz](https://numpy.org/doc/stable/reference/generated/numpy.savez_compressed.html) file. 75 | For a time series of `N` data cases, each of which has `C` channels 76 | with each channel having at most `L` observations (time-value pairs), 77 | it is represented by three tensors `time`, `data` and `mask` 78 | of size `(N, C, L)`: 79 | 80 | * `mask` is the binary mask indicating which entries in `time` and `data` 81 | correspond to a missing value. 82 | `mask[n, c, k]` is 1 if the `k`-th entry of the `c`-th channel of the 83 | `n`-th time series is observed, and 0 if it is missing. 84 | * `time` stores the timestamps of the time series rescaled to the range [0, 1]. 85 | Note that for those missing entries, whose corresponding 86 | `mask` entry is zero, they must be set to values within [0, 1] 87 | for the decoder to work correctly. 88 | The easiest way is to set those to zero by `time *= mask`. 89 | * `data` stores the corresponding time series values associated with `time`. 90 | For those missing entries, they may contain arbitrary values. 91 | 92 | 93 | The script [gen_toy_data.py](time-series/gen_toy_data.py) is an example 94 | of creating a synthetic time series dataset in such format. 95 | 96 | ### Synthetic data 97 | 98 | [This notebook](https://nbviewer.jupyter.org/github/steveli/partial-encoder-decoder/blob/master/time-series/toy_time_series.ipynb) 99 | provides an overview of P-VAE and P-BiGAN 100 | and demonstrates how to train them on a synthetic dataset. 101 | 102 | time series imputation 103 | 104 | Under the `time-series` directory, the following commands train a P-VAE 105 | and P-BiGAN on a synthetic multivariate time series dataset: 106 | 107 | ```sh 108 | # P-VAE: 109 | python toy_pvae.py 110 | # P-BiGAN: 111 | python toy_pbigan.py 112 | ``` 113 | 114 | ### MIMIC-III 115 | 116 | MIMIC-III can be downloaded following the instructions from 117 | its [website](https://mimic.physionet.org/gettingstarted/access/). 118 | 119 | For the experiments, we apply the optional preprocessing 120 | used in [this work](https://github.com/mlds-lab/interp-net) 121 | to the MIMIC-III dataset. 122 | 123 | For time series classification task, our implementation takes as input 124 | one of the following three labeled time series data format: 125 | 126 | 1. Unsplit format with an additional label vector with the following 4 fields. 127 | The data will be randomly split into train/test/validation set. 128 | * `(time|data|mask)`: numpy array of shape `(N, C, L)` as described before. 129 | * `label`: binary label of shape `(N,)`. 130 | 2. Data come with train/test split with the following 8 fields. 131 | The training set will be 132 | subsequently split into a smaller training set (80%) 133 | and a validation set (20%). 134 | * `(train|test)_(time|data|mask)` 135 | * `(train|test)_label` 136 | 3. Data come with train/test/validation split with the following 12 fields. 137 | This is useful for model selection based on the metric evaluated 138 | on the validation set with multiple runs (with different randomness). 139 | * `(train|test|val)_(time|data|mask)` 140 | * `(train|test|val)_label` 141 | 142 | The function `split_data` in [time_series.py](time-series/time_series.py) 143 | demonstrates how the data file is read and split 144 | into training/test/validation set. 145 | You can follow this to create time series data of your own. 146 | 147 | Once the time series data is ready, 148 | run the following command under the `time-series` directory: 149 | 150 | ```sh 151 | # P-VAE: 152 | python mimic3_pvae.py 153 | # P-BiGAN: 154 | python mimic3_pbigan.py 155 | ``` 156 | 157 | ## Citation 158 | 159 | If you find our work relevant to your research, please cite: 160 | 161 | ```bibtex 162 | @InProceedings{li2020learning, 163 | title = {Learning from Irregularly-Sampled Time Series: A Missing Data Perspective}, 164 | author = {Li, Steven Cheng-Xian and Marlin, Benjamin M.}, 165 | booktitle = {Proceedings of the 37th International Conference on Machine Learning}, 166 | year = {2020} 167 | } 168 | ``` 169 | 170 | ## Contact 171 | 172 | Your feedback would be greatly appreciated! 173 | Reach us at . 174 | -------------------------------------------------------------------------------- /examples/images.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/steveli/partial-encoder-decoder/402abdb97b1cdfcca2a30b05507cd97be560ee75/examples/images.png -------------------------------------------------------------------------------- /examples/time-series.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/steveli/partial-encoder-decoder/402abdb97b1cdfcca2a30b05507cd97be560ee75/examples/time-series.png -------------------------------------------------------------------------------- /image/celeba-data/download-celeba.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | fileid=0B7EVK8r0v71pZjFTYXZWM3FlRnM 4 | filename=img_align_celeba.zip 5 | 6 | curl -c ./cookie -s -L "https://drive.google.com/uc?export=download&id=${fileid}" > /dev/null 7 | curl -Lb ./cookie "https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=${fileid}" -o ${filename} 8 | unzip $filename 9 | rm -f cookie $filename 10 | -------------------------------------------------------------------------------- /image/celeba_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def dconv_bn_relu(in_dim, out_dim): 6 | return nn.Sequential( 7 | nn.ConvTranspose2d(in_dim, out_dim, 5, 2, 8 | padding=2, output_padding=1, bias=False), 9 | nn.BatchNorm2d(out_dim), 10 | nn.ReLU()) 11 | 12 | 13 | class ConvDecoder(nn.Module): 14 | def __init__(self, latent_size=128): 15 | super().__init__() 16 | 17 | self.out_channels = 3 18 | dim = 64 19 | 20 | self.l1 = nn.Sequential( 21 | nn.Linear(latent_size, dim * 8 * 4 * 4, bias=False), 22 | nn.BatchNorm1d(dim * 8 * 4 * 4), 23 | nn.ReLU()) 24 | 25 | self.l2_5 = nn.Sequential( 26 | dconv_bn_relu(dim * 8, dim * 4), 27 | dconv_bn_relu(dim * 4, dim * 2), 28 | dconv_bn_relu(dim * 2, dim), 29 | nn.ConvTranspose2d(dim, self.out_channels, 5, 2, 30 | padding=2, output_padding=1)) 31 | 32 | def forward(self, input): 33 | x = self.l1(input) 34 | x = x.view(x.shape[0], -1, 4, 4) 35 | x = self.l2_5(x) 36 | return x, torch.sigmoid(x) 37 | -------------------------------------------------------------------------------- /image/celeba_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch.distributions.normal import Normal 4 | import flow 5 | 6 | 7 | def conv_ln_lrelu(in_dim, out_dim): 8 | return nn.Sequential( 9 | nn.Conv2d(in_dim, out_dim, 5, 2, 2), 10 | nn.InstanceNorm2d(out_dim, affine=True), 11 | nn.LeakyReLU(0.2)) 12 | 13 | 14 | class ConvEncoder(nn.Module): 15 | def __init__(self, latent_size, flow_depth=2, logprob=False): 16 | super().__init__() 17 | 18 | if logprob: 19 | self.encode_func = self.encode_logprob 20 | else: 21 | self.encode_func = self.encode 22 | 23 | dim = 64 24 | self.ls = nn.Sequential( 25 | nn.Conv2d(3, dim, 5, 2, 2), nn.LeakyReLU(0.2), 26 | conv_ln_lrelu(dim, dim * 2), 27 | conv_ln_lrelu(dim * 2, dim * 4), 28 | conv_ln_lrelu(dim * 4, dim * 8), 29 | nn.Conv2d(dim * 8, latent_size, 4)) 30 | 31 | if flow_depth > 0: 32 | # IAF 33 | hidden_size = latent_size * 2 34 | flow_layers = [flow.InverseAutoregressiveFlow( 35 | latent_size, hidden_size, latent_size) 36 | for _ in range(flow_depth)] 37 | 38 | flow_layers.append(flow.Reverse(latent_size)) 39 | self.q_z_flow = flow.FlowSequential(*flow_layers) 40 | self.enc_chunk = 3 41 | else: 42 | self.q_z_flow = None 43 | self.enc_chunk = 2 44 | 45 | fc_out_size = latent_size * self.enc_chunk 46 | self.fc = nn.Sequential( 47 | nn.Linear(latent_size, fc_out_size), 48 | nn.LayerNorm(fc_out_size), 49 | nn.LeakyReLU(0.2), 50 | nn.Linear(fc_out_size, fc_out_size), 51 | ) 52 | 53 | def forward(self, input, k_samples=5): 54 | return self.encode_func(input, k_samples) 55 | 56 | def encode_logprob(self, input, k_samples=5): 57 | x = self.ls(input) 58 | x = x.view(input.shape[0], -1) 59 | fc_out = self.fc(x).chunk(self.enc_chunk, dim=1) 60 | mu, logvar = fc_out[:2] 61 | std = F.softplus(logvar) 62 | qz_x = Normal(mu, std) 63 | z = qz_x.rsample([k_samples]) 64 | log_q_z = qz_x.log_prob(z) 65 | if self.q_z_flow: 66 | z, log_q_z_flow = self.q_z_flow(z, context=fc_out[2]) 67 | log_q_z = (log_q_z + log_q_z_flow).sum(-1) 68 | else: 69 | log_q_z = log_q_z.sum(-1) 70 | return z, log_q_z 71 | 72 | def encode(self, input, _): 73 | x = self.ls(input) 74 | x = x.view(input.shape[0], -1) 75 | fc_out = self.fc(x).chunk(self.enc_chunk, dim=1) 76 | mu, logvar = fc_out[:2] 77 | std = F.softplus(logvar) 78 | qz_x = Normal(mu, std) 79 | z = qz_x.rsample() 80 | if self.q_z_flow: 81 | z, _ = self.q_z_flow(z, context=fc_out[2]) 82 | return z 83 | -------------------------------------------------------------------------------- /image/celeba_fid.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from torchvision import datasets, transforms 3 | from PIL import Image 4 | from fid import FID 5 | 6 | 7 | class CelebAFID(FID): 8 | def __init__(self, batch_size=256, data_name='celeba', 9 | workers=0, verbose=True): 10 | self.batch_size = batch_size 11 | self.workers = workers 12 | super().__init__(data_name, verbose) 13 | 14 | def complete_data(self): 15 | data = datasets.ImageFolder( 16 | 'celeba', 17 | transforms.Compose([ 18 | transforms.CenterCrop(108), 19 | transforms.Resize(size=64, interpolation=Image.BICUBIC), 20 | transforms.ToTensor(), 21 | # transforms.Normalize(mean=(.5, .5, .5), std=(.5, .5, .5)), 22 | ])) 23 | 24 | images = len(data) 25 | data_loader = DataLoader( 26 | data, batch_size=self.batch_size, num_workers=self.workers) 27 | 28 | return data_loader, images 29 | -------------------------------------------------------------------------------- /image/celeba_pbigan.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import grad 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.utils.data import DataLoader 8 | from pathlib import Path 9 | from datetime import datetime 10 | import pprint 11 | import argparse 12 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA 13 | from celeba_decoder import ConvDecoder 14 | from celeba_encoder import ConvEncoder, conv_ln_lrelu 15 | from mmd import mmd 16 | from utils import mkdir, make_scheduler 17 | from visualize import Visualizer 18 | 19 | 20 | use_cuda = torch.cuda.is_available() 21 | device = torch.device('cuda' if use_cuda else 'cpu') 22 | 23 | 24 | class PBiGAN(nn.Module): 25 | def __init__(self, encoder, decoder, ae_loss='mse'): 26 | super().__init__() 27 | self.encoder = encoder 28 | self.decoder = decoder 29 | self.ae_loss = ae_loss 30 | 31 | def forward(self, x, mask, ae=True): 32 | z_T = self.encoder(x * mask) 33 | 34 | z_gen = torch.empty_like(z_T).normal_() 35 | x_gen_logit, x_gen = self.decoder(z_gen) 36 | 37 | x_logit, x_recon = self.decoder(z_T) 38 | 39 | recon_loss = 0 40 | if ae: 41 | if self.ae_loss == 'mse': 42 | recon_loss = F.mse_loss( 43 | x_recon * mask, x * mask, reduction='none') * mask 44 | recon_loss = recon_loss.sum((1, 2, 3)).mean() 45 | elif self.ae_loss == 'l1': 46 | recon_loss = F.l1_loss( 47 | x_recon * mask, x * mask, reduction='none') * mask 48 | recon_loss = recon_loss.sum((1, 2, 3)).mean() 49 | elif self.ae_loss == 'smooth_l1': 50 | recon_loss = F.smooth_l1_loss( 51 | x_recon * mask, x * mask, reduction='none') * mask 52 | recon_loss = recon_loss.sum((1, 2, 3)).mean() 53 | elif self.ae_loss == 'bce': 54 | # Bernoulli noise 55 | # recon_loss: -log p(x|z) 56 | recon_loss = F.binary_cross_entropy_with_logits( 57 | x_logit * mask, x * mask, reduction='none') * mask 58 | recon_loss = recon_loss.sum((1, 2, 3)).mean() 59 | 60 | return z_T, z_gen, x_recon, x_gen, recon_loss 61 | 62 | def impute(self, x, mask): 63 | self.eval() 64 | with torch.no_grad(): 65 | z_T = self.encoder(x * mask) 66 | _, x_recon = self.decoder(z_T) 67 | self.train() 68 | return x_recon 69 | 70 | 71 | class ConvCritic(nn.Module): 72 | def __init__(self, latent_size): 73 | super().__init__() 74 | 75 | dim = 64 76 | self.conv = nn.Sequential( 77 | nn.Conv2d(3, dim, 5, 2, 2), nn.LeakyReLU(0.2), 78 | conv_ln_lrelu(dim, dim * 2), 79 | conv_ln_lrelu(dim * 2, dim * 4), 80 | conv_ln_lrelu(dim * 4, dim * 8), 81 | nn.Conv2d(dim * 8, latent_size, 4)) 82 | 83 | embed_size = 64 84 | 85 | self.z_fc = nn.Sequential( 86 | nn.Linear(latent_size, embed_size), 87 | nn.LayerNorm(embed_size), 88 | nn.LeakyReLU(0.2), 89 | nn.Linear(embed_size, embed_size), 90 | ) 91 | 92 | self.x_fc = nn.Linear(latent_size, embed_size) 93 | 94 | self.xz_fc = nn.Sequential( 95 | nn.Linear(embed_size * 2, embed_size), 96 | nn.LayerNorm(embed_size), 97 | nn.LeakyReLU(0.2), 98 | nn.Linear(embed_size, 1), 99 | ) 100 | 101 | def forward(self, input): 102 | x, z = input 103 | x = self.conv(x) 104 | x = x.view(x.shape[0], -1) 105 | x = self.x_fc(x) 106 | z = self.z_fc(z) 107 | xz = torch.cat((x, z), 1) 108 | xz = self.xz_fc(xz) 109 | return xz.view(-1) 110 | 111 | 112 | class GradientPenalty: 113 | def __init__(self, critic, batch_size=64, gp_lambda=10): 114 | self.critic = critic 115 | self.gp_lambda = gp_lambda 116 | # Interpolation coefficient 117 | self.eps = torch.empty(batch_size, device=device) 118 | # For computing the gradient penalty 119 | self.ones = torch.ones(batch_size).to(device) 120 | 121 | def interpolate(self, real, fake): 122 | eps = self.eps.view([-1] + [1] * (len(real.shape) - 1)) 123 | return (eps * real + (1 - eps) * fake).requires_grad_() 124 | 125 | def __call__(self, real, fake): 126 | real = [x.detach() for x in real] 127 | fake = [x.detach() for x in fake] 128 | self.eps.uniform_(0, 1) 129 | interp = [self.interpolate(a, b) for a, b in zip(real, fake)] 130 | grad_d = grad(self.critic(interp), 131 | interp, 132 | grad_outputs=self.ones, 133 | create_graph=True) 134 | batch_size = real[0].shape[0] 135 | grad_d = torch.cat([g.view(batch_size, -1) for g in grad_d], 1) 136 | grad_penalty = ((grad_d.norm(dim=1) - 1)**2).mean() * self.gp_lambda 137 | return grad_penalty 138 | 139 | 140 | def train_pbigan(args): 141 | torch.manual_seed(args.seed) 142 | 143 | if args.mask == 'indep': 144 | data = IndepMaskedCelebA(obs_prob=args.obs_prob) 145 | mask_str = f'{args.mask}_{args.obs_prob}' 146 | elif args.mask == 'block': 147 | data = BlockMaskedCelebA(block_len=args.block_len) 148 | mask_str = f'{args.mask}_{args.block_len}' 149 | 150 | data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True, 151 | drop_last=True) 152 | mask_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True, 153 | drop_last=True) 154 | 155 | test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True) 156 | 157 | decoder = ConvDecoder(args.latent) 158 | encoder = ConvEncoder(args.latent, args.flow, logprob=False) 159 | pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device) 160 | 161 | critic = ConvCritic(args.latent).to(device) 162 | 163 | optimizer = optim.Adam(pbigan.parameters(), lr=args.lr, betas=(.5, .9)) 164 | 165 | critic_optimizer = optim.Adam( 166 | critic.parameters(), lr=args.lr, betas=(.5, .9)) 167 | 168 | grad_penalty = GradientPenalty(critic, args.batch_size) 169 | 170 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch) 171 | 172 | path = '{}_{}_{}'.format( 173 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), mask_str) 174 | output_dir = Path('results') / 'celeba-pbigan' / path 175 | mkdir(output_dir) 176 | print(output_dir) 177 | 178 | if args.save_interval > 0: 179 | model_dir = mkdir(output_dir / 'model') 180 | 181 | with (output_dir / 'args.txt').open('w') as f: 182 | print(pprint.pformat(vars(args)), file=f) 183 | 184 | vis = Visualizer(output_dir, loss_xlim=(0, args.epoch)) 185 | 186 | test_x, test_mask, index = iter(test_loader).next() 187 | test_x = test_x.to(device) 188 | test_mask = test_mask.to(device).float() 189 | bbox = None 190 | if data.mask_loc is not None: 191 | bbox = [data.mask_loc[idx] for idx in index] 192 | 193 | n_critic = 5 194 | critic_updates = 0 195 | ae_weight = 0 196 | 197 | for epoch in range(args.epoch): 198 | loss_breakdown = defaultdict(float) 199 | 200 | if epoch >= args.ae_start: 201 | ae_weight = args.ae 202 | 203 | for (x, mask, _), (_, mask_gen, _) in zip(data_loader, mask_loader): 204 | x = x.to(device) 205 | mask = mask.to(device).float() 206 | mask_gen = mask_gen.to(device).float() 207 | 208 | if critic_updates < n_critic: 209 | z_enc, z_gen, x_rec, x_gen, _ = pbigan(x, mask, ae=False) 210 | 211 | real_score = critic((x * mask, z_enc)).mean() 212 | fake_score = critic((x_gen * mask_gen, z_gen)).mean() 213 | 214 | w_dist = real_score - fake_score 215 | D_loss = -w_dist + grad_penalty((x * mask, z_enc), 216 | (x_gen * mask_gen, z_gen)) 217 | 218 | critic_optimizer.zero_grad() 219 | D_loss.backward() 220 | critic_optimizer.step() 221 | 222 | loss_breakdown['D'] += D_loss.item() 223 | 224 | critic_updates += 1 225 | else: 226 | critic_updates = 0 227 | 228 | # Update generators' parameters 229 | for p in critic.parameters(): 230 | p.requires_grad_(False) 231 | 232 | z_enc, z_gen, x_rec, x_gen, ae_loss = pbigan( 233 | x, mask, ae=(args.ae > 0)) 234 | 235 | real_score = critic((x * mask, z_enc)).mean() 236 | fake_score = critic((x_gen * mask_gen, z_gen)).mean() 237 | 238 | G_loss = real_score - fake_score 239 | 240 | ae_loss = ae_loss * ae_weight 241 | loss = G_loss + ae_loss 242 | 243 | mmd_loss = 0 244 | if args.mmd > 0: 245 | mmd_loss = mmd(z_enc, z_gen) 246 | loss += mmd_loss * args.mmd 247 | 248 | optimizer.zero_grad() 249 | loss.backward() 250 | optimizer.step() 251 | 252 | loss_breakdown['G'] += G_loss.item() 253 | if torch.is_tensor(ae_loss): 254 | loss_breakdown['AE'] += ae_loss.item() 255 | if torch.is_tensor(mmd_loss): 256 | loss_breakdown['MMD'] += mmd_loss.item() 257 | loss_breakdown['total'] += loss.item() 258 | 259 | for p in critic.parameters(): 260 | p.requires_grad_(True) 261 | 262 | if scheduler: 263 | scheduler.step() 264 | 265 | vis.plot_loss(epoch, loss_breakdown) 266 | 267 | if epoch % args.plot_interval == 0: 268 | with torch.no_grad(): 269 | pbigan.eval() 270 | z, z_gen, x_rec, x_gen, ae_loss = pbigan(test_x, test_mask) 271 | pbigan.train() 272 | vis.plot(epoch, test_x, test_mask, bbox, x_rec, x_gen) 273 | 274 | model_dict = { 275 | 'pbigan': pbigan.state_dict(), 276 | 'critic': critic.state_dict(), 277 | 'history': vis.history, 278 | 'epoch': epoch, 279 | 'args': args, 280 | } 281 | torch.save(model_dict, str(output_dir / 'model.pth')) 282 | if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0: 283 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) 284 | 285 | print(output_dir) 286 | 287 | 288 | def main(): 289 | parser = argparse.ArgumentParser() 290 | 291 | parser.add_argument('--seed', type=int, default=3, 292 | help='random seed') 293 | # training options 294 | parser.add_argument('--plot-interval', type=int, default=10, 295 | help='plot interval. 0 to disable plotting.') 296 | parser.add_argument('--save-interval', type=int, default=0, 297 | help='interval to save models. 0 to disable saving.') 298 | parser.add_argument('--mask', default='block', 299 | help='missing data mask. (options: block, indep)') 300 | # option for block: set to 0 for variable size 301 | parser.add_argument('--block-len', type=int, default=32, 302 | help='size of observed block. ' 303 | 'Set to 0 to use variable size') 304 | # option for indep: 305 | parser.add_argument('--obs-prob', type=float, default=.2, 306 | help='observed probability for independent dropout') 307 | 308 | parser.add_argument('--flow', type=int, default=2, 309 | help='number of IAF layers') 310 | parser.add_argument('--lr', type=float, default=2e-4, 311 | help='learning rate') 312 | parser.add_argument('--min-lr', type=float, default=5e-5, 313 | help='min learning rate for LR scheduler. ' 314 | '-1 to disable annealing') 315 | 316 | parser.add_argument('--epoch', type=int, default=500, 317 | help='number of training epochs') 318 | parser.add_argument('--batch-size', type=int, default=256, 319 | help='batch size') 320 | parser.add_argument('--ae', type=float, default=.002, 321 | help='autoencoding regularization strength') 322 | parser.add_argument('--ae-start', type=int, default=0, 323 | help='start epoch of autoencoding regularization') 324 | parser.add_argument('--prefix', default='pbigan', 325 | help='prefix of output directory') 326 | parser.add_argument('--latent', type=int, default=128, 327 | help='dimension of latent variable') 328 | parser.add_argument('--aeloss', default='smooth_l1', 329 | help='autoencoding loss. ' 330 | '(options: mse, bce, smooth_l1, l1)') 331 | # --mmd 0 to disable mmd regularization 332 | parser.add_argument('--mmd', type=float, default=0, 333 | help='MMD strength for latent variable') 334 | 335 | args = parser.parse_args() 336 | 337 | train_pbigan(args) 338 | 339 | 340 | if __name__ == '__main__': 341 | main() 342 | -------------------------------------------------------------------------------- /image/celeba_pbigan_fid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from fid import BaseSampler, BaseImputationSampler 6 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA 7 | from celeba_fid import CelebAFID 8 | from celeba_decoder import ConvDecoder 9 | from celeba_encoder import ConvEncoder 10 | from celeba_pbigan import PBiGAN 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('root_dir') 15 | parser.add_argument('--batch-size', type=int, default=256) 16 | parser.add_argument('--only', action='store_true') 17 | args = parser.parse_args() 18 | 19 | 20 | use_cuda = torch.cuda.is_available() 21 | device = torch.device('cuda' if use_cuda else 'cpu') 22 | 23 | 24 | class Sampler(BaseSampler): 25 | def __init__(self, model, latent_size, images=60000, batch_size=256): 26 | super().__init__(images) 27 | self.model = model 28 | self.rand_z = torch.empty(batch_size, latent_size, device=device) 29 | 30 | def sample(self): 31 | self.rand_z.normal_() 32 | return self.model.decoder(self.rand_z)[1] 33 | 34 | 35 | class ImputationSampler(BaseImputationSampler): 36 | def __init__(self, data_loader, model, batch_size=256): 37 | super().__init__(data_loader) 38 | self.model = model 39 | 40 | def impute(self, data, mask): 41 | z = self.model.encoder(data * mask) 42 | x_recon = self.model.decoder(z)[1] 43 | imputed_data = data * mask + x_recon * (1 - mask) 44 | return imputed_data 45 | 46 | 47 | class Data: 48 | def __init__(self, args, batch_size): 49 | self.args = args 50 | self.batch_size = batch_size 51 | self.data_loader = None 52 | 53 | def gen_data(self): 54 | args = self.args 55 | if args.mask == 'indep': 56 | data = IndepMaskedCelebA(obs_prob=args.obs_prob) 57 | elif args.mask == 'block': 58 | data = BlockMaskedCelebA(block_len=args.block_len) 59 | 60 | self.data_size = len(data) 61 | self.data_loader = DataLoader(data, batch_size=self.batch_size) 62 | 63 | def get_data(self): 64 | if self.data_loader is None: 65 | self.gen_data() 66 | return self.data_loader, self.data_size 67 | 68 | 69 | def pretrained_misgan_fid(model_file, data_loader, data_size): 70 | model = torch.load(model_file, map_location='cpu') 71 | 72 | model_args = model['args'] 73 | decoder = ConvDecoder(model_args.latent) 74 | encoder = ConvEncoder(model_args.latent, model_args.flow, logprob=False) 75 | pbigan = PBiGAN(encoder, decoder, model_args.aeloss).to(device) 76 | pbigan.load_state_dict(model['pbigan']) 77 | 78 | batch_size = args.batch_size 79 | 80 | pbigan.eval() 81 | with torch.no_grad(): 82 | compute_fid = CelebAFID(batch_size=batch_size) 83 | sampler = Sampler(pbigan, model_args.latent, data_size, batch_size) 84 | gen_fid = compute_fid.fid(sampler, data_size) 85 | print('fid: {:.2f}'.format(gen_fid)) 86 | 87 | imputation_sampler = ImputationSampler(data_loader, pbigan, batch_size) 88 | imp_fid = compute_fid.fid(imputation_sampler, data_size) 89 | print('impute fid: {:.2f}'.format(imp_fid)) 90 | 91 | return gen_fid, imp_fid 92 | 93 | 94 | def gen_fid_file(model_file, fid_file, imp_fid_file, data): 95 | if imp_fid_file.exists(): 96 | print('skip') 97 | return 98 | 99 | fid, imp_fid = pretrained_misgan_fid(model_file, *data.get_data()) 100 | 101 | with fid_file.open('w') as f: 102 | print(fid, file=f) 103 | 104 | if imp_fid is not None: 105 | with imp_fid_file.open('w') as f: 106 | print(imp_fid, file=f) 107 | 108 | 109 | def main(): 110 | root_dir = Path(args.root_dir) 111 | model_file = root_dir / 'model.pth' 112 | print(model_file) 113 | fid_file = root_dir / 'fid.txt' 114 | imp_fid_file = root_dir / 'impute-fid.txt' 115 | 116 | model = torch.load(model_file, map_location='cpu') 117 | data = Data(model['args'], args.batch_size) 118 | 119 | gen_fid_file(model_file, fid_file, imp_fid_file, data) 120 | 121 | if args.only: 122 | return 123 | 124 | model_dir = root_dir / 'model' 125 | for model_file in sorted(model_dir.glob('*.pth')): 126 | print(model_file) 127 | fid_file = model_dir / f'{model_file.stem}-fid.txt' 128 | imp_fid_file = model_dir / f'{model_file.stem}-impute-fid.txt' 129 | gen_fid_file(model_file, fid_file, imp_fid_file, data) 130 | 131 | 132 | if __name__ == '__main__': 133 | main() 134 | -------------------------------------------------------------------------------- /image/celeba_pvae.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.distributions.normal import Normal 7 | from torch.distributions.categorical import Categorical 8 | from torch.utils.data import DataLoader 9 | import math 10 | import sys 11 | import logging 12 | from pathlib import Path 13 | from datetime import datetime 14 | import pprint 15 | import argparse 16 | from masked_celeba import BlockMaskedCelebA, IndepMaskedCelebA 17 | from celeba_decoder import ConvDecoder 18 | from celeba_encoder import ConvEncoder 19 | from utils import mkdir, make_scheduler 20 | from visualize import Visualizer 21 | 22 | 23 | use_cuda = torch.cuda.is_available() 24 | device = torch.device('cuda' if use_cuda else 'cpu') 25 | 26 | 27 | class PVAE(nn.Module): 28 | def __init__(self, encoder, decoder): 29 | super().__init__() 30 | self.encoder = encoder 31 | self.decoder = decoder 32 | 33 | def forward(self, x, mask, k_samples=5, kl_weight=1, ae_weight=1): 34 | z_T, log_q_z = self.encoder(x * mask, k_samples) 35 | 36 | pz = Normal(torch.zeros_like(z_T), torch.ones_like(z_T)) 37 | log_p_z = pz.log_prob(z_T).sum(-1) 38 | # kl_loss: log q(z|x) - log p(z) 39 | kl_loss = log_q_z - log_p_z 40 | 41 | # Reshape z to accommodate modules with strict input shape requirements 42 | # such as convolutional layers. 43 | x_logit, x_recon = self.decoder(z_T.view(-1, *z_T.shape[2:])) 44 | expanded_mask = mask[None] 45 | masked_logit = x_logit.view(k_samples, *x.shape) * expanded_mask 46 | masked_x = (x * mask)[None].expand_as(masked_logit) 47 | # Bernoulli noise 48 | bce = F.binary_cross_entropy_with_logits( 49 | masked_logit, masked_x, reduction='none') 50 | recon_loss = (bce * expanded_mask).sum((2, 3, 4)) 51 | 52 | # elbo = log p(x|z) + log p(z) - log q(z|x) 53 | elbo = -(recon_loss * ae_weight + kl_loss * kl_weight) 54 | 55 | # IWAE loss: -log E[p(x|z) p(z) / q(z|x)] 56 | # Here we ignore the constant shift of -log(k_samples) 57 | loss = -elbo.logsumexp(0).mean() 58 | 59 | x_recon = x_recon.view(-1, *x.shape) 60 | loss_breakdown = { 61 | 'loss': loss.item(), 62 | 'KL': kl_loss.mean().item(), 63 | 'recon': recon_loss.mean().item(), 64 | } 65 | return loss, z_T, x_recon, elbo, loss_breakdown 66 | 67 | def impute(self, x, mask, k_samples=10): 68 | self.eval() 69 | with torch.no_grad(): 70 | _, z, x_recon, elbo, _ = self(x, mask, k_samples) 71 | # sampling importance resampling 72 | is_idx = Categorical(logits=elbo.t()).sample() 73 | batch_idx = torch.arange(len(x)) 74 | z = z[is_idx, batch_idx] 75 | x_recon = x_recon[is_idx, batch_idx] 76 | self.train() 77 | return x_recon 78 | 79 | 80 | def train_pvae(args): 81 | torch.manual_seed(args.seed) 82 | 83 | if args.mask == 'indep': 84 | data = IndepMaskedCelebA(obs_prob=args.obs_prob) 85 | mask_str = f'{args.mask}_{args.obs_prob}' 86 | elif args.mask == 'block': 87 | data = BlockMaskedCelebA(block_len=args.block_len) 88 | mask_str = f'{args.mask}_{args.block_len}' 89 | 90 | data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True, 91 | drop_last=True) 92 | 93 | test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True) 94 | 95 | decoder = ConvDecoder(args.latent) 96 | encoder = ConvEncoder(args.latent, args.flow, logprob=True) 97 | pvae = PVAE(encoder, decoder).to(device) 98 | 99 | optimizer = optim.Adam(pvae.parameters(), lr=args.lr) 100 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch) 101 | 102 | rand_z = torch.empty(args.batch_size, args.latent, device=device) 103 | 104 | path = '{}_{}_{}'.format( 105 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), mask_str) 106 | output_dir = Path('results') / 'celeba-pvae' / path 107 | mkdir(output_dir) 108 | print(output_dir) 109 | 110 | if args.save_interval > 0: 111 | model_dir = mkdir(output_dir / 'model') 112 | 113 | logging.basicConfig( 114 | level=logging.INFO, 115 | format='%(asctime)s %(message)s', 116 | datefmt='%Y-%m-%d %H:%M:%S', 117 | handlers=[ 118 | logging.FileHandler(output_dir / 'log.txt'), 119 | logging.StreamHandler(sys.stdout), 120 | ], 121 | ) 122 | 123 | with (output_dir / 'args.txt').open('w') as f: 124 | print(pprint.pformat(vars(args)), file=f) 125 | 126 | vis = Visualizer(output_dir, loss_xlim=(0, args.epoch)) 127 | 128 | test_x, test_mask, index = iter(test_loader).next() 129 | test_x = test_x.to(device) 130 | test_mask = test_mask.to(device).float() 131 | bbox = None 132 | if data.mask_loc is not None: 133 | bbox = [data.mask_loc[idx] for idx in index] 134 | 135 | kl_center = (args.kl_on + args.kl_off) / 2 136 | kl_scale = 1 / min(args.kl_on - args.kl_off, 1) 137 | 138 | for epoch in range(args.epoch): 139 | if epoch >= args.kl_on: 140 | kl_weight = 1 141 | elif epoch < args.kl_off: 142 | kl_weight = 0 143 | else: 144 | kl_weight = 1 / (1 + math.exp(-(epoch - kl_center) * kl_scale)) 145 | loss_breakdown = defaultdict(float) 146 | for x, mask, _ in data_loader: 147 | x = x.to(device) 148 | mask = mask.to(device).float() 149 | 150 | optimizer.zero_grad() 151 | loss, _, _, _, loss_info = pvae( 152 | x, mask, args.k, kl_weight, args.ae) 153 | loss.backward() 154 | optimizer.step() 155 | for name, val in loss_info.items(): 156 | loss_breakdown[name] += val 157 | 158 | if scheduler: 159 | scheduler.step() 160 | 161 | vis.plot_loss(epoch, loss_breakdown) 162 | 163 | if epoch % args.plot_interval == 0: 164 | x_recon = pvae.impute(test_x, test_mask, args.k) 165 | with torch.no_grad(): 166 | pvae.eval() 167 | rand_z.normal_() 168 | _, x_gen = decoder(rand_z) 169 | pvae.train() 170 | vis.plot(epoch, test_x, test_mask, bbox, x_recon, x_gen) 171 | 172 | model_dict = { 173 | 'pvae': pvae.state_dict(), 174 | 'history': vis.history, 175 | 'epoch': epoch, 176 | 'args': args, 177 | } 178 | torch.save(model_dict, str(output_dir / 'model.pth')) 179 | if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0: 180 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) 181 | 182 | print(output_dir) 183 | 184 | 185 | def main(): 186 | parser = argparse.ArgumentParser() 187 | 188 | parser.add_argument('--seed', type=int, default=3, 189 | help='random seed') 190 | # training options 191 | parser.add_argument('--plot-interval', type=int, default=50, 192 | help='plot interval. 0 to disable plotting.') 193 | parser.add_argument('--save-interval', type=int, default=50, 194 | help='interval to save models. 0 to disable saving.') 195 | parser.add_argument('--mask', default='block', 196 | help='missing data mask. (options: block, indep)') 197 | # option for block: set to 0 for variable size 198 | parser.add_argument('--block-len', type=int, default=32, 199 | help='size of observed block. ' 200 | 'Set to 0 to use variable size') 201 | # option for indep: 202 | parser.add_argument('--obs-prob', type=float, default=.2, 203 | help='observed probability for independent dropout') 204 | 205 | parser.add_argument('--flow', type=int, default=2, 206 | help='number of IAF layers') 207 | parser.add_argument('--lr', type=float, default=1e-4, 208 | help='learning rate') 209 | parser.add_argument('--min-lr', type=float, default=-1, 210 | help='min learning rate for LR scheduler. ' 211 | '-1 to disable annealing') 212 | 213 | parser.add_argument('--epoch', type=int, default=500, 214 | help='number of training epochs') 215 | parser.add_argument('--batch-size', type=int, default=256, 216 | help='batch size') 217 | parser.add_argument('--k', type=int, default=5, 218 | help='number of importance weights') 219 | parser.add_argument('--prefix', default='pvae', 220 | help='prefix of output directory') 221 | parser.add_argument('--latent', type=int, default=128, 222 | help='dimension of latent variable') 223 | parser.add_argument('--kl-off', type=int, default=10, 224 | help='epoch to start tune up KL weight from zero') 225 | # set --kl-on to 0 to use constant kl_weight = 1 226 | parser.add_argument('--kl-on', type=int, default=20, 227 | help='start epoch to use KL weight 1') 228 | parser.add_argument('--ae', type=float, default=1, 229 | help='log-likelihood weight for ELBO') 230 | 231 | args = parser.parse_args() 232 | 233 | train_pvae(args) 234 | 235 | 236 | if __name__ == '__main__': 237 | main() 238 | -------------------------------------------------------------------------------- /image/fid.py: -------------------------------------------------------------------------------- 1 | """Calculates the Frechet Inception Distance (FID) to evalulate GANs 2 | 3 | **Code apapted from https://github.com/mseitzer/pytorch-fid** 4 | 5 | The FID metric calculates the distance between two distributions of images. 6 | Typically, we have summary statistics (mean & covariance matrix) of one 7 | of these distributions, while the 2nd distribution is given by a GAN. 8 | 9 | When run as a stand-alone program, it compares the distribution of 10 | images that are stored as PNG/JPEG at a specified location with a 11 | distribution given by summary statistics (in pickle format). 12 | 13 | The FID is calculated by assuming that X_1 and X_2 are the activations of 14 | the pool_3 layer of the inception net for generated samples and real world 15 | samples respectivly. 16 | 17 | See --help to see further details. 18 | 19 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead 20 | of Tensorflow 21 | 22 | Copyright 2018 Institute of Bioinformatics, JKU Linz 23 | 24 | Licensed under the Apache License, Version 2.0 (the "License"); 25 | you may not use this file except in compliance with the License. 26 | You may obtain a copy of the License at 27 | 28 | http://www.apache.org/licenses/LICENSE-2.0 29 | 30 | Unless required by applicable law or agreed to in writing, software 31 | distributed under the License is distributed on an "AS IS" BASIS, 32 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 33 | See the License for the specific language governing permissions and 34 | limitations under the License. 35 | """ 36 | 37 | from pathlib import Path 38 | import torch 39 | import numpy as np 40 | from scipy import linalg 41 | import time 42 | import sys 43 | from inception import InceptionV3 44 | 45 | 46 | use_cuda = torch.cuda.is_available() 47 | device = torch.device('cuda' if use_cuda else 'cpu') 48 | 49 | FEATURE_DIM = 2048 50 | RESIZE = 299 51 | 52 | 53 | def get_activations(image_iterator, images, model, verbose=True): 54 | """Calculates the activations of the pool_3 layer for all images. 55 | 56 | Params: 57 | -- image_iterator 58 | : A generator that generates a batch of images at a time. 59 | -- images : Number of images that will be generated by 60 | image_iterator. 61 | -- model : Instance of inception model 62 | -- verbose : If set to True and parameter out_step is given, the number 63 | of calculated batches is reported. 64 | Returns: 65 | -- A numpy array of dimension (num images, dims) that contains the 66 | activations of the given tensor when feeding inception with the 67 | query tensor. 68 | """ 69 | model.eval() 70 | 71 | if not sys.stdout.isatty(): 72 | verbose = False 73 | 74 | pred_arr = np.empty((images, FEATURE_DIM)) 75 | end = 0 76 | t0 = time.time() 77 | 78 | for batch in image_iterator: 79 | if not isinstance(batch, torch.Tensor): 80 | batch = batch[0] 81 | start = end 82 | batch_size = batch.shape[0] 83 | end = start + batch_size 84 | 85 | with torch.no_grad(): 86 | batch = batch.to(device) 87 | pred = model(batch)[0] 88 | batch_feature = pred.cpu().numpy().reshape(batch_size, -1) 89 | pred_arr[start:end] = batch_feature 90 | 91 | if verbose: 92 | print('\rProcessed: {} time: {:.2f}'.format( 93 | end, time.time() - t0), end='', flush=True) 94 | 95 | assert end == images 96 | 97 | if verbose: 98 | print(' done') 99 | 100 | return pred_arr 101 | 102 | 103 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 104 | """Numpy implementation of the Frechet Distance. 105 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 106 | and X_2 ~ N(mu_2, C_2) is 107 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 108 | 109 | Stable version by Dougal J. Sutherland. 110 | 111 | Params: 112 | -- mu1 : Numpy array containing the activations of a layer of the 113 | inception net (like returned by the function 'get_predictions') 114 | for generated samples. 115 | -- mu2 : The sample mean over activations, precalculated on an 116 | representive data set. 117 | -- sigma1: The covariance matrix over activations for generated samples. 118 | -- sigma2: The covariance matrix over activations, precalculated on an 119 | representive data set. 120 | 121 | Returns: 122 | -- : The Frechet Distance. 123 | """ 124 | 125 | mu1 = np.atleast_1d(mu1) 126 | mu2 = np.atleast_1d(mu2) 127 | 128 | sigma1 = np.atleast_2d(sigma1) 129 | sigma2 = np.atleast_2d(sigma2) 130 | 131 | assert mu1.shape == mu2.shape, \ 132 | 'Training and test mean vectors have different lengths' 133 | assert sigma1.shape == sigma2.shape, \ 134 | 'Training and test covariances have different dimensions' 135 | 136 | diff = mu1 - mu2 137 | 138 | # Product might be almost singular 139 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 140 | if not np.isfinite(covmean).all(): 141 | msg = ('fid calculation produces singular product; ' 142 | 'adding %s to diagonal of cov estimates') % eps 143 | print(msg) 144 | offset = np.eye(sigma1.shape[0]) * eps 145 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 146 | 147 | # Numerical error might give slight imaginary component 148 | if np.iscomplexobj(covmean): 149 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 150 | m = np.max(np.abs(covmean.imag)) 151 | raise ValueError('Imaginary component {}'.format(m)) 152 | covmean = covmean.real 153 | 154 | tr_covmean = np.trace(covmean) 155 | 156 | return (diff.dot(diff) + np.trace(sigma1) + 157 | np.trace(sigma2) - 2 * tr_covmean) 158 | 159 | 160 | def calculate_activation_statistics(image_iterator, images, model, 161 | verbose=False): 162 | """Calculation of the statistics used by the FID. 163 | Params: 164 | -- image_iterator 165 | : A generator that generates a batch of images at a time. 166 | -- images : Number of images that will be generated by 167 | image_iterator. 168 | -- model : Instance of inception model 169 | -- verbose : If set to True and parameter out_step is given, the 170 | number of calculated batches is reported. 171 | Returns: 172 | -- mu : The mean over samples of the activations of the pool_3 layer of 173 | the inception model. 174 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 175 | the inception model. 176 | """ 177 | act = get_activations(image_iterator, images, model, verbose) 178 | mu = np.mean(act, axis=0) 179 | sigma = np.cov(act, rowvar=False) 180 | return mu, sigma 181 | 182 | 183 | class FID: 184 | def __init__(self, data_name, verbose=True): 185 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[FEATURE_DIM] 186 | model = InceptionV3([block_idx], RESIZE).to(device) 187 | self.verbose = verbose 188 | 189 | stats_dir = Path('fid_stats') 190 | stats_file = stats_dir / '{}_act_{}_{}.npz'.format( 191 | data_name, FEATURE_DIM, RESIZE) 192 | 193 | try: 194 | f = np.load(str(stats_file)) 195 | mu, sigma = f['mu'], f['sigma'] 196 | f.close() 197 | except FileNotFoundError: 198 | data_loader, images = self.complete_data() 199 | mu, sigma = calculate_activation_statistics( 200 | data_loader, images, model, verbose) 201 | stats_dir.mkdir(parents=True, exist_ok=True) 202 | np.savez(stats_file, mu=mu, sigma=sigma) 203 | 204 | self.model = model 205 | self.stats = mu, sigma 206 | 207 | def complete_data(self): 208 | raise NotImplementedError 209 | 210 | def fid(self, image_iterator, images): 211 | mu, sigma = calculate_activation_statistics( 212 | image_iterator, images, self.model, verbose=self.verbose) 213 | return calculate_frechet_distance(mu, sigma, *self.stats) 214 | 215 | 216 | class BaseSampler: 217 | def __init__(self, images): 218 | self.images = images 219 | 220 | def __iter__(self): 221 | self.n = 0 222 | return self 223 | 224 | def __next__(self): 225 | if self.n < self.images: 226 | batch = self.sample() 227 | batch_size = batch.shape[0] 228 | self.n += batch_size 229 | if self.n > self.images: 230 | return batch[:-(self.n - self.images)] 231 | return batch 232 | else: 233 | raise StopIteration 234 | 235 | def sample(self): 236 | raise NotImplementedError 237 | 238 | 239 | class BaseImputationSampler: 240 | def __init__(self, data_loader): 241 | self.data_loader = data_loader 242 | 243 | def __iter__(self): 244 | self.data_iter = iter(self.data_loader) 245 | return self 246 | 247 | def __next__(self): 248 | data, mask = next(self.data_iter)[:2] 249 | data = data.to(device) 250 | mask = mask.float().to(device) 251 | imputed_data = self.impute(data, mask) 252 | return mask * data + (1 - mask) * imputed_data 253 | 254 | def impute(self, data, mask): 255 | raise NotImplementedError 256 | -------------------------------------------------------------------------------- /image/flow.py: -------------------------------------------------------------------------------- 1 | """Code adapted from https://github.com/altosaar/variational-autoencoder""" 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class InverseAutoregressiveFlow(nn.Module): 9 | """Inverse Autoregressive Flows with LSTM-type update. One block. 10 | 11 | Eq 11-14 of https://arxiv.org/abs/1606.04934 12 | """ 13 | def __init__(self, num_input, num_hidden, num_context): 14 | super().__init__() 15 | self.made = MADE(num_input=num_input, num_output=num_input * 2, 16 | num_hidden=num_hidden, num_context=num_context) 17 | # init such that sigmoid(s) is close to 1 for stability 18 | self.sigmoid_arg_bias = nn.Parameter(torch.ones(num_input) * 2) 19 | self.sigmoid = nn.Sigmoid() 20 | self.log_sigmoid = nn.LogSigmoid() 21 | 22 | def forward(self, input, context=None): 23 | m, s = torch.chunk(self.made(input, context), chunks=2, dim=-1) 24 | s = s + self.sigmoid_arg_bias 25 | sigmoid = self.sigmoid(s) 26 | z = sigmoid * input + (1 - sigmoid) * m 27 | return z, -self.log_sigmoid(s) 28 | 29 | 30 | class FlowSequential(nn.Sequential): 31 | """Forward pass.""" 32 | 33 | def forward(self, input, context=None): 34 | total_log_prob = torch.zeros_like(input) 35 | for block in self._modules.values(): 36 | input, log_prob = block(input, context) 37 | total_log_prob += log_prob 38 | return input, total_log_prob 39 | 40 | 41 | class MaskedLinear(nn.Module): 42 | """Linear layer with some input-output connections masked.""" 43 | def __init__(self, in_features, out_features, mask, context_features=None, 44 | bias=True): 45 | super().__init__() 46 | self.linear = nn.Linear(in_features, out_features, bias) 47 | self.register_buffer("mask", mask) 48 | if context_features is not None: 49 | self.cond_linear = nn.Linear(context_features, out_features, 50 | bias=False) 51 | 52 | def forward(self, input, context=None): 53 | output = F.linear(input, self.mask * self.linear.weight, 54 | self.linear.bias) 55 | if context is None: 56 | return output 57 | else: 58 | return output + self.cond_linear(context) 59 | 60 | 61 | class MADE(nn.Module): 62 | """Implements MADE: Masked Autoencoder for Distribution Estimation. 63 | 64 | Follows https://arxiv.org/abs/1502.03509 65 | 66 | This is used to build MAF: 67 | Masked Autoregressive Flow (https://arxiv.org/abs/1705.07057). 68 | """ 69 | def __init__(self, num_input, num_output, num_hidden, num_context): 70 | super().__init__() 71 | # m corresponds to m(k), the maximum degree of a node in the MADE paper 72 | self._m = [] 73 | self._masks = [] 74 | self._build_masks(num_input, num_output, num_hidden, num_layers=3) 75 | self._check_masks() 76 | modules = [] 77 | self.input_context_net = MaskedLinear( 78 | num_input, num_hidden, self._masks[0], num_context) 79 | modules.append(nn.ReLU()) 80 | modules.append(MaskedLinear( 81 | num_hidden, num_hidden, self._masks[1], context_features=None)) 82 | modules.append(nn.ReLU()) 83 | modules.append(MaskedLinear( 84 | num_hidden, num_output, self._masks[2], context_features=None)) 85 | self.net = nn.Sequential(*modules) 86 | 87 | def _build_masks(self, num_input, num_output, num_hidden, num_layers): 88 | """Build the masks according to Eq 12 and 13 in the MADE paper.""" 89 | rng = np.random.RandomState(0) 90 | # assign input units a number between 1 and D 91 | self._m.append(np.arange(1, num_input + 1)) 92 | for i in range(1, num_layers + 1): 93 | # randomly assign maximum number of input nodes to connect to 94 | if i == num_layers: 95 | # assign output layer units a number between 1 and D 96 | m = np.arange(1, num_input + 1) 97 | assert num_output % num_input == 0, ( 98 | "num_output must be multiple of num_input") 99 | self._m.append(np.hstack( 100 | [m for _ in range(num_output // num_input)])) 101 | else: 102 | # assign hidden layer units a number between 1 and D-1 103 | self._m.append(rng.randint(1, num_input, size=num_hidden)) 104 | # self._m.append( 105 | # np.arange(1, num_hidden + 1) % (num_input - 1) + 1) 106 | if i == num_layers: 107 | mask = self._m[i][None, :] > self._m[i - 1][:, None] 108 | else: 109 | # input to hidden & hidden to hidden 110 | mask = self._m[i][None, :] >= self._m[i - 1][:, None] 111 | # need to transpose for torch linear layer. 112 | # shape (num_output, num_input) 113 | self._masks.append(torch.from_numpy(mask.astype(np.float32).T)) 114 | 115 | def _check_masks(self): 116 | """Check that the connectivity matrix between layers is lower 117 | triangular.""" 118 | # (num_input, num_hidden) 119 | prev = self._masks[0].t() 120 | for i in range(1, len(self._masks)): 121 | # num_hidden is second axis 122 | prev = prev @ self._masks[i].t() 123 | final = prev.numpy() 124 | num_input = self._masks[0].shape[1] 125 | num_output = self._masks[-1].shape[0] 126 | assert final.shape == (num_input, num_output) 127 | if num_output == num_input: 128 | assert np.triu(final).all() == 0 129 | else: 130 | for submat in np.split( 131 | final, indices_or_sections=num_output // num_input, 132 | axis=1): 133 | assert np.triu(submat).all() == 0 134 | 135 | def forward(self, input, context=None): 136 | # first hidden layer receives input and context 137 | hidden = self.input_context_net(input, context) 138 | # rest of the network is conditioned on both input and context 139 | return self.net(hidden) 140 | 141 | 142 | class Reverse(nn.Module): 143 | """ An implementation of a reversing layer from 144 | Density estimation using Real NVP 145 | (https://arxiv.org/abs/1605.08803). 146 | 147 | From https://github.com/ikostrikov/pytorch-flows/blob/master/main.py 148 | """ 149 | 150 | def __init__(self, num_input): 151 | super(Reverse, self).__init__() 152 | self.perm = np.array(np.arange(0, num_input)[::-1]) 153 | self.inv_perm = np.argsort(self.perm) 154 | 155 | def forward(self, inputs, context=None, mode='forward'): 156 | if mode == "forward": 157 | return inputs[..., self.perm], torch.zeros_like(inputs) 158 | elif mode == "inverse": 159 | return inputs[..., self.inv_perm], torch.zeros_like(inputs) 160 | else: 161 | raise ValueError("Mode must be one of {forward, inverse}.") 162 | -------------------------------------------------------------------------------- /image/inception.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code adapted from https://github.com/pytorch/vision 3 | """ 4 | 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torchvision import models 8 | 9 | 10 | class InceptionV3(nn.Module): 11 | """Pretrained InceptionV3 network returning feature maps""" 12 | 13 | # Index of default block of inception to return, 14 | # corresponds to output of final average pooling 15 | DEFAULT_BLOCK_INDEX = 3 16 | 17 | # Maps feature dimensionality to their output blocks indices 18 | BLOCK_INDEX_BY_DIM = { 19 | 64: 0, # First max pooling features 20 | 192: 1, # Second max pooling featurs 21 | 768: 2, # Pre-aux classifier features 22 | 2048: 3 # Final average pooling features 23 | } 24 | 25 | def __init__(self, 26 | output_blocks=[DEFAULT_BLOCK_INDEX], 27 | resize_input=299, # -1: not resize 28 | normalize_input=True, 29 | requires_grad=False): 30 | """Build pretrained InceptionV3 31 | 32 | Parameters 33 | ---------- 34 | output_blocks : list of int 35 | Indices of blocks to return features of. Possible values are: 36 | - 0: corresponds to output of first max pooling 37 | - 1: corresponds to output of second max pooling 38 | - 2: corresponds to output which is fed to aux classifier 39 | - 3: corresponds to output of final average pooling 40 | resize_input : bool 41 | If true, bilinearly resizes input to width and height 299 before 42 | feeding input to model. As the network without fully connected 43 | layers is fully convolutional, it should be able to handle inputs 44 | of arbitrary size, so resizing might not be strictly needed 45 | normalize_input : bool 46 | If true, normalizes the input to the statistics the pretrained 47 | Inception network expects 48 | requires_grad : bool 49 | If true, parameters of the model require gradient. Possibly useful 50 | for finetuning the network 51 | """ 52 | super(InceptionV3, self).__init__() 53 | 54 | self.resize_input = resize_input 55 | self.normalize_input = normalize_input 56 | self.output_blocks = sorted(output_blocks) 57 | self.last_needed_block = max(output_blocks) 58 | 59 | assert self.last_needed_block <= 3, \ 60 | 'Last possible output block index is 3' 61 | 62 | self.blocks = nn.ModuleList() 63 | 64 | inception = models.inception_v3(pretrained=True) 65 | 66 | # Block 0: input to maxpool1 67 | block0 = [ 68 | inception.Conv2d_1a_3x3, 69 | inception.Conv2d_2a_3x3, 70 | inception.Conv2d_2b_3x3, 71 | nn.MaxPool2d(kernel_size=3, stride=2) 72 | ] 73 | self.blocks.append(nn.Sequential(*block0)) 74 | 75 | # Block 1: maxpool1 to maxpool2 76 | if self.last_needed_block >= 1: 77 | block1 = [ 78 | inception.Conv2d_3b_1x1, 79 | inception.Conv2d_4a_3x3, 80 | nn.MaxPool2d(kernel_size=3, stride=2) 81 | ] 82 | self.blocks.append(nn.Sequential(*block1)) 83 | 84 | # Block 2: maxpool2 to aux classifier 85 | if self.last_needed_block >= 2: 86 | block2 = [ 87 | inception.Mixed_5b, 88 | inception.Mixed_5c, 89 | inception.Mixed_5d, 90 | inception.Mixed_6a, 91 | inception.Mixed_6b, 92 | inception.Mixed_6c, 93 | inception.Mixed_6d, 94 | inception.Mixed_6e, 95 | ] 96 | self.blocks.append(nn.Sequential(*block2)) 97 | 98 | # Block 3: aux classifier to final avgpool 99 | if self.last_needed_block >= 3: 100 | block3 = [ 101 | inception.Mixed_7a, 102 | inception.Mixed_7b, 103 | inception.Mixed_7c, 104 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 105 | ] 106 | self.blocks.append(nn.Sequential(*block3)) 107 | 108 | for param in self.parameters(): 109 | param.requires_grad = requires_grad 110 | 111 | def forward(self, inp): 112 | """Get Inception feature maps 113 | 114 | Parameters 115 | ---------- 116 | inp : torch.autograd.Variable 117 | Input tensor of shape Bx3xHxW. Values are expected to be in 118 | range (0, 1) 119 | 120 | Returns 121 | ------- 122 | List of torch.autograd.Variable, corresponding to the selected output 123 | block, sorted ascending by index 124 | """ 125 | outp = [] 126 | x = inp 127 | 128 | if self.resize_input > 0: 129 | # size = 299 130 | x = F.interpolate(x, size=(self.resize_input, self.resize_input), 131 | mode='bilinear', align_corners=True) 132 | 133 | if self.normalize_input: 134 | x = x.clone() 135 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 136 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 137 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 138 | 139 | for idx, block in enumerate(self.blocks): 140 | x = block(x) 141 | if idx in self.output_blocks: 142 | outp.append(x) 143 | 144 | if idx == self.last_needed_block: 145 | break 146 | 147 | return outp 148 | -------------------------------------------------------------------------------- /image/masked_celeba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | class MaskedCelebA(datasets.ImageFolder): 8 | def __init__(self, data_dir='celeba-data', image_size=64, random_seed=0): 9 | transform = transforms.Compose([ 10 | transforms.CenterCrop(108), 11 | transforms.Resize(size=image_size, interpolation=Image.BICUBIC), 12 | transforms.ToTensor(), 13 | # transforms.Normalize(mean=(.5, .5, .5), std=(.5, .5, .5)), 14 | ]) 15 | 16 | super().__init__(data_dir, transform) 17 | 18 | self.rnd = np.random.RandomState(random_seed) 19 | torch.manual_seed(random_seed) 20 | self.image_size = image_size 21 | self.cache_images() 22 | self.generate_masks() 23 | self.mask = torch.stack(self.mask) 24 | 25 | def cache_images(self): 26 | images = [] 27 | for i in range(len(self)): 28 | image, _ = super().__getitem__(i) 29 | images.append(image) 30 | self.images = torch.stack(images) 31 | 32 | def __getitem__(self, index): 33 | return self.images[index], self.mask[index], index 34 | 35 | def __len__(self): 36 | return super().__len__() 37 | 38 | 39 | class BlockMaskedCelebA(MaskedCelebA): 40 | def __init__(self, block_len=None, *args, **kwargs): 41 | self.block_len = block_len 42 | super().__init__(*args, **kwargs) 43 | 44 | def generate_masks(self): 45 | d0_len = d1_len = self.image_size 46 | d0_min_len = 12 47 | d0_max_len = d0_len - d0_min_len 48 | d1_min_len = 12 49 | d1_max_len = d1_len - d1_min_len 50 | 51 | n_masks = len(self) 52 | self.mask = [None] * n_masks 53 | self.mask_loc = [None] * n_masks 54 | for i in range(n_masks): 55 | if self.block_len == 0: 56 | d0_mask_len = self.rnd.randint(d0_min_len, d0_max_len) 57 | d1_mask_len = self.rnd.randint(d1_min_len, d1_max_len) 58 | else: 59 | d0_mask_len = d1_mask_len = self.block_len 60 | 61 | d0_start = self.rnd.randint(0, d0_len - d0_mask_len + 1) 62 | d1_start = self.rnd.randint(0, d1_len - d1_mask_len + 1) 63 | 64 | mask = torch.zeros((d0_len, d1_len), dtype=torch.uint8) 65 | mask[d0_start:(d0_start + d0_mask_len), 66 | d1_start:(d1_start + d1_mask_len)] = 1 67 | self.mask[i] = mask[None] 68 | self.mask_loc[i] = d0_start, d1_start, d0_mask_len, d1_mask_len 69 | 70 | 71 | class IndepMaskedCelebA(MaskedCelebA): 72 | def __init__(self, obs_prob=.2, obs_prob_max=None, *args, **kwargs): 73 | self.prob = obs_prob 74 | self.prob_max = obs_prob_max 75 | self.mask_loc = None 76 | super().__init__(*args, **kwargs) 77 | 78 | def generate_masks(self): 79 | imsize = self.image_size 80 | prob = self.prob 81 | prob_max = self.prob_max 82 | n_masks = len(self) 83 | self.mask = [None] * n_masks 84 | for i in range(n_masks): 85 | if prob_max is None: 86 | p = prob 87 | else: 88 | p = self.rnd.uniform(prob, prob_max) 89 | self.mask[i] = torch.ByteTensor(1, imsize, imsize).bernoulli_(p) 90 | -------------------------------------------------------------------------------- /image/masked_mnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torchvision import datasets, transforms 4 | import numpy as np 5 | 6 | 7 | class MaskedMNIST(Dataset): 8 | def __init__(self, data_dir='mnist-data', image_size=28, random_seed=0): 9 | self.rnd = np.random.RandomState(random_seed) 10 | torch.manual_seed(random_seed) 11 | self.image_size = image_size 12 | if image_size == 28: 13 | self.data = datasets.MNIST( 14 | data_dir, train=True, download=True, 15 | transform=transforms.ToTensor()) 16 | else: 17 | self.data = datasets.MNIST( 18 | data_dir, train=True, download=True, 19 | transform=transforms.Compose([ 20 | transforms.Resize(image_size), transforms.ToTensor()])) 21 | self.generate_masks() 22 | 23 | def __getitem__(self, index): 24 | image, label = self.data[index] 25 | mask = self.mask[index] 26 | return image * mask.float(), mask[None], index 27 | 28 | def __len__(self): 29 | return len(self.data) 30 | 31 | def generate_masks(self): 32 | raise NotImplementedError 33 | 34 | 35 | class BlockMaskedMNIST(MaskedMNIST): 36 | def __init__(self, block_len=11, block_len_max=None, *args, **kwargs): 37 | self.block_len = block_len 38 | self.block_len_max = block_len_max 39 | super().__init__(*args, **kwargs) 40 | 41 | def generate_masks(self): 42 | d0_len = d1_len = self.image_size 43 | n_masks = len(self) 44 | self.mask = [None] * n_masks 45 | self.mask_loc = [None] * n_masks 46 | for i in range(n_masks): 47 | if self.block_len_max is None: 48 | d0_mask_len = d1_mask_len = self.block_len 49 | else: 50 | d0_mask_len = self.rnd.randint( 51 | self.block_len, self.block_len_max) 52 | d1_mask_len = self.rnd.randint( 53 | self.block_len, self.block_len_max) 54 | 55 | d0_start = self.rnd.randint(0, d0_len - d0_mask_len + 1) 56 | d1_start = self.rnd.randint(0, d1_len - d1_mask_len + 1) 57 | 58 | mask = torch.zeros((d0_len, d1_len), dtype=torch.uint8) 59 | mask[d0_start:(d0_start + d0_mask_len), 60 | d1_start:(d1_start + d1_mask_len)] = 1 61 | self.mask[i] = mask 62 | self.mask_loc[i] = d0_start, d1_start, d0_mask_len, d1_mask_len 63 | 64 | 65 | class IndepMaskedMNIST(MaskedMNIST): 66 | def __init__(self, obs_prob=.2, obs_prob_max=None, *args, **kwargs): 67 | self.prob = obs_prob 68 | self.prob_max = obs_prob_max 69 | self.mask_loc = None 70 | super().__init__(*args, **kwargs) 71 | 72 | def generate_masks(self): 73 | imsize = self.image_size 74 | n_masks = len(self) 75 | self.mask = [None] * n_masks 76 | for i in range(n_masks): 77 | if self.prob_max is None: 78 | p = self.prob 79 | else: 80 | p = self.rnd.uniform(self.prob, self.prob_max) 81 | self.mask[i] = torch.ByteTensor(imsize, imsize).bernoulli_(p) 82 | -------------------------------------------------------------------------------- /image/mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mmd(x, y): 5 | n, dim = x.shape 6 | 7 | xx = (x**2).sum(1, keepdim=True) 8 | yy = (y**2).sum(1, keepdim=True) 9 | 10 | outer_xx = torch.mm(x, x.t()) 11 | outer_yy = torch.mm(y, y.t()) 12 | outer_xy = torch.mm(x, y.t()) 13 | 14 | diff_xx = xx + xx.t() - 2 * outer_xx 15 | diff_yy = yy + yy.t() - 2 * outer_yy 16 | diff_xy = xx + yy.t() - 2 * outer_xy 17 | 18 | C = 2. * dim 19 | k_xx = C / (C + diff_xx) 20 | k_yy = C / (C + diff_yy) 21 | k_xy = C / (C + diff_xy) 22 | 23 | mean_xx = (k_xx.sum() - k_xx.diag().sum()) / (n * (n - 1)) 24 | mean_yy = (k_yy.sum() - k_yy.diag().sum()) / (n * (n - 1)) 25 | mean_xy = k_xy.sum() / (n * n) 26 | 27 | return mean_xx + mean_yy - 2 * mean_xy 28 | -------------------------------------------------------------------------------- /image/mnist_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ConvDecoder(nn.Module): 6 | def __init__(self, latent_size=128): 7 | super().__init__() 8 | 9 | self.DIM = 64 10 | self.latent_size = latent_size 11 | 12 | self.preprocess = nn.Sequential( 13 | nn.Linear(latent_size, 4 * 4 * 4 * self.DIM), 14 | nn.ReLU(True), 15 | ) 16 | self.block1 = nn.Sequential( 17 | nn.ConvTranspose2d(4 * self.DIM, 2 * self.DIM, 5), 18 | nn.ReLU(True), 19 | ) 20 | self.block2 = nn.Sequential( 21 | nn.ConvTranspose2d(2 * self.DIM, self.DIM, 5), 22 | nn.ReLU(True), 23 | ) 24 | self.deconv_out = nn.ConvTranspose2d(self.DIM, 1, 8, stride=2) 25 | 26 | def forward(self, input): 27 | net = self.preprocess(input) 28 | net = net.view(-1, 4 * self.DIM, 4, 4) 29 | net = self.block1(net) 30 | net = net[:, :, :7, :7] 31 | net = self.block2(net) 32 | net = self.deconv_out(net) 33 | net = net.view(-1, 1, 28, 28) 34 | return net, torch.sigmoid(net) 35 | -------------------------------------------------------------------------------- /image/mnist_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torch.distributions.normal import Normal 4 | import flow 5 | 6 | 7 | class ConvEncoder(nn.Module): 8 | def __init__(self, latent_size, flow_depth=2, logprob=False): 9 | super().__init__() 10 | 11 | if logprob: 12 | self.encode_func = self.encode_logprob 13 | else: 14 | self.encode_func = self.encode 15 | 16 | DIM = 64 17 | self.main = nn.Sequential( 18 | nn.Conv2d(1, DIM, 5, stride=2, padding=2), 19 | nn.ReLU(True), 20 | nn.Conv2d(DIM, 2 * DIM, 5, stride=2, padding=2), 21 | nn.ReLU(True), 22 | nn.Conv2d(2 * DIM, 4 * DIM, 5, stride=2, padding=2), 23 | nn.ReLU(True), 24 | ) 25 | 26 | if flow_depth > 0: 27 | # IAF 28 | hidden_size = latent_size * 2 29 | flow_layers = [flow.InverseAutoregressiveFlow( 30 | latent_size, hidden_size, latent_size) 31 | for _ in range(flow_depth)] 32 | 33 | flow_layers.append(flow.Reverse(latent_size)) 34 | self.q_z_flow = flow.FlowSequential(*flow_layers) 35 | self.enc_chunk = 3 36 | else: 37 | self.q_z_flow = None 38 | self.enc_chunk = 2 39 | 40 | fc_out_size = latent_size * self.enc_chunk 41 | conv_out_size = 4 * 4 * 4 * DIM 42 | self.fc = nn.Sequential( 43 | nn.Linear(conv_out_size, fc_out_size), 44 | nn.LayerNorm(fc_out_size), 45 | nn.LeakyReLU(0.2), 46 | nn.Linear(fc_out_size, fc_out_size), 47 | ) 48 | 49 | def forward(self, input, k_samples=5): 50 | return self.encode_func(input, k_samples) 51 | 52 | def encode_logprob(self, input, k_samples=5): 53 | x = self.main(input.view(-1, 1, 28, 28)) 54 | x = x.view(input.shape[0], -1) 55 | fc_out = self.fc(x).chunk(self.enc_chunk, dim=1) 56 | mu, logvar = fc_out[:2] 57 | std = F.softplus(logvar) 58 | qz_x = Normal(mu, std) 59 | z = qz_x.rsample([k_samples]) 60 | log_q_z = qz_x.log_prob(z) 61 | if self.q_z_flow: 62 | z, log_q_z_flow = self.q_z_flow(z, context=fc_out[2]) 63 | log_q_z = (log_q_z + log_q_z_flow).sum(-1) 64 | else: 65 | log_q_z = log_q_z.sum(-1) 66 | return z, log_q_z 67 | 68 | def encode(self, input, _): 69 | x = self.main(input.view(-1, 1, 28, 28)) 70 | x = x.view(input.shape[0], -1) 71 | fc_out = self.fc(x).chunk(self.enc_chunk, dim=1) 72 | mu, logvar = fc_out[:2] 73 | std = F.softplus(logvar) 74 | qz_x = Normal(mu, std) 75 | z = qz_x.rsample() 76 | if self.q_z_flow: 77 | z, _ = self.q_z_flow(z, context=fc_out[2]) 78 | return z 79 | -------------------------------------------------------------------------------- /image/mnist_pbigan.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import grad 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.utils.data import DataLoader 8 | import sys 9 | import logging 10 | from pathlib import Path 11 | from datetime import datetime 12 | import pprint 13 | import argparse 14 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST 15 | from mnist_decoder import ConvDecoder 16 | from mnist_encoder import ConvEncoder 17 | from utils import mkdir, make_scheduler 18 | from visualize import Visualizer 19 | 20 | 21 | use_cuda = torch.cuda.is_available() 22 | device = torch.device('cuda' if use_cuda else 'cpu') 23 | 24 | 25 | class PBiGAN(nn.Module): 26 | def __init__(self, encoder, decoder, ae_loss='bce'): 27 | super().__init__() 28 | self.encoder = encoder 29 | self.decoder = decoder 30 | self.ae_loss = ae_loss 31 | 32 | def forward(self, x, mask, ae=True): 33 | z_T = self.encoder(x * mask) 34 | 35 | z_gen = torch.empty_like(z_T).normal_() 36 | x_gen_logit, x_gen = self.decoder(z_gen) 37 | 38 | x_logit, x_recon = self.decoder(z_T) 39 | 40 | recon_loss = None 41 | if ae: 42 | if self.ae_loss == 'mse': 43 | recon_loss = F.mse_loss( 44 | x_recon * mask, x * mask, reduction='none') * mask 45 | recon_loss = recon_loss.sum((1, 2, 3)).mean() 46 | elif self.ae_loss == 'l1': 47 | recon_loss = F.l1_loss( 48 | x_recon * mask, x * mask, reduction='none') * mask 49 | recon_loss = recon_loss.sum((1, 2, 3)).mean() 50 | elif self.ae_loss == 'smooth_l1': 51 | recon_loss = F.smooth_l1_loss( 52 | x_recon * mask, x * mask, reduction='none') * mask 53 | recon_loss = recon_loss.sum((1, 2, 3)).mean() 54 | elif self.ae_loss == 'bce': 55 | # Bernoulli noise 56 | # recon_loss: -log p(x|z) 57 | recon_loss = F.binary_cross_entropy_with_logits( 58 | x_logit * mask, x * mask, reduction='none') * mask 59 | recon_loss = recon_loss.sum((1, 2, 3)).mean() 60 | 61 | return z_T, z_gen, x_recon, x_gen, recon_loss 62 | 63 | 64 | class ConvCritic(nn.Module): 65 | def __init__(self, latent_size): 66 | super().__init__() 67 | 68 | self.DIM = 64 69 | self.main = nn.Sequential( 70 | nn.Conv2d(1, self.DIM, 5, stride=2, padding=2), 71 | nn.ReLU(True), 72 | nn.Conv2d(self.DIM, 2 * self.DIM, 5, stride=2, padding=2), 73 | nn.ReLU(True), 74 | nn.Conv2d(2 * self.DIM, 4 * self.DIM, 5, stride=2, padding=2), 75 | nn.ReLU(True), 76 | ) 77 | 78 | embed_size = 64 79 | 80 | self.z_fc = nn.Sequential( 81 | nn.Linear(latent_size, embed_size), 82 | nn.LayerNorm(embed_size), 83 | nn.LeakyReLU(0.2), 84 | nn.Linear(embed_size, embed_size), 85 | ) 86 | 87 | self.x_fc = nn.Linear(4 * 4 * 4 * self.DIM, embed_size) 88 | 89 | self.xz_fc = nn.Sequential( 90 | nn.Linear(embed_size * 2, embed_size), 91 | nn.LayerNorm(embed_size), 92 | nn.LeakyReLU(0.2), 93 | nn.Linear(embed_size, 1), 94 | ) 95 | 96 | def forward(self, input): 97 | x, z = input 98 | x = x.view(-1, 1, 28, 28) 99 | x = self.main(x) 100 | x = x.view(x.shape[0], -1) 101 | x = self.x_fc(x) 102 | z = self.z_fc(z) 103 | xz = torch.cat((x, z), 1) 104 | xz = self.xz_fc(xz) 105 | return xz.view(-1) 106 | 107 | 108 | class GradientPenalty: 109 | def __init__(self, critic, batch_size=64, gp_lambda=10): 110 | self.critic = critic 111 | self.gp_lambda = gp_lambda 112 | # Interpolation coefficient 113 | self.eps = torch.empty(batch_size, device=device) 114 | # For computing the gradient penalty 115 | self.ones = torch.ones(batch_size).to(device) 116 | 117 | def interpolate(self, real, fake): 118 | eps = self.eps.view([-1] + [1] * (len(real.shape) - 1)) 119 | return (eps * real + (1 - eps) * fake).requires_grad_() 120 | 121 | def __call__(self, real, fake): 122 | real = [x.detach() for x in real] 123 | fake = [x.detach() for x in fake] 124 | self.eps.uniform_(0, 1) 125 | interp = [self.interpolate(a, b) for a, b in zip(real, fake)] 126 | grad_d = grad(self.critic(interp), 127 | interp, 128 | grad_outputs=self.ones, 129 | create_graph=True) 130 | batch_size = real[0].shape[0] 131 | grad_d = torch.cat([g.view(batch_size, -1) for g in grad_d], 1) 132 | grad_penalty = ((grad_d.norm(dim=1) - 1)**2).mean() * self.gp_lambda 133 | return grad_penalty 134 | 135 | 136 | def train_pbigan(args): 137 | torch.manual_seed(args.seed) 138 | 139 | if args.mask == 'indep': 140 | data = IndepMaskedMNIST(obs_prob=args.obs_prob, 141 | obs_prob_max=args.obs_prob_max) 142 | mask_str = f'{args.mask}_{args.obs_prob}_{args.obs_prob_max}' 143 | elif args.mask == 'block': 144 | data = BlockMaskedMNIST(block_len=args.block_len, 145 | block_len_max=args.block_len_max) 146 | mask_str = f'{args.mask}_{args.block_len}_{args.block_len_max}' 147 | 148 | data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True, 149 | drop_last=True) 150 | mask_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True, 151 | drop_last=True) 152 | 153 | # Evaluate the training progress using 2000 examples from the training data 154 | test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True) 155 | 156 | decoder = ConvDecoder(args.latent) 157 | encoder = ConvEncoder(args.latent, args.flow, logprob=False) 158 | pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device) 159 | 160 | critic = ConvCritic(args.latent).to(device) 161 | 162 | lrate = 1e-4 163 | optimizer = optim.Adam(pbigan.parameters(), lr=lrate, betas=(.5, .9)) 164 | 165 | critic_optimizer = optim.Adam( 166 | critic.parameters(), lr=lrate, betas=(.5, .9)) 167 | 168 | grad_penalty = GradientPenalty(critic, args.batch_size) 169 | 170 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch) 171 | 172 | path = '{}_{}_{}'.format( 173 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), mask_str) 174 | output_dir = Path('results') / 'mnist-pbigan' / path 175 | mkdir(output_dir) 176 | print(output_dir) 177 | 178 | if args.save_interval > 0: 179 | model_dir = mkdir(output_dir / 'model') 180 | 181 | logging.basicConfig( 182 | level=logging.INFO, 183 | format='%(asctime)s %(message)s', 184 | datefmt='%Y-%m-%d %H:%M:%S', 185 | handlers=[ 186 | logging.FileHandler(output_dir / 'log.txt'), 187 | logging.StreamHandler(sys.stdout), 188 | ], 189 | ) 190 | 191 | with (output_dir / 'args.txt').open('w') as f: 192 | print(pprint.pformat(vars(args)), file=f) 193 | 194 | vis = Visualizer(output_dir) 195 | 196 | test_x, test_mask, index = iter(test_loader).next() 197 | test_x = test_x.to(device) 198 | test_mask = test_mask.to(device).float() 199 | bbox = None 200 | if data.mask_loc is not None: 201 | bbox = [data.mask_loc[idx] for idx in index] 202 | 203 | n_critic = 5 204 | critic_updates = 0 205 | ae_weight = 0 206 | ae_flat = 100 207 | 208 | for epoch in range(args.epoch): 209 | loss_breakdown = defaultdict(float) 210 | 211 | if epoch > ae_flat: 212 | ae_weight = args.ae * (epoch - ae_flat) / (args.epoch - ae_flat) 213 | 214 | for (x, mask, _), (_, mask_gen, _) in zip(data_loader, mask_loader): 215 | x = x.to(device) 216 | mask = mask.to(device).float() 217 | mask_gen = mask_gen.to(device).float() 218 | 219 | z_enc, z_gen, x_rec, x_gen, _ = pbigan(x, mask, ae=False) 220 | 221 | real_score = critic((x * mask, z_enc)).mean() 222 | fake_score = critic((x_gen * mask_gen, z_gen)).mean() 223 | 224 | w_dist = real_score - fake_score 225 | D_loss = -w_dist + grad_penalty((x * mask, z_enc), 226 | (x_gen * mask_gen, z_gen)) 227 | 228 | critic_optimizer.zero_grad() 229 | D_loss.backward() 230 | critic_optimizer.step() 231 | 232 | loss_breakdown['D'] += D_loss.item() 233 | 234 | critic_updates += 1 235 | 236 | if critic_updates == n_critic: 237 | critic_updates = 0 238 | 239 | # Update generators' parameters 240 | for p in critic.parameters(): 241 | p.requires_grad_(False) 242 | 243 | z_enc, z_gen, x_rec, x_gen, ae_loss = pbigan(x, mask) 244 | 245 | real_score = critic((x * mask, z_enc)).mean() 246 | fake_score = critic((x_gen * mask_gen, z_gen)).mean() 247 | 248 | G_loss = real_score - fake_score 249 | 250 | ae_loss = ae_loss * ae_weight 251 | loss = G_loss + ae_loss 252 | 253 | optimizer.zero_grad() 254 | loss.backward() 255 | optimizer.step() 256 | 257 | loss_breakdown['G'] += G_loss.item() 258 | loss_breakdown['AE'] += ae_loss.item() 259 | loss_breakdown['total'] += loss.item() 260 | 261 | for p in critic.parameters(): 262 | p.requires_grad_(True) 263 | 264 | if scheduler: 265 | scheduler.step() 266 | 267 | vis.plot_loss(epoch, loss_breakdown) 268 | 269 | if epoch % args.plot_interval == 0: 270 | with torch.no_grad(): 271 | pbigan.eval() 272 | z, z_gen, x_rec, x_gen, ae_loss = pbigan(test_x, test_mask) 273 | pbigan.train() 274 | vis.plot(epoch, test_x, test_mask, bbox, x_rec, x_gen) 275 | 276 | model_dict = { 277 | 'pbigan': pbigan.state_dict(), 278 | 'critic': critic.state_dict(), 279 | 'history': vis.history, 280 | 'epoch': epoch, 281 | 'args': args, 282 | } 283 | torch.save(model_dict, str(output_dir / 'model.pth')) 284 | if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0: 285 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) 286 | 287 | print(output_dir) 288 | 289 | 290 | def main(): 291 | parser = argparse.ArgumentParser() 292 | 293 | parser.add_argument('--seed', type=int, default=3, 294 | help='random seed') 295 | # training options 296 | parser.add_argument('--plot-interval', type=int, default=50, 297 | help='plot interval. 0 to disable plotting.') 298 | parser.add_argument('--save-interval', type=int, default=0, 299 | help='interval to save models. 0 to disable saving.') 300 | parser.add_argument('--mask', default='block', 301 | help='missing data mask. (options: block, indep)') 302 | # option for block: set to 0 for variable size 303 | parser.add_argument('--block-len', type=int, default=12, 304 | help='size of observed block') 305 | parser.add_argument('--block-len-max', type=int, default=None, 306 | help='max size of observed block. ' 307 | 'Use fixed-size observed block if unspecified.') 308 | # option for indep: 309 | parser.add_argument('--obs-prob', type=float, default=.2, 310 | help='observed probability for independent dropout') 311 | parser.add_argument('--obs-prob-max', type=float, default=None, 312 | help='max observed probability for independent ' 313 | 'dropout. Use fixed probability if unspecified.') 314 | 315 | parser.add_argument('--flow', type=int, default=2, 316 | help='number of IAF layers') 317 | parser.add_argument('--lr', type=float, default=1e-3, 318 | help='learning rate') 319 | parser.add_argument('--min-lr', type=float, default=-1, 320 | help='min learning rate for LR scheduler. ' 321 | '-1 to disable annealing') 322 | 323 | parser.add_argument('--arch', default='conv', 324 | help='network architecture. (options: fc, conv)') 325 | parser.add_argument('--epoch', type=int, default=2000, 326 | help='number of training epochs') 327 | parser.add_argument('--batch-size', type=int, default=128, 328 | help='batch size') 329 | parser.add_argument('--ae', type=float, default=.1, 330 | help='autoencoding regularization strength') 331 | parser.add_argument('--prefix', default='pbigan', 332 | help='prefix of output directory') 333 | parser.add_argument('--latent', type=int, default=128, 334 | help='dimension of latent variable') 335 | parser.add_argument('--aeloss', default='bce', 336 | help='autoencoding loss. ' 337 | '(options: mse, bce, smooth_l1, l1)') 338 | 339 | args = parser.parse_args() 340 | 341 | train_pbigan(args) 342 | 343 | 344 | if __name__ == '__main__': 345 | main() 346 | -------------------------------------------------------------------------------- /image/mnist_pvae.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | from torch.distributions.normal import Normal 7 | from torch.distributions.categorical import Categorical 8 | from torch.utils.data import DataLoader 9 | import math 10 | import sys 11 | import logging 12 | from pathlib import Path 13 | from datetime import datetime 14 | import pprint 15 | import argparse 16 | from masked_mnist import IndepMaskedMNIST, BlockMaskedMNIST 17 | from mnist_decoder import ConvDecoder 18 | from mnist_encoder import ConvEncoder 19 | from utils import mkdir, make_scheduler 20 | from visualize import Visualizer 21 | 22 | 23 | use_cuda = torch.cuda.is_available() 24 | device = torch.device('cuda' if use_cuda else 'cpu') 25 | 26 | 27 | class PVAE(nn.Module): 28 | def __init__(self, encoder, decoder): 29 | super().__init__() 30 | self.encoder = encoder 31 | self.decoder = decoder 32 | 33 | def forward(self, x, mask, k_samples=5, kl_weight=1): 34 | z_T, log_q_z = self.encoder(x * mask, k_samples) 35 | 36 | pz = Normal(torch.zeros_like(z_T), torch.ones_like(z_T)) 37 | log_p_z = pz.log_prob(z_T).sum(-1) 38 | # kl_loss: log q(z|x) - log p(z) 39 | kl_loss = log_q_z - log_p_z 40 | 41 | # Reshape z to accommodate modules with strict input shape requirements 42 | # such as convolutional layers. 43 | x_logit, x_recon = self.decoder(z_T.view(-1, *z_T.shape[2:])) 44 | flat_mask = mask.view(x.shape[0], -1) 45 | flat_logit = x_logit.view(*z_T.shape[:2], -1) * flat_mask 46 | flat_x = (x * mask).view(1, x.shape[0], -1).expand_as(flat_logit) 47 | # Bernoulli noise 48 | # recon_loss: -log p(x|z) 49 | recon_loss = (F.binary_cross_entropy_with_logits( 50 | flat_logit, flat_x, reduction='none') * flat_mask).sum(-1) 51 | 52 | # elbo = log p(x|z) + log p(z) - log q(z|x) 53 | elbo = -(recon_loss + kl_loss * kl_weight) 54 | 55 | # IWAE loss: -log E[p(x|z) p(z) / q(z|x)] 56 | # Here we ignore the constant shift of -log(k_samples) 57 | loss = -elbo.logsumexp(0).mean() 58 | 59 | x_recon = x_recon.view(-1, *x.shape) 60 | loss_breakdown = { 61 | 'loss': loss.item(), 62 | 'KL': kl_loss.mean().item(), 63 | 'recon': recon_loss.mean().item(), 64 | } 65 | return loss, z_T, x_recon, elbo, loss_breakdown 66 | 67 | def impute(self, x, mask, k_samples=10): 68 | self.eval() 69 | with torch.no_grad(): 70 | _, z, x_recon, elbo, _ = self(x, mask, k_samples) 71 | # sampling importance resampling 72 | is_idx = Categorical(logits=elbo.t()).sample() 73 | batch_idx = torch.arange(len(x)) 74 | z = z[is_idx, batch_idx] 75 | x_recon = x_recon[is_idx, batch_idx] 76 | self.train() 77 | return x_recon 78 | 79 | 80 | def train_pvae(args): 81 | torch.manual_seed(args.seed) 82 | 83 | if args.mask == 'indep': 84 | data = IndepMaskedMNIST(obs_prob=args.obs_prob, 85 | obs_prob_max=args.obs_prob_max) 86 | mask_str = f'{args.mask}_{args.obs_prob}_{args.obs_prob_max}' 87 | elif args.mask == 'block': 88 | data = BlockMaskedMNIST(block_len=args.block_len, 89 | block_len_max=args.block_len_max) 90 | mask_str = f'{args.mask}_{args.block_len}_{args.block_len_max}' 91 | 92 | data_loader = DataLoader(data, batch_size=args.batch_size, shuffle=True, 93 | drop_last=True) 94 | 95 | # Evaluate the training progress using 2000 examples from the training data 96 | test_loader = DataLoader(data, batch_size=args.batch_size, drop_last=True) 97 | 98 | decoder = ConvDecoder(args.latent) 99 | encoder = ConvEncoder(args.latent, args.flow, logprob=True) 100 | pvae = PVAE(encoder, decoder).to(device) 101 | 102 | optimizer = optim.Adam(pvae.parameters(), lr=args.lr) 103 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, args.epoch) 104 | 105 | rand_z = torch.empty(args.batch_size, args.latent, device=device) 106 | 107 | path = '{}_{}_{}'.format( 108 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), mask_str) 109 | output_dir = Path('results') / 'mnist-pvae' / path 110 | mkdir(output_dir) 111 | print(output_dir) 112 | 113 | if args.save_interval > 0: 114 | model_dir = mkdir(output_dir / 'model') 115 | 116 | logging.basicConfig( 117 | level=logging.INFO, 118 | format='%(asctime)s %(message)s', 119 | datefmt='%Y-%m-%d %H:%M:%S', 120 | handlers=[ 121 | logging.FileHandler(output_dir / 'log.txt'), 122 | logging.StreamHandler(sys.stdout), 123 | ], 124 | ) 125 | 126 | with (output_dir / 'args.txt').open('w') as f: 127 | print(pprint.pformat(vars(args)), file=f) 128 | 129 | vis = Visualizer(output_dir) 130 | 131 | test_x, test_mask, index = iter(test_loader).next() 132 | test_x = test_x.to(device) 133 | test_mask = test_mask.to(device).float() 134 | bbox = None 135 | if data.mask_loc is not None: 136 | bbox = [data.mask_loc[idx] for idx in index] 137 | 138 | kl_center = (args.kl_on + args.kl_off) / 2 139 | kl_scale = 12 / min(args.kl_on - args.kl_off, 1) 140 | 141 | for epoch in range(args.epoch): 142 | if epoch >= args.kl_on: 143 | kl_weight = 1 144 | elif epoch < args.kl_off: 145 | kl_weight = 0 146 | else: 147 | kl_weight = 1 / (1 + math.exp(-(epoch - kl_center) * kl_scale)) 148 | loss_breakdown = defaultdict(float) 149 | for x, mask, _ in data_loader: 150 | x = x.to(device) 151 | mask = mask.to(device).float() 152 | 153 | optimizer.zero_grad() 154 | loss, _, _, _, loss_info = pvae( 155 | x, mask, args.k, kl_weight=kl_weight) 156 | loss.backward() 157 | optimizer.step() 158 | for name, val in loss_info.items(): 159 | loss_breakdown[name] += val 160 | 161 | if scheduler: 162 | scheduler.step() 163 | 164 | vis.plot_loss(epoch, loss_breakdown) 165 | 166 | if epoch % args.plot_interval == 0: 167 | x_recon = pvae.impute(test_x, test_mask, args.k) 168 | with torch.no_grad(): 169 | pvae.eval() 170 | rand_z.normal_() 171 | _, x_gen = decoder(rand_z) 172 | pvae.train() 173 | vis.plot(epoch, test_x, test_mask, bbox, x_recon, x_gen) 174 | 175 | model_dict = { 176 | 'pvae': pvae.state_dict(), 177 | 'history': vis.history, 178 | 'epoch': epoch, 179 | 'args': args, 180 | } 181 | torch.save(model_dict, str(output_dir / 'model.pth')) 182 | if args.save_interval > 0 and (epoch + 1) % args.save_interval == 0: 183 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) 184 | 185 | print(output_dir) 186 | 187 | 188 | def main(): 189 | parser = argparse.ArgumentParser() 190 | 191 | parser.add_argument('--seed', type=int, default=3, 192 | help='random seed') 193 | # training options 194 | parser.add_argument('--plot-interval', type=int, default=50, 195 | help='plot interval. 0 to disable plotting.') 196 | parser.add_argument('--save-interval', type=int, default=50, 197 | help='interval to save models. 0 to disable saving.') 198 | parser.add_argument('--mask', default='block', 199 | help='missing data mask. (options: block, indep)') 200 | # option for block: set to 0 for variable size 201 | parser.add_argument('--block-len', type=int, default=12, 202 | help='size of observed block. ' 203 | 'Set to 0 to use variable size') 204 | parser.add_argument('--block-len-max', type=int, default=None, 205 | help='max size of observed block. ' 206 | 'Use fixed-size observed block if unspecified.') 207 | # option for indep: 208 | parser.add_argument('--obs-prob', type=float, default=.2, 209 | help='observed probability for independent dropout') 210 | parser.add_argument('--obs-prob-max', type=float, default=None, 211 | help='max observed probability for independent ' 212 | 'dropout. Use fixed probability if unspecified.') 213 | 214 | parser.add_argument('--flow', type=int, default=2, 215 | help='number of IAF layers') 216 | parser.add_argument('--lr', type=float, default=1e-3, 217 | help='learning rate') 218 | parser.add_argument('--min-lr', type=float, default=-1, 219 | help='min learning rate for LR scheduler. ' 220 | '-1 to disable annealing') 221 | 222 | parser.add_argument('--epoch', type=int, default=4000, 223 | help='number of training epochs') 224 | parser.add_argument('--batch-size', type=int, default=128, 225 | help='batch size') 226 | parser.add_argument('--k', type=int, default=5, 227 | help='number of importance weights') 228 | parser.add_argument('--prefix', default='pvae', 229 | help='prefix of output directory') 230 | parser.add_argument('--latent', type=int, default=128, 231 | help='dimension of latent variable') 232 | parser.add_argument('--kl-off', type=int, default=200, 233 | help='epoch to start tune up KL weight from zero') 234 | # set --kl-on to 0 to use constant kl_weight = 1 235 | parser.add_argument('--kl-on', type=int, default=0, 236 | help='start epoch to use KL weight 1') 237 | 238 | args = parser.parse_args() 239 | 240 | train_pvae(args) 241 | 242 | 243 | if __name__ == '__main__': 244 | main() 245 | -------------------------------------------------------------------------------- /image/utils.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | 3 | 4 | def count_parameters(model): 5 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 6 | 7 | 8 | def mkdir(path): 9 | path.mkdir(parents=True, exist_ok=True) 10 | return path 11 | 12 | 13 | def make_scheduler(optimizer, lr, min_lr, epochs, steps=10): 14 | if min_lr < 0: 15 | return None 16 | step_size = epochs // steps 17 | gamma = (min_lr / lr)**(1 / steps) 18 | return optim.lr_scheduler.StepLR( 19 | optimizer, step_size=step_size, gamma=gamma) 20 | -------------------------------------------------------------------------------- /image/visualize.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import matplotlib.gridspec as gridspec 3 | from matplotlib.patches import Rectangle 4 | import seaborn as sns 5 | import logging 6 | import sys 7 | from collections import defaultdict 8 | from pathlib import Path 9 | from utils import mkdir 10 | 11 | 12 | class Visualizer(object): 13 | def __init__(self, 14 | output_dir, 15 | loss_xlim=None, 16 | loss_ylim=None): 17 | self.output_dir = Path(output_dir) 18 | self.recons_dir = mkdir(self.output_dir / 'recons') 19 | self.gen_dir = mkdir(self.output_dir / 'gen') 20 | self.loss_xlim = loss_xlim 21 | self.loss_ylim = loss_ylim 22 | sns.set() 23 | self.rows, self.cols = 8, 16 24 | self.at_start = True 25 | 26 | self.history = defaultdict(list) 27 | logging.basicConfig( 28 | level=logging.INFO, 29 | format='%(asctime)s %(message)s', 30 | datefmt='%Y-%m-%d %H:%M:%S', 31 | handlers=[ 32 | logging.FileHandler(output_dir / 'log.txt'), 33 | logging.StreamHandler(sys.stdout), 34 | ], 35 | ) 36 | self.print_header = True 37 | 38 | def plot_subgrid(self, images, bbox=None, filename=None): 39 | rows, cols = 8, 16 40 | scale = .75 41 | fig, ax = plt.subplots(figsize=(cols * scale, rows * scale)) 42 | ax.set_axis_off() 43 | ax.set_xticks([]) 44 | ax.set_yticks([]) 45 | 46 | inner_grid = gridspec.GridSpec(rows, cols, fig, wspace=.05, hspace=.05, 47 | left=0, right=1, top=1, bottom=0) 48 | 49 | images = images[:(rows * cols)].cpu().numpy() 50 | if images.shape[1] == 1: # single channel 51 | images = images.squeeze(1) 52 | cmap = 'binary_r' 53 | else: # 3 channels 54 | images = images.transpose((0, 2, 3, 1)) 55 | cmap = None 56 | 57 | for i, image in enumerate(images): 58 | ax = plt.Subplot(fig, inner_grid[i]) 59 | ax.set_axis_off() 60 | ax.set_xticks([]) 61 | ax.set_yticks([]) 62 | ax.set_aspect('equal') 63 | ax.imshow(image, interpolation='none', aspect='equal', 64 | cmap=cmap, vmin=0, vmax=1) 65 | 66 | if bbox is not None: 67 | d0, d1, d0_len, d1_len = bbox[i] 68 | ax.add_patch(Rectangle( 69 | (d1 - .5, d0 - .5), d1_len, d0_len, lw=1, 70 | edgecolor='red', fill=False)) 71 | fig.add_subplot(ax) 72 | 73 | if filename is not None: 74 | plt.savefig(str(filename)) 75 | plt.close(fig) 76 | 77 | def plot(self, epoch, x, mask, bbox, x_recon, x_gen): 78 | if self.at_start: 79 | self.plot_subgrid(x * mask + .5 * (1 - mask), bbox, 80 | self.recons_dir / f'groundtruth.png') 81 | self.at_start = False 82 | self.plot_subgrid(x * mask + x_recon * (1 - mask), bbox, 83 | self.recons_dir / f'{epoch:04d}.png') 84 | self.plot_subgrid(x_gen, None, self.gen_dir / f'{epoch:04d}.png') 85 | 86 | def plot_loss(self, epoch, losses): 87 | for name, val in losses.items(): 88 | self.history[name].append(val) 89 | 90 | fig, ax_trace = plt.subplots(figsize=(6, 4)) 91 | ax_trace.set_ylabel('loss') 92 | ax_trace.set_xlabel('epochs') 93 | if self.loss_xlim is not None: 94 | ax_trace.set_xlim(self.loss_xlim) 95 | if self.loss_ylim is not None: 96 | ax_trace.set_ylim(self.loss_ylim) 97 | for label, loss in self.history.items(): 98 | ax_trace.plot(loss, '-', label=label) 99 | if len(self.history) > 1: 100 | ax_trace.legend(ncol=len(self.history), loc='upper center') 101 | plt.tight_layout() 102 | plt.savefig(str(self.output_dir / 'loss.png'), dpi=300) 103 | plt.close(fig) 104 | 105 | if self.print_header: 106 | logging.info(' ' * 7 + ' '.join( 107 | f'{key:>12}' for key in sorted(losses))) 108 | self.print_header = False 109 | logging.info(f'[{epoch:4}] ' + ' '.join( 110 | f'{val:12.4f}' for _, val in sorted(losses.items()))) 111 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.1.1 2 | numpy==1.17.0 3 | Pillow>=7.1.0 4 | scikit-learn==0.21.3 5 | scipy==1.4.1 6 | seaborn==0.9.0 7 | torch==1.1.0 8 | torch-spline-conv==1.1.0 9 | torchvision==0.3.0 10 | -------------------------------------------------------------------------------- /time-series/.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-language=Python 2 | -------------------------------------------------------------------------------- /time-series/ema.py: -------------------------------------------------------------------------------- 1 | class EMA: 2 | def __init__(self, model, decay, start_iter=0): 3 | self.model = model 4 | self.beta = 1 - decay 5 | self.start_iter = start_iter 6 | self.iter = 0 7 | self.shadow = None 8 | 9 | def state_dict(self): 10 | """Returns the EMA state as a dictionary for serialization.""" 11 | # NOTE: skip saving `model` 12 | return { 13 | 'beta': self.beta, 14 | 'start_iter': self.start_iter, 15 | 'iter': self.iter, 16 | 'shadow': self.shadow, 17 | } 18 | 19 | def update(self): 20 | if self.iter < self.start_iter: 21 | self.iter += 1 22 | else: 23 | if self.shadow is None: 24 | self.shadow = {} 25 | for name, param in self.model.named_parameters(): 26 | if param.requires_grad: 27 | self.shadow[name] = param.data.clone() 28 | for name, param in self.model.named_parameters(): 29 | if param.requires_grad: 30 | # p = p - (1 - delay) * (p - theta) 31 | # = delay * p + (1 - delay) * theta 32 | self.shadow[name].sub_( 33 | self.beta * (self.shadow[name] - param.data)) 34 | 35 | def apply(self): 36 | if self.shadow is None: 37 | return 38 | for name, param in self.model.named_parameters(): 39 | if param.requires_grad: 40 | self.shadow[name], param.data = param.data, self.shadow[name] 41 | 42 | def restore(self): 43 | self.apply() 44 | -------------------------------------------------------------------------------- /time-series/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import roc_auc_score 4 | 5 | 6 | class Evaluator: 7 | def __init__(self, model, val_loader, test_loader, 8 | log_dir, eval_args={}): 9 | self.model = model 10 | self.val_loader = val_loader 11 | self.test_loader = test_loader 12 | self.log_dir = log_dir 13 | self.eval_args = eval_args 14 | 15 | self.test_auc_log, self.val_auc_log = [], [] 16 | self.best_auc, self.best_val_auc = [-float('inf')] * 2 17 | 18 | def evaluate(self, epoch): 19 | self.model.eval() 20 | with torch.no_grad(): 21 | val_auc = self.compute_auc(self.val_loader) 22 | print('val AUC:', val_auc) 23 | self.val_auc_log.append((epoch, val_auc)) 24 | 25 | test_auc = self.compute_auc(self.test_loader) 26 | print('test AUC:', test_auc) 27 | self.test_auc_log.append((epoch, test_auc)) 28 | self.model.train() 29 | 30 | if val_auc > self.best_auc: 31 | self.best_auc = val_auc 32 | torch.save({'model': self.model.state_dict()}, 33 | str(self.log_dir / 'best-model.pth')) 34 | 35 | with (self.log_dir / 'val_auc.txt').open('a') as f: 36 | print(epoch, val_auc, file=f) 37 | with (self.log_dir / 'test_auc.txt').open('a') as f: 38 | print(epoch, test_auc, file=f) 39 | 40 | def compute_auc(self, data_loader): 41 | y_true, y_score = [], [] 42 | for (val, idx, mask, y, _, cconv_graph) in data_loader: 43 | score = self.model.predict( 44 | val, idx, mask, cconv_graph, **self.eval_args) 45 | y_score.append(score.cpu().numpy()) 46 | y_true.append(y.cpu().numpy()) 47 | 48 | y_true = np.concatenate(y_true) 49 | y_score = np.concatenate(y_score) 50 | return roc_auc_score(y_true, y_score) 51 | -------------------------------------------------------------------------------- /time-series/figures/pbigan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/steveli/partial-encoder-decoder/402abdb97b1cdfcca2a30b05507cd97be560ee75/time-series/figures/pbigan.png -------------------------------------------------------------------------------- /time-series/figures/pvae.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/steveli/partial-encoder-decoder/402abdb97b1cdfcca2a30b05507cd97be560ee75/time-series/figures/pvae.png -------------------------------------------------------------------------------- /time-series/flow.py: -------------------------------------------------------------------------------- 1 | """Code adapted from https://github.com/altosaar/variational-autoencoder""" 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class InverseAutoregressiveFlow(nn.Module): 9 | """Inverse Autoregressive Flows with LSTM-type update. One block. 10 | 11 | Eq 11-14 of https://arxiv.org/abs/1606.04934 12 | """ 13 | def __init__(self, num_input, num_hidden, num_context): 14 | super().__init__() 15 | self.made = MADE(num_input=num_input, num_output=num_input * 2, 16 | num_hidden=num_hidden, num_context=num_context) 17 | # init such that sigmoid(s) is close to 1 for stability 18 | self.sigmoid_arg_bias = nn.Parameter(torch.ones(num_input) * 2) 19 | self.sigmoid = nn.Sigmoid() 20 | self.log_sigmoid = nn.LogSigmoid() 21 | 22 | def forward(self, input, context=None): 23 | m, s = torch.chunk(self.made(input, context), chunks=2, dim=-1) 24 | s = s + self.sigmoid_arg_bias 25 | sigmoid = self.sigmoid(s) 26 | z = sigmoid * input + (1 - sigmoid) * m 27 | return z, -self.log_sigmoid(s) 28 | 29 | 30 | class FlowSequential(nn.Sequential): 31 | """Forward pass.""" 32 | 33 | def forward(self, input, context=None): 34 | total_log_prob = torch.zeros_like(input) 35 | for block in self._modules.values(): 36 | input, log_prob = block(input, context) 37 | total_log_prob += log_prob 38 | return input, total_log_prob 39 | 40 | 41 | class MaskedLinear(nn.Module): 42 | """Linear layer with some input-output connections masked.""" 43 | def __init__(self, in_features, out_features, mask, context_features=None, 44 | bias=True): 45 | super().__init__() 46 | self.linear = nn.Linear(in_features, out_features, bias) 47 | self.register_buffer("mask", mask) 48 | if context_features is not None: 49 | self.cond_linear = nn.Linear(context_features, out_features, 50 | bias=False) 51 | 52 | def forward(self, input, context=None): 53 | output = F.linear(input, self.mask * self.linear.weight, 54 | self.linear.bias) 55 | if context is None: 56 | return output 57 | else: 58 | return output + self.cond_linear(context) 59 | 60 | 61 | class MADE(nn.Module): 62 | """Implements MADE: Masked Autoencoder for Distribution Estimation. 63 | 64 | Follows https://arxiv.org/abs/1502.03509 65 | 66 | This is used to build MAF: 67 | Masked Autoregressive Flow (https://arxiv.org/abs/1705.07057). 68 | """ 69 | def __init__(self, num_input, num_output, num_hidden, num_context): 70 | super().__init__() 71 | # m corresponds to m(k), the maximum degree of a node in the MADE paper 72 | self._m = [] 73 | self._masks = [] 74 | self._build_masks(num_input, num_output, num_hidden, num_layers=3) 75 | self._check_masks() 76 | modules = [] 77 | self.input_context_net = MaskedLinear( 78 | num_input, num_hidden, self._masks[0], num_context) 79 | modules.append(nn.ReLU()) 80 | modules.append(MaskedLinear( 81 | num_hidden, num_hidden, self._masks[1], context_features=None)) 82 | modules.append(nn.ReLU()) 83 | modules.append(MaskedLinear( 84 | num_hidden, num_output, self._masks[2], context_features=None)) 85 | self.net = nn.Sequential(*modules) 86 | 87 | def _build_masks(self, num_input, num_output, num_hidden, num_layers): 88 | """Build the masks according to Eq 12 and 13 in the MADE paper.""" 89 | rng = np.random.RandomState(0) 90 | # assign input units a number between 1 and D 91 | self._m.append(np.arange(1, num_input + 1)) 92 | for i in range(1, num_layers + 1): 93 | # randomly assign maximum number of input nodes to connect to 94 | if i == num_layers: 95 | # assign output layer units a number between 1 and D 96 | m = np.arange(1, num_input + 1) 97 | assert num_output % num_input == 0, ( 98 | "num_output must be multiple of num_input") 99 | self._m.append(np.hstack( 100 | [m for _ in range(num_output // num_input)])) 101 | else: 102 | # assign hidden layer units a number between 1 and D-1 103 | self._m.append(rng.randint(1, num_input, size=num_hidden)) 104 | # self._m.append( 105 | # np.arange(1, num_hidden + 1) % (num_input - 1) + 1) 106 | if i == num_layers: 107 | mask = self._m[i][None, :] > self._m[i - 1][:, None] 108 | else: 109 | # input to hidden & hidden to hidden 110 | mask = self._m[i][None, :] >= self._m[i - 1][:, None] 111 | # need to transpose for torch linear layer. 112 | # shape (num_output, num_input) 113 | self._masks.append(torch.from_numpy(mask.astype(np.float32).T)) 114 | 115 | def _check_masks(self): 116 | """Check that the connectivity matrix between layers is lower 117 | triangular.""" 118 | # (num_input, num_hidden) 119 | prev = self._masks[0].t() 120 | for i in range(1, len(self._masks)): 121 | # num_hidden is second axis 122 | prev = prev @ self._masks[i].t() 123 | final = prev.numpy() 124 | num_input = self._masks[0].shape[1] 125 | num_output = self._masks[-1].shape[0] 126 | assert final.shape == (num_input, num_output) 127 | if num_output == num_input: 128 | assert np.triu(final).all() == 0 129 | else: 130 | for submat in np.split( 131 | final, indices_or_sections=num_output // num_input, 132 | axis=1): 133 | assert np.triu(submat).all() == 0 134 | 135 | def forward(self, input, context=None): 136 | # first hidden layer receives input and context 137 | hidden = self.input_context_net(input, context) 138 | # rest of the network is conditioned on both input and context 139 | return self.net(hidden) 140 | 141 | 142 | class Reverse(nn.Module): 143 | """ An implementation of a reversing layer from 144 | Density estimation using Real NVP 145 | (https://arxiv.org/abs/1605.08803). 146 | 147 | From https://github.com/ikostrikov/pytorch-flows/blob/master/main.py 148 | """ 149 | 150 | def __init__(self, num_input): 151 | super(Reverse, self).__init__() 152 | self.perm = np.array(np.arange(0, num_input)[::-1]) 153 | self.inv_perm = np.argsort(self.perm) 154 | 155 | def forward(self, inputs, context=None, mode='forward'): 156 | if mode == "forward": 157 | return inputs[:, :, self.perm], torch.zeros_like(inputs) 158 | elif mode == "inverse": 159 | return inputs[:, :, self.inv_perm], torch.zeros_like(inputs) 160 | else: 161 | raise ValueError("Mode must be one of {forward, inverse}.") 162 | -------------------------------------------------------------------------------- /time-series/gen_toy_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.distributions.exponential import Exponential 4 | import math 5 | 6 | 7 | class HomogeneousPoissonProcess: 8 | def __init__(self, rate=1): 9 | self.rate = rate 10 | self.exp = Exponential(rate) 11 | 12 | def sample(self, size, max_seq_len, max_time=math.inf): 13 | gaps = self.exp.sample((size, max_seq_len)) 14 | times = torch.cumsum(gaps, dim=1) 15 | masks = (times <= max_time).float() 16 | return times, masks 17 | 18 | 19 | def gen_data(n_samples=10000, seq_len=200, max_time=1, poisson_rate=50, 20 | obs_span_rate=.25, save_file=None): 21 | """Generates a 3-channel synthetic dataset. 22 | 23 | The observations are within a window of size (max_time * obs_span_rate) 24 | randomly occurring at the time span [0, max_time]. 25 | 26 | Args: 27 | n_samples: 28 | Number of data cases. 29 | seq_len: 30 | Maximum number of observations in a channel. 31 | max_time: 32 | Length of time interval [0, max_time]. 33 | poisson_rate: 34 | Rate of homogeneous Poisson process. 35 | obs_span_rate: 36 | The continuous portion of the time span [0, max_time] 37 | that observations are restricted in. 38 | save_file: 39 | File name that the generated data is saved to. 40 | """ 41 | n_channels = 3 42 | time_unif = np.linspace(0, max_time, seq_len) 43 | time_unif_3ch = np.broadcast_to(time_unif, (n_channels, seq_len)) 44 | data_unif = np.empty((n_samples, n_channels, seq_len)) 45 | sparse_data, sparse_time, sparse_mask = [ 46 | np.empty((n_samples, n_channels, seq_len)) for _ in range(3)] 47 | tpp = HomogeneousPoissonProcess(rate=poisson_rate) 48 | 49 | def gen_time_series(offset1, offset2, t): 50 | t1 = t[0] + offset1 51 | t2 = t[2] + offset2 52 | t1_shift = t[1] + offset1 + 20 53 | data = np.empty((3, seq_len)) 54 | data[0] = np.sin(t1 * 20 + np.sin(t1 * 20)) * .8 55 | data[1] = -np.sin(t1_shift * 20 + np.sin(t1_shift * 20)) * .5 56 | data[2] = np.sin(t2 * 12) 57 | return data 58 | 59 | for i in range(n_samples): 60 | offset1 = np.random.normal(0, 10) 61 | offset2 = np.random.uniform(0, 10) 62 | 63 | # Noise-free evenly-sampled time series 64 | data_unif[i] = gen_time_series(offset1, offset2, time_unif_3ch) 65 | 66 | # Generate observations between [0, obs_span_rate]. 67 | times, masks = tpp.sample(3, seq_len, max_time=obs_span_rate) 68 | # Add independent random offset Unif(0, 1 - obs_span_rate) to each 69 | # channel so that all the observations will still be within [0, 1]. 70 | times += torch.rand((3, 1)) * (1 - obs_span_rate) 71 | # Scale time span from [0, 1] to [0, max_time]. 72 | times *= max_time 73 | # Set time entries corresponding to unobserved samples to time 0. 74 | sparse_time[i] = times * masks 75 | sparse_mask[i] = masks 76 | sparse_data[i] = gen_time_series(offset1, offset2, times) 77 | 78 | # Add a small independent Gaussian noise to each channel 79 | sparse_data += np.random.normal(0, .01, sparse_data.shape) 80 | 81 | # Pack the data to minimize the padded entries 82 | compact_len = sparse_mask.astype(int).sum(axis=2).max() 83 | compact_data, compact_time, compact_mask = [ 84 | np.zeros((n_samples, 3, compact_len)) for _ in range(3)] 85 | for i in range(n_samples): 86 | for j in range(3): 87 | idx = sparse_mask[i, j] == 1 88 | n_obs = idx.sum() 89 | compact_data[i, j, :n_obs] = sparse_data[i, j, idx] 90 | compact_time[i, j, :n_obs] = sparse_time[i, j, idx] 91 | compact_mask[i, j, :n_obs] = sparse_mask[i, j, idx] 92 | 93 | if save_file: 94 | np.savez_compressed( 95 | save_file, 96 | time=compact_time, 97 | data=compact_data, 98 | mask=compact_mask, 99 | data_unif=data_unif, 100 | time_unif=time_unif, 101 | ) 102 | 103 | return compact_data, compact_time, compact_mask, data_unif, time_unif 104 | 105 | 106 | def main(): 107 | gen_data(n_samples=10000, seq_len=200, max_time=1, poisson_rate=50, 108 | obs_span_rate=.25, save_file='toy-data.npz') 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /time-series/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResLinearBlock(nn.Module): 7 | def __init__(self, in_size, out_size): 8 | super().__init__() 9 | self.linear = nn.Sequential( 10 | nn.Linear(in_size, out_size), 11 | nn.Dropout(), 12 | nn.LeakyReLU(.2), 13 | nn.Linear(out_size, out_size), 14 | nn.Dropout(), 15 | nn.LeakyReLU(.2), 16 | ) 17 | 18 | self.skip = nn.Sequential( 19 | nn.Linear(in_size, out_size), 20 | nn.LeakyReLU(.2), 21 | ) 22 | 23 | def forward(self, x): 24 | return self.linear(x) + self.skip(x) 25 | 26 | 27 | class Classifier(nn.Module): 28 | def __init__(self, in_size, layers=1): 29 | super().__init__() 30 | blocks = [] 31 | for _ in range(layers): 32 | blocks.append(ResLinearBlock(in_size, in_size)) 33 | # No spectral normalization for the last layer 34 | blocks.append(nn.Linear(in_size, 1)) 35 | self.res_linear = nn.Sequential(*blocks) 36 | 37 | def forward(self, x): 38 | return self.res_linear(x) 39 | 40 | 41 | class GBlock(nn.Module): 42 | def __init__(self, in_channels, out_channels): 43 | super().__init__() 44 | self.activation = nn.ReLU(inplace=False) 45 | self.bn1 = nn.BatchNorm1d(in_channels) 46 | self.bn2 = nn.BatchNorm1d(out_channels) 47 | self.conv1 = nn.Conv1d( 48 | in_channels, out_channels, kernel_size=3, padding=1) 49 | self.conv2 = nn.Conv1d( 50 | out_channels, out_channels, kernel_size=3, padding=1) 51 | self.convx = None 52 | if in_channels != out_channels: 53 | self.convx = nn.Conv1d( 54 | in_channels, out_channels, kernel_size=1) 55 | 56 | def forward(self, x): 57 | h = self.activation(self.bn1(x)) 58 | h = F.interpolate(h, scale_factor=2) 59 | x = F.interpolate(x, scale_factor=2) 60 | if self.convx: 61 | x = self.convx(x) 62 | h = self.conv1(h) 63 | h = self.activation(self.bn2(h)) 64 | h = self.conv2(h) 65 | return h + x 66 | 67 | 68 | class GridDecoder(nn.Module): 69 | def __init__(self, dim_z, channels, start_len=16, squash=None): 70 | super().__init__() 71 | self.activation = nn.ReLU(inplace=False) 72 | self.start_len = start_len 73 | self.linear = nn.Linear(dim_z, channels[0] * start_len) 74 | self.blocks = nn.Sequential( 75 | *[GBlock(in_channels, channels[c + 1]) 76 | for c, in_channels in enumerate(channels[:-2])]) 77 | self.output = nn.Sequential( 78 | nn.BatchNorm1d(channels[-2]), 79 | self.activation, 80 | nn.Conv1d(channels[-2], channels[-1], kernel_size=3, padding=1), 81 | ) 82 | self.squash = squash 83 | 84 | def forward(self, z): 85 | h = self.linear(z) 86 | h = h.view(h.shape[0], -1, self.start_len) 87 | h = self.blocks(h) 88 | h = self.output(h) 89 | if self.squash: 90 | h = self.squash(h) 91 | return h 92 | 93 | 94 | class DBlock(nn.Module): 95 | def __init__(self, in_channels, out_channels, downsample=True): 96 | super().__init__() 97 | self.activation = nn.ReLU(inplace=False) 98 | self.conv1 = nn.Conv1d( 99 | in_channels, out_channels, kernel_size=3, padding=1) 100 | self.conv2 = nn.Conv1d( 101 | out_channels, out_channels, kernel_size=3, padding=1) 102 | self.convx = None 103 | if in_channels != out_channels: 104 | self.convx = nn.Conv1d(in_channels, out_channels, kernel_size=1) 105 | self.downsample = None 106 | if downsample: 107 | self.downsample = nn.AvgPool1d(2) 108 | 109 | def shortcut(self, x): 110 | if self.convx: 111 | x = self.convx(x) 112 | if self.downsample: 113 | x = self.downsample(x) 114 | return x 115 | 116 | def forward(self, x): 117 | # pre-activation 118 | h = self.activation(x) 119 | h = self.conv1(h) 120 | h = self.conv2(self.activation(h)) 121 | if self.downsample: 122 | h = self.downsample(h) 123 | return h + self.shortcut(x) 124 | 125 | 126 | class GridEncoder(nn.Module): 127 | def __init__(self, channels, out_dim=1): 128 | super().__init__() 129 | self.activation = nn.ReLU(inplace=False) 130 | self.blocks = nn.Sequential( 131 | *[DBlock(in_channels, out_channels) 132 | for in_channels, out_channels 133 | in zip(channels[:-1], channels[1:])]) 134 | self.linear = nn.Linear(channels[-1], out_dim) 135 | 136 | def forward(self, x): 137 | h = x 138 | h = self.blocks(h) 139 | h = self.activation(h).sum(2) 140 | return self.linear(h) 141 | 142 | 143 | class Decoder(nn.Module): 144 | def __init__(self, grid_decoder, max_time=5, kernel_bw=None, dec_ref=128): 145 | super().__init__() 146 | if kernel_bw is None: 147 | self.kernel_bw = max_time / dec_ref * 3 148 | else: 149 | self.kernel_bw = kernel_bw 150 | # ref_times are the assigned time stamps for the evenly-spaced 151 | # generated sequences by conv1d. 152 | self.register_buffer('ref_times', torch.linspace(0, max_time, dec_ref)) 153 | self.ref_times = self.ref_times[:, None] 154 | self.grid_decoder = grid_decoder 155 | 156 | def forward(self, code, time, mask): 157 | """ 158 | Args: 159 | code: shape (batch_size, latent_size) 160 | time: shape (batch_size, channels, max_seq_len) 161 | mask: shape (batch_size, channels, max_seq_len) 162 | 163 | Returns: 164 | interpolated tensor of shape (batch_size, max_seq_len) 165 | """ 166 | # shape of x: (batch_size, n_channels, dec_ref) 167 | x = self.grid_decoder(code) 168 | 169 | # t_diff shape: (batch_size, n_channels, dec_ref, max_seq_len) 170 | t_diff = time[:, :, None] - self.ref_times 171 | 172 | # Epanechnikov quadratic kernel: 173 | # K_\lambda(x_0, x) = relu(3/4 * (1 - (|x_0 - x| / \lambda)^2)) 174 | # shape of w: (batch_size, n_channels, dec_ref, max_seq_len) 175 | w = F.relu((1 - (t_diff / self.kernel_bw)**2) * .75) 176 | # avoid divided by zero 177 | # normalizer = torch.clamp(w.sum(2), min=1e-6) 178 | # return ((x[:, :, :, None] * w).sum(2) * mask) / normalizer 179 | ks_x = ((x[:, :, :, None] * w).sum(2) * mask) / w.sum(2) 180 | return ks_x 181 | 182 | 183 | def gan_loss(real, fake, real_target, fake_target): 184 | real_score = sum(F.binary_cross_entropy_with_logits( 185 | r, r.new_tensor(real_target).expand_as(r)) for r in real) 186 | fake_score = sum(F.binary_cross_entropy_with_logits( 187 | f, f.new_tensor(fake_target).expand_as(f)) for f in fake) 188 | return real_score + fake_score 189 | -------------------------------------------------------------------------------- /time-series/mimic3_pbigan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import spectral_norm 5 | import torch.optim as optim 6 | from torch.distributions.normal import Normal 7 | from torch.utils.data import DataLoader 8 | import numpy as np 9 | from datetime import datetime 10 | import time 11 | from pathlib import Path 12 | import argparse 13 | from collections import defaultdict 14 | from spline_cconv import ContinuousConv1D 15 | import time_series 16 | from ema import EMA 17 | from utils import count_parameters, mkdir, make_scheduler 18 | from mmd import mmd 19 | from tracker import Tracker 20 | from evaluate import Evaluator 21 | from sn_layers import ( 22 | InvertibleLinearResNet, 23 | Classifier, 24 | GridDecoder, 25 | GridEncoder, 26 | Decoder, 27 | gan_loss, 28 | ) 29 | 30 | 31 | use_cuda = torch.cuda.is_available() 32 | device = torch.device('cuda' if use_cuda else 'cpu') 33 | 34 | 35 | class Encoder(nn.Module): 36 | def __init__(self, cconv, latent_size, channels, trans_layers=1): 37 | super().__init__() 38 | self.cconv = cconv 39 | self.grid_encoder = GridEncoder(channels, latent_size * 2) 40 | self.trans = InvertibleLinearResNet( 41 | latent_size, latent_size, trans_layers).to(device) 42 | 43 | def forward(self, cconv_graph, batch_size): 44 | x = self.cconv(*cconv_graph, batch_size) 45 | mu, logvar = self.grid_encoder(x).chunk(2, dim=1) 46 | if self.training: 47 | std = F.softplus(logvar) 48 | qz_x = Normal(mu, std) 49 | z_0 = qz_x.rsample() 50 | z_T = self.trans(z_0) 51 | else: 52 | z_T = self.trans(mu) 53 | return z_T 54 | 55 | 56 | class ConvCritic(nn.Module): 57 | def __init__(self, cconv, latent_size, channels, embed_size=32): 58 | super().__init__() 59 | 60 | self.cconv = cconv 61 | self.grid_critic = GridEncoder(channels, embed_size) 62 | 63 | self.z_dis = nn.Sequential( 64 | spectral_norm(nn.Linear(latent_size, embed_size)), 65 | nn.LeakyReLU(0.2), 66 | spectral_norm(nn.Linear(embed_size, embed_size)), 67 | ) 68 | 69 | self.x_linear = spectral_norm(nn.Linear(embed_size, 1)) 70 | 71 | self.xz_dis = nn.Sequential( 72 | spectral_norm(nn.Linear(embed_size * 2, embed_size)), 73 | nn.LeakyReLU(0.2), 74 | spectral_norm(nn.Linear(embed_size, 1)), 75 | ) 76 | 77 | def forward(self, cconv_graph, batch_size, z): 78 | x = self.cconv(*cconv_graph, batch_size) 79 | x = self.grid_critic(x) 80 | z = self.z_dis(z) 81 | xz = torch.cat((x, z), 1) 82 | xz = self.xz_dis(xz) 83 | x_out = self.x_linear(x).view(-1) 84 | xz_out = xz.view(-1) 85 | return xz_out, x_out 86 | 87 | 88 | class PBiGAN(nn.Module): 89 | def __init__(self, encoder, decoder, classifier, ae_loss='mse'): 90 | super().__init__() 91 | self.encoder = encoder 92 | self.decoder = decoder 93 | self.classifier = classifier 94 | self.ae_loss = ae_loss 95 | 96 | def forward(self, data, time, mask, y, cconv_graph, time_t, mask_t): 97 | batch_size = len(data) 98 | z_T = self.encoder(cconv_graph, batch_size) 99 | 100 | z_gen = torch.empty_like(z_T).normal_() 101 | x_gen = self.decoder(z_gen, time_t, mask_t) 102 | 103 | x_recon = self.decoder(z_T, time, mask) 104 | 105 | if self.ae_loss == 'mse': 106 | ae_loss = F.mse_loss(x_recon, data, reduction='none') * mask 107 | elif self.ae_loss == 'smooth_l1': 108 | ae_loss = F.smooth_l1_loss(x_recon, data, reduction='none') * mask 109 | 110 | ae_loss = ae_loss.sum((-1, -2)) 111 | 112 | y_logit = self.classifier(z_T).view(-1) 113 | 114 | # cls_loss: -log p(y|z) 115 | cls_loss = F.binary_cross_entropy_with_logits( 116 | y_logit, y.expand_as(y_logit), reduction='none') 117 | 118 | return z_T, x_recon, z_gen, x_gen, ae_loss.mean(), cls_loss.mean() 119 | 120 | def predict(self, data, time, mask, cconv_graph): 121 | batch_size = len(data) 122 | z_T = self.encoder(cconv_graph, batch_size) 123 | y_logit = self.classifier(z_T).view(-1) 124 | return y_logit 125 | 126 | 127 | def main(): 128 | parser = argparse.ArgumentParser() 129 | 130 | parser.add_argument('--data', default='mimic3.npz', 131 | help='data file') 132 | parser.add_argument('--seed', type=int, default=None, 133 | help='random seed. Randomly set if not specified.') 134 | 135 | # training options 136 | parser.add_argument('--nz', type=int, default=32, 137 | help='dimension of latent variable') 138 | parser.add_argument('--epoch', type=int, default=500, 139 | help='number of training epochs') 140 | parser.add_argument('--batch-size', type=int, default=64, 141 | help='batch size') 142 | # Use smaller test batch size to accommodate more importance samples 143 | parser.add_argument('--test-batch-size', type=int, default=32, 144 | help='batch size for validation and test set') 145 | parser.add_argument('--lr', type=float, default=2e-4, 146 | help='encoder/decoder learning rate') 147 | parser.add_argument('--dis-lr', type=float, default=3e-4, 148 | help='discriminator learning rate') 149 | parser.add_argument('--min-lr', type=float, default=1e-4, 150 | help='min encoder/decoder learning rate for LR ' 151 | 'scheduler. -1 to disable annealing') 152 | parser.add_argument('--min-dis-lr', type=float, default=1.5e-4, 153 | help='min discriminator learning rate for LR ' 154 | 'scheduler. -1 to disable annealing') 155 | parser.add_argument('--wd', type=float, default=1e-4, 156 | help='weight decay') 157 | parser.add_argument('--overlap', type=float, default=.5, 158 | help='kernel overlap') 159 | parser.add_argument('--cls', type=float, default=1, 160 | help='classification weight') 161 | parser.add_argument('--clsdep', type=int, default=1, 162 | help='number of layers for classifier') 163 | parser.add_argument('--eval-interval', type=int, default=1, 164 | help='AUC evaluation interval. ' 165 | '0 to disable evaluation.') 166 | parser.add_argument('--save-interval', type=int, default=0, 167 | help='interval to save models. 0 to disable saving.') 168 | parser.add_argument('--prefix', default='pbigan', 169 | help='prefix of output directory') 170 | parser.add_argument('--comp', type=int, default=7, 171 | help='continuous convolution kernel size') 172 | parser.add_argument('--ae', type=float, default=1, 173 | help='autoencoding regularization strength') 174 | parser.add_argument('--aeloss', default='mse', 175 | help='autoencoding loss. (options: mse, smooth_l1)') 176 | parser.add_argument('--dec-ch', default='8-16-16', 177 | help='decoder architecture') 178 | parser.add_argument('--enc-ch', default='64-32-32-16', 179 | help='encoder architecture') 180 | parser.add_argument('--dis-ch', default=None, 181 | help='discriminator architecture. Use encoder ' 182 | 'architecture if unspecified.') 183 | parser.add_argument('--rescale', dest='rescale', action='store_const', 184 | const=True, default=True, 185 | help='if set, rescale time to [-1, 1]') 186 | parser.add_argument('--no-rescale', dest='rescale', action='store_const', 187 | const=False) 188 | parser.add_argument('--cconvnorm', dest='cconv_norm', 189 | action='store_const', const=True, default=True, 190 | help='if set, normalize continuous convolutional ' 191 | 'layer using mean pooling') 192 | parser.add_argument('--no-cconvnorm', dest='cconv_norm', 193 | action='store_const', const=False) 194 | parser.add_argument('--cconv-ref', type=int, default=98, 195 | help='number of evenly-spaced reference locations ' 196 | 'for continuous convolutional layer') 197 | parser.add_argument('--dec-ref', type=int, default=128, 198 | help='number of evenly-spaced reference locations ' 199 | 'for decoder') 200 | parser.add_argument('--trans', type=int, default=2, 201 | help='number of encoder layers') 202 | parser.add_argument('--ema', dest='ema', type=int, default=0, 203 | help='start epoch of exponential moving average ' 204 | '(EMA). -1 to disable EMA') 205 | parser.add_argument('--ema-decay', type=float, default=.9999, 206 | help='EMA decay') 207 | parser.add_argument('--mmd', type=float, default=1, 208 | help='MMD strength for latent variable') 209 | 210 | args = parser.parse_args() 211 | 212 | nz = args.nz 213 | 214 | epochs = args.epoch 215 | eval_interval = args.eval_interval 216 | save_interval = args.save_interval 217 | 218 | if args.seed is None: 219 | rnd = np.random.RandomState(None) 220 | random_seed = rnd.randint(np.iinfo(np.uint32).max) 221 | else: 222 | random_seed = args.seed 223 | rnd = np.random.RandomState(random_seed) 224 | np.random.seed(random_seed) 225 | torch.manual_seed(random_seed) 226 | 227 | max_time = 5 228 | cconv_ref = args.cconv_ref 229 | overlap = args.overlap 230 | train_dataset, val_dataset, test_dataset = time_series.split_data( 231 | args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale) 232 | 233 | train_loader = DataLoader( 234 | train_dataset, batch_size=args.batch_size, shuffle=True, 235 | drop_last=True, collate_fn=train_dataset.collate_fn) 236 | n_train_batch = len(train_loader) 237 | 238 | time_loader = DataLoader( 239 | train_dataset, batch_size=args.batch_size, shuffle=True, 240 | drop_last=True, collate_fn=train_dataset.collate_fn) 241 | 242 | val_loader = DataLoader( 243 | val_dataset, batch_size=args.test_batch_size, shuffle=False, 244 | collate_fn=val_dataset.collate_fn) 245 | 246 | test_loader = DataLoader( 247 | test_dataset, batch_size=args.test_batch_size, shuffle=False, 248 | collate_fn=test_dataset.collate_fn) 249 | 250 | in_channels, seq_len = train_dataset.data.shape[1:] 251 | 252 | if args.dis_ch is None: 253 | args.dis_ch = args.enc_ch 254 | 255 | dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels] 256 | enc_channels = [int(c) for c in args.enc_ch.split('-')] 257 | dis_channels = [int(c) for c in args.dis_ch.split('-')] 258 | 259 | out_channels = enc_channels[0] 260 | 261 | squash = torch.sigmoid 262 | if args.rescale: 263 | squash = torch.tanh 264 | 265 | dec_ch_up = 2**(len(dec_channels) - 2) 266 | assert args.dec_ref % dec_ch_up == 0, ( 267 | f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.') 268 | dec_len0 = args.dec_ref // dec_ch_up 269 | grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash) 270 | 271 | decoder = Decoder( 272 | grid_decoder, max_time=max_time, dec_ref=args.dec_ref).to(device) 273 | cconv = ContinuousConv1D( 274 | in_channels, out_channels, max_time, cconv_ref, overlap_rate=overlap, 275 | kernel_size=args.comp, norm=args.cconv_norm).to(device) 276 | encoder = Encoder(cconv, nz, enc_channels, args.trans).to(device) 277 | 278 | classifier = Classifier(nz, args.clsdep).to(device) 279 | 280 | pbigan = PBiGAN( 281 | encoder, decoder, classifier, ae_loss=args.aeloss).to(device) 282 | 283 | ema = None 284 | if args.ema >= 0: 285 | ema = EMA(pbigan, args.ema_decay, args.ema) 286 | 287 | critic_cconv = ContinuousConv1D( 288 | in_channels, out_channels, max_time, cconv_ref, overlap_rate=overlap, 289 | kernel_size=args.comp, norm=args.cconv_norm).to(device) 290 | critic_embed = 32 291 | critic = ConvCritic( 292 | critic_cconv, nz, dis_channels, critic_embed).to(device) 293 | 294 | optimizer = optim.Adam( 295 | pbigan.parameters(), lr=args.lr, 296 | betas=(0, .999), weight_decay=args.wd) 297 | critic_optimizer = optim.Adam( 298 | critic.parameters(), lr=args.dis_lr, 299 | betas=(0, .999), weight_decay=args.wd) 300 | 301 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs) 302 | dis_scheduler = make_scheduler( 303 | critic_optimizer, args.dis_lr, args.min_dis_lr, epochs) 304 | 305 | path = '{}_{}'.format( 306 | args.prefix, datetime.now().strftime('%m%d.%H%M%S')) 307 | 308 | output_dir = Path('results') / 'mimic3-pbigan' / path 309 | print(output_dir) 310 | log_dir = mkdir(output_dir / 'log') 311 | model_dir = mkdir(output_dir / 'model') 312 | 313 | start_epoch = 0 314 | 315 | with (log_dir / 'seed.txt').open('w') as f: 316 | print(random_seed, file=f) 317 | with (log_dir / 'gpu.txt').open('a') as f: 318 | print(torch.cuda.device_count(), start_epoch, file=f) 319 | with (log_dir / 'args.txt').open('w') as f: 320 | for key, val in sorted(vars(args).items()): 321 | print(f'{key}: {val}', file=f) 322 | with (log_dir / 'params.txt').open('w') as f: 323 | def print_params_count(module, name): 324 | try: # sum counts if module is a list 325 | params_count = sum(count_parameters(m) for m in module) 326 | except TypeError: 327 | params_count = count_parameters(module) 328 | print(f'{name} {params_count}', file=f) 329 | print_params_count(grid_decoder, 'grid_decoder') 330 | print_params_count(decoder, 'decoder') 331 | print_params_count(cconv, 'cconv') 332 | print_params_count(encoder, 'encoder') 333 | print_params_count(classifier, 'classifier') 334 | print_params_count(pbigan, 'pbigan') 335 | print_params_count(critic, 'critic') 336 | print_params_count([pbigan, critic], 'total') 337 | 338 | tracker = Tracker(log_dir, n_train_batch) 339 | evaluator = Evaluator(pbigan, val_loader, test_loader, log_dir) 340 | start = time.time() 341 | epoch_start = start 342 | 343 | batch_size = args.batch_size 344 | 345 | for epoch in range(start_epoch, epochs): 346 | loss_breakdown = defaultdict(float) 347 | epoch_start = time.time() 348 | 349 | if epoch >= 40: 350 | args.cls = 200 351 | 352 | for ((val, idx, mask, y, _, cconv_graph), 353 | (_, idx_t, mask_t, _, index, _)) in zip( 354 | train_loader, time_loader): 355 | 356 | z_enc, x_recon, z_gen, x_gen, ae_loss, cls_loss = pbigan( 357 | val, idx, mask, y, cconv_graph, idx_t, mask_t) 358 | 359 | cconv_graph_gen = train_dataset.make_graph( 360 | x_gen, idx_t, mask_t, index) 361 | 362 | # Don't need pbigan.requires_grad_(False); 363 | # critic takes as input only the detached tensors. 364 | real = critic(cconv_graph, batch_size, z_enc.detach()) 365 | detached_graph = [[cat_y.detach() for cat_y in x] if i == 2 else x 366 | for i, x in enumerate(cconv_graph_gen)] 367 | fake = critic(detached_graph, batch_size, z_gen.detach()) 368 | 369 | D_loss = gan_loss(real, fake, 1, 0) 370 | 371 | critic_optimizer.zero_grad() 372 | D_loss.backward() 373 | critic_optimizer.step() 374 | 375 | for p in critic.parameters(): 376 | p.requires_grad_(False) 377 | real = critic(cconv_graph, batch_size, z_enc) 378 | fake = critic(cconv_graph_gen, batch_size, z_gen) 379 | 380 | G_loss = gan_loss(real, fake, 0, 1) 381 | 382 | mmd_loss = mmd(z_enc, z_gen) 383 | 384 | loss = (G_loss + ae_loss * args.ae + cls_loss * args.cls 385 | + mmd_loss * args.mmd) 386 | 387 | optimizer.zero_grad() 388 | loss.backward() 389 | optimizer.step() 390 | for p in critic.parameters(): 391 | p.requires_grad_(True) 392 | 393 | if ema: 394 | ema.update() 395 | 396 | loss_breakdown['D'] += D_loss.item() 397 | loss_breakdown['G'] += G_loss.item() 398 | loss_breakdown['AE'] += ae_loss.item() 399 | loss_breakdown['MMD'] += mmd_loss.item() 400 | loss_breakdown['CLS'] += cls_loss.item() 401 | loss_breakdown['total'] += loss.item() 402 | 403 | if scheduler: 404 | scheduler.step() 405 | if dis_scheduler: 406 | dis_scheduler.step() 407 | 408 | cur_time = time.time() 409 | tracker.log( 410 | epoch, loss_breakdown, cur_time - epoch_start, cur_time - start) 411 | 412 | if eval_interval > 0 and (epoch + 1) % eval_interval == 0: 413 | if ema: 414 | ema.apply() 415 | evaluator.evaluate(epoch) 416 | ema.restore() 417 | else: 418 | evaluator.evaluate(epoch) 419 | 420 | model_dict = { 421 | 'pbigan': pbigan.state_dict(), 422 | 'critic': critic.state_dict(), 423 | 'ema': ema.state_dict() if ema else None, 424 | 'epoch': epoch + 1, 425 | 'args': args, 426 | } 427 | torch.save(model_dict, str(log_dir / 'model.pth')) 428 | if save_interval > 0 and (epoch + 1) % save_interval == 0: 429 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) 430 | 431 | print(output_dir) 432 | 433 | 434 | if __name__ == '__main__': 435 | main() 436 | -------------------------------------------------------------------------------- /time-series/mimic3_pvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.distributions.normal import Normal 6 | from torch.utils.data import DataLoader 7 | import numpy as np 8 | from datetime import datetime 9 | import time 10 | from pathlib import Path 11 | import argparse 12 | import math 13 | from collections import defaultdict 14 | from spline_cconv import ContinuousConv1D 15 | import time_series 16 | import flow 17 | from ema import EMA 18 | from utils import count_parameters, mkdir, make_scheduler 19 | from tracker import Tracker 20 | from evaluate import Evaluator 21 | from layers import ( 22 | Classifier, 23 | GridDecoder, 24 | GridEncoder, 25 | Decoder, 26 | ) 27 | 28 | 29 | use_cuda = torch.cuda.is_available() 30 | device = torch.device('cuda' if use_cuda else 'cpu') 31 | 32 | 33 | class Encoder(nn.Module): 34 | def __init__(self, cconv, latent_size, channels, flow_depth=2): 35 | super().__init__() 36 | self.cconv = cconv 37 | 38 | if flow_depth > 0: 39 | hidden_size = latent_size * 2 40 | flow_layers = [flow.InverseAutoregressiveFlow( 41 | latent_size, hidden_size, latent_size) 42 | for _ in range(flow_depth)] 43 | 44 | flow_layers.append(flow.Reverse(latent_size)) 45 | self.q_z_flow = flow.FlowSequential(*flow_layers) 46 | self.enc_chunk = 3 47 | else: 48 | self.q_z_flow = None 49 | self.enc_chunk = 2 50 | 51 | self.grid_encoder = GridEncoder(channels, latent_size * self.enc_chunk) 52 | 53 | def forward(self, cconv_graph, batch_size, iw_samples=3): 54 | x = self.cconv(*cconv_graph, batch_size) 55 | grid_enc = self.grid_encoder(x).chunk(self.enc_chunk, dim=1) 56 | mu, logvar = grid_enc[:2] 57 | std = F.softplus(logvar) 58 | qz_x = Normal(mu, std) 59 | z_0 = qz_x.rsample([iw_samples]) 60 | log_q_z_0 = qz_x.log_prob(z_0) 61 | if self.q_z_flow: 62 | z_T, log_q_z_flow = self.q_z_flow(z_0, context=grid_enc[2]) 63 | log_q_z = (log_q_z_0 + log_q_z_flow).sum(-1) 64 | else: 65 | z_T, log_q_z = z_0, log_q_z_0.sum(-1) 66 | return z_T, log_q_z 67 | 68 | 69 | def masked_loss(loss_fn, pred, data, mask): 70 | # return (loss_fn(pred * mask, data * mask, 71 | # reduction='none') * mask).mean() 72 | # Expand data shape from (batch_size, d) to (iw_samples, batch_size, d) 73 | return loss_fn(pred, data.expand_as(pred), reduction='none') * mask 74 | 75 | 76 | class PVAE(nn.Module): 77 | def __init__(self, encoder, decoder, classifier, sigma=.2, cls_weight=100): 78 | super().__init__() 79 | self.encoder = encoder 80 | self.decoder = decoder 81 | self.classifier = classifier 82 | self.sigma = sigma 83 | self.cls_weight = cls_weight 84 | 85 | def forward(self, data, time, mask, y, cconv_graph, iw_samples=3, 86 | ts_lambda=1, kl_lambda=1): 87 | batch_size = len(data) 88 | z_T, log_q_z = self.encoder(cconv_graph, batch_size, iw_samples) 89 | 90 | pz = Normal(torch.zeros_like(z_T), torch.ones_like(z_T)) 91 | log_p_z = pz.log_prob(z_T).sum(-1) 92 | # kl_loss: log q(z|x) - log p(z) 93 | kl_loss = log_q_z - log_p_z 94 | 95 | var2 = 2 * self.sigma**2 96 | # half_log2pivar: log(2 * pi * sigma^2) / 2 97 | half_log2pivar = .5 * math.log(math.pi * var2) 98 | 99 | # Multivariate Gaussian log-likelihood: 100 | # -D/2 * log(2*pi*sigma^2) - 1/2 \sum_{i=1}^D (x_i - mu_i)^2 / sigma^2 101 | def neg_gaussian_logp(pred, data, mask=None): 102 | se = F.mse_loss(pred, data.expand_as(pred), reduction='none') 103 | if mask is None: 104 | return se / var2 + half_log2pivar 105 | return (se / var2 + half_log2pivar) * mask 106 | 107 | # Reshape z to accommodate modules with strict input shape 108 | # requirements such as convolutional layers. 109 | # Expected shape of x_recon: (iw_samples * batch_size, C, L) 110 | z_flat = z_T.view(-1, *z_T.shape[2:]) 111 | x_recon = self.decoder( 112 | z_flat, 113 | time.repeat((iw_samples, 1, 1)), 114 | mask.repeat((iw_samples, 1, 1))) 115 | 116 | # Gaussian noise for time series 117 | # data shape :(batch_size, C, L) 118 | # x_recon shape: (iw_samples * batch_size, C, L) 119 | x_recon = x_recon.view(iw_samples, *data.shape) 120 | neg_logp = neg_gaussian_logp(x_recon, data, mask) 121 | # neg_logp: -log p(x|z) 122 | neg_logp = neg_logp.sum((-1, -2)) 123 | 124 | y_logit = self.classifier(z_flat).view(iw_samples, -1) 125 | 126 | # cls_loss: -log p(y|z) 127 | cls_loss = F.binary_cross_entropy_with_logits( 128 | y_logit, y.expand_as(y_logit), reduction='none') 129 | 130 | # elbo_x = log p(x|z) + log p(z) - log q(z|x) 131 | elbo_x = -(neg_logp * ts_lambda + kl_loss * kl_lambda) 132 | 133 | with torch.no_grad(): 134 | is_weight = F.softmax(elbo_x, 0) 135 | 136 | # IWAE loss: -log E[p(x|z) p(z) / q(z|x)] 137 | # Here we ignore the constant shift of -log(k_samples) 138 | loss_x = -elbo_x.logsumexp(0).mean() 139 | loss_y = (is_weight * cls_loss).sum(0).mean() 140 | loss = loss_x + loss_y * self.cls_weight 141 | 142 | # For debugging 143 | x_se = masked_loss(F.mse_loss, x_recon, data, mask) 144 | mse = x_se.sum((-1, -2)) / mask.sum((-1, -2)).clamp(min=1) 145 | 146 | CE = (is_weight * cls_loss).sum(0).mean().item() 147 | loss_breakdown = { 148 | 'loss': loss.item(), 149 | 'reconst.': neg_logp.mean().item() * ts_lambda, 150 | 'MSE': mse.mean().item(), 151 | 'KL': kl_loss.mean().item() * kl_lambda, 152 | 'CE': CE, 153 | 'classif.': CE * self.cls_weight, 154 | } 155 | return loss, z_T, elbo_x, loss_breakdown 156 | 157 | def predict(self, data, time, mask, cconv_graph, iw_samples=50): 158 | dummy_y = data.new_zeros(len(data)) 159 | _, z, elbo, _ = self( 160 | data, time, mask, dummy_y, cconv_graph, iw_samples) 161 | z_flat = z.view(-1, *z.shape[2:]) 162 | pred_logit = self.classifier(z_flat).view(iw_samples, -1) 163 | is_weight = F.softmax(elbo, 0) 164 | 165 | # Importance reweighted predictive probability 166 | # p(y|x) =~ E_{q_IW(z|x)}[p(y|z)] 167 | py_z = torch.sigmoid(pred_logit) 168 | expected_py_z = (is_weight * py_z).sum(0) 169 | return expected_py_z 170 | 171 | 172 | def main(): 173 | parser = argparse.ArgumentParser() 174 | 175 | parser.add_argument('--data', default='mimic3.npz', 176 | help='data file') 177 | parser.add_argument('--seed', type=int, default=None, 178 | help='random seed. Randomly set if not specified.') 179 | 180 | # training options 181 | parser.add_argument('--nz', type=int, default=32, 182 | help='dimension of latent variable') 183 | parser.add_argument('--epoch', type=int, default=200, 184 | help='number of training epochs') 185 | parser.add_argument('--batch-size', type=int, default=64, 186 | help='batch size') 187 | # Use smaller test batch size to accommodate more importance samples 188 | parser.add_argument('--test-batch-size', type=int, default=32, 189 | help='batch size for validation and test set') 190 | parser.add_argument('--train-k', type=int, default=8, 191 | help='number of importance weights for training') 192 | parser.add_argument('--test-k', type=int, default=50, 193 | help='number of importance weights for evaluation') 194 | parser.add_argument('--flow', type=int, default=2, 195 | help='number of IAF layers') 196 | parser.add_argument('--lr', type=float, default=2e-4, 197 | help='global learning rate') 198 | parser.add_argument('--enc-lr', type=float, default=1e-4, 199 | help='encoder learning rate') 200 | parser.add_argument('--dec-lr', type=float, default=1e-4, 201 | help='decoder learning rate') 202 | parser.add_argument('--min-lr', type=float, default=-1, 203 | help='min learning rate for LR scheduler. ' 204 | '-1 to disable annealing') 205 | parser.add_argument('--wd', type=float, default=1e-3, 206 | help='weight decay') 207 | parser.add_argument('--overlap', type=float, default=.5, 208 | help='kernel overlap') 209 | parser.add_argument('--cls', type=float, default=200, 210 | help='classification weight') 211 | parser.add_argument('--clsdep', type=int, default=1, 212 | help='number of layers for classifier') 213 | parser.add_argument('--ts', type=float, default=1, 214 | help='log-likelihood weight for ELBO') 215 | parser.add_argument('--kl', type=float, default=.1, 216 | help='KL weight for ELBO') 217 | parser.add_argument('--eval-interval', type=int, default=1, 218 | help='AUC evaluation interval. ' 219 | '0 to disable evaluation.') 220 | parser.add_argument('--save-interval', type=int, default=0, 221 | help='interval to save models. 0 to disable saving.') 222 | parser.add_argument('--prefix', default='pvae', 223 | help='prefix of output directory') 224 | parser.add_argument('--comp', type=int, default=7, 225 | help='continuous convolution kernel size') 226 | parser.add_argument('--sigma', type=float, default=.2, 227 | help='standard deviation for Gaussian likelihood') 228 | parser.add_argument('--dec-ch', default='8-16-16', 229 | help='decoder architecture') 230 | parser.add_argument('--enc-ch', default='64-32-32-16', 231 | help='encoder architecture') 232 | parser.add_argument('--rescale', dest='rescale', action='store_const', 233 | const=True, default=True, 234 | help='if set, rescale time to [-1, 1]') 235 | parser.add_argument('--no-rescale', dest='rescale', action='store_const', 236 | const=False) 237 | parser.add_argument('--cconvnorm', dest='cconv_norm', 238 | action='store_const', const=True, default=True, 239 | help='if set, normalize continuous convolutional ' 240 | 'layer using mean pooling') 241 | parser.add_argument('--no-cconvnorm', dest='cconv_norm', 242 | action='store_const', const=False) 243 | parser.add_argument('--cconv-ref', type=int, default=98, 244 | help='number of evenly-spaced reference locations ' 245 | 'for continuous convolutional layer') 246 | parser.add_argument('--dec-ref', type=int, default=128, 247 | help='number of evenly-spaced reference locations ' 248 | 'for decoder') 249 | parser.add_argument('--ema', dest='ema', type=int, default=0, 250 | help='start epoch of exponential moving average ' 251 | '(EMA). -1 to disable EMA') 252 | parser.add_argument('--ema-decay', type=float, default=.9999, 253 | help='EMA decay') 254 | 255 | args = parser.parse_args() 256 | 257 | nz = args.nz 258 | 259 | epochs = args.epoch 260 | eval_interval = args.eval_interval 261 | save_interval = args.save_interval 262 | 263 | if args.seed is None: 264 | rnd = np.random.RandomState(None) 265 | random_seed = rnd.randint(np.iinfo(np.uint32).max) 266 | else: 267 | random_seed = args.seed 268 | rnd = np.random.RandomState(random_seed) 269 | np.random.seed(random_seed) 270 | torch.manual_seed(random_seed) 271 | 272 | max_time = 5 273 | cconv_ref = args.cconv_ref 274 | overlap = args.overlap 275 | train_dataset, val_dataset, test_dataset = time_series.split_data( 276 | args.data, rnd, max_time, cconv_ref, overlap, device, args.rescale) 277 | 278 | train_loader = DataLoader( 279 | train_dataset, batch_size=args.batch_size, shuffle=True, 280 | drop_last=True, collate_fn=train_dataset.collate_fn) 281 | n_train_batch = len(train_loader) 282 | 283 | val_loader = DataLoader( 284 | val_dataset, batch_size=args.test_batch_size, shuffle=False, 285 | collate_fn=val_dataset.collate_fn) 286 | 287 | test_loader = DataLoader( 288 | test_dataset, batch_size=args.test_batch_size, shuffle=False, 289 | collate_fn=test_dataset.collate_fn) 290 | 291 | in_channels, seq_len = train_dataset.data.shape[1:] 292 | 293 | dec_channels = [int(c) for c in args.dec_ch.split('-')] + [in_channels] 294 | enc_channels = [int(c) for c in args.enc_ch.split('-')] 295 | 296 | out_channels = enc_channels[0] 297 | 298 | squash = torch.sigmoid 299 | if args.rescale: 300 | squash = torch.tanh 301 | 302 | dec_ch_up = 2**(len(dec_channels) - 2) 303 | assert args.dec_ref % dec_ch_up == 0, ( 304 | f'--dec-ref={args.dec_ref} is not divided by {dec_ch_up}.') 305 | dec_len0 = args.dec_ref // dec_ch_up 306 | grid_decoder = GridDecoder(nz, dec_channels, dec_len0, squash) 307 | 308 | decoder = Decoder( 309 | grid_decoder, max_time=max_time, dec_ref=args.dec_ref).to(device) 310 | 311 | cconv = ContinuousConv1D(in_channels, out_channels, max_time, cconv_ref, 312 | overlap_rate=overlap, kernel_size=args.comp, 313 | norm=args.cconv_norm).to(device) 314 | encoder = Encoder(cconv, nz, enc_channels, args.flow).to(device) 315 | 316 | classifier = Classifier(nz, args.clsdep).to(device) 317 | 318 | pvae = PVAE( 319 | encoder, decoder, classifier, args.sigma, args.cls).to(device) 320 | 321 | ema = None 322 | if args.ema >= 0: 323 | ema = EMA(pvae, args.ema_decay, args.ema) 324 | 325 | other_params = [param for name, param in pvae.named_parameters() 326 | if not (name.startswith('decoder.grid_decoder') 327 | or name.startswith('encoder.grid_encoder'))] 328 | params = [ 329 | {'params': decoder.grid_decoder.parameters(), 'lr': args.dec_lr}, 330 | {'params': encoder.grid_encoder.parameters(), 'lr': args.enc_lr}, 331 | {'params': other_params}, 332 | ] 333 | 334 | optimizer = optim.Adam( 335 | params, lr=args.lr, weight_decay=args.wd) 336 | 337 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs) 338 | 339 | path = '{}_{}'.format( 340 | args.prefix, datetime.now().strftime('%m%d.%H%M%S')) 341 | 342 | output_dir = Path('results') / 'mimic3-pvae' / path 343 | print(output_dir) 344 | log_dir = mkdir(output_dir / 'log') 345 | model_dir = mkdir(output_dir / 'model') 346 | 347 | start_epoch = 0 348 | 349 | with (log_dir / 'seed.txt').open('w') as f: 350 | print(random_seed, file=f) 351 | with (log_dir / 'gpu.txt').open('a') as f: 352 | print(torch.cuda.device_count(), start_epoch, file=f) 353 | with (log_dir / 'args.txt').open('w') as f: 354 | for key, val in sorted(vars(args).items()): 355 | print(f'{key}: {val}', file=f) 356 | with (log_dir / 'params.txt').open('w') as f: 357 | def print_params_count(module, name): 358 | try: # sum counts if module is a list 359 | params_count = sum(count_parameters(m) for m in module) 360 | except TypeError: 361 | params_count = count_parameters(module) 362 | print(f'{name} {params_count}', file=f) 363 | print_params_count(grid_decoder, 'grid_decoder') 364 | print_params_count(decoder, 'decoder') 365 | print_params_count(cconv, 'cconv') 366 | print_params_count(encoder, 'encoder') 367 | print_params_count(classifier, 'classifier') 368 | print_params_count(pvae, 'pvae') 369 | print_params_count(pvae, 'total') 370 | 371 | tracker = Tracker(log_dir, n_train_batch) 372 | evaluator = Evaluator(pvae, val_loader, test_loader, log_dir, 373 | eval_args={'iw_samples': args.test_k}) 374 | start = time.time() 375 | epoch_start = start 376 | 377 | for epoch in range(start_epoch, epochs): 378 | loss_breakdown = defaultdict(float) 379 | epoch_start = time.time() 380 | for (val, idx, mask, y, _, cconv_graph) in train_loader: 381 | optimizer.zero_grad() 382 | loss, _, _, loss_info = pvae( 383 | val, idx, mask, y, cconv_graph, args.train_k, args.ts, args.kl) 384 | loss.backward() 385 | optimizer.step() 386 | 387 | if ema: 388 | ema.update() 389 | 390 | for loss_name, loss_val in loss_info.items(): 391 | loss_breakdown[loss_name] += loss_val 392 | 393 | if scheduler: 394 | scheduler.step() 395 | 396 | cur_time = time.time() 397 | tracker.log( 398 | epoch, loss_breakdown, cur_time - epoch_start, cur_time - start) 399 | 400 | if eval_interval > 0 and (epoch + 1) % eval_interval == 0: 401 | if ema: 402 | ema.apply() 403 | evaluator.evaluate(epoch) 404 | ema.restore() 405 | else: 406 | evaluator.evaluate(epoch) 407 | 408 | model_dict = { 409 | 'pvae': pvae.state_dict(), 410 | 'ema': ema.state_dict() if ema else None, 411 | 'epoch': epoch + 1, 412 | 'args': args, 413 | } 414 | torch.save(model_dict, str(log_dir / 'model.pth')) 415 | if save_interval > 0 and (epoch + 1) % save_interval == 0: 416 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) 417 | 418 | print(output_dir) 419 | 420 | 421 | if __name__ == '__main__': 422 | main() 423 | -------------------------------------------------------------------------------- /time-series/mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def mmd(x, y): 5 | n, dim = x.shape 6 | 7 | xx = (x**2).sum(1, keepdim=True) 8 | yy = (y**2).sum(1, keepdim=True) 9 | 10 | outer_xx = torch.mm(x, x.t()) 11 | outer_yy = torch.mm(y, y.t()) 12 | outer_xy = torch.mm(x, y.t()) 13 | 14 | diff_xx = xx + xx.t() - 2 * outer_xx 15 | diff_yy = yy + yy.t() - 2 * outer_yy 16 | diff_xy = xx + yy.t() - 2 * outer_xy 17 | 18 | C = 2. * dim 19 | k_xx = C / (C + diff_xx) 20 | k_yy = C / (C + diff_yy) 21 | k_xy = C / (C + diff_xy) 22 | 23 | mean_xx = (k_xx.sum() - k_xx.diag().sum()) / (n * (n - 1)) 24 | mean_yy = (k_yy.sum() - k_yy.diag().sum()) / (n * (n - 1)) 25 | mean_xy = k_xy.sum() / (n * n) 26 | 27 | return mean_xx + mean_yy - 2 * mean_xy 28 | -------------------------------------------------------------------------------- /time-series/sn_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils import spectral_norm 4 | import torch.nn.functional as F 5 | 6 | 7 | class InvertibleLinearResNetBlock(nn.Module): 8 | def __init__(self, in_size, hid_size): 9 | super().__init__() 10 | self.res = nn.Sequential( 11 | nn.ELU(), 12 | spectral_norm(nn.Linear(in_size, hid_size)), 13 | nn.ELU(), 14 | spectral_norm(nn.Linear(hid_size, in_size)), 15 | ) 16 | 17 | def forward(self, x): 18 | return x + self.res(x) 19 | 20 | 21 | class InvertibleLinearResNet(nn.Module): 22 | def __init__(self, in_size, hid_size, layers=1): 23 | super().__init__() 24 | self.res = nn.Sequential( 25 | *[InvertibleLinearResNetBlock(in_size, hid_size) 26 | for _ in range(layers)]) 27 | 28 | def forward(self, x): 29 | return self.res(x) 30 | 31 | 32 | class ResLinearBlock(nn.Module): 33 | def __init__(self, in_size, out_size): 34 | super().__init__() 35 | self.linear = nn.Sequential( 36 | spectral_norm(nn.Linear(in_size, out_size)), 37 | nn.Dropout(), 38 | nn.LeakyReLU(.2), 39 | spectral_norm(nn.Linear(out_size, out_size)), 40 | nn.Dropout(), 41 | nn.LeakyReLU(.2), 42 | ) 43 | 44 | self.skip = nn.Sequential( 45 | spectral_norm(nn.Linear(in_size, out_size)), 46 | nn.LeakyReLU(.2), 47 | ) 48 | 49 | def forward(self, x): 50 | return self.linear(x) + self.skip(x) 51 | 52 | 53 | class Classifier(nn.Module): 54 | def __init__(self, in_size, layers=1): 55 | super().__init__() 56 | blocks = [] 57 | for _ in range(layers): 58 | blocks.append(ResLinearBlock(in_size, in_size)) 59 | # No spectral normalization for the last layer 60 | blocks.append(nn.Linear(in_size, 1)) 61 | self.res_linear = nn.Sequential(*blocks) 62 | 63 | def forward(self, x): 64 | return self.res_linear(x) 65 | 66 | 67 | class GBlock(nn.Module): 68 | def __init__(self, in_channels, out_channels): 69 | super().__init__() 70 | self.activation = nn.ReLU(inplace=False) 71 | self.bn1 = nn.BatchNorm1d(in_channels) 72 | self.bn2 = nn.BatchNorm1d(out_channels) 73 | self.conv1 = spectral_norm(nn.Conv1d( 74 | in_channels, out_channels, kernel_size=3, padding=1)) 75 | self.conv2 = spectral_norm(nn.Conv1d( 76 | out_channels, out_channels, kernel_size=3, padding=1)) 77 | self.convx = None 78 | if in_channels != out_channels: 79 | self.convx = spectral_norm(nn.Conv1d( 80 | in_channels, out_channels, kernel_size=1)) 81 | 82 | def forward(self, x): 83 | h = self.activation(self.bn1(x)) 84 | h = F.interpolate(h, scale_factor=2) 85 | x = F.interpolate(x, scale_factor=2) 86 | if self.convx: 87 | x = self.convx(x) 88 | h = self.conv1(h) 89 | h = self.activation(self.bn2(h)) 90 | h = self.conv2(h) 91 | return h + x 92 | 93 | 94 | class GridDecoder(nn.Module): 95 | def __init__(self, dim_z, channels, start_len=16, squash=None): 96 | super().__init__() 97 | self.activation = nn.ReLU(inplace=False) 98 | self.start_len = start_len 99 | self.linear = spectral_norm(nn.Linear(dim_z, channels[0] * start_len)) 100 | self.blocks = nn.Sequential( 101 | *[GBlock(in_channels, channels[c + 1]) 102 | for c, in_channels in enumerate(channels[:-2])]) 103 | self.output = nn.Sequential( 104 | nn.BatchNorm1d(channels[-2]), 105 | self.activation, 106 | spectral_norm(nn.Conv1d( 107 | channels[-2], channels[-1], kernel_size=3, padding=1)), 108 | ) 109 | self.squash = squash 110 | 111 | def forward(self, z): 112 | h = self.linear(z) 113 | h = h.view(h.shape[0], -1, self.start_len) 114 | h = self.blocks(h) 115 | h = self.output(h) 116 | if self.squash: 117 | h = self.squash(h) 118 | return h 119 | 120 | 121 | class DBlock(nn.Module): 122 | def __init__(self, in_channels, out_channels, downsample=True): 123 | super().__init__() 124 | self.activation = nn.ReLU(inplace=False) 125 | self.conv1 = spectral_norm(nn.Conv1d( 126 | in_channels, out_channels, kernel_size=3, padding=1)) 127 | self.conv2 = spectral_norm(nn.Conv1d( 128 | out_channels, out_channels, kernel_size=3, padding=1)) 129 | self.convx = None 130 | if in_channels != out_channels: 131 | self.convx = spectral_norm(nn.Conv1d( 132 | in_channels, out_channels, kernel_size=1)) 133 | self.downsample = None 134 | if downsample: 135 | self.downsample = nn.AvgPool1d(2) 136 | 137 | def shortcut(self, x): 138 | if self.convx: 139 | x = self.convx(x) 140 | if self.downsample: 141 | x = self.downsample(x) 142 | return x 143 | 144 | def forward(self, x): 145 | # pre-activation 146 | h = self.activation(x) 147 | h = self.conv1(h) 148 | h = self.conv2(self.activation(h)) 149 | if self.downsample: 150 | h = self.downsample(h) 151 | return h + self.shortcut(x) 152 | 153 | 154 | class GridEncoder(nn.Module): 155 | def __init__(self, channels, out_dim=1): 156 | super().__init__() 157 | self.activation = nn.ReLU(inplace=False) 158 | self.blocks = nn.Sequential( 159 | *[DBlock(in_channels, out_channels) 160 | for in_channels, out_channels 161 | in zip(channels[:-1], channels[1:])]) 162 | self.linear = spectral_norm(nn.Linear(channels[-1], out_dim)) 163 | 164 | def forward(self, x): 165 | h = x 166 | h = self.blocks(h) 167 | h = self.activation(h).sum(2) 168 | return self.linear(h) 169 | 170 | 171 | class Decoder(nn.Module): 172 | def __init__(self, grid_decoder, max_time=5, kernel_bw=None, dec_ref=128): 173 | super().__init__() 174 | if kernel_bw is None: 175 | self.kernel_bw = max_time / dec_ref * 3 176 | else: 177 | self.kernel_bw = kernel_bw 178 | # ref_times are the assigned time stamps for the evenly-spaced 179 | # generated sequences by conv1d. 180 | self.register_buffer('ref_times', torch.linspace(0, max_time, dec_ref)) 181 | self.ref_times = self.ref_times[:, None] 182 | self.grid_decoder = grid_decoder 183 | 184 | def forward(self, code, time, mask): 185 | """ 186 | Args: 187 | code: shape (batch_size, latent_size) 188 | time: shape (batch_size, channels, max_seq_len) 189 | mask: shape (batch_size, channels, max_seq_len) 190 | 191 | Returns: 192 | interpolated tensor of shape (batch_size, max_seq_len) 193 | """ 194 | # shape of x: (batch_size, n_channels, dec_ref) 195 | x = self.grid_decoder(code) 196 | 197 | # t_diff shape: (batch_size, n_channels, dec_ref, max_seq_len) 198 | t_diff = time[:, :, None] - self.ref_times 199 | 200 | # Epanechnikov quadratic kernel: 201 | # K_\lambda(x_0, x) = relu(3/4 * (1 - (|x_0 - x| / \lambda)^2)) 202 | # shape of w: (batch_size, n_channels, dec_ref, max_seq_len) 203 | w = F.relu((1 - (t_diff / self.kernel_bw)**2) * .75) 204 | # avoid divided by zero 205 | # normalizer = torch.clamp(w.sum(2), min=1e-6) 206 | # return ((x[:, :, :, None] * w).sum(2) * mask) / normalizer 207 | ks_x = ((x[:, :, :, None] * w).sum(2) * mask) / w.sum(2) 208 | return ks_x 209 | 210 | 211 | def gan_loss(real, fake, real_target, fake_target): 212 | real_score = sum(F.binary_cross_entropy_with_logits( 213 | r, r.new_tensor(real_target).expand_as(r)) for r in real) 214 | fake_score = sum(F.binary_cross_entropy_with_logits( 215 | f, f.new_tensor(fake_target).expand_as(f)) for f in fake) 216 | return real_score + fake_score 217 | -------------------------------------------------------------------------------- /time-series/spline_cconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data.dataloader import default_collate 4 | from torch_spline_conv import SplineBasis, SplineWeighting 5 | import math 6 | 7 | 8 | def kernel_width(max_time, ref_size, overlap_rate): 9 | return max_time / (ref_size + overlap_rate - overlap_rate * ref_size) 10 | 11 | 12 | class ContinuousConv1D(nn.Module): 13 | def __init__(self, 14 | in_channels, 15 | out_channels=64, 16 | max_time=5, 17 | ref_size=98, 18 | overlap_rate=.5, 19 | kernel_size=5, 20 | norm=False, 21 | bias=True, 22 | spline_degree=1): 23 | super().__init__() 24 | 25 | self.in_channels = in_channels 26 | self.out_channels = out_channels 27 | self.ref_size = ref_size 28 | self.overlap_rate = overlap_rate 29 | self.kernel_width = kernel_width(max_time, ref_size, overlap_rate) 30 | self.spline_degree = spline_degree 31 | self.norm = norm 32 | 33 | margin = self.kernel_width / 2 34 | refs = torch.linspace(margin, max_time - margin, ref_size) 35 | self.register_buffer('refs', refs) 36 | 37 | kernel_size = torch.tensor([kernel_size], dtype=torch.long) 38 | self.register_buffer('kernel_size', kernel_size) 39 | 40 | is_open_spline = torch.tensor([1], dtype=torch.uint8) 41 | self.register_buffer('is_open_spline', is_open_spline) 42 | 43 | self.weight = nn.Parameter( 44 | torch.Tensor(kernel_size, in_channels, out_channels)) 45 | if bias: 46 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 47 | else: 48 | self.register_parameter('bias', None) 49 | 50 | self.reset_parameters() 51 | 52 | def reset_parameters(self): 53 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 54 | if self.bias is not None: 55 | bound = 1 / math.sqrt(self.in_channels) 56 | nn.init.uniform_(self.bias, -bound, bound) 57 | 58 | def forward(self, pseudo, ref_idx, y, ref_deg, batch_size): 59 | conv_out = y[0].new_zeros( 60 | self.in_channels, self.ref_size * batch_size, self.out_channels) 61 | for c in range(self.in_channels): 62 | data = SplineBasis.apply(pseudo[c], self.kernel_size, 63 | self.is_open_spline, self.spline_degree) 64 | out = SplineWeighting.apply( 65 | y[c], self.weight[:, c].unsqueeze(1), *data) 66 | idx = ref_idx[c].expand_as(out) 67 | conv_out[c].scatter_add_(0, idx, out) 68 | if self.norm: 69 | conv_out[c].div_(ref_deg[c]) 70 | conv_out = conv_out.sum(0) 71 | conv_out = conv_out.view(batch_size, self.ref_size, self.out_channels) 72 | conv_out = conv_out.transpose(1, 2) 73 | if self.bias is not None: 74 | conv_out = conv_out + self.bias[:, None] 75 | return conv_out 76 | 77 | 78 | def gen_collate_fn(channels, max_time=5, ref_size=98, overlap_rate=.5, 79 | device=None): 80 | k_width = kernel_width(max_time, ref_size, overlap_rate) 81 | margin = k_width / 2 82 | refs_ = torch.linspace(margin, max_time - margin, ref_size) 83 | 84 | def collate_fn(batch): 85 | y0 = batch[0][0] 86 | refs = refs_.to(y0.device) 87 | 88 | pseudo = [[] for _ in range(channels)] 89 | cum_ref_idx = [[] for _ in range(channels)] 90 | concat_y = [[] for _ in range(channels)] 91 | deg = [[] for _ in range(channels)] 92 | 93 | for i, ts_info in enumerate(batch): 94 | y, t, m = ts_info[:3] 95 | for c in range(channels): 96 | tc = t[c][m[c] == 1] 97 | yc = y[c][m[c] == 1] 98 | dis = (tc - refs[:, None]) / k_width + .5 99 | mask = (dis <= 1) * (dis >= 0) 100 | ref_idx, t_idx = torch.nonzero(mask).t() 101 | # Pseudo coordinates in [0, 1] 102 | pseudo[c].append(dis[mask]) 103 | # Indices accumulated across mini-batch. Used for adding 104 | # convolution results to linearized padded tensor. 105 | cum_ref_idx[c].append(ref_idx + i * ref_size) 106 | concat_y[c].append(yc[t_idx]) 107 | deg[c].append(y0.new_zeros(ref_size).scatter_add_( 108 | 0, ref_idx, y0.new_ones(ref_idx.shape))) 109 | 110 | for c in range(channels): 111 | pseudo[c] = torch.cat(pseudo[c]).unsqueeze(1).to(device) 112 | cum_ref_idx[c] = torch.cat(cum_ref_idx[c]).unsqueeze(1).to(device) 113 | concat_y[c] = torch.cat(concat_y[c]).unsqueeze(1).to(device) 114 | # clamp(min=1) to avoid dividing by zero 115 | deg[c] = torch.cat(deg[c]).clamp(min=1).unsqueeze(1).to(device) 116 | 117 | converted_batch = [x.to(device) for x in default_collate(batch)] 118 | return converted_batch + [(pseudo, cum_ref_idx, concat_y, deg)] 119 | 120 | return collate_fn 121 | -------------------------------------------------------------------------------- /time-series/time_series.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.utils.data.dataloader import default_collate 4 | import numpy as np 5 | from sklearn.model_selection import train_test_split 6 | 7 | 8 | def kernel_width(max_time, cconv_ref, overlap_rate): 9 | return max_time / (cconv_ref + overlap_rate - overlap_rate * cconv_ref) 10 | 11 | 12 | class TimeSeries(Dataset): 13 | def __init__(self, data, time, mask, label=None, 14 | max_time=5, cconv_ref=98, overlap_rate=.5, device=None): 15 | self.data = torch.tensor(data, dtype=torch.float) 16 | self.time = torch.tensor(time, dtype=torch.float) 17 | self.mask = torch.tensor(mask, dtype=torch.float) 18 | 19 | if label is None: 20 | TimeSeries.__getitem__ = lambda self, index: ( 21 | self.data[index], self.time[index], self.mask[index], index) 22 | else: 23 | self.label = torch.tensor(label, dtype=torch.float) 24 | TimeSeries.__getitem__ = lambda self, index: ( 25 | self.data[index], self.time[index], self.mask[index], 26 | self.label[index], index) 27 | 28 | self.data_len, self.channels = self.data.shape[:2] 29 | self.cconv_ref = cconv_ref 30 | self.device = device 31 | k_width = kernel_width(max_time, cconv_ref, overlap_rate) 32 | margin = k_width / 2 33 | refs = torch.linspace(margin, max_time - margin, cconv_ref) 34 | 35 | self.pseudo, self.deg, self.ref_idx, self.t_idx = [ 36 | [[None] * self.channels for _ in range(self.data_len)] 37 | for _ in range(4)] 38 | 39 | for i, (y, t, m) in enumerate(zip(self.data, self.time, self.mask)): 40 | for c in range(self.channels): 41 | tc = t[c][m[c] == 1] 42 | dis = (tc - refs[:, None]) / k_width + .5 43 | dmask = (dis <= 1) * (dis >= 0) 44 | self.ref_idx[i][c], self.t_idx[i][c] = torch.nonzero(dmask).t() 45 | # Pseudo coordinates in [0, 1] 46 | self.pseudo[i][c] = dis[dmask] 47 | cur_deg = torch.zeros(self.cconv_ref) 48 | cur_deg.scatter_add_(0, self.ref_idx[i][c], 49 | torch.ones(self.ref_idx[i][c].shape)) 50 | self.deg[i][c] = cur_deg.clamp(min=1) 51 | 52 | def __len__(self): 53 | return self.data_len 54 | 55 | def make_graph(self, data, time, mask, index): 56 | pseudo = [ 57 | torch.cat([self.pseudo[idx][c] for idx in index]) 58 | .to(self.device).unsqueeze_(1).requires_grad_(False) 59 | for c in range(self.channels)] 60 | 61 | # Indices accumulated across mini-batch. Used for adding 62 | # convolution results to linearized padded tensor. 63 | cum_ref_idx = [ 64 | torch.cat([self.ref_idx[idx][c] + i * self.cconv_ref 65 | for i, idx in enumerate(index)]) 66 | .to(self.device).unsqueeze_(1).requires_grad_(False) 67 | for c in range(self.channels)] 68 | 69 | concat_y = [ 70 | torch.cat( 71 | [y[c][(m[c] == 1).requires_grad_(False)][self.t_idx[idx][c]] 72 | for y, m, idx in zip(data, mask, index)]) 73 | .to(self.device).unsqueeze_(1) 74 | for c in range(self.channels)] 75 | 76 | deg = [ 77 | torch.cat([self.deg[idx][c] for idx in index]) 78 | .to(self.device).unsqueeze_(1).requires_grad_(False) 79 | for c in range(self.channels)] 80 | 81 | return pseudo, cum_ref_idx, concat_y, deg 82 | 83 | def collate_fn(self, batch): 84 | batch = [x.to(self.device) for x in default_collate(batch)] 85 | # For labeled data, skip the label as the 4th entry. 86 | (data, time, mask), index = batch[:3], batch[-1] 87 | graph = self.make_graph(data, time, mask, index) 88 | return batch + [graph] 89 | 90 | 91 | def split_data(data_file, rnd, max_time, cconv_ref, overlap, device, 92 | rescale=False): 93 | raw_data = np.load(data_file) 94 | 95 | if len(raw_data) == 4: 96 | time_np = raw_data['time'] 97 | data_np = raw_data['data'] 98 | mask_np = raw_data['mask'] 99 | label_np = raw_data['label'].squeeze() 100 | 101 | (tv_time, test_time, tv_data, test_data, 102 | tv_mask, test_mask, tv_label, test_label) = train_test_split( 103 | time_np, data_np, mask_np, label_np, 104 | train_size=.8, stratify=label_np, random_state=rnd) 105 | 106 | (train_time, val_time, train_data, val_data, 107 | train_mask, val_mask, train_label, val_label) = train_test_split( 108 | tv_time, tv_data, tv_mask, tv_label, 109 | train_size=.8, stratify=tv_label, random_state=rnd) 110 | 111 | elif len(raw_data) == 8: 112 | tv_time = raw_data['train_time'] 113 | tv_data = raw_data['train_data'] 114 | tv_mask = raw_data['train_mask'] 115 | tv_label = raw_data['train_label'] 116 | 117 | test_time = raw_data['test_time'] 118 | test_data = raw_data['test_data'] 119 | test_mask = raw_data['test_mask'] 120 | test_label = raw_data['test_label'] 121 | 122 | (train_time, val_time, train_data, val_data, 123 | train_mask, val_mask, train_label, val_label) = train_test_split( 124 | tv_time, tv_data, tv_mask, tv_label, 125 | train_size=.8, stratify=tv_label, random_state=rnd) 126 | 127 | elif len(raw_data) == 12: 128 | train_time = raw_data['train_time'] 129 | train_data = raw_data['train_data'] 130 | train_mask = raw_data['train_mask'] 131 | train_label = raw_data['train_label'] 132 | 133 | test_time = raw_data['test_time'] 134 | test_data = raw_data['test_data'] 135 | test_mask = raw_data['test_mask'] 136 | test_label = raw_data['test_label'] 137 | 138 | val_time = raw_data['val_time'] 139 | val_data = raw_data['val_data'] 140 | val_mask = raw_data['val_mask'] 141 | val_label = raw_data['val_label'] 142 | else: 143 | raise Exception('Invalid data') 144 | 145 | # Scale time 146 | train_time *= max_time 147 | test_time *= max_time 148 | val_time *= max_time 149 | 150 | # Rescale data from [0, 1] to [-1, 1] 151 | if rescale: 152 | train_data = 2 * train_data - 1 153 | val_data = 2 * val_data - 1 154 | test_data = 2 * test_data - 1 155 | 156 | train_dataset = TimeSeries( 157 | train_data, train_time, train_mask, train_label, max_time=max_time, 158 | cconv_ref=cconv_ref, overlap_rate=overlap, device=device) 159 | 160 | val_dataset = TimeSeries( 161 | val_data, val_time, val_mask, val_label, max_time=max_time, 162 | cconv_ref=cconv_ref, overlap_rate=overlap, device=device) 163 | 164 | test_dataset = TimeSeries( 165 | test_data, test_time, test_mask, test_label, max_time=max_time, 166 | cconv_ref=cconv_ref, overlap_rate=overlap, device=device) 167 | 168 | return train_dataset, val_dataset, test_dataset 169 | -------------------------------------------------------------------------------- /time-series/toy_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils import spectral_norm 3 | 4 | 5 | def dconv_bn_relu(in_dim, out_dim): 6 | return nn.Sequential( 7 | nn.ConvTranspose1d(in_dim, out_dim, 5, 2, 8 | padding=2, output_padding=1, bias=False), 9 | nn.BatchNorm1d(out_dim), 10 | nn.ReLU()) 11 | 12 | 13 | class SeqGeneratorDiscrete(nn.Module): 14 | def __init__(self, n_channels=3, latent_size=128, squash=None): 15 | super().__init__() 16 | 17 | self.l1 = nn.Sequential( 18 | nn.Linear(latent_size, 2048, bias=False), 19 | nn.BatchNorm1d(2048), 20 | nn.ReLU()) 21 | 22 | self.l2 = nn.Sequential( 23 | dconv_bn_relu(256, 128), 24 | dconv_bn_relu(128, 64), 25 | dconv_bn_relu(64, 32), 26 | nn.ConvTranspose1d(32, n_channels, 5, 2, 27 | padding=2, output_padding=1)) 28 | self.squash = squash 29 | 30 | def forward(self, z): 31 | h = self.l1(z) 32 | h = h.view(h.shape[0], -1, 8) 33 | h = self.l2(h) 34 | if self.squash: 35 | h = self.squash(h) 36 | return h 37 | 38 | 39 | def conv_ln_lrelu(in_dim, out_dim): 40 | return nn.Sequential( 41 | spectral_norm(nn.Conv1d(in_dim, out_dim, 5, 2, 2)), 42 | nn.LeakyReLU(0.2)) 43 | -------------------------------------------------------------------------------- /time-series/toy_pbigan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import spectral_norm 5 | import torch.optim as optim 6 | from torch.utils.data import DataLoader 7 | import numpy as np 8 | from datetime import datetime 9 | import time 10 | from pathlib import Path 11 | import argparse 12 | from collections import defaultdict 13 | from spline_cconv import ContinuousConv1D 14 | from time_series import TimeSeries 15 | from ema import EMA 16 | from mmd import mmd 17 | from tracker import Tracker 18 | from vis import Visualizer 19 | from gen_toy_data import gen_data 20 | from utils import Rescaler, mkdir, make_scheduler 21 | from layers import Decoder, gan_loss 22 | from toy_layers import SeqGeneratorDiscrete, conv_ln_lrelu 23 | 24 | 25 | use_cuda = torch.cuda.is_available() 26 | device = torch.device('cuda' if use_cuda else 'cpu') 27 | 28 | 29 | class Encoder(nn.Module): 30 | def __init__(self, cconv, latent_size, norm_trans=True): 31 | super().__init__() 32 | self.cconv = cconv 33 | self.ls = nn.Sequential( 34 | nn.LeakyReLU(0.2), 35 | conv_ln_lrelu(64, 128), 36 | conv_ln_lrelu(128, 256), 37 | # conv_ln_lrelu(256, 512), 38 | # conv_ln_lrelu(512, 64), 39 | conv_ln_lrelu(256, 32), 40 | ) 41 | conv_size = 416 42 | self.fc = nn.Sequential( 43 | spectral_norm(nn.Linear(conv_size, latent_size * 2)), 44 | nn.LeakyReLU(0.2), 45 | spectral_norm(nn.Linear(latent_size * 2, latent_size * 2)), 46 | ) 47 | self.norm_trans = norm_trans 48 | if norm_trans: 49 | self.fc2 = nn.Sequential( 50 | spectral_norm(nn.Linear(latent_size, latent_size)), 51 | nn.LeakyReLU(0.2), 52 | spectral_norm(nn.Linear(latent_size, latent_size)), 53 | ) 54 | 55 | def forward(self, cconv_graph, batch_size): 56 | x = self.cconv(*cconv_graph, batch_size) 57 | # expected shape: (batch_size, 448) 58 | x = self.ls(x).view(x.shape[0], -1) 59 | mu, logvar = self.fc(x).chunk(2, dim=1) 60 | if self.training: 61 | std = F.softplus(logvar) 62 | eps = torch.empty_like(std).normal_() 63 | # return mu + eps * std, mu, logvar, eps 64 | z = mu + eps * std 65 | else: 66 | z = mu 67 | if self.norm_trans: 68 | z = self.fc2(z) 69 | return z 70 | 71 | 72 | class GridCritic(nn.Module): 73 | def __init__(self): 74 | super().__init__() 75 | self.ls = nn.Sequential( 76 | nn.LeakyReLU(0.2), 77 | conv_ln_lrelu(64, 128), 78 | conv_ln_lrelu(128, 256), 79 | # conv_ln_lrelu(256, 512), 80 | # conv_ln_lrelu(512, 1) 81 | conv_ln_lrelu(256, 1), 82 | ) 83 | 84 | def forward(self, x): 85 | # expected shape: (batch_size, 448) 86 | return self.ls(x) 87 | 88 | 89 | class ConvCritic(nn.Module): 90 | def __init__(self, cconv, latent_size, embed_size=13): 91 | super().__init__() 92 | self.cconv = cconv 93 | self.grid_critic = GridCritic() 94 | # self.x_dis = spectral_norm(nn.Linear(7, embed_size)) 95 | 96 | self.z_dis = nn.Sequential( 97 | spectral_norm(nn.Linear(latent_size, embed_size)), 98 | nn.LeakyReLU(0.2), 99 | spectral_norm(nn.Linear(embed_size, embed_size)), 100 | ) 101 | 102 | self.x_linear = spectral_norm(nn.Linear(embed_size, 1)) 103 | 104 | self.xz_dis = nn.Sequential( 105 | spectral_norm(nn.Linear(embed_size * 2, embed_size)), 106 | nn.LeakyReLU(0.2), 107 | spectral_norm(nn.Linear(embed_size, 1)), 108 | ) 109 | 110 | def forward(self, cconv_graph, batch_size, z): 111 | x = self.cconv(*cconv_graph, batch_size) 112 | x = self.grid_critic(x) 113 | x = x.squeeze(1) 114 | # x = self.x_dis(x) 115 | z = self.z_dis(z) 116 | xz = torch.cat((x, z), 1) 117 | xz = self.xz_dis(xz) 118 | x_out = self.x_linear(x).view(-1) 119 | xz_out = xz.view(-1) 120 | return xz_out, x_out 121 | 122 | 123 | class PBiGAN(nn.Module): 124 | def __init__(self, encoder, decoder, ae_loss='mse'): 125 | super().__init__() 126 | self.encoder = encoder 127 | self.decoder = decoder 128 | self.ae_loss = ae_loss 129 | 130 | def forward(self, data, time, mask, cconv_graph, time_t, mask_t): 131 | batch_size = len(data) 132 | z_T = self.encoder(cconv_graph, batch_size) 133 | 134 | z_gen = torch.empty_like(z_T).normal_() 135 | x_gen = self.decoder(z_gen, time_t, mask_t) 136 | 137 | x_recon = self.decoder(z_T, time, mask) 138 | 139 | if self.ae_loss == 'mse': 140 | ae_loss = F.mse_loss(x_recon, data, reduction='none') * mask 141 | elif self.ae_loss == 'smooth_l1': 142 | ae_loss = F.smooth_l1_loss(x_recon, data, reduction='none') * mask 143 | 144 | ae_loss = ae_loss.sum((-1, -2)) 145 | 146 | return z_T, x_recon, z_gen, x_gen, ae_loss.mean() 147 | 148 | 149 | def main(): 150 | parser = argparse.ArgumentParser() 151 | 152 | default_dataset = 'toy-data.npz' 153 | parser.add_argument('--data', default=default_dataset, 154 | help='data file') 155 | parser.add_argument('--seed', type=int, default=None, 156 | help='random seed. Randomly set if not specified.') 157 | 158 | # training options 159 | parser.add_argument('--nz', type=int, default=32, 160 | help='dimension of latent variable') 161 | parser.add_argument('--epoch', type=int, default=1000, 162 | help='number of training epochs') 163 | parser.add_argument('--batch-size', type=int, default=128, 164 | help='batch size') 165 | parser.add_argument('--lr', type=float, default=8e-5, 166 | help='encoder/decoder learning rate') 167 | parser.add_argument('--dis-lr', type=float, default=1e-4, 168 | help='discriminator learning rate') 169 | parser.add_argument('--min-lr', type=float, default=5e-5, 170 | help='min encoder/decoder learning rate for LR ' 171 | 'scheduler. -1 to disable annealing') 172 | parser.add_argument('--min-dis-lr', type=float, default=7e-5, 173 | help='min discriminator learning rate for LR ' 174 | 'scheduler. -1 to disable annealing') 175 | parser.add_argument('--wd', type=float, default=0, 176 | help='weight decay') 177 | parser.add_argument('--overlap', type=float, default=.5, 178 | help='kernel overlap') 179 | parser.add_argument('--no-norm-trans', action='store_true', 180 | help='if set, use Gaussian posterior without ' 181 | 'transformation') 182 | parser.add_argument('--plot-interval', type=int, default=1, 183 | help='plot interval. 0 to disable plotting.') 184 | parser.add_argument('--save-interval', type=int, default=0, 185 | help='interval to save models. 0 to disable saving.') 186 | parser.add_argument('--prefix', default='pbigan', 187 | help='prefix of output directory') 188 | parser.add_argument('--comp', type=int, default=7, 189 | help='continuous convolution kernel size') 190 | parser.add_argument('--ae', type=float, default=.2, 191 | help='autoencoding regularization strength') 192 | parser.add_argument('--aeloss', default='smooth_l1', 193 | help='autoencoding loss. (options: mse, smooth_l1)') 194 | parser.add_argument('--ema', dest='ema', type=int, default=-1, 195 | help='start epoch of exponential moving average ' 196 | '(EMA). -1 to disable EMA') 197 | parser.add_argument('--ema-decay', type=float, default=.9999, 198 | help='EMA decay') 199 | parser.add_argument('--mmd', type=float, default=1, 200 | help='MMD strength for latent variable') 201 | 202 | # squash is off when rescale is off 203 | parser.add_argument('--squash', dest='squash', action='store_const', 204 | const=True, default=True, 205 | help='bound the generated time series value ' 206 | 'using tanh') 207 | parser.add_argument('--no-squash', dest='squash', action='store_const', 208 | const=False) 209 | 210 | # rescale to [-1, 1] 211 | parser.add_argument('--rescale', dest='rescale', action='store_const', 212 | const=True, default=True, 213 | help='if set, rescale time to [-1, 1]') 214 | parser.add_argument('--no-rescale', dest='rescale', action='store_const', 215 | const=False) 216 | 217 | args = parser.parse_args() 218 | 219 | batch_size = args.batch_size 220 | nz = args.nz 221 | 222 | epochs = args.epoch 223 | plot_interval = args.plot_interval 224 | save_interval = args.save_interval 225 | 226 | try: 227 | npz = np.load(args.data) 228 | train_data = npz['data'] 229 | train_time = npz['time'] 230 | train_mask = npz['mask'] 231 | except FileNotFoundError: 232 | if args.data != default_dataset: 233 | raise 234 | # Generate the default toy dataset from scratch 235 | train_data, train_time, train_mask, _, _ = gen_data( 236 | n_samples=10000, seq_len=200, max_time=1, poisson_rate=50, 237 | obs_span_rate=.25, save_file=default_dataset) 238 | 239 | _, in_channels, seq_len = train_data.shape 240 | train_time *= train_mask 241 | 242 | if args.seed is None: 243 | rnd = np.random.RandomState(None) 244 | random_seed = rnd.randint(np.iinfo(np.uint32).max) 245 | else: 246 | random_seed = args.seed 247 | rnd = np.random.RandomState(random_seed) 248 | np.random.seed(random_seed) 249 | torch.manual_seed(random_seed) 250 | 251 | # Scale time 252 | max_time = 5 253 | train_time *= max_time 254 | 255 | squash = None 256 | rescaler = None 257 | if args.rescale: 258 | rescaler = Rescaler(train_data) 259 | train_data = rescaler.rescale(train_data) 260 | if args.squash: 261 | squash = torch.tanh 262 | 263 | out_channels = 64 264 | cconv_ref = 98 265 | 266 | train_dataset = TimeSeries( 267 | train_data, train_time, train_mask, label=None, max_time=max_time, 268 | cconv_ref=cconv_ref, overlap_rate=args.overlap, device=device) 269 | 270 | train_loader = DataLoader( 271 | train_dataset, batch_size=batch_size, shuffle=True, 272 | drop_last=True, collate_fn=train_dataset.collate_fn) 273 | n_train_batch = len(train_loader) 274 | 275 | time_loader = DataLoader( 276 | train_dataset, batch_size=batch_size, shuffle=True, 277 | drop_last=True, collate_fn=train_dataset.collate_fn) 278 | 279 | test_loader = DataLoader(train_dataset, batch_size=batch_size, 280 | collate_fn=train_dataset.collate_fn) 281 | 282 | grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash) 283 | decoder = Decoder(grid_decoder, max_time=max_time).to(device) 284 | 285 | cconv = ContinuousConv1D( 286 | in_channels, out_channels, max_time, cconv_ref, 287 | overlap_rate=args.overlap, kernel_size=args.comp, norm=True).to(device) 288 | encoder = Encoder(cconv, nz, not args.no_norm_trans).to(device) 289 | 290 | pbigan = PBiGAN(encoder, decoder, args.aeloss).to(device) 291 | 292 | critic_cconv = ContinuousConv1D( 293 | in_channels, out_channels, max_time, cconv_ref, 294 | overlap_rate=args.overlap, kernel_size=args.comp, norm=True).to(device) 295 | critic = ConvCritic(critic_cconv, nz).to(device) 296 | 297 | ema = None 298 | if args.ema >= 0: 299 | ema = EMA(pbigan, args.ema_decay, args.ema) 300 | 301 | optimizer = optim.Adam( 302 | pbigan.parameters(), lr=args.lr, weight_decay=args.wd) 303 | critic_optimizer = optim.Adam( 304 | critic.parameters(), lr=args.dis_lr, weight_decay=args.wd) 305 | 306 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs) 307 | dis_scheduler = make_scheduler( 308 | critic_optimizer, args.dis_lr, args.min_dis_lr, epochs) 309 | 310 | path = '{}_{}'.format( 311 | args.prefix, datetime.now().strftime('%m%d.%H%M%S')) 312 | 313 | output_dir = Path('results') / 'toy-pbigan' / path 314 | print(output_dir) 315 | log_dir = mkdir(output_dir / 'log') 316 | model_dir = mkdir(output_dir / 'model') 317 | 318 | start_epoch = 0 319 | 320 | with (log_dir / 'seed.txt').open('w') as f: 321 | print(random_seed, file=f) 322 | with (log_dir / 'gpu.txt').open('a') as f: 323 | print(torch.cuda.device_count(), start_epoch, file=f) 324 | with (log_dir / 'args.txt').open('w') as f: 325 | for key, val in sorted(vars(args).items()): 326 | print(f'{key}: {val}', file=f) 327 | 328 | tracker = Tracker(log_dir, n_train_batch) 329 | visualizer = Visualizer(encoder, decoder, batch_size, max_time, 330 | test_loader, rescaler, output_dir, device) 331 | start = time.time() 332 | epoch_start = start 333 | 334 | for epoch in range(start_epoch, epochs): 335 | loss_breakdown = defaultdict(float) 336 | 337 | for ((val, idx, mask, _, cconv_graph), 338 | (_, idx_t, mask_t, index, _)) in zip( 339 | train_loader, time_loader): 340 | 341 | z_enc, x_recon, z_gen, x_gen, ae_loss = pbigan( 342 | val, idx, mask, cconv_graph, idx_t, mask_t) 343 | 344 | cconv_graph_gen = train_dataset.make_graph( 345 | x_gen, idx_t, mask_t, index) 346 | 347 | real = critic(cconv_graph, batch_size, z_enc) 348 | fake = critic(cconv_graph_gen, batch_size, z_gen) 349 | 350 | D_loss = gan_loss(real, fake, 1, 0) 351 | 352 | critic_optimizer.zero_grad() 353 | D_loss.backward(retain_graph=True) 354 | critic_optimizer.step() 355 | 356 | G_loss = gan_loss(real, fake, 0, 1) 357 | 358 | mmd_loss = mmd(z_enc, z_gen) 359 | 360 | loss = G_loss + ae_loss * args.ae + mmd_loss * args.mmd 361 | 362 | optimizer.zero_grad() 363 | loss.backward() 364 | optimizer.step() 365 | 366 | if ema: 367 | ema.update() 368 | 369 | loss_breakdown['D'] += D_loss.item() 370 | loss_breakdown['G'] += G_loss.item() 371 | loss_breakdown['AE'] += ae_loss.item() 372 | loss_breakdown['MMD'] += mmd_loss.item() 373 | loss_breakdown['total'] += loss.item() 374 | 375 | if scheduler: 376 | scheduler.step() 377 | if dis_scheduler: 378 | dis_scheduler.step() 379 | 380 | cur_time = time.time() 381 | tracker.log( 382 | epoch, loss_breakdown, cur_time - epoch_start, cur_time - start) 383 | 384 | if plot_interval > 0 and (epoch + 1) % plot_interval == 0: 385 | if ema: 386 | ema.apply() 387 | visualizer.plot(epoch) 388 | ema.restore() 389 | else: 390 | visualizer.plot(epoch) 391 | 392 | model_dict = { 393 | 'pbigan': pbigan.state_dict(), 394 | 'critic': critic.state_dict(), 395 | 'ema': ema.state_dict() if ema else None, 396 | 'epoch': epoch + 1, 397 | 'args': args, 398 | } 399 | torch.save(model_dict, str(log_dir / 'model.pth')) 400 | if save_interval > 0 and (epoch + 1) % save_interval == 0: 401 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) 402 | 403 | print(output_dir) 404 | 405 | 406 | if __name__ == '__main__': 407 | main() 408 | -------------------------------------------------------------------------------- /time-series/toy_pvae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import spectral_norm 5 | import torch.optim as optim 6 | from torch.utils.data import DataLoader 7 | import numpy as np 8 | from datetime import datetime 9 | import time 10 | from pathlib import Path 11 | import argparse 12 | from collections import defaultdict 13 | from spline_cconv import ContinuousConv1D 14 | from time_series import TimeSeries 15 | from tracker import Tracker 16 | from vis import Visualizer 17 | from gen_toy_data import gen_data 18 | from utils import Rescaler, mkdir, make_scheduler 19 | from layers import Decoder 20 | from toy_layers import SeqGeneratorDiscrete, conv_ln_lrelu 21 | 22 | 23 | use_cuda = torch.cuda.is_available() 24 | device = torch.device('cuda' if use_cuda else 'cpu') 25 | 26 | 27 | class Encoder(nn.Module): 28 | def __init__(self, latent_size, cconv): 29 | super().__init__() 30 | self.cconv = cconv 31 | self.ls = nn.Sequential( 32 | nn.LeakyReLU(0.2), 33 | conv_ln_lrelu(64, 128), 34 | conv_ln_lrelu(128, 256), 35 | conv_ln_lrelu(256, 512), 36 | conv_ln_lrelu(512, 64), 37 | ) 38 | conv_size = 448 39 | self.fc = nn.Sequential( 40 | spectral_norm(nn.Linear(conv_size, latent_size * 2)), 41 | nn.LeakyReLU(0.2), 42 | spectral_norm(nn.Linear(latent_size * 2, latent_size * 2)), 43 | ) 44 | 45 | def forward(self, cconv_graph, batch_size): 46 | x = self.cconv(*cconv_graph, batch_size) 47 | # expected shape: (batch_size, 448) 48 | x = self.ls(x).view(x.shape[0], -1) 49 | mu, logvar = self.fc(x).chunk(2, dim=1) 50 | std = torch.exp(logvar * .5) 51 | eps = torch.empty_like(std).normal_() 52 | return mu + eps * std, mu, logvar, eps 53 | 54 | 55 | class PVAE(nn.Module): 56 | def __init__(self, encoder, decoder, sigma=.2): 57 | super().__init__() 58 | self.encoder = encoder 59 | self.decoder = decoder 60 | self.sigma = sigma 61 | 62 | def forward(self, data, time, mask, cconv_graph): 63 | batch_size = len(data) 64 | z, mu, logvar, eps = self.encoder(cconv_graph, batch_size) 65 | x_recon = self.decoder(z, time, mask) 66 | # Gaussian noise 67 | recon_loss = (1 / (2 * self.sigma**2) * F.mse_loss( 68 | x_recon * mask, data * mask, reduction='none') * mask).sum((1, 2)) 69 | kl_loss = .5 * (z**2 - logvar - eps**2).sum(1) 70 | loss = recon_loss.mean() + kl_loss.mean() 71 | return loss 72 | 73 | 74 | def main(): 75 | parser = argparse.ArgumentParser() 76 | 77 | default_dataset = 'toy-data.npz' 78 | parser.add_argument('--data', default=default_dataset, 79 | help='data file') 80 | parser.add_argument('--seed', type=int, default=None, 81 | help='random seed. Randomly set if not specified.') 82 | 83 | # training options 84 | parser.add_argument('--nz', type=int, default=32, 85 | help='dimension of latent variable') 86 | parser.add_argument('--epoch', type=int, default=1000, 87 | help='number of training epochs') 88 | parser.add_argument('--batch-size', type=int, default=128, 89 | help='batch size') 90 | parser.add_argument('--lr', type=float, default=1e-4, 91 | help='learning rate') 92 | parser.add_argument('--min-lr', type=float, default=5e-5, 93 | help='min learning rate for LR scheduler. ' 94 | '-1 to disable annealing') 95 | parser.add_argument('--plot-interval', type=int, default=10, 96 | help='plot interval. 0 to disable plotting.') 97 | parser.add_argument('--save-interval', type=int, default=0, 98 | help='interval to save models. 0 to disable saving.') 99 | parser.add_argument('--prefix', default='pvae', 100 | help='prefix of output directory') 101 | parser.add_argument('--comp', type=int, default=5, 102 | help='continuous convolution kernel size') 103 | parser.add_argument('--sigma', type=float, default=.2, 104 | help='standard deviation for Gaussian likelihood') 105 | parser.add_argument('--overlap', type=float, default=.5, 106 | help='kernel overlap') 107 | # squash is off when rescale is off 108 | parser.add_argument('--squash', dest='squash', action='store_const', 109 | const=True, default=True, 110 | help='bound the generated time series value ' 111 | 'using tanh') 112 | parser.add_argument('--no-squash', dest='squash', action='store_const', 113 | const=False) 114 | 115 | # rescale to [-1, 1] 116 | parser.add_argument('--rescale', dest='rescale', action='store_const', 117 | const=True, default=True, 118 | help='if set, rescale time to [-1, 1]') 119 | parser.add_argument('--no-rescale', dest='rescale', action='store_const', 120 | const=False) 121 | 122 | args = parser.parse_args() 123 | 124 | batch_size = args.batch_size 125 | nz = args.nz 126 | 127 | epochs = args.epoch 128 | plot_interval = args.plot_interval 129 | save_interval = args.save_interval 130 | 131 | try: 132 | npz = np.load(args.data) 133 | train_data = npz['data'] 134 | train_time = npz['time'] 135 | train_mask = npz['mask'] 136 | except FileNotFoundError: 137 | if args.data != default_dataset: 138 | raise 139 | # Generate the default toy dataset from scratch 140 | train_data, train_time, train_mask, _, _ = gen_data( 141 | n_samples=10000, seq_len=200, max_time=1, poisson_rate=50, 142 | obs_span_rate=.25, save_file=default_dataset) 143 | 144 | _, in_channels, seq_len = train_data.shape 145 | train_time *= train_mask 146 | 147 | if args.seed is None: 148 | rnd = np.random.RandomState(None) 149 | random_seed = rnd.randint(np.iinfo(np.uint32).max) 150 | else: 151 | random_seed = args.seed 152 | rnd = np.random.RandomState(random_seed) 153 | np.random.seed(random_seed) 154 | torch.manual_seed(random_seed) 155 | 156 | # Scale time 157 | max_time = 5 158 | train_time *= max_time 159 | 160 | squash = None 161 | rescaler = None 162 | if args.rescale: 163 | rescaler = Rescaler(train_data) 164 | train_data = rescaler.rescale(train_data) 165 | if args.squash: 166 | squash = torch.tanh 167 | 168 | out_channels = 64 169 | cconv_ref = 98 170 | 171 | train_dataset = TimeSeries( 172 | train_data, train_time, train_mask, label=None, max_time=max_time, 173 | cconv_ref=cconv_ref, overlap_rate=args.overlap, device=device) 174 | 175 | train_loader = DataLoader( 176 | train_dataset, batch_size=batch_size, shuffle=True, 177 | drop_last=True, collate_fn=train_dataset.collate_fn) 178 | n_train_batch = len(train_loader) 179 | 180 | test_batch_size = 64 181 | test_loader = DataLoader(train_dataset, batch_size=test_batch_size, 182 | collate_fn=train_dataset.collate_fn) 183 | 184 | grid_decoder = SeqGeneratorDiscrete(in_channels, nz, squash) 185 | decoder = Decoder(grid_decoder, max_time=max_time).to(device) 186 | 187 | cconv = ContinuousConv1D( 188 | in_channels, out_channels, max_time, cconv_ref, 189 | overlap_rate=args.overlap, kernel_size=args.comp, norm=True).to(device) 190 | 191 | encoder = Encoder(nz, cconv).to(device) 192 | 193 | pvae = PVAE(encoder, decoder, sigma=args.sigma).to(device) 194 | 195 | optimizer = optim.Adam(pvae.parameters(), lr=args.lr) 196 | 197 | scheduler = make_scheduler(optimizer, args.lr, args.min_lr, epochs) 198 | 199 | path = '{}_{}_{}'.format( 200 | args.prefix, datetime.now().strftime('%m%d.%H%M%S'), 201 | '_'.join([f'lr_{args.lr:g}'])) 202 | 203 | output_dir = Path('results') / 'toy-pvae' / path 204 | print(output_dir) 205 | log_dir = mkdir(output_dir / 'log') 206 | model_dir = mkdir(output_dir / 'model') 207 | 208 | start_epoch = 0 209 | 210 | with (log_dir / 'seed.txt').open('w') as f: 211 | print(random_seed, file=f) 212 | with (log_dir / 'gpu.txt').open('a') as f: 213 | print(torch.cuda.device_count(), start_epoch, file=f) 214 | with (log_dir / 'args.txt').open('w') as f: 215 | for key, val in sorted(vars(args).items()): 216 | print(f'{key}: {val}', file=f) 217 | 218 | tracker = Tracker(log_dir, n_train_batch) 219 | visualizer = Visualizer(encoder, decoder, test_batch_size, max_time, 220 | test_loader, rescaler, output_dir, device) 221 | start = time.time() 222 | epoch_start = start 223 | 224 | for epoch in range(start_epoch, epochs): 225 | loss_breakdown = defaultdict(float) 226 | for val, idx, mask, _, cconv_graph in train_loader: 227 | optimizer.zero_grad() 228 | loss = pvae(val, idx, mask, cconv_graph) 229 | loss.backward() 230 | optimizer.step() 231 | loss_breakdown['loss'] += loss.item() 232 | 233 | if scheduler: 234 | scheduler.step() 235 | 236 | cur_time = time.time() 237 | tracker.log( 238 | epoch, loss_breakdown, cur_time - epoch_start, cur_time - start) 239 | 240 | if plot_interval > 0 and (epoch + 1) % plot_interval == 0: 241 | visualizer.plot(epoch) 242 | 243 | model_dict = { 244 | 'pvae': pvae.state_dict(), 245 | 'epoch': epoch + 1, 246 | 'args': args, 247 | } 248 | torch.save(model_dict, str(log_dir / 'model.pth')) 249 | if save_interval > 0 and (epoch + 1) % save_interval == 0: 250 | torch.save(model_dict, str(model_dir / f'{epoch:04d}.pth')) 251 | 252 | print(output_dir) 253 | 254 | 255 | if __name__ == '__main__': 256 | main() 257 | -------------------------------------------------------------------------------- /time-series/tracker.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | import logging 4 | import sys 5 | 6 | 7 | class Tracker: 8 | def __init__(self, log_dir, n_train_batch): 9 | self.log_dir = log_dir 10 | self.n_train_batch = n_train_batch 11 | self.loss = defaultdict(list) 12 | 13 | logging.basicConfig( 14 | level=logging.INFO, 15 | format='%(asctime)s %(message)s', 16 | datefmt='%Y-%m-%d %H:%M:%S', 17 | handlers=[ 18 | logging.FileHandler(log_dir / 'log.txt'), 19 | logging.StreamHandler(sys.stdout), 20 | ], 21 | ) 22 | self.print_header = True 23 | 24 | def log(self, epoch, loss_breakdown, epoch_time, time_elapsed): 25 | for loss_name, loss_val in loss_breakdown.items(): 26 | self.loss[loss_name].append(loss_val / self.n_train_batch) 27 | 28 | if self.print_header: 29 | logging.info(' ' * 7 + ' '.join( 30 | f'{key:>12}' for key in sorted(self.loss))) 31 | self.print_header = False 32 | logging.info(f'[{epoch:4}] ' + ' '.join( 33 | f'{val[-1]:12.4f}' for _, val in sorted(self.loss.items()))) 34 | 35 | torch.save(self.loss, str(self.log_dir / 'log.pth')) 36 | 37 | with (self.log_dir / 'time.txt').open('a') as f: 38 | print(epoch, epoch_time, time_elapsed, file=f) 39 | -------------------------------------------------------------------------------- /time-series/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import numpy as np 4 | 5 | 6 | def to_numpy(v): 7 | if torch.is_tensor(v): 8 | return v.cpu().numpy() 9 | return v 10 | 11 | 12 | def count_parameters(model): 13 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 14 | 15 | 16 | class Rescaler: 17 | def __init__(self, data): 18 | channels = data.shape[1] 19 | ch_min = np.array([data[:, i].min() for i in range(channels)]) 20 | ch_max = np.array([data[:, i].max() for i in range(channels)]) 21 | self.ch_min, self.ch_max = ch_min[:, None], ch_max[:, None] 22 | 23 | def rescale(self, data): 24 | return 2 * (data - self.ch_min) / (self.ch_max - self.ch_min) - 1 25 | 26 | def unrescale(self, data): 27 | return .5 * (data + 1) * (self.ch_max - self.ch_min) + self.ch_min 28 | 29 | 30 | def mkdir(path): 31 | path.mkdir(parents=True, exist_ok=True) 32 | return path 33 | 34 | 35 | def make_scheduler(optimizer, lr, min_lr, epochs, steps=10): 36 | if min_lr < 0: 37 | return None 38 | step_size = epochs // steps 39 | gamma = (min_lr / lr)**(1 / steps) 40 | return optim.lr_scheduler.StepLR( 41 | optimizer, step_size=step_size, gamma=gamma) 42 | -------------------------------------------------------------------------------- /time-series/vis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import matplotlib.gridspec as gridspec 4 | import seaborn as sns 5 | from utils import mkdir, to_numpy 6 | 7 | 8 | sns.set() 9 | 10 | sns.set_style('darkgrid', { 11 | 'axes.spines.left': True, 12 | 'axes.spines.bottom': True, 13 | 'axes.spines.right': True, 14 | 'axes.spines.top': True, 15 | 'patch.edgecolor': 'k', 16 | 'axes.edgecolor': '.1', 17 | 'axes.facecolor': '.95', 18 | }) 19 | 20 | 21 | def plot_samples(data_unif, time_unif, data=None, time=None, mask=None, 22 | rescaler=None, max_time=1, img_path=None, nrows=2, ncols=4): 23 | data_unif = to_numpy(data_unif) 24 | time_unif = to_numpy(time_unif) 25 | n_channels = data_unif.shape[1] 26 | if rescaler: 27 | data_unif = rescaler.unrescale(data_unif) 28 | 29 | if data is not None: 30 | data = to_numpy(data) 31 | time = to_numpy(time) 32 | mask = to_numpy(mask) 33 | if rescaler: 34 | data = rescaler.unrescale(data) 35 | 36 | fig = plt.figure(figsize=(6 * ncols, 2 * n_channels * nrows)) 37 | gs = gridspec.GridSpec(nrows, ncols, wspace=.1, hspace=.1) 38 | 39 | for i in range(nrows * ncols): 40 | outer_ax = plt.subplot(gs[i]) 41 | outer_ax.set_xticks([]) 42 | outer_ax.set_yticks([]) 43 | 44 | inner_grid = gridspec.GridSpecFromSubplotSpec( 45 | n_channels, 1, subplot_spec=gs[i], hspace=0) 46 | for k in range(n_channels): 47 | ax = plt.Subplot(fig, inner_grid[k]) 48 | ax.plot(time_unif[i, k], data_unif[i, k], 49 | 'k-', alpha=.5, linewidth=.6) 50 | if data is not None: 51 | ax.scatter(time[i, k, mask[i, k] == 1], 52 | data[i, k, mask[i, k] == 1], c='r', 53 | s=30, 54 | linewidth=1.5, 55 | marker='x') 56 | ax.set_ylim(-1.2, 1.2) 57 | ax.set_xlim(0, max_time) 58 | ax.axes.xaxis.set_ticklabels([]) 59 | ax.axes.yaxis.set_ticklabels([]) 60 | fig.add_subplot(ax) 61 | 62 | if img_path: 63 | plt.savefig(img_path, bbox_inches='tight') 64 | plt.close(fig) 65 | 66 | 67 | class Visualizer: 68 | def __init__(self, encoder, decoder, batch_size, max_time, test_loader, 69 | rescaler, output_dir, device): 70 | self.encoder = encoder 71 | self.decoder = decoder 72 | self.batch_size = batch_size 73 | self.rescaler = rescaler 74 | (self.test_val, self.test_idx, self.test_mask, 75 | _, self.test_cconv_graph) = next(iter(test_loader)) 76 | in_channels = self.test_val.shape[1] 77 | self.max_time = max_time 78 | t = torch.linspace(0, max_time, 200, device=device) 79 | self.t = t.expand(batch_size, in_channels, len(t)).contiguous() 80 | self.t_mask = torch.ones_like(self.t) 81 | 82 | self.gen_data_dir = mkdir(output_dir / 'gen') 83 | self.imp_data_dir = mkdir(output_dir / 'imp') 84 | 85 | def plot(self, epoch): 86 | filename = f'{epoch:04d}.png' 87 | 88 | self.encoder.eval() 89 | self.decoder.eval() 90 | with torch.no_grad(): 91 | z = self.encoder(self.test_cconv_graph, self.batch_size) 92 | if not torch.is_tensor(z): # P-VAE encoder returns a list 93 | z = z[0] 94 | imp_data = self.decoder(z, self.t, self.t_mask) 95 | plot_samples(imp_data, self.t, 96 | self.test_val, self.test_idx, self.test_mask, 97 | rescaler=self.rescaler, 98 | max_time=self.max_time, 99 | img_path=f'{self.imp_data_dir / filename}') 100 | 101 | data_noise = torch.empty_like(z).normal_() 102 | gen_data = self.decoder(data_noise, self.t, self.t_mask) 103 | plot_samples(gen_data, self.t, 104 | rescaler=self.rescaler, 105 | max_time=self.max_time, 106 | img_path=f'{self.gen_data_dir / filename}') 107 | self.decoder.train() 108 | self.encoder.train() 109 | --------------------------------------------------------------------------------