├── images ├── intro.png ├── models.png └── objective.png ├── requirements.txt ├── LICENSE ├── README.md ├── .gitignore ├── data.py ├── utils.py ├── models.py └── p2e_wgan_gp.py /images/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khuongav/P2E-WGAN-ecg-ppg-reconstruction/HEAD/images/intro.png -------------------------------------------------------------------------------- /images/models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khuongav/P2E-WGAN-ecg-ppg-reconstruction/HEAD/images/models.png -------------------------------------------------------------------------------- /images/objective.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/khuongav/P2E-WGAN-ecg-ppg-reconstruction/HEAD/images/objective.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | argparse 2 | joblib 3 | matplotlib 4 | numpy 5 | scipy 6 | tensorboardX 7 | torch-summary 8 | torch>=1.6.0 9 | torchvision>=0.7.0 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Khuong (Anthony) Vo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Codebase for "P2E-WGAN: ECG Waveform Synthesis from PPG with Conditional Wasserstein Generative Adversarial Networks" 2 | 3 | Paper link: https://dl.acm.org/doi/10.1145/3412841.3441979 4 | 5 | ![](images/intro.png) 6 | 7 | ## Model 8 | 9 | #### ECG feature-based WGAN-GP loss function: 10 | 11 | ![](images/objective.png) 12 | 13 | #### End-to-end 1D convolutional network architectures: 14 | 15 | ![](images/models.png) 16 | 17 | ## Setup 18 | 19 | To install the dependencies, you can run in your terminal: 20 | ```sh 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | A sampled dataset with ECG feature indices can be downloaded at [\[link\]](https://drive.google.com/file/d/1lLTerHpAx0w3Xg2QxZCuI6wAxpuC0TCH/view?usp=sharing). 25 | 26 | ## Usage 27 | 28 | The code is structured as follows: 29 | - `data.py` contains functions to transform and feed the data to the model; 30 | - `models.py` defines deep neural network architectures; 31 | - `utils.py` has utilities to benchmark the model and calculate the gradient penalty; 32 | - `p2e_wgan_gp.py` is the main entry to run the training and evaluation process (support running on multiple GPUs); 33 | - `--dataset_prefix` flag sets the directory containing the .npy files 34 | - `--peaks_only` flag sets the model to reconstruct precisely only the main features for data augmentation purposes 35 | 36 | ## Citation 37 | 38 | If you find this code helpful in any way, please cite our paper: 39 | 40 | @inproceedings{vo2021p2e, 41 | title={P2E-WGAN: ECG waveform synthesis from PPG with conditional wasserstein generative adversarial networks}, 42 | author={Vo, Khuong and Naeini, Emad Kasaeyan and Naderi, Amir and Jilani, Daniel and Rahmani, Amir M and Dutt, Nikil and Cao, Hung}, 43 | booktitle={Proceedings of the 36th Annual ACM Symposium on Applied Computing}, 44 | pages={1030--1036}, 45 | year={2021} 46 | } 47 | 48 | ## Acknowledgments 49 | 50 | The implementation of the WGAN-GP model is based on this repository: https://github.com/eriklindernoren/PyTorch-GAN 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | from torchvision import transforms 5 | 6 | 7 | class BioData(Dataset): 8 | def __init__(self, from_data_path, to_data_path, transformer, ecg_peaks_data_path): 9 | 10 | with open(from_data_path, 'rb') as f: 11 | self.from_dataset = np.load(f) 12 | print(from_data_path, self.from_dataset.shape) 13 | 14 | with open(to_data_path, 'rb') as f: 15 | self.to_dataset = np.load(f) 16 | print(to_data_path, self.to_dataset.shape) 17 | 18 | if ecg_peaks_data_path: 19 | self.from_ppg = True 20 | with open(ecg_peaks_data_path['opeaks'], 'rb') as f: 21 | self.opeaks_dataset = np.load(f) 22 | print(ecg_peaks_data_path['opeaks'], self.opeaks_dataset.shape) 23 | 24 | with open(ecg_peaks_data_path['rpeaks'], 'rb') as f: 25 | self.rpeaks_dataset = np.load(f) 26 | print(ecg_peaks_data_path['rpeaks'], self.rpeaks_dataset.shape) 27 | else: 28 | self.from_ppg = False 29 | 30 | self.transformer = transformer 31 | 32 | def __len__(self): 33 | return self.from_dataset.shape[0] 34 | 35 | def __getitem__(self, idx): 36 | X = self.from_dataset[idx][np.newaxis, :] 37 | y = self.to_dataset[idx][np.newaxis, :] 38 | X = self.transformer(X) 39 | y = self.transformer(y) 40 | 41 | if self.from_ppg: 42 | opeaks = self.opeaks_dataset[idx][np.newaxis, :] 43 | rpeaks = self.rpeaks_dataset[idx][np.newaxis, :] 44 | opeaks = self.transformer(opeaks) 45 | rpeaks = self.transformer(rpeaks) 46 | 47 | if self.from_ppg: 48 | return X.float(), y.float(), opeaks.float(), rpeaks.float() 49 | else: 50 | return X.float(), y.float() 51 | 52 | 53 | class NP_to_Tensor(object): 54 | def __call__(self, sample): 55 | return torch.tensor(sample) 56 | 57 | 58 | def get_bio_data(from_data_path, to_data_path, ecg_peaks_data_path=None): 59 | transformer = transforms.Compose([NP_to_Tensor()]) 60 | return BioData(from_data_path, to_data_path, transformer, ecg_peaks_data_path) 61 | 62 | 63 | def get_data_loader(dataset_prefix, batch_size, from_ppg, shuffle_training=True, shuffle_testing=False): 64 | train_ppg_data_path = dataset_prefix + 'ppg_train.npy' 65 | eval_ppg_data_path = dataset_prefix + 'ppg_eval.npy' 66 | 67 | train_ecg_data_path = dataset_prefix + 'ecg_train.npy' 68 | eval_ecg_data_path = dataset_prefix + 'ecg_eval.npy' 69 | 70 | train_ecg_peaks_data_path = {'opeaks': dataset_prefix + 'ecg_opeaks_train.npy', 71 | 'rpeaks': dataset_prefix + 'ecg_rpeaks_train.npy'} 72 | 73 | eval_ecg_peaks_data_path = {'opeaks': dataset_prefix + 'ecg_opeaks_eval.npy', 74 | 'rpeaks': dataset_prefix + 'ecg_rpeaks_eval.npy'} 75 | 76 | if from_ppg: 77 | train_data = get_bio_data( 78 | train_ppg_data_path, train_ecg_data_path, train_ecg_peaks_data_path) 79 | eval_data = get_bio_data( 80 | eval_ppg_data_path, eval_ecg_data_path, eval_ecg_peaks_data_path) 81 | else: 82 | train_data = get_bio_data(train_ecg_data_path, train_ppg_data_path) 83 | eval_data = get_bio_data(eval_ecg_data_path, eval_ppg_data_path) 84 | 85 | train_data_loader = DataLoader( 86 | train_data, batch_size=batch_size, shuffle=shuffle_training, num_workers=4, pin_memory=True) 87 | 88 | eval_data_loader = DataLoader( 89 | eval_data, batch_size=15, shuffle=shuffle_testing, num_workers=4, pin_memory=True) 90 | 91 | return train_data_loader, eval_data_loader 92 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from joblib import Parallel, delayed 2 | import torch 3 | import torch.autograd as autograd 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | from scipy import stats 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | plt.rcParams['figure.figsize'] = 17, 15 10 | 11 | # import similaritymeasures 12 | 13 | 14 | def compute_gradient_penalty(D, real_samples, fake_samples, real_A, patch, device): 15 | """Calculates the gradient penalty loss for WGAN GP""" 16 | alpha = torch.rand((real_samples.size(0), 1, 1)).to(device) 17 | interpolates = (alpha * real_samples + ((1 - alpha) 18 | * fake_samples)).requires_grad_(True) 19 | d_interpolates = D(interpolates, real_A) 20 | fake = torch.full( 21 | (real_samples.shape[0], *patch), 1, dtype=torch.float, device=device) 22 | 23 | # Get gradient w.r.t. interpolates 24 | gradients = autograd.grad( 25 | outputs=d_interpolates, 26 | inputs=interpolates, 27 | grad_outputs=fake, 28 | create_graph=True, 29 | retain_graph=True, 30 | only_inputs=True, 31 | )[0] 32 | gradients = gradients.view(gradients.size(0), -1) 33 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 34 | return gradient_penalty 35 | 36 | 37 | mean_conv = torch.nn.Conv1d(in_channels=1, out_channels=1, kernel_size=4, 38 | bias=False, padding_mode='replicate', padding=1) 39 | mean_conv.weight.data = torch.full_like( 40 | mean_conv.weight.data, 0.25) 41 | pad = torch.nn.ReplicationPad1d((0, 1)) 42 | 43 | 44 | def smoother(fake, device): 45 | fake = mean_conv.to(device)(fake) 46 | fake = pad.to(device)(fake) 47 | return fake 48 | 49 | 50 | def sample_images(experiment_name, val_dataloader, generator, steps, device): 51 | """Saves generated signals from the validation set""" 52 | 53 | generator.eval() 54 | 55 | current_img_dir = "sample_signals/%s/%s.png" % (experiment_name, steps) 56 | 57 | signals = next(iter(val_dataloader)) 58 | real_A = signals[0].to(device) 59 | real_B = signals[1].to(device) 60 | fake_B = generator(real_A) 61 | fake_B = smoother(fake_B, device) 62 | 63 | real_A = torch.squeeze(real_A).cpu().detach().numpy() 64 | real_B = torch.squeeze(real_B).cpu().detach().numpy() 65 | fake_B = torch.squeeze(fake_B).cpu().detach().numpy() 66 | 67 | fig, axes = plt.subplots(real_A.shape[0], 3) 68 | 69 | axes[0][0].set_title('Real A') 70 | axes[0][1].set_title('Real B') 71 | axes[0][2].set_title('Fake B') 72 | 73 | for idx, signal in enumerate(real_A): 74 | axes[idx][0].plot(real_A[idx], color='c') 75 | axes[idx][1].plot(real_B[idx], color='m') 76 | axes[idx][2].plot(fake_B[idx], color='y') 77 | 78 | fig.canvas.draw() 79 | fig.savefig(current_img_dir) 80 | plt.close(fig) 81 | 82 | 83 | def eval_rmse_p(signal_a, signal_b): 84 | rmse = np.sqrt(((signal_a - signal_b) ** 2).mean()) 85 | p = stats.pearsonr(signal_a, signal_b)[0] 86 | 87 | return rmse, p 88 | 89 | 90 | def evaluate_generated_signal_quality(val_dataloader, generator, writer, steps, device): 91 | generator.eval() 92 | 93 | all_signal_ = [] 94 | all_generated_signal_ = [] 95 | 96 | for _, batch in enumerate(val_dataloader): 97 | real_A = batch[0].to(device) 98 | 99 | real_B = batch[1].to(device) 100 | real_B = torch.squeeze(real_B).cpu().detach().numpy() 101 | 102 | fake_B = generator(real_A) 103 | fake_B = smoother(fake_B, device) 104 | fake_B = torch.squeeze(fake_B).cpu().detach().numpy() 105 | 106 | all_signal_.append(real_B) 107 | all_generated_signal_.append(fake_B) 108 | 109 | all_signal = np.vstack(all_signal_) 110 | all_generated_signal = np.vstack(all_generated_signal_) 111 | 112 | rmse_p_pairs = Parallel(n_jobs=4)(delayed(eval_rmse_p)( 113 | signal_a, signal_b) for signal_a, signal_b in zip(all_signal, all_generated_signal)) 114 | res = list(zip(*rmse_p_pairs)) 115 | 116 | rmse_mean, rmse_std = np.mean(res[0]), np.std(res[0]) 117 | p_mean, p_std = np.mean(res[1]), np.std(res[1]) 118 | 119 | print('\nepoch: ', steps) 120 | print('rmse_mean:', rmse_mean, ', rmse_std:', rmse_std) 121 | print('p_mean:', p_mean, ', p_std:', p_std) 122 | # fdists = [] 123 | # fdists = Parallel(n_jobs=32)(delayed(similaritymeasures.frechet_dist)(sig, all_signal[i]) for i, sig in enumerate(all_generated_signal)) 124 | # print('frechet distance: ', np.mean(fdists), np.std(fdists)) 125 | 126 | if writer: 127 | writer.add_scalars('losses4', {'rms_error': rmse_mean}, steps) 128 | writer.add_scalars('losses4', {'p_mean': p_mean}, steps) 129 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torchsummary import summary 4 | 5 | 6 | # ---------- 7 | # U-NET 8 | # ---------- 9 | 10 | class UNetDown(nn.Module): 11 | def __init__(self, in_size, out_size, ksize=4, stride=2, normalize=True, dropout=0.0): 12 | super(UNetDown, self).__init__() 13 | layers = [nn.Conv1d(in_size, out_size, kernel_size=ksize, 14 | stride=stride, bias=False, padding_mode='replicate')] 15 | if normalize: 16 | layers.append(nn.InstanceNorm1d(out_size)) 17 | layers.append(nn.LeakyReLU(0.2)) 18 | if dropout: 19 | layers.append(nn.Dropout(dropout)) 20 | self.model = nn.Sequential(*layers) 21 | 22 | def forward(self, x): 23 | return self.model(x) 24 | 25 | 26 | class UNetUp(nn.Module): 27 | def __init__(self, in_size, out_size, ksize=4, stride=2, output_padding=0, dropout=0.0): 28 | super(UNetUp, self).__init__() 29 | layers = [ 30 | nn.ConvTranspose1d(in_size, out_size, kernel_size=ksize, 31 | stride=stride, output_padding=output_padding, bias=False), 32 | nn.InstanceNorm1d(out_size), 33 | nn.ReLU(inplace=True), 34 | ] 35 | if dropout: 36 | layers.append(nn.Dropout(dropout)) 37 | 38 | self.model = nn.Sequential(*layers) 39 | 40 | def forward(self, x, skip_input): 41 | x = self.model(x) 42 | x = torch.cat((x, skip_input), 1) 43 | 44 | return x 45 | 46 | 47 | # ---------- 48 | # Generator 49 | # ---------- 50 | 51 | class GeneratorUNet(nn.Module): 52 | def __init__(self, in_channels=1, out_channels=1): 53 | super(GeneratorUNet, self).__init__() 54 | 55 | self.down1 = UNetDown(in_channels, 128, normalize=False) 56 | self.down2 = UNetDown(128, 256) 57 | self.down3 = UNetDown(256, 512, dropout=0.5) 58 | self.down4 = UNetDown(512, 512, dropout=0.5, normalize=False) 59 | 60 | self.up1 = UNetUp(512, 512, output_padding=1, dropout=0.5) 61 | self.up2 = UNetUp(1024, 256, output_padding=0) 62 | self.up3 = UNetUp(512, 128, output_padding=0) 63 | 64 | self.final = nn.Sequential( 65 | nn.Upsample(scale_factor=2), 66 | nn.ConstantPad1d((1, 1), 0), 67 | nn.Conv1d(256, out_channels, 4, padding=2, 68 | padding_mode='replicate'), 69 | nn.Tanh(), 70 | ) 71 | 72 | def forward(self, x): 73 | # U-Net generator with skip connections from encoder to decoder 74 | d1 = self.down1(x) 75 | d2 = self.down2(d1) 76 | d3 = self.down3(d2) 77 | d4 = self.down4(d3) 78 | 79 | u1 = self.up1(d4, d3) 80 | u2 = self.up2(u1, d2) 81 | u3 = self.up3(u2, d1) 82 | 83 | return self.final(u3) 84 | 85 | 86 | # -------------- 87 | # Discriminator 88 | # -------------- 89 | 90 | class Discriminator(nn.Module): 91 | def __init__(self, in_channels=1): 92 | super(Discriminator, self).__init__() 93 | 94 | def discriminator_block(in_filters, out_filters, ksize=6, stride=3, normalization=True): 95 | """Returns downsampling layers of each discriminator block""" 96 | layers = [nn.Conv1d(in_filters, out_filters, ksize, 97 | stride=stride, padding_mode='replicate')] 98 | if normalization: 99 | layers.append(nn.InstanceNorm1d(out_filters)) 100 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 101 | return layers 102 | 103 | self.model = nn.Sequential( 104 | *discriminator_block(in_channels * 2, 128, normalization=False), 105 | *discriminator_block(128, 256), 106 | *discriminator_block(256, 512), 107 | nn.Conv1d(512, 1, 4, bias=False, padding_mode='replicate') 108 | ) 109 | 110 | def forward(self, signal_A, signal_B): 111 | # Concatenate signals and condition signals by channels to produce input 112 | signal_input = torch.cat((signal_A, signal_B), 1) 113 | return self.model(signal_input) 114 | 115 | 116 | def weights_init_normal(m): 117 | classname = m.__class__.__name__ 118 | if classname.find("Conv") != -1: 119 | torch.nn.init.normal_(m.weight.data, 0.0, 0.02) 120 | elif classname.find("BatchNorm1d") != -1: 121 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 122 | torch.nn.init.constant_(m.bias.data, 0.0) 123 | 124 | 125 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 126 | #generator = GeneratorUNet().to(device) 127 | #summary(generator, input_size=(1, 375)) 128 | # discriminator = Discriminator().to(device) 129 | # summary(discriminator, input_size=[(1, 375), (1, 375)]) 130 | -------------------------------------------------------------------------------- /p2e_wgan_gp.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import time 4 | import datetime 5 | import os 6 | import sys 7 | import numpy as np 8 | from tensorboardX import SummaryWriter 9 | import torch 10 | import torch.optim as optim 11 | from models import weights_init_normal, GeneratorUNet, Discriminator 12 | from data import get_data_loader 13 | from utils import compute_gradient_penalty, smoother, sample_images, evaluate_generated_signal_quality 14 | 15 | torch.backends.cudnn.benchmark = True 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--experiment_name", type=str, 19 | default="p2e_wgan_gp_mimic", help="name of the experiment") 20 | parser.add_argument("--dataset_prefix", type=str, 21 | default="data/mimic/", help="path to the train and valid dataset") 22 | parser.add_argument("--epoch", type=int, default=0, 23 | help="epoch to start training from") 24 | parser.add_argument("--shuffle_training", type=bool, 25 | default=True, help="shuffle training") 26 | parser.add_argument("--shuffle_testing", type=bool, 27 | default=False, help="shuffle testing") 28 | parser.add_argument("--is_eval", type=bool, 29 | default=False, help="evaluation mode") 30 | parser.add_argument("--from_ppg", type=bool, default=True, 31 | help="reconstruct from ppg") 32 | parser.add_argument("--peaks_only", type=bool, default=False, 33 | help="L2 loss on peaks only") 34 | parser.add_argument("--n_epochs", type=int, default=10000, 35 | help="number of epochs of training") 36 | parser.add_argument("--batch_size", type=int, default=192, 37 | help="size of the batches") 38 | parser.add_argument("--lr", type=float, default=0.0002, 39 | help="adam: learning rate") 40 | parser.add_argument("--b1", type=float, default=0.5, 41 | help="adam: decay of first order momentum of gradient") 42 | parser.add_argument("--b2", type=float, default=0.999, 43 | help="adam: decay of first order momentum of gradient") 44 | parser.add_argument("--lambda_gp", type=float, default=10, 45 | help="Loss weight for gradient penalty") 46 | parser.add_argument("--ncritic", type=int, default=3, 47 | help=" number of iterations of the critic per generator iteration") 48 | parser.add_argument("--n_cpu", type=int, default=4, 49 | help="number of cpu threads to use during batch generation") 50 | parser.add_argument("--signal_length", type=int, 51 | default=375, help="size of the signal") 52 | parser.add_argument("--checkpoint_interval", type=int, 53 | default=30, help="interval between model checkpoints") 54 | 55 | args, unknown = parser.parse_known_args() 56 | print(args) 57 | 58 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 59 | cuda = True if torch.cuda.is_available() else False 60 | 61 | # Weighting for L2 loss 62 | if args.peaks_only: 63 | lambda_sample = 20 64 | rpeak_weight = 4 65 | else: 66 | lambda_sample = 50 67 | 68 | # Loss functions 69 | if args.peaks_only: 70 | criterion_samplewise = torch.nn.MSELoss(reduction='sum') 71 | else: 72 | criterion_samplewise = torch.nn.MSELoss() 73 | 74 | # Output size of the discriminator (PatchGAN) 75 | patch = (1, 9) 76 | 77 | # Load data 78 | dataloader, val_dataloader = get_data_loader(args.dataset_prefix, args.batch_size, from_ppg=args.from_ppg, 79 | shuffle_training=args.shuffle_training, 80 | shuffle_testing=args.shuffle_testing) 81 | 82 | # Initialize generator and discriminator 83 | generator = GeneratorUNet() 84 | discriminator = Discriminator() 85 | 86 | if cuda: 87 | if torch.cuda.device_count() > 2: 88 | # if False: 89 | generator = torch.nn.DataParallel( 90 | generator, device_ids=[0, 1, 2]).to(device) 91 | discriminator = torch.nn.DataParallel( 92 | discriminator, device_ids=[0, 1, 2]).to(device) 93 | else: 94 | generator = generator.to(device) 95 | discriminator = discriminator.to(device) 96 | 97 | criterion_samplewise.to(device) 98 | 99 | # Optimizers 100 | optimizer_G = torch.optim.Adam( 101 | generator.parameters(), lr=args.lr, betas=(args.b1, args.b2)) 102 | optimizer_D = torch.optim.Adam( 103 | discriminator.parameters(), lr=args.lr, betas=(args.b1, args.b2)) 104 | 105 | if args.epoch != 0: 106 | # Load pretrained models 107 | 108 | pretrained_path = "saved_models/%s/multi_models_%d.pth" % ( 109 | args.experiment_name, args.epoch) 110 | checkpoint = torch.load(pretrained_path) 111 | generator.load_state_dict(checkpoint['generator_state_dict']) 112 | discriminator.load_state_dict(checkpoint['discriminator_state_dict']) 113 | 114 | optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict']) 115 | optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict']) 116 | 117 | if args.is_eval: 118 | sample_images(args.experiment_name, val_dataloader, 119 | generator, args.epoch, device) 120 | evaluate_generated_signal_quality( 121 | val_dataloader, generator, None, args.epoch, device) 122 | sys.exit() 123 | else: 124 | # Initialize weights 125 | generator.apply(weights_init_normal) 126 | discriminator.apply(weights_init_normal) 127 | 128 | os.makedirs("saved_models/%s" % args.experiment_name, exist_ok=True) 129 | os.makedirs("sample_signals/%s" % args.experiment_name, exist_ok=True) 130 | os.makedirs("logs/%s" % args.experiment_name, exist_ok=True) 131 | writer = SummaryWriter("logs/%s" % args.experiment_name) 132 | 133 | # ---------- 134 | # Training 135 | # ---------- 136 | 137 | prev_time = time.time() 138 | for epoch in range(args.epoch+1, args.n_epochs): 139 | 140 | for i, batch in enumerate(dataloader): 141 | # Model inputs 142 | real_A = batch[0].to(device) 143 | real_B = batch[1].to(device) 144 | 145 | if i % args.ncritic == 0: 146 | 147 | # ------------------ 148 | # Train Generators 149 | # ------------------ 150 | generator.train() 151 | for p in generator.parameters(): 152 | p.grad = None 153 | 154 | # GAN loss 155 | fake_B = generator(real_A) 156 | 157 | # Sample-wise loss 158 | if args.from_ppg and args.peaks_only: 159 | opeaks = batch[2].to(device) 160 | rpeaks = batch[3].to(device) 161 | fake_B_masked_opeaks = fake_B * (opeaks != 0) 162 | fake_B_masked_rpeaks = fake_B * (rpeaks != 0) 163 | opeak_count = torch.sum(opeaks != 0) 164 | rpeak_count = torch.sum(rpeaks != 0) 165 | 166 | loss_sample_opeaks = criterion_samplewise( 167 | fake_B_masked_opeaks, opeaks) 168 | loss_sample_rpeaks = criterion_samplewise( 169 | fake_B_masked_rpeaks, rpeaks) 170 | loss_sample = loss_sample_opeaks / opeak_count + \ 171 | rpeak_weight * loss_sample_rpeaks / rpeak_count 172 | else: 173 | loss_sample = criterion_samplewise(fake_B, real_B) 174 | 175 | # Smooth the output with moving averages 176 | if args.from_ppg: 177 | fake_B = smoother(fake_B, device) 178 | 179 | pred_fake = discriminator(fake_B, real_A) 180 | loss_GAN = -torch.mean(pred_fake) 181 | 182 | # Total loss 183 | loss_G = loss_GAN + lambda_sample * loss_sample 184 | 185 | loss_G.backward() 186 | optimizer_G.step() 187 | 188 | else: 189 | fake_B = generator(real_A) 190 | 191 | # --------------------- 192 | # Train Discriminator 193 | # --------------------- 194 | 195 | # optimizer_D.zero_grad() 196 | for p in discriminator.parameters(): 197 | p.grad = None 198 | # Real signals 199 | real_validity = discriminator(real_B, real_A) 200 | # Fake signals 201 | fake_validity = discriminator(fake_B.detach(), real_A) 202 | # Gradient penalty 203 | gradient_penalty = compute_gradient_penalty( 204 | discriminator, real_B, fake_B.detach(), real_A, patch, device) 205 | # Adversarial loss 206 | loss_D0 = -torch.mean(real_validity) + \ 207 | torch.mean(fake_validity) + args.lambda_gp * gradient_penalty 208 | loss_D = loss_D0 209 | 210 | loss_D.backward() 211 | optimizer_D.step() 212 | 213 | # -------------- 214 | # Log Progress 215 | # -------------- 216 | 217 | # Determine approximate time left 218 | batches_done = epoch * len(dataloader) + i 219 | batches_left = args.n_epochs * len(dataloader) - batches_done 220 | time_left = datetime.timedelta( 221 | seconds=batches_left * (time.time() - prev_time)) 222 | prev_time = time.time() 223 | 224 | # Print log 225 | sys.stdout.write( 226 | "\r[Epoch %d/%d] [Batch %d/%d] [D0 loss: %f] [D loss: %f] [G loss: %f, sample: %f, adv: %f] ETA: %s" 227 | % (epoch, args.n_epochs, i, len(dataloader), 228 | loss_D0.item(), loss_D.item(), loss_G.item(), loss_sample.item(), loss_GAN.item(), 229 | time_left) 230 | ) 231 | 232 | writer.add_scalars('losses', {'g_loss': loss_G.item()}, batches_done) 233 | writer.add_scalars('losses', {'d_loss': loss_D.item()}, batches_done) 234 | writer.add_scalars( 235 | 'losses2', {'d_loss0': loss_D0.item()}, batches_done) 236 | writer.add_scalars( 237 | 'losses2', {'gan_loss': loss_GAN.item()}, batches_done) 238 | writer.add_scalars( 239 | 'losses3', {'sample_loss': loss_sample.item()}, batches_done) 240 | 241 | if args.checkpoint_interval != -1 and epoch % args.checkpoint_interval == 0: 242 | # Save model checkpoints 243 | torch.save({ 244 | 'epoch': epoch, 245 | 'generator_state_dict': generator.state_dict(), 246 | 'optimizer_G_state_dict': optimizer_G.state_dict(), 247 | 'discriminator_state_dict': discriminator.state_dict(), 248 | 'optimizer_D_state_dict': optimizer_D.state_dict(), 249 | }, "saved_models/%s/multi_models_%d.pth" % (args.experiment_name, epoch)) 250 | sample_images(args.experiment_name, val_dataloader, 251 | generator, epoch, device) 252 | evaluate_generated_signal_quality( 253 | val_dataloader, generator, writer, epoch, device) 254 | --------------------------------------------------------------------------------