├── .github └── workflows │ └── ci.yml ├── CHANGELOG.md ├── LICENSE ├── README.md ├── auton_survival ├── __init__.py ├── datasets.py ├── datasets │ ├── framingham.csv │ ├── pbc2.csv │ ├── support2.csv │ └── synthetic_dataset.csv ├── estimators.py ├── experiments.py ├── metrics.py ├── models │ ├── cmhe │ │ ├── __init__.py │ │ ├── cmhe_torch.py │ │ └── cmhe_utilities.py │ ├── cph │ │ ├── __init__.py │ │ ├── dcph_torch.py │ │ └── dcph_utilities.py │ ├── dcm │ │ ├── __init__.py │ │ ├── dcm_torch.py │ │ └── dcm_utilities.py │ └── dsm │ │ ├── __init__.py │ │ ├── datasets.py │ │ ├── datasets │ │ ├── framingham.csv │ │ ├── pbc2.csv │ │ └── support2.csv │ │ ├── dsm_torch.py │ │ ├── losses.py │ │ └── utilities.py ├── phenotyping.py ├── preprocessing.py ├── reporting.py └── utils.py ├── docs ├── dashboard.html ├── datasets.html ├── estimators.html ├── experiments.html ├── index.html ├── metrics.html ├── models │ ├── .DS_Store │ ├── cmhe │ │ ├── cmhe_torch.html │ │ ├── cmhe_utilities.html │ │ └── index.html │ ├── cph │ │ ├── dcph_torch.html │ │ ├── dcph_utilities.html │ │ └── index.html │ ├── dcm │ │ ├── dcm_torch.html │ │ ├── dcm_utilities.html │ │ └── index.html │ ├── dsm │ │ ├── datasets.html │ │ ├── dsm_torch.html │ │ ├── index.html │ │ ├── losses.html │ │ └── utilities.html │ └── index.html ├── phenotyping.html ├── preprocessing.html ├── reporting.html └── utils.html ├── examples ├── CV Survival Regression on SUPPORT Dataset.ipynb ├── DCM on SUPPORT Dataset copy.ipynb ├── DSM on SUPPORT Dataset.ipynb ├── Demo of CMHE on Synthetic Data.ipynb ├── Phenotyping Censored Time-to-Events.ipynb ├── RDSM on PBC Dataset.ipynb ├── Survival Regression with Auton-Survival.ipynb ├── cmhe_demo_utils.py ├── estimators_demo_utils.py └── matplotlibrc ├── pyproject.toml └── tests ├── __init__.py └── test_dsm.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | 10 | jobs: 11 | Quality: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | fail-fast: true 15 | matrix: 16 | python-version: ["3.8", "3.9", "3.10"] 17 | 18 | steps: 19 | - uses: actions/checkout@v3 20 | 21 | - uses: actions/setup-python@v3 22 | with: 23 | python-version: ${{matrix.python-version}} 24 | 25 | - name: Install Python Poetry 26 | uses: abatilo/actions-poetry@v2.3.0 27 | 28 | - name: Configure poetry 29 | shell: bash 30 | run: python -m poetry config virtualenvs.in-project true 31 | 32 | - name: View poetry version 33 | run: poetry --version 34 | 35 | - name: Install dependencies 36 | run: | 37 | python -m poetry install 38 | 39 | - name: Test 40 | run: poetry run python3 -m unittest discover 41 | 42 | Release: 43 | needs: Quality 44 | if: | 45 | github.repository == 'autonlab/auton-survival' && 46 | github.event_name == 'push' && 47 | github.ref == 'refs/heads/master' && 48 | !contains ( github.event.head_commit.message, 'chore(release)' ) 49 | 50 | runs-on: ubuntu-latest 51 | concurrency: release 52 | permissions: 53 | id-token: write 54 | contents: write 55 | 56 | steps: 57 | - uses: actions/setup-python@v3 58 | with: 59 | python-version: 3.8 60 | 61 | - uses: actions/checkout@v3 62 | with: 63 | fetch-depth: 0 64 | 65 | - name: Check release status 66 | id: release-status 67 | env: 68 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 69 | run: | 70 | pip install python-semantic-release 71 | if semantic-release --noop --strict version 72 | then 73 | echo "Releasing new version." 74 | else 75 | echo "Skipping release steps." 76 | fi 77 | 78 | - if: steps.release-status.outputs.released == 'true' 79 | name: Release to GitHub 80 | id: github-release 81 | env: 82 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 83 | run: | 84 | semantic-release version 85 | git fetch --tags 86 | for file in ./dist/** 87 | do gh release upload "${{steps.release-status.outputs.tag}}" $file 88 | done 89 | 90 | - if: steps.release-status.outputs.released == 'true' 91 | name: Release to Test PyPI 92 | id: test-pypi-release 93 | env: 94 | TEST_PYPI_TOKEN: ${{ secrets.TEST_PYPI_TOKEN }} 95 | run: | 96 | poetry config repositories.test-pypi https://test.pypi.org/legacy/ 97 | poetry config pypi-token.test-pypi $TEST_PYPI_TOKEN 98 | poetry publish -r test-pypi -u __token__ 99 | 100 | - if: steps.release-status.outputs.released == 'true' 101 | name: Release to PyPI 102 | id: pypi-release 103 | env: 104 | PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} 105 | run: | 106 | poetry config pypi-token.pypi $PYPI_TOKEN 107 | poetry publish 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Carnegie Mellon University Auton Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /auton_survival/models/cmhe/cmhe_torch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # MIT License 3 | 4 | # Copyright (c) 2020 Carnegie Mellon University, Auton Lab 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | import torch 25 | from auton_survival.models.dsm.dsm_torch import create_representation 26 | 27 | class DeepCMHETorch(torch.nn.Module): 28 | """PyTorch model definition of the Cox Mixture with Hereogenous Effects Model. 29 | 30 | Cox Mixtures with Heterogenous Effects involves the assuming that the 31 | base survival rates are independent of the treatment effect. 32 | of the individual to be a mixture of K Cox Models. Conditioned on each 33 | subgroup Z=k; the PH assumptions are assumed to hold and the baseline 34 | hazard rates is determined non-parametrically using an spline-interpolated 35 | Breslow's estimator. 36 | 37 | """ 38 | 39 | def _init_dcmhe_layers(self, lastdim): 40 | 41 | 42 | self.expert = IdentifiableLinear(lastdim, self.k, bias=False) 43 | self.z_gate = IdentifiableLinear(lastdim, self.k, bias=False) 44 | self.phi_gate = IdentifiableLinear(lastdim, self.g, bias=False) 45 | # self.expert = torch.nn.Linear(lastdim, self.k, bias=False) 46 | # self.z_gate = torch.nn.Linear(lastdim, self.k, bias=False) 47 | # self.phi_gate = torch.nn.Linear(lastdim, self.g, bias=False) 48 | self.omega = torch.nn.Parameter(torch.rand(self.g)-0.5) 49 | 50 | def __init__(self, k, g, inputdim, layers=None, gamma=100, 51 | smoothing_factor=1e-4, gate_l2_penalty=1e-4, 52 | optimizer='Adam'): 53 | 54 | super(DeepCMHETorch, self).__init__() 55 | 56 | assert isinstance(k, int) 57 | 58 | if layers is None: layers = [] 59 | 60 | self.optimizer = optimizer 61 | 62 | self.k = k # Base Physiology groups 63 | self.g = g # Treatment Effect groups 64 | 65 | self.gamma = gamma 66 | self.smoothing_factor = smoothing_factor 67 | 68 | if len(layers) == 0: lastdim = inputdim 69 | else: lastdim = layers[-1] 70 | 71 | self._init_dcmhe_layers(lastdim) 72 | 73 | self.gate_l2_penalty = gate_l2_penalty 74 | 75 | self.embedding = create_representation(inputdim, layers, 'Tanh') 76 | 77 | 78 | def forward(self, x, a): 79 | 80 | x = self.embedding(x) 81 | a = 2*(a-0.5) 82 | 83 | log_hrs = torch.clamp(self.expert(x), 84 | min=-self.gamma, 85 | max=self.gamma) 86 | 87 | logp_z_gate = torch.nn.LogSoftmax(dim=1)(self.z_gate(x)) # 88 | logp_phi_gate = torch.nn.LogSoftmax(dim=1)(self.phi_gate(x)) 89 | 90 | logp_jointlatent_gate = torch.zeros(len(x), self.k, self.g) 91 | 92 | for i in range(self.k): 93 | for j in range(self.g): 94 | logp_jointlatent_gate[:, i, j] = logp_z_gate[:, i] + logp_phi_gate[:, j] 95 | 96 | logp_joint_hrs = torch.zeros(len(x), self.k, self.g) 97 | 98 | for i in range(self.k): 99 | for j in range(self.g): 100 | logp_joint_hrs[:, i, j] = log_hrs[:, i] + (j!=2)*a*self.omega[j] 101 | 102 | return logp_jointlatent_gate, logp_joint_hrs 103 | 104 | class IdentifiableLinear(torch.nn.Module): 105 | 106 | """ 107 | Softmax and LogSoftmax with K classes in pytorch are over specfied and lead to 108 | issues of mis-identifiability. This class is a wrapper for linear layers that 109 | are correctly specified with K-1 columns. The output of this layer for the Kth 110 | class is all zeros. This allows direct application of pytorch.nn.LogSoftmax 111 | and pytorch.nn.Softmax. 112 | """ 113 | 114 | def __init__(self, in_features, out_features, bias=True): 115 | 116 | super(IdentifiableLinear, self).__init__() 117 | 118 | assert out_features>0; "Output features must be greater than 0" 119 | 120 | self.out_features = out_features 121 | self.in_features = in_features 122 | self.linear = torch.nn.Linear(in_features, max(out_features-1, 1), bias=bias) 123 | 124 | @property 125 | def weight(self): 126 | return self.linear.weight 127 | 128 | def forward(self, x): 129 | if self.out_features == 1: 130 | return self.linear(x).reshape(-1, 1) 131 | else: 132 | zeros = torch.zeros(len(x), 1, device=x.device) 133 | return torch.cat([self.linear(x), zeros], dim=1) -------------------------------------------------------------------------------- /auton_survival/models/cmhe/cmhe_utilities.py: -------------------------------------------------------------------------------- 1 | 2 | # coding=utf-8 3 | # MIT License 4 | 5 | # Copyright (c) 2020 Carnegie Mellon University, Auton Lab 6 | 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | import torch 26 | import numpy as np 27 | 28 | from scipy.interpolate import UnivariateSpline 29 | from sksurv.linear_model.coxph import BreslowEstimator 30 | 31 | from tqdm import tqdm 32 | 33 | def randargmax(b,**kw): 34 | """ a random tie-breaking argmax""" 35 | return np.argmax(np.random.random(b.shape) * (b==b.max()), **kw) 36 | 37 | def partial_ll_loss(lrisks, tb, eb, eps=1e-2): 38 | 39 | tb = tb + eps*np.random.random(len(tb)) 40 | sindex = np.argsort(-tb) 41 | 42 | tb = tb[sindex] 43 | eb = eb[sindex] 44 | 45 | lrisks = lrisks[sindex] # lrisks = tf.gather(lrisks, sindex) 46 | 47 | lrisksdenom = torch.logcumsumexp(lrisks, dim = 0) 48 | plls = lrisks - lrisksdenom 49 | pll = plls[eb == 1] 50 | 51 | pll = torch.sum(pll) # pll = tf.reduce_sum(pll) 52 | 53 | return -pll 54 | 55 | def fit_spline(t, surv, smoothing_factor=1e-4): 56 | return UnivariateSpline(t, surv, s=smoothing_factor, ext=3) 57 | 58 | def smooth_bl_survival(breslow, smoothing_factor): 59 | 60 | blsurvival = breslow.baseline_survival_ 61 | x, y = blsurvival.x, blsurvival.y 62 | return fit_spline(x, y, smoothing_factor=smoothing_factor) 63 | 64 | def get_probability_(lrisks, ts, spl): 65 | risks = np.exp(lrisks) 66 | s0ts = (-risks)*(spl(ts)**(risks-1)) 67 | return s0ts * spl.derivative()(ts) 68 | 69 | def get_survival_(lrisks, ts, spl): 70 | risks = np.exp(lrisks) 71 | return spl(ts)**risks 72 | 73 | def get_probability(lrisks, breslow_splines, t): 74 | psurv = [] 75 | for i in range(lrisks.shape[1]): 76 | p = get_probability_(lrisks[:, i], t, breslow_splines[i]) 77 | psurv.append(p) 78 | psurv = np.array(psurv).T 79 | return psurv 80 | 81 | def get_survival(lrisks, breslow_splines, t): 82 | psurv = [] 83 | for i in range(lrisks.shape[1]): 84 | p = get_survival_(lrisks[:, i], t, breslow_splines[i]) 85 | psurv.append(p) 86 | psurv = np.array(psurv).T 87 | return psurv 88 | 89 | def get_posteriors(probs): 90 | probs_ = probs+1e-8 91 | return probs-torch.logsumexp(probs, dim=1).reshape(-1,1) 92 | 93 | def get_hard_z(gates_prob): 94 | return torch.argmax(gates_prob, dim=1) 95 | 96 | def sample_hard_z(gates_prob): 97 | return torch.multinomial(gates_prob.exp(), num_samples=1)[:, 0] 98 | 99 | def repair_probs(probs): 100 | probs[torch.isnan(probs)] = -20 101 | probs[probs<-20] = -20 102 | return probs 103 | 104 | def get_likelihood(model, breslow_splines, x, t, e, a): 105 | 106 | # Function requires numpy/torch 107 | 108 | gates, lrisks = model(x, a=a) 109 | lrisks = lrisks.numpy() 110 | e, t = e.numpy(), t.numpy() 111 | 112 | probs = [] 113 | 114 | for i in range(model.g): 115 | 116 | survivals = get_survival(lrisks[:, :, i], breslow_splines, t) 117 | probability = get_probability(lrisks[:, :, i], breslow_splines, t) 118 | 119 | event_probs = np.array([survivals, probability]) 120 | event_probs = event_probs[e.astype('int'), range(len(e)), :] 121 | probs.append(np.log(event_probs)) 122 | 123 | probs = np.array(probs).transpose(1, 2, 0) 124 | event_probs = gates+probs 125 | 126 | return event_probs 127 | 128 | def q_function(model, x, t, e, a, log_likelihoods, typ='soft'): 129 | 130 | z_posteriors = repair_probs( 131 | get_posteriors( 132 | torch.logsumexp(log_likelihoods, dim=2))) 133 | zeta_posteriors = repair_probs( 134 | get_posteriors( 135 | torch.logsumexp(log_likelihoods, dim=1))) 136 | 137 | if typ == 'hard': 138 | z = get_hard_z(z_posteriors) 139 | zeta = get_hard_z(zeta_posteriors) 140 | else: 141 | z = sample_hard_z(z_posteriors) 142 | zeta = sample_hard_z(zeta_posteriors) 143 | 144 | gates, lrisks = model(x, a=a) 145 | 146 | loss = 0 147 | for i in range(model.k): 148 | lrisks_ = lrisks[:, i, :][range(len(zeta)), zeta] 149 | loss += partial_ll_loss(lrisks_[z == i], t[z == i], e[z == i]) 150 | 151 | #log_smax_loss = -torch.nn.LogSoftmax(dim=1)(gates) # tf.nn.log_softmax(gates) 152 | 153 | posteriors = repair_probs( 154 | get_posteriors( 155 | log_likelihoods.reshape(-1, model.k*model.g))).exp() 156 | 157 | gate_loss = posteriors*gates.reshape(-1, model.k*model.g) 158 | gate_loss = -torch.sum(gate_loss) 159 | loss+=gate_loss 160 | 161 | return loss 162 | 163 | def e_step(model, breslow_splines, x, t, e, a): 164 | 165 | # TODO: Do this in `Log Space` 166 | # If Breslow splines are not available, like in the first 167 | # iteration of learning, we randomly compute posteriors. 168 | if breslow_splines is None: log_likelihoods = torch.rand(len(x), model.k, model.g) 169 | else: log_likelihoods = get_likelihood(model, breslow_splines, x, t, e, a) 170 | 171 | return log_likelihoods 172 | 173 | def m_step(model, optimizer, x, t, e, a, log_likelihoods, typ='soft'): 174 | 175 | optimizer.zero_grad() 176 | loss = q_function(model, x, t, e, a, log_likelihoods, typ) 177 | gate_regularization_loss = (model.phi_gate.weight**2).sum() 178 | gate_regularization_loss += (model.z_gate.weight**2).sum() 179 | loss += (model.gate_l2_penalty)*gate_regularization_loss 180 | loss.backward() 181 | optimizer.step() 182 | 183 | return float(loss) 184 | 185 | def fit_breslow(model, x, t, e, a, log_likelihoods=None, smoothing_factor=1e-4, typ='soft'): 186 | 187 | gates, lrisks = model(x, a=a) 188 | 189 | lrisks = lrisks.numpy() 190 | 191 | e = e.numpy() 192 | t = t.numpy() 193 | 194 | if log_likelihoods is None: 195 | z_posteriors = torch.logsumexp(gates, dim=2) 196 | zeta_posteriors = torch.logsumexp(gates, dim=1) 197 | else: 198 | z_posteriors = repair_probs(get_posteriors(torch.logsumexp(log_likelihoods, dim=2))) 199 | zeta_posteriors = repair_probs(get_posteriors(torch.logsumexp(log_likelihoods, dim=1))) 200 | 201 | if typ == 'soft': 202 | z = sample_hard_z(z_posteriors) 203 | zeta = sample_hard_z(zeta_posteriors) 204 | else: 205 | z = get_hard_z(z_posteriors) 206 | zeta = get_hard_z(zeta_posteriors) 207 | 208 | breslow_splines = {} 209 | for i in range(model.k): 210 | breslowk = BreslowEstimator().fit(lrisks[:, i, :][range(len(zeta)), zeta][z==i], e[z==i], t[z==i]) 211 | breslow_splines[i] = smooth_bl_survival(breslowk, smoothing_factor=smoothing_factor) 212 | 213 | return breslow_splines 214 | 215 | 216 | def train_step(model, x, t, e, a, breslow_splines, optimizer, 217 | bs=256, seed=100, typ='soft', use_posteriors=False, 218 | update_splines_after=10, smoothing_factor=1e-4): 219 | 220 | from sklearn.utils import shuffle 221 | 222 | x, t, e, a = shuffle(x, t, e, a, random_state=seed) 223 | 224 | n = x.shape[0] 225 | batches = (n // bs) + 1 226 | 227 | epoch_loss = 0 228 | for i in range(batches): 229 | 230 | xb = x[i*bs:(i+1)*bs] 231 | tb = t[i*bs:(i+1)*bs] 232 | eb = e[i*bs:(i+1)*bs] 233 | ab = a[i*bs:(i+1)*bs] 234 | 235 | # E-Step !!! 236 | # e_step_start = time.time() 237 | with torch.no_grad(): 238 | log_likelihoods = e_step(model, breslow_splines, xb, tb, eb, ab) 239 | 240 | torch.enable_grad() 241 | loss = m_step(model, optimizer, xb, tb, eb, ab, log_likelihoods, typ=typ) 242 | epoch_loss += loss 243 | 244 | with torch.no_grad(): 245 | if i%update_splines_after == 0: 246 | if use_posteriors: 247 | log_likelihoods = e_step(model, breslow_splines, x, t, e, a) 248 | breslow_splines = fit_breslow(model, x, t, e, a, 249 | log_likelihoods=log_likelihoods, 250 | typ='soft', 251 | smoothing_factor=smoothing_factor) 252 | else: 253 | breslow_splines = fit_breslow(model, x, t, e, a, 254 | log_likelihoods=None, 255 | typ='soft', 256 | smoothing_factor=smoothing_factor) 257 | # print(f'Duration of Breslow spline estimation: {time.time() - estimate_breslow_start}') 258 | # except Exception as exce: 259 | # print("Exception!!!:", exce) 260 | # logging.warning("Couldn't fit splines, reusing from previous epoch") 261 | #print (epoch_loss/n) 262 | return breslow_splines 263 | 264 | 265 | def test_step(model, x, t, e, a, breslow_splines, loss='q', typ='soft'): 266 | 267 | if loss == 'q': 268 | with torch.no_grad(): 269 | posteriors = e_step(model, breslow_splines, x, t, e, a) 270 | loss = q_function(model, x, t, e, a, posteriors, typ=typ) 271 | 272 | return float(loss/x.shape[0]) 273 | 274 | 275 | def train_cmhe(model, train_data, val_data, epochs=50, 276 | patience=2, vloss='q', bs=256, typ='soft', lr=1e-3, 277 | use_posteriors=False, debug=False, 278 | return_losses=False, update_splines_after=10, 279 | smoothing_factor=1e-4, random_seed=0): 280 | 281 | torch.manual_seed(random_seed) 282 | np.random.seed(random_seed) 283 | 284 | if val_data is None: val_data = train_data 285 | 286 | xt, tt, et, at = train_data 287 | xv, tv, ev, av = val_data 288 | 289 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 290 | 291 | valc = np.inf 292 | patience_ = 0 293 | 294 | breslow_splines = None 295 | 296 | losses = [] 297 | 298 | for epoch in tqdm(range(epochs)): 299 | 300 | # train_step_start = time.time() 301 | breslow_splines = train_step(model, xt, tt, et, at, breslow_splines, 302 | optimizer, bs=bs, seed=epoch, typ=typ, 303 | use_posteriors=use_posteriors, 304 | update_splines_after=update_splines_after, 305 | smoothing_factor=smoothing_factor) 306 | # print(f'Duration of train-step: {time.time() - train_step_start}') 307 | # test_step_start = time.time() 308 | valcn = test_step(model, xv, tv, ev, av, breslow_splines, 309 | loss=vloss, typ=typ) 310 | # print(f'Duration of test-step: {time.time() - test_step_start}') 311 | 312 | losses.append(valcn) 313 | 314 | if epoch % 1 == 0: 315 | if debug: print(patience_, epoch, valcn) 316 | 317 | if valcn > valc: 318 | patience_ += 1 319 | else: 320 | patience_ = 0 321 | 322 | if patience_ == patience: 323 | if return_losses: return (model, breslow_splines), losses 324 | else: return (model, breslow_splines) 325 | 326 | valc = valcn 327 | 328 | if return_losses: return (model, breslow_splines), losses 329 | else: return (model, breslow_splines) 330 | 331 | def predict_survival(model, x, a, t): 332 | 333 | if isinstance(t, (int, float)): t = [t] 334 | 335 | model, breslow_splines = model 336 | 337 | gates, lrisks = model(x, a=a) 338 | 339 | lrisks = lrisks.detach().numpy() 340 | gates = gates.exp().reshape(-1, model.k*model.g).detach().numpy() 341 | 342 | predictions = [] 343 | for t_ in t: 344 | expert_outputs = [] 345 | for i in range(model.g): 346 | expert_output = get_survival(lrisks[:, :, i], breslow_splines, t_) 347 | expert_outputs.append(expert_output) 348 | expert_outputs = np.array(expert_outputs).transpose(1, 2, 0).reshape(-1, model.k*model.g) 349 | 350 | predictions.append((gates*expert_outputs).sum(axis=1)) 351 | return np.array(predictions).T 352 | 353 | def predict_latent_z(model, x): 354 | 355 | model, _ = model 356 | gates = model.model.embedding(x) 357 | 358 | z_gate_probs = torch.exp(gates).sum(axis=2).detach().numpy() 359 | 360 | return z_gate_probs 361 | 362 | def predict_latent_phi(model, x): 363 | 364 | model, _ = model 365 | x = model.embedding(x) 366 | 367 | p_phi_gate = torch.nn.Softmax(dim=1)(model.phi_gate(x)).detach().numpy() 368 | 369 | return p_phi_gate 370 | -------------------------------------------------------------------------------- /auton_survival/models/cph/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # MIT License 3 | 4 | # Copyright (c) 2020 Carnegie Mellon University, Auton Lab 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | r""" Deep Cox Proportional Hazards Model""" 25 | 26 | import torch 27 | import numpy as np 28 | 29 | from .dcph_torch import DeepCoxPHTorch, DeepRecurrentCoxPHTorch 30 | from .dcph_utilities import train_dcph, predict_survival 31 | 32 | from auton_survival.utils import _dataframe_to_array 33 | from auton_survival.models.dsm.utilities import _get_padded_features 34 | from auton_survival.models.dsm.utilities import _get_padded_targets 35 | 36 | 37 | class DeepCoxPH: 38 | """A Deep Cox Proportional Hazards model. 39 | 40 | This is the main interface to a Deep Cox Proportional Hazards model. 41 | A model is instantiated with approporiate set of hyperparameters and 42 | fit on numpy arrays consisting of the features, event/censoring times 43 | and the event/censoring indicators. 44 | 45 | For full details on Deep Cox Proportional Hazards, refer [1], [2]. 46 | 47 | References 48 | ---------- 49 | [1] DeepSurv: personalized 50 | treatment recommender system using a Cox proportional hazards 51 | deep neural network. BMC medical research methodology (2018) 52 | 53 | [2] 54 | A neural network model for survival data. Statistics in medicine (1995) 55 | 56 | Parameters 57 | ---------- 58 | k: int 59 | The number of underlying Cox distributions. 60 | layers: list 61 | A list of integers consisting of the number of neurons in each 62 | hidden layer. 63 | random_seed: int 64 | Controls the reproducibility of called functions. 65 | Example 66 | ------- 67 | >>> from auton_survival import DeepCoxPH 68 | >>> model = DeepCoxPH() 69 | >>> model.fit(x, t, e) 70 | 71 | """ 72 | 73 | def __init__(self, layers=None, random_seed=0): 74 | 75 | self.layers = layers 76 | self.fitted = False 77 | self.random_seed = random_seed 78 | 79 | def __call__(self): 80 | if self.fitted: 81 | print("A fitted instance of the Deep Cox PH model") 82 | else: 83 | print("An unfitted instance of the Deep Cox PH model") 84 | 85 | print("Hidden Layers:", self.layers) 86 | 87 | def _preprocess_test_data(self, x): 88 | x = _dataframe_to_array(x) 89 | return torch.from_numpy(x).float() 90 | 91 | def _preprocess_training_data(self, x, t, e, vsize, val_data, random_seed): 92 | 93 | x = _dataframe_to_array(x) 94 | t = _dataframe_to_array(t) 95 | e = _dataframe_to_array(e) 96 | 97 | idx = list(range(x.shape[0])) 98 | 99 | np.random.seed(random_seed) 100 | np.random.shuffle(idx) 101 | 102 | x_train, t_train, e_train = x[idx], t[idx], e[idx] 103 | 104 | x_train = torch.from_numpy(x_train).float() 105 | t_train = torch.from_numpy(t_train).float() 106 | e_train = torch.from_numpy(e_train).float() 107 | 108 | if val_data is None: 109 | 110 | vsize = int(vsize*x_train.shape[0]) 111 | x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:] 112 | 113 | x_train = x_train[:-vsize] 114 | t_train = t_train[:-vsize] 115 | e_train = e_train[:-vsize] 116 | 117 | else: 118 | 119 | x_val, t_val, e_val = val_data 120 | 121 | x_val = _dataframe_to_array(x_val) 122 | t_val = _dataframe_to_array(t_val) 123 | e_val = _dataframe_to_array(e_val) 124 | 125 | x_val = torch.from_numpy(x_val).float() 126 | t_val = torch.from_numpy(t_val).float() 127 | e_val = torch.from_numpy(e_val).float() 128 | 129 | return (x_train, t_train, e_train, x_val, t_val, e_val) 130 | 131 | def _gen_torch_model(self, inputdim, optimizer): 132 | """Helper function to return a torch model.""" 133 | # Add random seed to get the same results like in dcm __init__.py 134 | np.random.seed(self.random_seed) 135 | torch.manual_seed(self.random_seed) 136 | 137 | return DeepCoxPHTorch(inputdim, layers=self.layers, 138 | optimizer=optimizer) 139 | 140 | def fit(self, x, t, e, vsize=0.15, val_data=None, 141 | iters=1, learning_rate=1e-3, batch_size=100, 142 | optimizer="Adam"): 143 | 144 | r"""This method is used to train an instance of the DSM model. 145 | 146 | Parameters 147 | ---------- 148 | x: np.ndarray 149 | A numpy array of the input features, \( x \). 150 | t: np.ndarray 151 | A numpy array of the event/censoring times, \( t \). 152 | e: np.ndarray 153 | A numpy array of the event/censoring indicators, \( \delta \). 154 | \( \delta = 1 \) means the event took place. 155 | vsize: float 156 | Amount of data to set aside as the validation set. 157 | val_data: tuple 158 | A tuple of the validation dataset. If passed vsize is ignored. 159 | iters: int 160 | The maximum number of training iterations on the training dataset. 161 | learning_rate: float 162 | The learning rate for the `Adam` optimizer. 163 | batch_size: int 164 | learning is performed on mini-batches of input data. this parameter 165 | specifies the size of each mini-batch. 166 | optimizer: str 167 | The choice of the gradient based optimization method. One of 168 | 'Adam', 'RMSProp' or 'SGD'. 169 | 170 | """ 171 | 172 | processed_data = self._preprocess_training_data(x, t, e, 173 | vsize, val_data, 174 | self.random_seed) 175 | 176 | x_train, t_train, e_train, x_val, t_val, e_val = processed_data 177 | 178 | #Todo: Change this somehow. The base design shouldn't depend on child 179 | 180 | inputdim = x_train.shape[-1] 181 | 182 | model = self._gen_torch_model(inputdim, optimizer) 183 | 184 | model, _ = train_dcph(model, 185 | (x_train, t_train, e_train), 186 | (x_val, t_val, e_val), 187 | epochs=iters, 188 | lr=learning_rate, 189 | bs=batch_size, 190 | return_losses=True, 191 | random_seed=self.random_seed) 192 | 193 | self.torch_model = (model[0].eval(), model[1]) 194 | self.fitted = True 195 | 196 | return self 197 | 198 | def predict_risk(self, x, t=None): 199 | 200 | if self.fitted: 201 | return 1-self.predict_survival(x, t) 202 | else: 203 | raise Exception("The model has not been fitted yet. Please fit the " + 204 | "model using the `fit` method on some training data " + 205 | "before calling `predict_risk`.") 206 | 207 | def predict_survival(self, x, t=None): 208 | r"""Returns the estimated survival probability at time \( t \), 209 | \( \widehat{\mathbb{P}}(T > t|X) \) for some input data \( x \). 210 | 211 | Parameters 212 | ---------- 213 | x: np.ndarray 214 | A numpy array of the input features, \( x \). 215 | t: list or float 216 | a list or float of the times at which survival probability is 217 | to be computed 218 | Returns: 219 | np.array: numpy array of the survival probabilites at each time in t. 220 | 221 | """ 222 | if not self.fitted: 223 | raise Exception("The model has not been fitted yet. Please fit the " + 224 | "model using the `fit` method on some training data " + 225 | "before calling `predict_survival`.") 226 | 227 | x = self._preprocess_test_data(x) 228 | 229 | if t is not None: 230 | if not isinstance(t, list): 231 | t = [t] 232 | 233 | scores = predict_survival(self.torch_model, x, t) 234 | return scores 235 | 236 | 237 | class DeepRecurrentCoxPH(DeepCoxPH): 238 | r"""A deep recurrent Cox PH model. 239 | 240 | This model is based on the paper: 241 | Leveraging 242 | Deep Representations of Radiology Reports in Survival Analysis for 243 | Predicting Heart Failure Patient Mortality. NAACL (2021) 244 | 245 | Parameters 246 | ---------- 247 | k: int 248 | The number of underlying Cox distributions. 249 | layers: list 250 | A list of integers consisting of the number of neurons in each 251 | hidden layer. 252 | random_seed: int 253 | Controls the reproducibility of called functions. 254 | Example 255 | ------- 256 | >>> from dsm.contrib import DeepRecurrentCoxPH 257 | >>> model = DeepRecurrentCoxPH() 258 | >>> model.fit(x, t, e) 259 | 260 | """ 261 | 262 | def __init__(self, layers=None, hidden=None, typ="LSTM", random_seed=0): 263 | 264 | super(DeepRecurrentCoxPH, self).__init__(layers=layers) 265 | 266 | self.typ = typ 267 | self.hidden = hidden 268 | self.random_seed = random_seed 269 | 270 | def __call__(self): 271 | if self.fitted: 272 | print("A fitted instance of the Recurrent Deep Cox PH model") 273 | else: 274 | print("An unfitted instance of the Recurrent Deep Cox PH model") 275 | 276 | print("Hidden Layers:", self.layers) 277 | 278 | def _gen_torch_model(self, inputdim, optimizer): 279 | """Helper function to return a torch model.""" 280 | 281 | np.random.seed(self.random_seed) 282 | torch.manual_seed(self.random_seed) 283 | 284 | return DeepRecurrentCoxPHTorch(inputdim, layers=self.layers, 285 | hidden=self.hidden, 286 | optimizer=optimizer, typ=self.typ) 287 | 288 | def _preprocess_test_data(self, x): 289 | if isinstance(x, pd.DataFrame): 290 | x = x.values 291 | return torch.from_numpy(_get_padded_features(x)).float() 292 | 293 | def _preprocess_training_data(self, x, t, e, vsize, val_data, random_seed): 294 | """RNNs require different preprocessing for variable length sequences""" 295 | 296 | x = _dataframe_to_array(x) 297 | t = _dataframe_to_array(t) 298 | e = _dataframe_to_array(e) 299 | 300 | idx = list(range(x.shape[0])) 301 | np.random.seed(random_seed) 302 | np.random.shuffle(idx) 303 | 304 | x = _get_padded_features(x) 305 | t = _get_padded_targets(t) 306 | e = _get_padded_targets(e) 307 | 308 | x_train, t_train, e_train = x[idx], t[idx], e[idx] 309 | 310 | x_train = torch.from_numpy(x_train).float() 311 | t_train = torch.from_numpy(t_train).float() 312 | e_train = torch.from_numpy(e_train).float() 313 | 314 | if val_data is None: 315 | 316 | vsize = int(vsize*x_train.shape[0]) 317 | 318 | x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:] 319 | 320 | x_train = x_train[:-vsize] 321 | t_train = t_train[:-vsize] 322 | e_train = e_train[:-vsize] 323 | 324 | else: 325 | 326 | x_val, t_val, e_val = val_data 327 | 328 | x_val = _dataframe_to_array(x_val) 329 | t_val = _dataframe_to_array(t_val) 330 | e_val = _dataframe_to_array(e_val) 331 | 332 | x_val = _get_padded_features(x_val) 333 | t_val = _get_padded_features(t_val) 334 | e_val = _get_padded_features(e_val) 335 | 336 | x_val = torch.from_numpy(x_val).float() 337 | t_val = torch.from_numpy(t_val).float() 338 | e_val = torch.from_numpy(e_val).float() 339 | 340 | return (x_train, t_train, e_train, x_val, t_val, e_val) 341 | -------------------------------------------------------------------------------- /auton_survival/models/cph/dcph_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from auton_survival.models.dsm.dsm_torch import create_representation 5 | 6 | 7 | class DeepCoxPHTorch(nn.Module): 8 | 9 | def _init_coxph_layers(self, lastdim): 10 | self.expert = nn.Linear(lastdim, 1, bias=False) 11 | 12 | def __init__(self, inputdim, layers=None, optimizer='Adam'): 13 | 14 | super(DeepCoxPHTorch, self).__init__() 15 | 16 | self.optimizer = optimizer 17 | 18 | if layers is None: layers = [] 19 | self.layers = layers 20 | 21 | if len(layers) == 0: lastdim = inputdim 22 | else: lastdim = layers[-1] 23 | 24 | self._init_coxph_layers(lastdim) 25 | self.embedding = create_representation(inputdim, layers, 'ReLU6') 26 | 27 | def forward(self, x): 28 | 29 | return self.expert(self.embedding(x)) 30 | 31 | class DeepRecurrentCoxPHTorch(DeepCoxPHTorch): 32 | 33 | def __init__(self, inputdim, typ='LSTM', layers=1, 34 | hidden=None, optimizer='Adam'): 35 | 36 | super(DeepCoxPHTorch, self).__init__() 37 | 38 | self.typ = typ 39 | self.layers = layers 40 | self.hidden = hidden 41 | self.optimizer = optimizer 42 | 43 | self._init_coxph_layers(hidden) 44 | 45 | if self.typ == 'LSTM': 46 | self.embedding = nn.LSTM(inputdim, hidden, layers, 47 | bias=False, batch_first=True) 48 | if self.typ == 'RNN': 49 | self.embedding = nn.RNN(inputdim, hidden, layers, 50 | bias=False, batch_first=True, 51 | nonlinearity='relu') 52 | if self.typ == 'GRU': 53 | self.embedding = nn.GRU(inputdim, hidden, layers, 54 | bias=False, batch_first=True) 55 | 56 | def forward(self, x): 57 | 58 | x = x.detach().clone() 59 | inputmask = ~torch.isnan(x[:, :, 0]).reshape(-1) 60 | x[torch.isnan(x)] = 0 61 | 62 | xrep, _ = self.embedding(x) 63 | xrep = xrep.contiguous().view(-1, self.hidden) 64 | xrep = xrep[inputmask] 65 | xrep = nn.ReLU6()(xrep) 66 | 67 | dim = xrep.shape[0] 68 | 69 | return self.expert(xrep.view(dim, -1)) 70 | -------------------------------------------------------------------------------- /auton_survival/models/cph/dcph_utilities.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | 5 | from sksurv.linear_model.coxph import BreslowEstimator 6 | 7 | from sklearn.utils import shuffle 8 | 9 | from tqdm import tqdm 10 | 11 | from auton_survival.models.dsm.utilities import get_optimizer, _reshape_tensor_with_nans 12 | 13 | from copy import deepcopy 14 | 15 | def randargmax(b,**kw): 16 | """ a random tie-breaking argmax""" 17 | return np.argmax(np.random.random(b.shape) * (b==b.max()), **kw) 18 | 19 | def partial_ll_loss(lrisks, tb, eb, eps=1e-3): 20 | 21 | tb = tb + eps*np.random.random(len(tb)) 22 | sindex = np.argsort(-tb) 23 | 24 | tb = tb[sindex] 25 | eb = eb[sindex] 26 | 27 | lrisks = lrisks[sindex] 28 | lrisksdenom = torch.logcumsumexp(lrisks, dim = 0) 29 | 30 | plls = lrisks - lrisksdenom 31 | pll = plls[eb == 1] 32 | 33 | pll = torch.sum(pll) 34 | 35 | return -pll 36 | 37 | def fit_breslow(model, x, t, e): 38 | return BreslowEstimator().fit(model(x).detach().cpu().numpy(), 39 | e.numpy(), t.numpy()) 40 | 41 | def train_step(model, x, t, e, optimizer, bs=256, seed=100): 42 | 43 | x, t, e = shuffle(x, t, e, random_state=seed) 44 | 45 | n = x.shape[0] 46 | 47 | batches = (n // bs) + 1 48 | 49 | epoch_loss = 0 50 | 51 | for i in range(batches): 52 | 53 | xb = x[i*bs:(i+1)*bs] 54 | tb = t[i*bs:(i+1)*bs] 55 | eb = e[i*bs:(i+1)*bs] 56 | 57 | # Training Step 58 | torch.enable_grad() 59 | optimizer.zero_grad() 60 | loss = partial_ll_loss(model(xb), 61 | _reshape_tensor_with_nans(tb), 62 | _reshape_tensor_with_nans(eb)) 63 | loss.backward() 64 | optimizer.step() 65 | 66 | epoch_loss += float(loss) 67 | 68 | return epoch_loss/n 69 | 70 | def test_step(model, x, t, e): 71 | 72 | with torch.no_grad(): 73 | loss = float(partial_ll_loss(model(x), t, e)) 74 | 75 | return loss/x.shape[0] 76 | 77 | 78 | def train_dcph(model, train_data, val_data, epochs=50, 79 | patience=3, bs=256, lr=1e-3, debug=False, 80 | random_seed=0, return_losses=False): 81 | 82 | torch.manual_seed(random_seed) 83 | np.random.seed(random_seed) 84 | 85 | if val_data is None: 86 | val_data = train_data 87 | 88 | xt, tt, et = train_data 89 | xv, tv, ev = val_data 90 | 91 | tt_ = _reshape_tensor_with_nans(tt) 92 | et_ = _reshape_tensor_with_nans(et) 93 | tv_ = _reshape_tensor_with_nans(tv) 94 | ev_ = _reshape_tensor_with_nans(ev) 95 | 96 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 97 | optimizer = get_optimizer(model, lr) 98 | 99 | valc = np.inf 100 | patience_ = 0 101 | 102 | breslow_spline = None 103 | 104 | losses = [] 105 | dics = [] 106 | 107 | for epoch in tqdm(range(epochs)): 108 | 109 | # train_step_start = time.time() 110 | _ = train_step(model, xt, tt, et, optimizer, bs, seed=epoch) 111 | # print(f'Duration of train-step: {time.time() - train_step_start}') 112 | # test_step_start = time.time() 113 | valcn = test_step(model, xv, tv_, ev_) 114 | # print(f'Duration of test-step: {time.time() - test_step_start}') 115 | 116 | losses.append(float(valcn)) 117 | 118 | dics.append(deepcopy(model.state_dict())) 119 | 120 | if epoch % 1 == 0: 121 | if debug: print(patience_, epoch, valcn) 122 | 123 | if valcn > valc: 124 | patience_ += 1 125 | else: 126 | patience_ = 0 127 | 128 | if patience_ == patience: 129 | 130 | minm = np.argmin(losses) 131 | model.load_state_dict(dics[minm]) 132 | 133 | breslow_spline = fit_breslow(model, xt, tt_, et_) 134 | 135 | if return_losses: 136 | return (model, breslow_spline), losses 137 | else: 138 | return (model, breslow_spline) 139 | 140 | valc = valcn 141 | 142 | minm = np.argmin(losses) 143 | model.load_state_dict(dics[minm]) 144 | 145 | breslow_spline = fit_breslow(model, xt, tt_, et_) 146 | 147 | if return_losses: 148 | return (model, breslow_spline), losses 149 | else: 150 | return (model, breslow_spline) 151 | 152 | def predict_survival(model, x, t=None): 153 | 154 | if isinstance(t, (int, float)): t = [t] 155 | 156 | model, breslow_spline = model 157 | lrisks = model(x).detach().cpu().numpy() 158 | 159 | unique_times = breslow_spline.baseline_survival_.x 160 | 161 | raw_predictions = breslow_spline.get_survival_function(lrisks) 162 | raw_predictions = np.array([pred.y for pred in raw_predictions]) 163 | 164 | predictions = pd.DataFrame(data=raw_predictions, columns=unique_times) 165 | 166 | if t is None: 167 | return predictions 168 | else: 169 | return __interpolate_missing_times(predictions.T, t) 170 | #return np.array(predictions).T 171 | 172 | def __interpolate_missing_times(survival_predictions, times): 173 | 174 | nans = np.full(survival_predictions.shape[1], np.nan) 175 | not_in_index = list(set(times) - set(survival_predictions.index)) 176 | 177 | for idx in not_in_index: 178 | survival_predictions.loc[idx] = nans 179 | return survival_predictions.sort_index(axis=0).interpolate(method='bfill').T[times].values 180 | -------------------------------------------------------------------------------- /auton_survival/models/dcm/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # MIT License 3 | 4 | # Copyright (c) 2020 Carnegie Mellon University, Auton Lab 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | r""" 25 | 26 | Deep Cox Mixtures 27 | ------------------ 28 | 29 | The Cox Mixture involves the assumption that the survival function 30 | of the individual to be a mixture of K Cox Models. Conditioned on each 31 | subgroup \( Z=k \); the PH assumptions are assumed to hold and the baseline 32 | hazard rates is determined non-parametrically using an spline-interpolated 33 | Breslow's estimator. 34 | 35 | For full details on Deep Cox Mixture, refer to the paper [1]. 36 | 37 | References 38 | ---------- 39 | [1] Deep Cox Mixtures 40 | for Survival Regression. Machine Learning in Health Conference (2021) 41 | 42 | ``` 43 | @article{nagpal2021dcm, 44 | title={Deep Cox mixtures for survival regression}, 45 | author={Nagpal, Chirag and Yadlowsky, Steve and Rostamzadeh, Negar and Heller, Katherine}, 46 | journal={arXiv preprint arXiv:2101.06536}, 47 | year={2021} 48 | } 49 | ``` 50 | 51 | """ 52 | 53 | import torch 54 | import numpy as np 55 | 56 | from .dcm_torch import DeepCoxMixturesTorch 57 | from .dcm_utilities import train_dcm, predict_survival, predict_latent_z 58 | 59 | from auton_survival.utils import _dataframe_to_array 60 | 61 | 62 | class DeepCoxMixtures: 63 | """A Deep Cox Mixture model. 64 | 65 | This is the main interface to a Deep Cox Mixture model. 66 | A model is instantiated with approporiate set of hyperparameters and 67 | fit on numpy arrays consisting of the features, event/censoring times 68 | and the event/censoring indicators. 69 | 70 | For full details on Deep Cox Mixture, refer to the paper [1]. 71 | 72 | References 73 | ---------- 74 | [1] Deep Cox Mixtures 75 | for Survival Regression. Machine Learning in Health Conference (2021) 76 | 77 | Parameters 78 | ---------- 79 | k: int 80 | The number of underlying Cox distributions. 81 | layers: list 82 | A list of integers consisting of the number of neurons in each 83 | hidden layer. 84 | random_seed: int 85 | Controls the reproducibility of called functions. 86 | 87 | Example 88 | ------- 89 | >>> from auton_survival.models.dcm import DeepCoxMixtures 90 | >>> model = DeepCoxMixtures() 91 | >>> model.fit(x, t, e) 92 | 93 | """ 94 | 95 | def __init__(self, k=3, layers=None, gamma=10, 96 | smoothing_factor=1e-4, use_activation=False, 97 | random_seed=0): 98 | 99 | self.k = k 100 | self.layers = layers 101 | self.fitted = False 102 | self.gamma = gamma 103 | self.smoothing_factor = smoothing_factor 104 | self.use_activation = use_activation 105 | self.random_seed = random_seed 106 | 107 | def __call__(self): 108 | if self.fitted: 109 | print("A fitted instance of the Deep Cox Mixtures model") 110 | else: 111 | print("An unfitted instance of the Deep Cox Mixtures model") 112 | 113 | print("Number of underlying cox distributions (k):", self.k) 114 | print("Hidden Layers:", self.layers) 115 | 116 | def _preprocess_test_data(self, x): 117 | x = _dataframe_to_array(x) 118 | return torch.from_numpy(x).float() 119 | 120 | def _preprocess_training_data(self, x, t, e, vsize, val_data, random_seed): 121 | 122 | x = _dataframe_to_array(x) 123 | t = _dataframe_to_array(t) 124 | e = _dataframe_to_array(e) 125 | 126 | idx = list(range(x.shape[0])) 127 | np.random.seed(random_seed) 128 | np.random.shuffle(idx) 129 | x_train, t_train, e_train = x[idx], t[idx], e[idx] 130 | 131 | x_train = torch.from_numpy(x_train).float() 132 | t_train = torch.from_numpy(t_train).float() 133 | e_train = torch.from_numpy(e_train).float() 134 | 135 | if val_data is None: 136 | 137 | vsize = int(vsize*x_train.shape[0]) 138 | x_val, t_val, e_val = x_train[-vsize:], t_train[-vsize:], e_train[-vsize:] 139 | 140 | x_train = x_train[:-vsize] 141 | t_train = t_train[:-vsize] 142 | e_train = e_train[:-vsize] 143 | 144 | else: 145 | 146 | x_val, t_val, e_val = val_data 147 | 148 | x_val = _dataframe_to_array(x_val) 149 | t_val = _dataframe_to_array(t_val) 150 | e_val = _dataframe_to_array(e_val) 151 | 152 | x_val = torch.from_numpy(x_val).float() 153 | t_val = torch.from_numpy(t_val).float() 154 | e_val = torch.from_numpy(e_val).float() 155 | 156 | return (x_train, t_train, e_train, x_val, t_val, e_val) 157 | 158 | def _gen_torch_model(self, inputdim, optimizer): 159 | """Helper function to return a torch model.""" 160 | 161 | np.random.seed(self.random_seed) 162 | torch.manual_seed(self.random_seed) 163 | 164 | return DeepCoxMixturesTorch(inputdim, 165 | k=self.k, 166 | gamma=self.gamma, 167 | use_activation=self.use_activation, 168 | layers=self.layers, 169 | optimizer=optimizer) 170 | 171 | def fit(self, x, t, e, vsize=0.15, val_data=None, 172 | iters=1, learning_rate=1e-3, batch_size=100, 173 | optimizer="Adam"): 174 | 175 | r"""This method is used to train an instance of the DSM model. 176 | 177 | Parameters 178 | ---------- 179 | x: np.ndarray 180 | A numpy array of the input features, \( x \). 181 | t: np.ndarray 182 | A numpy array of the event/censoring times, \( t \). 183 | e: np.ndarray 184 | A numpy array of the event/censoring indicators, \( \delta \). 185 | \( \delta = 1 \) means the event took place. 186 | vsize: float 187 | Amount of data to set aside as the validation set. 188 | val_data: tuple 189 | A tuple of the validation dataset. If passed vsize is ignored. 190 | iters: int 191 | The maximum number of training iterations on the training dataset. 192 | learning_rate: float 193 | The learning rate for the `Adam` optimizer. 194 | batch_size: int 195 | learning is performed on mini-batches of input data. this parameter 196 | specifies the size of each mini-batch. 197 | optimizer: str 198 | The choice of the gradient based optimization method. One of 199 | 'Adam', 'RMSProp' or 'SGD'. 200 | 201 | """ 202 | 203 | processed_data = self._preprocess_training_data(x, t, e, 204 | vsize, val_data, 205 | self.random_seed) 206 | x_train, t_train, e_train, x_val, t_val, e_val = processed_data 207 | 208 | #Todo: Change this somehow. The base design shouldn't depend on child 209 | 210 | inputdim = x_train.shape[-1] 211 | 212 | model = self._gen_torch_model(inputdim, optimizer) 213 | 214 | model, _ = train_dcm(model, 215 | (x_train, t_train, e_train), 216 | (x_val, t_val, e_val), 217 | epochs=iters, 218 | lr=learning_rate, 219 | bs=batch_size, 220 | return_losses=True, 221 | smoothing_factor=self.smoothing_factor, 222 | use_posteriors=True, 223 | random_seed=self.random_seed) 224 | 225 | self.torch_model = (model[0].eval(), model[1]) 226 | self.fitted = True 227 | 228 | return self 229 | 230 | 231 | def predict_survival(self, x, t): 232 | r"""Returns the estimated survival probability at time \( t \), 233 | \( \widehat{\mathbb{P}}(T > t|X) \) for some input data \( x \). 234 | 235 | Parameters 236 | ---------- 237 | x: np.ndarray 238 | A numpy array of the input features, \( x \). 239 | t: list or float 240 | a list or float of the times at which survival probability is 241 | to be computed 242 | Returns: 243 | np.array: numpy array of the survival probabilites at each time in t. 244 | 245 | """ 246 | x = self._preprocess_test_data(x) 247 | if not isinstance(t, list): 248 | t = [t] 249 | if self.fitted: 250 | scores = predict_survival(self.torch_model, x, t) 251 | return scores 252 | else: 253 | raise Exception("The model has not been fitted yet. Please fit the " + 254 | "model using the `fit` method on some training data " + 255 | "before calling `predict_survival`.") 256 | 257 | def predict_latent_z(self, x): 258 | 259 | x = self._preprocess_test_data(x) 260 | 261 | if self.fitted: 262 | scores = predict_latent_z(self.torch_model, x) 263 | return scores 264 | else: 265 | raise Exception("The model has not been fitted yet. Please fit the " + 266 | "model using the `fit` method on some training data " + 267 | "before calling `predict_latent_z`.") 268 | -------------------------------------------------------------------------------- /auton_survival/models/dcm/dcm_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | 6 | from auton_survival.models.dsm.dsm_torch import create_representation 7 | 8 | class DeepCoxMixturesTorch(nn.Module): 9 | """PyTorch model definition of the Deep Cox Mixture Survival Model. 10 | 11 | The Cox Mixture involves the assumption that the survival function 12 | of the individual to be a mixture of K Cox Models. Conditioned on each 13 | subgroup Z=k; the PH assumptions are assumed to hold and the baseline 14 | hazard rates is determined non-parametrically using an spline-interpolated 15 | Breslow's estimator. 16 | """ 17 | 18 | def _init_dcm_layers(self, lastdim): 19 | 20 | self.gate = torch.nn.Linear(lastdim, self.k, bias=False) 21 | self.expert = torch.nn.Linear(lastdim, self.k, bias=False) 22 | 23 | def __init__(self, inputdim, k, gamma=1, use_activation=False, layers=None, optimizer='Adam'): 24 | 25 | super(DeepCoxMixturesTorch, self).__init__() 26 | 27 | if not isinstance(k, int): 28 | raise ValueError(f'k must be int, but supplied k is {type(k)}') 29 | 30 | self.k = k 31 | self.optimizer = optimizer 32 | 33 | if layers is None: layers = [] 34 | self.layers = layers 35 | 36 | if len(layers) == 0: lastdim = inputdim 37 | else: lastdim = layers[-1] 38 | 39 | self._init_dcm_layers(lastdim) 40 | self.embedding = create_representation(inputdim, layers, 'ReLU6') 41 | self.gamma = gamma 42 | self.use_activation = use_activation 43 | 44 | def forward(self, x): 45 | 46 | gamma = self.gamma 47 | 48 | x = self.embedding(x) 49 | if self.use_activation: 50 | log_hazard_ratios = gamma*torch.nn.Tanh()(self.expert(x)) 51 | else: 52 | log_hazard_ratios = torch.clamp(self.expert(x), min=-gamma, max=gamma) 53 | log_gate_prob = torch.nn.LogSoftmax(dim=1)(self.gate(x)) 54 | 55 | return log_gate_prob, log_hazard_ratios 56 | -------------------------------------------------------------------------------- /auton_survival/models/dcm/dcm_utilities.py: -------------------------------------------------------------------------------- 1 | 2 | import logging 3 | from matplotlib.pyplot import get 4 | 5 | import torch 6 | import numpy as np 7 | 8 | from scipy.interpolate import UnivariateSpline 9 | from sksurv.linear_model.coxph import BreslowEstimator 10 | 11 | from sklearn.utils import shuffle 12 | 13 | 14 | from tqdm import tqdm 15 | 16 | 17 | from auton_survival.models.dsm.utilities import get_optimizer 18 | 19 | def randargmax(b,**kw): 20 | """ a random tie-breaking argmax""" 21 | return np.argmax(np.random.random(b.shape) * (b==b.max()), **kw) 22 | 23 | def partial_ll_loss(lrisks, tb, eb, eps=1e-2): 24 | 25 | tb = tb + eps*np.random.random(len(tb)) 26 | sindex = np.argsort(-tb) 27 | 28 | tb = tb[sindex] 29 | eb = eb[sindex] 30 | 31 | lrisks = lrisks[sindex] # lrisks = tf.gather(lrisks, sindex) 32 | # lrisksdenom = tf.math.cumulative_logsumexp(lrisks) 33 | lrisksdenom = torch.logcumsumexp(lrisks, dim = 0) 34 | 35 | plls = lrisks - lrisksdenom 36 | pll = plls[eb == 1] 37 | 38 | pll = torch.sum(pll) # pll = tf.reduce_sum(pll) 39 | 40 | return -pll 41 | 42 | def fit_spline(t, surv, s=1e-4): 43 | return UnivariateSpline(t, surv, s=s, ext=3, k=1) 44 | 45 | def smooth_bl_survival(breslow, smoothing_factor): 46 | 47 | blsurvival = breslow.baseline_survival_ 48 | x, y = blsurvival.x, blsurvival.y 49 | return fit_spline(x, y, s=smoothing_factor) 50 | 51 | def get_probability_(lrisks, ts, spl): 52 | risks = np.exp(lrisks) 53 | s0ts = (-risks)*(spl(ts)**(risks-1)) 54 | return s0ts * spl.derivative()(ts) 55 | 56 | def get_survival_(lrisks, ts, spl): 57 | risks = np.exp(lrisks) 58 | return spl(ts)**risks 59 | 60 | def get_probability(lrisks, breslow_splines, t): 61 | psurv = [] 62 | for i in range(lrisks.shape[1]): 63 | p = get_probability_(lrisks[:, i], t, breslow_splines[i]) 64 | psurv.append(p) 65 | psurv = np.array(psurv).T 66 | return psurv 67 | 68 | def get_survival(lrisks, breslow_splines, t): 69 | psurv = [] 70 | for i in range(lrisks.shape[1]): 71 | p = get_survival_(lrisks[:, i], t, breslow_splines[i]) 72 | psurv.append(p) 73 | psurv = np.array(psurv).T 74 | return psurv 75 | 76 | def get_posteriors(probs): 77 | #probs_ = probs+1e-8 78 | return probs-torch.logsumexp(probs, dim=1).reshape(-1,1) 79 | 80 | def get_hard_z(gates_prob): 81 | return torch.argmax(gates_prob, dim=1) 82 | 83 | def sample_hard_z(gates_prob): 84 | return torch.multinomial(gates_prob.exp(), num_samples=1)[:, 0] 85 | 86 | def repair_probs(probs): 87 | probs[torch.isnan(probs)] = -10 88 | probs[probs<-10] = -10 89 | return probs 90 | 91 | def get_likelihood(model, breslow_splines, x, t, e): 92 | 93 | # Function requires numpy/torch 94 | 95 | gates, lrisks = model(x) 96 | lrisks = lrisks.numpy() 97 | e, t = e.numpy(), t.numpy() 98 | 99 | survivals = get_survival(lrisks, breslow_splines, t) 100 | probability = get_probability(lrisks, breslow_splines, t) 101 | 102 | event_probs = np.array([survivals, probability]) 103 | event_probs = event_probs[e.astype('int'), range(len(e)), :] 104 | #event_probs[event_probs<1e-10] = 1e-10 105 | probs = gates+np.log(event_probs) 106 | # else: 107 | # gates_prob = torch.nn.Softmax(dim = 1)(gates) 108 | # probs = gates_prob*event_probs 109 | return probs 110 | 111 | def q_function(model, x, t, e, posteriors, typ='soft'): 112 | 113 | if typ == 'hard': z = get_hard_z(posteriors) 114 | else: z = sample_hard_z(posteriors) 115 | 116 | gates, lrisks = model(x) 117 | 118 | k = model.k 119 | 120 | loss = 0 121 | for i in range(k): 122 | lrisks_ = lrisks[z == i][:, i] 123 | loss += partial_ll_loss(lrisks_, t[z == i], e[z == i]) 124 | 125 | #log_smax_loss = -torch.nn.LogSoftmax(dim=1)(gates) # tf.nn.log_softmax(gates) 126 | 127 | gate_loss = posteriors.exp()*gates 128 | gate_loss = -torch.sum(gate_loss) 129 | loss+=gate_loss 130 | 131 | return loss 132 | 133 | def e_step(model, breslow_splines, x, t, e): 134 | 135 | # TODO: Do this in `Log Space` 136 | if breslow_splines is None: 137 | # If Breslow splines are not available, like in the first 138 | # iteration of learning, we randomly compute posteriors. 139 | posteriors = get_posteriors(torch.rand(len(x), model.k)) 140 | pass 141 | else: 142 | probs = get_likelihood(model, breslow_splines, x, t, e) 143 | posteriors = get_posteriors(repair_probs(probs)) 144 | 145 | return posteriors 146 | 147 | def m_step(model, optimizer, x, t, e, posteriors, typ='soft'): 148 | 149 | optimizer.zero_grad() 150 | loss = q_function(model, x, t, e, posteriors, typ) 151 | loss.backward() 152 | optimizer.step() 153 | 154 | return float(loss) 155 | 156 | def fit_breslow(model, x, t, e, posteriors=None, 157 | smoothing_factor=1e-4, typ='soft'): 158 | 159 | # TODO: Make Breslow in Torch !!! 160 | 161 | gates, lrisks = model(x) 162 | 163 | lrisks = lrisks.numpy() 164 | 165 | e = e.numpy() 166 | t = t.numpy() 167 | 168 | if posteriors is None: z_probs = gates 169 | else: z_probs = posteriors 170 | 171 | if typ == 'soft': z = sample_hard_z(z_probs) 172 | else: z = get_hard_z(z_probs) 173 | 174 | breslow_splines = {} 175 | for i in range(model.k): 176 | breslowk = BreslowEstimator().fit(lrisks[:, i][z==i], e[z==i], t[z==i]) 177 | breslow_splines[i] = smooth_bl_survival(breslowk, 178 | smoothing_factor=smoothing_factor) 179 | 180 | return breslow_splines 181 | 182 | 183 | def train_step(model, x, t, e, breslow_splines, optimizer, 184 | bs=256, seed=100, typ='soft', use_posteriors=False, 185 | update_splines_after=10, smoothing_factor=1e-4): 186 | 187 | x, t, e = shuffle(x, t, e, random_state=seed) 188 | 189 | n = x.shape[0] 190 | 191 | batches = (n // bs) + 1 192 | 193 | epoch_loss = 0 194 | for i in range(batches): 195 | 196 | xb = x[i*bs:(i+1)*bs] 197 | tb = t[i*bs:(i+1)*bs] 198 | eb = e[i*bs:(i+1)*bs] 199 | #ab = a[i*bs:(i+1)*bs] 200 | 201 | # E-Step !!! 202 | # e_step_start = time.time() 203 | with torch.no_grad(): 204 | posteriors = e_step(model, breslow_splines, xb, tb, eb) 205 | 206 | torch.enable_grad() 207 | loss = m_step(model, optimizer, xb, tb, eb, posteriors, typ=typ) 208 | 209 | with torch.no_grad(): 210 | try: 211 | if i%update_splines_after == 0: 212 | if use_posteriors: 213 | 214 | posteriors = e_step(model, breslow_splines, x, t, e) 215 | breslow_splines = fit_breslow(model, x, t, e, 216 | posteriors=posteriors, 217 | typ='soft', 218 | smoothing_factor=smoothing_factor) 219 | else: 220 | breslow_splines = fit_breslow(model, x, t, e, 221 | posteriors=None, 222 | typ='soft', 223 | smoothing_factor=smoothing_factor) 224 | # print(f'Duration of Breslow spline estimation: {time.time() - estimate_breslow_start}') 225 | except Exception as exce: 226 | print("Exception!!!:", exce) 227 | logging.warning("Couldn't fit splines, reusing from previous epoch") 228 | epoch_loss += loss 229 | #print (epoch_loss/n) 230 | return breslow_splines 231 | 232 | 233 | def test_step(model, x, t, e, breslow_splines, loss='q', typ='soft'): 234 | 235 | if loss == 'q': 236 | with torch.no_grad(): 237 | posteriors = e_step(model, breslow_splines, x, t, e) 238 | loss = q_function(model, x, t, e, posteriors, typ=typ) 239 | 240 | return float(loss/x.shape[0]) 241 | 242 | 243 | def train_dcm(model, train_data, val_data, epochs=50, 244 | patience=3, vloss='q', bs=256, typ='soft', lr=1e-3, 245 | use_posteriors=True, debug=False, random_seed=0, 246 | return_losses=False, update_splines_after=10, 247 | smoothing_factor=1e-2): 248 | 249 | torch.manual_seed(random_seed) 250 | np.random.seed(random_seed) 251 | 252 | if val_data is None: 253 | val_data = train_data 254 | 255 | xt, tt, et = train_data 256 | xv, tv, ev = val_data 257 | 258 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 259 | optimizer = get_optimizer(model, lr) 260 | 261 | valc = np.inf 262 | patience_ = 0 263 | 264 | breslow_splines = None 265 | 266 | losses = [] 267 | 268 | for epoch in tqdm(range(epochs)): 269 | 270 | # train_step_start = time.time() 271 | breslow_splines = train_step(model, xt, tt, et, breslow_splines, 272 | optimizer, bs=bs, seed=epoch, typ=typ, 273 | use_posteriors=use_posteriors, 274 | update_splines_after=update_splines_after, 275 | smoothing_factor=smoothing_factor) 276 | # print(f'Duration of train-step: {time.time() - train_step_start}') 277 | # test_step_start = time.time() 278 | valcn = test_step(model, xv, tv, ev, breslow_splines, loss=vloss, typ=typ) 279 | # print(f'Duration of test-step: {time.time() - test_step_start}') 280 | 281 | losses.append(valcn) 282 | 283 | if epoch % 1 == 0: 284 | if debug: print(patience_, epoch, valcn) 285 | 286 | if valcn > valc: patience_ += 1 287 | else: patience_ = 0 288 | 289 | if patience_ == patience: 290 | if return_losses: return (model, breslow_splines), losses 291 | else: return (model, breslow_splines) 292 | 293 | valc = valcn 294 | 295 | if return_losses: return (model, breslow_splines), losses 296 | else: return (model, breslow_splines) 297 | 298 | 299 | def predict_survival(model, x, t): 300 | 301 | if isinstance(t, int) or isinstance(t, float): t = [t] 302 | 303 | model, breslow_splines = model 304 | gates, lrisks = model(x) 305 | 306 | lrisks = lrisks.detach().numpy() 307 | gate_probs = torch.exp(gates).detach().numpy() 308 | 309 | predictions = [] 310 | 311 | for t_ in t: 312 | expert_output = get_survival(lrisks, breslow_splines, t_) 313 | predictions.append((gate_probs*expert_output).sum(axis=1)) 314 | 315 | return np.array(predictions).T 316 | 317 | def predict_latent_z(model, x): 318 | 319 | model, _ = model 320 | gates, _ = model(x) 321 | 322 | gate_probs = torch.exp(gates).detach().numpy() 323 | 324 | return gate_probs 325 | -------------------------------------------------------------------------------- /auton_survival/models/dsm/datasets.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # MIT License 3 | 4 | # Copyright (c) 2020 Carnegie Mellon University, Auton Lab 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | 25 | """Utility functions to load standard datasets to train and evaluate the 26 | Deep Survival Machines models. 27 | """ 28 | 29 | 30 | import io 31 | import pkgutil 32 | 33 | import pandas as pd 34 | import numpy as np 35 | 36 | from sklearn.impute import SimpleImputer 37 | from sklearn.preprocessing import StandardScaler 38 | 39 | import torchvision 40 | 41 | def increase_censoring(e, t, p): 42 | 43 | uncens = np.where(e == 1)[0] 44 | mask = np.random.choice([False, True], len(uncens), p=[1-p, p]) 45 | toswitch = uncens[mask] 46 | 47 | e[toswitch] = 0 48 | t_ = t[toswitch] 49 | 50 | newt = [] 51 | for t__ in t_: 52 | newt.append(np.random.uniform(1, t__)) 53 | t[toswitch] = newt 54 | 55 | return e, t 56 | 57 | def _load_framingham_dataset(sequential): 58 | """Helper function to load and preprocess the Framingham dataset. 59 | 60 | The Framingham Dataset is a subset of 4,434 participants of the well known, 61 | ongoing Framingham Heart study [1] for studying epidemiology for 62 | hypertensive and arteriosclerotic cardiovascular disease. It is a popular 63 | dataset for longitudinal survival analysis with time dependent covariates. 64 | 65 | Parameters 66 | ---------- 67 | sequential: bool 68 | If True returns a list of np.arrays for each individual. 69 | else, returns collapsed results for each time step. To train 70 | recurrent neural models you would typically use True. 71 | 72 | References 73 | ---------- 74 | [1] Dawber, Thomas R., Gilcin F. Meadors, and Felix E. Moore Jr. 75 | "Epidemiological approaches to heart disease: the Framingham Study." 76 | American Journal of Public Health and the Nations Health 41.3 (1951). 77 | 78 | """ 79 | 80 | data = pkgutil.get_data(__name__, 'datasets/framingham.csv') 81 | data = pd.read_csv(io.BytesIO(data)) 82 | 83 | dat_cat = data[['SEX', 'CURSMOKE', 'DIABETES', 'BPMEDS', 84 | 'educ', 'PREVCHD', 'PREVAP', 'PREVMI', 85 | 'PREVSTRK', 'PREVHYP']] 86 | dat_num = data[['TOTCHOL', 'AGE', 'SYSBP', 'DIABP', 87 | 'CIGPDAY', 'BMI', 'HEARTRTE', 'GLUCOSE']] 88 | 89 | x1 = pd.get_dummies(dat_cat).values 90 | x2 = dat_num.values 91 | x = np.hstack([x1, x2]) 92 | 93 | time = (data['TIMEDTH'] - data['TIME']).values 94 | event = data['DEATH'].values 95 | 96 | x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) 97 | x_ = StandardScaler().fit_transform(x) 98 | 99 | if not sequential: 100 | return x_, time, event 101 | else: 102 | x, t, e = [], [], [] 103 | for id_ in sorted(list(set(data['RANDID']))): 104 | x.append(x_[data['RANDID'] == id_]) 105 | t.append(time[data['RANDID'] == id_]) 106 | e.append(event[data['RANDID'] == id_]) 107 | return x, t, e 108 | 109 | def _load_pbc_dataset(sequential): 110 | """Helper function to load and preprocess the PBC dataset 111 | 112 | The Primary biliary cirrhosis (PBC) Dataset [1] is well known 113 | dataset for evaluating survival analysis models with time 114 | dependent covariates. 115 | 116 | Parameters 117 | ---------- 118 | sequential: bool 119 | If True returns a list of np.arrays for each individual. 120 | else, returns collapsed results for each time step. To train 121 | recurrent neural models you would typically use True. 122 | 123 | References 124 | ---------- 125 | [1] Fleming, Thomas R., and David P. Harrington. Counting processes and 126 | survival analysis. Vol. 169. John Wiley & Sons, 2011. 127 | 128 | """ 129 | 130 | data = pkgutil.get_data(__name__, 'datasets/pbc2.csv') 131 | data = pd.read_csv(io.BytesIO(data)) 132 | 133 | data['histologic'] = data['histologic'].astype(str) 134 | dat_cat = data[['drug', 'sex', 'ascites', 'hepatomegaly', 135 | 'spiders', 'edema', 'histologic']] 136 | dat_num = data[['serBilir', 'serChol', 'albumin', 'alkaline', 137 | 'SGOT', 'platelets', 'prothrombin']] 138 | age = data['age'] + data['years'] 139 | 140 | x1 = pd.get_dummies(dat_cat).values 141 | x2 = dat_num.values 142 | x3 = age.values.reshape(-1, 1) 143 | x = np.hstack([x1, x2, x3]) 144 | 145 | time = (data['years'] - data['year']).values 146 | event = data['status2'].values 147 | 148 | x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) 149 | x_ = StandardScaler().fit_transform(x) 150 | 151 | if not sequential: 152 | return x_, time, event 153 | else: 154 | x, t, e = [], [], [] 155 | for id_ in sorted(list(set(data['id']))): 156 | x.append(x_[data['id'] == id_]) 157 | t.append(time[data['id'] == id_]) 158 | e.append(event[data['id'] == id_]) 159 | return x, t, e 160 | 161 | def _load_support_dataset(): 162 | """Helper function to load and preprocess the SUPPORT dataset. 163 | 164 | The SUPPORT Dataset comes from the Vanderbilt University study 165 | to estimate survival for seriously ill hospitalized adults [1]. 166 | 167 | Please refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc. 168 | for the original datasource. 169 | 170 | References 171 | ---------- 172 | [1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic 173 | model: Objective estimates of survival for seriously ill hospitalized 174 | adults. Annals of Internal Medicine 122:191-203. 175 | 176 | """ 177 | 178 | data = pkgutil.get_data(__name__, 'datasets/support2.csv') 179 | data = pd.read_csv(io.BytesIO(data)) 180 | x1 = data[['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', 'temp', 181 | 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', 'glucose', 'bun', 182 | 'urine', 'adlp', 'adls']] 183 | 184 | catfeats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca'] 185 | x2 = pd.get_dummies(data[catfeats]) 186 | 187 | x = np.concatenate([x1, x2], axis=1) 188 | t = data['d.time'].values 189 | e = data['death'].values 190 | 191 | x = SimpleImputer(missing_values=np.nan, strategy='mean').fit_transform(x) 192 | x = StandardScaler().fit_transform(x) 193 | 194 | remove = ~np.isnan(t) 195 | return x[remove], t[remove], e[remove] 196 | 197 | def _load_mnist(): 198 | """Helper function to load and preprocess the MNIST dataset. 199 | 200 | The MNIST database of handwritten digits, available from this page, has a 201 | training set of 60,000 examples, and a test set of 10,000 examples. 202 | It is a good database for people who want to try learning techniques and 203 | pattern recognition methods on real-world data while spending minimal 204 | efforts on preprocessing and formatting [1]. 205 | 206 | Please refer to http://yann.lecun.com/exdb/mnist/. 207 | for the original datasource. 208 | 209 | References 210 | ---------- 211 | [1]: LeCun, Y. (1998). The MNIST database of handwritten digits. 212 | http://yann.lecun.com/exdb/mnist/. 213 | 214 | """ 215 | 216 | 217 | train = torchvision.datasets.MNIST(root='datasets/', 218 | train=True, download=True) 219 | x = train.data.numpy() 220 | x = np.expand_dims(x, 1).astype(float) 221 | t = train.targets.numpy().astype(float) + 1 222 | 223 | e, t = increase_censoring(np.ones(t.shape), t, p=.5) 224 | 225 | return x, t, e 226 | 227 | def load_dataset(dataset='SUPPORT', **kwargs): 228 | """Helper function to load datasets to test Survival Analysis models. 229 | 230 | Currently implemented datasets include: 231 | 232 | **SUPPORT**: This dataset comes from the Vanderbilt University study 233 | to estimate survival for seriously ill hospitalized adults [1]. 234 | (Refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc. 235 | for the original datasource.) 236 | 237 | **PBC**: The Primary biliary cirrhosis dataset [2] is well known 238 | dataset for evaluating survival analysis models with time 239 | dependent covariates. 240 | 241 | **FRAMINGHAM**: This dataset is a subset of 4,434 participants of the well 242 | known, ongoing Framingham Heart study [3] for studying epidemiology for 243 | hypertensive and arteriosclerotic cardiovascular disease. It is a popular 244 | dataset for longitudinal survival analysis with time dependent covariates. 245 | 246 | References 247 | ----------- 248 | 249 | [1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic 250 | model: Objective estimates of survival for seriously ill hospitalized 251 | adults. Annals of Internal Medicine 122:191-203. 252 | 253 | [2] Fleming, Thomas R., and David P. Harrington. Counting processes and 254 | survival analysis. Vol. 169. John Wiley & Sons, 2011. 255 | 256 | [3] Dawber, Thomas R., Gilcin F. Meadors, and Felix E. Moore Jr. 257 | "Epidemiological approaches to heart disease: the Framingham Study." 258 | American Journal of Public Health and the Nations Health 41.3 (1951). 259 | 260 | Parameters 261 | ---------- 262 | dataset: str 263 | The choice of dataset to load. Currently implemented is 'SUPPORT', 264 | 'PBC' and 'FRAMINGHAM'. 265 | **kwargs: dict 266 | Dataset specific keyword arguments. 267 | 268 | Returns 269 | ---------- 270 | tuple: (np.ndarray, np.ndarray, np.ndarray) 271 | A tuple of the form of (x, t, e) where x, t, e are the input covariates, 272 | event times and the censoring indicators respectively. 273 | 274 | """ 275 | sequential = kwargs.get('sequential', False) 276 | 277 | if dataset == 'SUPPORT': 278 | return _load_support_dataset() 279 | if dataset == 'PBC': 280 | return _load_pbc_dataset(sequential) 281 | if dataset == 'FRAMINGHAM': 282 | return _load_framingham_dataset(sequential) 283 | if dataset == 'MNIST': 284 | return _load_mnist() 285 | else: 286 | raise NotImplementedError('Dataset '+dataset+' not implemented.') 287 | -------------------------------------------------------------------------------- /auton_survival/models/dsm/utilities.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # MIT License 3 | 4 | # Copyright (c) 2020 Carnegie Mellon University, Auton Lab 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | 25 | """Utility functions to train the Deep Survival Machines models""" 26 | 27 | from .dsm_torch import DeepSurvivalMachinesTorch 28 | from .losses import unconditional_loss, conditional_loss 29 | 30 | from sklearn.utils import shuffle 31 | 32 | from tqdm import tqdm 33 | from copy import deepcopy 34 | 35 | import torch 36 | import numpy as np 37 | 38 | import gc 39 | import logging 40 | 41 | 42 | def get_optimizer(model, lr): 43 | 44 | if model.optimizer == 'Adam': 45 | return torch.optim.Adam(model.parameters(), lr=lr) 46 | elif model.optimizer == 'SGD': 47 | return torch.optim.SGD(model.parameters(), lr=lr) 48 | elif model.optimizer == 'RMSProp': 49 | return torch.optim.RMSprop(model.parameters(), lr=lr) 50 | else: 51 | raise NotImplementedError('Optimizer '+model.optimizer+ 52 | ' is not implemented') 53 | 54 | def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, 55 | n_iter=10000, lr=1e-2, thres=1e-4): 56 | 57 | premodel = DeepSurvivalMachinesTorch(1, 1, 58 | dist=model.dist, 59 | risks=model.risks, 60 | optimizer=model.optimizer) 61 | premodel.double() 62 | 63 | optimizer = get_optimizer(premodel, lr) 64 | 65 | oldcost = float('inf') 66 | patience = 0 67 | costs = [] 68 | for _ in tqdm(range(n_iter)): 69 | 70 | optimizer.zero_grad() 71 | loss = 0 72 | for r in range(model.risks): 73 | loss += unconditional_loss(premodel, t_train, e_train, str(r+1)) 74 | loss.backward() 75 | optimizer.step() 76 | 77 | valid_loss = 0 78 | for r in range(model.risks): 79 | valid_loss += unconditional_loss(premodel, t_valid, e_valid, str(r+1)) 80 | valid_loss = valid_loss.detach().cpu().numpy() 81 | costs.append(valid_loss) 82 | #print(valid_loss) 83 | if np.abs(costs[-1] - oldcost) < thres: 84 | patience += 1 85 | if patience == 3: 86 | break 87 | oldcost = costs[-1] 88 | 89 | return premodel 90 | 91 | def _reshape_tensor_with_nans(data): 92 | """Helper function to unroll padded RNN inputs.""" 93 | data = data.reshape(-1) 94 | return data[~torch.isnan(data)] 95 | 96 | def _get_padded_features(x): 97 | """Helper function to pad variable length RNN inputs with nans.""" 98 | d = max([len(x_) for x_ in x]) 99 | padx = [] 100 | for i in range(len(x)): 101 | pads = np.nan*np.ones((d - len(x[i]),) + x[i].shape[1:]) 102 | padx.append(np.concatenate([x[i], pads])) 103 | return np.array(padx) 104 | 105 | def _get_padded_targets(t): 106 | """Helper function to pad variable length RNN inputs with nans.""" 107 | d = max([len(t_) for t_ in t]) 108 | padt = [] 109 | for i in range(len(t)): 110 | pads = np.nan*np.ones(d - len(t[i])) 111 | padt.append(np.concatenate([t[i], pads])) 112 | return np.array(padt)[:, :, np.newaxis] 113 | 114 | def train_dsm(model, 115 | x_train, t_train, e_train, 116 | x_valid, t_valid, e_valid, 117 | n_iter=10000, lr=1e-3, elbo=True, 118 | bs=100, random_seed=0): 119 | """Function to train the torch instance of the model.""" 120 | 121 | torch.manual_seed(random_seed) 122 | np.random.seed(random_seed) 123 | 124 | logging.info('Pretraining the Underlying Distributions...') 125 | # For padded variable length sequences we first unroll the input and 126 | # mask out the padded nans. 127 | t_train_ = _reshape_tensor_with_nans(t_train) 128 | e_train_ = _reshape_tensor_with_nans(e_train) 129 | t_valid_ = _reshape_tensor_with_nans(t_valid) 130 | e_valid_ = _reshape_tensor_with_nans(e_valid) 131 | 132 | premodel = pretrain_dsm(model, 133 | t_train_, 134 | e_train_, 135 | t_valid_, 136 | e_valid_, 137 | n_iter=10000, 138 | lr=1e-2, 139 | thres=1e-4) 140 | 141 | for r in range(model.risks): 142 | model.shape[str(r+1)].data.fill_(float(premodel.shape[str(r+1)])) 143 | model.scale[str(r+1)].data.fill_(float(premodel.scale[str(r+1)])) 144 | 145 | model.double() 146 | optimizer = get_optimizer(model, lr) 147 | 148 | patience = 0 149 | oldcost = float('inf') 150 | 151 | nbatches = int(x_train.shape[0]/bs)+1 152 | 153 | dics = [] 154 | costs = [] 155 | i = 0 156 | for i in tqdm(range(n_iter)): 157 | 158 | x_train, t_train, e_train = shuffle(x_train, t_train, e_train, random_state=i) 159 | 160 | for j in range(nbatches): 161 | 162 | xb = x_train[j*bs:(j+1)*bs] 163 | tb = t_train[j*bs:(j+1)*bs] 164 | eb = e_train[j*bs:(j+1)*bs] 165 | 166 | if xb.shape[0] == 0: 167 | continue 168 | 169 | optimizer.zero_grad() 170 | loss = 0 171 | for r in range(model.risks): 172 | loss += conditional_loss(model, 173 | xb, 174 | _reshape_tensor_with_nans(tb), 175 | _reshape_tensor_with_nans(eb), 176 | elbo=elbo, 177 | risk=str(r+1)) 178 | #print ("Train Loss:", float(loss)) 179 | loss.backward() 180 | optimizer.step() 181 | 182 | valid_loss = 0 183 | for r in range(model.risks): 184 | valid_loss += conditional_loss(model, 185 | x_valid, 186 | t_valid_, 187 | e_valid_, 188 | elbo=False, 189 | risk=str(r+1)) 190 | 191 | valid_loss = valid_loss.detach().cpu().numpy() 192 | costs.append(float(valid_loss)) 193 | dics.append(deepcopy(model.state_dict())) 194 | 195 | if costs[-1] >= oldcost: 196 | if patience == 2: 197 | minm = np.argmin(costs) 198 | model.load_state_dict(dics[minm]) 199 | 200 | del dics 201 | gc.collect() 202 | 203 | return model, i 204 | else: 205 | patience += 1 206 | else: 207 | patience = 0 208 | 209 | oldcost = costs[-1] 210 | 211 | minm = np.argmin(costs) 212 | model.load_state_dict(dics[minm]) 213 | 214 | del dics 215 | gc.collect() 216 | 217 | return model, i 218 | 219 | -------------------------------------------------------------------------------- /auton_survival/reporting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from lifelines import KaplanMeierFitter, NelsonAalenFitter 5 | 6 | from lifelines import KaplanMeierFitter 7 | from lifelines.plotting import add_at_risk_counts 8 | 9 | 10 | def plot_kaplanmeier(outcomes, groups=None, plot_counts=False, **kwargs): 11 | 12 | """Plot a Kaplan-Meier Survival Estimator stratified by groups. 13 | 14 | Parameters 15 | ---------- 16 | outcomes: pandas.DataFrame 17 | a pandas dataframe containing the survival outcomes. The index of the 18 | dataframe should be the same as the index of the features dataframe. 19 | Should contain a column named 'time' that contains the survival time and 20 | a column named 'event' that contains the censoring status. 21 | \( \delta_i = 1 \) if the event is observed. 22 | groups: pandas.Series 23 | a pandas series containing the groups to stratify the Kaplan-Meier 24 | estimates by. 25 | plot_counts: bool 26 | if True, plot the number of at risk and censored individuals in each group. 27 | 28 | """ 29 | 30 | if groups is None: 31 | groups = np.array([1]*len(outcomes)) 32 | 33 | curves = {} 34 | 35 | from matplotlib import pyplot as plt 36 | 37 | ax = plt.subplot(111) 38 | 39 | for group in sorted(set(groups)): 40 | if pd.isna(group): continue 41 | 42 | curves[group] = KaplanMeierFitter().fit(outcomes[groups==group]['time'], 43 | outcomes[groups==group]['event']) 44 | ax = curves[group].plot(label=group, ax=ax, **kwargs) 45 | 46 | if plot_counts: 47 | add_at_risk_counts(iter([curves[group] for group in curves]), ax=ax) 48 | 49 | return ax 50 | 51 | 52 | def plot_nelsonaalen(outcomes, groups=None, **kwargs): 53 | 54 | """Plot a Nelson-Aalen Survival Estimator stratified by groups. 55 | 56 | Parameters 57 | ---------- 58 | outcomes: pandas.DataFrame 59 | a pandas dataframe containing the survival outcomes. The index of the 60 | dataframe should be the same as the index of the features dataframe. 61 | Should contain a column named 'time' that contains the survival time and 62 | a column named 'event' that contains the censoring status. 63 | \( \delta_i = 1 \) if the event is observed. 64 | groups: pandas.Series 65 | a pandas series containing the groups to stratify the Kaplan-Meier 66 | estimates by. 67 | 68 | """ 69 | 70 | if groups is None: 71 | groups = np.array([1]*len(outcomes)) 72 | 73 | for group in sorted(set(groups)): 74 | if pd.isna(group): continue 75 | 76 | print('Group:', group) 77 | 78 | NelsonAalenFitter().fit(outcomes[groups==group]['time'], 79 | outcomes[groups==group]['event']).plot(label=group, 80 | **kwargs) 81 | 82 | 83 | -------------------------------------------------------------------------------- /auton_survival/utils.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import pandas as pd 3 | 4 | def _get_method_kwargs(method, kwargs): 5 | 6 | assert isinstance(kwargs, dict) 7 | 8 | params = inspect.signature(method).parameters.items() 9 | params = set([param[0] for param in params]) - set(['self']) 10 | 11 | method_params = params&set(kwargs.keys()) 12 | method_kwargs = {k: kwargs[k] for k in method_params} 13 | 14 | return method_kwargs 15 | 16 | def _dataframe_to_array(data): 17 | if isinstance(data, (pd.Series, pd.DataFrame)): 18 | return data.to_numpy() 19 | else: 20 | return data 21 | 22 | # TaR: Code alternative to lifelines.utils.qth_percentile 23 | from scipy.optimize import fsolve 24 | def interp_x(y, x, thres): 25 | if len(y[ythres][0], y[y>thres][0] 30 | root = fsolve(func, x0=x1, args=(thres, x1, y1, x2, y2))[0] 31 | return root 32 | def func(x, y, x1, y1, x2, y2): 33 | return y1 + (x-x1)*((y2-y1)/(x2-x1)) - y 34 | -------------------------------------------------------------------------------- /docs/dashboard.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /docs/datasets.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | auton_survival.datasets API documentation 8 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 20 | 21 | 22 | 23 | 24 | 25 |
26 |
27 |
28 |

Module auton_survival.datasets

29 |
30 |
31 |

Utility functions to load standard datasets to train and evaluate the 32 | Deep Survival Machines models.

33 | 34 |
35 |
36 |
37 |
38 |
39 |
40 |

Functions

41 |
42 |
43 | def increase_censoring(e, t, p, random_seed=0) 44 |
45 |
46 |
47 | 48 |
49 |
50 | def load_support() 51 |
52 |
53 |

Helper function to load and preprocess the SUPPORT dataset. 54 | The SUPPORT Dataset comes from the Vanderbilt University study 55 | to estimate survival for seriously ill hospitalized adults [1]. 56 | Please refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc. 57 | for the original datasource.

58 |

References

59 |

[1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic 60 | model: Objective estimates of survival for seriously ill hospitalized 61 | adults. Annals of Internal Medicine 122:191-203.

62 | 63 |
64 |
65 | def load_synthetic_cf_phenotyping() 66 |
67 |
68 |
69 | 70 |
71 |
72 | def load_dataset(dataset='SUPPORT', **kwargs) 73 |
74 |
75 |

Helper function to load datasets to test Survival Analysis models. 76 | Currently implemented datasets include:

77 |

SUPPORT: This dataset comes from the Vanderbilt University study 78 | to estimate survival for seriously ill hospitalized adults [1]. 79 | (Refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc. 80 | for the original datasource.)

81 |

PBC: The Primary biliary cirrhosis dataset [2] is well known 82 | dataset for evaluating survival analysis models with time 83 | dependent covariates.

84 |

FRAMINGHAM: This dataset is a subset of 4,434 participants of the well 85 | known, ongoing Framingham Heart study [3] for studying epidemiology for 86 | hypertensive and arteriosclerotic cardiovascular disease. It is a popular 87 | dataset for longitudinal survival analysis with time dependent covariates.

88 |

SYNTHETIC: This is a non-linear censored dataset for counterfactual 89 | time-to-event phenotyping. Introduced in [4], the dataset is generated 90 | such that the treatment effect is heterogenous conditioned on the covariates.

91 |

References

92 |

[1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic 93 | model: Objective estimates of survival for seriously ill hospitalized 94 | adults. Annals of Internal Medicine 122:191-203.

95 |

[2] Fleming, Thomas R., and David P. Harrington. Counting processes and 96 | survival analysis. Vol. 169. John Wiley & Sons, 2011.

97 |

[3] Dawber, Thomas R., Gilcin F. Meadors, and Felix E. Moore Jr. 98 | "Epidemiological approaches to heart disease: the Framingham Study." 99 | American Journal of Public Health and the Nations Health 41.3 (1951).

100 |

[4] Nagpal, C., Goswami M., Dufendach K., and Artur Dubrawski. 101 | "Counterfactual phenotyping for censored Time-to-Events" (2022).

102 |

Parameters

103 |
104 |
dataset : str
105 |
The choice of dataset to load. Currently implemented is 'SUPPORT', 106 | 'PBC' and 'FRAMINGHAM'.
107 |
**kwargs : dict
108 |
Dataset specific keyword arguments.
109 |
110 |

Returns

111 |
112 |
tuple : (np.ndarray, np.ndarray, np.ndarray)
113 |
A tuple of the form of (x, t, e) where x 114 | are the input covariates, t the event times and 115 | e the censoring indicators.
116 |
117 | 118 |
119 |
120 |
121 |
122 |
123 |
124 | 146 |
147 | 150 | 151 | -------------------------------------------------------------------------------- /docs/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonlab/auton-survival/5dde465f7223601717abddc1d075e837707c403b/docs/models/.DS_Store -------------------------------------------------------------------------------- /docs/models/cmhe/cmhe_torch.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | auton_survival.models.cmhe.cmhe_torch API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 19 | 20 | 21 | 22 | 23 | 24 |
25 |
26 |
27 |

Module auton_survival.models.cmhe.cmhe_torch

28 |
29 |
30 | 31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |

Classes

40 |
41 |
42 | class DeepCMHETorch 43 | (k, g, inputdim, layers=None, gamma=100, smoothing_factor=0.0001, gate_l2_penalty=0.0001, optimizer='Adam') 44 |
45 |
46 |

PyTorch model definition of the Cox Mixture with Hereogenous Effects Model.

47 |

Cox Mixtures with Heterogenous Effects involves the assuming that the 48 | base survival rates are independent of the treatment effect. 49 | of the individual to be a mixture of K Cox Models. Conditioned on each 50 | subgroup Z=k; the PH assumptions are assumed to hold and the baseline 51 | hazard rates is determined non-parametrically using an spline-interpolated 52 | Breslow's estimator.

53 |

Initializes internal Module state, shared by both nn.Module and ScriptModule.

54 | 55 |

Class variables

56 |
57 |
var dump_patches : bool
58 |
59 |
60 |
61 |
var training : bool
62 |
63 |
64 |
65 |
66 |

Methods

67 |
68 |
69 | def forward(self, x, a) ‑> Callable[..., Any] 70 |
71 |
72 |
73 | 74 |
75 |
76 |
77 |
78 | class IdentifiableLinear 79 | (in_features, out_features, bias=True) 80 |
81 |
82 |

Softmax and LogSoftmax with K classes in pytorch are over specfied and lead to 83 | issues of mis-identifiability. This class is a wrapper for linear layers that 84 | are correctly specified with K-1 columns. The output of this layer for the Kth 85 | class is all zeros. This allows direct application of pytorch.nn.LogSoftmax 86 | and pytorch.nn.Softmax.

87 |

Initializes internal Module state, shared by both nn.Module and ScriptModule.

88 | 89 |

Class variables

90 |
91 |
var dump_patches : bool
92 |
93 |
94 |
95 |
var training : bool
96 |
97 |
98 |
99 |
100 |

Instance variables

101 |
102 |
var weight
103 |
104 |
105 | 106 |
107 |
108 |

Methods

109 |
110 |
111 | def forward(self, x) ‑> Callable[..., Any] 112 |
113 |
114 |
115 | 116 |
117 |
118 |
119 |
120 |
121 |
122 | 157 |
158 | 161 | 162 | -------------------------------------------------------------------------------- /docs/models/cph/dcph_torch.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | auton_survival.models.cph.dcph_torch API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 19 | 20 | 21 | 22 | 23 | 24 |
25 |
26 |
27 |

Module auton_survival.models.cph.dcph_torch

28 |
29 |
30 | 31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |

Classes

40 |
41 |
42 | class DeepCoxPHTorch 43 | (inputdim, layers=None, optimizer='Adam') 44 |
45 |
46 |

Base class for all neural network modules.

47 |

Your models should also subclass this class.

48 |

Modules can also contain other Modules, allowing to nest them in 49 | a tree structure. You can assign the submodules as regular attributes::

50 |
import torch.nn as nn
 51 | import torch.nn.functional as F
 52 | 
 53 | class Model(nn.Module):
 54 |     def __init__(self):
 55 |         super(Model, self).__init__()
 56 |         self.conv1 = nn.Conv2d(1, 20, 5)
 57 |         self.conv2 = nn.Conv2d(20, 20, 5)
 58 | 
 59 |     def forward(self, x):
 60 |         x = F.relu(self.conv1(x))
 61 |         return F.relu(self.conv2(x))
 62 | 
63 |

Submodules assigned in this way will be registered, and will have their 64 | parameters converted too when you call :meth:to, etc.

65 |

Initializes internal Module state, shared by both nn.Module and ScriptModule.

66 | 67 |

Class variables

68 |
69 |
var dump_patches : bool
70 |
71 |
72 |
73 |
var training : bool
74 |
75 |
76 |
77 |
78 |

Methods

79 |
80 |
81 | def forward(self, x) ‑> Callable[..., Any] 82 |
83 |
84 |
85 | 86 |
87 |
88 |
89 |
90 | class DeepRecurrentCoxPHTorch 91 | (inputdim, typ='LSTM', layers=1, hidden=None, optimizer='Adam') 92 |
93 |
94 |

Base class for all neural network modules.

95 |

Your models should also subclass this class.

96 |

Modules can also contain other Modules, allowing to nest them in 97 | a tree structure. You can assign the submodules as regular attributes::

98 |
import torch.nn as nn
 99 | import torch.nn.functional as F
100 | 
101 | class Model(nn.Module):
102 |     def __init__(self):
103 |         super(Model, self).__init__()
104 |         self.conv1 = nn.Conv2d(1, 20, 5)
105 |         self.conv2 = nn.Conv2d(20, 20, 5)
106 | 
107 |     def forward(self, x):
108 |         x = F.relu(self.conv1(x))
109 |         return F.relu(self.conv2(x))
110 | 
111 |

Submodules assigned in this way will be registered, and will have their 112 | parameters converted too when you call :meth:to, etc.

113 |

Initializes internal Module state, shared by both nn.Module and ScriptModule.

114 | 115 |

Class variables

116 |
117 |
var dump_patches : bool
118 |
119 |
120 |
121 |
var training : bool
122 |
123 |
124 |
125 |
126 |

Methods

127 |
128 |
129 | def forward(self, x) ‑> Callable[..., Any] 130 |
131 |
132 |
133 | 134 |
135 |
136 |
137 |
138 |
139 |
140 | 174 |
175 | 178 | 179 | -------------------------------------------------------------------------------- /docs/models/cph/dcph_utilities.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | auton_survival.models.cph.dcph_utilities API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 19 | 20 | 21 | 22 | 23 | 24 |
25 |
26 |
27 |

Module auton_survival.models.cph.dcph_utilities

28 |
29 |
30 | 31 |
32 |
33 |
34 |
35 |
36 |
37 |

Functions

38 |
39 |
40 | def randargmax(b, **kw) 41 |
42 |
43 |

a random tie-breaking argmax

44 | 45 |
46 |
47 | def partial_ll_loss(lrisks, tb, eb, eps=0.001) 48 |
49 |
50 |
51 | 52 |
53 |
54 | def fit_breslow(model, x, t, e) 55 |
56 |
57 |
58 | 59 |
60 |
61 | def train_step(model, x, t, e, optimizer, bs=256, seed=100) 62 |
63 |
64 |
65 | 66 |
67 |
68 | def test_step(model, x, t, e) 69 |
70 |
71 |
72 | 73 |
74 |
75 | def train_dcph(model, train_data, val_data, epochs=50, patience=3, bs=256, lr=0.001, debug=False, random_seed=0, return_losses=False) 76 |
77 |
78 |
79 | 80 |
81 |
82 | def predict_survival(model, x, t=None) 83 |
84 |
85 |
86 | 87 |
88 |
89 |
90 |
91 |
92 |
93 | 118 |
119 | 122 | 123 | -------------------------------------------------------------------------------- /docs/models/dcm/dcm_torch.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | auton_survival.models.dcm.dcm_torch API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 19 | 20 | 21 | 22 | 23 | 24 |
25 |
26 |
27 |

Module auton_survival.models.dcm.dcm_torch

28 |
29 |
30 | 31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |

Classes

40 |
41 |
42 | class DeepCoxMixturesTorch 43 | (inputdim, k, gamma=1, use_activation=False, layers=None, optimizer='Adam') 44 |
45 |
46 |

PyTorch model definition of the Deep Cox Mixture Survival Model.

47 |

The Cox Mixture involves the assumption that the survival function 48 | of the individual to be a mixture of K Cox Models. Conditioned on each 49 | subgroup Z=k; the PH assumptions are assumed to hold and the baseline 50 | hazard rates is determined non-parametrically using an spline-interpolated 51 | Breslow's estimator.

52 |

Initializes internal Module state, shared by both nn.Module and ScriptModule.

53 | 54 |

Class variables

55 |
56 |
var dump_patches : bool
57 |
58 |
59 |
60 |
var training : bool
61 |
62 |
63 |
64 |
65 |

Methods

66 |
67 |
68 | def forward(self, x) ‑> Callable[..., Any] 69 |
70 |
71 |
72 | 73 |
74 |
75 |
76 |
77 |
78 |
79 | 105 |
106 | 109 | 110 | -------------------------------------------------------------------------------- /docs/models/dsm/datasets.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | auton_survival.models.dsm.datasets API documentation 8 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 20 | 21 | 22 | 23 | 24 | 25 |
26 |
27 |
28 |

Module auton_survival.models.dsm.datasets

29 |
30 |
31 |

Utility functions to load standard datasets to train and evaluate the 32 | Deep Survival Machines models.

33 | 34 |
35 |
36 |
37 |
38 |
39 |
40 |

Functions

41 |
42 |
43 | def increase_censoring(e, t, p) 44 |
45 |
46 |
47 | 48 |
49 |
50 | def load_dataset(dataset='SUPPORT', **kwargs) 51 |
52 |
53 |

Helper function to load datasets to test Survival Analysis models.

54 |

Currently implemented datasets include:

55 |

SUPPORT: This dataset comes from the Vanderbilt University study 56 | to estimate survival for seriously ill hospitalized adults [1]. 57 | (Refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc. 58 | for the original datasource.)

59 |

PBC: The Primary biliary cirrhosis dataset [2] is well known 60 | dataset for evaluating survival analysis models with time 61 | dependent covariates.

62 |

FRAMINGHAM: This dataset is a subset of 4,434 participants of the well 63 | known, ongoing Framingham Heart study [3] for studying epidemiology for 64 | hypertensive and arteriosclerotic cardiovascular disease. It is a popular 65 | dataset for longitudinal survival analysis with time dependent covariates.

66 |

References

67 |

[1]: Knaus WA, Harrell FE, Lynn J et al. (1995): The SUPPORT prognostic 68 | model: Objective estimates of survival for seriously ill hospitalized 69 | adults. Annals of Internal Medicine 122:191-203.

70 |

[2] Fleming, Thomas R., and David P. Harrington. Counting processes and 71 | survival analysis. Vol. 169. John Wiley & Sons, 2011.

72 |

[3] Dawber, Thomas R., Gilcin F. Meadors, and Felix E. Moore Jr. 73 | "Epidemiological approaches to heart disease: the Framingham Study." 74 | American Journal of Public Health and the Nations Health 41.3 (1951).

75 |

Parameters

76 |
77 |
dataset : str
78 |
The choice of dataset to load. Currently implemented is 'SUPPORT', 79 | 'PBC' and 'FRAMINGHAM'.
80 |
**kwargs : dict
81 |
Dataset specific keyword arguments.
82 |
83 |

Returns

84 |
85 |
tuple : (np.ndarray, np.ndarray, np.ndarray)
86 |
A tuple of the form of (x, t, e) where x, t, e are the input covariates, 87 | event times and the censoring indicators respectively.
88 |
89 | 90 |
91 |
92 |
93 |
94 |
95 |
96 | 116 |
117 | 120 | 121 | -------------------------------------------------------------------------------- /docs/models/dsm/losses.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | auton_survival.models.dsm.losses API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 19 | 20 | 21 | 22 | 23 | 24 |
25 |
26 |
27 |

Module auton_survival.models.dsm.losses

28 |
29 |
30 |

Loss function definitions for the Deep Survival Machines model

31 |

In this module we define the various losses for the censored and uncensored 32 | instances of data corresponding to Weibull and LogNormal distributions. 33 | These losses are optimized when training DSM.

34 |
35 |

TODO

36 |

Use torch.distributions

37 |
38 |
39 |

Warning

40 |

NOT DESIGNED TO BE CALLED DIRECTLY!!!

41 |
42 | 43 |
44 |
45 |
46 |
47 |
48 |
49 |

Functions

50 |
51 |
52 | def unconditional_loss(model, t, e, risk='1') 53 |
54 |
55 |
56 | 57 |
58 |
59 | def conditional_loss(model, x, t, e, elbo=True, risk='1') 60 |
61 |
62 |
63 | 64 |
65 |
66 | def predict_mean(model, x, risk='1') 67 |
68 |
69 |
70 | 71 |
72 |
73 | def predict_pdf(model, x, t_horizon, risk='1') 74 |
75 |
76 |
77 | 78 |
79 |
80 | def predict_cdf(model, x, t_horizon, risk='1') 81 |
82 |
83 |
84 | 85 |
86 |
87 |
88 |
89 |
90 |
91 | 114 |
115 | 118 | 119 | -------------------------------------------------------------------------------- /docs/models/dsm/utilities.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | auton_survival.models.dsm.utilities API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 19 | 20 | 21 | 22 | 23 | 24 |
25 |
26 |
27 |

Module auton_survival.models.dsm.utilities

28 |
29 |
30 |

Utility functions to train the Deep Survival Machines models

31 | 32 |
33 |
34 |
35 |
36 |
37 |
38 |

Functions

39 |
40 |
41 | def get_optimizer(model, lr) 42 |
43 |
44 |
45 | 46 |
47 |
48 | def pretrain_dsm(model, t_train, e_train, t_valid, e_valid, n_iter=10000, lr=0.01, thres=0.0001) 49 |
50 |
51 |
52 | 53 |
54 |
55 | def train_dsm(model, x_train, t_train, e_train, x_valid, t_valid, e_valid, n_iter=10000, lr=0.001, elbo=True, bs=100, random_seed=0) 56 |
57 |
58 |

Function to train the torch instance of the model.

59 | 60 |
61 |
62 |
63 |
64 |
65 |
66 | 87 |
88 | 91 | 92 | -------------------------------------------------------------------------------- /docs/reporting.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | auton_survival.reporting API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 19 | 20 | 21 | 22 | 23 | 24 |
25 |
26 |
27 |

Module auton_survival.reporting

28 |
29 |
30 | 31 |
32 |
33 |
34 |
35 |
36 |
37 |

Functions

38 |
39 |
40 | def plot_kaplanmeier(outcomes, groups=None, plot_counts=False, **kwargs) 41 |
42 |
43 |

Plot a Kaplan-Meier Survival Estimator stratified by groups.

44 |

Parameters

45 |
46 |
outcomes : pandas.DataFrame
47 |
 
48 |
a pandas dataframe containing the survival outcomes. The index of the
49 |
dataframe should be the same as the index of the features dataframe.
50 |
Should contain a column named 'time' that contains the survival time and
51 |
a column named 'event' that contains the censoring status.
52 |
\delta_i = 1 if the event is observed.
53 |
groups : pandas.Series
54 |
 
55 |
a pandas series containing the groups to stratify the Kaplan-Meier
56 |
estimates by.
57 |
plot_counts : bool
58 |
 
59 |
60 |

if True, plot the number of at risk and censored individuals in each group.

61 | 62 |
63 |
64 | def plot_nelsonaalen(outcomes, groups=None, **kwargs) 65 |
66 |
67 |

Plot a Nelson-Aalen Survival Estimator stratified by groups.

68 |

Parameters

69 |
70 |
outcomes : pandas.DataFrame
71 |
 
72 |
a pandas dataframe containing the survival outcomes. The index of the
73 |
dataframe should be the same as the index of the features dataframe.
74 |
Should contain a column named 'time' that contains the survival time and
75 |
a column named 'event' that contains the censoring status.
76 |
\delta_i = 1 if the event is observed.
77 |
groups : pandas.Series
78 |
 
79 |
80 |

a pandas series containing the groups to stratify the Kaplan-Meier 81 | estimates by.

82 | 83 |
84 |
85 |
86 |
87 |
88 |
89 | 109 |
110 | 113 | 114 | -------------------------------------------------------------------------------- /docs/utils.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | auton_survival.utils API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 19 | 20 | 21 | 22 | 23 | 24 |
25 |
26 |
27 |

Module auton_survival.utils

28 |
29 |
30 | 31 |
32 |
33 |
34 |
35 |
36 |
37 |

Functions

38 |
39 |
40 | def interp_x(y, x, thres) 41 |
42 |
43 |
44 | 45 |
46 |
47 | def func(x, y, x1, y1, x2, y2) 48 |
49 |
50 |
51 | 52 |
53 |
54 |
55 |
56 |
57 |
58 | 78 |
79 | 82 | 83 | -------------------------------------------------------------------------------- /examples/CV Survival Regression on SUPPORT Dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# `auton-survival` Cross Validation Survival Regression\n", 8 | "\n", 9 | "`auton-survival` offers a simple to use API to train Survival Regression Models that performs cross validation model selection by minimizing integrated brier score. In this notebook we demonstrate the use of `auton-survival` to train survival models on the *SUPPORT* dataset in cross validation fashion." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import sys\n", 19 | "\n", 20 | "sys.path.append('../')\n", 21 | "from auton_survival import datasets\n", 22 | "outcomes, features = datasets.load_support()" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "from auton_survival.preprocessing import Preprocessor\n", 32 | "\n", 33 | "cat_feats = ['sex', 'dzgroup', 'dzclass', 'income', 'race', 'ca']\n", 34 | "num_feats = ['age', 'num.co', 'meanbp', 'wblc', 'hrt', 'resp', \n", 35 | " 'temp', 'pafi', 'alb', 'bili', 'crea', 'sod', 'ph', \n", 36 | " 'glucose', 'bun', 'urine', 'adlp', 'adls']\n", 37 | "\n", 38 | "# Data should be processed in a fold-independent manner when performing cross-validation. \n", 39 | "# For simplicity in this demo, we process the dataset in a non-independent manner.\n", 40 | "preprocessor = Preprocessor(cat_feat_strat='ignore', num_feat_strat= 'mean') \n", 41 | "x = preprocessor.fit_transform(features, cat_feats=cat_feats, num_feats=num_feats,\n", 42 | " one_hot=True, fill_value=-1)" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": null, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "import numpy as np\n", 52 | "times = np.quantile(outcomes.time[outcomes.event==1], [0.25, 0.5, 0.75]).tolist()" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "from auton_survival.experiments import SurvivalRegressionCV\n", 62 | "\n", 63 | "param_grid = {'k' : [3],\n", 64 | " 'distribution' : ['Weibull'],\n", 65 | " 'learning_rate' : [1e-4, 1e-3],\n", 66 | " 'layers' : [[100]]}\n", 67 | "\n", 68 | "experiment = SurvivalRegressionCV(model='dsm', num_folds=3, hyperparam_grid=param_grid, random_seed=0)\n", 69 | "model = experiment.fit(x, outcomes, times, metric='brs')" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "print(experiment.folds)\n", 79 | "model" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": null, 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "out_risk = model.predict_risk(x, times)\n", 89 | "out_survival = model.predict_survival(x, times)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "from auton_survival.metrics import survival_regression_metric\n", 99 | "\n", 100 | "for fold in set(experiment.folds):\n", 101 | " print(survival_regression_metric('brs', outcomes[experiment.folds==fold], \n", 102 | " out_survival[experiment.folds==fold], \n", 103 | " times=times))" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "from auton_survival.metrics import survival_regression_metric\n", 113 | "\n", 114 | "for fold in set(experiment.folds):\n", 115 | " print(survival_regression_metric('ctd', outcomes[experiment.folds==fold], \n", 116 | " out_survival[experiment.folds==fold], \n", 117 | " times=times))" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [ 126 | "for fold in set(experiment.folds):\n", 127 | " for time in times:\n", 128 | " print(time)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [] 137 | } 138 | ], 139 | "metadata": { 140 | "interpreter": { 141 | "hash": "c22fbbe4c37d04364aa4e785d7cd9729f94ca3cb878d5f955e35b0076c04a2d7" 142 | }, 143 | "kernelspec": { 144 | "display_name": "Python 3", 145 | "language": "python", 146 | "name": "python3" 147 | }, 148 | "language_info": { 149 | "codemirror_mode": { 150 | "name": "ipython", 151 | "version": 3 152 | }, 153 | "file_extension": ".py", 154 | "mimetype": "text/x-python", 155 | "name": "python", 156 | "nbconvert_exporter": "python", 157 | "pygments_lexer": "ipython3", 158 | "version": "3.9.7" 159 | } 160 | }, 161 | "nbformat": 4, 162 | "nbformat_minor": 4 163 | } 164 | -------------------------------------------------------------------------------- /examples/DCM on SUPPORT Dataset copy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# DSM on SUPPORT Dataset" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "The SUPPORT dataset comes from the Vanderbilt University study\n", 15 | "to estimate survival for seriously ill hospitalized adults.\n", 16 | "(Refer to http://biostat.mc.vanderbilt.edu/wiki/Main/SupportDesc.\n", 17 | "for the original datasource.)\n", 18 | "\n", 19 | "In this notebook, we will apply Deep Survival Machines for survival prediction on the SUPPORT data." 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "### Load the SUPPORT Dataset\n", 27 | "\n", 28 | "The package includes helper functions to load the dataset.\n", 29 | "\n", 30 | "X represents an np.array of features (covariates),\n", 31 | "T is the event/censoring times and,\n", 32 | "E is the censoring indicator." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from dsm import datasets\n", 42 | "x, t, e = datasets.load_dataset('SUPPORT')" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "### Compute horizons at which we evaluate the performance of DSM\n", 50 | "\n", 51 | "Survival predictions are issued at certain time horizons. Here we will evaluate the performance\n", 52 | "of DSM to issue predictions at the 25th, 50th and 75th event time quantile as is standard practice in Survival Analysis." 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "import numpy as np\n", 62 | "horizons = [0.25, 0.5, 0.75]\n", 63 | "times = np.quantile(t[e==1], horizons).tolist()" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "### Splitting the data into train, test and validation sets\n", 71 | "\n", 72 | "We will train DSM on 70% of the Data, use a Validation set of 10% for Model Selection and report performance on the remaining 20% held out test set." 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": {}, 79 | "outputs": [], 80 | "source": [ 81 | "n = len(x)\n", 82 | "\n", 83 | "tr_size = int(n*0.70)\n", 84 | "vl_size = int(n*0.10)\n", 85 | "te_size = int(n*0.20)\n", 86 | "\n", 87 | "x_train, x_test, x_val = x[:tr_size], x[-te_size:], x[tr_size:tr_size+vl_size]\n", 88 | "t_train, t_test, t_val = t[:tr_size], t[-te_size:], t[tr_size:tr_size+vl_size]\n", 89 | "e_train, e_test, e_val = e[:tr_size], e[-te_size:], e[tr_size:tr_size+vl_size]" 90 | ] 91 | }, 92 | { 93 | "cell_type": "markdown", 94 | "metadata": {}, 95 | "source": [ 96 | "### Setting the parameter grid\n", 97 | "\n", 98 | "Lets set up the parameter grid to tune hyper-parameters. We will tune the number of underlying survival distributions, \n", 99 | "($K$), the distribution choices (Log-Normal or Weibull), the learning rate for the Adam optimizer between $1\\times10^{-3}$ and $1\\times10^{-4}$ and the number of hidden layers between $0, 1$ and $2$." 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "from sklearn.model_selection import ParameterGrid" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "param_grid = {'k' : [3, 4, 6],\n", 118 | " 'distribution' : ['LogNormal', 'Weibull'],\n", 119 | " 'learning_rate' : [ 1e-4, 1e-3],\n", 120 | " 'layers' : [ [], [100], [100, 100] ]\n", 121 | " }\n", 122 | "params = ParameterGrid(param_grid)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "### Model Training and Selection" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "from dsm import DeepSurvivalMachines" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": { 145 | "scrolled": true 146 | }, 147 | "outputs": [], 148 | "source": [ 149 | "models = []\n", 150 | "for param in params:\n", 151 | " model = DeepSurvivalMachines(k = param['k'],\n", 152 | " distribution = param['distribution'],\n", 153 | " layers = param['layers'])\n", 154 | " # The fit method is called to train the model\n", 155 | " model.fit(x_train, t_train, e_train, iters = 100, learning_rate = param['learning_rate'])\n", 156 | " models.append([[model.compute_nll(x_val, t_val, e_val), model]])\n", 157 | "best_model = min(models)\n", 158 | "model = best_model[0][1]" 159 | ] 160 | }, 161 | { 162 | "cell_type": "markdown", 163 | "metadata": {}, 164 | "source": [ 165 | "### Inference" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "out_risk = model.predict_risk(x_test, times)\n", 175 | "out_survival = model.predict_survival(x_test, times)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "metadata": {}, 181 | "source": [ 182 | "### Evaluation\n", 183 | "\n", 184 | "We evaluate the performance of DSM in its discriminative ability (Time Dependent Concordance Index and Cumulative Dynamic AUC) as well as Brier Score." 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": null, 199 | "metadata": {}, 200 | "outputs": [], 201 | "source": [ 202 | "cis = []\n", 203 | "brs = []\n", 204 | "\n", 205 | "et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))],\n", 206 | " dtype = [('e', bool), ('t', float)])\n", 207 | "et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))],\n", 208 | " dtype = [('e', bool), ('t', float)])\n", 209 | "et_val = np.array([(e_val[i], t_val[i]) for i in range(len(e_val))],\n", 210 | " dtype = [('e', bool), ('t', float)])\n", 211 | "\n", 212 | "for i, _ in enumerate(times):\n", 213 | " cis.append(concordance_index_ipcw(et_train, et_test, out_risk[:, i], times[i])[0])\n", 214 | "brs.append(brier_score(et_train, et_test, out_survival, times)[1])\n", 215 | "roc_auc = []\n", 216 | "for i, _ in enumerate(times):\n", 217 | " roc_auc.append(cumulative_dynamic_auc(et_train, et_test, out_risk[:, i], times[i])[0])\n", 218 | "for horizon in enumerate(horizons):\n", 219 | " print(f\"For {horizon[1]} quantile,\")\n", 220 | " print(\"TD Concordance Index:\", cis[horizon[0]])\n", 221 | " print(\"Brier Score:\", brs[0][horizon[0]])\n", 222 | " print(\"ROC AUC \", roc_auc[horizon[0]][0], \"\\n\")" 223 | ] 224 | } 225 | ], 226 | "metadata": { 227 | "kernelspec": { 228 | "display_name": "Python 3", 229 | "language": "python", 230 | "name": "python3" 231 | }, 232 | "language_info": { 233 | "codemirror_mode": { 234 | "name": "ipython", 235 | "version": 3 236 | }, 237 | "file_extension": ".py", 238 | "mimetype": "text/x-python", 239 | "name": "python", 240 | "nbconvert_exporter": "python", 241 | "pygments_lexer": "ipython3", 242 | "version": "3.8.3" 243 | } 244 | }, 245 | "nbformat": 4, 246 | "nbformat_minor": 4 247 | } 248 | -------------------------------------------------------------------------------- /examples/RDSM on PBC Dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Recurrent DSM on PBC Dataset" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "The longitudinal PBC dataset comes from the Mayo Clinic trial in primary biliary cirrhosis (PBC) of the liver conducted between 1974 and 1984 (Refer to https://stat.ethz.ch/R-manual/R-devel/library/survival/html/pbc.html)\n", 15 | "\n", 16 | "In this notebook, we will apply Recurrent Deep Survival Machines for survival prediction on the PBC data." 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "### Load the PBC Dataset\n", 24 | "\n", 25 | "The package includes helper functions to load the dataset.\n", 26 | "\n", 27 | "X represents an np.array of features (covariates),\n", 28 | "T is the event/censoring times and,\n", 29 | "E is the censoring indicator." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 14, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "from dsm import datasets\n", 39 | "x, t, e = datasets.load_dataset('PBC', sequential = True)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "### Compute horizons at which we evaluate the performance of RDSM\n", 47 | "\n", 48 | "Survival predictions are issued at certain time horizons. Here we will evaluate the performance\n", 49 | "of RDSM to issue predictions at the 25th, 50th and 75th event time quantile as is standard practice in Survival Analysis." 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 27, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "import numpy as np\n", 59 | "horizons = [0.25, 0.5, 0.75]\n", 60 | "times = np.quantile([t_[-1] for t_, e_ in zip(t, e) if e_[-1] == 1], horizons).tolist()" 61 | ] 62 | }, 63 | { 64 | "cell_type": "markdown", 65 | "metadata": {}, 66 | "source": [ 67 | "### Splitting the data into train, test and validation sets\n", 68 | "\n", 69 | "We will train RDSM on 70% of the Data, use a Validation set of 10% for Model Selection and report performance on the remaining 20% held out test set." 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 36, 75 | "metadata": {}, 76 | "outputs": [], 77 | "source": [ 78 | "n = len(x)\n", 79 | "\n", 80 | "tr_size = int(n*0.70)\n", 81 | "vl_size = int(n*0.10)\n", 82 | "te_size = int(n*0.20)\n", 83 | "\n", 84 | "x_train, x_test, x_val = np.array(x[:tr_size], dtype = object), np.array(x[-te_size:], dtype = object), np.array(x[tr_size:tr_size+vl_size], dtype = object)\n", 85 | "t_train, t_test, t_val = np.array(t[:tr_size], dtype = object), np.array(t[-te_size:], dtype = object), np.array(t[tr_size:tr_size+vl_size], dtype = object)\n", 86 | "e_train, e_test, e_val = np.array(e[:tr_size], dtype = object), np.array(e[-te_size:], dtype = object), np.array(e[tr_size:tr_size+vl_size], dtype = object)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "metadata": {}, 92 | "source": [ 93 | "### Setting the parameter grid\n", 94 | "\n", 95 | "Lets set up the parameter grid to tune hyper-parameters. We will tune the number of underlying survival distributions, \n", 96 | "($K$), the distribution choices (Log-Normal or Weibull), the learning rate for the Adam optimizer between $1\\times10^{-3}$ and $1\\times10^{-4}$, the number of hidden nodes per layer $50, 100$ and $2$, the number of layers $3, 2$ and $1$ and the type of recurrent cell (LSTM, GRU, RNN)." 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 31, 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "from sklearn.model_selection import ParameterGrid" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 39, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "param_grid = {'k' : [3, 4, 6],\n", 115 | " 'distribution' : ['LogNormal', 'Weibull'],\n", 116 | " 'learning_rate' : [1e-4, 1e-3],\n", 117 | " 'hidden': [50, 100],\n", 118 | " 'layers': [3, 2, 1],\n", 119 | " 'typ': ['LSTM', 'GRU', 'RNN'],\n", 120 | " }\n", 121 | "params = ParameterGrid(param_grid)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "### Model Training and Selection" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 33, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "from dsm import DeepRecurrentSurvivalMachines" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 41, 143 | "metadata": { 144 | "scrolled": true 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "models = []\n", 149 | "for param in params:\n", 150 | " model = DeepRecurrentSurvivalMachines(k = param['k'],\n", 151 | " distribution = param['distribution'],\n", 152 | " hidden = param['hidden'], \n", 153 | " typ = param['typ'],\n", 154 | " layers = param['layers'])\n", 155 | " # The fit method is called to train the model\n", 156 | " model.fit(x_train, t_train, e_train, iters = 1, learning_rate = param['learning_rate'])\n", 157 | " models.append([[model.compute_nll(x_val, t_val, e_val), model]])\n", 158 | "\n", 159 | "best_model = min(models)\n", 160 | "model = best_model[0][1]" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "### Inference" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 42, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "out_risk = model.predict_risk(x_test, times)\n", 177 | "out_survival = model.predict_survival(x_test, times)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "### Evaluation\n", 185 | "\n", 186 | "We evaluate the performance of RDSM in its discriminative ability (Time Dependent Concordance Index and Cumulative Dynamic AUC) as well as Brier Score on the concatenated temporal data." 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 43, 192 | "metadata": {}, 193 | "outputs": [], 194 | "source": [ 195 | "from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc" 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 57, 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "name": "stdout", 205 | "output_type": "stream", 206 | "text": [ 207 | "For 0.25 quantile,\n", 208 | "TD Concordance Index: 0.5748031496062992\n", 209 | "Brier Score: 0.0040254261016212795\n", 210 | "ROC AUC 0.5770750988142292 \n", 211 | "\n", 212 | "For 0.5 quantile,\n", 213 | "TD Concordance Index: 0.8037750594183785\n", 214 | "Brier Score: 0.012524285322743573\n", 215 | "ROC AUC 0.8130810214146464 \n", 216 | "\n", 217 | "For 0.75 quantile,\n", 218 | "TD Concordance Index: 0.8507809756261016\n", 219 | "Brier Score: 0.03105328491896606\n", 220 | "ROC AUC 0.8674491502503145 \n", 221 | "\n" 222 | ] 223 | } 224 | ], 225 | "source": [ 226 | "cis = []\n", 227 | "brs = []\n", 228 | "\n", 229 | "et_train = np.array([(e_train[i][j], t_train[i][j]) for i in range(len(e_train)) for j in range(len(e_train[i]))],\n", 230 | " dtype = [('e', bool), ('t', float)])\n", 231 | "et_test = np.array([(e_test[i][j], t_test[i][j]) for i in range(len(e_test)) for j in range(len(e_test[i]))],\n", 232 | " dtype = [('e', bool), ('t', float)])\n", 233 | "et_val = np.array([(e_val[i][j], t_val[i][j]) for i in range(len(e_val)) for j in range(len(e_val[i]))],\n", 234 | " dtype = [('e', bool), ('t', float)])\n", 235 | "\n", 236 | "for i, _ in enumerate(times):\n", 237 | " cis.append(concordance_index_ipcw(et_train, et_test, out_risk[:, i], times[i])[0])\n", 238 | "brs.append(brier_score(et_train, et_test, out_survival, times)[1])\n", 239 | "roc_auc = []\n", 240 | "for i, _ in enumerate(times):\n", 241 | " roc_auc.append(cumulative_dynamic_auc(et_train, et_test, out_risk[:, i], times[i])[0])\n", 242 | "for horizon in enumerate(horizons):\n", 243 | " print(f\"For {horizon[1]} quantile,\")\n", 244 | " print(\"TD Concordance Index:\", cis[horizon[0]])\n", 245 | " print(\"Brier Score:\", brs[0][horizon[0]])\n", 246 | " print(\"ROC AUC \", roc_auc[horizon[0]][0], \"\\n\")" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [] 255 | } 256 | ], 257 | "metadata": { 258 | "kernelspec": { 259 | "display_name": "Python 3", 260 | "language": "python", 261 | "name": "python3" 262 | }, 263 | "language_info": { 264 | "codemirror_mode": { 265 | "name": "ipython", 266 | "version": 3 267 | }, 268 | "file_extension": ".py", 269 | "mimetype": "text/x-python", 270 | "name": "python", 271 | "nbconvert_exporter": "python", 272 | "pygments_lexer": "ipython3", 273 | "version": "3.7.9" 274 | } 275 | }, 276 | "nbformat": 4, 277 | "nbformat_minor": 4 278 | } 279 | -------------------------------------------------------------------------------- /examples/cmhe_demo_utils.py: -------------------------------------------------------------------------------- 1 | ### Utility functions to find the maximum treatment effect phenotype and mean differential survival 2 | import sys 3 | sys.path.append('../auton_survival/') 4 | 5 | import torch 6 | import numpy as np 7 | import pandas as pd 8 | from sklearn.metrics import roc_curve, roc_auc_score 9 | from sksurv.metrics import concordance_index_ipcw, brier_score, integrated_brier_score 10 | from sksurv.util import Surv 11 | import seaborn as sns 12 | import matplotlib.pyplot as plt 13 | from matplotlib import cm 14 | from matplotlib.colors import ListedColormap, LinearSegmentedColormap 15 | 16 | def plot_synthetic_data(outcomes, features, interventions): 17 | import matplotlib.pyplot as plt 18 | plt.rcParams['text.latex.preamble'] = r"\usepackage{bm} \usepackage{amsmath}" 19 | fs = 48 # Font size 20 | s = 65 # Size of the marker 21 | lim = 2.25 22 | 23 | fig, (ax1, ax2) = plt.subplots(ncols=2, nrows=1, figsize=(16,8), sharey=True) 24 | color_maps = { 25 | 0: LinearSegmentedColormap.from_list("z1", colors=['black', 'C0']), 26 | 1: LinearSegmentedColormap.from_list("z1", colors=['black', 'r']), 27 | 2: LinearSegmentedColormap.from_list("z1", colors=['black', 'r']) 28 | } 29 | for cmap in color_maps: color_maps[cmap].set_gamma(0.4) 30 | 31 | # Data 32 | X1, X2, X3, X4 = features.X1.to_numpy(), features.X2.to_numpy(), features.X3.to_numpy(), features.X4.to_numpy() 33 | 34 | # First sub-plot X1 vs X2 35 | for z in set(outcomes.Z): 36 | mask = (outcomes.Z == z) 37 | sns.kdeplot(ax=ax1, x=X1[mask], y=X2[mask], 38 | fill=False, levels=10, thresh=0.3, 39 | cmap=color_maps[z]) 40 | 41 | ax1.tick_params(axis="both", labelsize=21) 42 | ax1.set_xlabel( r'$\mathcal{X}_1 \longrightarrow$', fontsize=fs) 43 | ax1.set_ylabel( r'$\mathcal{X}_2 \longrightarrow$', fontsize=fs) 44 | ax1.text(-2,0.5, s=r'$\mathcal{Z}_1$', color='C0', fontsize=fs, 45 | bbox=dict(lw=2, boxstyle="round", ec='C0', fc=(.95, .95, .95))) 46 | ax1.text(1,1.75, s=r'$\mathcal{Z}_2$', color='C2', fontsize=fs, 47 | bbox=dict(lw=2, boxstyle="round", ec='C2', fc=(.95, .95, .95))) 48 | ax1.text(1,-1.75, s=r'$\mathcal{Z}_3$', color='C3', fontsize=fs, 49 | bbox=dict(lw=2, boxstyle="round", ec='C3', fc=(.95, .95, .95))) 50 | ax1.set_xlim(-lim, lim) 51 | ax1.set_ylim(-lim, lim) 52 | 53 | # Second sub-plot X1 vs X2 54 | R = (np.abs(X3) + np.abs(X4))<=2 55 | ax2.scatter(X3[R], X4[R], s=s, c='white', marker='X',ec='C0') 56 | ax2.scatter(X3[~R], X4[~R], s=s, c='white', marker='o', ec='C3') 57 | 58 | grid = np.meshgrid([2, 1, 0, -1, -2], [2, 1, 0, -1, -2]) 59 | ax1.scatter(grid[0].ravel(), grid[1].ravel(), color='grey', marker='+', zorder=-500, s=50) 60 | 61 | ax2.set_xlabel(r'$\mathcal{X}_3 \longrightarrow$', fontsize=fs) 62 | ax2.set_ylabel(r'$\mathcal{X}_4 \longrightarrow$', fontsize=fs) 63 | ax2.text(-1.25,.25, s=r'$\phi_1$', color='C0', fontsize=fs, 64 | bbox=dict( lw=2, boxstyle="round", ec='C0', fc=(.95, .95, .95))) 65 | ax2.text(1,-1.75, s=r'$\phi_2$', color='C3', fontsize=fs, 66 | bbox=dict( lw=2, boxstyle="round", ec='C3', fc=(.95, .95, .95))) 67 | ax2.tick_params(axis="both", labelsize=21) 68 | ax2.set_xlim(-lim, lim) 69 | ax2.set_ylim(-lim, lim) 70 | 71 | plt.show() 72 | 73 | def factual_evaluate(train_data, test_data, horizons, predictions): 74 | """ 75 | Function to evaluate the Concordance indices and Integrated brier score 76 | """ 77 | y_train = Surv.from_arrays(train_data[2], train_data[1]) 78 | y_test = Surv.from_arrays(test_data[2], test_data[1]) 79 | 80 | y_train_t_max = np.max([row[1] for row in y_train]) 81 | y_test_t_vidx = (np.array([row[1] for row in y_test])np.quantile(zeta_probs_g, 0.75) 109 | 110 | mean_differential_survival[gr] = find_mean_differential_survival( 111 | outcomes_train.loc[z_mask], interventions_train.loc[z_mask]) 112 | 113 | return np.nanargmax(mean_differential_survival) 114 | 115 | def find_mean_differential_survival(outcomes, interventions): 116 | """ 117 | Given outcomes and interventions, find the maximum restricted mean survival time 118 | """ 119 | from lifelines import KaplanMeierFitter 120 | 121 | treated_km = KaplanMeierFitter().fit(outcomes['uncensored time treated'].values, np.ones(len(outcomes)).astype(bool)) 122 | control_km = KaplanMeierFitter().fit(outcomes['uncensored time control'].values, np.ones(len(outcomes)).astype(bool)) 123 | 124 | unique_times = treated_km.survival_function_.index.values.tolist() + control_km.survival_function_.index.values.tolist() 125 | unique_times = np.unique(unique_times) 126 | 127 | treated_km = treated_km.predict(unique_times, interpolate=True) 128 | control_km = control_km.predict(unique_times, interpolate=True) 129 | 130 | mean_differential_survival = np.trapz(y=(treated_km.values - control_km.values), 131 | x=unique_times) 132 | 133 | return mean_differential_survival 134 | 135 | def plot_phenotypes_roc(outcomes, zeta_probs): 136 | from matplotlib import pyplot as plt 137 | plt.rcParams['text.latex.preamble'] = r"\usepackage{bm} \usepackage{amsmath}" 138 | 139 | zeta = outcomes['Zeta'] 140 | 141 | y_true = zeta == 0 142 | 143 | fpr, tpr, thresholds = roc_curve(y_true, zeta_probs) 144 | auc = roc_auc_score(y_true, zeta_probs) 145 | 146 | plt.figure(figsize=(5,5)) 147 | 148 | plt.plot(fpr, tpr, label="AUC: "+str(round(auc, 3)), c='darkblue') 149 | plt.plot(np.linspace(0,1,100), np.linspace(0,1,100), ls='--', color='k') 150 | 151 | plt.xticks(fontsize=24) 152 | plt.yticks(fontsize=24) 153 | plt.legend(fontsize=24, loc='upper left') 154 | 155 | plt.xlabel('FPR', fontsize=36) 156 | plt.ylabel('TPR', fontsize=36) 157 | plt.xscale('log') 158 | plt.show() -------------------------------------------------------------------------------- /examples/estimators_demo_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pylab as plt 2 | import matplotlib.gridspec as gridspec 3 | 4 | def plot_performance_metrics(results, times): 5 | """Plot Brier Score, ROC-AUC, and time-dependent concordance index 6 | for survival model evaluation. 7 | 8 | Parameters 9 | ----------- 10 | results : dict 11 | Python dict with key as the evaulation metric 12 | times : float or list 13 | A float or list of the times at which to compute 14 | the survival probability. 15 | 16 | Returns 17 | ----------- 18 | matplotlib subplots 19 | 20 | """ 21 | 22 | colors = ['blue', 'purple', 'orange', 'green'] 23 | gs = gridspec.GridSpec(1, len(results), wspace=0.3) 24 | 25 | for fi, result in enumerate(results.keys()): 26 | val = results[result] 27 | x = [str(round(t, 1)) for t in times] 28 | ax = plt.subplot(gs[0, fi]) # row 0, col 0 29 | ax.set_xlabel('Time') 30 | ax.set_ylabel(result) 31 | ax.set_ylim(0, 1) 32 | ax.bar(x, val, color=colors[fi]) 33 | plt.xticks(rotation=30) 34 | plt.show() 35 | -------------------------------------------------------------------------------- /examples/matplotlibrc: -------------------------------------------------------------------------------- 1 | ### Font 2 | font.family : Serif 3 | font.size : 16.0 4 | 5 | ### Lines 6 | lines.linewidth : 0.6 7 | lines.antialiased : True 8 | 9 | ### Axes settings 10 | axes.facecolor : fafafa 11 | axes.edgecolor : black 12 | axes.linewidth : 0.6 13 | axes.labelsize : 12.0 14 | axes.axisbelow : True 15 | 16 | ### Ticks 17 | xtick.major.size : 5 # major tick size in points 18 | xtick.color : black # color of the tick labels 19 | xtick.labelsize : 10.0 # fontsize of the tick labels 20 | xtick.direction : out # direction: in or out 21 | ytick.major.size : 5 # major tick size in points 22 | ytick.color : black # color of the tick labels 23 | ytick.labelsize : 10.0 # fontsize of the tick labels 24 | ytick.direction : out # direction: in or out 25 | 26 | ### Grid settings 27 | axes.grid : True 28 | grid.alpha : 0.4 29 | grid.linewidth : 0.5 30 | 31 | ### Legend 32 | legend.fancybox : True 33 | legend.fontsize : 10.0 34 | legend.facecolor : fdfdfd 35 | 36 | ### Figure 37 | figure.figsize : 10.0, 4.0 38 | figure.facecolor : white 39 | figure.edgecolor : black 40 | 41 | ### Bar plots 42 | hatch.linewidth : 0.1 43 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "auton-survival" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Chirag Nagpal "] 6 | readme = "README.md" 7 | packages = [{include = "auton_survival"}] 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.8" 11 | torch = "^1.13" 12 | numpy = "^1.24" 13 | pandas = "^1.5" 14 | tqdm = "^4.66" 15 | scikit-learn = "^1.2" 16 | torchvision = "^0.14" 17 | scikit-survival = "^0.21" 18 | lifelines = "^0.26" 19 | 20 | [tool.semantic_release] 21 | version_toml = [ 22 | "pyproject.toml:tool.poetry.version" 23 | ] 24 | branch = "master" 25 | commit_author = "github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>" 26 | upload_to_PyPI = false 27 | upload_to_release = true 28 | build_command = "pip install poetry && poetry build" 29 | 30 | [tool.semantic_release.commit_parser_options] 31 | allowed_tags = [ 32 | "build", 33 | "chore", 34 | "ci", 35 | "docs", 36 | "feat", 37 | "fix", 38 | "perf", 39 | "style", 40 | "refactor", 41 | "test" 42 | ] 43 | minor_tags = ["feat"] 44 | patch_tags = ["fix", "perf", "refactor", "build", "style"] 45 | 46 | [build-system] 47 | requires = [ 48 | "poetry-core" 49 | ] 50 | build-backend = "poetry.core.masonry.api" 51 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autonlab/auton-survival/5dde465f7223601717abddc1d075e837707c403b/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_dsm.py: -------------------------------------------------------------------------------- 1 | """This module contains test functions to 2 | test the accuracy of Deep Survival Machines 3 | models on certain standard datasets. 4 | """ 5 | import unittest 6 | 7 | from auton_survival.models.dsm import DeepSurvivalMachines 8 | from auton_survival.models.dsm.dsm_torch import DeepSurvivalMachinesTorch 9 | from auton_survival.models.dsm import datasets 10 | 11 | import numpy as np 12 | 13 | class TestDSM(unittest.TestCase): 14 | """Base Class for all test functions""" 15 | def test_support_dataset(self): 16 | """Test function to load and test the SUPPORT dataset. 17 | """ 18 | 19 | x, t, e = datasets.load_dataset('SUPPORT') 20 | t_median = np.median(t[e==1]) 21 | 22 | self.assertIsInstance(x, np.ndarray) 23 | self.assertIsInstance(t, np.ndarray) 24 | self.assertIsInstance(e, np.ndarray) 25 | 26 | self.assertEqual(x.shape, (9105, 44)) 27 | self.assertEqual(t.shape, (9105,)) 28 | self.assertEqual(e.shape, (9105,)) 29 | 30 | model = DeepSurvivalMachines() 31 | self.assertIsInstance(model, DeepSurvivalMachines) 32 | model.fit(x, t, e, iters=10) 33 | self.assertIsInstance(model.torch_model, 34 | DeepSurvivalMachinesTorch) 35 | risk_score = model.predict_risk(x, t_median) 36 | survival_probability = model.predict_survival(x, t_median) 37 | np.testing.assert_equal((risk_score+survival_probability).all(), 1.0) 38 | 39 | def test_pbc_dataset(self): 40 | """Test function to load and test the PBC dataset. 41 | """ 42 | 43 | x, t, e = datasets.load_dataset('PBC') 44 | t_median = np.median(t[e==1]) 45 | 46 | self.assertIsInstance(x, np.ndarray) 47 | self.assertIsInstance(t, np.ndarray) 48 | self.assertIsInstance(e, np.ndarray) 49 | 50 | self.assertEqual(x.shape, (1945, 25)) 51 | self.assertEqual(t.shape, (1945,)) 52 | self.assertEqual(e.shape, (1945,)) 53 | 54 | model = DeepSurvivalMachines() 55 | self.assertIsInstance(model, DeepSurvivalMachines) 56 | model.fit(x, t, e, iters=10) 57 | self.assertIsInstance(model.torch_model, 58 | DeepSurvivalMachinesTorch) 59 | risk_score = model.predict_risk(x, t_median) 60 | survival_probability = model.predict_survival(x, t_median) 61 | np.testing.assert_equal((risk_score+survival_probability).all(), 1.0) 62 | 63 | def test_framingham_dataset(self): 64 | """Test function to load and test the Framingham dataset. 65 | """ 66 | x, t, e = datasets.load_dataset('FRAMINGHAM') 67 | t_median = np.median(t) 68 | 69 | self.assertIsInstance(x, np.ndarray) 70 | self.assertIsInstance(t, np.ndarray) 71 | self.assertIsInstance(e, np.ndarray) 72 | 73 | self.assertEqual(x.shape, (11627, 18)) 74 | self.assertEqual(t.shape, (11627,)) 75 | self.assertEqual(e.shape, (11627,)) 76 | 77 | model = DeepSurvivalMachines() 78 | self.assertIsInstance(model, DeepSurvivalMachines) 79 | model.fit(x, t, e, iters=10) 80 | self.assertIsInstance(model.torch_model, 81 | DeepSurvivalMachinesTorch) 82 | risk_score = model.predict_risk(x, t_median) 83 | survival_probability = model.predict_survival(x, t_median) 84 | np.testing.assert_equal((risk_score+survival_probability).all(), 1.0) 85 | --------------------------------------------------------------------------------