├── assets └── table3.png ├── ward2icu ├── models │ ├── __init__.py │ ├── classifiers.py │ ├── rcgan.py │ ├── cnngan.py │ └── rgan.py ├── __init__.py ├── utils.py ├── samplers.py ├── metrics.py ├── layers.py └── trainers.py ├── Makefile ├── tests ├── test_utils.py ├── test_samplers.py ├── test_data.py ├── test_models.py └── test_trainers.py ├── .gitignore ├── README.md └── run-experiment.py /assets/table3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/3778/Ward2ICU/HEAD/assets/table3.png -------------------------------------------------------------------------------- /ward2icu/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifier import * 2 | from .gan import * 3 | from .rgan import * 4 | from .rcgan import * 5 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: README.md 2 | 3 | bin/gh-md-toc: 4 | mkdir -p bin 5 | wget https://raw.githubusercontent.com/ekalinin/github-markdown-toc/master/gh-md-toc 6 | chmod a+x gh-md-toc 7 | mv gh-md-toc bin/ 8 | 9 | README.md: bin/gh-md-toc 10 | ./bin/gh-md-toc --insert README.md 11 | rm -f README.md.orig.* README.md.toc.* 12 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from sybric.utils import tile 3 | 4 | 5 | def test_tile(): 6 | y = Tensor([0, 1, 2]) 7 | y_tiled = tile(y, 3) 8 | expected = Tensor([[0, 0, 0], 9 | [1, 1, 1], 10 | [2, 2, 2]]) 11 | assert (y_tiled == expected).all() 12 | -------------------------------------------------------------------------------- /ward2icu/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import numpy as np 4 | import torch 5 | from pathlib import Path 6 | 7 | 8 | def get_project_root() -> Path: 9 | """Returns project root folder.""" 10 | return Path(__file__).parent.parent 11 | 12 | 13 | def get_data_dir() -> Path: 14 | return get_project_root() / 'data' 15 | 16 | 17 | def make_logger(file_: str = "NO_FILE") -> logging.Logger: 18 | log_level = getattr(logging, os.getenv("LOG_LEVEL", "INFO")) 19 | fmt = "%(asctime)s %(name)-12s %(levelname)-8s %(message)s" 20 | logging.basicConfig(level=log_level, format=fmt) 21 | return logging.getLogger(file_.split("/")[-1]) 22 | 23 | 24 | def set_seeds(): 25 | np.random.seed(3778) 26 | torch.manual_seed(3778) 27 | -------------------------------------------------------------------------------- /ward2icu/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from math import floor 4 | from ward2icu import make_logger 5 | 6 | logger = make_logger(__file__) 7 | 8 | 9 | def tile(t, length): 10 | ''' Creates an extra dimension on the tensor t and 11 | repeats it throughout.''' 12 | return t.view(-1, 1).repeat(1, length) 13 | 14 | 15 | def calc_conv_output_length(conv_layer, 16 | input_length): 17 | def _maybe_slice(x): 18 | return x[0] if isinstance(x, tuple) else x 19 | 20 | l = input_length 21 | p =_maybe_slice(conv_layer.padding) 22 | d =_maybe_slice(conv_layer.dilation) 23 | k =_maybe_slice(conv_layer.kernel_size) 24 | s =_maybe_slice(conv_layer.stride) 25 | return floor((l + 2*p - d*(k-1)-1)/s + 1) 26 | 27 | 28 | def flatten(l): 29 | return [item for sublist in l for item in sublist] 30 | 31 | 32 | def numpy_to_cuda(*args): 33 | return [torch.from_numpy(a).cuda() for a in args] 34 | 35 | 36 | def n_(tensor): 37 | return tensor.detach().cpu().numpy() 38 | -------------------------------------------------------------------------------- /tests/test_samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sybric.samplers import BinaryBalancedSampler, IdentitySampler 3 | 4 | 5 | def test_BinaryBalancedSampler(): 6 | X = torch.eye(5) 7 | y = torch.Tensor([0, 0, 0, 1, 1]) 8 | sampler = BinaryBalancedSampler(X, y) 9 | for _ in range(100): 10 | X_s, y_s = sampler.sample() 11 | 12 | assert (y_s == torch.Tensor([0, 0, 1, 1])).all() 13 | assert X_s.shape == (4, 5) 14 | assert any((X_s[0] == X[i]).all() for i in [0, 1, 2, 3]) 15 | assert (X_s[-1] == X[-1]).all() 16 | 17 | def test_BinaryBalancedSampler_batch(): 18 | X = torch.eye(5) 19 | y = torch.Tensor([0, 0, 0, 1, 1]) 20 | sampler = BinaryBalancedSampler(X, y, batch_size=2) 21 | for _ in range(100): 22 | X_s, y_s = sampler.sample() 23 | 24 | assert (y_s == torch.Tensor([0, 1])).all() 25 | assert X_s.shape == (2, 5) 26 | assert any((X_s[0] == X[i]).all() for i in [0, 1, 2, 3]) 27 | assert any((X_s[-1] == X[i]).all() for i in [-1, -2]) 28 | 29 | def test_IdentitySampler(): 30 | X = torch.ones(4, 2, 6) 31 | y = torch.Tensor([0, 0, 1, 1]) 32 | sampler = IdentitySampler(X, y, tile=True) 33 | X_s, y_s = sampler.sample() 34 | 35 | assert X_s.shape == (4, 2, 6) 36 | assert y_s.shape == (4, 2) 37 | assert (y_s == torch.Tensor([[0, 0], 38 | [0, 0], 39 | [1, 1], 40 | [1, 1]])).all() 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # others 107 | bin/ 108 | -------------------------------------------------------------------------------- /tests/test_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sybric.data import TimeSeriesVitalSigns 4 | 5 | 6 | dataset = TimeSeriesVitalSigns() 7 | 8 | def test_TimeSeriesVitalSigns_inv_transforms(): 9 | d = dataset 10 | X = d.X.astype(np.float128) 11 | y = d.y 12 | s = d.synthesis_df(X, y) 13 | assert np.isclose(d.inv_whiten(d.whiten(X)), X).all() 14 | assert np.isclose(d.inv_normalize(d.normalize(X)), X).all() 15 | assert np.isclose(d.inv_minmax(d.minmax(X)), X).all() 16 | assert np.isclose(d.inv_minmax_signals(d.minmax_signals(X)), X).all() 17 | 18 | columns = ['cat_vital_sign', 't', 'class'] 19 | assert (s[columns] == d.df[columns]).all().all() 20 | assert np.isclose(s['value'], d.df['value']).all() 21 | 22 | def test_TimeSeriesVitalSigns_transform_minmax(): 23 | X = np.array([ 24 | # patient 1 25 | [[36, 90], [40, 100], [40, 100]], 26 | # patient 2 27 | [[40, 80], [36, 90], [36, 90]], 28 | # patient 3 29 | [[40, 80], [36, 90], [36, 90]] 30 | ]) 31 | 32 | e = np.array([ 33 | # patient 1 34 | [[-1, 1], [1, 1], [1, 1]], 35 | # patient 2 36 | [[1, -1], [-1, -1], [-1, -1]], 37 | # patient 3 38 | [[1, -1], [-1, -1], [-1, -1]] 39 | ]) 40 | 41 | assert (dataset.minmax(X, X) == e).all() 42 | assert X.shape == (3, 3, 2) 43 | 44 | def test_TimeSeriesVitalSigns_transform_minmax_signals(): 45 | X = np.array([ 46 | # patient 1 47 | [[36, 90], [40, 100], [40, 100]], 48 | # patient 2 49 | [[40, 80], [36, 90], [36, 90]], 50 | # patient 2 51 | [[40, 80], [36, 90], [36, 90]], 52 | # patient 2 53 | [[40, 80], [36, 90], [36, 90]], 54 | ]) 55 | 56 | # signal 1 = [36, 40, 40, 36] -> [-1, 1, 1, -1] 57 | # signal 2 = [90, 100, 80, 90] -> [ 0, 1, -1, 0] 58 | 59 | e = np.array([ 60 | # patient 1 61 | [[-1, 0], [1, 1], [1, 1]], 62 | # patient 2 63 | [[1, -1], [-1, 0], [-1, 0]], 64 | # patient 2 65 | [[1, -1], [-1, 0], [-1, 0]], 66 | # patient 2 67 | [[1, -1], [-1, 0], [-1, 0]] 68 | ]) 69 | 70 | assert (dataset.minmax_signals(X, X) == e).all() 71 | assert X.shape == (4, 3, 2) 72 | assert e.shape == (4, 3, 2) 73 | -------------------------------------------------------------------------------- /ward2icu/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ward2icu.utils import tile as tile_func 3 | 4 | 5 | class IdentitySampler: 6 | def __init__(self, X, y, tile=False): 7 | self.X = X 8 | self.y = tile_func(y, X.shape[1]) if tile else y 9 | self.tile = tile 10 | self.device = X.device 11 | 12 | def sample(self): 13 | return self.X, self.y 14 | 15 | def __len__(self): 16 | return self.y.shape[0] 17 | 18 | 19 | class BinaryBalancedSampler: 20 | def __init__(self, X, y, tile=False, batch_size=None): 21 | assert len(y.unique()) <= 2 22 | self.sequence_length = X.shape[1] 23 | self.batch_size = batch_size 24 | self.tile = tile 25 | self.device = X.device 26 | 27 | self.majority_class = y.mode().values 28 | self.minority_class = y[y != self.majority_class].mode().values 29 | 30 | self.majority_mask = (y == self.majority_class) 31 | self.minority_mask = (y == self.minority_class) 32 | 33 | self.majority_count = int(self.majority_mask.sum()) 34 | self.minority_count = int(self.minority_mask.sum()) 35 | 36 | if batch_size: 37 | assert batch_size/2 < self.minority_count 38 | 39 | self.X_majority = X[self.majority_mask] 40 | self.X_minority = X[self.minority_mask] 41 | 42 | self.y_majority = y[self.majority_mask] 43 | self.y_minority = y[self.minority_mask] 44 | self.should_sample_minority_class = self.batch_size is not None 45 | 46 | def sample(self): 47 | batch_size_half = (int(self.batch_size/2) 48 | if self.batch_size is not None 49 | else self.minority_count) 50 | 51 | idx_maj, idx_min = self._create_idxs(batch_size_half) 52 | 53 | X_majority_batch = self.X_majority[idx_maj] 54 | y_majority_batch = self.y_majority[idx_maj] 55 | 56 | X_minority_batch = (self.X_minority[idx_min] 57 | if self.should_sample_minority_class 58 | else self.X_minority) 59 | y_minority_batch = (self.y_minority[idx_min] 60 | if self.should_sample_minority_class 61 | else self.y_minority) 62 | 63 | X = torch.cat((X_majority_batch, X_minority_batch), dim=0) 64 | y = torch.cat((y_majority_batch, y_minority_batch), dim=0) 65 | 66 | y = tile_func(y, self.sequence_length) if self.tile else y 67 | return X, y 68 | 69 | def __len__(self): 70 | return int(self.batch_size or 2*self.minority_count) 71 | 72 | def _create_idxs(self, batch_size): 73 | idx_maj = torch.randperm(self.majority_count)[:batch_size] 74 | idx_min = (torch.randperm(self.minority_count)[:batch_size] 75 | if self.should_sample_minority_class else None) 76 | return idx_maj, idx_min 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | --- 2 |
3 | 4 | # Ward2ICU 5 | 6 | 7 | [![Paper](http://img.shields.io/badge/paper-arxiv.1910.00752-B31B1B.svg)](https://arxiv.org/abs/1910.00752) 8 | [![3778 Research](http://img.shields.io/badge/3778-Research-4b44ce.svg)](https://research.3778.care/projects/privacy/) 9 | [![3778 Research](http://img.shields.io/badge/3778-Survey-4b44ce.svg)](https://forms.gle/e2asYSVaiuPUUCKu8) 10 | 11 |
12 | 13 | 14 | * [Description](#description) 15 | * [Models](#models) 16 | * [1D Conditional CNN GAN](#1d-conditional-cnn-gan) 17 | * [Recursive GAN (RGAN)](#recursive-gan-rgan) 18 | * [Recursive Conditional GAN (RCGAN)](#recursive-conditional-gan-rcgan) 19 | * [RNN Classifier](#rnn-classifier) 20 | * [1D-CNN Classifier](#1d-cnn-classifier) 21 | * [Citation](#citation) 22 | 23 | 24 | 25 | 26 | 27 | ## Description 28 | Ward2ICU: A Vital Signs Dataset of Inpatients from the General Ward 29 | 30 | ## Models 31 | 32 | ### 1D Conditional CNN GAN 33 | [![Source code](https://img.shields.io/badge/code-PyTorch-009900.svg)](https://github.com/3778/Ward2ICU/blob/master/ward2icu/models/cnngan.py) 34 | 35 | ![Table 3](assets/table3.png) 36 | 37 | ### Recursive GAN (RGAN) 38 | [![Source code](https://img.shields.io/badge/code-PyTorch-009900.svg)](https://github.com/3778/Ward2ICU/blob/master/ward2icu/models/rgan.py) 39 | [![Paper](http://img.shields.io/badge/paper-arxiv.1706.02633-B31B1B.svg)](https://arxiv.org/abs/1706.02633) 40 | 41 | Recursive GAN (Generator) implementation with RNN cells. 42 | 43 | ### Recursive Conditional GAN (RCGAN) 44 | [![Source code](https://img.shields.io/badge/code-PyTorch-009900.svg)](https://github.com/3778/Ward2ICU/blob/master/ward2icu/models/rcgan.py) 45 | [![Paper](http://img.shields.io/badge/paper-arxiv.1706.02633-B31B1B.svg)](https://arxiv.org/abs/1706.02633) 46 | 47 | Recursive Conditional GAN (Generator) implementation with RNN cells 48 | 49 | ### RNN Classifier 50 | [![Source code](https://img.shields.io/badge/code-PyTorch-009900.svg)](https://github.com/3778/Ward2ICU/blob/master/ward2icu/models/classifiers.py) 51 | 52 | A simple RNN for classification tasks. It consists of a recurrent layer (Elman RNN, LSTM or GRU) followed by 2 fully connected. The first shares parameters across the time domain (i.e. second tensor dimension), while the second collapses the time-domain to a single point with a Sigmoid activation. 53 | 54 | ### 1D-CNN Classifier 55 | [![Source code](https://img.shields.io/badge/code-PyTorch-009900.svg)](https://github.com/3778/Ward2ICU/blob/master/ward2icu/models/classifiers.py) 56 | 57 | Single-dimension convolutional network for classification. Consists of a sequence of `Conv1d` followed by `MaxPool1d` and `Linear` with a `Sigmoid` output. 58 | 59 | ## Citation 60 | ``` 61 | @article{severo2019ward2icu, 62 | title={Ward2ICU: A Vital Signs Dataset of Inpatients from the General Ward}, 63 | author={Severo, Daniel and Amaro, Fl{\'a}vio and Hruschka Jr, Estevam R and Costa, Andr{\'e} Soares de Moura}, 64 | journal={arXiv preprint arXiv:1910.00752}, 65 | year={2019} 66 | } 67 | ``` 68 | -------------------------------------------------------------------------------- /ward2icu/metrics.py: -------------------------------------------------------------------------------- 1 | from ward2icu.samplers import BinaryBalancedSampler, IdentitySampler 2 | from ward2icu.models import BinaryRNNClassifier 3 | from torch import optim 4 | from ward2icu.utils import (train_test_split_tensor, 5 | numpy_to_cuda, tile) 6 | from ward2icu.trainers import BinaryClassificationTrainer 7 | from ward2icu import make_logger 8 | 9 | logger = make_logger(__file__) 10 | 11 | 12 | def mean_feature_error(X_real, X_synth): 13 | return (1 - X_synth.mean(axis=0)/X_real.mean(axis=0)).mean(axis=0).mean() 14 | 15 | 16 | # TODO(dsevero) in theory, we should seperate the 17 | # training and testing datasets. 18 | def tstr(X_synth, y_synth, X_real, y_real, epochs=3_000, batch_size=None): 19 | 20 | logger.info('Running TSTR') 21 | logger.info(f'Synthetic size: {len(y_synth)}') 22 | logger.info(f'Real size: {len(y_real)}') 23 | logger.info(f'Class distribution: ' 24 | f'[real {y_real.float().mean()}]' 25 | f'[synthetic {y_synth.mean()}]') 26 | 27 | sequence_length = X_real.shape[1] 28 | sequence_size = X_real.shape[2] 29 | X_train, y_train = X_synth, y_synth 30 | X_test, y_test = X_real, y_real 31 | 32 | X_train, y_train = numpy_to_cuda(X_train, y_train) 33 | y_test, y_train = y_test.float(), y_train.float() 34 | 35 | sampler_train = BinaryBalancedSampler(X_train, y_train, 36 | tile=True, batch_size=batch_size) 37 | sampler_test = IdentitySampler(X_test, y_test, tile=True) 38 | 39 | model = BinaryRNNClassifier(sequence_length=sequence_length, 40 | input_size=sequence_size, 41 | dropout=0.5, 42 | hidden_size=100).cuda() 43 | optimizer = optim.Adam(model.parameters(), 44 | lr=0.001) 45 | trainer = BinaryClassificationTrainer(model, 46 | optimizer, 47 | sampler_train, 48 | sampler_test, 49 | metrics_prepend='tstr_') 50 | trainer.train(epochs) 51 | return trainer 52 | 53 | 54 | def classify(X, y, epochs=3_000, batch_size=None): 55 | sequence_length = X.shape[1] 56 | sequence_size = X.shape[2] 57 | 58 | (X_train, y_train, 59 | X_test, y_test) = train_test_split_tensor(X, y, 0.3) 60 | (X_train, y_train, 61 | X_test, y_test) = numpy_to_cuda(X_train, y_train, 62 | X_test, y_test) 63 | y_test, y_train = y_test.float(), y_train.float() 64 | 65 | sampler_train = BinaryBalancedSampler(X_train, y_train, 66 | tile=True, batch_size=batch_size) 67 | sampler_test = IdentitySampler(X_test, y_test, tile=True) 68 | 69 | model = BinaryRNNClassifier(sequence_length=sequence_length, 70 | input_size=sequence_size, 71 | dropout=0.8, 72 | hidden_size=100).cuda() 73 | optimizer = optim.Adam(model.parameters(), 74 | lr=0.001) 75 | trainer = BinaryClassificationTrainer(model, 76 | optimizer, 77 | sampler_train, 78 | sampler_test) 79 | logger.info(model) 80 | logger.info(f'Test size: {len(y_test)}') 81 | logger.info(f'Train size: {len(y_train)}') 82 | 83 | trainer.train(epochs) 84 | return trainer 85 | -------------------------------------------------------------------------------- /ward2icu/layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from ward2icu.utils import calc_conv_output_length, tile 4 | 5 | 6 | def rnn_layer(input_size, 7 | hidden_size=None, 8 | num_layers=1, 9 | dropout=0.5, 10 | rnn_type='rnn', 11 | nonlinearity='relu'): 12 | 13 | # Set hidden_size to input_size if not specified 14 | hidden_size = hidden_size or input_size 15 | 16 | rnn_types = { 17 | 'rnn': nn.RNN, 18 | 'lstm': nn.LSTM, 19 | 'gru': nn.GRU 20 | } 21 | 22 | rnn_kwargs = dict( 23 | input_size=input_size, 24 | hidden_size=hidden_size, 25 | num_layers=num_layers, 26 | batch_first=True, 27 | dropout=dropout 28 | ) 29 | 30 | if rnn_type == 'rnn': 31 | rnn_kwargs['nonlinearity'] = nonlinearity 32 | 33 | return rnn_types[rnn_type](**rnn_kwargs) 34 | 35 | 36 | class Conv1dLayers(nn.Module): 37 | def __init__(self, 38 | input_size, 39 | input_length, 40 | kernel_size, 41 | dropout_prob=0.5, 42 | n_layers=1, 43 | step_up=1, 44 | **kwargs): 45 | self.input_size = input_size 46 | self.input_length = input_length 47 | self.kernel_size = kernel_size 48 | self.dropout_prob = dropout_prob 49 | self.n_layers = n_layers 50 | self.step_up = step_up 51 | super(Conv1dLayers, self).__init__() 52 | 53 | self.output_sizes = list() 54 | self.output_lengths = list() 55 | 56 | in_channels = input_size 57 | out_channels = in_channels 58 | output_length = input_length 59 | layers = list() 60 | for depth in range(n_layers): 61 | layers += [nn.Conv1d(in_channels=in_channels, 62 | out_channels=out_channels, 63 | kernel_size=kernel_size, 64 | **kwargs), 65 | nn.Dropout(dropout_prob), 66 | nn.LeakyReLU(0.2)] 67 | output_length = calc_conv_output_length(layers[-3], 68 | output_length) 69 | self.output_sizes += [out_channels] 70 | self.output_lengths += [output_length] 71 | 72 | # next layer 73 | in_channels = out_channels 74 | out_channels = step_up*in_channels 75 | 76 | self.layers = nn.Sequential(*layers) 77 | 78 | def forward(self, x): 79 | return self.layers(x) 80 | 81 | 82 | class Permutation(nn.Module): 83 | def __init__(self, *dims): 84 | self.dims = dims 85 | super(Permutation, self).__init__() 86 | 87 | def forward(self, x): 88 | return x.permute(*self.dims) 89 | 90 | 91 | class View(nn.Module): 92 | def __init__(self, *dims): 93 | self.dims = dims 94 | super(View, self).__init__() 95 | 96 | def forward(self, x): 97 | return x.view(*self.dims) 98 | 99 | 100 | class AppendEmbedding(nn.Module): 101 | def __init__(self, embedding_layer, dim=-1, tile=True): 102 | 103 | super(AppendEmbedding, self).__init__() 104 | self.embedding_layer = embedding_layer 105 | self.dim = dim 106 | self.tile = tile 107 | 108 | def forward(self, x): 109 | y = self.labels_pointer 110 | if self.tile: 111 | y = tile(y, x.size(1)) 112 | y_emb = self.embedding_layer(y.type(torch.LongTensor).to(y.device)) 113 | return torch.cat((x, y_emb), dim=self.dim) 114 | -------------------------------------------------------------------------------- /ward2icu/models/classifiers.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from ward2icu.layers import rnn_layer, Conv1dLayers 9 | from ward2icu.utils import calc_conv_output_length, flatten 10 | 11 | 12 | class BinaryRNNClassifier(nn.Module): 13 | def __init__(self, sequence_length, **kwargs): 14 | """Recursive NN for binary classification with linear 15 | time-collapsing layers. 16 | 17 | Layers: 18 | RNN (with activation, multiple layers): 19 | input: (batch_size, sequence_length, input_size) 20 | output: (batch_size, sequence_length, hidden_size) 21 | 22 | Linear (no activation): 23 | input: (batch_size, sequence_length, hidden_size) 24 | output: (batch_size, sequence_length, 1) 25 | 26 | Notes: 27 | This model adds upon ward2icu.layers.rnn_layer. See docs 28 | for more information. 29 | 30 | Args: 31 | sequence_length (int): Number of points in the sequence. 32 | kwargs: Keyword arguments passed on to ward2icu.layers.rnn_layer 33 | """ 34 | super(BinaryRNNClassifier, self).__init__() 35 | 36 | # Default hidden_size to input_size. 37 | hidden_size = kwargs.get('hidden_size', kwargs['input_size']) 38 | 39 | # See ward2icu.layers.rnn_layer for more information. 40 | self.rnn = rnn_layer(**kwargs) 41 | 42 | # Parameters are shared across time-steps. 43 | self.linear = nn.Linear(hidden_size, 1) 44 | 45 | def forward(self, x): 46 | x, _ = self.rnn(x) 47 | return self.linear(x).squeeze() 48 | 49 | 50 | class BinaryCNNClassifier(nn.Module): 51 | def __init__(self, 52 | input_size, 53 | input_length, 54 | kernel_size, 55 | n_layers=1, 56 | step_up=1, 57 | dropout_prob=0.5, 58 | channel_last=True, 59 | pool_size=None, 60 | **kwargs): 61 | """ 1D CNN for binary classification. 62 | """ 63 | # defaults 64 | pool_size = pool_size or kernel_size 65 | 66 | self.input_length = input_length 67 | self.input_size = input_size 68 | self.kernel_size = kernel_size 69 | self.channel_last = channel_last 70 | self.step_up = step_up 71 | self.pool_size = pool_size 72 | self.n_layers = n_layers 73 | super(BinaryCNNClassifier, self).__init__() 74 | 75 | conv = Conv1dLayers(input_size, 76 | input_length, 77 | kernel_size, 78 | n_layers=n_layers, 79 | step_up=step_up, 80 | dropout_prob=dropout_prob, 81 | **kwargs) 82 | 83 | maxpool = nn.MaxPool1d(pool_size) 84 | flatten = nn.Flatten() 85 | flatten_output_size = ( 86 | conv.output_sizes[-1]* 87 | calc_conv_output_length(maxpool, 88 | conv.output_lengths[-1]) 89 | ) 90 | 91 | self.layers = nn.Sequential(conv, 92 | maxpool, 93 | flatten, 94 | nn.Linear(flatten_output_size, flatten_output_size), 95 | nn.LeakyReLU(0.2), 96 | nn.Linear(flatten_output_size, 1), 97 | nn.Sigmoid()) 98 | 99 | def forward(self, x): 100 | if self.channel_last: 101 | x = x.permute(0, 2, 1) 102 | return self.layers(x).squeeze() 103 | -------------------------------------------------------------------------------- /ward2icu/trainers.py: -------------------------------------------------------------------------------- 1 | """ 2 | References: 3 | - https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py 4 | """ 5 | import torch 6 | import numpy as np 7 | import torchgan 8 | from torch.nn import BCELoss, BCEWithLogitsLoss 9 | from ward2icu import make_logger 10 | from sklearn.metrics import balanced_accuracy_score, matthews_corrcoef 11 | 12 | logger = make_logger(__file__) 13 | 14 | 15 | class BinaryClassificationTrainer(object): 16 | def __init__(self, 17 | model, 18 | optimizer, 19 | sampler_train=None, 20 | sampler_test=None, 21 | log_to_mlflow=True, 22 | loss_function=BCEWithLogitsLoss(), 23 | metrics_prepend=''): 24 | self.optimizer = optimizer 25 | self.sampler_train = sampler_train 26 | self.sampler_test = sampler_test 27 | self.loss_function = loss_function 28 | self.model = model 29 | self.log_to_mlflow = log_to_mlflow 30 | self.tiled = sampler_train.tile 31 | self.metrics_prepend = metrics_prepend 32 | 33 | def train(self, epochs, evaluate_interval=100): 34 | for epoch in range(epochs): 35 | self.optimizer.zero_grad() 36 | X_train, y_train = self.sampler_train.sample() 37 | 38 | logits_train = self.model(X_train) 39 | loss = self.loss_function(logits_train, y_train) 40 | loss.backward() 41 | self.optimizer.step() 42 | 43 | if epoch % evaluate_interval == 0: 44 | with torch.no_grad(): 45 | X_test, y_test = self.sampler_test.sample() 46 | logger.debug(f'[Train sizes {X_train.shape} {y_train.shape}]') 47 | logger.debug(f'[Test sizes {X_test.shape} {y_test.shape}]') 48 | metrics = self.evaluate(X_test, y_test, 49 | X_train, y_train) 50 | msg = f'[epoch {epoch}]' 51 | msg += ''.join(f'[{m} {np.round(v,4)}]' 52 | for m, v in metrics.items() 53 | if m.endswith('balanced_accuracy') or 54 | m.endswith('matheus')) 55 | logger.info(msg) 56 | if self.log_to_mlflow: 57 | mlflow.log_metrics(metrics, step=epoch) 58 | 59 | def evaluate(self, X_test, y_test, X_train, y_train): 60 | def _calculate(X, y, name): 61 | logits = self.model(X) 62 | probs = torch.sigmoid(logits) 63 | y_pred = probs.round() 64 | y_true = y 65 | return self.calculate_metrics(y_true, y_pred, logits, probs, name) 66 | 67 | mp = self.metrics_prepend 68 | return {**_calculate(X_test, y_test, f'{mp}test'), 69 | **_calculate(X_train, y_train, f'{mp}train')} 70 | 71 | def calculate_metrics(self, y_true, y_pred, logits, probs, name=''): 72 | y_true_ = (y_true[:, 0] if self.tiled else y_true).cpu() 73 | y_pred_ = (y_pred.mode().values if self.tiled else y_pred).cpu() 74 | 75 | mask_0 = (y_true_ == 0) 76 | mask_1 = (y_true_ == 1) 77 | 78 | hits = (y_true_ == y_pred_).float() 79 | bas = balanced_accuracy_score(y_true_, y_pred_) 80 | matthews = matthews_corrcoef(y_true_, y_pred_) 81 | 82 | return {f'{name}_accuracy': hits.mean().item(), 83 | f'{name}_balanced_accuracy': bas, 84 | f'{name}_accuracy_0': hits[mask_0].mean().item(), 85 | f'{name}_accuracy_1': hits[mask_1].mean().item(), 86 | f'{name}_loss': self.loss_function(logits, y_true).item(), 87 | f'{name}_loss_0': self.loss_function(logits[mask_0], 88 | y_true[mask_0]).item(), 89 | f'{name}_loss_1': self.loss_function(logits[mask_1], 90 | y_true[mask_1]).item(), 91 | f'{name}_matthews': matthews} 92 | -------------------------------------------------------------------------------- /run-experiment.py: -------------------------------------------------------------------------------- 1 | import mlflow 2 | import tempfile 3 | import click 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | from ward2icu.data import TimeSeriesVitalSigns 8 | from ward2icu.logs import log_avg_loss, log_avg_grad, log_model, log_df 9 | from ward2icu.models import CNNCGANGenerator, CNNCGANDiscriminator 10 | from ward2icu.utils import synthesis_df, train_test_split_tensor, numpy_to_cuda, tile 11 | from ward2icu.metrics import mean_feature_error, classify, tstr 12 | from ward2icu.samplers import BinaryBalancedSampler, IdentitySampler 13 | from ward2icu.trainers import BinaryClassificationTrainer, MinMaxBinaryCGANTrainer, SequenceTrainer 14 | from torch import optim 15 | from torch.utils.data import DataLoader 16 | from torchgan import losses 17 | from torchgan.metrics import ClassifierScore 18 | from torchgan.trainer import Trainer 19 | from slugify import slugify 20 | from dmpy.datascience import DataPond 21 | from ward2icu import make_logger 22 | 23 | logger = make_logger(__file__) 24 | 25 | @click.command() 26 | @click.option("--lr", type=float) 27 | @click.option("--epochs", type=int) 28 | @click.option("--ncritic", type=int) 29 | @click.option("--batch_size", type=int) 30 | @click.option("--dataset_transform", type=str) 31 | @click.option("--signals", type=int) 32 | @click.option("--gen_dropout", type=float) 33 | @click.option("--noise_size", type=int) 34 | @click.option("--hidden_size", type=int) 35 | @click.option("--flag", type=str) 36 | def cli(**opt): 37 | main(**opt) 38 | 39 | def main(**opt): 40 | logger.info(opt) 41 | batch_size = opt['batch_size'] if opt['batch_size'] != -1 else None 42 | 43 | dataset = TimeSeriesVitalSigns(transform=opt['dataset_transform'], 44 | vital_signs=opt['signals']) 45 | X = torch.from_numpy(dataset.X).cuda() 46 | y = torch.from_numpy(dataset.y).long().cuda() 47 | sampler = BinaryBalancedSampler(X, y, batch_size=batch_size) 48 | 49 | network = { 50 | 'generator': { 51 | 'name': CNNCGANGenerator, 52 | 'args': { 53 | 'output_size': opt['signals'], 54 | 'dropout': opt['gen_dropout'], 55 | 'noise_size': opt['noise_size'], 56 | 'hidden_size': opt['hidden_size'] 57 | }, 58 | 'optimizer': { 59 | 'name': optim.RMSprop, 60 | 'args': { 61 | 'lr': opt['lr'] 62 | } 63 | } 64 | }, 65 | 'discriminator': { 66 | 'name': CNNCGANDiscriminator, 67 | 'args': { 68 | 'input_size': opt['signals'], 69 | 'hidden_size': opt['hidden_size'] 70 | }, 71 | 'optimizer': { 72 | 'name': optim.RMSprop, 73 | 'args': { 74 | 'lr': opt['lr'] 75 | } 76 | } 77 | } 78 | } 79 | 80 | wasserstein_losses = [losses.WassersteinGeneratorLoss(), 81 | losses.WassersteinDiscriminatorLoss(), 82 | losses.WassersteinGradientPenalty()] 83 | 84 | logger.info(network) 85 | 86 | trainer = SequenceTrainer(models=network, 87 | recon=None, 88 | ncritic=opt['ncritic'], 89 | losses_list=wasserstein_losses, 90 | epochs=opt['epochs'], 91 | retain_checkpoints=1, 92 | checkpoints=f"{MODEL_DIR}/", 93 | mlflow_interval=50, 94 | device=DEVICE) 95 | 96 | trainer(sampler=sampler) 97 | trainer.log_to_mlflow() 98 | logger.info(trainer.generator) 99 | logger.info(trainer.discriminator) 100 | 101 | df_synth, X_synth, y_synth = synthesis_df(trainer.generator, dataset) 102 | 103 | logger.info(df_synth.sample(10)) 104 | logger.info(df_synth.groupby('cat_vital_sign')['value'].nunique() 105 | .div(df_synth.groupby('cat_vital_sign').size())) 106 | X_real = X.detach().cpu().numpy() 107 | mfe = np.abs(mean_feature_error(X_real, X_synth)) 108 | logger.info(f'Mean feature error: {mfe}') 109 | 110 | mlflow.set_tag('flag', opt['flag']) 111 | log_df(df_synth, 'synthetic/vital_signs') 112 | mlflow.log_metric('mean_feature_error', mfe) 113 | 114 | trainer_class = classify(X_synth, y_synth, epochs=2_000, batch_size=batch_size) 115 | trainer_tstr = tstr(X_synth, y_synth, X, y, epochs=3_000, batch_size=batch_size) 116 | log_model(trainer_class.model, 'models/classifier') 117 | log_model(trainer_tstr.model, 'models/tstr') 118 | 119 | if __name__ == '__main__': 120 | with mlflow.start_run(): 121 | with tempfile.TemporaryDirectory() as MODEL_DIR: 122 | if torch.cuda.is_available(): 123 | DEVICE = torch.device("cuda") 124 | torch.backends.cudnn.deterministic = True 125 | else: 126 | DEVICE = torch.device("cpu") 127 | logger.info(f'Running on device {DEVICE}') 128 | cli() 129 | -------------------------------------------------------------------------------- /ward2icu/models/rcgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ward2icu.models import RGANGenerator, RGANDiscriminator 4 | from ward2icu.utils import tile 5 | 6 | 7 | class RCGANGenerator(RGANGenerator): 8 | def __init__(self, 9 | sequence_length, 10 | output_size, 11 | num_classes, 12 | noise_size, 13 | prob_classes=None, 14 | label_embedding_size=None, 15 | **kwargs): 16 | """Recursive Conditional GAN (Generator) implementation with RNN cells. 17 | 18 | Notes: 19 | This model adds upon ward2icu.models.rgan. See docs of parent classes 20 | for more information. 21 | 22 | Args (check parent class for extra arguments): 23 | sequence_length (int): Number of points in the sequence. 24 | Defined by the real dataset. 25 | output_size (int): Size of output (usually the last tensor dimension). 26 | Defined by the real dataset. 27 | num_classes (int): Number of classes in the dataset. 28 | noise_size (int): Size of noise used to generate fake data. 29 | label_embedding_size (int, optional): Size of embedding dimensions. 30 | Defaults to num_classes. 31 | """ 32 | # Defaults 33 | label_embedding_size = label_embedding_size or num_classes 34 | if prob_classes is None: 35 | prob_classes = torch.ones(num_classes) 36 | 37 | #TODO(dsevero): this could cause problems 38 | kwargs['input_size'] = label_embedding_size + noise_size 39 | kwargs['noise_size'] = noise_size 40 | kwargs['label_type'] = "generated" 41 | kwargs['sequence_length'] = sequence_length 42 | kwargs['output_size'] = output_size 43 | 44 | super(RCGANGenerator, self).__init__(**kwargs) 45 | self.num_classes = num_classes 46 | self.label_embeddings = nn.Embedding(num_classes, label_embedding_size) 47 | self.label_embedding_size = label_embedding_size 48 | self.prob_classes = torch.Tensor(prob_classes) 49 | 50 | # Initialize all weights. 51 | # Already initialized in parent class 52 | 53 | def forward(self, z, y): 54 | # y must be tiled so that labels repeat across the sequence dimensions 55 | # y_tiled[:, i] == y_tiled[:, j] for all i and j 56 | # shape: (batch_size, sequence_length) 57 | y_tiled = tile(y, self.sequence_length) 58 | 59 | # shape: (batch-size, sequence_length, label_embedding_size) 60 | y_emb = self.label_embeddings(y_tiled.type(torch.LongTensor).to(y.device)) 61 | 62 | # shape: (batch-size, sequence_length, noise_size) 63 | z = z.view(-1, self.sequence_length, self.noise_size) 64 | 65 | # shape: (batch-size, encoding_dims) 66 | z_cond = torch.cat((z, y_emb), dim=2) 67 | 68 | # shape: (batch-size, sequence_length, output_size) 69 | return super(RCGANGenerator, self).forward(z_cond, reshape=False) 70 | 71 | def sampler(self, sample_size, device='cpu'): 72 | return [ 73 | torch.randn(sample_size, self.encoding_dims, device=device), 74 | torch.multinomial(self.prob_classes, sample_size, 75 | replacement=True).to(device) 76 | ] 77 | 78 | 79 | class RCGANDiscriminator(RGANDiscriminator): 80 | def __init__(self, 81 | sequence_length, 82 | num_classes, 83 | input_size, 84 | label_embedding_size=None, 85 | **kwargs): 86 | """Recursive Conditional GAN (Discriminator) implementation with RNN cells. 87 | 88 | Notes: 89 | This model adds upon ward2icu.models.rgan. See docs of parent class. 90 | for more information. 91 | 92 | Args (check parent class for extra arguments): 93 | sequence_length (int): Number of points in the sequence. 94 | num_classes (int): Number of classes in the dataset. 95 | input_size (int): Size of input (usually the last tensor dimension). 96 | label_embedding_size (int, optional): Size of embedding dimensions. 97 | Defaults to num_classes. 98 | """ 99 | 100 | # Defaults 101 | label_embedding_size = label_embedding_size or num_classes 102 | kwargs['input_size'] = label_embedding_size + input_size 103 | kwargs['label_type'] = "required" 104 | kwargs['sequence_length'] = sequence_length 105 | 106 | super(RCGANDiscriminator, self).__init__(**kwargs) 107 | self.num_classes = num_classes 108 | self.label_embeddings = nn.Embedding(num_classes, label_embedding_size) 109 | self.label_embedding_size = label_embedding_size 110 | 111 | # Initialize all weights. 112 | self._weight_initializer() 113 | 114 | def forward(self, x, y): 115 | # y must be tiled so that labels repeat across the sequence dimensions 116 | # y_tiled[:, i] == y_tiled[:, j] for all i and j 117 | # shape: (batch_size, sequence_length) 118 | y_tiled = tile(y, self.sequence_length) 119 | 120 | # shape: (batch-size, sequence_length, label_embedding_size) 121 | y_emb = self.label_embeddings(y_tiled.type(torch.LongTensor).to(y.device)) 122 | 123 | # shape: (batch-size, sequence_length, label_embedding_size + hidden_size) 124 | x_cond = torch.cat((x, y_emb), dim=2) 125 | return super(RCGANDiscriminator, self).forward(x_cond) 126 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pytest 4 | from torchgan.losses import MinimaxDiscriminatorLoss 5 | from sybric.models import (RGANGenerator, 6 | RGANDiscriminator, 7 | RCGANGenerator, 8 | RCGANDiscriminator, 9 | BinaryRNNClassifier, 10 | BinaryCNNClassifier, 11 | CNNCGANGenerator, 12 | CNNCGANDiscriminator, 13 | FCCMLPGANGenerator, 14 | FCCMLPGANDiscriminator) 15 | 16 | np.random.seed(3778) 17 | torch.manual_seed(3778) 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | DEVICE='cpu' 21 | 22 | @pytest.fixture 23 | def RGAN_setup(): 24 | batch_size = 3 25 | seq_length = 10 26 | 27 | # Generator 28 | gen_opt = dict(output_size=5, 29 | num_layers=1, 30 | dropout=0, 31 | noise_size=100, 32 | sequence_length=seq_length) 33 | 34 | # Discriminator 35 | dsc_opt = dict(input_size=gen_opt['output_size'], 36 | sequence_length=seq_length, 37 | hidden_size=None, 38 | num_layers=1, 39 | dropout=0) 40 | 41 | gen = RGANGenerator(**gen_opt) 42 | dsc = RGANDiscriminator(**dsc_opt) 43 | return gen, dsc, batch_size, seq_length 44 | 45 | 46 | def test_RGAN_forward(RGAN_setup): 47 | gen, dsc, batch_size, seq_length = RGAN_setup 48 | z, = gen.sampler(batch_size, DEVICE) 49 | assert tuple(z.shape) == (batch_size, seq_length*gen.noise_size) 50 | 51 | x = gen.forward(z) 52 | assert tuple(x.shape) == (batch_size, seq_length, gen.output_size) 53 | 54 | y = dsc.forward(x) 55 | assert tuple(y.shape) == (batch_size, seq_length, 1) 56 | 57 | 58 | 59 | def test_RCGAN(): 60 | batch_size = 3 61 | seq_length = 20 62 | label_emb_size = None 63 | prob_classes = [0, 0, 0, 0, 1] 64 | num_classes = len(prob_classes) 65 | 66 | # Generator 67 | gen_opt = dict(num_classes=num_classes, 68 | prob_classes=prob_classes, 69 | label_embedding_size=label_emb_size, 70 | output_size=5, 71 | num_layers=1, 72 | dropout=0, 73 | noise_size=10, 74 | hidden_size=100, 75 | sequence_length=seq_length) 76 | 77 | # Discriminator 78 | dsc_opt = dict(input_size=gen_opt['output_size'], 79 | sequence_length=seq_length, 80 | num_classes=num_classes, 81 | label_embedding_size=label_emb_size, 82 | num_layers=1, 83 | hidden_size=100, 84 | dropout=0) 85 | 86 | gen = RCGANGenerator(**gen_opt) 87 | dsc = RCGANDiscriminator(**dsc_opt) 88 | 89 | z, y = gen.sampler(batch_size, DEVICE) 90 | assert tuple(z.shape) == (batch_size, seq_length*gen.noise_size) 91 | assert tuple(y.shape) == (batch_size,) 92 | 93 | x = gen.forward(z, y) 94 | assert tuple(x.shape) == (batch_size, seq_length, gen.output_size) 95 | 96 | p = dsc.forward(x, y) 97 | assert tuple(p.shape) == (batch_size, seq_length, 1) 98 | 99 | # assert prob_classes 100 | assert (y == 4).all() 101 | 102 | 103 | def test_BinaryRNNClassifier(): 104 | seq_length = 10 105 | batch_size = 8 106 | hidden_size = 20 107 | input_size = 2 108 | 109 | model = BinaryRNNClassifier(sequence_length=seq_length, 110 | input_size=input_size, 111 | hidden_size=hidden_size) 112 | 113 | x = torch.randn(batch_size, seq_length, input_size) 114 | p = model.forward(x) 115 | assert tuple(p.shape) == (batch_size, seq_length) 116 | 117 | 118 | def test_BinaryCNNClassifier(): 119 | input_size = 2 120 | input_length = 10 121 | batch_size = 8 122 | 123 | model = BinaryCNNClassifier(input_size, 124 | input_length, 125 | kernel_size=3) 126 | 127 | x = torch.randn(batch_size, input_length, input_size) 128 | assert tuple(x.shape) == (batch_size, input_length, input_size) 129 | 130 | p = model.forward(x) 131 | assert tuple(p.shape) == (batch_size,) 132 | 133 | 134 | def test_FCCMLPGAN(): 135 | sequence_length = 20 136 | sequence_size = 5 137 | hidden_size = 10 138 | prob_classes = [0, 0, 0, 0, 1] 139 | num_classes = len(prob_classes) 140 | batch_size = 8 141 | 142 | # Generator 143 | gen_opt = dict(sequence_length=sequence_length, 144 | sequence_size=sequence_size, 145 | num_classes=num_classes, 146 | hidden_size=hidden_size, 147 | prob_classes=prob_classes) 148 | 149 | # Discriminator 150 | dsc_opt = dict(sequence_length=sequence_length, 151 | num_classes=num_classes, 152 | sequence_size=gen_opt['sequence_size'], 153 | hidden_size=hidden_size) 154 | 155 | gen = FCCMLPGANGenerator(**gen_opt) 156 | dsc = FCCMLPGANDiscriminator(**dsc_opt) 157 | 158 | z, y = gen.sampler(batch_size, DEVICE) 159 | assert tuple(z.shape) == (batch_size, sequence_length*gen.noise_size) 160 | assert tuple(y.shape) == (batch_size,) 161 | 162 | x = gen.forward(z, y) 163 | assert tuple(x.shape) == (batch_size, sequence_length, sequence_size) 164 | 165 | p = dsc.forward(x, y) 166 | assert tuple(p.shape) == (batch_size, sequence_length) 167 | 168 | # assert prob_classes 169 | assert (y == 4).all() 170 | -------------------------------------------------------------------------------- /ward2icu/models/cnngan.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: https://arxiv.org/abs/1806.01875 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import (Linear, 8 | Conv1d, 9 | MaxPool1d, 10 | AvgPool1d, 11 | Upsample, 12 | ReplicationPad1d, 13 | LeakyReLU, 14 | Flatten, 15 | Dropout) 16 | from torchgan.models import Generator, Discriminator 17 | from ward2icu.utils import tile 18 | from ward2icu.layers import Conv1dLayers, View, AppendEmbedding 19 | from ward2icu import set_seeds 20 | 21 | 22 | class CNNCGANGenerator(Generator): 23 | def __init__(self, 24 | output_size, 25 | dropout=0, 26 | noise_size=20, 27 | channel_last=True, 28 | hidden_size=50, 29 | **kwargs): 30 | # defaults 31 | self.initial_length = 5 32 | self.output_size = output_size 33 | self.hidden_size = hidden_size 34 | self.encoding_dims = noise_size 35 | self.num_classes = 2 36 | self.label_embedding_size = self.num_classes 37 | self.prob_classes = torch.ones(self.num_classes) 38 | self.dropout = dropout 39 | self.label_type = 'generated' 40 | self.channel_last = channel_last 41 | self._labels = None 42 | 43 | # Set kwargs (might overried above attributes) 44 | for key, value in kwargs.items(): 45 | setattr(self, key, value) 46 | 47 | super(CNNCGANGenerator, self).__init__(self.encoding_dims, 48 | self.label_type) 49 | self.label_embeddings = nn.Embedding(self.num_classes, 50 | self.label_embedding_size) 51 | 52 | Conv1d_ = lambda k: Conv1d(self.hidden_size, self.hidden_size, k) 53 | ## Build CNN layer 54 | # (batch_size, channels, sequence length) 55 | 56 | layers_input_size = self.encoding_dims*(1 + self.label_embedding_size) 57 | self.layers = nn.Sequential( 58 | Linear(layers_input_size, self.initial_length*self.hidden_size), 59 | View(-1, self.hidden_size, self.initial_length), 60 | LeakyReLU(0.2), 61 | Dropout(dropout), 62 | # output size: (-1, 50, 5) 63 | 64 | Upsample(scale_factor=2, mode='linear', align_corners=True), 65 | # output size: (-1, 50, 10) 66 | 67 | Conv1d_(3), 68 | ReplicationPad1d(1), 69 | LeakyReLU(0.2), 70 | Conv1d_(3), 71 | ReplicationPad1d(1), 72 | LeakyReLU(0.2), 73 | Dropout(dropout), 74 | # output size: (-1, 50, 10) 75 | 76 | Upsample(scale_factor=2, mode='linear', align_corners=True), 77 | # output size: (-1, 50, 20) 78 | 79 | Conv1d_(3), 80 | ReplicationPad1d(1), 81 | LeakyReLU(0.2), 82 | Conv1d_(3), 83 | ReplicationPad1d(1), 84 | LeakyReLU(0.2), 85 | Dropout(dropout), 86 | # output size: (-1, 50, 20) 87 | 88 | Conv1d(self.hidden_size, self.output_size, 1) 89 | # output size: (-1, 5, 20) 90 | ) 91 | 92 | # Initialize all weights. 93 | self._weight_initializer() 94 | 95 | def forward(self, z, y): 96 | y_tiled = tile(y, z.size(-1)) 97 | y_emb = self.label_embeddings(y_tiled 98 | .type(torch.LongTensor) 99 | .to(y.device)) 100 | z = torch.cat((z.unsqueeze(-1), y_emb), dim=-1) 101 | z = z.flatten(1) 102 | x = self.layers(z) 103 | return x.permute(0, 2, 1) if self.channel_last else x 104 | 105 | def sampler(self, sample_size, device='cpu'): 106 | return [ 107 | torch.randn(sample_size, self.encoding_dims, device=device), 108 | torch.multinomial(self.prob_classes, sample_size, 109 | replacement=True).to(device) 110 | ] 111 | 112 | 113 | class CNNCGANDiscriminator(Generator): 114 | def __init__(self, 115 | input_size, 116 | channel_last=True, 117 | hidden_size=50, 118 | **kwargs): 119 | 120 | # Defaults 121 | self.input_length = 20 122 | self.input_size = input_size 123 | self.hidden_size = hidden_size 124 | self.num_classes = 2 125 | self.label_embedding_size = self.num_classes 126 | self.prob_classes = torch.ones(self.num_classes) 127 | self.dropout = 0.5 128 | self.label_type = 'required' 129 | self.channel_last = channel_last 130 | 131 | 132 | # Set kwargs (might overried above attributes) 133 | for key, value in kwargs.items(): 134 | setattr(self, key, value) 135 | 136 | super(CNNCGANDiscriminator, self).__init__(self.input_size, 137 | self.label_type) 138 | 139 | # Build CNN layer 140 | self.label_embeddings = nn.Embedding(self.num_classes, 141 | self.label_embedding_size) 142 | 143 | Conv1d_ = lambda k: Conv1d(hidden_size, hidden_size, k) 144 | layers_input_size = self.input_size + self.label_embedding_size 145 | layers_output_size = 5*hidden_size 146 | ## Build CNN layer 147 | self.layers = nn.Sequential( 148 | Conv1d(layers_input_size, hidden_size, 1), 149 | LeakyReLU(0.2), 150 | # output size: (-1, 50, 20) 151 | 152 | Conv1d_(3), 153 | ReplicationPad1d(1), 154 | LeakyReLU(0.2), 155 | Conv1d_(3), 156 | ReplicationPad1d(1), 157 | LeakyReLU(0.2), 158 | AvgPool1d(2, 2), 159 | # output size: (-1, 50, 10) 160 | 161 | Conv1d_(3), 162 | ReplicationPad1d(1), 163 | LeakyReLU(0.2), 164 | Conv1d_(3), 165 | ReplicationPad1d(1), 166 | LeakyReLU(0.2), 167 | AvgPool1d(2, 2), 168 | # output size: (-1, 50, 5) 169 | 170 | Flatten(), 171 | Linear(layers_output_size, 1) 172 | # output size: (-1, 1) 173 | ) 174 | 175 | # Initialize all weights. 176 | self._weight_initializer() 177 | 178 | def forward(self, x, y): 179 | if self.channel_last: 180 | x = x.permute(0, 2, 1) 181 | y_tiled = tile(y, x.size(-1)) 182 | y_emb = self.label_embeddings(y_tiled 183 | .type(torch.LongTensor) 184 | .to(y.device)) 185 | y_emb = y_emb.permute(0, 2, 1) 186 | 187 | x = torch.cat((x, y_emb), dim=1) 188 | x = self.layers(x) 189 | return x.squeeze() 190 | -------------------------------------------------------------------------------- /ward2icu/models/rgan.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Reference: https://arxiv.org/abs/1706.02633 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchgan.models import Generator, Discriminator 8 | from ward2icu.layers import rnn_layer 9 | 10 | 11 | class RGANGenerator(Generator): 12 | def __init__(self, 13 | sequence_length, 14 | output_size, 15 | hidden_size=None, 16 | noise_size=None, 17 | num_layers=1, 18 | dropout=0, 19 | rnn_nonlinearity='tanh', 20 | rnn_type='rnn', 21 | input_size=None, 22 | last_layer=None, 23 | **kwargs): 24 | """Recursive GAN (Generator) implementation with RNN cells. 25 | 26 | Layers: 27 | RNN (with activation, multiple layers): 28 | input: (batch_size, sequence_length, noise_size) 29 | output: (batch_size, sequence_length, hidden_size) 30 | 31 | Linear (no activation, weights shared between time steps): 32 | input: (batch_size, sequence_length, hidden_size) 33 | output: (batch_size, sequence_length, output_size) 34 | 35 | last_layer (optional) 36 | input: (batch_size, sequence_length, output_size) 37 | 38 | Args: 39 | sequence_length (int): Number of points in the sequence. 40 | Defined by the real dataset. 41 | output_size (int): Size of output (usually the last tensor dimension). 42 | Defined by the real dataset. 43 | hidden_size (int, optional): Size of RNN output. 44 | Defaults to output_size. 45 | noise_size (int, optional): Size of noise used to generate fake data. 46 | Defaults to output_size. 47 | num_layers (int, optional): Number of stacked RNNs in rnn. 48 | dropout (float, optional): Dropout probability for rnn layers. 49 | rnn_nonlinearity (str, optional): Non-linearity of the RNN. Must be 50 | either 'tanh' or 'relu'. Only valid if rnn_type == 'rnn'. 51 | rnn_type (str, optional): Type of RNN layer. Valid values are 'lstm', 52 | 'gru' and 'rnn', the latter being the default. 53 | input_size (int, optional): Input size of RNN, defaults to noise_size. 54 | last_layer (Module, optional): Last layer of the discriminator. 55 | """ 56 | 57 | 58 | # Defaults 59 | noise_size = noise_size or output_size 60 | input_size = input_size or noise_size 61 | hidden_size = hidden_size or output_size 62 | 63 | self.sequence_length = sequence_length 64 | self.hidden_size = hidden_size 65 | self.output_size = output_size 66 | self.noise_size = noise_size 67 | self.num_layers = num_layers 68 | self.dropout = dropout 69 | self.rnn_nonlinearity = rnn_nonlinearity 70 | self.rnn_type = rnn_type 71 | self.input_size = input_size 72 | self.label_type = "none" 73 | 74 | # Set kwargs (might overried above attributes) 75 | for key, value in kwargs.items(): 76 | setattr(self, key, value) 77 | 78 | # Total size of z that will be sampled. Later, in the forward 79 | # method, we resize to (batch_size, sequence_length, noise_size). 80 | # TODO: Any resizing of z is valid as long as the total size 81 | # remains sequence_length*noise_size. How does this affect 82 | # the performance of the RNN? 83 | self.encoding_dims = sequence_length*noise_size 84 | 85 | super(RGANGenerator, self).__init__(self.encoding_dims, 86 | self.label_type) 87 | 88 | # Build RNN layer 89 | self.rnn = rnn_layer(input_size=input_size, 90 | hidden_size=hidden_size, 91 | num_layers=num_layers, 92 | dropout=dropout, 93 | rnn_type=rnn_type, 94 | nonlinearity=rnn_nonlinearity) 95 | self.dropout = nn.Dropout(dropout) 96 | # self.batchnorm = nn.BatchNorm1d(hidden_size) 97 | self.linear = nn.Linear(hidden_size, output_size) 98 | self.last_layer = last_layer 99 | 100 | # Initialize all weights. 101 | # nn.init.xavier_normal_(self.rnn) 102 | nn.init.xavier_normal_(self.linear.weight) 103 | 104 | def forward(self, z, reshape=True): 105 | if reshape: 106 | z = z.view(-1, self.sequence_length, self.noise_size) 107 | y, _ = self.rnn(z) 108 | y = self.dropout(y) 109 | # y = self.batchnorm(y.permute(0, 2, 1)).permute(0, 2, 1) 110 | y = self.linear(y) 111 | return y if self.last_layer is None else self.last_layer(y) 112 | 113 | 114 | class RGANDiscriminator(Discriminator): 115 | def __init__(self, 116 | sequence_length, 117 | input_size, 118 | hidden_size=None, 119 | num_layers=1, 120 | dropout=0, 121 | rnn_nonlinearity='tanh', 122 | rnn_type='rnn', 123 | last_layer=None, 124 | **kwargs): 125 | """Recursive GAN (Discriminator) implementation with RNN cells. 126 | 127 | Layers: 128 | RNN (with activation, multiple layers): 129 | input: (batch_size, sequence_length, input_size) 130 | output: (batch_size, sequence_length, hidden_size) 131 | 132 | Linear (no activation, weights shared between time steps): 133 | input: (batch_size, sequence_length, hidden_size) 134 | output: (batch_size, sequence_length, 1) 135 | 136 | last_layer (optional) 137 | input: (batch_size, sequence_length, 1) 138 | 139 | Args: 140 | sequence_length (int): Number of points in the sequence. 141 | input_size (int): Size of input (usually the last tensor dimension). 142 | hidden_size (int, optional): Size of hidden layers in rnn. 143 | If None, defaults to input_size. 144 | num_layers (int, optional): Number of stacked RNNs in rnn. 145 | dropout (float, optional): Dropout probability for rnn layers. 146 | rnn_nonlinearity (str, optional): Non-linearity of the RNN. Must be 147 | either 'tanh' or 'relu'. Only valid if rnn_type == 'rnn'. 148 | rnn_type (str, optional): Type of RNN layer. Valid values are 'lstm', 149 | 'gru' and 'rnn', the latter being the default. 150 | last_layer (Module, optional): Last layer of the discriminator. 151 | """ 152 | 153 | # TODO: Insert non-linearities between Linear layers. 154 | # TODO: Add BatchNorm and Dropout as in https://arxiv.org/abs/1905.05928v1 155 | 156 | # Set hidden_size to input_size if not specified 157 | hidden_size = hidden_size or input_size 158 | 159 | self.input_size = input_size 160 | self.sequence_length = sequence_length 161 | self.hidden_size = hidden_size 162 | self.num_layers = num_layers 163 | self.dropout = dropout 164 | self.label_type = "none" 165 | 166 | # Set kwargs (might overried above attributes) 167 | for key, value in kwargs.items(): 168 | setattr(self, key, value) 169 | 170 | super(RGANDiscriminator, self).__init__(self.input_size, 171 | self.label_type) 172 | 173 | # Build RNN layer 174 | self.rnn = rnn_layer(input_size=input_size, 175 | hidden_size=hidden_size, 176 | num_layers=num_layers, 177 | dropout=dropout, 178 | rnn_type=rnn_type, 179 | nonlinearity=rnn_nonlinearity) 180 | self.dropout = nn.Dropout(dropout) 181 | self.linear = nn.Linear(hidden_size, 1) 182 | self.last_layer = last_layer 183 | 184 | # Initialize all weights. 185 | # nn.init.xavier_normal_(self.rnn) 186 | nn.init.xavier_normal_(self.linear.weight) 187 | 188 | def forward(self, x): 189 | y, _ = self.rnn(x) 190 | y = self.dropout(y) 191 | y = self.linear(y) 192 | return y if self.last_layer is None else self.last_layer(y) 193 | -------------------------------------------------------------------------------- /tests/test_trainers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from torch.optim import SGD 7 | from torch.utils.data import DataLoader, Dataset 8 | from sybric.trainers import (BinaryClassificationTrainer, 9 | MinMaxBinaryCGANTrainer) 10 | from sybric.samplers import IdentitySampler 11 | from sybric.models import RCGANGenerator, RCGANDiscriminator 12 | from itertools import chain 13 | 14 | @pytest.fixture 15 | def binaryclass_trainer(): 16 | batch_size = 8 17 | input_size = 3 18 | num_classes = 2 19 | 20 | X_train = torch.randn(batch_size, input_size, device='cpu').float() 21 | y_train = torch.randint(0, num_classes, (batch_size,), device='cpu').float() 22 | 23 | X_test = torch.randn(batch_size, input_size, device='cpu').float() 24 | y_test = torch.randint(0, num_classes, (batch_size,), device='cpu').float() 25 | 26 | sampler_train = IdentitySampler(X_train, y_train) 27 | sampler_test = IdentitySampler(X_test, y_test) 28 | 29 | class Model(nn.Module): 30 | def __init__(self): 31 | super(Model, self).__init__() 32 | self.linear = nn.Linear(input_size, 1) 33 | def forward(self, x): 34 | return self.linear(x).squeeze() 35 | 36 | model = Model().cpu() 37 | optimizer = SGD(model.parameters(), lr=100) 38 | binaryclass_trainer = BinaryClassificationTrainer(optimizer=optimizer, 39 | sampler_train=sampler_train, 40 | sampler_test=sampler_test, 41 | model=model) 42 | return binaryclass_trainer 43 | 44 | @pytest.fixture 45 | def binaryclass_tiled_trainer(): 46 | batch_size = 8 47 | input_size = 3 48 | num_classes = 2 49 | 50 | X_train = torch.randn(batch_size, input_size, device='cpu').float() 51 | y_train = torch.randint(0, num_classes, (batch_size,), device='cpu').float() 52 | 53 | X_test = torch.randn(batch_size, input_size, device='cpu').float() 54 | y_test = torch.randint(0, num_classes, (batch_size,), device='cpu').float() 55 | 56 | sampler_train = IdentitySampler(X_train, y_train, tile=True) 57 | sampler_test = IdentitySampler(X_test, y_test, tile=True) 58 | 59 | class Model(nn.Module): 60 | def __init__(self): 61 | super(Model, self).__init__() 62 | self.linear = nn.Linear(input_size, 1) 63 | def forward(self, x): 64 | return self.linear(x).squeeze() 65 | 66 | model = Model().cpu() 67 | optimizer = SGD(model.parameters(), lr=100) 68 | binaryclass_trainer_tiled = BinaryClassificationTrainer(optimizer=optimizer, 69 | sampler_train=sampler_train, 70 | sampler_test=sampler_test, 71 | model=model) 72 | return binaryclass_trainer_tiled 73 | 74 | 75 | @pytest.fixture 76 | def minmaxgan_trainer(): 77 | data_size = 10 78 | input_length = 20 79 | num_classes = 2 80 | noise_size = 10 81 | output_size = 5 82 | input_size = output_size 83 | encoding_dims = input_length*input_size 84 | 85 | X = torch.randn(data_size, input_length, output_size) 86 | y = torch.randint(0, num_classes, (data_size,)) 87 | sampler = IdentitySampler(X, y) 88 | 89 | generator = RCGANGenerator(input_length, output_size, 90 | num_classes, noise_size) 91 | discriminator = RCGANDiscriminator(input_length, num_classes, input_size) 92 | optimizer_gen = SGD(generator.parameters(), lr=0.1) 93 | optimizer_dsc = SGD(discriminator.parameters(), lr=0.1) 94 | trainer = MinMaxBinaryCGANTrainer(generator, 95 | discriminator, 96 | optimizer_gen, 97 | optimizer_dsc, 98 | sampler) 99 | return trainer 100 | 101 | def test_MinMaxGANTrainer_train(minmaxgan_trainer): 102 | def parameters(): 103 | yield from minmaxgan_trainer.generator.parameters() 104 | yield from minmaxgan_trainer.discriminator.parameters() 105 | 106 | old_params = [p.detach().clone() for p in parameters()] 107 | minmaxgan_trainer.train(epochs=1) 108 | new_params = [p.detach().clone() for p in parameters()] 109 | 110 | # assert parameters are being trained 111 | assert all([(old != new).all() 112 | for old, new in zip(old_params, new_params)]) 113 | 114 | 115 | def test_MinMaxGANTrainer_calc_dsc_loss(minmaxgan_trainer): 116 | dsc_logits_real = torch.Tensor([[[0], 117 | [0], 118 | [0]], 119 | 120 | [[1], 121 | [1], 122 | [0]]]) 123 | 124 | dsc_logits_fake = torch.Tensor([[[0], 125 | [0], 126 | [0]], 127 | 128 | [[1], 129 | [1], 130 | [0]]]) 131 | 132 | loss_real = minmaxgan_trainer.calc_dsc_real_loss(dsc_logits_real) 133 | loss_fake = minmaxgan_trainer.calc_dsc_fake_loss(dsc_logits_fake) 134 | loss = loss_real + loss_fake 135 | e =-0.5*(np.log(1 - torch.sigmoid(dsc_logits_real)).mean() + 136 | np.log(torch.sigmoid(dsc_logits_fake)).mean()) 137 | assert np.isclose(e, loss.item()) 138 | 139 | 140 | def test_MinMaxGANTrainer_calc_gen_loss(minmaxgan_trainer): 141 | dsc_logits = torch.Tensor([[[0], 142 | [0], 143 | [0]], 144 | 145 | [[1], 146 | [1], 147 | [0]]]) 148 | 149 | loss = minmaxgan_trainer.calc_gen_loss(dsc_logits) 150 | e = -np.log(torch.sigmoid(dsc_logits)).mean() 151 | assert np.isclose(e, loss.item()) 152 | 153 | 154 | def test_MinMaxGANTrainer_calculate_metrics(minmaxgan_trainer): 155 | loss_gen = torch.Tensor([0]) 156 | loss_dsc = torch.Tensor([1]) 157 | metrics = minmaxgan_trainer.calculate_metrics(loss_gen, loss_dsc) 158 | assert metrics['generator_loss'] == 0 159 | assert metrics['discriminator_loss'] == 1 160 | assert np.isnan(metrics['generator_mean_abs_grad']) 161 | assert np.isnan(metrics['discriminator_mean_abs_grad']) 162 | 163 | 164 | def test_BinaryClassificationTrainer_train(binaryclass_trainer): 165 | old_params = [p.detach().clone() 166 | for p in binaryclass_trainer.model.parameters()] 167 | binaryclass_trainer.train(epochs=1) 168 | new_params = [p.detach().clone() 169 | for p in binaryclass_trainer.model.parameters()] 170 | 171 | # assert parameters are being trained 172 | assert all([(old != new).all() for old, new in zip(old_params, new_params)]) 173 | 174 | 175 | def test_BinaryClassificationTrainer_calculate_metrics(binaryclass_trainer): 176 | logits = torch.Tensor([-9, -8, 8, 9, -1, 8]) 177 | probs = torch.sigmoid(logits) 178 | y_true = torch.Tensor([ 0, 0, 0, 0, 1, 1]) 179 | y_pred = torch.Tensor([ 0, 0, 1, 1, 0, 1]) 180 | 181 | assert (y_pred == probs.round()).all() 182 | 183 | r = binaryclass_trainer.calculate_metrics(y_true, y_pred, 184 | logits, probs) 185 | 186 | loss = -np.log(probs[y_true==1]).sum() - np.log(1 - probs[y_true==0]).sum() 187 | loss /= probs.shape[0] 188 | assert np.isclose(r['_accuracy'], 3/6) 189 | assert np.isclose(r['_balanced_accuracy'], (2/4)*(2/6) + (1/2)*(4/6)) # 0.5 190 | assert np.isclose(r['_accuracy_0'], 2/4) 191 | assert np.isclose(r['_accuracy_1'], 1/2) 192 | assert np.isclose(r['_loss'], loss) 193 | assert np.isclose(r['_loss_0'], -np.log(1 - probs[y_true==0]).mean()) 194 | assert np.isclose(r['_loss_1'], -np.log(probs[y_true==1]).mean()) 195 | assert np.isclose(r['_matthews'], 0) 196 | 197 | 198 | def test_BinaryClassificationTrainer_calculate_metrics_tiled(binaryclass_tiled_trainer): 199 | logits = torch.Tensor([[-9, -8, 8, 9, -1, 8], 200 | [-9, -8, 8, 9, -1, 8]]).T 201 | probs = torch.sigmoid(logits) 202 | y_true = torch.Tensor([[ 0, 0, 0, 0, 1, 1], 203 | [ 0, 0, 0, 0, 1, 1]]).T 204 | y_pred = torch.Tensor([[ 0, 0, 1, 1, 0, 1], 205 | [ 0, 0, 1, 1, 0, 1]]).T 206 | 207 | assert (y_pred == probs.round()).all() 208 | 209 | r = binaryclass_tiled_trainer.calculate_metrics(y_true, y_pred, 210 | logits, probs) 211 | 212 | loss = 0.5*(-np.log(probs[y_true==1]).sum() - np.log(1 - probs[y_true==0]).sum()) 213 | loss /= probs.shape[0] 214 | assert np.isclose(r['_accuracy'], 3/6) 215 | assert np.isclose(r['_balanced_accuracy'], (2/4)*(2/6) + (1/2)*(4/6)) # 0.5 216 | assert np.isclose(r['_accuracy_0'], 2/4) 217 | assert np.isclose(r['_accuracy_1'], 1/2) 218 | assert np.isclose(r['_loss'], loss) 219 | assert np.isclose(r['_loss_0'], -np.log(1 - probs[y_true==0]).mean()) 220 | assert np.isclose(r['_loss_1'], -np.log(probs[y_true==1]).mean()) 221 | assert np.isclose(r['_matthews'], 0) 222 | --------------------------------------------------------------------------------