├── assets
└── table3.png
├── ward2icu
├── models
│ ├── __init__.py
│ ├── classifiers.py
│ ├── rcgan.py
│ ├── cnngan.py
│ └── rgan.py
├── __init__.py
├── utils.py
├── samplers.py
├── metrics.py
├── layers.py
└── trainers.py
├── Makefile
├── tests
├── test_utils.py
├── test_samplers.py
├── test_data.py
├── test_models.py
└── test_trainers.py
├── .gitignore
├── README.md
└── run-experiment.py
/assets/table3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/3778/Ward2ICU/HEAD/assets/table3.png
--------------------------------------------------------------------------------
/ward2icu/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .classifier import *
2 | from .gan import *
3 | from .rgan import *
4 | from .rcgan import *
5 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: README.md
2 |
3 | bin/gh-md-toc:
4 | mkdir -p bin
5 | wget https://raw.githubusercontent.com/ekalinin/github-markdown-toc/master/gh-md-toc
6 | chmod a+x gh-md-toc
7 | mv gh-md-toc bin/
8 |
9 | README.md: bin/gh-md-toc
10 | ./bin/gh-md-toc --insert README.md
11 | rm -f README.md.orig.* README.md.toc.*
12 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 | from sybric.utils import tile
3 |
4 |
5 | def test_tile():
6 | y = Tensor([0, 1, 2])
7 | y_tiled = tile(y, 3)
8 | expected = Tensor([[0, 0, 0],
9 | [1, 1, 1],
10 | [2, 2, 2]])
11 | assert (y_tiled == expected).all()
12 |
--------------------------------------------------------------------------------
/ward2icu/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import numpy as np
4 | import torch
5 | from pathlib import Path
6 |
7 |
8 | def get_project_root() -> Path:
9 | """Returns project root folder."""
10 | return Path(__file__).parent.parent
11 |
12 |
13 | def get_data_dir() -> Path:
14 | return get_project_root() / 'data'
15 |
16 |
17 | def make_logger(file_: str = "NO_FILE") -> logging.Logger:
18 | log_level = getattr(logging, os.getenv("LOG_LEVEL", "INFO"))
19 | fmt = "%(asctime)s %(name)-12s %(levelname)-8s %(message)s"
20 | logging.basicConfig(level=log_level, format=fmt)
21 | return logging.getLogger(file_.split("/")[-1])
22 |
23 |
24 | def set_seeds():
25 | np.random.seed(3778)
26 | torch.manual_seed(3778)
27 |
--------------------------------------------------------------------------------
/ward2icu/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from math import floor
4 | from ward2icu import make_logger
5 |
6 | logger = make_logger(__file__)
7 |
8 |
9 | def tile(t, length):
10 | ''' Creates an extra dimension on the tensor t and
11 | repeats it throughout.'''
12 | return t.view(-1, 1).repeat(1, length)
13 |
14 |
15 | def calc_conv_output_length(conv_layer,
16 | input_length):
17 | def _maybe_slice(x):
18 | return x[0] if isinstance(x, tuple) else x
19 |
20 | l = input_length
21 | p =_maybe_slice(conv_layer.padding)
22 | d =_maybe_slice(conv_layer.dilation)
23 | k =_maybe_slice(conv_layer.kernel_size)
24 | s =_maybe_slice(conv_layer.stride)
25 | return floor((l + 2*p - d*(k-1)-1)/s + 1)
26 |
27 |
28 | def flatten(l):
29 | return [item for sublist in l for item in sublist]
30 |
31 |
32 | def numpy_to_cuda(*args):
33 | return [torch.from_numpy(a).cuda() for a in args]
34 |
35 |
36 | def n_(tensor):
37 | return tensor.detach().cpu().numpy()
38 |
--------------------------------------------------------------------------------
/tests/test_samplers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from sybric.samplers import BinaryBalancedSampler, IdentitySampler
3 |
4 |
5 | def test_BinaryBalancedSampler():
6 | X = torch.eye(5)
7 | y = torch.Tensor([0, 0, 0, 1, 1])
8 | sampler = BinaryBalancedSampler(X, y)
9 | for _ in range(100):
10 | X_s, y_s = sampler.sample()
11 |
12 | assert (y_s == torch.Tensor([0, 0, 1, 1])).all()
13 | assert X_s.shape == (4, 5)
14 | assert any((X_s[0] == X[i]).all() for i in [0, 1, 2, 3])
15 | assert (X_s[-1] == X[-1]).all()
16 |
17 | def test_BinaryBalancedSampler_batch():
18 | X = torch.eye(5)
19 | y = torch.Tensor([0, 0, 0, 1, 1])
20 | sampler = BinaryBalancedSampler(X, y, batch_size=2)
21 | for _ in range(100):
22 | X_s, y_s = sampler.sample()
23 |
24 | assert (y_s == torch.Tensor([0, 1])).all()
25 | assert X_s.shape == (2, 5)
26 | assert any((X_s[0] == X[i]).all() for i in [0, 1, 2, 3])
27 | assert any((X_s[-1] == X[i]).all() for i in [-1, -2])
28 |
29 | def test_IdentitySampler():
30 | X = torch.ones(4, 2, 6)
31 | y = torch.Tensor([0, 0, 1, 1])
32 | sampler = IdentitySampler(X, y, tile=True)
33 | X_s, y_s = sampler.sample()
34 |
35 | assert X_s.shape == (4, 2, 6)
36 | assert y_s.shape == (4, 2)
37 | assert (y_s == torch.Tensor([[0, 0],
38 | [0, 0],
39 | [1, 1],
40 | [1, 1]])).all()
41 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | # others
107 | bin/
108 |
--------------------------------------------------------------------------------
/tests/test_data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | from sybric.data import TimeSeriesVitalSigns
4 |
5 |
6 | dataset = TimeSeriesVitalSigns()
7 |
8 | def test_TimeSeriesVitalSigns_inv_transforms():
9 | d = dataset
10 | X = d.X.astype(np.float128)
11 | y = d.y
12 | s = d.synthesis_df(X, y)
13 | assert np.isclose(d.inv_whiten(d.whiten(X)), X).all()
14 | assert np.isclose(d.inv_normalize(d.normalize(X)), X).all()
15 | assert np.isclose(d.inv_minmax(d.minmax(X)), X).all()
16 | assert np.isclose(d.inv_minmax_signals(d.minmax_signals(X)), X).all()
17 |
18 | columns = ['cat_vital_sign', 't', 'class']
19 | assert (s[columns] == d.df[columns]).all().all()
20 | assert np.isclose(s['value'], d.df['value']).all()
21 |
22 | def test_TimeSeriesVitalSigns_transform_minmax():
23 | X = np.array([
24 | # patient 1
25 | [[36, 90], [40, 100], [40, 100]],
26 | # patient 2
27 | [[40, 80], [36, 90], [36, 90]],
28 | # patient 3
29 | [[40, 80], [36, 90], [36, 90]]
30 | ])
31 |
32 | e = np.array([
33 | # patient 1
34 | [[-1, 1], [1, 1], [1, 1]],
35 | # patient 2
36 | [[1, -1], [-1, -1], [-1, -1]],
37 | # patient 3
38 | [[1, -1], [-1, -1], [-1, -1]]
39 | ])
40 |
41 | assert (dataset.minmax(X, X) == e).all()
42 | assert X.shape == (3, 3, 2)
43 |
44 | def test_TimeSeriesVitalSigns_transform_minmax_signals():
45 | X = np.array([
46 | # patient 1
47 | [[36, 90], [40, 100], [40, 100]],
48 | # patient 2
49 | [[40, 80], [36, 90], [36, 90]],
50 | # patient 2
51 | [[40, 80], [36, 90], [36, 90]],
52 | # patient 2
53 | [[40, 80], [36, 90], [36, 90]],
54 | ])
55 |
56 | # signal 1 = [36, 40, 40, 36] -> [-1, 1, 1, -1]
57 | # signal 2 = [90, 100, 80, 90] -> [ 0, 1, -1, 0]
58 |
59 | e = np.array([
60 | # patient 1
61 | [[-1, 0], [1, 1], [1, 1]],
62 | # patient 2
63 | [[1, -1], [-1, 0], [-1, 0]],
64 | # patient 2
65 | [[1, -1], [-1, 0], [-1, 0]],
66 | # patient 2
67 | [[1, -1], [-1, 0], [-1, 0]]
68 | ])
69 |
70 | assert (dataset.minmax_signals(X, X) == e).all()
71 | assert X.shape == (4, 3, 2)
72 | assert e.shape == (4, 3, 2)
73 |
--------------------------------------------------------------------------------
/ward2icu/samplers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ward2icu.utils import tile as tile_func
3 |
4 |
5 | class IdentitySampler:
6 | def __init__(self, X, y, tile=False):
7 | self.X = X
8 | self.y = tile_func(y, X.shape[1]) if tile else y
9 | self.tile = tile
10 | self.device = X.device
11 |
12 | def sample(self):
13 | return self.X, self.y
14 |
15 | def __len__(self):
16 | return self.y.shape[0]
17 |
18 |
19 | class BinaryBalancedSampler:
20 | def __init__(self, X, y, tile=False, batch_size=None):
21 | assert len(y.unique()) <= 2
22 | self.sequence_length = X.shape[1]
23 | self.batch_size = batch_size
24 | self.tile = tile
25 | self.device = X.device
26 |
27 | self.majority_class = y.mode().values
28 | self.minority_class = y[y != self.majority_class].mode().values
29 |
30 | self.majority_mask = (y == self.majority_class)
31 | self.minority_mask = (y == self.minority_class)
32 |
33 | self.majority_count = int(self.majority_mask.sum())
34 | self.minority_count = int(self.minority_mask.sum())
35 |
36 | if batch_size:
37 | assert batch_size/2 < self.minority_count
38 |
39 | self.X_majority = X[self.majority_mask]
40 | self.X_minority = X[self.minority_mask]
41 |
42 | self.y_majority = y[self.majority_mask]
43 | self.y_minority = y[self.minority_mask]
44 | self.should_sample_minority_class = self.batch_size is not None
45 |
46 | def sample(self):
47 | batch_size_half = (int(self.batch_size/2)
48 | if self.batch_size is not None
49 | else self.minority_count)
50 |
51 | idx_maj, idx_min = self._create_idxs(batch_size_half)
52 |
53 | X_majority_batch = self.X_majority[idx_maj]
54 | y_majority_batch = self.y_majority[idx_maj]
55 |
56 | X_minority_batch = (self.X_minority[idx_min]
57 | if self.should_sample_minority_class
58 | else self.X_minority)
59 | y_minority_batch = (self.y_minority[idx_min]
60 | if self.should_sample_minority_class
61 | else self.y_minority)
62 |
63 | X = torch.cat((X_majority_batch, X_minority_batch), dim=0)
64 | y = torch.cat((y_majority_batch, y_minority_batch), dim=0)
65 |
66 | y = tile_func(y, self.sequence_length) if self.tile else y
67 | return X, y
68 |
69 | def __len__(self):
70 | return int(self.batch_size or 2*self.minority_count)
71 |
72 | def _create_idxs(self, batch_size):
73 | idx_maj = torch.randperm(self.majority_count)[:batch_size]
74 | idx_min = (torch.randperm(self.minority_count)[:batch_size]
75 | if self.should_sample_minority_class else None)
76 | return idx_maj, idx_min
77 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ---
2 |
3 |
4 | # Ward2ICU
5 |
6 |
7 | [](https://arxiv.org/abs/1910.00752)
8 | [](https://research.3778.care/projects/privacy/)
9 | [](https://forms.gle/e2asYSVaiuPUUCKu8)
10 |
11 |
12 |
13 |
14 | * [Description](#description)
15 | * [Models](#models)
16 | * [1D Conditional CNN GAN](#1d-conditional-cnn-gan)
17 | * [Recursive GAN (RGAN)](#recursive-gan-rgan)
18 | * [Recursive Conditional GAN (RCGAN)](#recursive-conditional-gan-rcgan)
19 | * [RNN Classifier](#rnn-classifier)
20 | * [1D-CNN Classifier](#1d-cnn-classifier)
21 | * [Citation](#citation)
22 |
23 |
24 |
25 |
26 |
27 | ## Description
28 | Ward2ICU: A Vital Signs Dataset of Inpatients from the General Ward
29 |
30 | ## Models
31 |
32 | ### 1D Conditional CNN GAN
33 | [](https://github.com/3778/Ward2ICU/blob/master/ward2icu/models/cnngan.py)
34 |
35 | 
36 |
37 | ### Recursive GAN (RGAN)
38 | [](https://github.com/3778/Ward2ICU/blob/master/ward2icu/models/rgan.py)
39 | [](https://arxiv.org/abs/1706.02633)
40 |
41 | Recursive GAN (Generator) implementation with RNN cells.
42 |
43 | ### Recursive Conditional GAN (RCGAN)
44 | [](https://github.com/3778/Ward2ICU/blob/master/ward2icu/models/rcgan.py)
45 | [](https://arxiv.org/abs/1706.02633)
46 |
47 | Recursive Conditional GAN (Generator) implementation with RNN cells
48 |
49 | ### RNN Classifier
50 | [](https://github.com/3778/Ward2ICU/blob/master/ward2icu/models/classifiers.py)
51 |
52 | A simple RNN for classification tasks. It consists of a recurrent layer (Elman RNN, LSTM or GRU) followed by 2 fully connected. The first shares parameters across the time domain (i.e. second tensor dimension), while the second collapses the time-domain to a single point with a Sigmoid activation.
53 |
54 | ### 1D-CNN Classifier
55 | [](https://github.com/3778/Ward2ICU/blob/master/ward2icu/models/classifiers.py)
56 |
57 | Single-dimension convolutional network for classification. Consists of a sequence of `Conv1d` followed by `MaxPool1d` and `Linear` with a `Sigmoid` output.
58 |
59 | ## Citation
60 | ```
61 | @article{severo2019ward2icu,
62 | title={Ward2ICU: A Vital Signs Dataset of Inpatients from the General Ward},
63 | author={Severo, Daniel and Amaro, Fl{\'a}vio and Hruschka Jr, Estevam R and Costa, Andr{\'e} Soares de Moura},
64 | journal={arXiv preprint arXiv:1910.00752},
65 | year={2019}
66 | }
67 | ```
68 |
--------------------------------------------------------------------------------
/ward2icu/metrics.py:
--------------------------------------------------------------------------------
1 | from ward2icu.samplers import BinaryBalancedSampler, IdentitySampler
2 | from ward2icu.models import BinaryRNNClassifier
3 | from torch import optim
4 | from ward2icu.utils import (train_test_split_tensor,
5 | numpy_to_cuda, tile)
6 | from ward2icu.trainers import BinaryClassificationTrainer
7 | from ward2icu import make_logger
8 |
9 | logger = make_logger(__file__)
10 |
11 |
12 | def mean_feature_error(X_real, X_synth):
13 | return (1 - X_synth.mean(axis=0)/X_real.mean(axis=0)).mean(axis=0).mean()
14 |
15 |
16 | # TODO(dsevero) in theory, we should seperate the
17 | # training and testing datasets.
18 | def tstr(X_synth, y_synth, X_real, y_real, epochs=3_000, batch_size=None):
19 |
20 | logger.info('Running TSTR')
21 | logger.info(f'Synthetic size: {len(y_synth)}')
22 | logger.info(f'Real size: {len(y_real)}')
23 | logger.info(f'Class distribution: '
24 | f'[real {y_real.float().mean()}]'
25 | f'[synthetic {y_synth.mean()}]')
26 |
27 | sequence_length = X_real.shape[1]
28 | sequence_size = X_real.shape[2]
29 | X_train, y_train = X_synth, y_synth
30 | X_test, y_test = X_real, y_real
31 |
32 | X_train, y_train = numpy_to_cuda(X_train, y_train)
33 | y_test, y_train = y_test.float(), y_train.float()
34 |
35 | sampler_train = BinaryBalancedSampler(X_train, y_train,
36 | tile=True, batch_size=batch_size)
37 | sampler_test = IdentitySampler(X_test, y_test, tile=True)
38 |
39 | model = BinaryRNNClassifier(sequence_length=sequence_length,
40 | input_size=sequence_size,
41 | dropout=0.5,
42 | hidden_size=100).cuda()
43 | optimizer = optim.Adam(model.parameters(),
44 | lr=0.001)
45 | trainer = BinaryClassificationTrainer(model,
46 | optimizer,
47 | sampler_train,
48 | sampler_test,
49 | metrics_prepend='tstr_')
50 | trainer.train(epochs)
51 | return trainer
52 |
53 |
54 | def classify(X, y, epochs=3_000, batch_size=None):
55 | sequence_length = X.shape[1]
56 | sequence_size = X.shape[2]
57 |
58 | (X_train, y_train,
59 | X_test, y_test) = train_test_split_tensor(X, y, 0.3)
60 | (X_train, y_train,
61 | X_test, y_test) = numpy_to_cuda(X_train, y_train,
62 | X_test, y_test)
63 | y_test, y_train = y_test.float(), y_train.float()
64 |
65 | sampler_train = BinaryBalancedSampler(X_train, y_train,
66 | tile=True, batch_size=batch_size)
67 | sampler_test = IdentitySampler(X_test, y_test, tile=True)
68 |
69 | model = BinaryRNNClassifier(sequence_length=sequence_length,
70 | input_size=sequence_size,
71 | dropout=0.8,
72 | hidden_size=100).cuda()
73 | optimizer = optim.Adam(model.parameters(),
74 | lr=0.001)
75 | trainer = BinaryClassificationTrainer(model,
76 | optimizer,
77 | sampler_train,
78 | sampler_test)
79 | logger.info(model)
80 | logger.info(f'Test size: {len(y_test)}')
81 | logger.info(f'Train size: {len(y_train)}')
82 |
83 | trainer.train(epochs)
84 | return trainer
85 |
--------------------------------------------------------------------------------
/ward2icu/layers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from ward2icu.utils import calc_conv_output_length, tile
4 |
5 |
6 | def rnn_layer(input_size,
7 | hidden_size=None,
8 | num_layers=1,
9 | dropout=0.5,
10 | rnn_type='rnn',
11 | nonlinearity='relu'):
12 |
13 | # Set hidden_size to input_size if not specified
14 | hidden_size = hidden_size or input_size
15 |
16 | rnn_types = {
17 | 'rnn': nn.RNN,
18 | 'lstm': nn.LSTM,
19 | 'gru': nn.GRU
20 | }
21 |
22 | rnn_kwargs = dict(
23 | input_size=input_size,
24 | hidden_size=hidden_size,
25 | num_layers=num_layers,
26 | batch_first=True,
27 | dropout=dropout
28 | )
29 |
30 | if rnn_type == 'rnn':
31 | rnn_kwargs['nonlinearity'] = nonlinearity
32 |
33 | return rnn_types[rnn_type](**rnn_kwargs)
34 |
35 |
36 | class Conv1dLayers(nn.Module):
37 | def __init__(self,
38 | input_size,
39 | input_length,
40 | kernel_size,
41 | dropout_prob=0.5,
42 | n_layers=1,
43 | step_up=1,
44 | **kwargs):
45 | self.input_size = input_size
46 | self.input_length = input_length
47 | self.kernel_size = kernel_size
48 | self.dropout_prob = dropout_prob
49 | self.n_layers = n_layers
50 | self.step_up = step_up
51 | super(Conv1dLayers, self).__init__()
52 |
53 | self.output_sizes = list()
54 | self.output_lengths = list()
55 |
56 | in_channels = input_size
57 | out_channels = in_channels
58 | output_length = input_length
59 | layers = list()
60 | for depth in range(n_layers):
61 | layers += [nn.Conv1d(in_channels=in_channels,
62 | out_channels=out_channels,
63 | kernel_size=kernel_size,
64 | **kwargs),
65 | nn.Dropout(dropout_prob),
66 | nn.LeakyReLU(0.2)]
67 | output_length = calc_conv_output_length(layers[-3],
68 | output_length)
69 | self.output_sizes += [out_channels]
70 | self.output_lengths += [output_length]
71 |
72 | # next layer
73 | in_channels = out_channels
74 | out_channels = step_up*in_channels
75 |
76 | self.layers = nn.Sequential(*layers)
77 |
78 | def forward(self, x):
79 | return self.layers(x)
80 |
81 |
82 | class Permutation(nn.Module):
83 | def __init__(self, *dims):
84 | self.dims = dims
85 | super(Permutation, self).__init__()
86 |
87 | def forward(self, x):
88 | return x.permute(*self.dims)
89 |
90 |
91 | class View(nn.Module):
92 | def __init__(self, *dims):
93 | self.dims = dims
94 | super(View, self).__init__()
95 |
96 | def forward(self, x):
97 | return x.view(*self.dims)
98 |
99 |
100 | class AppendEmbedding(nn.Module):
101 | def __init__(self, embedding_layer, dim=-1, tile=True):
102 |
103 | super(AppendEmbedding, self).__init__()
104 | self.embedding_layer = embedding_layer
105 | self.dim = dim
106 | self.tile = tile
107 |
108 | def forward(self, x):
109 | y = self.labels_pointer
110 | if self.tile:
111 | y = tile(y, x.size(1))
112 | y_emb = self.embedding_layer(y.type(torch.LongTensor).to(y.device))
113 | return torch.cat((x, y_emb), dim=self.dim)
114 |
--------------------------------------------------------------------------------
/ward2icu/models/classifiers.py:
--------------------------------------------------------------------------------
1 | '''
2 | Reference: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from ward2icu.layers import rnn_layer, Conv1dLayers
9 | from ward2icu.utils import calc_conv_output_length, flatten
10 |
11 |
12 | class BinaryRNNClassifier(nn.Module):
13 | def __init__(self, sequence_length, **kwargs):
14 | """Recursive NN for binary classification with linear
15 | time-collapsing layers.
16 |
17 | Layers:
18 | RNN (with activation, multiple layers):
19 | input: (batch_size, sequence_length, input_size)
20 | output: (batch_size, sequence_length, hidden_size)
21 |
22 | Linear (no activation):
23 | input: (batch_size, sequence_length, hidden_size)
24 | output: (batch_size, sequence_length, 1)
25 |
26 | Notes:
27 | This model adds upon ward2icu.layers.rnn_layer. See docs
28 | for more information.
29 |
30 | Args:
31 | sequence_length (int): Number of points in the sequence.
32 | kwargs: Keyword arguments passed on to ward2icu.layers.rnn_layer
33 | """
34 | super(BinaryRNNClassifier, self).__init__()
35 |
36 | # Default hidden_size to input_size.
37 | hidden_size = kwargs.get('hidden_size', kwargs['input_size'])
38 |
39 | # See ward2icu.layers.rnn_layer for more information.
40 | self.rnn = rnn_layer(**kwargs)
41 |
42 | # Parameters are shared across time-steps.
43 | self.linear = nn.Linear(hidden_size, 1)
44 |
45 | def forward(self, x):
46 | x, _ = self.rnn(x)
47 | return self.linear(x).squeeze()
48 |
49 |
50 | class BinaryCNNClassifier(nn.Module):
51 | def __init__(self,
52 | input_size,
53 | input_length,
54 | kernel_size,
55 | n_layers=1,
56 | step_up=1,
57 | dropout_prob=0.5,
58 | channel_last=True,
59 | pool_size=None,
60 | **kwargs):
61 | """ 1D CNN for binary classification.
62 | """
63 | # defaults
64 | pool_size = pool_size or kernel_size
65 |
66 | self.input_length = input_length
67 | self.input_size = input_size
68 | self.kernel_size = kernel_size
69 | self.channel_last = channel_last
70 | self.step_up = step_up
71 | self.pool_size = pool_size
72 | self.n_layers = n_layers
73 | super(BinaryCNNClassifier, self).__init__()
74 |
75 | conv = Conv1dLayers(input_size,
76 | input_length,
77 | kernel_size,
78 | n_layers=n_layers,
79 | step_up=step_up,
80 | dropout_prob=dropout_prob,
81 | **kwargs)
82 |
83 | maxpool = nn.MaxPool1d(pool_size)
84 | flatten = nn.Flatten()
85 | flatten_output_size = (
86 | conv.output_sizes[-1]*
87 | calc_conv_output_length(maxpool,
88 | conv.output_lengths[-1])
89 | )
90 |
91 | self.layers = nn.Sequential(conv,
92 | maxpool,
93 | flatten,
94 | nn.Linear(flatten_output_size, flatten_output_size),
95 | nn.LeakyReLU(0.2),
96 | nn.Linear(flatten_output_size, 1),
97 | nn.Sigmoid())
98 |
99 | def forward(self, x):
100 | if self.channel_last:
101 | x = x.permute(0, 2, 1)
102 | return self.layers(x).squeeze()
103 |
--------------------------------------------------------------------------------
/ward2icu/trainers.py:
--------------------------------------------------------------------------------
1 | """
2 | References:
3 | - https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py
4 | """
5 | import torch
6 | import numpy as np
7 | import torchgan
8 | from torch.nn import BCELoss, BCEWithLogitsLoss
9 | from ward2icu import make_logger
10 | from sklearn.metrics import balanced_accuracy_score, matthews_corrcoef
11 |
12 | logger = make_logger(__file__)
13 |
14 |
15 | class BinaryClassificationTrainer(object):
16 | def __init__(self,
17 | model,
18 | optimizer,
19 | sampler_train=None,
20 | sampler_test=None,
21 | log_to_mlflow=True,
22 | loss_function=BCEWithLogitsLoss(),
23 | metrics_prepend=''):
24 | self.optimizer = optimizer
25 | self.sampler_train = sampler_train
26 | self.sampler_test = sampler_test
27 | self.loss_function = loss_function
28 | self.model = model
29 | self.log_to_mlflow = log_to_mlflow
30 | self.tiled = sampler_train.tile
31 | self.metrics_prepend = metrics_prepend
32 |
33 | def train(self, epochs, evaluate_interval=100):
34 | for epoch in range(epochs):
35 | self.optimizer.zero_grad()
36 | X_train, y_train = self.sampler_train.sample()
37 |
38 | logits_train = self.model(X_train)
39 | loss = self.loss_function(logits_train, y_train)
40 | loss.backward()
41 | self.optimizer.step()
42 |
43 | if epoch % evaluate_interval == 0:
44 | with torch.no_grad():
45 | X_test, y_test = self.sampler_test.sample()
46 | logger.debug(f'[Train sizes {X_train.shape} {y_train.shape}]')
47 | logger.debug(f'[Test sizes {X_test.shape} {y_test.shape}]')
48 | metrics = self.evaluate(X_test, y_test,
49 | X_train, y_train)
50 | msg = f'[epoch {epoch}]'
51 | msg += ''.join(f'[{m} {np.round(v,4)}]'
52 | for m, v in metrics.items()
53 | if m.endswith('balanced_accuracy') or
54 | m.endswith('matheus'))
55 | logger.info(msg)
56 | if self.log_to_mlflow:
57 | mlflow.log_metrics(metrics, step=epoch)
58 |
59 | def evaluate(self, X_test, y_test, X_train, y_train):
60 | def _calculate(X, y, name):
61 | logits = self.model(X)
62 | probs = torch.sigmoid(logits)
63 | y_pred = probs.round()
64 | y_true = y
65 | return self.calculate_metrics(y_true, y_pred, logits, probs, name)
66 |
67 | mp = self.metrics_prepend
68 | return {**_calculate(X_test, y_test, f'{mp}test'),
69 | **_calculate(X_train, y_train, f'{mp}train')}
70 |
71 | def calculate_metrics(self, y_true, y_pred, logits, probs, name=''):
72 | y_true_ = (y_true[:, 0] if self.tiled else y_true).cpu()
73 | y_pred_ = (y_pred.mode().values if self.tiled else y_pred).cpu()
74 |
75 | mask_0 = (y_true_ == 0)
76 | mask_1 = (y_true_ == 1)
77 |
78 | hits = (y_true_ == y_pred_).float()
79 | bas = balanced_accuracy_score(y_true_, y_pred_)
80 | matthews = matthews_corrcoef(y_true_, y_pred_)
81 |
82 | return {f'{name}_accuracy': hits.mean().item(),
83 | f'{name}_balanced_accuracy': bas,
84 | f'{name}_accuracy_0': hits[mask_0].mean().item(),
85 | f'{name}_accuracy_1': hits[mask_1].mean().item(),
86 | f'{name}_loss': self.loss_function(logits, y_true).item(),
87 | f'{name}_loss_0': self.loss_function(logits[mask_0],
88 | y_true[mask_0]).item(),
89 | f'{name}_loss_1': self.loss_function(logits[mask_1],
90 | y_true[mask_1]).item(),
91 | f'{name}_matthews': matthews}
92 |
--------------------------------------------------------------------------------
/run-experiment.py:
--------------------------------------------------------------------------------
1 | import mlflow
2 | import tempfile
3 | import click
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 | from ward2icu.data import TimeSeriesVitalSigns
8 | from ward2icu.logs import log_avg_loss, log_avg_grad, log_model, log_df
9 | from ward2icu.models import CNNCGANGenerator, CNNCGANDiscriminator
10 | from ward2icu.utils import synthesis_df, train_test_split_tensor, numpy_to_cuda, tile
11 | from ward2icu.metrics import mean_feature_error, classify, tstr
12 | from ward2icu.samplers import BinaryBalancedSampler, IdentitySampler
13 | from ward2icu.trainers import BinaryClassificationTrainer, MinMaxBinaryCGANTrainer, SequenceTrainer
14 | from torch import optim
15 | from torch.utils.data import DataLoader
16 | from torchgan import losses
17 | from torchgan.metrics import ClassifierScore
18 | from torchgan.trainer import Trainer
19 | from slugify import slugify
20 | from dmpy.datascience import DataPond
21 | from ward2icu import make_logger
22 |
23 | logger = make_logger(__file__)
24 |
25 | @click.command()
26 | @click.option("--lr", type=float)
27 | @click.option("--epochs", type=int)
28 | @click.option("--ncritic", type=int)
29 | @click.option("--batch_size", type=int)
30 | @click.option("--dataset_transform", type=str)
31 | @click.option("--signals", type=int)
32 | @click.option("--gen_dropout", type=float)
33 | @click.option("--noise_size", type=int)
34 | @click.option("--hidden_size", type=int)
35 | @click.option("--flag", type=str)
36 | def cli(**opt):
37 | main(**opt)
38 |
39 | def main(**opt):
40 | logger.info(opt)
41 | batch_size = opt['batch_size'] if opt['batch_size'] != -1 else None
42 |
43 | dataset = TimeSeriesVitalSigns(transform=opt['dataset_transform'],
44 | vital_signs=opt['signals'])
45 | X = torch.from_numpy(dataset.X).cuda()
46 | y = torch.from_numpy(dataset.y).long().cuda()
47 | sampler = BinaryBalancedSampler(X, y, batch_size=batch_size)
48 |
49 | network = {
50 | 'generator': {
51 | 'name': CNNCGANGenerator,
52 | 'args': {
53 | 'output_size': opt['signals'],
54 | 'dropout': opt['gen_dropout'],
55 | 'noise_size': opt['noise_size'],
56 | 'hidden_size': opt['hidden_size']
57 | },
58 | 'optimizer': {
59 | 'name': optim.RMSprop,
60 | 'args': {
61 | 'lr': opt['lr']
62 | }
63 | }
64 | },
65 | 'discriminator': {
66 | 'name': CNNCGANDiscriminator,
67 | 'args': {
68 | 'input_size': opt['signals'],
69 | 'hidden_size': opt['hidden_size']
70 | },
71 | 'optimizer': {
72 | 'name': optim.RMSprop,
73 | 'args': {
74 | 'lr': opt['lr']
75 | }
76 | }
77 | }
78 | }
79 |
80 | wasserstein_losses = [losses.WassersteinGeneratorLoss(),
81 | losses.WassersteinDiscriminatorLoss(),
82 | losses.WassersteinGradientPenalty()]
83 |
84 | logger.info(network)
85 |
86 | trainer = SequenceTrainer(models=network,
87 | recon=None,
88 | ncritic=opt['ncritic'],
89 | losses_list=wasserstein_losses,
90 | epochs=opt['epochs'],
91 | retain_checkpoints=1,
92 | checkpoints=f"{MODEL_DIR}/",
93 | mlflow_interval=50,
94 | device=DEVICE)
95 |
96 | trainer(sampler=sampler)
97 | trainer.log_to_mlflow()
98 | logger.info(trainer.generator)
99 | logger.info(trainer.discriminator)
100 |
101 | df_synth, X_synth, y_synth = synthesis_df(trainer.generator, dataset)
102 |
103 | logger.info(df_synth.sample(10))
104 | logger.info(df_synth.groupby('cat_vital_sign')['value'].nunique()
105 | .div(df_synth.groupby('cat_vital_sign').size()))
106 | X_real = X.detach().cpu().numpy()
107 | mfe = np.abs(mean_feature_error(X_real, X_synth))
108 | logger.info(f'Mean feature error: {mfe}')
109 |
110 | mlflow.set_tag('flag', opt['flag'])
111 | log_df(df_synth, 'synthetic/vital_signs')
112 | mlflow.log_metric('mean_feature_error', mfe)
113 |
114 | trainer_class = classify(X_synth, y_synth, epochs=2_000, batch_size=batch_size)
115 | trainer_tstr = tstr(X_synth, y_synth, X, y, epochs=3_000, batch_size=batch_size)
116 | log_model(trainer_class.model, 'models/classifier')
117 | log_model(trainer_tstr.model, 'models/tstr')
118 |
119 | if __name__ == '__main__':
120 | with mlflow.start_run():
121 | with tempfile.TemporaryDirectory() as MODEL_DIR:
122 | if torch.cuda.is_available():
123 | DEVICE = torch.device("cuda")
124 | torch.backends.cudnn.deterministic = True
125 | else:
126 | DEVICE = torch.device("cpu")
127 | logger.info(f'Running on device {DEVICE}')
128 | cli()
129 |
--------------------------------------------------------------------------------
/ward2icu/models/rcgan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from ward2icu.models import RGANGenerator, RGANDiscriminator
4 | from ward2icu.utils import tile
5 |
6 |
7 | class RCGANGenerator(RGANGenerator):
8 | def __init__(self,
9 | sequence_length,
10 | output_size,
11 | num_classes,
12 | noise_size,
13 | prob_classes=None,
14 | label_embedding_size=None,
15 | **kwargs):
16 | """Recursive Conditional GAN (Generator) implementation with RNN cells.
17 |
18 | Notes:
19 | This model adds upon ward2icu.models.rgan. See docs of parent classes
20 | for more information.
21 |
22 | Args (check parent class for extra arguments):
23 | sequence_length (int): Number of points in the sequence.
24 | Defined by the real dataset.
25 | output_size (int): Size of output (usually the last tensor dimension).
26 | Defined by the real dataset.
27 | num_classes (int): Number of classes in the dataset.
28 | noise_size (int): Size of noise used to generate fake data.
29 | label_embedding_size (int, optional): Size of embedding dimensions.
30 | Defaults to num_classes.
31 | """
32 | # Defaults
33 | label_embedding_size = label_embedding_size or num_classes
34 | if prob_classes is None:
35 | prob_classes = torch.ones(num_classes)
36 |
37 | #TODO(dsevero): this could cause problems
38 | kwargs['input_size'] = label_embedding_size + noise_size
39 | kwargs['noise_size'] = noise_size
40 | kwargs['label_type'] = "generated"
41 | kwargs['sequence_length'] = sequence_length
42 | kwargs['output_size'] = output_size
43 |
44 | super(RCGANGenerator, self).__init__(**kwargs)
45 | self.num_classes = num_classes
46 | self.label_embeddings = nn.Embedding(num_classes, label_embedding_size)
47 | self.label_embedding_size = label_embedding_size
48 | self.prob_classes = torch.Tensor(prob_classes)
49 |
50 | # Initialize all weights.
51 | # Already initialized in parent class
52 |
53 | def forward(self, z, y):
54 | # y must be tiled so that labels repeat across the sequence dimensions
55 | # y_tiled[:, i] == y_tiled[:, j] for all i and j
56 | # shape: (batch_size, sequence_length)
57 | y_tiled = tile(y, self.sequence_length)
58 |
59 | # shape: (batch-size, sequence_length, label_embedding_size)
60 | y_emb = self.label_embeddings(y_tiled.type(torch.LongTensor).to(y.device))
61 |
62 | # shape: (batch-size, sequence_length, noise_size)
63 | z = z.view(-1, self.sequence_length, self.noise_size)
64 |
65 | # shape: (batch-size, encoding_dims)
66 | z_cond = torch.cat((z, y_emb), dim=2)
67 |
68 | # shape: (batch-size, sequence_length, output_size)
69 | return super(RCGANGenerator, self).forward(z_cond, reshape=False)
70 |
71 | def sampler(self, sample_size, device='cpu'):
72 | return [
73 | torch.randn(sample_size, self.encoding_dims, device=device),
74 | torch.multinomial(self.prob_classes, sample_size,
75 | replacement=True).to(device)
76 | ]
77 |
78 |
79 | class RCGANDiscriminator(RGANDiscriminator):
80 | def __init__(self,
81 | sequence_length,
82 | num_classes,
83 | input_size,
84 | label_embedding_size=None,
85 | **kwargs):
86 | """Recursive Conditional GAN (Discriminator) implementation with RNN cells.
87 |
88 | Notes:
89 | This model adds upon ward2icu.models.rgan. See docs of parent class.
90 | for more information.
91 |
92 | Args (check parent class for extra arguments):
93 | sequence_length (int): Number of points in the sequence.
94 | num_classes (int): Number of classes in the dataset.
95 | input_size (int): Size of input (usually the last tensor dimension).
96 | label_embedding_size (int, optional): Size of embedding dimensions.
97 | Defaults to num_classes.
98 | """
99 |
100 | # Defaults
101 | label_embedding_size = label_embedding_size or num_classes
102 | kwargs['input_size'] = label_embedding_size + input_size
103 | kwargs['label_type'] = "required"
104 | kwargs['sequence_length'] = sequence_length
105 |
106 | super(RCGANDiscriminator, self).__init__(**kwargs)
107 | self.num_classes = num_classes
108 | self.label_embeddings = nn.Embedding(num_classes, label_embedding_size)
109 | self.label_embedding_size = label_embedding_size
110 |
111 | # Initialize all weights.
112 | self._weight_initializer()
113 |
114 | def forward(self, x, y):
115 | # y must be tiled so that labels repeat across the sequence dimensions
116 | # y_tiled[:, i] == y_tiled[:, j] for all i and j
117 | # shape: (batch_size, sequence_length)
118 | y_tiled = tile(y, self.sequence_length)
119 |
120 | # shape: (batch-size, sequence_length, label_embedding_size)
121 | y_emb = self.label_embeddings(y_tiled.type(torch.LongTensor).to(y.device))
122 |
123 | # shape: (batch-size, sequence_length, label_embedding_size + hidden_size)
124 | x_cond = torch.cat((x, y_emb), dim=2)
125 | return super(RCGANDiscriminator, self).forward(x_cond)
126 |
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pytest
4 | from torchgan.losses import MinimaxDiscriminatorLoss
5 | from sybric.models import (RGANGenerator,
6 | RGANDiscriminator,
7 | RCGANGenerator,
8 | RCGANDiscriminator,
9 | BinaryRNNClassifier,
10 | BinaryCNNClassifier,
11 | CNNCGANGenerator,
12 | CNNCGANDiscriminator,
13 | FCCMLPGANGenerator,
14 | FCCMLPGANDiscriminator)
15 |
16 | np.random.seed(3778)
17 | torch.manual_seed(3778)
18 | torch.backends.cudnn.deterministic = True
19 | torch.backends.cudnn.benchmark = False
20 | DEVICE='cpu'
21 |
22 | @pytest.fixture
23 | def RGAN_setup():
24 | batch_size = 3
25 | seq_length = 10
26 |
27 | # Generator
28 | gen_opt = dict(output_size=5,
29 | num_layers=1,
30 | dropout=0,
31 | noise_size=100,
32 | sequence_length=seq_length)
33 |
34 | # Discriminator
35 | dsc_opt = dict(input_size=gen_opt['output_size'],
36 | sequence_length=seq_length,
37 | hidden_size=None,
38 | num_layers=1,
39 | dropout=0)
40 |
41 | gen = RGANGenerator(**gen_opt)
42 | dsc = RGANDiscriminator(**dsc_opt)
43 | return gen, dsc, batch_size, seq_length
44 |
45 |
46 | def test_RGAN_forward(RGAN_setup):
47 | gen, dsc, batch_size, seq_length = RGAN_setup
48 | z, = gen.sampler(batch_size, DEVICE)
49 | assert tuple(z.shape) == (batch_size, seq_length*gen.noise_size)
50 |
51 | x = gen.forward(z)
52 | assert tuple(x.shape) == (batch_size, seq_length, gen.output_size)
53 |
54 | y = dsc.forward(x)
55 | assert tuple(y.shape) == (batch_size, seq_length, 1)
56 |
57 |
58 |
59 | def test_RCGAN():
60 | batch_size = 3
61 | seq_length = 20
62 | label_emb_size = None
63 | prob_classes = [0, 0, 0, 0, 1]
64 | num_classes = len(prob_classes)
65 |
66 | # Generator
67 | gen_opt = dict(num_classes=num_classes,
68 | prob_classes=prob_classes,
69 | label_embedding_size=label_emb_size,
70 | output_size=5,
71 | num_layers=1,
72 | dropout=0,
73 | noise_size=10,
74 | hidden_size=100,
75 | sequence_length=seq_length)
76 |
77 | # Discriminator
78 | dsc_opt = dict(input_size=gen_opt['output_size'],
79 | sequence_length=seq_length,
80 | num_classes=num_classes,
81 | label_embedding_size=label_emb_size,
82 | num_layers=1,
83 | hidden_size=100,
84 | dropout=0)
85 |
86 | gen = RCGANGenerator(**gen_opt)
87 | dsc = RCGANDiscriminator(**dsc_opt)
88 |
89 | z, y = gen.sampler(batch_size, DEVICE)
90 | assert tuple(z.shape) == (batch_size, seq_length*gen.noise_size)
91 | assert tuple(y.shape) == (batch_size,)
92 |
93 | x = gen.forward(z, y)
94 | assert tuple(x.shape) == (batch_size, seq_length, gen.output_size)
95 |
96 | p = dsc.forward(x, y)
97 | assert tuple(p.shape) == (batch_size, seq_length, 1)
98 |
99 | # assert prob_classes
100 | assert (y == 4).all()
101 |
102 |
103 | def test_BinaryRNNClassifier():
104 | seq_length = 10
105 | batch_size = 8
106 | hidden_size = 20
107 | input_size = 2
108 |
109 | model = BinaryRNNClassifier(sequence_length=seq_length,
110 | input_size=input_size,
111 | hidden_size=hidden_size)
112 |
113 | x = torch.randn(batch_size, seq_length, input_size)
114 | p = model.forward(x)
115 | assert tuple(p.shape) == (batch_size, seq_length)
116 |
117 |
118 | def test_BinaryCNNClassifier():
119 | input_size = 2
120 | input_length = 10
121 | batch_size = 8
122 |
123 | model = BinaryCNNClassifier(input_size,
124 | input_length,
125 | kernel_size=3)
126 |
127 | x = torch.randn(batch_size, input_length, input_size)
128 | assert tuple(x.shape) == (batch_size, input_length, input_size)
129 |
130 | p = model.forward(x)
131 | assert tuple(p.shape) == (batch_size,)
132 |
133 |
134 | def test_FCCMLPGAN():
135 | sequence_length = 20
136 | sequence_size = 5
137 | hidden_size = 10
138 | prob_classes = [0, 0, 0, 0, 1]
139 | num_classes = len(prob_classes)
140 | batch_size = 8
141 |
142 | # Generator
143 | gen_opt = dict(sequence_length=sequence_length,
144 | sequence_size=sequence_size,
145 | num_classes=num_classes,
146 | hidden_size=hidden_size,
147 | prob_classes=prob_classes)
148 |
149 | # Discriminator
150 | dsc_opt = dict(sequence_length=sequence_length,
151 | num_classes=num_classes,
152 | sequence_size=gen_opt['sequence_size'],
153 | hidden_size=hidden_size)
154 |
155 | gen = FCCMLPGANGenerator(**gen_opt)
156 | dsc = FCCMLPGANDiscriminator(**dsc_opt)
157 |
158 | z, y = gen.sampler(batch_size, DEVICE)
159 | assert tuple(z.shape) == (batch_size, sequence_length*gen.noise_size)
160 | assert tuple(y.shape) == (batch_size,)
161 |
162 | x = gen.forward(z, y)
163 | assert tuple(x.shape) == (batch_size, sequence_length, sequence_size)
164 |
165 | p = dsc.forward(x, y)
166 | assert tuple(p.shape) == (batch_size, sequence_length)
167 |
168 | # assert prob_classes
169 | assert (y == 4).all()
170 |
--------------------------------------------------------------------------------
/ward2icu/models/cnngan.py:
--------------------------------------------------------------------------------
1 | '''
2 | Reference: https://arxiv.org/abs/1806.01875
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import (Linear,
8 | Conv1d,
9 | MaxPool1d,
10 | AvgPool1d,
11 | Upsample,
12 | ReplicationPad1d,
13 | LeakyReLU,
14 | Flatten,
15 | Dropout)
16 | from torchgan.models import Generator, Discriminator
17 | from ward2icu.utils import tile
18 | from ward2icu.layers import Conv1dLayers, View, AppendEmbedding
19 | from ward2icu import set_seeds
20 |
21 |
22 | class CNNCGANGenerator(Generator):
23 | def __init__(self,
24 | output_size,
25 | dropout=0,
26 | noise_size=20,
27 | channel_last=True,
28 | hidden_size=50,
29 | **kwargs):
30 | # defaults
31 | self.initial_length = 5
32 | self.output_size = output_size
33 | self.hidden_size = hidden_size
34 | self.encoding_dims = noise_size
35 | self.num_classes = 2
36 | self.label_embedding_size = self.num_classes
37 | self.prob_classes = torch.ones(self.num_classes)
38 | self.dropout = dropout
39 | self.label_type = 'generated'
40 | self.channel_last = channel_last
41 | self._labels = None
42 |
43 | # Set kwargs (might overried above attributes)
44 | for key, value in kwargs.items():
45 | setattr(self, key, value)
46 |
47 | super(CNNCGANGenerator, self).__init__(self.encoding_dims,
48 | self.label_type)
49 | self.label_embeddings = nn.Embedding(self.num_classes,
50 | self.label_embedding_size)
51 |
52 | Conv1d_ = lambda k: Conv1d(self.hidden_size, self.hidden_size, k)
53 | ## Build CNN layer
54 | # (batch_size, channels, sequence length)
55 |
56 | layers_input_size = self.encoding_dims*(1 + self.label_embedding_size)
57 | self.layers = nn.Sequential(
58 | Linear(layers_input_size, self.initial_length*self.hidden_size),
59 | View(-1, self.hidden_size, self.initial_length),
60 | LeakyReLU(0.2),
61 | Dropout(dropout),
62 | # output size: (-1, 50, 5)
63 |
64 | Upsample(scale_factor=2, mode='linear', align_corners=True),
65 | # output size: (-1, 50, 10)
66 |
67 | Conv1d_(3),
68 | ReplicationPad1d(1),
69 | LeakyReLU(0.2),
70 | Conv1d_(3),
71 | ReplicationPad1d(1),
72 | LeakyReLU(0.2),
73 | Dropout(dropout),
74 | # output size: (-1, 50, 10)
75 |
76 | Upsample(scale_factor=2, mode='linear', align_corners=True),
77 | # output size: (-1, 50, 20)
78 |
79 | Conv1d_(3),
80 | ReplicationPad1d(1),
81 | LeakyReLU(0.2),
82 | Conv1d_(3),
83 | ReplicationPad1d(1),
84 | LeakyReLU(0.2),
85 | Dropout(dropout),
86 | # output size: (-1, 50, 20)
87 |
88 | Conv1d(self.hidden_size, self.output_size, 1)
89 | # output size: (-1, 5, 20)
90 | )
91 |
92 | # Initialize all weights.
93 | self._weight_initializer()
94 |
95 | def forward(self, z, y):
96 | y_tiled = tile(y, z.size(-1))
97 | y_emb = self.label_embeddings(y_tiled
98 | .type(torch.LongTensor)
99 | .to(y.device))
100 | z = torch.cat((z.unsqueeze(-1), y_emb), dim=-1)
101 | z = z.flatten(1)
102 | x = self.layers(z)
103 | return x.permute(0, 2, 1) if self.channel_last else x
104 |
105 | def sampler(self, sample_size, device='cpu'):
106 | return [
107 | torch.randn(sample_size, self.encoding_dims, device=device),
108 | torch.multinomial(self.prob_classes, sample_size,
109 | replacement=True).to(device)
110 | ]
111 |
112 |
113 | class CNNCGANDiscriminator(Generator):
114 | def __init__(self,
115 | input_size,
116 | channel_last=True,
117 | hidden_size=50,
118 | **kwargs):
119 |
120 | # Defaults
121 | self.input_length = 20
122 | self.input_size = input_size
123 | self.hidden_size = hidden_size
124 | self.num_classes = 2
125 | self.label_embedding_size = self.num_classes
126 | self.prob_classes = torch.ones(self.num_classes)
127 | self.dropout = 0.5
128 | self.label_type = 'required'
129 | self.channel_last = channel_last
130 |
131 |
132 | # Set kwargs (might overried above attributes)
133 | for key, value in kwargs.items():
134 | setattr(self, key, value)
135 |
136 | super(CNNCGANDiscriminator, self).__init__(self.input_size,
137 | self.label_type)
138 |
139 | # Build CNN layer
140 | self.label_embeddings = nn.Embedding(self.num_classes,
141 | self.label_embedding_size)
142 |
143 | Conv1d_ = lambda k: Conv1d(hidden_size, hidden_size, k)
144 | layers_input_size = self.input_size + self.label_embedding_size
145 | layers_output_size = 5*hidden_size
146 | ## Build CNN layer
147 | self.layers = nn.Sequential(
148 | Conv1d(layers_input_size, hidden_size, 1),
149 | LeakyReLU(0.2),
150 | # output size: (-1, 50, 20)
151 |
152 | Conv1d_(3),
153 | ReplicationPad1d(1),
154 | LeakyReLU(0.2),
155 | Conv1d_(3),
156 | ReplicationPad1d(1),
157 | LeakyReLU(0.2),
158 | AvgPool1d(2, 2),
159 | # output size: (-1, 50, 10)
160 |
161 | Conv1d_(3),
162 | ReplicationPad1d(1),
163 | LeakyReLU(0.2),
164 | Conv1d_(3),
165 | ReplicationPad1d(1),
166 | LeakyReLU(0.2),
167 | AvgPool1d(2, 2),
168 | # output size: (-1, 50, 5)
169 |
170 | Flatten(),
171 | Linear(layers_output_size, 1)
172 | # output size: (-1, 1)
173 | )
174 |
175 | # Initialize all weights.
176 | self._weight_initializer()
177 |
178 | def forward(self, x, y):
179 | if self.channel_last:
180 | x = x.permute(0, 2, 1)
181 | y_tiled = tile(y, x.size(-1))
182 | y_emb = self.label_embeddings(y_tiled
183 | .type(torch.LongTensor)
184 | .to(y.device))
185 | y_emb = y_emb.permute(0, 2, 1)
186 |
187 | x = torch.cat((x, y_emb), dim=1)
188 | x = self.layers(x)
189 | return x.squeeze()
190 |
--------------------------------------------------------------------------------
/ward2icu/models/rgan.py:
--------------------------------------------------------------------------------
1 | '''
2 | Reference: https://arxiv.org/abs/1706.02633
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchgan.models import Generator, Discriminator
8 | from ward2icu.layers import rnn_layer
9 |
10 |
11 | class RGANGenerator(Generator):
12 | def __init__(self,
13 | sequence_length,
14 | output_size,
15 | hidden_size=None,
16 | noise_size=None,
17 | num_layers=1,
18 | dropout=0,
19 | rnn_nonlinearity='tanh',
20 | rnn_type='rnn',
21 | input_size=None,
22 | last_layer=None,
23 | **kwargs):
24 | """Recursive GAN (Generator) implementation with RNN cells.
25 |
26 | Layers:
27 | RNN (with activation, multiple layers):
28 | input: (batch_size, sequence_length, noise_size)
29 | output: (batch_size, sequence_length, hidden_size)
30 |
31 | Linear (no activation, weights shared between time steps):
32 | input: (batch_size, sequence_length, hidden_size)
33 | output: (batch_size, sequence_length, output_size)
34 |
35 | last_layer (optional)
36 | input: (batch_size, sequence_length, output_size)
37 |
38 | Args:
39 | sequence_length (int): Number of points in the sequence.
40 | Defined by the real dataset.
41 | output_size (int): Size of output (usually the last tensor dimension).
42 | Defined by the real dataset.
43 | hidden_size (int, optional): Size of RNN output.
44 | Defaults to output_size.
45 | noise_size (int, optional): Size of noise used to generate fake data.
46 | Defaults to output_size.
47 | num_layers (int, optional): Number of stacked RNNs in rnn.
48 | dropout (float, optional): Dropout probability for rnn layers.
49 | rnn_nonlinearity (str, optional): Non-linearity of the RNN. Must be
50 | either 'tanh' or 'relu'. Only valid if rnn_type == 'rnn'.
51 | rnn_type (str, optional): Type of RNN layer. Valid values are 'lstm',
52 | 'gru' and 'rnn', the latter being the default.
53 | input_size (int, optional): Input size of RNN, defaults to noise_size.
54 | last_layer (Module, optional): Last layer of the discriminator.
55 | """
56 |
57 |
58 | # Defaults
59 | noise_size = noise_size or output_size
60 | input_size = input_size or noise_size
61 | hidden_size = hidden_size or output_size
62 |
63 | self.sequence_length = sequence_length
64 | self.hidden_size = hidden_size
65 | self.output_size = output_size
66 | self.noise_size = noise_size
67 | self.num_layers = num_layers
68 | self.dropout = dropout
69 | self.rnn_nonlinearity = rnn_nonlinearity
70 | self.rnn_type = rnn_type
71 | self.input_size = input_size
72 | self.label_type = "none"
73 |
74 | # Set kwargs (might overried above attributes)
75 | for key, value in kwargs.items():
76 | setattr(self, key, value)
77 |
78 | # Total size of z that will be sampled. Later, in the forward
79 | # method, we resize to (batch_size, sequence_length, noise_size).
80 | # TODO: Any resizing of z is valid as long as the total size
81 | # remains sequence_length*noise_size. How does this affect
82 | # the performance of the RNN?
83 | self.encoding_dims = sequence_length*noise_size
84 |
85 | super(RGANGenerator, self).__init__(self.encoding_dims,
86 | self.label_type)
87 |
88 | # Build RNN layer
89 | self.rnn = rnn_layer(input_size=input_size,
90 | hidden_size=hidden_size,
91 | num_layers=num_layers,
92 | dropout=dropout,
93 | rnn_type=rnn_type,
94 | nonlinearity=rnn_nonlinearity)
95 | self.dropout = nn.Dropout(dropout)
96 | # self.batchnorm = nn.BatchNorm1d(hidden_size)
97 | self.linear = nn.Linear(hidden_size, output_size)
98 | self.last_layer = last_layer
99 |
100 | # Initialize all weights.
101 | # nn.init.xavier_normal_(self.rnn)
102 | nn.init.xavier_normal_(self.linear.weight)
103 |
104 | def forward(self, z, reshape=True):
105 | if reshape:
106 | z = z.view(-1, self.sequence_length, self.noise_size)
107 | y, _ = self.rnn(z)
108 | y = self.dropout(y)
109 | # y = self.batchnorm(y.permute(0, 2, 1)).permute(0, 2, 1)
110 | y = self.linear(y)
111 | return y if self.last_layer is None else self.last_layer(y)
112 |
113 |
114 | class RGANDiscriminator(Discriminator):
115 | def __init__(self,
116 | sequence_length,
117 | input_size,
118 | hidden_size=None,
119 | num_layers=1,
120 | dropout=0,
121 | rnn_nonlinearity='tanh',
122 | rnn_type='rnn',
123 | last_layer=None,
124 | **kwargs):
125 | """Recursive GAN (Discriminator) implementation with RNN cells.
126 |
127 | Layers:
128 | RNN (with activation, multiple layers):
129 | input: (batch_size, sequence_length, input_size)
130 | output: (batch_size, sequence_length, hidden_size)
131 |
132 | Linear (no activation, weights shared between time steps):
133 | input: (batch_size, sequence_length, hidden_size)
134 | output: (batch_size, sequence_length, 1)
135 |
136 | last_layer (optional)
137 | input: (batch_size, sequence_length, 1)
138 |
139 | Args:
140 | sequence_length (int): Number of points in the sequence.
141 | input_size (int): Size of input (usually the last tensor dimension).
142 | hidden_size (int, optional): Size of hidden layers in rnn.
143 | If None, defaults to input_size.
144 | num_layers (int, optional): Number of stacked RNNs in rnn.
145 | dropout (float, optional): Dropout probability for rnn layers.
146 | rnn_nonlinearity (str, optional): Non-linearity of the RNN. Must be
147 | either 'tanh' or 'relu'. Only valid if rnn_type == 'rnn'.
148 | rnn_type (str, optional): Type of RNN layer. Valid values are 'lstm',
149 | 'gru' and 'rnn', the latter being the default.
150 | last_layer (Module, optional): Last layer of the discriminator.
151 | """
152 |
153 | # TODO: Insert non-linearities between Linear layers.
154 | # TODO: Add BatchNorm and Dropout as in https://arxiv.org/abs/1905.05928v1
155 |
156 | # Set hidden_size to input_size if not specified
157 | hidden_size = hidden_size or input_size
158 |
159 | self.input_size = input_size
160 | self.sequence_length = sequence_length
161 | self.hidden_size = hidden_size
162 | self.num_layers = num_layers
163 | self.dropout = dropout
164 | self.label_type = "none"
165 |
166 | # Set kwargs (might overried above attributes)
167 | for key, value in kwargs.items():
168 | setattr(self, key, value)
169 |
170 | super(RGANDiscriminator, self).__init__(self.input_size,
171 | self.label_type)
172 |
173 | # Build RNN layer
174 | self.rnn = rnn_layer(input_size=input_size,
175 | hidden_size=hidden_size,
176 | num_layers=num_layers,
177 | dropout=dropout,
178 | rnn_type=rnn_type,
179 | nonlinearity=rnn_nonlinearity)
180 | self.dropout = nn.Dropout(dropout)
181 | self.linear = nn.Linear(hidden_size, 1)
182 | self.last_layer = last_layer
183 |
184 | # Initialize all weights.
185 | # nn.init.xavier_normal_(self.rnn)
186 | nn.init.xavier_normal_(self.linear.weight)
187 |
188 | def forward(self, x):
189 | y, _ = self.rnn(x)
190 | y = self.dropout(y)
191 | y = self.linear(y)
192 | return y if self.last_layer is None else self.last_layer(y)
193 |
--------------------------------------------------------------------------------
/tests/test_trainers.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | import torch.nn.functional as F
6 | from torch.optim import SGD
7 | from torch.utils.data import DataLoader, Dataset
8 | from sybric.trainers import (BinaryClassificationTrainer,
9 | MinMaxBinaryCGANTrainer)
10 | from sybric.samplers import IdentitySampler
11 | from sybric.models import RCGANGenerator, RCGANDiscriminator
12 | from itertools import chain
13 |
14 | @pytest.fixture
15 | def binaryclass_trainer():
16 | batch_size = 8
17 | input_size = 3
18 | num_classes = 2
19 |
20 | X_train = torch.randn(batch_size, input_size, device='cpu').float()
21 | y_train = torch.randint(0, num_classes, (batch_size,), device='cpu').float()
22 |
23 | X_test = torch.randn(batch_size, input_size, device='cpu').float()
24 | y_test = torch.randint(0, num_classes, (batch_size,), device='cpu').float()
25 |
26 | sampler_train = IdentitySampler(X_train, y_train)
27 | sampler_test = IdentitySampler(X_test, y_test)
28 |
29 | class Model(nn.Module):
30 | def __init__(self):
31 | super(Model, self).__init__()
32 | self.linear = nn.Linear(input_size, 1)
33 | def forward(self, x):
34 | return self.linear(x).squeeze()
35 |
36 | model = Model().cpu()
37 | optimizer = SGD(model.parameters(), lr=100)
38 | binaryclass_trainer = BinaryClassificationTrainer(optimizer=optimizer,
39 | sampler_train=sampler_train,
40 | sampler_test=sampler_test,
41 | model=model)
42 | return binaryclass_trainer
43 |
44 | @pytest.fixture
45 | def binaryclass_tiled_trainer():
46 | batch_size = 8
47 | input_size = 3
48 | num_classes = 2
49 |
50 | X_train = torch.randn(batch_size, input_size, device='cpu').float()
51 | y_train = torch.randint(0, num_classes, (batch_size,), device='cpu').float()
52 |
53 | X_test = torch.randn(batch_size, input_size, device='cpu').float()
54 | y_test = torch.randint(0, num_classes, (batch_size,), device='cpu').float()
55 |
56 | sampler_train = IdentitySampler(X_train, y_train, tile=True)
57 | sampler_test = IdentitySampler(X_test, y_test, tile=True)
58 |
59 | class Model(nn.Module):
60 | def __init__(self):
61 | super(Model, self).__init__()
62 | self.linear = nn.Linear(input_size, 1)
63 | def forward(self, x):
64 | return self.linear(x).squeeze()
65 |
66 | model = Model().cpu()
67 | optimizer = SGD(model.parameters(), lr=100)
68 | binaryclass_trainer_tiled = BinaryClassificationTrainer(optimizer=optimizer,
69 | sampler_train=sampler_train,
70 | sampler_test=sampler_test,
71 | model=model)
72 | return binaryclass_trainer_tiled
73 |
74 |
75 | @pytest.fixture
76 | def minmaxgan_trainer():
77 | data_size = 10
78 | input_length = 20
79 | num_classes = 2
80 | noise_size = 10
81 | output_size = 5
82 | input_size = output_size
83 | encoding_dims = input_length*input_size
84 |
85 | X = torch.randn(data_size, input_length, output_size)
86 | y = torch.randint(0, num_classes, (data_size,))
87 | sampler = IdentitySampler(X, y)
88 |
89 | generator = RCGANGenerator(input_length, output_size,
90 | num_classes, noise_size)
91 | discriminator = RCGANDiscriminator(input_length, num_classes, input_size)
92 | optimizer_gen = SGD(generator.parameters(), lr=0.1)
93 | optimizer_dsc = SGD(discriminator.parameters(), lr=0.1)
94 | trainer = MinMaxBinaryCGANTrainer(generator,
95 | discriminator,
96 | optimizer_gen,
97 | optimizer_dsc,
98 | sampler)
99 | return trainer
100 |
101 | def test_MinMaxGANTrainer_train(minmaxgan_trainer):
102 | def parameters():
103 | yield from minmaxgan_trainer.generator.parameters()
104 | yield from minmaxgan_trainer.discriminator.parameters()
105 |
106 | old_params = [p.detach().clone() for p in parameters()]
107 | minmaxgan_trainer.train(epochs=1)
108 | new_params = [p.detach().clone() for p in parameters()]
109 |
110 | # assert parameters are being trained
111 | assert all([(old != new).all()
112 | for old, new in zip(old_params, new_params)])
113 |
114 |
115 | def test_MinMaxGANTrainer_calc_dsc_loss(minmaxgan_trainer):
116 | dsc_logits_real = torch.Tensor([[[0],
117 | [0],
118 | [0]],
119 |
120 | [[1],
121 | [1],
122 | [0]]])
123 |
124 | dsc_logits_fake = torch.Tensor([[[0],
125 | [0],
126 | [0]],
127 |
128 | [[1],
129 | [1],
130 | [0]]])
131 |
132 | loss_real = minmaxgan_trainer.calc_dsc_real_loss(dsc_logits_real)
133 | loss_fake = minmaxgan_trainer.calc_dsc_fake_loss(dsc_logits_fake)
134 | loss = loss_real + loss_fake
135 | e =-0.5*(np.log(1 - torch.sigmoid(dsc_logits_real)).mean() +
136 | np.log(torch.sigmoid(dsc_logits_fake)).mean())
137 | assert np.isclose(e, loss.item())
138 |
139 |
140 | def test_MinMaxGANTrainer_calc_gen_loss(minmaxgan_trainer):
141 | dsc_logits = torch.Tensor([[[0],
142 | [0],
143 | [0]],
144 |
145 | [[1],
146 | [1],
147 | [0]]])
148 |
149 | loss = minmaxgan_trainer.calc_gen_loss(dsc_logits)
150 | e = -np.log(torch.sigmoid(dsc_logits)).mean()
151 | assert np.isclose(e, loss.item())
152 |
153 |
154 | def test_MinMaxGANTrainer_calculate_metrics(minmaxgan_trainer):
155 | loss_gen = torch.Tensor([0])
156 | loss_dsc = torch.Tensor([1])
157 | metrics = minmaxgan_trainer.calculate_metrics(loss_gen, loss_dsc)
158 | assert metrics['generator_loss'] == 0
159 | assert metrics['discriminator_loss'] == 1
160 | assert np.isnan(metrics['generator_mean_abs_grad'])
161 | assert np.isnan(metrics['discriminator_mean_abs_grad'])
162 |
163 |
164 | def test_BinaryClassificationTrainer_train(binaryclass_trainer):
165 | old_params = [p.detach().clone()
166 | for p in binaryclass_trainer.model.parameters()]
167 | binaryclass_trainer.train(epochs=1)
168 | new_params = [p.detach().clone()
169 | for p in binaryclass_trainer.model.parameters()]
170 |
171 | # assert parameters are being trained
172 | assert all([(old != new).all() for old, new in zip(old_params, new_params)])
173 |
174 |
175 | def test_BinaryClassificationTrainer_calculate_metrics(binaryclass_trainer):
176 | logits = torch.Tensor([-9, -8, 8, 9, -1, 8])
177 | probs = torch.sigmoid(logits)
178 | y_true = torch.Tensor([ 0, 0, 0, 0, 1, 1])
179 | y_pred = torch.Tensor([ 0, 0, 1, 1, 0, 1])
180 |
181 | assert (y_pred == probs.round()).all()
182 |
183 | r = binaryclass_trainer.calculate_metrics(y_true, y_pred,
184 | logits, probs)
185 |
186 | loss = -np.log(probs[y_true==1]).sum() - np.log(1 - probs[y_true==0]).sum()
187 | loss /= probs.shape[0]
188 | assert np.isclose(r['_accuracy'], 3/6)
189 | assert np.isclose(r['_balanced_accuracy'], (2/4)*(2/6) + (1/2)*(4/6)) # 0.5
190 | assert np.isclose(r['_accuracy_0'], 2/4)
191 | assert np.isclose(r['_accuracy_1'], 1/2)
192 | assert np.isclose(r['_loss'], loss)
193 | assert np.isclose(r['_loss_0'], -np.log(1 - probs[y_true==0]).mean())
194 | assert np.isclose(r['_loss_1'], -np.log(probs[y_true==1]).mean())
195 | assert np.isclose(r['_matthews'], 0)
196 |
197 |
198 | def test_BinaryClassificationTrainer_calculate_metrics_tiled(binaryclass_tiled_trainer):
199 | logits = torch.Tensor([[-9, -8, 8, 9, -1, 8],
200 | [-9, -8, 8, 9, -1, 8]]).T
201 | probs = torch.sigmoid(logits)
202 | y_true = torch.Tensor([[ 0, 0, 0, 0, 1, 1],
203 | [ 0, 0, 0, 0, 1, 1]]).T
204 | y_pred = torch.Tensor([[ 0, 0, 1, 1, 0, 1],
205 | [ 0, 0, 1, 1, 0, 1]]).T
206 |
207 | assert (y_pred == probs.round()).all()
208 |
209 | r = binaryclass_tiled_trainer.calculate_metrics(y_true, y_pred,
210 | logits, probs)
211 |
212 | loss = 0.5*(-np.log(probs[y_true==1]).sum() - np.log(1 - probs[y_true==0]).sum())
213 | loss /= probs.shape[0]
214 | assert np.isclose(r['_accuracy'], 3/6)
215 | assert np.isclose(r['_balanced_accuracy'], (2/4)*(2/6) + (1/2)*(4/6)) # 0.5
216 | assert np.isclose(r['_accuracy_0'], 2/4)
217 | assert np.isclose(r['_accuracy_1'], 1/2)
218 | assert np.isclose(r['_loss'], loss)
219 | assert np.isclose(r['_loss_0'], -np.log(1 - probs[y_true==0]).mean())
220 | assert np.isclose(r['_loss_1'], -np.log(probs[y_true==1]).mean())
221 | assert np.isclose(r['_matthews'], 0)
222 |
--------------------------------------------------------------------------------