├── .gitignore ├── LICENSE ├── README.md ├── data └── Brain_Integ.mat ├── examples ├── Bayesian_Optimization.py ├── CostFunction.py ├── ModelAnalysis.py └── Run.py ├── requirements.txt ├── results ├── Gradients.gct ├── Gradients.rnk ├── Heatmap.pdf ├── KM.10p_CNVArm.pdf ├── KM.10q_CNVArm.pdf ├── KM.CDKN2A_CNV.pdf ├── KM.IDH1_Mut.pdf ├── KM.IDH2_Mut.pdf ├── KM.IGFBP2_Protein.pdf ├── KM.PTEN_CNV.pdf ├── KM.SMARCA4_Mut.pdf ├── KM.age_at_initial_pathologic_diagnosis_Clinical.pdf ├── KM.histological_type-Is-untreated primary (de novo) gbm_Clinical.pdf ├── PairedScatter.Feature.pdf ├── PairedScatter.Gradient.pdf ├── RankedBox.pdf ├── c_index_list.mat └── glioma_integ_model ├── setup.py └── survivalnet ├── __init__.py ├── analysis ├── FeatureAnalysis.py ├── PathwayAnalysis.py ├── ReadGMT.py ├── RiskCluster.py ├── RiskCohort.py ├── Visualization.py ├── WriteGCT.py ├── WriteRNK.py └── __init__.py ├── model ├── DropoutHiddenLayer.py ├── HiddenLayer.py ├── Model.py ├── RiskLayer.py ├── SparseDenoisingAutoencoder.py └── __init__.py ├── optimization ├── BFGS.py ├── EarlyStopping.py ├── GDLS.py ├── Optimization.py ├── SurvivalAnalysis.py └── __init__.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.out 2 | *.pyc 3 | *.lprof 4 | data/ 5 | results/ 6 | output/ 7 | build 8 | dist 9 | *.egg-info 10 | *.out 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | 179 | Copyright 2017 Emory University 180 | 181 | Licensed under the Apache License, Version 2.0 (the "License"); 182 | you may not use this file except in compliance with the License. 183 | You may obtain a copy of the License at 184 | 185 | http://www.apache.org/licenses/LICENSE-2.0 186 | 187 | Unless required by applicable law or agreed to in writing, software 188 | distributed under the License is distributed on an "AS IS" BASIS, 189 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 190 | See the License for the specific language governing permissions and 191 | limitations under the License. 192 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SurvivalNet 2 | SurvivalNet is a package for building survival analysis models using deep learning. The SurvivalNet package has the following features: 3 | 4 | * Training deep networks for time-to-event data using Cox partial likelihood 5 | * Automatic tuning of network architecture and learning hyper-parameters with Bayesian Optimization 6 | * Interpretation of trained networks using partial derivatives 7 | * Layer-wise unsupervised pre-training 8 | 9 | A [short paper [1]](https://arxiv.org/abs/1609.08663) descibing our approach of using Cox partial likelihood was presented in ICLR in May, 2016 is available at arXiv. A [longer paper [2]](https://www.nature.com/articles/s41598-017-11817-6) was later published describing the package and showing applications in Nature Scientific Reports. 10 | 11 | ## References: 12 | [[1] Yousefi, Safoora, et al. "Learning Genomic Representations to Predict Clinical Outcomes in Cancer." arXiv preprint arXiv:1609.08663, May 2016.](https://arxiv.org/abs/1609.08663) 13 | 14 | [[2] Yousefi, Safoora, et al. "Predicting clinical outcomes from large scale cancer genomic profiles with deep survival models." Nature Scientific Reports 7, Article number: 11707 (2017) doi:10.1038/s41598-017-11817-6](https://www.nature.com/articles/s41598-017-11817-6) 15 | 16 | # Getting Started 17 | The **examples** folder provides scripts to: 18 | 19 | * Train a neural network on your dataset using Bayesian Optimization (Run.py) 20 | * Set parameters for Bayesian Optimizaiton (BayesianOptimization.py) 21 | * Define a cost function for use by Bayesian Optimization (CostFunction.py) 22 | * Interpret a trained model and analyze feature importance (ModelAnalysis.py) 23 | 24 | Run.py demonstrates how you can provide the input to the train.py module. To get started, you need the following three numpy arrays: 25 | 26 | * X: input data of size (number of patients, number of features). Patients must be sorted with respect to event or censoring times 'T'. 27 | * T: Time of event or time to last follow-up, appearing in increasing order and corresponding to the rows of 'X'. size: (number of patients, ). 28 | * O: Right-censoring status. A value of 1 means the event is observed (i.e. deceased or disease progression), a 0 value indicates that the sample is censored. size:(number of patients, ). 29 | 30 | After splitting the data into train, validation and test sets, feed the corresponding arrays to 'SurvivalAnalysis.calc\_at\_risk' to get the data that can be used to train the network. 31 | ```python 32 | train_set['X'], train_set['T'], train_set['O'], train_set['A'] = sa.calc_at_risk(X_train, T_train, O_train) 33 | test_set['X'], test_set['T'], test_set['O'], test_set['A'] = sa.calc_at_risk(X_test, T_test, O_test) 34 | ``` 35 | The resulting dictionaries 'train\_set' and 'test\_set' can be directly fed to train.py. 36 | 37 | The provided example scripts read data provided in .mat format. You can, however, convert your data from any format to numpy arrays and follow the above procedure to prepare it for the SurvivalNet package. 38 | 39 | ## Installation Guide for Docker Image 40 | 41 | A Docker image for SurvivalNet is provided for those who prefer not to build from source. This image contains an installation of SurvivalNet on a bare Ubuntu operating system along with sample data used in our *bioRxiv* paper. This helps users avoid installation of the */bayesopt/* package and other dependencies required by SurvivalNet. 42 | 43 | The SurvivalNet Docker Image can either be downloaded [here](https://hub.docker.com/r/cancerdatascience/snet/), or can be pulled from Docker hub using the following command: 44 | 45 | sudo docker pull cancerdatascience/snet:version1 46 | 47 | Running this image on your local machine with the command 48 | 49 | sudo docker run -it cancerdatascience/snet:version1 /bin/bash 50 | 51 | launches a terminal within the image where users have access to the package installation. 52 | 53 | Example python scripts used in generating our results for the full-length paper can be found in the folder 54 | 55 | cd /SurvivalNet/examples/ 56 | 57 | These scripts provide examples of training and validating deep survival models. The main script 58 | 59 | python Run.py 60 | 61 | will perform Bayesian optimization to identify the optimal deep survival model configuation and will update the terminal with the step by step updates of the learning process. 62 | 63 | The sample data file - ***Brain_Integ.mat*** is located inside the */SurvivalNet/data/* folder. By default, ***Run.py*** uses this data for learning. 64 | 65 | 66 | ### Using your own data to train networks 67 | 68 | You can train a network using your own data by mounting a folder within the SurvivalNet Docker image. The command 69 | 70 | sudo docker run -v //:// -it cancerdatascience/snet:version1 /bin/bash 71 | 72 | will pull and run the Docker image, and mount *hostmachine_data_path* inside the container at *container_data_path*. container data path. Any files placed into the mounted folder on the host machine will appear in *container_data_path* on the image. Setting *container_data_path* as */SurvivalNet/data/* will place the image mount in the SurvivalNet data folder. 73 | 74 | -------------------------------------------------------------------------------- /data/Brain_Integ.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/data/Brain_Integ.mat -------------------------------------------------------------------------------- /examples/Bayesian_Optimization.py: -------------------------------------------------------------------------------- 1 | import bayesopt 2 | import numpy as np 3 | from time import clock 4 | from CostFunction import cost_func 5 | 6 | def tune(): 7 | """Tunes hyperparameters of a feed forward net using Bayesian Optimization. 8 | 9 | Returns: 10 | mvalue: float. Best value of the cost function found using BayesOpt. 11 | x_out: 1D array. Best hyper-parameters found. 12 | """ 13 | params = {} 14 | params['n_iterations'] = 50 15 | params['n_iter_relearn'] = 1 16 | params['n_init_samples'] = 2 17 | 18 | print "*** Model Selection with BayesOpt ***" 19 | n = 6 # n dimensions 20 | # params: #layer, width, dropout, nonlinearity, l1_rate, l2_rate 21 | lb = np.array([1 , 10 , 0., 0., 0., 0.]) 22 | ub = np.array([10, 500, 1., 1., 0., 0.]) 23 | 24 | start = clock() 25 | mvalue, x_out, _ = bayesopt.optimize(cost_func, n, lb, ub, params) 26 | 27 | # Usage of BayesOpt with discrete set of values for hyper-parameters. 28 | 29 | #layers = [1, 3, 5, 7, 9, 10] 30 | #hsizes = [10, 50, 100, 150, 200, 300] 31 | #drates = [0.0, .1, .3, .5, .7, .9] 32 | #x_set = np.array([[layers, hsizes, drates], dtype=float).transpose() 33 | #mvalue, x_out, _ = bayesopt.optimize_discrete(cost_func, x_set, params) 34 | 35 | print "Result", mvalue, "at", x_out 36 | print "Running time:", clock() - start, "seconds" 37 | return mvalue, x_out 38 | 39 | 40 | if __name__=='__main__': 41 | tune() 42 | -------------------------------------------------------------------------------- /examples/CostFunction.py: -------------------------------------------------------------------------------- 1 | import os 2 | from survivalnet.train import train 3 | import numpy as np 4 | import scipy.io as sio 5 | from survivalnet.optimization import SurvivalAnalysis 6 | import theano 7 | import cPickle 8 | 9 | LEARNING_RATE = 0.001 10 | EPOCHS = 40 11 | OPTIM = 'GDLS' 12 | 13 | 14 | def cost_func(params): 15 | n_layers = int(params[0]) 16 | n_hidden = int(params[1]) 17 | do_rate = params[2] 18 | nonlin = theano.tensor.nnet.relu if params[3] > .5 else np.tanh 19 | lambda1 = params[4] 20 | lambda2 = params[5] 21 | 22 | # Loads data sets saved by the Run.py module. 23 | with open('train_set', 'rb') as f: 24 | train_set = cPickle.load(f) 25 | with open('val_set', 'rb') as f: 26 | val_set = cPickle.load(f) 27 | 28 | pretrain_config = None #No pre-training 29 | pretrain_set = None 30 | 31 | finetune_config = {'ft_lr':LEARNING_RATE, 'ft_epochs':EPOCHS} 32 | 33 | # Prints experiment identifier. 34 | print('nl{}-hs{}-dor{}_nonlin{}'.format(str(n_layers), str(n_hidden), 35 | str(do_rate), str(nonlin))) 36 | 37 | _, _, val_costs, val_cindices, _, _, _, maxIter = train(pretrain_set, 38 | train_set, val_set, pretrain_config, finetune_config, n_layers, 39 | n_hidden, dropout_rate=do_rate, lambda1=lambda1, lambda2=lambda2, 40 | non_lin=nonlin, optim=OPTIM, verbose=False, earlystp=False) 41 | 42 | if not val_costs or np.isnan(val_costs[-1]): 43 | print 'Skipping due to NAN' 44 | return 1 45 | 46 | return (1 - val_cindices[maxIter]) 47 | 48 | if __name__ == '__main__': 49 | res = cost_func([1.0, 38.0, 0.3, 0.4, 0.00004, 0.00004]) 50 | print res 51 | -------------------------------------------------------------------------------- /examples/ModelAnalysis.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import scipy.io as sio 3 | import survivalnet as sn 4 | 5 | # Integrated models. 6 | # Defines model/dataset pairs. 7 | ModelPaths = ['results/'] 8 | Models = ['final_model'] 9 | Data = ['data/Brain_Integ.mat'] 10 | 11 | # Loads datasets and performs feature analysis. 12 | for i, Path in enumerate(ModelPaths): 13 | 14 | # Loads normalized data. 15 | X = sio.loadmat(Data[i]) 16 | 17 | # Extracts relevant values. 18 | Samples = X['Patients'] 19 | Normalized = X['Integ_X'].astype('float32') 20 | Raw = X['Integ_X_raw'].astype('float32') 21 | Symbols = X['Integ_Symbs'] 22 | Survival = X['Survival'] 23 | Censored = X['Censored'] 24 | 25 | # Loads model. 26 | f = open(Path + Models[i], 'rb') 27 | Model = pickle.load(f) 28 | f.close() 29 | 30 | sn.analysis.FeatureAnalysis(Model, Normalized, Raw, Symbols, 31 | Survival, Censored, 32 | Tau=5e-2, Path=Path) 33 | -------------------------------------------------------------------------------- /examples/Run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import Bayesian_Optimization as BayesOpt 4 | import os 5 | import scipy.io as sio 6 | from survivalnet.optimization import SurvivalAnalysis 7 | import numpy as np 8 | from survivalnet.train import train 9 | import theano 10 | import cPickle 11 | 12 | N_SHUFFLES = 20 13 | 14 | def Run(input_path, output_path, do_bayes_opt, feature_key, epochs): 15 | if not os.path.exists(output_path): 16 | os.makedirs(output_path) 17 | """Runs the model selection and assesment of survivalnet. 18 | 19 | Arguments: 20 | input_path: str. Path to dataset. The input dataset in this script is 21 | expected to be a mat file contating 'Survival' and 'Censored' 22 | keys in addition the the feature_key. 23 | output_path: str. Path to save the model and results. 24 | do_bayes_opt: bool. Whether to do Bayesian optimization of hyperparams. 25 | feature_key: str. Key to the input data in the .mat file. 26 | epochs: int. Number of training epochs. 27 | """ 28 | # Loading dataset. The model requires a nxp matrix of input data, nx1 array 29 | # of time to event labels, and nx1 array of censoring status. 30 | D = sio.loadmat(input_path) 31 | T = np.asarray([t[0] for t in D['Survival']]).astype('float32') 32 | # C is censoring status where 1 means incomplete folow-up. We change it to 33 | # Observed status where 1 means death. 34 | O = 1 - np.asarray([c[0] for c in D['Censored']]).astype('int32') 35 | X = D[feature_key].astype('float32') 36 | 37 | # Optimization algorithm. 38 | opt = 'GDLS' 39 | 40 | # Pretraining settings 41 | # pretrain_config = {'pt_lr':0.01, 'pt_epochs':1000, 42 | # 'pt_batchsize':None,'corruption_level':.3} 43 | pretrain_config = None #No pre-training 44 | 45 | # The results in the paper are averaged over 20 random assignment of samples 46 | # to training/validation/testing sets. 47 | cindex_results =[] 48 | avg_cost = 0 49 | for i in range(N_SHUFFLES): 50 | # Sets random generator seed for reproducibility. 51 | prng = np.random.RandomState(i) 52 | order = prng.permutation(np.arange(len(X))) 53 | X = X[order] 54 | O = O[order] 55 | T = T[order] 56 | 57 | # Uses the entire dataset for pretraining 58 | pretrain_set = X 59 | 60 | # 'foldsize' denotes th number of samples used for testing. The same 61 | # number of samples is used for model selection. 62 | fold_size = int(20 * len(X) / 100) # 20% of the dataset. 63 | train_set = {} 64 | test_set = {} 65 | val_set = {} 66 | 67 | # Caclulates the risk group for every patient i: patients whose time of 68 | # death is greater than that of patient i. 69 | sa = SurvivalAnalysis() 70 | train_set['X'], train_set['T'], train_set['O'], train_set['A'] = sa.calc_at_risk( 71 | X[2*fold_size:], 72 | T[2*fold_size:], 73 | O[2*fold_size:]); 74 | test_set['X'], test_set['T'], test_set['O'], test_set['A'] = sa.calc_at_risk( 75 | X[:fold_size], 76 | T[:fold_size], 77 | O[:fold_size]); 78 | val_set['X'], val_set['T'], val_set['O'], val_set['A'] = sa.calc_at_risk( 79 | X[fold_size:2*fold_size], 80 | T[fold_size:2*fold_size], 81 | O[fold_size:2*fold_size]); 82 | 83 | # Writes data sets for bayesopt cost function's use. 84 | with file('train_set', 'wb') as f: 85 | cPickle.dump(train_set, f, protocol=cPickle.HIGHEST_PROTOCOL) 86 | with file('val_set', 'wb') as f: 87 | cPickle.dump(val_set, f, protocol=cPickle.HIGHEST_PROTOCOL) 88 | 89 | if do_bayes_opt == True: 90 | print '***Model Selection with BayesOpt for shuffle', str(i), '***' 91 | _, bo_params = BayesOpt.tune() 92 | n_layers = int(bo_params[0]) 93 | n_hidden = int(bo_params[1]) 94 | do_rate = bo_params[2] 95 | nonlin = theano.tensor.nnet.relu if bo_params[3]>.5 else np.tanh 96 | lambda1 = bo_params[4] 97 | lambda2 = bo_params[5] 98 | else: 99 | n_layers = 1 100 | n_hidden = 100 101 | do_rate = 0.5 102 | lambda1 = 0 103 | lambda2 = 0 104 | nonlin = np.tanh # or nonlin = theano.tensor.nnet.relu 105 | 106 | # Prints experiment identifier. 107 | expID = 'nl{}-hs{}-dor{}_nonlin{}_id{}'.format( 108 | str(n_layers), str(n_hidden), str(do_rate), str(nonlin), str(i)) 109 | 110 | finetune_config = {'ft_lr':0.0001, 'ft_epochs':epochs} 111 | 112 | print '*** Model Assesment ***' 113 | _, train_cindices, _, test_cindices, _, _, model, _ = train(pretrain_set, 114 | train_set, test_set, pretrain_config, finetune_config, n_layers, 115 | n_hidden, dropout_rate=do_rate, lambda1=lambda1, lambda2=lambda2, 116 | non_lin=nonlin, optim=opt, verbose=True, earlystp=False) 117 | cindex_results.append(test_cindices[-1]) 118 | avg_cost += test_cindices[-1] 119 | print expID , ' ', test_cindices[-1], 'average = ',avg_cost/(i+1) 120 | print np.mean(cindex_results), np.std(cindex_results) 121 | with file(os.path.join(output_path, 'final_model'), 'wb') as f: 122 | cPickle.dump(model, f, protocol=cPickle.HIGHEST_PROTOCOL) 123 | 124 | outputFileName = os.path.join(output_path, 'c_index_list.mat') 125 | sio.savemat(outputFileName, {'c_index':cindex_results}) 126 | 127 | 128 | if __name__ == '__main__': 129 | parser = argparse.ArgumentParser(prog='Run', 130 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 131 | description = 'Script to train survival net') 132 | parser.add_argument('-ip', '--input_path', dest='input_path', 133 | default='./data/Brain_Integ.mat', 134 | help='Path specifying location of dataset.') 135 | parser.add_argument('-sp', '--output_path', dest='output_path', 136 | default='./results', 137 | help='Path specifying where to save output files.') 138 | parser.add_argument('-bo', '--bayes_opt', dest='do_bayes_opt', 139 | default=False, action='store_true', 140 | help='Pass this flag if you want to do Bayesian Optimization.') 141 | parser.add_argument('-key', '--feature_key', dest='feature_key', 142 | default='Integ_X', 143 | help='Name of input features in the .mat file.') 144 | parser.add_argument('-i', '--epochs', dest='epochs', default=40, type=int, 145 | help='Number of training epochs.') 146 | args = parser.parse_args() 147 | Run(args.input_path, args.output_path, args.do_bayes_opt, args.feature_key, 148 | args.epochs) 149 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==0.18.0 2 | setuptools==34.0.2 3 | numpy==1.12.0 4 | statsmodels==0.8.0 5 | matplotlib==1.4.3 6 | lifelines==0.8.0.0 7 | Theano==0.8.2 8 | bayesopt==0.3 9 | -------------------------------------------------------------------------------- /results/Gradients.rnk: -------------------------------------------------------------------------------- 1 | CDKN2A.2 -0.187542485821 2 | 10q -0.185430899186 3 | 10p -0.166688611442 4 | SMARCA4 -0.11768246575 5 | PTEN.2 -0.10810855071 6 | IDH1 -0.100705719788 7 | IDH2 -0.0984232785396 8 | KLF6 -0.086117416591 9 | IRF4 -0.0835822536739 10 | CIC -0.0835243607371 11 | 9p -0.0803702131304 12 | 14q -0.0753509631238 13 | NUTM1 -0.0695788314593 14 | CTNNB1 -0.0693273755484 15 | RB1.2 -0.0681229900515 16 | 11q -0.0675465394984 17 | CREB3L1 -0.0664043675196 18 | 4p -0.0659575107345 19 | TP53.1 -0.0618574962164 20 | RPTOR -0.0614917598652 21 | MYO5A -0.0587812742309 22 | SETD2.2 -0.0569695102164 23 | MYCN -0.056823080643 24 | GOPC -0.0547447442715 25 | RAF1.2 -0.0546085867224 26 | KCNJ15 -0.0536722137376 27 | FGFR3 -0.0532882563025 28 | 6p -0.0529441312198 29 | 15q -0.0527252581028 30 | histological_type-Is-oligoastrocytoma -0.0526956015903 31 | IRS4 -0.0519197211356 32 | 22q -0.0513859475089 33 | CHEK2.2 -0.049889048911 34 | NRAS.1 -0.0497220682987 35 | FMR1 -0.0492033814856 36 | NF2 -0.0488065054602 37 | OR52M1 -0.0483595401527 38 | ROS1 -0.0478521372002 39 | ATRX.2 -0.0470775318662 40 | PREX1 -0.0467222797643 41 | 13q -0.0456454082788 42 | AR -0.045362837292 43 | MAPK14 -0.0453341007901 44 | CD44 -0.0444032720664 45 | 11p -0.0441529614086 46 | ARID1A -0.0425594724317 47 | PDCD4 -0.0413781156097 48 | CASP7 -0.0412659738841 49 | MYC.1 -0.0410835516587 50 | PTEN.3 -0.0409528340002 51 | CHEK1.1 -0.0405558132493 52 | ETS1 -0.0399950105148 53 | NOTCH1.2 -0.0396912692841 54 | DNMT3A -0.0393115107931 55 | EGFR.2 -0.0386975792952 56 | FAM47C -0.0381907557447 57 | MYH8 -0.0372833328582 58 | MTOR.2 -0.037044432612 59 | TP53.2 -0.0366460306276 60 | CLDN7 -0.0362334963449 61 | GATA3 -0.035397765144 62 | BCL2L11 -0.035042378119 63 | NIPBL -0.0338433892768 64 | PRCC -0.0337502411342 65 | CDH2 -0.033484015434 66 | CNOT1 -0.0325454416173 67 | IRS1 -0.0322433320308 68 | STAT3 -0.0321522847668 69 | ACACA -0.0315812702064 70 | DOCK5 -0.0310752703951 71 | PRKAA1.2 -0.0307854569949 72 | DVL3 -0.0307127801302 73 | PIK3R1 PIK3R2 -0.0304800954875 74 | NKD2 -0.0303698467215 75 | ANXA7 -0.0302613514543 76 | BUB1B -0.0302251712098 77 | G6PD -0.0302110379173 78 | TSC2.1 -0.0299833481189 79 | TCF12 -0.0299722620503 80 | ARAF -0.0287642510745 81 | TMEM216 -0.0284519776379 82 | ERBB3.2 -0.0282615488448 83 | TSC1 -0.0280687132007 84 | BRCA2 -0.0279882335471 85 | RSPO3 -0.027623887811 86 | SETD2.1 -0.0275347065917 87 | STK19 -0.0274209006224 88 | ZNF512B -0.0273597099512 89 | CDH1 -0.0271553876544 90 | SYK -0.0268016601348 91 | AKT1 -0.0265600041421 92 | PTPN11 -0.0264553272383 93 | CMA1 -0.0263733157855 94 | MAF -0.0261665555897 95 | KMT2A -0.0258489893337 96 | TERT -0.0253177507779 97 | AKT1 AKT2 AKT3.1 -0.0252022844804 98 | EIF4E -0.0249303154093 99 | RAB11A RAB11B -0.024565919523 100 | BCOR -0.0244810430039 101 | ESR2 -0.0240901502237 102 | GAGE2A -0.0239967396832 103 | NF1.2 -0.0237646858297 104 | FOXM1 -0.0236512529769 105 | WRN -0.0234923487578 106 | DICER1 -0.0231773693384 107 | MAPK9 -0.0229482849677 108 | BECN1 -0.0225325831978 109 | SMAD1 -0.0223612064007 110 | HSPA1A -0.0215787898615 111 | histological_type-Is-oligodendroglioma -0.0214026308834 112 | MET.2 -0.0213547531637 113 | DDX6 -0.0210090931387 114 | KRAS -0.020612759051 115 | ATF7IP2 -0.0203539842404 116 | SRC.2 -0.0203099776671 117 | BCL2 -0.0198318031039 118 | EIF4G1 -0.0195952932904 119 | AKT1S1 -0.0195857423804 120 | TMPRSS6 -0.0190996817881 121 | NUP98 -0.0189631077891 122 | SOX4 -0.0185777984202 123 | RICTOR.1 -0.0185718018612 124 | SLC35A2 -0.0181374824525 125 | RPS6KB1.1 -0.0176681319649 126 | MAPK1 -0.0176555952616 127 | histological_type-Is-astrocytoma -0.0176166632207 128 | 4q -0.0172054222488 129 | ESR1.1 -0.0167170578267 130 | MAP2K1.2 -0.0167011202999 131 | BAD -0.0165787216191 132 | SMAD4 -0.0164417577977 133 | TNFRSF9 -0.016012950681 134 | PEA15.1 -0.0155118884021 135 | EIF4EBP1.3 -0.0154031438221 136 | LCK -0.0147835266618 137 | PAK1 -0.0145121246159 138 | CCNE2 -0.0138209493333 139 | CPEB4 -0.0136731863966 140 | ECT2L -0.0136604190607 141 | STK11 -0.0133016351359 142 | ERBB3.1 -0.0132675445662 143 | YBX1.1 -0.0130883911571 144 | POM121 -0.0130876923306 145 | ATM -0.0130002914731 146 | FUBP1 -0.0125129954423 147 | RASGRF2 -0.0122955076808 148 | PDK1.2 -0.0122314190712 149 | RPS6KA1.1 -0.0119425987007 150 | CDKN2A.1 -0.011388754162 151 | 6q -0.0112596718444 152 | MAP3K12 -0.0108393211599 153 | ACACA ACACB -0.0107371916982 154 | TP53BP1 -0.0104384737413 155 | KIT -0.0102970795905 156 | KDR.1 -0.0099206412687 157 | TREML2 -0.00973157803254 158 | CCNB1 -0.00972273714504 159 | H3F3A -0.00914363230052 160 | CAV1 -0.00877638654943 161 | CDKN1B.3 -0.00875562180869 162 | TCL1A -0.00769609888831 163 | ZNF709 -0.00762602498711 164 | 9q -0.00705387159987 165 | CCNE1.2 -0.00703096942099 166 | RBM15 -0.00686863874208 167 | AKT1 AKT2 AKT3.2 -0.006367977825 168 | KRT15 -0.00636063655043 169 | ARID2.2 -0.00626530783854 170 | gender-Is-male -0.0062432542568 171 | MTOR.1 -0.0059520418655 172 | ARID2.1 -0.00560533376764 173 | CDKN1B.1 -0.0044948532877 174 | MS4A1 -0.00425257201776 175 | AKT1 AKT2 AKT3.3 -0.0041740481665 176 | MAP3K1 -0.00379522082934 177 | EMG1 -0.00350499581886 178 | TSHR -0.00295097745646 179 | PEA15.2 -0.00281812602198 180 | ITGA2 -0.00277979035575 181 | EPS8L1 -0.0022489228443 182 | ASNS -0.00197630938043 183 | PRG4 -0.00190579334035 184 | RAD51 -0.00175688256552 185 | RPL22 -0.00173252434693 186 | TRIP11 -0.00172679191654 187 | PRDX1 -0.0015624982061 188 | RAD50 -0.00155990136501 189 | NFKB1 -0.00152161203092 190 | RB1.1 -0.00149131683067 191 | SERPING1 -0.00122054985473 192 | INF2 -0.00115905179504 193 | CDH3 -0.00102812903854 194 | LGALS13 -0.000542956351458 195 | CHEK2.1 -0.000532782200337 196 | SRPX -0.000528212373312 197 | RPL5.1 -0.000489759158839 198 | MX2 -0.000124208945064 199 | PDHA1 -0.000104708658082 200 | NOTCH1.1 5.69738126053e-05 201 | GSK3A GSK3B.2 0.000307726402329 202 | TIGAR 0.000350679932822 203 | BCL11B 0.000354897582783 204 | MACC1 0.000718937864646 205 | MAPK1 MAPK3 0.000999696685516 206 | BIRC2 0.00104340639839 207 | RPS6.2 0.0012542765311 208 | BCL2L1 0.0015149477953 209 | EGFR.3 0.00219967237543 210 | MUC17 0.00227420391184 211 | HTRA2 0.00231714799016 212 | FAM126B 0.00245925547661 213 | EIF4EBP1.4 0.00252272457974 214 | FANCA 0.00252514631638 215 | GOLGA5 0.00273897658301 216 | EGFR.4 0.00278587920718 217 | TRERF1 0.00335538525895 218 | RICTOR.2 0.00336177807435 219 | BRAF.2 0.00396051129398 220 | TFRC.1 0.00417249383548 221 | ERBB2.1 0.00420372307572 222 | NRG1 0.00470414081731 223 | SEMG1 0.00502992867407 224 | YWHAE 0.00523565456226 225 | GRHL3 0.00551694617081 226 | ZNF41 0.00565949782728 227 | PALB2 0.00571699224886 228 | KRTAP5-3 0.00627526860094 229 | ZDHHC4 0.00644922015868 230 | MRE11A 0.00660396930201 231 | PRKAA1.1 0.00698781484307 232 | PIK3CA.2 0.00718400660702 233 | CDK4 0.00740020648384 234 | TP53.3 0.00760175781323 235 | YWHAB 0.00785592972224 236 | SLC26A3 0.00787376412295 237 | CREBZF 0.00805362565222 238 | INPP4B 0.00845822756688 239 | TNFAIP3 0.00881005873509 240 | MED9 0.00892883260845 241 | FASN 0.00904964630468 242 | NUP210L 0.00937691790844 243 | EEF2 0.00986605031286 244 | CDK1 0.0102054813348 245 | RPS6.3 0.010362959202 246 | PRKCA.2 0.0108579210125 247 | GAPDH 0.011279116762 248 | HSP90AA1 0.011380095907 249 | GSK3A GSK3B.1 0.0117469138677 250 | GSK3A GSK3B.3 0.011765281925 251 | MORN5 0.011946305216 252 | STMN1 0.0121285147664 253 | MYH11 0.012156321689 254 | ACADS 0.0124055794625 255 | PCNA 0.0128025147134 256 | RPS6.1 0.0131272537199 257 | TLR6 0.0135907786781 258 | 19p 0.0140939607517 259 | 18q 0.0142938894535 260 | STAT5A 0.0147685140281 261 | MYT1 0.0148918046398 262 | TFRC.2 0.0150750507971 263 | TPX2 0.0152438820118 264 | PARK7 0.0157550476821 265 | SRC.3 0.0168844645386 266 | PECAM1 0.0170725702742 267 | PRKCB 0.0171373471166 268 | STAG2 0.0175456284013 269 | EEF1A1 0.0177107963539 270 | NAP1L2 0.0182632088281 271 | CBFA2T3 0.0189133732784 272 | BID 0.0194448990133 273 | RPL5.2 0.0195210572916 274 | SHC1 0.0195488416672 275 | PTPRK 0.0196662424554 276 | SDHA 0.0198931912982 277 | TRPV6 0.0200107090822 278 | KDR.2 0.0202665330569 279 | TYRP1 0.0207055593213 280 | 19q 0.0208823488038 281 | ZBTB20 0.0209595086235 282 | NDRG1 0.021429914824 283 | DLX6 0.0214455843407 284 | NRAS.2 0.0218849399159 285 | ATRX.1 0.0221263795055 286 | YAP1.2 0.0223457903385 287 | SMAD3 0.0229553285859 288 | FAM83D 0.0230161982691 289 | BAX 0.0230235160989 290 | PGR 0.0233142970367 291 | SLC6A3 0.0234873945053 292 | MDM2 0.0242933106381 293 | DDX5 0.0248924815353 294 | PLCG1 0.0254226352022 295 | CHEK1.2 0.0259893169876 296 | PRCP 0.0263816825925 297 | TNRC18 0.0270038761745 298 | PRKCD 0.027475904392 299 | EEF2K 0.0274883386151 300 | GFRA4 0.0276835465207 301 | AOX1 0.027947328855 302 | OAS2 0.0282771807102 303 | NEU2 0.028327321892 304 | 18p 0.0283547925446 305 | CDKN2C 0.0285976056688 306 | NUDT11 0.0287182894078 307 | ERBB2.2 0.0303218803399 308 | FAM123C 0.0305351344514 309 | ZMIZ1 0.0306096107524 310 | MAP2K1.1 0.030697280549 311 | MYB 0.0307362460071 312 | RFX4 0.0309800935521 313 | CDKN1B.2 0.0310239785415 314 | TMEM184A 0.0314110914208 315 | ARHGEF12 0.0317024052387 316 | LUM 0.0321180806378 317 | BAK1 0.0321634406651 318 | CCND2 0.0328714787281 319 | MET.1 0.0330642780296 320 | histological_type-Is-glioblastoma multiforme (gbm) 0.0332336220476 321 | HRAS 0.033558748385 322 | COL6A1 0.0341787407158 323 | QKI 0.0343212407489 324 | 20q 0.0343438947624 325 | ACVRL1 0.0345005749098 326 | WWTR1 0.0349680364658 327 | CD209 0.0351415266763 328 | G6PC 0.0357026080002 329 | 12q 0.0363682483727 330 | CDKN1A 0.0365016274485 331 | CBL 0.0367035099508 332 | RPS6KA1.2 0.0373039964508 333 | MYC.2 0.0374950184284 334 | EIF4EBP1.1 0.0375725190355 335 | PDK1.1 0.0377972255424 336 | SRC.1 0.0384266675013 337 | FN1 0.0384485638137 338 | YBX1.2 0.0390623865064 339 | RECQL4 0.0392755949773 340 | TSC2.2 0.0397107416564 341 | SQSTM1 0.0398340301558 342 | 20p 0.0398438952809 343 | CCNE1.1 0.0399548143038 344 | RAF1.1 0.0407836628826 345 | RPS6KB1.2 0.0413098591973 346 | EIF4EBP1.2 0.0415019149195 347 | KLK2 0.0421187770108 348 | XRCC1 0.0426242326805 349 | CDKN1B.4 0.0426307042391 350 | BRAF.1 0.0431735642907 351 | NF1.1 0.0438179315969 352 | CCND1 0.0453017507916 353 | DLC1 0.0468198483931 354 | SRSF1 0.0474304863792 355 | RBPJ 0.0474668674476 356 | YAP1.1 0.0495763659768 357 | CD1D 0.0512045098782 358 | CDK6 0.0512929188846 359 | PRKCA.1 0.0516877314074 360 | histological_type-Is-treated primary gbm 0.0519440635614 361 | FOXO3.1 0.0528561415364 362 | HEATR3 0.0545032770047 363 | 1q 0.0549741421708 364 | PDGFRA.2 0.0549889997306 365 | CARS 0.0559332875181 366 | 7q 0.0563507956147 367 | PPP2R1A 0.0565069326084 368 | ZNF292 0.0569193721814 369 | SOX2 0.056962687054 370 | C10orf76 0.0572922293857 371 | BAP1 0.0582055939539 372 | PIK3CA.1 0.0621980978787 373 | 1p 0.0629168177132 374 | PXN 0.0634873515814 375 | PTEN.1 0.0635950976533 376 | ESR1.2 0.0640195484648 377 | CYP11A1 0.0651325643179 378 | TGM2 0.0665064457675 379 | FIP1L1 0.0667880066599 380 | FOXO3.2 0.0678831989717 381 | 7p 0.0694452867823 382 | JUN 0.0704659297822 383 | REN 0.0706890712082 384 | RAB25 0.0707951699823 385 | YWHAZ 0.07102368948 386 | KRT13 0.0714862066696 387 | RB1.3 0.0743985958709 388 | PROKR2 0.0762762992568 389 | EGFR.1 0.0782844834438 390 | XRCC5 0.0790922707372 391 | KRT3 0.0812318444972 392 | PDGFRA.1 0.0825710434469 393 | ERRFI1 0.0879014786521 394 | PIK3R1 0.0903187027462 395 | radiation_therapy-Is-yes 0.0939297822349 396 | SERPINE1 0.0977994710848 397 | IGFBP2 0.116024386429 398 | histological_type-Is-untreated primary (de novo) gbm 0.151451944648 399 | age_at_initial_pathologic_diagnosis 0.179224246297 400 | -------------------------------------------------------------------------------- /results/Heatmap.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/Heatmap.pdf -------------------------------------------------------------------------------- /results/KM.10p_CNVArm.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/KM.10p_CNVArm.pdf -------------------------------------------------------------------------------- /results/KM.10q_CNVArm.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/KM.10q_CNVArm.pdf -------------------------------------------------------------------------------- /results/KM.CDKN2A_CNV.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/KM.CDKN2A_CNV.pdf -------------------------------------------------------------------------------- /results/KM.IDH1_Mut.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/KM.IDH1_Mut.pdf -------------------------------------------------------------------------------- /results/KM.IDH2_Mut.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/KM.IDH2_Mut.pdf -------------------------------------------------------------------------------- /results/KM.IGFBP2_Protein.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/KM.IGFBP2_Protein.pdf -------------------------------------------------------------------------------- /results/KM.PTEN_CNV.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/KM.PTEN_CNV.pdf -------------------------------------------------------------------------------- /results/KM.SMARCA4_Mut.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/KM.SMARCA4_Mut.pdf -------------------------------------------------------------------------------- /results/KM.age_at_initial_pathologic_diagnosis_Clinical.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/KM.age_at_initial_pathologic_diagnosis_Clinical.pdf -------------------------------------------------------------------------------- /results/KM.histological_type-Is-untreated primary (de novo) gbm_Clinical.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/KM.histological_type-Is-untreated primary (de novo) gbm_Clinical.pdf -------------------------------------------------------------------------------- /results/PairedScatter.Feature.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/PairedScatter.Feature.pdf -------------------------------------------------------------------------------- /results/PairedScatter.Gradient.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/PairedScatter.Gradient.pdf -------------------------------------------------------------------------------- /results/RankedBox.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/RankedBox.pdf -------------------------------------------------------------------------------- /results/c_index_list.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/c_index_list.mat -------------------------------------------------------------------------------- /results/glioma_integ_model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PathologyDataScience/SurvivalNet/83b99fffea15bca5f0017fa08bef61de9f165eea/results/glioma_integ_model -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | try: 2 | from setuptools import setup, find_packages 3 | except ImportError: 4 | from distutils.core import setup 5 | 6 | import os 7 | from pkg_resources import parse_requirements 8 | 9 | with open('README.md') as readme_file: 10 | readme = readme_file.read() 11 | 12 | with open('LICENSE') as f: 13 | license_str = f.read() 14 | 15 | try: 16 | with open('requirements.txt') as f: 17 | ireqs = parse_requirements(f.read()) 18 | except SyntaxError: 19 | raise 20 | requirements = [str(req) for req in ireqs] 21 | 22 | setup(name='survivalnet', 23 | version='0.1.0', 24 | description='Deep learning survival models', 25 | author='Emory University', 26 | author_email='lee.cooper@emory.edu', 27 | url='https://github.com/cooperlab/SurvivalNet', 28 | packages=find_packages(), 29 | package_dir={'survivalnet': 'survivalnet'}, 30 | include_package_data=True, 31 | install_requires=requirements, 32 | license=license_str, 33 | zip_safe=False, 34 | keywords='survivalnet', 35 | classifiers=[ 36 | 'Development Status :: 2 - Pre-Alpha', 37 | 'Environment :: Console', 38 | 'License :: OSI Approved :: Apache Software License', 39 | 'Operating System :: OS Independent', 40 | 'Programming Language :: Python :: 2', 41 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 42 | 'Topic :: Software Development :: Libraries :: Python Modules', 43 | ], 44 | ) 45 | -------------------------------------------------------------------------------- /survivalnet/__init__.py: -------------------------------------------------------------------------------- 1 | # sub-package optimization must be imported before model 2 | from . import optimization 3 | 4 | # sub-package model must be imported before train 5 | from . import model 6 | 7 | # must be imported before Run 8 | from .train import train 9 | 10 | # sub-packages with no internal dependencies 11 | from . import analysis 12 | 13 | # must be imported before Bayesian_Optimizaiton 14 | #from .CostFunction import cost_func, aggr_st_cost_func, st_cost_func 15 | 16 | #from .Bayesian_Optimization import tune 17 | 18 | #from .Run import Run 19 | 20 | # list out things that are available for public use 21 | __all__ = ( 22 | 23 | # functions and classes of this package 24 | 'train', 25 | 26 | # sub-packages 27 | 'model', 28 | 'optimization', 29 | 'analysis', 30 | ) 31 | -------------------------------------------------------------------------------- /survivalnet/analysis/FeatureAnalysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .RiskCohort import RiskCohort 3 | from .RiskCluster import RiskCluster 4 | from .Visualization import _SplitSymbols 5 | from .Visualization import _WrapSymbols 6 | from .Visualization import RankedBox 7 | from .Visualization import PairScatter 8 | from .Visualization import KMPlots 9 | from .WriteGCT import WriteGCT 10 | from .WriteRNK import WriteRNK 11 | 12 | 13 | def FeatureAnalysis(Model, Normalized, Raw, Symbols, Survival, Censored, 14 | NBox=10, NScatter=10, NKM=10, NCluster=100, Tau=0.05, 15 | Path=None): 16 | """ 17 | Generate visualizations of risk profiles. Backpropagation is used to 18 | 19 | Parameters: 20 | ----------- 21 | Model : class 22 | Model generated by finetuning. 23 | 24 | Normalized : array_like 25 | Numpy array containing normalized feature values used in training / 26 | finetuning. These are used to examine associations between feature values 27 | and cluster assignments. Features are in columns and samples are in rows. 28 | 29 | Raw : array_like 30 | Numpy array containing raw, unnormalized feature values. These are used to 31 | examine associations between feature values and cluster assignments. 32 | Features are in columns and samples are in rows. 33 | 34 | Symbols : array_like 35 | List containing strings describing features. See Notes below for 36 | restrictions on symbol names. 37 | 38 | Survival : array_like 39 | Array containing death or last followup values. 40 | 41 | Censored : array_like 42 | Array containing vital status at last followup. 1 (alive) or 0 (deceased). 43 | 44 | NPlot : scalar 45 | Number of features to include when generating boxplot. 46 | Features are scored by absolute mean gradient and the highest N magnitude 47 | features will be used to generate the plot. Default value = 10. 48 | 49 | NCluster : scalar 50 | Number of features to include when generating cluster analysis. 51 | Features are scored by absolute mean gradient and the highest N magnitude 52 | features will be used to generate the plot. Default value = 100. 53 | 54 | Tau : scalar 55 | Threshold for statistical significance when examining cluster associations. 56 | 57 | Path : string 58 | Path to store .pdf versions of plots generated. 59 | """ 60 | 61 | # wrap long symbols and remove leading and trailing whitespace 62 | Corrected, Types = _SplitSymbols(Symbols) 63 | Wrapped = _WrapSymbols(Corrected) 64 | 65 | # generate risk derivative profiles for cohort 66 | print "Generting risk gradient profiles..." 67 | Gradients = RiskCohort(Model, Normalized) 68 | 69 | # normalize risk derivative profiles 70 | Gradients = Gradients / np.outer(np.linalg.norm(Gradients, axis=1), 71 | np.ones((1, Gradients.shape[1]))) 72 | 73 | # re-order symbols, raw features, gradients by mean gradient value, trim 74 | Means = np.asarray(np.mean(Gradients, axis=0)) 75 | Order = np.argsort(-np.abs(Means)) 76 | cSymbols = [Wrapped[i] for i in Order] 77 | cTypes = [Types[i] for i in Order] 78 | cRaw = Raw[:, Order] 79 | cGradients = Gradients[:, Order] 80 | 81 | # generate ranked box plot series 82 | print "Generating risk gradient boxplot..." 83 | RBFig = RankedBox(cGradients[:, 0:NBox], 84 | [cSymbols[i] for i in np.arange(NBox)], 85 | [cTypes[i] for i in np.arange(NBox)], 86 | XLabel='Model Features', YLabel='Risk Gradient') 87 | 88 | # generate paired scatter plot for gradients 89 | print "Generating paired scatter gradient plots..." 90 | PSGradFig = PairScatter(cGradients[:, 0:NScatter], 91 | [cSymbols[i] for i in np.arange(NScatter)], 92 | [cTypes[i] for i in np.arange(NScatter)]) 93 | 94 | # generate paired scatter plot for features 95 | print "Generating paired scatter feature plots..." 96 | PSFeatFig = PairScatter(cRaw[:, 0:NScatter], 97 | [cSymbols[i] for i in np.arange(NScatter)], 98 | [cTypes[i] for i in np.arange(NScatter)]) 99 | 100 | # generate cluster plot 101 | print "Generating cluster analysis..." 102 | CFig, Labels = RiskCluster(cGradients[:, 0:NCluster], cRaw[:, 0:NCluster], 103 | [cSymbols[i] for i in np.arange(NCluster)], 104 | [cTypes[i] for i in np.arange(NCluster)], 105 | Tau) 106 | 107 | # generate Kaplan-Meier plots for individual features 108 | print "Generating Kaplan-Meier plots..." 109 | KMFigs = KMPlots(cGradients[:, 0:NKM], cRaw[:, 0:NKM], 110 | [cSymbols[i] for i in np.arange(NKM)], 111 | [cTypes[i] for i in np.arange(NKM)], 112 | Survival, Censored) 113 | 114 | # save figures 115 | print "Saving figures and outputs..." 116 | if Path is not None: 117 | # save standard figures 118 | RBFig.savefig(Path + 'RankedBox.pdf') 119 | PSGradFig.savefig(Path + 'PairedScatter.Gradient.pdf') 120 | PSFeatFig.savefig(Path + 'PairedScatter.Feature.pdf') 121 | CFig.savefig(Path + 'Heatmap.pdf') 122 | for i, Figure in enumerate(KMFigs): 123 | Figure.savefig(Path + 'KM.' + Symbols[Order[i]].strip() + '.pdf') 124 | 125 | # save tables 126 | WriteRNK(Corrected, Means, Path + 'Gradients.rnk') 127 | WriteGCT(Corrected, None, Gradients, Path + 'Gradients.gct') 128 | -------------------------------------------------------------------------------- /survivalnet/analysis/PathwayAnalysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .RiskCohort import RiskCohort 3 | from .Visualization import _SplitSymbols 4 | from .Visualization import _WrapSymbols 5 | from .Visualization import RankedBar 6 | from .Visualization import RankedBox 7 | from .Visualization import PairScatter 8 | from .Visualization import KMPlots 9 | 10 | 11 | def PathwayAnalysis(Model, Normalized, Symbols, SetNames, Sets, 12 | Survival, Censored, Alpha=0, GSEANormalize=False, 13 | NPlot=10, Path=None): 14 | """ 15 | Pathway-based analysis of model sensitivity to inputs. Used with models 16 | trained on pure gene expression features. Transforms risk gradients of 17 | features obtained from backpropagation into pathway enrichment scores. 18 | 19 | Parameters: 20 | ----------- 21 | Model : class 22 | Model class instance generated by finetuned training. 23 | 24 | Normalized : array_like 25 | Numpy array containing normalized feature values used in training / 26 | finetuning. These features will be mapped to gene pathways and should 27 | populate the pathways densely, as with whole-exome RNA sequencing. 28 | Features are in columns and samples are in rows. 29 | 30 | Symbols : array_like 31 | List containing strings describing features. These symbols should be 32 | harmonized with those used in the genesets for mapping between the two. 33 | 34 | SetNames : array_like 35 | List of strings containing gene set names. 36 | 37 | Sets : array_like 38 | List of lists with each containing the gene symbols for one gene set. 39 | 40 | Survival : array_like 41 | Array containing death or last followup values. 42 | 43 | Censored : array_like 44 | Array containing vital status at last followup. 1 (alive) or 0 (deceased). 45 | 46 | NPlot : scalar 47 | Number of features to include when generating boxplot. 48 | Features are scored by absolute mean gradient and the highest N magnitude 49 | features will be used to generate the plot. Default value = 10. 50 | 51 | Path : string 52 | Path to store .pdf versions of plots generated. 53 | """ 54 | 55 | # wrap long gene set names and remove leading and trailing whitespace 56 | CorrectedSetNames, Types = _SplitSymbols(SetNames) 57 | CorrectedSetNames = _WrapSymbols(CorrectedSetNames) 58 | 59 | # trim gene symbols names 60 | CorrectedSymbols = [Symbol[0:str.rfind(str(Symbol), '_')] 61 | for Symbol in Symbols] 62 | 63 | # generate risk derivative profiles for cohort 64 | print "Generting risk gradient profiles..." 65 | Gradients = RiskCohort(Model, Normalized) 66 | 67 | # perform GSEA on mean feature risk profiles 68 | ES = np.squeeze(SSGSEA(np.median(Gradients, axis=0)[np.newaxis, :], 69 | CorrectedSymbols, Sets, Alpha, False)) 70 | 71 | # re-order pathways by mean absolute enrichment score 72 | Order = np.argsort(-np.abs(ES)) 73 | cSetNames = [CorrectedSetNames[i] for i in Order] 74 | cTypes = Types 75 | cES = ES[Order] 76 | 77 | # generate ranked bar plot 78 | print "Generating mean enrichment score bar plot..." 79 | BarFig = RankedBar(cES[0:NPlot], 80 | [cSetNames[i] for i in np.arange(NPlot)], 81 | [cTypes[i] for i in np.arange(NPlot)], 82 | XLabel='Pathway', 83 | YLabel='Enrichment Score') 84 | 85 | # perform SSGSEA on feature risk gradient profiles 86 | SSES = SSGSEA(Gradients, CorrectedSymbols, Sets, Alpha, GSEANormalize) 87 | 88 | # re-order single-sample enrichment scores 89 | cSSES = SSES[:, Order] 90 | 91 | # generate ranked box plot series 92 | print "Generating single-sample enrichment boxplot..." 93 | BoxFig = RankedBox(cSSES[:, 0:NPlot], 94 | [cSetNames[i] for i in np.arange(NPlot)], 95 | [cTypes[i] for i in np.arange(NPlot)], 96 | XLabel='Pathway', 97 | YLabel='Single-Sample Enrichment Score') 98 | 99 | # generate paired scatter plot for gradients 100 | print "Generating single-sample enrichment scatter plots..." 101 | PSGradFig = PairScatter(cSSES[:, 0:NPlot], 102 | [cSetNames[i] for i in np.arange(NPlot)], 103 | [cTypes[i] for i in np.arange(NPlot)]) 104 | 105 | # generate Kaplan-Meier plots for individual features 106 | print "Generating single-sample enrichment Kaplan-Meier plots..." 107 | KMFigs = KMPlots(cSSES[:, 0:NPlot], cSSES[:, 0:NPlot], 108 | [cSetNames[i] for i in np.arange(NPlot)], 109 | [cTypes[i] for i in np.arange(NPlot)], 110 | Survival, Censored) 111 | 112 | # save figures 113 | print "Saving figures..." 114 | if Path is not None: 115 | 116 | # save standard figures 117 | BarFig.savefig(Path + 'Pathway.Bar.pdf') 118 | BoxFig.savefig(Path + 'Pathway.Box.pdf') 119 | PSGradFig.savefig(Path + 'Pathway.PairedScatter.pdf') 120 | for i, Figure in enumerate(KMFigs): 121 | Figure.savefig(Path + 'Pathway.KM.' + 122 | SetNames[Order[i]] + '.pdf') 123 | -------------------------------------------------------------------------------- /survivalnet/analysis/ReadGMT.py: -------------------------------------------------------------------------------- 1 | def ReadGMT(File): 2 | """ 3 | Reads a Gene Matrix Transposed (GMT) text file defining gene sets to 4 | generate lists containing gene set names, descriptions and gene set gene 5 | symbols. 6 | Parameters 7 | ---------- 8 | File : string 9 | Filename and path to a GMT file containing gene sets. 10 | Returns 11 | ------- 12 | GeneSets : array_like 13 | A list of strings containing the gene set names. 14 | Description : array_like 15 | A list of strings containing the gene set descriptions. 16 | Genes : array_like 17 | A list of lists, each containing the gene symbols for each gene set. 18 | Notes 19 | ----- 20 | Gene sets can be obtained from the Molecular Signatures Database (MSigDB) 21 | at http://software.broadinstitute.org/gsea/msigdb/. 22 | See Also 23 | -------- 24 | SSGSEA 25 | """ 26 | 27 | # initialize lists for GeneSets, Links and Genes 28 | SetNames = [] 29 | Descriptions = [] 30 | Genes = [] 31 | 32 | # open gmt file 33 | with open(File, 'r') as gmt: 34 | for Line in gmt: 35 | 36 | # parse line into geneset name, link and gene members 37 | gs, desc, gn = _ParseLine(Line) 38 | 39 | # append gene set to outputs 40 | SetNames.append(gs) 41 | Descriptions.append(desc) 42 | Genes.append(gn) 43 | 44 | return SetNames, Descriptions, Genes 45 | 46 | 47 | def _ParseLine(String): 48 | """ 49 | Parses a GMT file line into gene set name, descriptions and gene symbols. 50 | """ 51 | 52 | # split string into delimited components 53 | Words = String.split() 54 | 55 | # extract gene set, link, genes 56 | GeneSet = Words[0] 57 | Description = Words[1] 58 | Genes = Words[2:] 59 | Genes.sort() 60 | 61 | return GeneSet, Description, Genes 62 | -------------------------------------------------------------------------------- /survivalnet/analysis/RiskCluster.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import scipy.cluster.hierarchy as sch 5 | import scipy.spatial.distance as dist 6 | from scipy.stats import chisquare 7 | from scipy.stats.mstats import kruskalwallis 8 | 9 | # heatmap layout constants 10 | WINDOW_HEIGHT = 30 11 | WINDOW_WIDTH = 30 12 | SPACING = 0.01 # spacing between plot elements 13 | TRACK = 0.01 # height of individual track elements 14 | FEATURE_W = 0.1 # width of left dendrogram (clustering of features) 15 | FEATURE_X = SPACING # horizontal offset for left dendrogram 16 | FEATURE_Y = SPACING # vertical offset for left dendrogram 17 | HEATMAP_X = FEATURE_X + FEATURE_W + SPACING # horizontal offset for heatmap 18 | HEATMAP_Y = FEATURE_Y # vertical offset for heatmap 19 | HEATMAP_W = 1 - FEATURE_W - 3 * SPACING # width of heatmap 20 | SAMPLE_H = 0.1 # height of top dendrogram (clustering of samples) 21 | SAMPLE_W = HEATMAP_W # width of top dendrogram 22 | SAMPLE_X = HEATMAP_X # horizontal offset of top dendrogram 23 | TRACK_W = HEATMAP_W # width of tracks 24 | TRACK_X = HEATMAP_X # horizontal offset of tracks 25 | 26 | 27 | def RiskCluster(Gradients, Raw, Symbols, Types, Tau=0.05): 28 | """ 29 | Generates a clustering and heatmap given risk profiles generated by 30 | Risk_Cohort. Analyzed features to identify and display cluster association 31 | mutations and copy number variations. 32 | 33 | Parameters 34 | ---------- 35 | Gradients : array_like 36 | Numpy array containing feature/sample gradients obtained by Risk_Cohort. 37 | Features are in columns and samples are in rows. 38 | 39 | Raw : array_like 40 | Numpy array containing raw, unnormalized feature values. These are used to 41 | examine associations between feature values and cluster assignments. 42 | Features are in columns and samples are in rows. 43 | 44 | Symbols : array_like 45 | List containing strings describing features. 46 | 47 | Types: array_like 48 | List containing strings describing feature types (e.g. CNV, Mut, Clinical). 49 | See notes on allowed values of Types below. 50 | 51 | Tau : scalar 52 | Threshold for statistical significance when examining cluster associations. 53 | 54 | Returns 55 | ------- 56 | Figure : figure handle 57 | Handle to figure used for saving image to disk i.e. 58 | Figure.savefig('heatmap.pdf') 59 | 60 | SampleIndices : array_like 61 | Cluster labels for the risk-gradient clustering of input samples. 62 | 63 | Notes 64 | ----- 65 | Types like 'Mut' and 'CNV' that are generated as suffixes to feature names 66 | by the package tcgaintegrator are required analysis. 67 | 68 | See Also 69 | -------- 70 | RiskCohort, Visualize 71 | """ 72 | 73 | # copy data, re-order, normalize 74 | Normalized = Gradients.copy() 75 | Normalized = (Normalized - np.mean(Normalized, axis=0)) / \ 76 | np.std(Normalized, axis=0) 77 | 78 | # transpose so that samples are in columns 79 | Normalized = Normalized.transpose() 80 | 81 | # generate figure 82 | Figure = plt.figure(figsize=(WINDOW_WIDTH, WINDOW_HEIGHT)) 83 | 84 | # cluster samples and generate dendrogram 85 | SampleDist = dist.pdist(Normalized.T, 'correlation') 86 | SampleDist = dist.squareform(SampleDist) 87 | SampleLinkage = sch.linkage(SampleDist, method='average', 88 | metric='correlation') 89 | Labels = sch.fcluster(SampleLinkage, 0.7*max(SampleLinkage[:, 2]), 90 | 'distance') 91 | 92 | # cluster features and generate dendrogram 93 | FeatureDist = dist.pdist(Normalized, 'correlation') 94 | FeatureDist = dist.squareform(FeatureDist) 95 | FeatureLinkage = sch.linkage(FeatureDist, method='average', 96 | metric='correlation') 97 | 98 | # capture cluster associations 99 | Significant, SigTypes = ClusterAssociations(Raw, Symbols, Types, 100 | Labels, Tau) 101 | 102 | # calculate layout parameters 103 | TRACK_H = TRACK * len(Significant) # total height of tracks 104 | HEATMAP_H = 1 - SAMPLE_H - TRACK_H - 4 * SPACING 105 | TRACK_Y = HEATMAP_H + 2 * SPACING 106 | FEATURE_H = HEATMAP_H 107 | SAMPLE_Y = HEATMAP_H + TRACK_H + 3 * SPACING 108 | 109 | # layout and generate top dendrogram (samples) 110 | SampleHandle = Figure.add_axes([SAMPLE_X, SAMPLE_Y, SAMPLE_W, SAMPLE_H], 111 | frame_on=False) 112 | SampleDendrogram = sch.dendrogram(SampleLinkage) 113 | SampleHandle.set_xticks([]) 114 | SampleHandle.set_yticks([]) 115 | 116 | # define sample order 117 | SampleOrder = SampleDendrogram['leaves'] 118 | 119 | # layout and generate left dendrogram (features) 120 | FeatureHandle = Figure.add_axes([FEATURE_X, FEATURE_Y, 121 | FEATURE_W, FEATURE_H], 122 | frame_on=False) 123 | FeatureDendrogram = sch.dendrogram(FeatureLinkage, orientation='right') 124 | FeatureHandle.set_xticks([]) 125 | FeatureHandle.set_yticks([]) 126 | 127 | # reorder input matrices based on clustering and capture order 128 | Reordered = Normalized[:, SampleDendrogram['leaves']] 129 | Reordered = Reordered[FeatureDendrogram['leaves'], :] 130 | 131 | # layout and generate heatmap 132 | Heatmap = Figure.add_axes([HEATMAP_X, HEATMAP_Y, HEATMAP_W, HEATMAP_H], 133 | frame_on=False) 134 | Heatmap.matshow(Reordered, aspect='auto', origin='lower', 135 | cmap=plt.cm.bwr) 136 | Heatmap.set_xticks([]) 137 | Heatmap.set_yticks([]) 138 | 139 | # extract mutation values from raw features 140 | SigMut = [Significant[i] for i, tpe in enumerate(SigTypes) if tpe == "Mut"] 141 | Indices = [i for i, Symbol in enumerate(Symbols) if Symbol in set(SigMut)] 142 | Mutations = Raw[:, Indices] 143 | Mutations = Mutations[SampleOrder, :].T 144 | 145 | # extract CNV values from raw features 146 | SigCNV = [Significant[i] for i, tpe in enumerate(SigTypes) if tpe == "CNV"] 147 | Indices = [i for i, Symbol in enumerate(Symbols) if Symbol in set(SigCNV)] 148 | CNVs = Raw[:, Indices] 149 | CNVs = CNVs[SampleOrder, :].T 150 | 151 | # layout and generate mutation tracks 152 | gm = Figure.add_axes([TRACK_X, TRACK_Y + len(SigCNV)*TRACK, 153 | TRACK_W, TRACK_H - len(SigCNV)*TRACK], 154 | frame_on=False) 155 | cmap_g = mpl.colors.ListedColormap(['k', 'w']) 156 | gm.matshow(Mutations, aspect='auto', origin='lower', cmap=cmap_g) 157 | for i in range(len(SigMut)): 158 | gm.text(-SPACING, i / np.float(len(SigMut)) + 159 | 1/np.float(2*len(SigMut)), 160 | SigMut[i], fontsize=6, 161 | verticalalignment='center', 162 | horizontalalignment='right', 163 | transform=gm.transAxes) 164 | gm.set_xticks([]) 165 | gm.set_yticks([]) 166 | 167 | # layout and generate CNV tracks 168 | cnv = Figure.add_axes([TRACK_X, TRACK_Y, 169 | TRACK_W, TRACK_H - len(SigMut)*TRACK], 170 | frame_on=False) 171 | cnv.matshow(CNVs, aspect='auto', origin='lower', cmap=plt.cm.bwr, 172 | vmin=-2, vmax=2) 173 | for i in range(len(SigCNV)): 174 | cnv.text(-SPACING, i / np.float(len(SigCNV)) + 175 | 1/np.float(2*len(SigCNV)), 176 | SigCNV[i], fontsize=6, 177 | verticalalignment='center', 178 | horizontalalignment='right', 179 | transform=cnv.transAxes) 180 | cnv.set_xticks([]) 181 | cnv.set_yticks([]) 182 | 183 | # return cluster labels 184 | return Figure, Labels 185 | 186 | 187 | def ClusterAssociations(Raw, Symbols, Types, Labels, Tau=0.05): 188 | """ 189 | Examines associations between cluster assigments of samples and copy-number 190 | and mutation events. 191 | 192 | Parameters 193 | ---------- 194 | Raw : array_like 195 | Numpy array containing raw, unnormalized feature values. These are used to 196 | examine associations between feature values and cluster assignments. 197 | Features are in columns and samples are in rows. 198 | 199 | Symbols : array_like 200 | List containing strings describing features. See Notes below for 201 | restrictions on symbol names. 202 | 203 | Types: array_like 204 | List containing strings describing feature types (e.g. CNV, Mut, Clinical). 205 | See notes on allowed values of Types below. 206 | 207 | Labels : array_like 208 | Cluster labels for the samples in 'Raw'. 209 | 210 | Tau : scalar 211 | Threshold for statistical significance when examining cluster associations. 212 | 213 | Returns 214 | ------- 215 | Significant : array_like 216 | List of copy number and mutation features from 'Raw' that are significantly 217 | associated with the clustering 'Labels'. 218 | 219 | SigTypes : array_like 220 | List of types for significant features. 221 | 222 | Notes 223 | ----- 224 | Types like 'Mut' and 'CNV' that are generated as suffixes to feature names 225 | by the package tcgaintegrator are required analysis. 226 | 227 | See Also 228 | -------- 229 | RiskCohort, RiskCluster 230 | """ 231 | 232 | # initialize list of symbols with significant associations and their types 233 | Significant = [] 234 | SigTypes = [] 235 | 236 | # identify mutations and CNVs 237 | Mutations = [index for index, tpe in enumerate(Types) if tpe == "Mut"] 238 | CNVs = [index for index, tpe in enumerate(Types) if tpe == "CNV"] 239 | 240 | # test mutation associations 241 | for i in np.arange(len(Mutations)): 242 | 243 | # build contingency table - expected and observed 244 | Observed = np.zeros((2, np.max(Labels))) 245 | for j in np.arange(1, np.max(Labels)+1): 246 | Observed[0, j-1] = np.sum(Raw[Labels == j, Mutations[i]] == 0) 247 | Observed[1, j-1] = np.sum(Raw[Labels == j, Mutations[i]] == 1) 248 | RowSum = np.sum(Observed, axis=0) 249 | ColSum = np.sum(Observed, axis=1) 250 | Expected = np.outer(ColSum, RowSum) / np.sum(Observed.flatten()) 251 | 252 | # perform test 253 | stat, p = chisquare(Observed, Expected, ddof=1, axis=None) 254 | if p < Tau: 255 | Significant.append(Symbols[Mutations[i]]) 256 | SigTypes.append(Types[Mutations[i]]) 257 | 258 | # copy number associations 259 | for i in np.arange(len(CNVs)): 260 | 261 | # separate out CNV values by cluster and perform test - hack for bad 262 | # interfact to scipy kruskalwallis 263 | if(np.max(Labels) == 2): 264 | CNV1 = Raw[Labels == 1, CNVs[i]] 265 | CNV2 = Raw[Labels == 2, CNVs[i]] 266 | stat, p = kruskalwallis(CNV1, CNV2) 267 | elif(np.max(Labels) == 3): 268 | CNV1 = Raw[Labels == 1, CNVs[i]] 269 | CNV2 = Raw[Labels == 2, CNVs[i]] 270 | CNV3 = Raw[Labels == 3, CNVs[i]] 271 | stat, p = kruskalwallis(CNV1, CNV2, CNV3) 272 | elif(np.max(Labels) == 4): 273 | CNV1 = Raw[Labels == 1, CNVs[i]] 274 | CNV2 = Raw[Labels == 2, CNVs[i]] 275 | CNV3 = Raw[Labels == 3, CNVs[i]] 276 | CNV4 = Raw[Labels == 4, CNVs[i]] 277 | stat, p = kruskalwallis(CNV1, CNV2, CNV3, CNV4) 278 | elif(np.max(Labels) == 5): 279 | CNV1 = Raw[Labels == 1, CNVs[i]] 280 | CNV2 = Raw[Labels == 2, CNVs[i]] 281 | CNV3 = Raw[Labels == 3, CNVs[i]] 282 | CNV4 = Raw[Labels == 4, CNVs[i]] 283 | CNV5 = Raw[Labels == 5, CNVs[i]] 284 | stat, p = kruskalwallis(CNV1, CNV2, CNV3, CNV4, CNV5) 285 | if p < Tau: 286 | Significant.append(Symbols[CNVs[i]]) 287 | SigTypes.append(Types[CNVs[i]]) 288 | 289 | # return names of features with significant associations 290 | return Significant, SigTypes 291 | -------------------------------------------------------------------------------- /survivalnet/analysis/RiskCohort.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano as th 3 | import theano.tensor as T 4 | 5 | 6 | def RiskCohort(Model, Features): 7 | """ 8 | Generates partial derivative weights of features and risk for the given 9 | model and profiles. Generates mean and standard deviation of these partial 10 | derivative weights. 11 | 12 | Parameters: 13 | ---------- 14 | Model : theano deep learning model 15 | a theano Neuralnetwork model containing theano functions to feed the 16 | model with a profile of feature values. 17 | 18 | Profiles : array_like 19 | a matrix of profiles containing features value to be used as a input 20 | for feeding the model. 21 | 22 | Output : 23 | -------- 24 | Gradients: numpy matrix 25 | A [N*D] matrix contains feature weights.(N = number of profiles and D = 26 | number of features(dimention of input to the model )). each row contains 27 | feature weights of Correspondence profile in profiles matrix. 28 | 29 | Gradients_mean: numpy matrix 30 | a [ 1 * D] matrix contains mean value of feautre weights. 31 | 32 | Gradients_std: numpy matrix 33 | a [ 1 * D] matrix contains standrad deviation of feautre weights. 34 | """ 35 | 36 | # initialize container for risk gradient profiles 37 | Gradients = np.zeros(Features.shape) 38 | 39 | # copy input to matrix for Theano 40 | Matrix = np.matrix(Features) 41 | 42 | # iterate through samples, calculating risk gradient profile for each 43 | for i in np.arange(Features.shape[0]): 44 | Gradients[i, :] = _RiskBackpropagate(Model, Matrix[i, :]) 45 | 46 | return Gradients 47 | 48 | 49 | def _RiskBackpropagate(Model, Features): 50 | """ 51 | Generates partial derivatives of input features in a neural network model. 52 | These represent the rate of change of risk with respect to each input 53 | feature. 54 | 55 | Parameters: 56 | ---------- 57 | Model : class 58 | A fine tuned model generated by training. 59 | 60 | Features : array_like 61 | A 1 x P numpy array of features corresponding to one sample. 62 | 63 | Output: 64 | ---------- 65 | Gradient : array_like 66 | A 1 67 | an array of the feature weights. 68 | """ 69 | 70 | # define partial derivative 71 | X = T.matrix('X') 72 | AtRisk = T.ivector('AtRisk') 73 | Observed = T.ivector('Observed') 74 | Is_train = T.scalar('Is_train', dtype='int32') 75 | masks = T.lmatrix('mask_' + str(0)) 76 | partial_derivative = th.function(on_unused_input='ignore', 77 | inputs=[X, AtRisk, Observed, Is_train, masks], 78 | outputs=T.grad(Model.risk_layer.output[0], 79 | Model.x), 80 | givens={Model.x: X, Model.o: AtRisk, 81 | Model.at_risk: Observed, 82 | Model.is_train: Is_train, 83 | Model.masks[0]: masks}, 84 | name='partial_derivative') 85 | 86 | # define parameters for risk and calculate partial 87 | sample_O = np.array([0]).astype(np.int32) 88 | sample_T = np.array([0]).astype(np.int32) 89 | # Create dummy masks for graph 90 | dummy_masks = np.ones((1, Model.n_hidden), dtype='int64') 91 | Gradient = partial_derivative(Features, sample_O, sample_T, 0, dummy_masks) 92 | 93 | return Gradient 94 | -------------------------------------------------------------------------------- /survivalnet/analysis/Visualization.py: -------------------------------------------------------------------------------- 1 | from lifelines import KaplanMeierFitter 2 | from lifelines.statistics import logrank_test 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from statsmodels.nonparametric.smoothers_lowess import lowess 6 | from textwrap import wrap 7 | 8 | # define colors for positive risk (red) and negative risk (blue) 9 | REDFACE = '#DE2D26' 10 | BLUEFACE = '#3182BD' 11 | REDEDGE = '#DE2D26' 12 | BLUEEDGE = '#3182BD' 13 | MEDIAN = '#000000' 14 | WHISKER = '#AAAAAA' 15 | POINTS = '#000000' 16 | GRID = '#BBBBBB' 17 | 18 | # layout constants for boxplot 19 | BOX_HSPACE = 0.15 20 | BOX_VSPACE = 0.4 21 | BOX_FH = 5 # boxplot figure width 22 | BOX_FW = 8 # boxplot figure height 23 | JITTER = 0.08 24 | BOX_FONT = 8 25 | 26 | # layout constants for pairwise feature plot 27 | PAIR_FW = 10 28 | PAIR_SPACING = 0.1 29 | 30 | # layout constants for survival plot 31 | SURV_FW = 10 32 | SURV_FH = 6 33 | SURV_HSPACE = 0.1 34 | SURV_VSPACE = 0.1 35 | SURV_FONT = 8 36 | 37 | 38 | def RankedBar(Profile, Symbols, Types, XLabel=None, YLabel=None): 39 | """ 40 | Generates a bar plot of feature gradients or enrichment scores ranked by 41 | magnitude. 42 | 43 | Parameters: 44 | ---------- 45 | Profile : array_like 46 | Numpy array containing 1-dimensional feature/sample gradients or enrichment 47 | scoresobtained. 48 | 49 | Symbols : array_like 50 | List containing strings describing features in profile. 51 | 52 | Types : array_like 53 | List containing strings describing feature types (e.g. CNV, Mut, Clinical). 54 | 55 | XLabel : string 56 | Label for y axis. Default value = None 57 | 58 | YLabel : string 59 | Label for y axis. Default value = None 60 | 61 | Returns: 62 | -------- 63 | Figure : figure handle 64 | Handle to figure used for saving image to disk i.e. 65 | Figure.savefig('heatmap.pdf') 66 | 67 | Notes: 68 | ------ 69 | Features are displayed in the order they are provided. Any sorting should 70 | happen prior to calling. 71 | """ 72 | 73 | # generate figure and add axes 74 | Figure = plt.figure(figsize=(BOX_FW, BOX_FH), facecolor='white') 75 | Axes = Figure.add_axes([BOX_HSPACE, BOX_VSPACE, 76 | 1-BOX_HSPACE, 1-BOX_VSPACE], 77 | frame_on=False) 78 | Axes.set_axis_bgcolor('white') 79 | 80 | # generate bars 81 | Bars = Axes.bar(np.linspace(1, len(Profile), len(Profile)), Profile, 82 | align='center') 83 | 84 | # modify box styling 85 | for i, bar in enumerate(Bars): 86 | if Profile[i] <= 0: 87 | bar.set(color=BLUEEDGE, linewidth=2) 88 | bar.set(facecolor=BLUEFACE) 89 | else: 90 | bar.set(color=REDEDGE, linewidth=2) 91 | bar.set(facecolor=REDFACE) 92 | 93 | # set limits 94 | Axes.set_ylim(1.05 * Profile.min(), 1.05 * Profile.max()) 95 | 96 | # format x axis 97 | if XLabel is not None: 98 | plt.xlabel(XLabel) 99 | plt.xticks(np.linspace(1, len(Profile), len(Profile)), 100 | [Symbols[i] + " _" + Types[i] for i in np.arange(len(Profile))], 101 | rotation='vertical', fontsize=BOX_FONT) 102 | Axes.set_xticks(np.linspace(1.5, len(Profile)-0.5, 103 | len(Profile)-1), minor=True) 104 | Axes.xaxis.set_ticks_position('bottom') 105 | 106 | # format y axis 107 | if YLabel is not None: 108 | plt.ylabel(YLabel) 109 | Axes.yaxis.set_ticks_position('left') 110 | 111 | # add grid lines and zero line 112 | Axes.xaxis.grid(True, color=GRID, linestyle='-', which='minor') 113 | plt.plot([0, len(Profile)+0.5], [0, 0], color='black') 114 | 115 | return Figure 116 | 117 | 118 | def RankedBox(Gradients, Symbols, Types, XLabel=None, YLabel=None): 119 | """ 120 | Generates boxplot series of feature gradients ranked by absolute magnitude. 121 | 122 | Parameters: 123 | ---------- 124 | Gradients : array_like 125 | Numpy array containing feature/sample gradients obtained by RiskCohort. 126 | Features are in columns and samples are in rows. 127 | 128 | Symbols : array_like 129 | List containing strings describing features. 130 | 131 | Types: array_like 132 | List containing strings describing feature types (e.g. CNV, Mut, Clinical). 133 | 134 | XLabel : string 135 | Label for y axis. Default value = None 136 | 137 | YLabel : string 138 | Label for y axis. Default value = None 139 | 140 | Returns: 141 | -------- 142 | Figure : figure handle 143 | Handle to figure used for saving image to disk i.e. 144 | Figure.savefig('heatmap.pdf') 145 | 146 | Notes: 147 | ------ 148 | Features are displayed in the order they are provided. Any sorting should 149 | happen prior to calling. 150 | """ 151 | 152 | # generate figure and add axes 153 | Figure = plt.figure(figsize=(BOX_FW, BOX_FH), facecolor='white') 154 | Axes = Figure.add_axes([BOX_HSPACE, BOX_VSPACE, 155 | 1-BOX_HSPACE, 1-BOX_VSPACE], 156 | frame_on=False) 157 | Axes.set_axis_bgcolor('white') 158 | 159 | # generate boxplots 160 | Box = Axes.boxplot(Gradients, patch_artist=True, showfliers=False) 161 | 162 | # set global properties 163 | plt.setp(Box['medians'], color=MEDIAN, linewidth=1) 164 | plt.setp(Box['whiskers'], color=WHISKER, linewidth=1, linestyle='-') 165 | plt.setp(Box['caps'], color=WHISKER, linewidth=1) 166 | 167 | # modify box styling 168 | for i, box in enumerate(Box['boxes']): 169 | if np.mean(Gradients[:, i]) <= 0: 170 | box.set(color=BLUEEDGE, linewidth=2) 171 | box.set(facecolor=BLUEFACE) 172 | else: 173 | box.set(color=REDEDGE, linewidth=2) 174 | box.set(facecolor=REDFACE) 175 | 176 | # add jittered data overlays 177 | for i in np.arange(Gradients.shape[1]): 178 | plt.scatter(np.random.normal(i+1, JITTER, size=Gradients.shape[0]), 179 | Gradients[:, i], color=POINTS, alpha=0.2, 180 | marker='o', s=2, zorder=100) 181 | 182 | # set limits 183 | Axes.set_ylim(1.05 * Gradients.min(), 1.05 * Gradients.max()) 184 | 185 | # format x axis 186 | if XLabel is not None: 187 | plt.xlabel(XLabel) 188 | plt.xticks(np.linspace(1, Gradients.shape[1], Gradients.shape[1]), 189 | [Symbols[i] + " _" + Types[i] for i in 190 | np.arange(Gradients.shape[1])], 191 | rotation='vertical', fontsize=BOX_FONT) 192 | Axes.set_xticks(np.linspace(1.5, Gradients.shape[1]-0.5, 193 | Gradients.shape[1]-1), minor=True) 194 | Axes.xaxis.set_ticks_position('bottom') 195 | 196 | # format y axis 197 | if YLabel is not None: 198 | plt.ylabel(YLabel) 199 | Axes.yaxis.set_ticks_position('left') 200 | 201 | # add grid lines and zero line 202 | Axes.xaxis.grid(True, color=GRID, linestyle='-', which='minor') 203 | plt.plot([0, Gradients.shape[1]+0.5], [0, 0], color='black') 204 | 205 | return Figure 206 | 207 | 208 | def PairScatter(Gradients, Symbols, Types): 209 | """ 210 | Generates boxplot series of feature gradients ranked by absolute magnitude. 211 | 212 | Parameters: 213 | ---------- 214 | 215 | Gradients : array_like 216 | Numpy array containing feature/sample gradients obtained by RiskCohort. 217 | Features are in columns and samples are in rows. 218 | 219 | Symbols : array_like 220 | List containing strings describing features. 221 | 222 | Types: array_like 223 | List containing strings describing feature types (e.g. CNV, Mut, Clinical). 224 | 225 | Returns: 226 | -------- 227 | Figure : figure handle 228 | Handle to figure used for saving image to disk i.e. 229 | Figure.savefig('heatmap.pdf') 230 | 231 | Notes: 232 | ------ 233 | Features are displayed in the order they are provided. Any sorting should 234 | happen prior to calling. 235 | """ 236 | 237 | # calculate means, standard deviations 238 | Means = np.asarray(np.mean(Gradients, axis=0)) 239 | Std = np.asarray(np.std(Gradients, axis=0)) 240 | 241 | # generate subplots 242 | Figure, Axes = plt.subplots(nrows=Gradients.shape[1], 243 | ncols=Gradients.shape[1], 244 | figsize=(PAIR_FW, PAIR_FW), 245 | facecolor='white') 246 | Figure.subplots_adjust(hspace=PAIR_SPACING, wspace=PAIR_SPACING, 247 | bottom=PAIR_SPACING) 248 | 249 | # remove axes and ticks 250 | for ax in Axes.flat: 251 | ax.xaxis.set_visible(False) 252 | ax.yaxis.set_visible(False) 253 | 254 | # generate scatter plots in lower triangular portion 255 | for i, j in zip(*np.triu_indices_from(Axes, k=1)): 256 | Axes[i, j].scatter((Gradients[:, j]-Means[j]) / Std[j], 257 | (Gradients[:, i]-Means[i]) / Std[i], 258 | color=POINTS, alpha=0.2, marker='o', s=2) 259 | Smooth = lowess((Gradients[:, j]-Means[j]) / Std[j], 260 | (Gradients[:, i]-Means[i]) / Std[i]) 261 | Axes[i, j].plot(Smooth[:, 1], Smooth[:, 0], color='red') 262 | 263 | # generate histograms on diagonal 264 | for i in np.arange(Gradients.shape[1]): 265 | if Means[i] <= 0: 266 | Axes[i, i].hist(Gradients[:, i], 267 | facecolor=BLUEFACE, 268 | alpha=0.8) 269 | else: 270 | Axes[i, i].hist(Gradients[:, i], 271 | facecolor=REDFACE, 272 | alpha=0.8) 273 | Axes[i, i].annotate(Symbols[i] + " _" + Types[i], (0, 0), 274 | xycoords='axes fraction', 275 | ha='right', va='top', 276 | rotation=45) 277 | 278 | # delete unused axes 279 | for i, j in zip(*np.tril_indices_from(Axes, k=-1)): 280 | Figure.delaxes(Axes[i, j]) 281 | 282 | return Figure 283 | 284 | 285 | def KMPlots(Gradients, Raw, Symbols, Types, Survival, Censored): 286 | 287 | """ 288 | Generates KM plots for individual features ranked by absolute magnitude. 289 | 290 | Parameters: 291 | ---------- 292 | 293 | Gradients : array_like 294 | Numpy array containing feature/sample gradients obtained by RiskCohort. 295 | Features are in columns and samples are in rows. 296 | 297 | Raw : array_like 298 | Numpy array containing raw, unnormalized feature values. These are used to 299 | examine associations between feature values and cluster assignments. 300 | Features are in columns and samples are in rows. 301 | 302 | Symbols : array_like 303 | List containing strings describing features. 304 | 305 | Types: array_like 306 | List containing strings describing feature types (e.g. CNV, Mut, Clinical). 307 | See notes on allowed values of Types below. 308 | 309 | Survival : array_like 310 | Array containing death or last followup values. 311 | 312 | Censored : array_like 313 | Array containing vital status at last followup. 1 (alive) or 0 (deceased). 314 | 315 | Returns 316 | ------- 317 | Figures : figure handle 318 | List containing handles to figures. 319 | 320 | Names : array_like 321 | List of feature names for figures in 'Figures' 322 | 323 | Notes 324 | ----- 325 | Types like 'Mut' and 'CNV' that are generated as suffixes to feature names 326 | by the package tcgaintegrator are required analysis. 327 | Note this uses feature values as opposed to back-propagated risk gradients. 328 | Features are displayed in the order they are provided. Any sorting should 329 | happen prior to calling. 330 | """ 331 | 332 | # initialize list of figures and names 333 | Figures = [] 334 | 335 | # generate Kaplan Meier fitter 336 | kmf = KaplanMeierFitter() 337 | 338 | # generate KM plot for each feature 339 | for count, i in enumerate(np.arange(Gradients.shape[1])): 340 | 341 | # generate figure and axes 342 | Figures.append(plt.figure(figsize=(SURV_FW, SURV_FH), 343 | facecolor='white')) 344 | Axes = Figures[count].add_axes([SURV_HSPACE, SURV_VSPACE, 345 | 1-2*SURV_HSPACE, 1-2*SURV_VSPACE]) 346 | 347 | # initialize log-rank test result 348 | LogRank = None 349 | 350 | if Types[i] == 'Clinical': 351 | 352 | # get unique values to determine if binary or continuous 353 | Unique = np.unique(Raw[:, i]) 354 | 355 | # process based on variable type 356 | if Unique.size == 2: 357 | 358 | # extract and plot mutant and wild-type survival profiles 359 | if np.sum(Raw[:, i] == Unique[0]): 360 | kmf.fit(Survival[Raw[:, i] == Unique[0]], 361 | 1-Censored[Raw[:, i] == Unique[0]] == 1, 362 | label=Symbols[i] + str(Unique[0])) 363 | kmf.plot(ax=Axes, show_censors=True) 364 | if np.sum(Raw[:, i] == Unique[1]): 365 | kmf.fit(Survival[Raw[:, i] == Unique[1]], 366 | 1-Censored[Raw[:, i] == Unique[1]] == 1, 367 | label=Symbols[i] + str(Unique[1])) 368 | kmf.plot(ax=Axes, show_censors=True) 369 | if np.sum(Raw[:, i] == Unique[0]) & \ 370 | np.sum(Raw[:, i] == Unique[1]): 371 | LogRank = logrank_test(Survival[Raw[:, i] == Unique[0]], 372 | Survival[Raw[:, i] == Unique[1]], 373 | 1-Censored[Raw[:, i] == Unique[0]] 374 | == 1, 375 | 1-Censored[Raw[:, i] == Unique[1]] 376 | == 1) 377 | plt.ylim(0, 1) 378 | if LogRank is not None: 379 | plt.title('Logrank p=' + str(LogRank.p_value)) 380 | lg = plt.gca().get_legend() 381 | plt.setp(lg.get_texts(), fontsize=SURV_FONT) 382 | 383 | else: 384 | 385 | # determine median value 386 | Median = np.median(Raw[:, i]) 387 | 388 | # extract and altered and unaltered survival profiles 389 | if np.sum(Raw[:, i] > Median): 390 | kmf.fit(Survival[Raw[:, i] > Median], 391 | 1-Censored[Raw[:, i] > Median] == 1, 392 | label=Symbols[i] + " > " + str(Median)) 393 | kmf.plot(ax=Axes, show_censors=True) 394 | if np.sum(Raw[:, i] <= Median): 395 | kmf.fit(Survival[Raw[:, i] <= Median], 396 | 1-Censored[Raw[:, i] <= Median] == 1, 397 | label=Symbols[i] + " <= " + str(Median)) 398 | kmf.plot(ax=Axes, show_censors=True) 399 | if np.sum(Raw[:, i] > Median) & np.sum(Raw[:, i] <= Median): 400 | LogRank = logrank_test(Survival[Raw[:, i] > Median], 401 | Survival[Raw[:, i] <= Median], 402 | 1-Censored[Raw[:, i] > Median] 403 | == 1, 404 | 1-Censored[Raw[:, i] <= Median] 405 | == 1) 406 | plt.ylim(0, 1) 407 | if LogRank is not None: 408 | plt.title('Logrank p=' + str(LogRank.p_value)) 409 | lg = plt.gca().get_legend() 410 | plt.setp(lg.get_texts(), fontsize=SURV_FONT) 411 | 412 | elif Types[i] == 'Mut': 413 | 414 | # extract and plot mutant and wild-type survival profiles 415 | if np.sum(Raw[:, i] == 1): 416 | kmf.fit(Survival[Raw[:, i] == 1], 417 | 1-Censored[Raw[:, i] == 1] == 1, 418 | label=Symbols[i] + " Mutant") 419 | kmf.plot(ax=Axes, show_censors=True) 420 | if np.sum(Raw[:, i] == 0): 421 | kmf.fit(Survival[Raw[:, i] == 0], 422 | 1-Censored[Raw[:, i] == 0] == 1, 423 | label=Symbols[i] + " WT") 424 | kmf.plot(ax=Axes, show_censors=True) 425 | if np.sum(Raw[:, i] == 1) & np.sum(Raw[:, i] == 0): 426 | LogRank = logrank_test(Survival[Raw[:, i] == 0], 427 | Survival[Raw[:, i] == 1], 428 | 1-Censored[Raw[:, i] == 0] == 1, 429 | 1-Censored[Raw[:, i] == 1] == 1) 430 | plt.ylim(0, 1) 431 | lg = plt.gca().get_legend() 432 | if LogRank is not None: 433 | plt.title('Logrank p=' + str(LogRank.p_value)) 434 | plt.setp(lg.get_texts(), fontsize=SURV_FONT) 435 | 436 | elif Types[i] == 'CNV': 437 | 438 | # determine if alteration is amplification or deletion 439 | Amplified = np.mean(Raw[:, i]) > 0 440 | 441 | # extract and plot altered and unaltered survival profiles 442 | if Amplified: 443 | kmf.fit(Survival[Raw[:, i] > 0], 444 | 1-Censored[Raw[:, i] > 0] == 1, 445 | label=Symbols[i] + " " + Types[i] + " Amplified") 446 | kmf.plot(ax=Axes, show_censors=True) 447 | if(np.sum(Raw[:, i] <= 0)): 448 | kmf.fit(Survival[Raw[:, i] <= 0], 449 | 1-Censored[Raw[:, i] <= 0] == 1, 450 | label=Symbols[i] + " " + Types[i] + 451 | " not Amplified") 452 | kmf.plot(ax=Axes, show_censors=True) 453 | LogRank = logrank_test(Survival[Raw[:, i] > 0], 454 | Survival[Raw[:, i] <= 0], 455 | 1-Censored[Raw[:, i] > 0] == 1, 456 | 1-Censored[Raw[:, i] <= 0] == 1) 457 | else: 458 | kmf.fit(Survival[Raw[:, i] < 0], 459 | 1-Censored[Raw[:, i] < 0] == 1, 460 | label=Symbols[i] + " " + Types[i] + " Deleted") 461 | kmf.plot(ax=Axes, show_censors=True) 462 | if(np.sum(Raw[:, i] >= 0)): 463 | kmf.fit(Survival[Raw[:, i] >= 0], 464 | 1-Censored[Raw[:, i] >= 0] == 1, 465 | label=Symbols[i] + " " + Types[i] + " not Deleted") 466 | kmf.plot(ax=Axes, show_censors=True) 467 | LogRank = logrank_test(Survival[Raw[:, i] < 0], 468 | Survival[Raw[:, i] >= 0], 469 | 1-Censored[Raw[:, i] < 0] == 1, 470 | 1-Censored[Raw[:, i] >= 0] == 1) 471 | if LogRank is not None: 472 | plt.title('Logrank p=' + str(LogRank.p_value)) 473 | plt.ylim(0, 1) 474 | lg = plt.gca().get_legend() 475 | plt.setp(lg.get_texts(), fontsize=SURV_FONT) 476 | 477 | elif Types[i] == 'CNVArm': 478 | 479 | # determine if alteration is amplification or deletion 480 | Amplified = np.mean(Raw[:, i]) > 0 481 | 482 | # extract and plot altered and unaltered survival profiles 483 | if Amplified: 484 | if(np.sum(Raw[:, i] > 0.25)): 485 | kmf.fit(Survival[Raw[:, i] > 0.25], 486 | 1-Censored[Raw[:, i] > 0.25] == 1, 487 | label=Symbols[i] + " " + Types[i] + " Amplified") 488 | kmf.plot(ax=Axes, show_censors=True) 489 | if(np.sum(Raw[:, i] <= 0.25)): 490 | kmf.fit(Survival[Raw[:, i] <= 0.25], 491 | 1-Censored[Raw[:, i] <= 0.25] == 1, 492 | label=Symbols[i] + " " + Types[i] + 493 | " not Amplified") 494 | kmf.plot(ax=Axes, show_censors=True) 495 | if(np.sum(Raw[:, i] > 0.25) & np.sum(Raw[:, i] <= 0.25)): 496 | LogRank = logrank_test(Survival[Raw[:, i] > 0.25], 497 | Survival[Raw[:, i] <= 0.25], 498 | 1-Censored[Raw[:, i] > 0.25] == 1, 499 | 1-Censored[Raw[:, i] <= 0.25] == 1) 500 | else: 501 | if np.sum(Raw[:, i] < -0.25): 502 | kmf.fit(Survival[Raw[:, i] < -0.25], 503 | 1-Censored[Raw[:, i] < -0.25] == 1, 504 | label=Symbols[i] + " " + Types[i] + " Deleted") 505 | kmf.plot(ax=Axes, show_censors=True) 506 | if np.sum(Raw[:, i] >= -0.25): 507 | kmf.fit(Survival[Raw[:, i] >= -0.25], 508 | 1-Censored[Raw[:, i] >= -0.25] == 1, 509 | label=Symbols[i] + " " + Types[i] + " not Deleted") 510 | kmf.plot(ax=Axes, show_censors=True) 511 | if np.sum(Raw[:, i] < -0.25) & np.sum(Raw[:, i] >= -0.25): 512 | LogRank = logrank_test(Survival[Raw[:, i] < -0.25], 513 | Survival[Raw[:, i] >= -0.25], 514 | 1-Censored[Raw[:, i] < -0.25] == 1, 515 | 1-Censored[Raw[:, i] >= -0.25] == 1) 516 | plt.ylim(0, 1) 517 | lg = plt.gca().get_legend() 518 | if LogRank is not None: 519 | plt.title('Logrank p=' + str(LogRank.p_value)) 520 | plt.setp(lg.get_texts(), fontsize=SURV_FONT) 521 | 522 | elif (Types[i] == 'Protein') or (Types[i] == 'mRNA'): 523 | 524 | # determine median expression 525 | Median = np.median(Raw[:, i]) 526 | 527 | # extract and altered and unaltered survival profiles 528 | if np.sum(Raw[:, i] > Median): 529 | kmf.fit(Survival[Raw[:, i] > Median], 530 | 1-Censored[Raw[:, i] > Median] == 1, 531 | label=Symbols[i] + " " + Types[i] + 532 | " Higher Expression") 533 | kmf.plot(ax=Axes, show_censors=True) 534 | if np.sum(Raw[:, i] <= Median): 535 | kmf.fit(Survival[Raw[:, i] <= Median], 536 | 1-Censored[Raw[:, i] <= Median] == 1, 537 | label=Symbols[i] + " " + Types[i] + 538 | " Lower Expression") 539 | kmf.plot(ax=Axes, show_censors=True) 540 | if np.sum(Raw[:, i] > Median) & np.sum(Raw[:, i] <= Median): 541 | LogRank = logrank_test(Survival[Raw[:, i] > Median], 542 | Survival[Raw[:, i] <= Median], 543 | 1-Censored[Raw[:, i] > Median] == 1, 544 | 1-Censored[Raw[:, i] <= Median] == 1) 545 | plt.ylim(0, 1) 546 | if LogRank is not None: 547 | plt.title('Logrank p=' + str(LogRank.p_value)) 548 | lg = plt.gca().get_legend() 549 | plt.setp(lg.get_texts(), fontsize=SURV_FONT) 550 | 551 | elif (Types[i] == 'PATHWAY'): 552 | 553 | # determine median expression 554 | Median = np.median(Raw[:, i]) 555 | 556 | # extract and altered and unaltered survival profiles 557 | if np.sum(Raw[:, i] > Median): 558 | kmf.fit(Survival[Raw[:, i] > Median], 559 | 1-Censored[Raw[:, i] > Median] == 1, 560 | label=Symbols[i] + " Higher Enrichment") 561 | kmf.plot(ax=Axes, show_censors=True) 562 | if np.sum(Raw[:, i] <= Median): 563 | kmf.fit(Survival[Raw[:, i] <= Median], 564 | 1-Censored[Raw[:, i] <= Median] == 1, 565 | label=Symbols[i] + " Lower Enrichment") 566 | kmf.plot(ax=Axes, show_censors=True) 567 | if np.sum(Raw[:, i] > Median) & np.sum(Raw[:, i] <= Median): 568 | LogRank = logrank_test(Survival[Raw[:, i] > Median], 569 | Survival[Raw[:, i] <= Median], 570 | 1-Censored[Raw[:, i] > Median] == 1, 571 | 1-Censored[Raw[:, i] <= Median] == 1) 572 | plt.ylim(0, 1) 573 | if LogRank is not None: 574 | plt.title('Logrank p=' + str(LogRank.p_value)) 575 | lg = plt.gca().get_legend() 576 | plt.setp(lg.get_texts(), fontsize=SURV_FONT) 577 | 578 | else: 579 | raise ValueError('Unrecognized feature type ' + '"' + 580 | Types[i] + '"') 581 | 582 | return Figures 583 | 584 | 585 | def _SplitSymbols(Symbols): 586 | """ 587 | Removes trailing and leading whitespace, separates feature types from 588 | feature names, enumerates duplicate symbol names 589 | """ 590 | 591 | # modify duplicate symbols where needed - append index to each instance 592 | Prefix = [Symbol[0:str.rfind(str(Symbol), '_')] for Symbol in Symbols] 593 | Types = [Symbol[str.rfind(str(Symbol), '_')+1:].strip() 594 | for Symbol in Symbols] 595 | 596 | # copy prefixes 597 | Corrected = Prefix[:] 598 | 599 | # append index to each duplicate instance 600 | for i in np.arange(len(Prefix)): 601 | if Prefix.count(Prefix[i]) > 1: 602 | Corrected[i] = Prefix[i] + '.' + \ 603 | str(Prefix[0:i+1].count(Prefix[i])) 604 | else: 605 | Corrected[i] = Prefix[i] 606 | 607 | return Corrected, Types 608 | 609 | 610 | def _WrapSymbols(Symbols, Length=20): 611 | """ 612 | Wraps long labels 613 | """ 614 | 615 | # remove whitespace and wrap 616 | Corrected = ['\n'.join(wrap(Symbol.strip().replace('_', ' '), Length)) 617 | for Symbol in Symbols] 618 | 619 | return Corrected 620 | -------------------------------------------------------------------------------- /survivalnet/analysis/WriteGCT.py: -------------------------------------------------------------------------------- 1 | def WriteGCT(Genes, Samples, Scores, File): 2 | """ 3 | Writes a gene expression (GCT) file format defining the scores of features 4 | for individual samples. 5 | 6 | Parameters 7 | ---------- 8 | Genes : array_like 9 | An N-length list of gene symbols associated with the rows of 'Scores'. 10 | 11 | Samples : array_like 12 | A K-length list of sample identifiers associated with the columns of 13 | 'Scores'. If None samples will be enumerated. 14 | 15 | Scores : array_like 16 | An NxK-length numpy array containing the signed values associated with 17 | 'Genes' and Samples. 18 | 19 | File : string 20 | Filename and path to write the output .gct file to. 21 | 22 | Notes 23 | ----- 24 | This is typically used to form the input for a SSGSEA analysis. See 25 | http://www.broadinstitute.org/cancer/software/gsea/wiki/index.php for more 26 | details on file formats. 27 | 28 | See Also 29 | -------- 30 | FeatureAnalysis 31 | """ 32 | 33 | # open rnk file 34 | try: 35 | Gct = open(File, 'w') 36 | except IOError: 37 | print "Cannot create file ", File 38 | 39 | # write leading rows 40 | Gct.write('#1.2\n') 41 | Gct.write(str(Scores.shape[1]) + '\t' + str(Scores.shape[0]) + '\n') 42 | Gct.write("NAME\tDescription\t") 43 | if Samples is None: 44 | for i in range(Scores.shape[0]-1): 45 | Gct.write("Sample." + str(i+1) + '\t') 46 | Gct.write("Sample." + str(Scores.shape[0]) + '\n') 47 | else: 48 | for i, Sample in enumerate(Samples): 49 | if i < len(Samples)-1: 50 | Gct.write(Sample + '\t') 51 | else: 52 | Gct.write(Sample + '\n') 53 | 54 | # write contents to file 55 | for i, Symbol in enumerate(Genes): 56 | Gct.write(Symbol + '\t\t') 57 | for j in range(Scores.shape[0]-1): 58 | Gct.write(str(Scores[j, i]) + '\t') 59 | Gct.write(str(Scores[-1, i]) + '\n') 60 | 61 | # close file 62 | Gct.close() 63 | -------------------------------------------------------------------------------- /survivalnet/analysis/WriteRNK.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def WriteRNK(Genes, Scores, File): 5 | """ 6 | Writes a ranked-list (RNK) file format defining the ranks of features. 7 | Features are sorted based on the signed values provided in 'Scores'. 8 | 9 | Parameters 10 | ---------- 11 | Genes : array_like 12 | An N-length list of gene symbols associated with 'Scores'. 13 | 14 | Scores : array_like 15 | An N-length numpy array containing the signed values associated with 16 | 'Genes'. 17 | 18 | File : string 19 | Filename and path to write the output .rnk file to. 20 | 21 | Notes 22 | ----- 23 | This is typically used to form the input for a GSEAPreranked analysis. See 24 | http://www.broadinstitute.org/cancer/software/gsea/wiki/index.php for more 25 | details on file formats. 26 | 27 | See Also 28 | -------- 29 | FeatureAnalysis 30 | """ 31 | 32 | # sort inputs by signed score 33 | Order = np.argsort(Scores) 34 | 35 | # open rnk file 36 | try: 37 | Rnk = open(File, 'w') 38 | except IOError: 39 | print "Cannot create file ", File 40 | 41 | # write contents to file 42 | for i in Order: 43 | Rnk.write(Genes[i] + '\t' + str(Scores[i]) + '\n') 44 | 45 | # close file 46 | Rnk.close() 47 | -------------------------------------------------------------------------------- /survivalnet/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | from .RiskCohort import RiskCohort 2 | 3 | # must be imported after RiskCohort 4 | from .FeatureAnalysis import FeatureAnalysis 5 | from .PathwayAnalysis import PathwayAnalysis 6 | from .Visualization import KMPlots 7 | from .Visualization import PairScatter 8 | from .Visualization import RankedBar 9 | from .Visualization import RankedBox 10 | from .WriteGCT import WriteGCT 11 | from .WriteRNK import WriteRNK 12 | from .ReadGMT import ReadGMT 13 | from .RiskCluster import RiskCluster 14 | 15 | 16 | # list functions and classes available for public use 17 | __all__ = ( 18 | 'FeatureAnalysis', 19 | 'KMPlots', 20 | 'PairScatter', 21 | 'PathwayAnalysis', 22 | 'RankedBar', 23 | 'RankedBox', 24 | 'ReadGMT', 25 | 'RiskCluster', 26 | 'RiskCohort', 27 | 'WriteGCT', 28 | 'WriteRNK', 29 | ) 30 | -------------------------------------------------------------------------------- /survivalnet/model/DropoutHiddenLayer.py: -------------------------------------------------------------------------------- 1 | __docformat__ = 'restructedtext en' 2 | 3 | import theano 4 | import theano.tensor as T 5 | from .HiddenLayer import HiddenLayer 6 | from theano.ifelse import ifelse 7 | import numpy as np 8 | 9 | class DropoutHiddenLayer(HiddenLayer): 10 | def __init__(self, rng, input, n_in, n_out, is_train, 11 | activation, dropout_rate, mask=None, W=None, b=None): 12 | super(DropoutHiddenLayer, self).__init__( 13 | rng=rng, input=input, n_in=n_in, n_out=n_out, W=W, b=b, 14 | activation=activation) 15 | 16 | self.dropout_rate = dropout_rate 17 | self.srng = T.shared_randomstreams.RandomStreams(rng.randint(999999)) 18 | self.mask = mask 19 | self.layer = self.output 20 | 21 | # Computes outputs for train and test phase applying dropout when needed. 22 | train_output = self.layer * T.cast(self.mask, theano.config.floatX) 23 | test_output = self.output * (1 - dropout_rate) 24 | self.output = ifelse(T.eq(is_train, 1), train_output, test_output) 25 | return 26 | -------------------------------------------------------------------------------- /survivalnet/model/HiddenLayer.py: -------------------------------------------------------------------------------- 1 | __docformat__ = 'restructedtext en' 2 | 3 | import numpy 4 | import theano 5 | import theano.tensor as T 6 | 7 | 8 | class HiddenLayer(object): 9 | def __init__(self, rng, input, n_in, n_out, W=None, b=None, 10 | activation=T.tanh): 11 | """ 12 | Typical hidden layer of a MLP: units are fully-connected and have 13 | sigmoidal activation function. Weight matrix W is of shape (n_in,n_out) 14 | and the bias vector b is of shape (n_out,). 15 | Hidden unit activation is given by: activation(dot(input,W) + b) 16 | :type rng: numpy.random.RandomState 17 | :param rng: a random number generator used to initialize weights 18 | :type input: theano.tensor.dmatrix 19 | :param input: a symbolic tensor of shape (n_examples, n_in) 20 | :type n_in: int 21 | :param n_in: dimensionality of input 22 | :type n_out: int 23 | :param n_out: number of hidden units 24 | :type activation: theano.Op or function 25 | :param activation: Non linearity to be applied in the hidden 26 | layer 27 | """ 28 | self.input = input 29 | # `W` is initialized with `W_values` which is uniformely sampled 30 | # from sqrt(-6./(n_in+n_hidden)) and sqrt(6./(n_in+n_hidden)) 31 | # for tanh activation function 32 | # the output of uniform if converted using asarray to dtype 33 | # theano.config.floatX so that the code is runable on GPU 34 | # Note : optimal initialization of weights is dependent on the 35 | # activation function used (among other things). 36 | # For example, results presented in [Xavier10] suggest that you 37 | # should use 4 times larger initial weights for sigmoid 38 | # compared to tanh 39 | # We have no info for other functions, so we use the same as 40 | # tanh. 41 | if W is None: 42 | W_values = numpy.asarray( 43 | rng.uniform( 44 | low=-numpy.sqrt(6. / (n_in + n_out)), 45 | high=numpy.sqrt(6. / (n_in + n_out)), 46 | size=(n_in, n_out) 47 | ), 48 | dtype=theano.config.floatX 49 | ) 50 | if activation == T.nnet.sigmoid: 51 | W_values *= 4 52 | W = theano.shared(value=W_values, name='W', borrow=True) 53 | 54 | if b is None: 55 | b_values = numpy.zeros((n_out,), dtype=theano.config.floatX) 56 | b = theano.shared(value=b_values, name='b', borrow=True) 57 | 58 | self.W = W 59 | self.b = b 60 | 61 | lin_output = T.dot(input, self.W) + self.b 62 | self.output = ( 63 | lin_output if activation is None 64 | else activation(lin_output) 65 | ) 66 | # parameters of the model 67 | self.params = [self.W, self.b] 68 | 69 | def reset_weight(self, params): 70 | self.W.set_value(params[0]) 71 | self.b.set_value(params[1]) 72 | 73 | def reset_weight_by_rate(self, rate): 74 | if rate != 0: 75 | self.W.set_value(self.W.get_value() / rate) 76 | self.b.set_value(self.b.get_value() / rate) 77 | 78 | -------------------------------------------------------------------------------- /survivalnet/model/Model.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import theano 3 | import theano.tensor as T 4 | from theano.tensor.shared_randomstreams import RandomStreams 5 | 6 | from .RiskLayer import RiskLayer 7 | from .HiddenLayer import HiddenLayer 8 | from .DropoutHiddenLayer import DropoutHiddenLayer 9 | from .SparseDenoisingAutoencoder import SparseDenoisingAutoencoder as dA 10 | from survivalnet.optimization import Optimization as Opt 11 | 12 | 13 | class Model(object): 14 | """ This class is made to pretrain and fine tune a variable number of layers.""" 15 | def __init__( 16 | self, 17 | numpy_rng, 18 | theano_rng=None, 19 | n_ins=183, 20 | hidden_layers_sizes=[250, 250], 21 | n_outs=1, 22 | corruption_levels=[0.1, 0.1], 23 | dropout_rate=0.1, 24 | lambda1 = 0, 25 | lambda2 = 0, 26 | non_lin=None 27 | ): 28 | """ 29 | :type numpy_rng: numpy.random.RandomState 30 | :param numpy_rng: numpy random number generator used to draw initial 31 | weights 32 | :type theano_rng: theano.tensor.shared_randomstreams.RandomStreams 33 | :param theano_rng: Theano random generator; if None is given one is 34 | generated based on a seed drawn from `rng` 35 | :type n_ins: int 36 | :param n_ins: dimension of the input to the Model 37 | 38 | :type hidden_layers_sizes: list of ints 39 | :param hidden_layers_sizes: sizes of intermediate layers. 40 | 41 | :type n_outs: int 42 | :param n_outs: dimension of the output of the network. Always 1 for a 43 | regression problem. 44 | 45 | :type corruption_levels: list of float 46 | :param corruption_levels: amount of corruption to use for each layer 47 | 48 | :type dropout_rate: float 49 | :param dropout_rate: probability of dropping a hidden unit 50 | 51 | :type non_lin: function 52 | :param non_lin: nonlinear activation function used in all layers 53 | 54 | """ 55 | # Initializes parameters. 56 | self.hidden_layers = []; 57 | self.dA_layers = []; 58 | self.params = []; 59 | self.dropout_masks = [] 60 | self.n_layers = len(hidden_layers_sizes); 61 | self.L1 = 0; 62 | self.L2_sqr = 0; 63 | self.n_hidden = hidden_layers_sizes[0] 64 | if not theano_rng: 65 | theano_rng = RandomStreams(numpy_rng.randint(2 ** 30)) 66 | 67 | # Allocates symbolic variables for the data. 68 | self.x = T.matrix('x', dtype='float32') 69 | self.o = T.ivector('o') 70 | self.at_risk = T.ivector('at_risk') 71 | self.is_train = T.iscalar('is_train') 72 | self.masks = [T.lmatrix('mask_' + str(i)) for i in range(self.n_layers)] 73 | 74 | # Linear cox regression with no hidden layers. 75 | if self.n_layers == 0: 76 | self.risk_layer = RiskLayer(input=self.x, n_in=n_ins, n_out=n_outs, rng = numpy_rng) 77 | else: 78 | # Constructs the intermediate layers. 79 | for i in xrange(self.n_layers): 80 | if i == 0: 81 | input_size = n_ins 82 | layer_input = self.x 83 | else: 84 | input_size = hidden_layers_sizes[i - 1] 85 | layer_input = self.hidden_layers[-1].output 86 | 87 | if dropout_rate > 0: 88 | hidden_layer = DropoutHiddenLayer(rng=numpy_rng, 89 | input=layer_input, 90 | n_in=input_size, 91 | n_out=hidden_layers_sizes[i], 92 | activation=non_lin, 93 | dropout_rate=dropout_rate, 94 | is_train=self.is_train, 95 | mask = self.masks[i]) 96 | else: 97 | hidden_layer = HiddenLayer(rng=numpy_rng, 98 | input=layer_input, 99 | n_in=input_size, 100 | n_out=hidden_layers_sizes[i], 101 | activation=non_lin) 102 | 103 | # Adds the layer to the stack of layers. 104 | self.hidden_layers.append(hidden_layer) 105 | self.params.extend(hidden_layer.params) 106 | 107 | # Constructs an autoencoder that shares weights with this layer. 108 | dA_layer = dA(numpy_rng=numpy_rng, 109 | theano_rng=theano_rng, 110 | input=layer_input, 111 | n_visible=input_size, 112 | n_hidden=hidden_layers_sizes[i], 113 | W=hidden_layer.W, 114 | bhid=hidden_layer.b, 115 | non_lin=non_lin) 116 | self.dA_layers.append(dA_layer) 117 | 118 | self.L1 += abs(hidden_layer.W).sum() 119 | self.L2_sqr += (hidden_layer.W ** 2).sum() 120 | 121 | # Adds a risk prediction layer on top of the stack. 122 | self.risk_layer = RiskLayer(input=self.hidden_layers[-1].output, 123 | n_in=hidden_layers_sizes[-1], 124 | n_out=n_outs, 125 | rng = numpy_rng) 126 | 127 | self.L1 += abs(self.risk_layer.W).sum() 128 | self.L2_sqr += (self.risk_layer.W ** 2).sum() 129 | self.params.extend(self.risk_layer.params) 130 | self.regularizers = lambda1 * self.L1 + lambda2 * self.L2_sqr 131 | 132 | def pretraining_functions(self, pretrain_x, batch_size): 133 | index = T.lscalar('index') # index to a minibatch 134 | corruption_level = T.scalar('corruption') # % of corruption 135 | learning_rate = T.scalar('lr') # learning rate 136 | 137 | 138 | if batch_size: 139 | # begining of a batch, given `index` 140 | batch_begin = index * batch_size 141 | # ending of a batch given `index` 142 | batch_end = batch_begin + batch_size 143 | pretrain_x = pretrain_x[batch_begin: batch_end] 144 | 145 | pretrain_fns = [] 146 | is_train = numpy.cast['int32'](0) # value does not matter 147 | for dA_layer in self.dA_layers: 148 | # get the cost and the updates list 149 | cost, updates = dA_layer.get_cost_updates(corruption_level, 150 | learning_rate) 151 | # compile the theano function 152 | fn = theano.function( 153 | on_unused_input='ignore', 154 | inputs=[ 155 | index, 156 | theano.Param(corruption_level, default=0.2), 157 | theano.Param(learning_rate, default=0.1) 158 | ], 159 | outputs=cost, 160 | updates=updates, 161 | givens={ 162 | self.x: pretrain_x, 163 | self.is_train: is_train 164 | } 165 | ) 166 | pretrain_fns.append(fn) 167 | 168 | return pretrain_fns 169 | 170 | def build_finetune_functions(self, learning_rate): 171 | 172 | is_train = T.iscalar('is_train') 173 | X = T.matrix('X', dtype='float32') 174 | at_risk = T.ivector('at_risk') 175 | observed = T.ivector('observed') 176 | opt = Opt() 177 | 178 | test = theano.function( 179 | on_unused_input='ignore', 180 | inputs=[X, observed, at_risk, is_train] + self.masks, 181 | outputs=[self.risk_layer.cost(self.o, self.at_risk), 182 | self.risk_layer.output, self.risk_layer.input], 183 | givens={ 184 | self.x: X, 185 | self.o: observed, 186 | self.at_risk: at_risk, 187 | self.is_train:is_train 188 | }, 189 | name='test' 190 | ) 191 | train = theano.function( 192 | on_unused_input='ignore', 193 | inputs=[X, observed, at_risk, is_train] + self.masks, 194 | outputs=[self.risk_layer.cost(self.o, self.at_risk), 195 | self.risk_layer.output, self.risk_layer.input], 196 | updates=opt.SGD( 197 | self.risk_layer.cost(self.o, self.at_risk)-self.regularizers, 198 | self.params, learning_rate), 199 | givens={ 200 | self.x: X, 201 | self.o: observed, 202 | self.at_risk: at_risk, 203 | self.is_train:is_train 204 | }, 205 | name='train' 206 | ) 207 | return test, train 208 | 209 | def reset_weight(self, params): 210 | for i in xrange(self.n_layers): 211 | self.hidden_layers[i].reset_weight((params[2*i], params[2*i+1])) 212 | self.risk_layer.reset_weight(params[-1]) 213 | 214 | def reset_weight_by_rate(self, rate): 215 | for i in xrange(self.n_layers): 216 | self.hidden_layers[i].reset_weight_by_rate(rate) 217 | 218 | def update_layers(self): 219 | for l in self.hidden_layers: 220 | l.update_layer() 221 | 222 | -------------------------------------------------------------------------------- /survivalnet/model/RiskLayer.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import theano 3 | import theano.tensor as T 4 | import theano.tensor.extra_ops as Te 5 | 6 | class RiskLayer(object): 7 | def __init__(self, input, n_in, n_out, rng): 8 | # Initializes randomly the weights W as a matrix of shape (n_in, n_out). 9 | self.W = theano.shared( 10 | value = numpy.asarray( 11 | rng.uniform( 12 | low=-numpy.sqrt(6. / (n_in + n_out)), 13 | high=numpy.sqrt(6. / (n_in + n_out)), 14 | size=(n_in, n_out) 15 | ), 16 | dtype=theano.config.floatX 17 | ), 18 | name='W', 19 | borrow=True 20 | ) 21 | 22 | self.input = input 23 | self.output = T.dot(self.input, self.W ).flatten() 24 | self.params = [self.W ] 25 | 26 | def cost(self, observed, at_risk): 27 | """Calculates the cox negative log likelihood. 28 | 29 | Args: 30 | observed: 1D array. Event status; 0 means censored. 31 | at_risk: 1D array. Element i of this array indicates the index of the 32 | first patient in the at risk group of patient i, when patients 33 | are sorted by increasing time to event. 34 | Returns: 35 | Objective function to be maximized. 36 | """ 37 | prediction = self.output 38 | # Subtracts maximum to facilitate computation. 39 | factorizedPred = prediction - prediction.max() 40 | exp = T.exp(factorizedPred)[::-1] 41 | # Calculates the reversed partial cumulative sum. 42 | partial_sum = Te.cumsum(exp)[::-1] + 1 43 | # Adds the subtracted maximum back. 44 | log_at_risk = T.log(partial_sum[at_risk]) + prediction.max() 45 | diff = prediction - log_at_risk 46 | cost = T.sum(T.dot(observed, diff)) 47 | return cost 48 | 49 | def reset_weight(self, params): 50 | self.W.set_value(params) 51 | 52 | -------------------------------------------------------------------------------- /survivalnet/model/SparseDenoisingAutoencoder.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | import theano 4 | import theano.tensor as T 5 | from theano.tensor.shared_randomstreams import RandomStreams 6 | 7 | class SparseDenoisingAutoencoder(object): 8 | """Sparse Denoising Auto-Encoder class (dA) 9 | A denoising autoencoders tries to reconstruct the input from a corrupted 10 | version of it by projecting it first in a latent space and reprojecting 11 | it afterwards back in the input space. Refer to Vincent et al.,2008 for 12 | details. If x is the input then equation (1) computes a partially 13 | destroyed version of x by means of a stochastic mapping q_D. Equation (2) 14 | computes the projection of the input into the latent space. Equation (3) 15 | computes the reconstruction of the input, while equation (4) computes the 16 | reconstruction error. 17 | .. math:: 18 | \tilde{x} ~ q_D(\tilde{x}|x) (1) 19 | y = s(W \tilde{x} + b) (2) 20 | x = s(W' y + b') (3) 21 | L(x,z) = -sum_{k=1}^d [x_k \log z_k + (1-x_k) \log( 1-z_k)] (4) 22 | """ 23 | 24 | def __init__( 25 | self, 26 | numpy_rng, 27 | theano_rng=None, 28 | input=None, 29 | n_visible=784, 30 | n_hidden=500, 31 | W=None, 32 | bhid=None, 33 | bvis=None, 34 | non_lin=None, 35 | ce=False 36 | ): 37 | """ 38 | Initialize the dA class by specifying the number of visible units (the 39 | dimension d of the input ), the number of hidden units ( the dimension 40 | d' of the latent or hidden space ) and the corruption level. The 41 | constructor also receives symbolic variables for the input, weights and 42 | bias. Such a symbolic variables are useful when, for example the input 43 | is the result of some computations, or when weights are shared between 44 | the dA and an MLP layer. 45 | :type numpy_rng: numpy.random.RandomState 46 | :param numpy_rng: number random generator used to generate weights 47 | :type theano_rng: theano.tensor.shared_randomstreams.RandomStreams 48 | :param theano_rng: Theano random generator; if None is given one is 49 | generated based on a seed drawn from `rng` 50 | :type input: theano.tensor.TensorType 51 | :param input: a symbolic description of the input or None for 52 | standalone dA 53 | :type n_visible: int 54 | :param n_visible: number of visible units 55 | :type n_hidden: int 56 | :param n_hidden: number of hidden units 57 | :type W: theano.tensor.TensorType 58 | :param W: Theano variable pointing to a set of weights that should be 59 | shared belong the dA and another architecture; if dA should 60 | be standalone set this to None 61 | :type bhid: theano.tensor.TensorType 62 | :param bhid: Theano variable pointing to a set of biases values (for 63 | hidden units) that should be shared belong dA and another 64 | architecture; if dA should be standalone set this to None 65 | :type bvis: theano.tensor.TensorType 66 | :param bvis: Theano variable pointing to a set of biases values (for 67 | visible units) that should be shared belong dA and another 68 | architecture; if dA should be standalone set this to None 69 | :type ce: boolean 70 | :param ce: Boolean determining whether to use cross entropy or 71 | mean squared error for cost 72 | 73 | """ 74 | self.non_lin = non_lin 75 | self.n_visible = n_visible 76 | self.n_hidden = n_hidden 77 | self.ce = ce 78 | # create a Theano random generator that gives symbolic random values 79 | if not theano_rng: 80 | theano_rng = RandomStreams(numpy_rng.randint(2 ** 30)) 81 | 82 | # note : W' was written as `W_prime` and b' as `b_prime` 83 | if not W: 84 | # W is initialized with `initial_W` which is uniformely sampled 85 | # from -4*sqrt(6./(n_visible+n_hidden)) and 86 | # 4*sqrt(6./(n_hidden+n_visible))the output of uniform if 87 | # converted using asarray to dtype 88 | # theano.config.floatX so that the code is runable on GPU 89 | initial_W = numpy.asarray( 90 | numpy_rng.uniform( 91 | low=-4 * numpy.sqrt(6. / (n_hidden + n_visible)), 92 | high=4 * numpy.sqrt(6. / (n_hidden + n_visible)), 93 | size=(n_visible, n_hidden) 94 | ), 95 | dtype=theano.config.floatX 96 | ) 97 | W = theano.shared(value=initial_W, name='W', borrow=True) 98 | 99 | if not bvis: 100 | bvis = theano.shared( 101 | value=numpy.zeros( 102 | n_visible, 103 | dtype=theano.config.floatX 104 | ), 105 | borrow=True 106 | ) 107 | 108 | if not bhid: 109 | bhid = theano.shared( 110 | value=numpy.zeros( 111 | n_hidden, 112 | dtype=theano.config.floatX 113 | ), 114 | name='b', 115 | borrow=True 116 | ) 117 | 118 | self.W = W 119 | # b corresponds to the bias of the hidden 120 | self.b = bhid 121 | # b_prime corresponds to the bias of the visible 122 | self.b_prime = bvis 123 | # tied weights, therefore W_prime is W transpose 124 | self.W_prime = self.W.T 125 | self.theano_rng = theano_rng 126 | # if no input is given, generate a variable representing the input 127 | if input is None: 128 | # we use a matrix because we expect a minibatch of several 129 | # examples, each example being a row 130 | self.x = T.dmatrix(name='input') 131 | else: 132 | self.x = input 133 | 134 | self.params = [self.W, self.b, self.b_prime] 135 | 136 | def get_corrupted_input(self, input, corruption_level): 137 | """This function keeps ``1 - corruption_level`` entries of the inputs the 138 | same and zero-out randomly selected subset of size ``coruption_level`` 139 | Note : first argument of theano.rng.binomial is the shape(size) of 140 | random numbers that it should produce 141 | second argument is the number of trials 142 | third argument is the probability of success of any trial 143 | this will produce an array of 0s and 1s where 1 has a 144 | probability of 1 - ``corruption_level`` and 0 with 145 | ``corruption_level`` 146 | The binomial function return int64 data type by 147 | default. int64 multiplicated by the input 148 | type(floatX) always return float64. To keep all data 149 | in floatX when floatX is float32, we set the dtype of 150 | the binomial to floatX. As in our case the value of 151 | the binomial is always 0 or 1, this don't change the 152 | result. This is needed to allow the gpu to work 153 | correctly as it only support float32 for now. 154 | """ 155 | return self.theano_rng.binomial(size=input.shape, n=1, 156 | p=1 - corruption_level, 157 | dtype=theano.config.floatX) * input 158 | 159 | def get_hidden_values(self, input): 160 | """ Computes the values of the hidden layer """ 161 | return self.non_lin((T.dot(input, self.W) + self.b)) 162 | 163 | def get_reconstructed_input(self, hidden): 164 | """Computes the reconstructed input given the values of the 165 | hidden layer 166 | """ 167 | return self.non_lin((T.dot(hidden, self.W_prime) + self.b_prime)) 168 | 169 | def get_cost_updates(self, corruption_level, learning_rate): 170 | """ This function computes the cost and the updates for one trainng 171 | step of the dA """ 172 | 173 | tilde_x = self.get_corrupted_input(self.x, corruption_level) 174 | y = self.get_hidden_values(tilde_x) 175 | z = self.get_reconstructed_input(y) 176 | # note : we sum over the size of a datapoint; if we are using 177 | # minibatches, L will be a vector, with one entry per 178 | # example in minibatch 179 | if (self.ce): 180 | L = - T.sum(self.x * T.log(z) + (1 - self.x) * T.log(1 - z), axis=1) 181 | else: 182 | L = T.sum((self.x - z) ** 2, axis=1) 183 | # note : L is now a vector, where each element is the 184 | # cross-entropy or mean squared error cost of the reconstruction of the 185 | # corresponding example of the minibatch. We need to 186 | # compute the average of all these to get the cost of 187 | # the minibatch 188 | cost = T.mean(L) 189 | 190 | # compute the gradients of the cost of the `dA` with respect 191 | # to its parameters 192 | gparams = T.grad(cost, self.params) 193 | # generate the list of updates 194 | updates = [ 195 | (param, param - learning_rate * gparam) 196 | for param, gparam in zip(self.params, gparams) 197 | ] 198 | 199 | return (cost, updates) 200 | -------------------------------------------------------------------------------- /survivalnet/model/__init__.py: -------------------------------------------------------------------------------- 1 | # import sub-packages to support nested calls 2 | from .HiddenLayer import HiddenLayer 3 | from .DropoutHiddenLayer import DropoutHiddenLayer 4 | from .RiskLayer import RiskLayer 5 | from .SparseDenoisingAutoencoder import SparseDenoisingAutoencoder 6 | from .Model import Model 7 | 8 | # list functions and classes available for public use 9 | __all__ = ( 10 | 'HiddenLayer', 11 | 'DropoutHiddenLayer', 12 | 'RiskLayer', 13 | 'SparseDenoisingAutoencoder', 14 | 'Model', 15 | ) 16 | -------------------------------------------------------------------------------- /survivalnet/optimization/BFGS.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import time 3 | import warnings 4 | import numpy 5 | import theano.tensor as T 6 | import scipy 7 | from scipy.optimize.optimize import _line_search_wolfe12, _LineSearchError 8 | 9 | 10 | class BFGS(object): 11 | 12 | def __init__(self, model, x, o, atrisk): 13 | self.cost = model.risk_layer.cost 14 | self.params = model.params 15 | is_tr = T.iscalar('is_train') 16 | self.theta_shape = sum([self.params[i].get_value().size for i in range(len(self.params))]) 17 | self.old_old_fval = None 18 | N = self.theta_shape 19 | self.H_t = numpy.eye(N, dtype=numpy.float32) 20 | 21 | self.theta = theano.shared(value=numpy.zeros(self.theta_shape, dtype=theano.config.floatX)) 22 | self.theta.set_value(numpy.concatenate([e.get_value().ravel() for e in 23 | self.params]), borrow = "true") 24 | 25 | self.gradient = theano.function(on_unused_input='ignore', 26 | inputs=[is_tr] + model.masks, 27 | outputs = T.grad(self.cost(o, atrisk), self.params), 28 | givens = {model.x:x, model.o:o, model.at_risk:atrisk, model.is_train:is_tr}, 29 | name='gradient') 30 | self.cost_func = theano.function(on_unused_input='ignore', 31 | inputs=[is_tr] + model.masks, 32 | outputs = self.cost(o, atrisk), 33 | givens = {model.x:x, model.o:o, model.at_risk:atrisk, model.is_train:is_tr}, 34 | name='cost_func') 35 | 36 | def f(self, theta_val): 37 | self.theta.set_value(theta_val) 38 | idx = 0 39 | for i in range(len(self.params)): 40 | p = self.theta.get_value()[idx:idx + self.params[i].get_value().size] 41 | p = p.reshape(self.params[i].get_value().shape) 42 | idx += self.params[i].get_value().size 43 | self.params[i].set_value(p) 44 | 45 | c = -self.cost_func(1, *self.masks) 46 | return c 47 | 48 | def fprime(self, theta_val): 49 | self.theta.set_value(theta_val) 50 | idx = 0 51 | for i in range(len(self.params)): 52 | p = self.theta.get_value()[idx:idx + self.params[i].get_value().size] 53 | p = p.reshape(self.params[i].get_value().shape) 54 | idx += self.params[i].get_value().size 55 | self.params[i].set_value(p) 56 | 57 | gs = self.gradient(1, *self.masks) 58 | gf = numpy.concatenate([g.ravel() for g in gs]) 59 | return -gf 60 | 61 | def bfgs_min(self, f, x0, fprime): 62 | self.theta_t = x0 63 | self.old_fval = f(self.theta_t) 64 | self.gf_t = fprime(x0) 65 | self.rho_t = -numpy.dot(self.H_t, self.gf_t) 66 | 67 | try: 68 | self.eps_t, fc, gc, self.old_fval, self.old_old_fval, gf_next = \ 69 | _line_search_wolfe12(f, fprime, self.theta_t, self.rho_t, self.gf_t, 70 | self.old_fval, self.old_old_fval, amin=1e-100, amax=1e100) 71 | except _LineSearchError: 72 | print 'Line search failed to find a better solution.\n' 73 | theta_next = self.theta_t + self.gf_t * .0001 74 | return theta_next 75 | 76 | theta_next = self.theta_t + self.eps_t * self.rho_t 77 | 78 | delta_t = theta_next - self.theta_t 79 | self.theta_t = theta_next 80 | self.phi_t = gf_next - self.gf_t 81 | self.gf_t = gf_next 82 | denom = 1.0 / (numpy.dot(self.phi_t, delta_t)) 83 | 84 | ## Memory intensive computation based on Wright and Nocedal 'Numerical Optimization', 1999, pg. 198. 85 | #I = numpy.eye(len(x0), dtype=int) 86 | #A = I - self.phi_t[:, numpy.newaxis] * delta_t[numpy.newaxis, :] * denom 87 | ## Estimating H. 88 | #self.H_t[...] = numpy.dot(self.H_t, A) 89 | #A[...] = I - delta_t[:, numpy.newaxis] * self.phi_t[numpy.newaxis, :] * denom 90 | #self.H_t[...] = numpy.dot(A, self.H_t) + (denom * delta_t[:, numpy.newaxis] * 91 | # delta_t[numpy.newaxis, :]) 92 | #A = None 93 | 94 | # Fast memory friendly calculation after simplifiation of the above. 95 | Z = numpy.dot(self.H_t, self.phi_t) 96 | self.H_t -= denom * Z[:, numpy.newaxis] * delta_t[numpy.newaxis,:] 97 | self.H_t -= denom * delta_t[:, numpy.newaxis] * Z[numpy.newaxis, :] 98 | self.H_t += denom * denom * numpy.dot(self.phi_t, Z) * delta_t[:, numpy.newaxis] * delta_t[numpy.newaxis,:] 99 | return theta_next 100 | 101 | def BFGS(self, masks): 102 | self.masks = masks 103 | of = self.bfgs_min 104 | theta_val = of(f=self.f, x0=self.theta.get_value(), fprime=self.fprime) 105 | self.theta.set_value(theta_val) 106 | idx = 0 107 | for i in range(len(self.params)): 108 | p = self.theta.get_value()[idx:idx + self.params[i].get_value().size] 109 | p = p.reshape(self.params[i].get_value().shape) 110 | idx += self.params[i].get_value().size 111 | self.params[i].set_value(p) 112 | return 113 | -------------------------------------------------------------------------------- /survivalnet/optimization/EarlyStopping.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Jun 5 22:35:05 2016 4 | 5 | @author: Safoora Yousefi 6 | """ 7 | import numpy as np 8 | def isOverfitting(results, interval=5, num_intervals = 3): 9 | flag = True 10 | end = len(results) 11 | maxIter = len(results)-1 12 | for i in range(num_intervals - 1): 13 | begin1 = end - (i + 1) * interval 14 | end1 = end - (i) * interval 15 | begin2 = end - (i + 2) * interval 16 | if np.mean(results[begin1:end1]) > np.mean(results[begin2:begin1]): 17 | flag = False 18 | if flag: 19 | maxIter = np.argmax(results[begin2:end1]) + end - interval * num_intervals 20 | return flag, maxIter 21 | -------------------------------------------------------------------------------- /survivalnet/optimization/GDLS.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import numpy 3 | import theano.tensor as T 4 | from scipy.optimize.optimize import _line_search_wolfe12, _LineSearchError 5 | 6 | 7 | class GDLS(object): 8 | 9 | def __init__(self, model, x, o, atrisk): 10 | self.cost = model.risk_layer.cost 11 | self.params = model.params 12 | is_tr = T.iscalar('is_train') 13 | self.stop = False 14 | self.theta_shape = sum([self.params[i].get_value().size for i in range(len(self.params))]) 15 | self.old_old_fval = None 16 | N = self.theta_shape 17 | self.theta = theano.shared(value=numpy.zeros(self.theta_shape, dtype=theano.config.floatX)) 18 | self.theta.set_value(numpy.concatenate([e.get_value().ravel() for e in self.params]), borrow = "true") 19 | 20 | self.gradient = theano.function(on_unused_input='ignore', 21 | inputs=[is_tr] + model.masks, 22 | outputs = T.grad(self.cost(o, atrisk) - model.L1 - model.L2_sqr, self.params), 23 | givens = {model.x:x, model.o:o, model.at_risk:atrisk, model.is_train:is_tr}, 24 | name='gradient') 25 | self.cost_func = theano.function(on_unused_input='ignore', 26 | inputs=[is_tr] + model.masks, 27 | outputs = self.cost(o, atrisk) - model.L1 - model.L2_sqr, 28 | givens = {model.x:x, model.o:o, model.at_risk:atrisk, model.is_train:is_tr}, 29 | name='cost_func') 30 | 31 | def f(self, theta_val): 32 | self.theta.set_value(theta_val) 33 | idx = 0 34 | for i in range(len(self.params)): 35 | p = self.theta.get_value()[idx:idx + self.params[i].get_value().size] 36 | p = p.reshape(self.params[i].get_value().shape) 37 | idx += self.params[i].get_value().size 38 | self.params[i].set_value(p) 39 | 40 | c = -self.cost_func(1, *self.masks) 41 | return c 42 | 43 | def fprime(self, theta_val): 44 | self.theta.set_value(theta_val) 45 | idx = 0 46 | for i in range(len(self.params)): 47 | p = self.theta.get_value()[idx:idx + self.params[i].get_value().size] 48 | p = p.reshape(self.params[i].get_value().shape) 49 | idx += self.params[i].get_value().size 50 | self.params[i].set_value(p) 51 | 52 | gs = self.gradient(1, *self.masks) 53 | gf = numpy.concatenate([g.ravel() for g in gs]) 54 | return -gf 55 | 56 | #Gradient Descent with line search 57 | def gd_ls(self, f, x0, fprime): 58 | self.theta_t = x0 59 | self.old_fval = f(self.theta_t) 60 | self.gf_t = fprime(x0) 61 | self.rho_t = -self.gf_t 62 | try: 63 | self.eps_t, fc, gc, self.old_fval, self.old_old_fval, gf_next = \ 64 | _line_search_wolfe12(f, fprime, self.theta_t, self.rho_t, self.gf_t, 65 | self.old_fval, self.old_old_fval, amin=1e-100, amax=1e100) 66 | except _LineSearchError: 67 | print 'Line search failed to find a better solution.\n' 68 | self.stop = True 69 | theta_next = self.theta_t + self.gf_t * .00001 70 | return theta_next 71 | theta_next = self.theta_t + self.eps_t * self.rho_t 72 | return theta_next 73 | 74 | def GDLS(self, masks): 75 | self.masks = masks 76 | of = self.gd_ls 77 | theta_val = of(f=self.f, x0=self.theta.get_value(), fprime=self.fprime) 78 | self.theta.set_value(theta_val) 79 | idx = 0 80 | for i in range(len(self.params)): 81 | p = self.theta.get_value()[idx:idx + self.params[i].get_value().size] 82 | p = p.reshape(self.params[i].get_value().shape) 83 | idx += self.params[i].get_value().size 84 | self.params[i].set_value(p) 85 | return 86 | -------------------------------------------------------------------------------- /survivalnet/optimization/Optimization.py: -------------------------------------------------------------------------------- 1 | import theano.tensor as T 2 | 3 | class Optimization(object): 4 | def SGD(self, cost, params, learning_rate): 5 | 6 | # compute the gradients with respect to the model parameters 7 | gparams = T.grad(cost, params) 8 | 9 | updates = [(param, param + gparam * learning_rate) 10 | for param, gparam in zip(params, gparams) 11 | ] 12 | 13 | return updates 14 | 15 | -------------------------------------------------------------------------------- /survivalnet/optimization/SurvivalAnalysis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class SurvivalAnalysis(object): 5 | """ This class contains methods used in survival analysis. 6 | """ 7 | 8 | def c_index(self, risk, T, C): 9 | """Calculate concordance index to evaluate model prediction. 10 | 11 | C-index calulates the fraction of all pairs of subjects whose predicted 12 | survival times are correctly ordered among all subjects that can actually 13 | be ordered, i.e. both of them are uncensored or the uncensored time of 14 | one is smaller than the censored survival time of the other. 15 | 16 | Parameters 17 | ---------- 18 | risk: numpy.ndarray 19 | m sized array of predicted risk (do not confuse with predicted survival time) 20 | T: numpy.ndarray 21 | m sized vector of time of death or last follow up 22 | C: numpy.ndarray 23 | m sized vector of censored status (do not confuse with observed status) 24 | 25 | Returns 26 | ------- 27 | A value between 0 and 1 indicating concordance index. 28 | """ 29 | n_orderable = 0.0 30 | score = 0.0 31 | for i in range(len(T)): 32 | for j in range(i+1,len(T)): 33 | if(C[i] == 0 and C[j] == 0): 34 | n_orderable = n_orderable + 1 35 | if(T[i] > T[j]): 36 | if(risk[j] > risk[i]): 37 | score = score + 1 38 | elif(T[j] > T[i]): 39 | if(risk[i] > risk[j]): 40 | score = score + 1 41 | else: 42 | if(risk[i] == risk[j]): 43 | score = score + 1 44 | elif(C[i] == 1 and C[j] == 0): 45 | if(T[i] >= T[j]): 46 | n_orderable = n_orderable + 1 47 | if(T[i] > T[j]): 48 | if(risk[j] > risk[i]): 49 | score = score + 1 50 | elif(C[j] == 1 and C[i] == 0): 51 | if(T[j] >= T[i]): 52 | n_orderable = n_orderable + 1 53 | if(T[j] > T[i]): 54 | if(risk[i] > risk[j]): 55 | score = score + 1 56 | 57 | #print score to screen 58 | return score / n_orderable 59 | 60 | def calc_at_risk(self, X, T, O): 61 | """Calculate the at risk group of all patients. 62 | 63 | For every patient i, this function returns the index of the first 64 | patient who died after i, after sorting the patients w.r.t. time of death. 65 | Refer to the definition of Cox proportional hazards log likelihood for 66 | details: https://goo.gl/k4TsEM 67 | 68 | Parameters 69 | ---------- 70 | X: numpy.ndarray 71 | m*n matrix of input data 72 | T: numpy.ndarray 73 | m sized vector of time of death 74 | O: numpy.ndarray 75 | m sized vector of observed status (1 - censoring status) 76 | 77 | Returns 78 | ------- 79 | X: numpy.ndarray 80 | m*n matrix of input data sorted w.r.t time of death 81 | T: numpy.ndarray 82 | m sized sorted vector of time of death 83 | O: numpy.ndarray 84 | m sized vector of observed status sorted w.r.t time of death 85 | at_risk: numpy.ndarray 86 | m sized vector of starting index of risk groups 87 | """ 88 | tmp = list(T) 89 | T = np.asarray(tmp).astype('float64') 90 | order = np.argsort(T) 91 | sorted_T = T[order] 92 | at_risk = np.asarray([list(sorted_T).index(x) for x in sorted_T]).astype('int32') 93 | T = np.asarray(sorted_T) 94 | O = O[order] 95 | X = X[order] 96 | 97 | return X, T, O, at_risk 98 | -------------------------------------------------------------------------------- /survivalnet/optimization/__init__.py: -------------------------------------------------------------------------------- 1 | # imported before Bayesian_Optimization 2 | from .BFGS import BFGS 3 | from .EarlyStopping import isOverfitting 4 | from .GDLS import GDLS 5 | from .Optimization import Optimization 6 | from .SurvivalAnalysis import SurvivalAnalysis 7 | 8 | 9 | # list functions and classes available for public use 10 | __all__ = ( 11 | 'BFGS', 12 | 'isOverfitting', 13 | 'GLDS', 14 | 'Optimization', 15 | 'SurvivalAnalysis', 16 | ) 17 | -------------------------------------------------------------------------------- /survivalnet/train.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import os 3 | import sys 4 | import theano 5 | import timeit 6 | 7 | from model import Model 8 | from optimization import BFGS 9 | from optimization import GDLS 10 | from optimization import SurvivalAnalysis 11 | from optimization import isOverfitting 12 | 13 | LEARNING_RATE_DECAY = 1 14 | 15 | def train(pretrain_set, train_set, test_set, pretrain_config, finetune_config, 16 | n_layers, n_hidden, dropout_rate, non_lin, 17 | optim='GD', lambda1=0, lambda2=0, verbose=True, earlystp=True): 18 | 19 | """Creates and trains a feedforward neural network. 20 | Arguments: 21 | pretrain_set: dict. Contains pre-training data (nxp array). If None, no 22 | pre-training is performed. 23 | train_set: dict. Contains training data (nxp array), labels (nx1 array), 24 | censoring status (nx1 array) and at risk indices (nx1 array). 25 | test_set: dict. Contains testing data (nxp array), labels (nx1 array), 26 | censoring status (nx1 array) and at risk indices (nx1 array). 27 | pretrain_config: dict. Contains pre-training parameters. 28 | finetune_config: dict. Contains finetuning parameters. 29 | n_layers: int. Number of layers in neural network. 30 | n_hidden: int. Number of hidden units in each layer. 31 | dropout_rate: float. Probability of dropping units. 32 | non_lin: theano.Op or function. Type of activation function. Linear if None. 33 | optim: str. Optimization algorithm to use. One of 'GD', 'GDLS', and 'BFGS'. 34 | lambda1: flaot. L1 regularization rate. 35 | lambda2: float. L2 regularization rate. 36 | verbose: bool. Whether to log progress to stderr. 37 | earlystp: bool. Whether to use early stopping. 38 | 39 | Outputs: 40 | train_costs: 1D array. Loss value on training data at each epoch. 41 | train_cindices: 1D array. C-index values on training data at each epoch. 42 | test_costs: 1D array. Loss value on testing data at each epoch. 43 | test_cindices: 1D array. C-index values on testing data at each epoch. 44 | train_risk: 1D array. Final predicted risks for all patients in training set. 45 | test_risk: 1D array. Final predicted risks for all patients in test set. 46 | model: Model. Final trained model. 47 | max_iter: int. Number of training epochs. Equal to 48 | finetune_config['ft_epochs'] or smaller if earlystp is True. 49 | """ 50 | finetune_lr = theano.shared(numpy.asarray(finetune_config['ft_lr'], 51 | dtype=theano.config.floatX)) 52 | 53 | numpy_rng = numpy.random.RandomState(1111) 54 | 55 | # Construct the stacked denoising autoencoder and the corresponding 56 | # supervised survival network. 57 | model = Model( 58 | numpy_rng = numpy_rng, 59 | n_ins = train_set['X'].shape[1], 60 | hidden_layers_sizes = [n_hidden] * n_layers, 61 | n_outs = 1, 62 | dropout_rate=dropout_rate, 63 | lambda1 = lambda1, 64 | lambda2 = lambda2, 65 | non_lin=non_lin) 66 | 67 | ######################### 68 | # PRETRAINING THE MODEL # 69 | ######################### 70 | if pretrain_config is not None: 71 | n_batches = len(train_set) / (pretrain_config['pt_batchsize'] or len(train_set)) 72 | 73 | pretraining_fns = model.pretraining_functions( 74 | pretrain_set, 75 | pretrain_config['pt_batchsize']) 76 | start_time = timeit.default_timer() 77 | # de-noising level 78 | corruption_levels = [pretrain_config['corruption_level']] * n_layers 79 | for i in xrange(model.n_layers): #Layerwise pre-training 80 | # go through pretraining epochs 81 | for epoch in xrange(pretrain_config['pt_epochs']): 82 | # go through the training set 83 | c = [] 84 | for batch_index in xrange(n_batches): 85 | c.append(pretraining_fns[i](index=batch_index, 86 | corruption=corruption_levels[i], 87 | lr=pretrain_config['pt_lr'])) 88 | 89 | if verbose: 90 | print 'Pre-training layer {}, epoch {}, cost'.format(i, epoch, numpy.mean(c)) 91 | 92 | end_time = timeit.default_timer() 93 | if verbose: 94 | print('Pretraining took {} minutes.'.format((end_time - start_time) / 60.)) 95 | 96 | ######################## 97 | # FINETUNING THE MODEL # 98 | ######################## 99 | test, train = model.build_finetune_functions(learning_rate=finetune_lr) 100 | 101 | train_cindices = [] 102 | test_cindices = [] 103 | train_costs = [] 104 | test_costs = [] 105 | 106 | if optim == 'BFGS': 107 | bfgs = BFGS(model, train_set['X'], train_set['O'], train_set['A']) 108 | elif optim == 'GDLS': 109 | gdls = GDLS(model, train_set['X'], train_set['O'], train_set['A']) 110 | survivalAnalysis = SurvivalAnalysis() 111 | 112 | # Starts the training routine. 113 | for epoch in range(finetune_config['ft_epochs']): 114 | 115 | # Creates masks for dropout during training. 116 | train_masks = [ 117 | numpy_rng.binomial(n=1, p=1-dropout_rate, 118 | size=(train_set['X'].shape[0], n_hidden)) 119 | for i in range(n_layers)] 120 | 121 | # Creates dummy masks for testing. 122 | test_masks = [ 123 | numpy.ones((test_set['X'].shape[0], n_hidden), dtype='int64') 124 | for i in range(n_layers)] 125 | 126 | # BFGS() and GDLS() update the gradients, so we only serve (test) the 127 | # model to calculate cost, risk, and cindex on training set. 128 | if optim == 'BFGS': 129 | bfgs.BFGS(train_masks) 130 | train_cost, train_risk, train_features = test( 131 | train_set['X'], train_set['O'], train_set['A'], 1, *train_masks) 132 | elif optim == 'GDLS': 133 | gdls.GDLS(train_masks) 134 | train_cost, train_risk, train_features = test( 135 | train_set['X'], train_set['O'], train_set['A'], 1, *train_masks) 136 | # In case of GD, uses the train function to update the gradients and get 137 | # training cost, risk, and cindex at the same time. 138 | elif optim == 'GD': 139 | train_cost, train_risk, train_features = train( 140 | train_set['X'], train_set['O'], train_set['A'], 1, *train_masks) 141 | train_ci = survivalAnalysis.c_index(train_risk, train_set['T'], 1 - train_set['O']) 142 | 143 | # Calculates testing cost, risk and cindex using th eupdated model. 144 | test_cost, test_risk, _ = test(test_set['X'], test_set['O'], test_set['A'], 0, *test_masks) 145 | test_ci = survivalAnalysis.c_index(test_risk, test_set['T'], 1 - test_set['O']) 146 | 147 | train_cindices.append(train_ci) 148 | test_cindices.append(test_ci) 149 | 150 | train_costs.append(train_cost) 151 | test_costs.append(test_cost) 152 | if verbose: 153 | print (('epoch = {}, trn_cost = {}, trn_ci = {}, tst_cost = {},' 154 | ' tst_ci = {}').format(epoch, train_cost, train_ci, 155 | test_cost, test_ci)) 156 | if earlystp and epoch >= 15 and (epoch % 5 == 0): 157 | if verbose: 158 | print 'Checking overfitting!' 159 | check, max_iter = isOverfitting(numpy.asarray(test_cindices)) 160 | if check: 161 | print(('Training Stopped Due to Overfitting! cindex = {},' 162 | ' MaxIter = {}').format(test_cindices[max_iter], max_iter)) 163 | break 164 | else: max_iter = epoch 165 | sys.stdout.flush() 166 | decay_learning_rate = theano.function( 167 | inputs=[], outputs=finetune_lr, 168 | updates={finetune_lr: finetune_lr * LEARNING_RATE_DECAY}) 169 | decay_learning_rate() 170 | epoch += 1 171 | if numpy.isnan(test_cost): break 172 | if verbose: 173 | print 'C-index score after {} epochs is: {}'.format(max_iter, max(test_cindices)) 174 | return train_costs, train_cindices, test_costs, test_cindices, train_risk, test_risk, model, max_iter 175 | --------------------------------------------------------------------------------