├── .DS_Store ├── README.md ├── mean_and_var_pool_arch.png ├── run_all_fits.py ├── run_viz_top_patches.py ├── scripts ├── .DS_Store ├── agg_val_tune.py ├── baseline_cox.py ├── eval.py ├── redo_save_preds.py ├── run_tune.py ├── train.py └── visualize.py ├── setup.py ├── tcga_scripts ├── .DS_Store ├── aggregate_surv_cv_folds.py ├── download_tcga_clinical_data.py ├── make_discr_suvr_splits.py ├── make_subtype_clf_splits.py ├── make_surv_yaml.py ├── stat_sig_c_index_cutoff.py ├── viz_top_patches-extremes_only.py └── viz_top_patches.py └── var_pool ├── .DS_Store ├── __init__.py ├── file_utils.py ├── gpu_utils.py ├── mhist ├── .DS_Store ├── __init__.py ├── clinical_data_porpoise.py ├── get_model.py ├── get_model_from_args.py ├── get_model_with_switch.py ├── patch_gcn_arch.py ├── tcga_agg_slides_to_patient_level.py └── tcga_clinical_data.py ├── nn ├── .DS_Store ├── ComparablePairSampler.py ├── CoxLoss.py ├── NLLSurvLoss.py ├── SurvRankingLoss.py ├── __init__.py ├── arch │ ├── .DS_Store │ ├── AttnMIL.py │ ├── AttnMIL_utils.py │ ├── GlobalPoolMIL.py │ ├── PatchGCN.py │ ├── SumMIL.py │ ├── VarPool.py │ ├── VarPool_switch.py │ ├── __init__.py │ └── utils.py ├── datasets │ ├── .DS_Store │ ├── BagDatasets.py │ ├── GraphDatasets.py │ ├── __init__.py │ └── fixed_bag_size.py ├── seeds.py ├── stream_evaler.py ├── train │ ├── .DS_Store │ ├── EarlyStopper.py │ ├── GradAccum.py │ ├── __init__.py │ ├── loops.py │ └── tests │ │ ├── .DS_Store │ │ ├── __init__.py │ │ ├── test_GradAccum.py │ │ └── utils_grad_accum.py ├── tune_utils.py └── utils.py ├── processing ├── .DS_Store ├── __init__.py ├── clf_utils.py ├── data_split.py └── discr_surv_utils.py ├── script_utils.py ├── utils.py └── viz ├── .DS_Store ├── __init__.py ├── top_attn.py ├── utils.py └── var_pool_extremes.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Incorporating intratumoral heterogeneity into weakly-supervised deep learning models via variance pooling 2 | 3 | **Carmichael, I.**\*, **Song, A.H.**\*, Chen, R.J., Williamson, D.F.K., Chen, T.Y., Mahmood, F. [Incorporating intratumoral heterogeneity into weakly-supervised deep learning models via variance pooling](https://arxiv.org/pdf/2206.08885.pdf). The International Conference on Medical Image Computing and Computer Assisted Intervention (MICCAI), 2022 4 | 5 | 6 | **Abstract**: Supervised learning tasks such as cancer survival prediction 7 | from gigapixel whole slide images (WSIs) are a critical challenge in computational pathology that requires modeling complex features of the tumor microenvironment. These learning tasks are often solved with deep 8 | multi-instance learning (MIL) models that do not explicitly capture intratumoral heterogeneity. We develop a novel variance pooling architecture that enables a MIL model to incorporate intratumoral heterogeneity 9 | into its predictions. Two interpretability tools based on “representative 10 | patches” are illustrated to probe the biological signals captured by these 11 | models. An empirical study with 4,479 gigapixel WSIs from the Cancer 12 | Genome Atlas shows that adding variance pooling onto MIL frameworks 13 | improves survival prediction performance for five cancer types. 14 | 15 | ![varpool](mean_and_var_pool_arch.png) 16 | 17 | # Updates 18 | Please follow this GitHub for more updates. 19 | 20 | # Setup 21 | 22 | 23 | # 1. Downloading TCGA Data 24 | To download diagnostic WSIs (formatted as .svs files) please refer to the [NIH Genomic Data Commons Data Portal](https://portal.gdc.cancer.gov/). WSIs for each cancer type can be downloaded using the [GDC Data Transfer Tool](https://docs.gdc.cancer.gov/Data_Transfer_Tool/Users_Guide/Data_Download_and_Upload/). 25 | 26 | # 2. Processing Whole Slide Images 27 | To process the WSI data we used the publicly available [CLAM WSI-analysis toolbox](https://github.com/mahmoodlab/CLAM). First, the tissue regions in each biopsy slide are segmented. The 256 x 256 patches without spatial overlapping are extracted from the segmented tissue regions at the desired magnification. Consequently, a pretrained truncated ResNet50 is used to encode raw image patches into 1024-dim feature vector. Using the CLAM toolbox, the features are saved as matrices of torch tensors of size N x 1024, where N is the number of patches from each WSI (varies from slide to slide). Please refer to [CLAM](https://github.com/mahmoodlab/CLAM) for examples on tissue segmentation and feature extraction. 28 | The extracted features then serve as input (in a .pt file) to the network. The following folder structure is assumed for the extracted features vectors: 29 | ```bash 30 | DATA_ROOT_DIR/ 31 | └──TCGA_BLCA/ 32 | ├── slide_1.pt 33 | ├── slide_2.pt 34 | └── ... 35 | └──TCGA_BRCA/ 36 | ├── slide_1.pt 37 | ├── slide_2.pt 38 | └── ... 39 | ... 40 | ``` 41 | is the base directory of all datasets / cancer type(e.g. the directory to your SSD). Within , each folder contains a list of .pt files for that dataset / cancer type. 42 | 43 | # 3. Run experiments 44 | 45 | 1. This step downloads the necessary clinical data csv file (slide names, clinical endpoints) 46 | 47 | ```bash 48 | python tcga_scripts/download_tcga_clinical_data.py --save_dir /clinical_data --merge_coadread_gbmlgg 49 | ``` 50 | 51 | 2. Next step is to run variance pooling experiment. Currently the available options are 52 | - **task**: Refers to the options for the loss functions **rank_surv** (ranking loss) or **cox_surv** (cox proportional hazard loss) 53 | - **arch_kind**: Refers to the options for the MIL architectures **amil** (Attention-based MIL), **deepsets** (average pooling MIL), or **patch_gcn** (Graph Convolutional Network MIL) 54 | ```bash 55 | python run_all_fits.py --feats_top_dir -- --task rank_surv --arch_kind amil --cuda 0 56 | ``` 57 | 58 | 59 | 3. This command can help you visualize the most important patches along the variance projection directions. 60 | ```bash 61 | python run_viz_top_patches.py 62 | ``` 63 | 64 | # License & Usage 65 | If you find our work useful in your research, please consider citing our paper at: 66 | ``` 67 | @inproceedings{carmichael2022incorporating, 68 | title={Incorporating intratumoral heterogeneity into weakly-supervised deep learning models via variance pooling}, 69 | author={Carmichael, Iain and Song, Andrew H and Chen, Richard J and Williamson, Drew FK and Chen, Tiffany Y and Mahmood, Faisal}, 70 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 71 | pages={387--397}, 72 | year={2022}, 73 | organization={Springer} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /mean_and_var_pool_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/mean_and_var_pool_arch.png -------------------------------------------------------------------------------- /run_all_fits.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from multiprocessing import Pool 4 | import torch 5 | from itertools import product 6 | import yaml 7 | import pandas as pd 8 | 9 | parser = argparse.\ 10 | ArgumentParser(description='Runs a single fit experiment in parallel over a set of subtypes and dataset seeds.') 11 | 12 | parser.add_argument('--feats_top_dir', 13 | type=str, help='Where the features are saved.') 14 | 15 | parser.add_argument('--output_dir', 16 | type=str, help='Where the output should be saved.') 17 | 18 | 19 | parser.add_argument('--task', type=str, help='Which loss task.', 20 | choices=['rank_surv', 'cox_surv', 'discr_surv']) 21 | 22 | parser.add_argument('--arch_kind', type=str, 23 | choices=['amil', 'deepsets', 'amil_gcn', 24 | 'patch_gcn', 'all'], 25 | help='Which architecture kind.') 26 | 27 | parser.add_argument('--cuda', default=0, type=int, 28 | choices=[0, 1, 2], 29 | help='To manually parallelize the process') 30 | 31 | args = parser.parse_args() 32 | 33 | endpoint = 'pfi' 34 | subtypes = ['blca', 'brca', 'coadread', 'gbmlgg', 'ucec'] 35 | dataset_seeds = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100] 36 | 37 | 38 | ################################ 39 | # Setup high level directories # 40 | ################################ 41 | 42 | n_devices = torch.cuda.device_count() 43 | 44 | if args.arch_kind in ['amil_gcn', 'patch_gcn']: 45 | if args.task == 'discr_surv': 46 | train_params = '--seed 1 --dropout '\ 47 | '--n_epochs 30 --lr 2e-4 --batch_size 1 --fixed_bag_size q75 '\ 48 | '--n_var_pools 10 --var_act_func log --mode graph '\ 49 | '--grad_accum 32 --imbal_how resample' 50 | else: 51 | train_params = '--seed 1 --dropout '\ 52 | '--n_epochs 30 --lr 1e-4 --batch_size 32 --fixed_bag_size q75 '\ 53 | '--n_var_pools 10 --var_act_func log --mode graph '\ 54 | '--grad_accum 1 --imbal_how resample' 55 | else: 56 | train_params = '--seed 1 --dropout '\ 57 | '--n_epochs 30 --lr 2e-4 --batch_size 32 --fixed_bag_size q75 '\ 58 | '--n_var_pools 10 --var_act_func log --imbal_how resample '\ 59 | '--grad_accum 1' 60 | 61 | if args.arch_kind == 'amil': 62 | archs2run = ['amil_nn', 'amil_var_nn'] 63 | elif args.arch_kind == 'deepsets': 64 | archs2run = ['sum_mil', 'sum_var_mil'] 65 | elif args.arch_kind == 'amil_gcn': 66 | archs2run = ['amil_gcn_varpool', 'amil_gcn'] 67 | elif args.arch_kind == 'patch_gcn': 68 | archs2run = ['patchGCN', 'patchGCN_varpool'] 69 | elif args.arch_kind == 'all': 70 | archs2run = ['amil_nn', 'amil_var_nn', 'sum_mil', 'sum_var_mil'] 71 | 72 | 73 | #################################### 74 | # Create setup for each experiment # 75 | #################################### 76 | train_commands = [] 77 | for subtype in subtypes: 78 | for dataset in dataset_seeds: 79 | 80 | ######################## 81 | # Make survival splits # 82 | ######################## 83 | 84 | tcga_clincal_fpath = os.path.join(args.output_dir, 85 | 'clinical_data', 86 | 'TCGA-CDR-union-gbmlgg-coadread.xlsx') 87 | 88 | surv_respose_dir = os.path.join(args.output_dir, 'surv_response', 89 | '{}-{}'.format(subtype, 90 | endpoint), 91 | 'dataset_{}'.format(dataset)) 92 | 93 | if args.arch_kind in ['amil_gcn', 'patch_gcn']: 94 | feats_dir = os.path.join(args.feats_top_dir, subtype, 'graph') 95 | else: 96 | feats_dir = os.path.join(args.feats_top_dir, subtype) 97 | 98 | make_splits_kws = {'feats_dir': feats_dir, 99 | 'tcga_clincal_fpath': tcga_clincal_fpath, 100 | 'surv_respose_dir': surv_respose_dir, 101 | 'subtype': subtype, 102 | 'endpoint': endpoint, 103 | 'dataset': dataset 104 | } 105 | 106 | make_splits_command = 'python tcga_scripts/make_discr_suvr_splits.py '\ 107 | '--tcga_clincal_fpath {tcga_clincal_fpath} '\ 108 | '--feats_dir {feats_dir} --save_dir {surv_respose_dir} '\ 109 | '--subtype {subtype} --endpoint {endpoint} '\ 110 | '--prop_trian .7 --seed {dataset} --n_bins 4 --no_test_split'.\ 111 | format(**make_splits_kws) 112 | 113 | os.system(make_splits_command) 114 | 115 | ####################### 116 | # Make task yaml file # 117 | ####################### 118 | 119 | y_fpath = os.path.join(surv_respose_dir, 'discr_survival.csv') 120 | 121 | task_fpath = os.path.join(output_dir, 'surv_yaml', 122 | '{}-{}-ds_{}-{}.yaml'. 123 | format(subtype, endpoint, 124 | dataset, args.task), 125 | ) 126 | 127 | train_dir = os.path.join(output_dir, 'surv_train_out', 128 | '{}-{}'.format(subtype, endpoint), 129 | 'dataset_{}'.format(dataset), 130 | args.task) 131 | 132 | make_yaml_kws = {'feats_dir': feats_dir, 133 | 'train_dir': train_dir, 134 | 'y_fpath': y_fpath, 135 | 'task_fpath': task_fpath, 136 | 'task': args.task 137 | } 138 | 139 | make_yaml_command = 'python tcga_scripts/make_surv_yaml.py '\ 140 | '--fpath {task_fpath} --task {task} --y_fpath {y_fpath} '\ 141 | '--feats_dir {feats_dir} --train_dir {train_dir}'.\ 142 | format(**make_yaml_kws) 143 | 144 | os.system(make_yaml_command) 145 | 146 | ############################ 147 | # Compute stat sig c-index # 148 | ############################ 149 | stat_sig_dir = os.path.join(output_dir, 'c_index_stat_sig') 150 | 151 | stat_sig_kws = {'response_fpath': y_fpath, 152 | 'stat_sig_dir': stat_sig_dir, 153 | 'save_stub': '{}-{}-dataset_{}'. 154 | format(subtype, endpoint, dataset)} 155 | 156 | stat_sig_command = 'python tcga_scripts/stat_sig_c_index_cutoff.py '\ 157 | '--response_fpath {response_fpath} '\ 158 | '--save_dir {stat_sig_dir} --save_stub {save_stub}'.\ 159 | format(**stat_sig_kws) 160 | 161 | os.system(stat_sig_command) 162 | 163 | ############################# 164 | # Make run train.py command # 165 | ############################# 166 | for arch in archs2run: 167 | 168 | run_train_kws = {'task_fpath': task_fpath, 169 | 'cuda': args.cuda, 170 | 'name': arch, 171 | 'arch': arch, 172 | 'train_params': train_params} 173 | 174 | run_train_command = 'CUDA_VISIBLE_DEVICES={cuda} '\ 175 | 'python scripts/train.py '\ 176 | '--task_fpath {task_fpath} --name {name} '\ 177 | '--arch {arch} {train_params}'.\ 178 | format(**run_train_kws) 179 | 180 | train_commands.append(run_train_command) 181 | 182 | # pool = Pool(processes=n_devices) 183 | pool = Pool(processes=1) # For manual GPU allocation 184 | pool.starmap(os.system, list(zip(train_commands))) 185 | pool.close() 186 | pool.join() 187 | 188 | ############################# 189 | # Aggregate results for val # 190 | ############################# 191 | save_dir = os.path.join(args.output_dir, 'fit_results') 192 | os.makedirs(save_dir, exist_ok=True) 193 | missing = [] 194 | for subtype in subtypes: 195 | 196 | results = [] 197 | for dataset, arch in product(dataset_seeds, archs2run): 198 | 199 | # load results for one experiment 200 | res_fpath = os.path.join(args.output_dir, 'surv_train_out', 201 | '{}-{}'.format(subtype, endpoint), 202 | 'dataset_{}'.format(dataset), 203 | args.task, arch, 204 | 'results.yaml') 205 | 206 | if os.path.exists(res_fpath): 207 | with open(res_fpath) as file: 208 | res = yaml.safe_load(file) 209 | res['dataset'] = dataset 210 | res['arch'] = arch 211 | results.append(res) 212 | 213 | else: 214 | missing.append({'subtype': subtype, 'dataset': dataset, 215 | 'arch': arch}) 216 | 217 | # Save results for this subtype 218 | results_fpath = os.path.join(save_dir, 219 | 'results-{}-{}-{}_{}_val.csv'. 220 | format(subtype, endpoint, args.task, 221 | args.arch_kind)) 222 | results = pd.DataFrame(results) 223 | results.to_csv(results_fpath, index=False) 224 | 225 | missing_fpath = os.path.join(save_dir, 'missing_results_val.csv') 226 | missing = pd.DataFrame(missing) 227 | missing.to_csv(missing_fpath, index=False) 228 | -------------------------------------------------------------------------------- /run_viz_top_patches.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | 7 | from var_pool.utils import format_command_line_args 8 | from var_pool.file_utils import join_and_make 9 | from var_pool.viz.utils import save_fig 10 | 11 | 12 | parser = argparse.\ 13 | ArgumentParser(description='Runs a single fit experiment in parallel over a set of subtypes and dataset seeds.') 14 | 15 | parser.add_argument('--feats_top_dir', 16 | type=str, help='Where the features are saved.') 17 | 18 | parser.add_argument('--output_dir', 19 | type=str, help='Where the output should be saved.') 20 | args = parser.parse_args() 21 | 22 | device = None 23 | 24 | endpoint = 'pfi' 25 | task = 'rank_surv' 26 | 27 | base_arch = 'amil_nn' 28 | var_arch = 'amil_var_nn' 29 | 30 | n_patients = 10 31 | 32 | ############### 33 | # Setup paths # 34 | ############### 35 | 36 | for subtype, dataset in zip(['brca', 'blca', 'coadread', 'gbmlgg', 'ucec'], 37 | [10, 20, 10, 10, 80]): 38 | # these are the seeds with the best validation errors 39 | # lets use these for visualization 40 | 41 | wsi_dir = os.path.join(args.top_data_dir, 'wsi/tcga', subtype) 42 | 43 | feat_h5_dir = os.path.join(args.top_data_dir, 'mil-h5_files', subtype, 44 | 'resnet50_trunc_h5_patch_features') 45 | 46 | results_dir = os.path.join(args.output_dir, 'surv_train_out', 47 | '{}-{}'.format(subtype, endpoint), 48 | 'dataset_{}'.format(dataset), 49 | task) 50 | 51 | autogen_fpath = os.path.join(args.output_dir, 'autogen', 52 | 'process_list_autogen-{}.csv'.format(subtype)) 53 | 54 | # Paths for saving 55 | top_save_dir = join_and_make(args.output_dir, 'viz', 56 | '{}-{}-ds_{}'.format(subtype, endpoint, dataset)) 57 | 58 | ############################################## 59 | # Make visualizations for base and var archs # 60 | ############################################## 61 | 62 | for i in range(2): 63 | if i == 0: 64 | arch = var_arch 65 | else: 66 | arch = base_arch 67 | 68 | ####################### 69 | # Paths for this arch # 70 | ####################### 71 | 72 | save_dir = join_and_make(top_save_dir, arch) 73 | 74 | checkpoint_fpath = os.path.join(results_dir, arch, 'checkpoints', 75 | 's_checkpoint.pt') 76 | 77 | train_preds_fpath = os.path.join(results_dir, arch, 'train_preds.npz') 78 | val_preds_fpath = os.path.join(results_dir, arch, 'val_preds.npz') 79 | 80 | y_fpath = os.path.join(save_dir, 'response.csv') 81 | y_fig_fpath = os.path.join(save_dir, 'risk_preds.png') 82 | 83 | #################################### 84 | # Get highest/lowest risk patients # 85 | #################################### 86 | 87 | train_preds = np.load(train_preds_fpath) 88 | val_preds = np.load(val_preds_fpath) 89 | z = np.concatenate([train_preds['z'], val_preds['z']]) 90 | y_true = np.vstack([train_preds['y_true'], val_preds['y_true']]) 91 | sample_ids = np.concatenate([train_preds['sample_ids'], 92 | val_preds['sample_ids']]) 93 | 94 | pred_risk = pd.Series(z, index=sample_ids, name='pred_risk') 95 | pred_risk = pred_risk.sort_values(ascending=False) # highest risk first 96 | 97 | # use the var arch's predictions for determine 98 | # the higest/lowest risk patients 99 | if i == 0: 100 | high_risk = pred_risk.index.values[0:n_patients] 101 | low_risk = pred_risk.index.values[-n_patients:] 102 | 103 | ############################# 104 | # Save survival predictions # 105 | ############################# 106 | 107 | y_df = pd.DataFrame(y_true, index=sample_ids, 108 | columns=['censor', 'survival_time']) 109 | y_df['censor'] = y_df['censor'].astype(bool) 110 | y_df = y_df.loc[pred_risk.index] 111 | y_df['pred_risk'] = pred_risk 112 | y_df.to_csv(y_fpath) 113 | 114 | # Plot predictions 115 | plt.figure(figsize=(8, 8)) 116 | plt.scatter(y_df.query("censor")['pred_risk'], 117 | y_df.query("censor")['survival_time'], 118 | marker='o', 119 | color='lightcoral', 120 | label='censored') 121 | 122 | plt.scatter(y_df.query("not censor")['pred_risk'], 123 | y_df.query("not censor")['survival_time'], 124 | marker='x', 125 | color='red', 126 | label='observed') 127 | plt.legend() 128 | plt.xlabel("Predicted risk") 129 | plt.ylabel("Survival time") 130 | save_fig(y_fig_fpath) 131 | 132 | ########################## 133 | # Format command and run # 134 | ########################## 135 | 136 | command_args = {'autogen_fpath': autogen_fpath, 137 | 'checkpoint_fpath': checkpoint_fpath, 138 | 'wsi_dir': wsi_dir, 139 | 'feat_h5_dir': feat_h5_dir, 140 | 'high_risk': high_risk, 141 | 'low_risk': low_risk, 142 | 'save_dir': save_dir} 143 | 144 | model_args = {'arch': arch, 145 | 'n_var_pools': 10, 146 | 'var_act_func': 'log', 147 | } 148 | 149 | model_flags = ['dropout'] 150 | 151 | command = 'python tcga_scripts/viz_top_patches.py' 152 | arg_str = format_command_line_args(kws={**command_args, **model_args}, 153 | flags=model_flags) 154 | 155 | command += ' ' + arg_str 156 | # print(command) 157 | 158 | if device is not None: 159 | command = 'CUDA_VISIBLE_DEVICES={} '.format(device) + command 160 | 161 | os.system(command) 162 | -------------------------------------------------------------------------------- /scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/scripts/.DS_Store -------------------------------------------------------------------------------- /scripts/agg_val_tune.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | import yaml 4 | import argparse 5 | import pandas as pd 6 | 7 | 8 | parser = argparse.\ 9 | ArgumentParser(description='Aggregates validation set tuning results.') 10 | 11 | parser.add_argument('--tune_dir', type=str, 12 | help='Directory containing tune folders named tune_1, tune_2, ...') 13 | 14 | parser.add_argument('--metric', type=str, 15 | help='Name of the metric to use to pick the best model.') 16 | 17 | parser.add_argument('--small_good', 18 | action='store_true', default=False, 19 | help='By default we want to maximize the tuning metric; this flag says we want to minimize the metric.') 20 | 21 | parser.add_argument('--save_dir', type=str, 22 | help='Directory where to save the aggregated tuning results.') 23 | 24 | args = parser.parse_args() 25 | 26 | # setup fpaths 27 | os.makedirs(args.save_dir, exist_ok=True) 28 | tune_results_fpath = os.path.join(args.save_dir, 'tune_results.csv') 29 | missing_results_fpath = os.path.join(args.save_dir, 'missing_results.csv') 30 | 31 | summary_fpath = os.path.join(args.save_dir, 'tune_summary.txt') 32 | 33 | ####################### 34 | # Load tuning results # 35 | ####################### 36 | 37 | # get all folders named tune_IDX 38 | tune_folders = glob(os.path.join(args.tune_dir, 'tune_*/')) 39 | 40 | if len(tune_folders) == 0: 41 | raise RuntimeError("No tuning folders found") 42 | else: 43 | print("Found {} tuning folders".format(len(tune_folders))) 44 | 45 | tune_results = [] 46 | tune_params = {} 47 | missing_results = [] 48 | for folder_path in tune_folders: 49 | # path/tune_IDX/ -> tune_IDX 50 | folder_name = folder_path.split('/')[-2] 51 | tune_idx = int(folder_name.split('_')[1]) 52 | 53 | # load tune param configs 54 | tune_params_fpath = os.path.join(folder_path, 'tune_params.yaml') 55 | with open(tune_params_fpath, 'r') as file: 56 | this_params = yaml.safe_load(file) 57 | this_params['tune_idx'] = tune_idx # add tune index 58 | 59 | # try loading the results 60 | results_fpath = os.path.join(folder_path, 'results.yaml') 61 | if os.path.exists(results_fpath): 62 | 63 | # load results and tune params 64 | with open(results_fpath, 'r') as file: 65 | this_res = yaml.safe_load(file) 66 | 67 | this_res.update(this_params) # add tuning params 68 | 69 | tune_results.append(this_res) 70 | tune_params[tune_idx] = this_params 71 | 72 | else: 73 | # missing results 74 | missing_results.append(this_params) 75 | 76 | 77 | # format to pandas and save to dis 78 | tune_results = pd.DataFrame(tune_results).\ 79 | set_index('tune_idx').sort_index() 80 | tune_results.to_csv(tune_results_fpath) 81 | 82 | if len(missing_results) > 0: 83 | missing_results = pd.DataFrame(missing_results).\ 84 | set_index('tune_idx').sort_index() 85 | missing_results.to_csv(missing_results_fpath) 86 | 87 | #################### 88 | # Pick best metric # 89 | #################### 90 | 91 | # pull out the metric we want to use to select the tuning parameter 92 | scores = tune_results[args.metric].copy() 93 | if args.small_good: 94 | best_tune_idx = scores.idxmin() 95 | else: 96 | best_tune_idx = scores.idxmax() 97 | 98 | best_tune_params = tune_params[best_tune_idx] 99 | best_score = scores.loc[best_tune_idx] 100 | 101 | scores_agg = scores.agg(['mean', 'std', 'min', 'max', 'median']) 102 | 103 | 104 | # sort 105 | scores.sort_values(ascending=False, inplace=True) 106 | tune_results.sort_values(by=args.metric, ascending=False, inplace=True) 107 | 108 | ###################### 109 | # print/write output # 110 | ###################### 111 | 112 | output = '' 113 | output += "Best tune {} score: {}\n".format(args.metric, best_score) 114 | output += "Best tune idx: {}\n".format(best_tune_idx) 115 | output += "Best tune params\n" 116 | output += "{}\n\n".format(best_tune_params) 117 | output += '{} agg\n'.format(args.metric) 118 | output += '{}\n\n'.format(scores_agg) 119 | output += 'all {} values\n'.format(args.metric) 120 | output += '{}\n\n'.format(scores) 121 | output += '{}\n\n'.format(tune_results.to_string(float_format='%1.6f')) 122 | output += 'Missing: \n{}'.format(missing_results) 123 | 124 | print(output) 125 | with open(summary_fpath, 'w') as f: 126 | f.write(output) 127 | -------------------------------------------------------------------------------- /scripts/baseline_cox.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple idea from Richard. Extract mean embeddings from WSI and train linear Cox 3 | """ 4 | import pandas as pd 5 | import os 6 | import argparse 7 | from time import time 8 | import numpy as np 9 | 10 | import torch 11 | 12 | from sksurv.linear_model import CoxPHSurvivalAnalysis 13 | from sksurv.util import Surv 14 | 15 | from var_pool.nn.datasets.BagDatasets import BagDataset 16 | from var_pool.script_utils import parse_mil_task_yaml 17 | from var_pool.file_utils import find_fpaths, join_and_make 18 | from var_pool.processing.discr_surv_utils import dict_split_discr_surv_df 19 | 20 | parser = argparse.\ 21 | ArgumentParser(description='Trains an attention MIL network for a supervised learning task.') 22 | 23 | parser.add_argument('--task_fpath', type=str, 24 | # default to task.yaml file saved in current directory 25 | default='yaml/luad_vs_lusc.yaml', 26 | help='Filepath to .yaml file containing the information ' 27 | 'for this task. It should include entries for: \n' 28 | 'feats_dir = directory containing WSI bag features as .h5 files.\n' 29 | 'y_fpath: csv file containing the response labels for each bag and the train/val/test spilts. See code for how the csv file should be formatted.\n' 30 | 'task: a string indicated which kind of task we are solving. Should be one of "clf", "surv_cov" or "surv_discr"\n' 31 | 'train_dir: directory to where training results are saved e.g. model checkpoints and logging information.') 32 | 33 | parser.add_argument('--alpha', type=float, default=10, 34 | help="Penalty hyperparamter for Ridge regression") 35 | 36 | args = parser.parse_args() 37 | 38 | start_time = time() 39 | 40 | # load task info from yaml file 41 | feat_dir, y_fpath, train_dir, task = parse_mil_task_yaml(fpath=args.task_fpath) 42 | 43 | # TODO: set this working below 44 | # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 45 | device = torch.device("cpu") 46 | # print("Training with device {} ({} cuda devices available)". 47 | # format(device, torch.cuda.device_count())) 48 | 49 | ################################## 50 | # Setup paths for saving results # 51 | ################################## 52 | train_dir = os.path.join(train_dir, 'baseline_cox') 53 | 54 | # training logs, model checkpoints, final results file 55 | log_dir = join_and_make(train_dir, 'log') 56 | checkpoint_dir = join_and_make(train_dir, 'checkpoints') 57 | summary_fpath = os.path.join(train_dir, 'summary.txt') 58 | results_fpath = os.path.join(train_dir, 'results.yaml') 59 | train_preds_fpath = os.path.join(train_dir, 'train_preds') 60 | val_preds_fpath = os.path.join(train_dir, 'train_preds') 61 | 62 | 63 | ####################################################### 64 | # load response data along with train/va/test splits # 65 | ####################################################### 66 | 67 | # make sure this is formatted correctly e.g. see dict_split functions below 68 | y_df = pd.read_csv(y_fpath) 69 | y_split = dict_split_discr_surv_df(y_df, 70 | time_bin_col='time_bin', 71 | censor_col='censorship', 72 | split_col='split', 73 | time_col='survival_time', 74 | index_col='sample_id') 75 | 76 | y_train = y_split['train'] 77 | y_val = y_split['val'] 78 | n_time_bins = len(set(y_train['time_bin'].values)) 79 | 80 | 81 | # samples in train/test sets 82 | train_samples = y_train.index.values 83 | val_samples = y_val.index.values 84 | 85 | ###################### 86 | # setup for datasets # 87 | ###################### 88 | 89 | # file paths containing features features 90 | train_fpaths = find_fpaths(folders=feat_dir, ext=['h5', 'pt'], 91 | names=train_samples) 92 | val_fpaths = find_fpaths(folders=feat_dir, ext=['h5', 'pt'], 93 | names=val_samples) 94 | 95 | if len(train_fpaths) == 0: 96 | raise RuntimeError("No training files found in {}".format(feat_dir)) 97 | if len(val_fpaths) == 0: 98 | # TODO: maybe warn? 99 | raise RuntimeError("No validation files found in {}".format(feat_dir)) 100 | 101 | dataset_kws = {} 102 | loader_kws = {'num_workers': 4} if device.type == "cuda" else {} 103 | 104 | # the datasets will only use fpaths with a corresponding index in y 105 | dataset_train = BagDataset(fpaths=train_fpaths, y=y_train, task=task, 106 | **dataset_kws) 107 | dataset_val = BagDataset(fpaths=val_fpaths, y=y_val, task=task, 108 | **dataset_kws) 109 | 110 | ####################### 111 | # Train linear Cox PH # 112 | ####################### 113 | print("======================") 114 | print("Training linear Cox PH") 115 | mean_embeddings_train = [] 116 | out_train = [] 117 | for idx in range(len(dataset_train)): 118 | bag, y = dataset_train[idx] 119 | bag_embedding = torch.mean(bag, dim=0).reshape(1, -1) 120 | mean_embeddings_train.append(bag_embedding) 121 | if task == 'discr_surv': 122 | out_train.append((y[1], y[2])) # (censor status, time of event) 123 | else: 124 | out_train.append((y[0], y[1])) 125 | 126 | mean_embeddings_train = torch.cat(mean_embeddings_train, dim=0) 127 | out_train = np.array(out_train) 128 | out_train = Surv.from_arrays(event=out_train[:, 0], time=out_train[:, 1]) 129 | 130 | # Feed meean_embeddings into training cox model 131 | ph = CoxPHSurvivalAnalysis(alpha=args.alpha) 132 | ph.fit(mean_embeddings_train, out_train) 133 | 134 | train_c_index = ph.score(mean_embeddings_train, out_train) 135 | print("Train c-index {}".format(train_c_index)) 136 | 137 | print("======================") 138 | print("Validating Linear Cox PH") 139 | # Check with val data 140 | mean_embeddings_val = [] 141 | out_val = [] 142 | for idx in range(len(dataset_val)): 143 | bag, y = dataset_val[idx] 144 | bag_embedding = torch.mean(bag, dim=0).reshape(1, -1) 145 | mean_embeddings_val.append(bag_embedding) 146 | if task == 'discr_surv': 147 | out_val.append((y[1], y[2])) # (censor status, time of event) 148 | else: 149 | out_val.append((y[0], y[1])) 150 | 151 | mean_embeddings_val = torch.cat(mean_embeddings_val, dim=0) 152 | out_val = np.array(out_val) 153 | out_val = Surv.from_arrays(event=out_val[:, 0], time=out_val[:, 1]) 154 | 155 | val_c_index = ph.score(mean_embeddings_val, out_val) 156 | print("Val c-index {}".format(val_c_index)) 157 | -------------------------------------------------------------------------------- /scripts/run_tune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | import os 4 | from functools import partial 5 | from multiprocessing import Pool 6 | from sklearn.model_selection import ParameterGrid 7 | import numpy as np 8 | import torch 9 | 10 | from var_pool.nn.tune_utils import run_train 11 | from var_pool.script_utils import parse_mil_task_yaml 12 | 13 | parser = argparse.\ 14 | ArgumentParser(description='Runs a grid serach over tuning parameters using a validation set.') 15 | 16 | parser.add_argument('--task_fpath', type=str, 17 | default='yaml/task.yaml', 18 | help='Filepath to .yaml file containing the information ' 19 | 'for this task.') 20 | 21 | 22 | parser.add_argument('--skip_if_exists', 23 | action='store_true', default=False, 24 | help='Skip running a tuning parameter setting if the results file alreay exists. Useful if tuning crashed halfway through.') 25 | 26 | 27 | # architecture 28 | parser.add_argument('--arch', default='amil_slim', type=str, 29 | choices=['amil_slim', 'amil_nn', 'amil_var_nn', 'sum_mil', 'sum_var_mil', 'patchGCN', 'patchGCN_varpool'], 30 | help="Which neural network architecture to use.") 31 | 32 | parser.add_argument('--attn_latent_dim', default=256, type=int, nargs="+", 33 | help='Dimension of the attention latent space.') 34 | 35 | parser.add_argument('--mode', default='patch', type=str, 36 | choices=['patch', 'graph'], 37 | help='Framework paradigm. patch for independent path assumption. graph for contextual-aware algorithms') 38 | 39 | # For everything but amil slim 40 | parser.add_argument('--head_n_hidden_layers', default=1, type=int, nargs="+", 41 | help='Number of hidden layers in the head network (excluding the final output layer).') 42 | 43 | 44 | parser.add_argument('--head_hidden_dim', default=256, type=int, nargs="+", 45 | help='Dimension of the head hidden layers.') 46 | 47 | # optimization 48 | parser.add_argument('--n_epochs', default=100, type=int, 49 | help='Number of training epochs. Note this defaults to 100.') 50 | 51 | parser.add_argument('--grad_accum', default=1, type=int, 52 | help='Number of gradient accumulation steps.') 53 | 54 | 55 | parser.add_argument('--no_early_stopping', 56 | action='store_true', default=False, 57 | help='Do not use early stopping.') 58 | 59 | parser.add_argument('--es_monitor', default='loss', type=str, 60 | choices=['loss', 'metric'], 61 | help='Should early stopping monitor the validation loss or another metric (e.g. c-index, auc).') 62 | 63 | parser.add_argument('--stop_patience', default=20, type=int, 64 | help='Number of patience steps for early stopping.') 65 | 66 | parser.add_argument('--plateau_patience', default=5, type=int, 67 | help='Number of patience steps for ReduceLROnPlateau learning rate scheduler.') 68 | 69 | parser.add_argument('--plateau_factor', default=np.sqrt(.1), type=float, 70 | help='The factor argument for ReduceLROnPlateau.') 71 | 72 | 73 | parser.add_argument('--seed', type=int, default=1, nargs="+", 74 | help='The random seed for training e.g. used to initialized network weights, order of shuffle, etc. Providing multiiple seeds lets you do multiple network initializations.') 75 | 76 | 77 | # Tuning params 78 | parser.add_argument('--batch_size', default=1, type=int, nargs="+", 79 | help='Batch size for training. If this is > 1 then you must set fixed_bag_size.') 80 | 81 | 82 | parser.add_argument('--fixed_bag_size', default=None, nargs="+", 83 | help='Fix the number of instances in each bag for training. E.g. we randomly sample a subset of instances from each bag. This can be used to speed up the training loop. To use a batch size larger than one you must set this value. By passing in fixed_bag_size=max this will automatically set fixed_bag_size to be the largest bag size of the training set.') 84 | 85 | 86 | parser.add_argument('--n_var_pools', default=10, 87 | type=int, nargs="+", 88 | help='Number of projection vectors for variance pooling.') 89 | 90 | parser.add_argument('--var_act_func', 91 | default='log', 92 | type=str, nargs="+", 93 | help='Choice of activation functions for var pooling.') 94 | 95 | parser.add_argument('--lr', default=2e-4, 96 | type=float, nargs="+", 97 | help='The learning rates to try.') 98 | 99 | parser.add_argument('--separate_attn', default=0, 100 | type=int, nargs="+", choices=[0, 1], 101 | help='Separate attn flag. 0 means no separate attn, 1 means yes separate attn. If 0, 1 are provided then this will be tuned over') 102 | 103 | parser.add_argument('--final_nonlin', default='relu', 104 | type=str, nargs='+', 105 | help='Final nonlinearity') 106 | 107 | # For parallelization 108 | parser.add_argument('--num_workers', default=3, 109 | type=int, 110 | help='Number of parallel workers for CV') 111 | 112 | args = parser.parse_args() 113 | 114 | ############################################### 115 | # Setup parameter setting to pass to train.py # 116 | ############################################### 117 | 118 | # fixed parameters that are not tuned over 119 | fixed_params = {'task_fpath': args.task_fpath, 120 | 121 | 'arch': args.arch, 122 | 123 | 'imbal_how': 'resample', 124 | 125 | 'mode': args.mode, 126 | 127 | 'n_epochs': args.n_epochs, 128 | 'grad_accum': args.grad_accum, 129 | 130 | 'lr_scheduler': 'plateau', 131 | 'plateau_factor': args.plateau_factor, 132 | 'plateau_patience': args.plateau_patience, 133 | 'stop_patience': args.stop_patience, 134 | 'es_monitor': args.es_monitor 135 | } 136 | 137 | 138 | fixed_flags = ['dropout'] 139 | if not args.no_early_stopping: 140 | fixed_flags.append('early_stopping') 141 | 142 | # setup parameters to be tuned over 143 | tunable_params = ['lr', 'attn_latent_dim', 144 | 'fixed_bag_size', 'batch_size', 145 | 'seed' # multiple inits, not actually a "tune param"! 146 | ] # every arch gets these 147 | tunable_flags = [] # every arch gets these 148 | if args.arch == 'amil_slim': 149 | pass 150 | 151 | elif args.arch == 'amil_nn': 152 | tunable_params.extend(['head_n_hidden_layers', 'head_hidden_dim']) 153 | 154 | elif args.arch == 'patchGCN': 155 | tunable_params.extend(['head_n_hidden_layers', 'head_hidden_dim']) 156 | 157 | elif args.arch == 'amil_var_nn': 158 | tunable_params.extend(['head_n_hidden_layers', 'head_hidden_dim', 159 | 'var_act_func', 'n_var_pools']) 160 | 161 | tunable_flags.append('separate_attn') 162 | 163 | elif args.arch == 'sum_var_mil': 164 | tunable_params.extend(['head_n_hidden_layers', 'head_hidden_dim', 165 | 'var_act_func', 'n_var_pools']) 166 | 167 | tunable_flags.append('separate_attn') 168 | 169 | else: 170 | pass 171 | 172 | ################################ 173 | # Process tunable params/flags # 174 | ################################ 175 | 176 | # pull out tuning parameter grids from argparse 177 | param_grid = {} 178 | for param_name in tunable_params: 179 | # pull out param settings from args 180 | param_settings = args.__dict__[param_name] 181 | 182 | # if this is None then don't add it 183 | if param_settings is None: 184 | continue 185 | 186 | if isinstance(param_settings, list): 187 | # if multiple parameters were provided we tune over them 188 | param_grid[param_name] = param_settings 189 | else: 190 | # otherwise these are just fixed parameters 191 | fixed_params[param_name] = param_settings 192 | 193 | # pull out tunable flags from argparse 194 | # format them too bool 195 | for flag_name in tunable_flags: 196 | # pull out param settings from args 197 | flag_settings = args.__dict__[flag_name] 198 | 199 | if isinstance(flag_settings, list): 200 | # if multiple parameters were provided we tune over them 201 | param_grid[flag_name] = [bool(int(f)) for f in flag_settings] 202 | else: 203 | # otherwise these are just fixed parameters 204 | flag_on = bool(int(flag_settings)) 205 | if flag_on: 206 | fixed_flags.append(flag_name) 207 | 208 | 209 | ############################# 210 | # Run setup tuning proceses # 211 | ############################# 212 | 213 | # count availabe cuda devices 214 | n_devices = torch.cuda.device_count() 215 | 216 | # load task info from yaml file 217 | feat_dir, y_fpath, train_dir, task = parse_mil_task_yaml(fpath=args.task_fpath) 218 | 219 | # get path to train.py -- should be in same directory as this script 220 | script_dir = pathlib.Path(__file__).parent.resolve() 221 | train_script_fpath = os.path.join(script_dir, 'train.py') 222 | 223 | 224 | pool = Pool(processes=args.num_workers) 225 | 226 | # set the fixed arguments for run_train 227 | # this is callable(tune_params, tune_idx, device) 228 | run_func = partial(run_train, 229 | script_fpath=train_script_fpath, 230 | train_dir=train_dir, 231 | fixed_params=fixed_params, 232 | fixed_flags=fixed_flags, 233 | skip_if_results_exist=args.skip_if_exists) 234 | 235 | # list of dicts of tune param settings 236 | tune_seq = list(ParameterGrid(param_grid)) 237 | n_tune = len(tune_seq) 238 | tune_idxs = range(n_tune) 239 | 240 | print("Tuning over {} settings".format(n_tune)) 241 | 242 | # which device each task goes to 243 | device_list = [idx % n_devices for idx in range(n_tune)] 244 | 245 | # each entry is (tune_params, tune_idx, device) 246 | run_args = list(zip(tune_seq, tune_idxs, device_list)) 247 | 248 | ############## 249 | # Run tuning # 250 | ############## 251 | 252 | pool.starmap(run_func, run_args) 253 | 254 | pool.close() 255 | pool.join() 256 | -------------------------------------------------------------------------------- /scripts/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | 6 | from var_pool.script_utils import parse_mil_task_yaml 7 | from var_pool.mhist.get_model_from_args import get_model 8 | from var_pool.gpu_utils import assign_free_gpus 9 | from var_pool.file_utils import join_and_make 10 | from var_pool.nn.datasets.VisualDatasets import VisualDataset 11 | 12 | from var_pool.visualization.vis_utils import get_top_patches 13 | 14 | parser = argparse.\ 15 | ArgumentParser(description='Creates several visulaizations') 16 | 17 | parser.add_argument('--task_fpath', type=str, 18 | # default to task.yaml file saved in current directory 19 | default='yaml/luad_vs_lusc.yaml', 20 | help='Filepath to .yaml file containing the information ' 21 | 'for this task. It should include entries for: \n' 22 | 'feats_dir = directory containing WSI bag features as .h5 files.\n' 23 | 'y_fpath: csv file containing the response labels for each bag and the train/val/test spilts. See code for how the csv file should be formatted.\n' 24 | 'task: a string indicated which kind of task we are solving. Should be one of "clf", "surv_cov" or "surv_discr"\n' 25 | 'train_dir: directory to where training results are saved e.g. model checkpoints and logging information.') 26 | 27 | parser.add_argument('--save_dir', type=str, default=None, 28 | help='(Optional) Directory where to save the results. If not provided, will use the original train directory from the yaml file.') 29 | 30 | parser.add_argument('--name', type=str, default=None, 31 | help='(Optional) Name of this experiment; the results will be saved in train_dir/name/. If name=time, will name the directory after the current date ane time.') 32 | 33 | parser.add_argument('--arch', type=str, default='amil_slim', 34 | choices=['amil_slim', 'amil_nn', 'amil_var_nn', 'sum_mil', 'sum_var_mil', 'patchGCN', 'patchGCN_varpool'], 35 | help="Which neural network architecture to use.\n" 36 | "'amil_slim' just does attention mean pooling with a final linear layer.\n" 37 | "'amil_nn' does attention mean pooling with an additional neural network layers applied to the instance embeddings and the mean pooled output.") 38 | 39 | 40 | # architecture 41 | parser.add_argument('--final_nonlin', default='relu', type=str, 42 | choices=['relu', 'tanh', 'identity'], 43 | help='Choice of final nonlinearity for architecture.') 44 | 45 | parser.add_argument('--attn_latent_dim', default=256, type=int, 46 | help='Dimension of the attention latent space.') 47 | 48 | parser.add_argument('--freeze_var_epochs', 49 | default=None, type=int, 50 | help='Freeze the variance pooling weights for an initial number of epochs.') 51 | 52 | # For everything but amil slim 53 | parser.add_argument('--head_n_hidden_layers', default=1, type=int, 54 | help='Number of hidden layers in the head network (excluding the final output layer).') 55 | 56 | 57 | parser.add_argument('--head_hidden_dim', default=256, type=int, 58 | help='Dimension of the head hidden layers.') 59 | 60 | parser.add_argument('--dropout', 61 | action='store_true', default=False, 62 | help='Use dropout (p=0.25 by default).') 63 | 64 | 65 | # For var pool 66 | parser.add_argument('--n_var_pools', default=10, type=int, 67 | help='Number of projection vectors for variance pooling.') 68 | 69 | parser.add_argument('--var_act_func', default='sigmoid', type=str, 70 | choices=['sqrt', 'log', 'sigmoid', 'identity'], 71 | help='Activation function for var pooling.') 72 | 73 | parser.add_argument('--separate_attn', 74 | action='store_true', default=False, 75 | help='Use separate attention branches for the mean and variance pools.') 76 | 77 | # For rank loss 78 | parser.add_argument('--rank_loss_phi', default='sigmoid', type=str, 79 | choices=['sigmoid', 'relu'], 80 | help='The phi function for rank loss.') 81 | 82 | # For visualization 83 | parser.add_argument('--n_proj', default=5, type=int, 84 | help='Number of projections to visualize') 85 | 86 | parser.add_argument('--n_patches', default=5, type=int, 87 | help='Number of patches to visualize') 88 | 89 | parser.add_argument('--slide_num', default=0, type=int, 90 | help='Slide number') 91 | 92 | parser.add_argument('--image_dir', type=str, default=None, 93 | help="Where the svs images are saved") 94 | 95 | args = parser.parse_args() 96 | 97 | # load task info from yaml file 98 | feat_dir, y_fpath, train_dir, task = parse_mil_task_yaml(fpath=args.task_fpath) 99 | 100 | ######################## 101 | # Identify/Assign gpus # 102 | ######################## 103 | device = assign_free_gpus() 104 | 105 | # Create folders/paths 106 | # where to load the trained results 107 | train_dir = os.path.join(train_dir, args.name) \ 108 | if args.name is not None else train_dir 109 | checkpoint_dir = join_and_make(train_dir, 'checkpoints') 110 | checkpoint_fpath = os.path.join(checkpoint_dir, 's_checkpoint.pt') 111 | 112 | save_dir = args.save_dir if args.save_dir is not None else train_dir 113 | # Directory to images 114 | assert args.image_dir is not None, "You need to supply .svs images folder" 115 | image_dir = args.image_dir 116 | 117 | ####################### 118 | # Assume single image # 119 | ####################### 120 | dataset_kws = {} 121 | loader_kws = {'num_workers': 4} if device.type == "cuda" else {} 122 | 123 | dataset = VisualDataset(feat_dir, image_dir) 124 | n_bag_feats = dataset.get_feat_dim() 125 | slide_coords, slide_feats = dataset[args.slide_num] 126 | slide_feats = slide_feats.unsqueeze(0) # (1 x n_instances x feat_dim) 127 | 128 | print("Loading the model...") 129 | if task == 'discr_surv': 130 | out_dim = 4 131 | elif task == 'cox_surv': 132 | out_dim = 1 133 | elif task == 'rank_surv': 134 | out_dim = 1 135 | model = get_model(args=args, n_bag_feats=n_bag_feats, out_dim=out_dim) 136 | state_dict = torch.load(checkpoint_fpath) 137 | 138 | # Hack 139 | # state_dict['head.1.bias'] = state_dict.pop('head.2.bias') 140 | # state_dict['head.1.weight'] = state_dict.pop('head.2.weight') 141 | 142 | # Load checkpoints 143 | model.load_state_dict(state_dict) 144 | model.to(device) 145 | 146 | print("Forward pass through the model...") 147 | with torch.no_grad(): 148 | slide_feats = slide_feats.to(device) 149 | z = model(slide_feats) 150 | 151 | print("Query the patches...") 152 | # Obtain scores & projection vectors 153 | if hasattr(model, 'var_pool'): 154 | var_pooled_feats = model.get_variance(slide_feats, normalize=False) 155 | # Identify top proj vectors by variance for now 156 | varpool_val, varpool_idx = torch.topk(var_pooled_feats, 157 | args.n_proj, 158 | largest=True) 159 | 160 | highest_list = [] 161 | lowest_list = [] 162 | # For each projection direction, get highest-scoring patches 163 | for idx, val in list(zip(varpool_idx.squeeze(), varpool_val.squeeze())): 164 | var_proj_vec = model.var_pool.get_projection_vector(idx) 165 | highest_patches = get_top_patches(model.encode(slide_feats), 166 | var_proj_vec, 167 | n_patches=args.n_patches) 168 | 169 | lowest_patches = get_top_patches(model.encode(slide_feats), 170 | var_proj_vec, 171 | n_patches=args.n_patches, 172 | largest=False) 173 | 174 | highest_list.append(highest_patches) 175 | lowest_list.append(lowest_patches) 176 | 177 | ############## 178 | # Load image # 179 | ############## 180 | print("Reading images...") 181 | highest_img_list = [] 182 | for idx, _, _ in highest_list: 183 | img = dataset.read_patch(args.slide_num, slide_coords[idx]) 184 | highest_img_list.append(img) 185 | 186 | lowest_img_list = [] 187 | for idx, _, _ in lowest_list: 188 | img = dataset.read_patch(args.slide_num, slide_coords[idx]) 189 | lowest_img_list.append(img) 190 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | 4 | version = None 5 | with open(os.path.join('var_pool', '__init__.py'), 'r') as fid: 6 | for line in (line.strip() for line in fid): 7 | if line.startswith('__version__'): 8 | version = line.split('=')[1].strip().strip('\'') 9 | break 10 | if version is None: 11 | raise RuntimeError('Could not determine version') 12 | 13 | 14 | install_requires = [] # ['pycox'] 15 | 16 | 17 | setup(name='var_pool', 18 | version=version, 19 | description='Reproduce experiments for variance pooling MICCAI paper.', 20 | author='Iain Carmichael', 21 | author_email='icarmichael@bwh.harvard.edu', 22 | license='MIT', 23 | packages=find_packages(), 24 | install_requires=install_requires, 25 | test_suite='nose.collector', 26 | tests_require=['nose'], 27 | zip_safe=False) 28 | -------------------------------------------------------------------------------- /tcga_scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/tcga_scripts/.DS_Store -------------------------------------------------------------------------------- /tcga_scripts/aggregate_surv_cv_folds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import pandas as pd 4 | from glob import glob 5 | import argparse 6 | 7 | 8 | raise NotImplementedError("Need to rewrite for new folder naming strategies.") 9 | 10 | parser = argparse.\ 11 | ArgumentParser(description='Aggregate CV folds to get summary statistics/') 12 | 13 | parser.add_argument('--folder', type=str, help="Folder containing the fold results as results.") 14 | 15 | args = parser.parse_args() 16 | 17 | # Read in results for each fold 18 | fold_folders = glob(os.path.join(args.tune_dir, 'fold_*/')) 19 | 20 | results = [] 21 | for folder_path in fold_folders: 22 | 23 | # path/fold_IDX/ -> fold_IDX 24 | folder_name = folder_path.split('/')[-2] 25 | fold_idx = int(folder_name.split('_')[1]) 26 | 27 | # load results fpath 28 | results_fpath = os.path.join(folder_path, 'results.yaml') 29 | with open(results_fpath, 'r') as file: 30 | res = yaml.safe_load(file) 31 | res['fold'] = fold_idx 32 | 33 | results.append(res) 34 | 35 | # format results to pandas 36 | results = pd.DataFrame(results) 37 | results = results.set_index('fold') 38 | results = results[['val_c_index', 'val_loss', 39 | 'train_c_index', 'train_loss', 40 | 'runtime']] # reorder columns 41 | 42 | # compute fold summary statistis 43 | results_summary = results.agg(['mean', 'std', 'min', 'max', 'median']).T 44 | 45 | # write output string 46 | output = '' 47 | output += '{} cross-validation results\n'.format(args.subtype) 48 | output += 'Validation c-index avg = {:1.2f}, std={:1.2f}'.\ 49 | format(100 * results_summary.loc['val_c_index', 'mean'], 50 | 100 * results_summary.loc['val_c_index', 'std']) 51 | output += '\n\n\n' 52 | 53 | output += 'Individual fold results' 54 | output += str(results) 55 | output += '\n\nFold summary statistics' 56 | output += str(results_summary) 57 | 58 | print(output) 59 | 60 | save_fpath = os.path.join(args.folder, 'results_agg.txt') 61 | with open(save_fpath, 'w') as f: 62 | f.write(output) 63 | -------------------------------------------------------------------------------- /tcga_scripts/download_tcga_clinical_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import urllib.request 3 | import os 4 | 5 | parser = argparse.\ 6 | ArgumentParser(description='Downloads the TCGA clinical data from (Liu et al, 2018). See https://www.cell.com/cell/fulltext/S0092-8674(18)30229-0.') 7 | 8 | parser.add_argument('--save_dir', 9 | type=str, 10 | help="Where to save the results.") 11 | 12 | parser.add_argument('--overwrite', 13 | action='store_true', default=False, 14 | help="Whether or not to overwrite the file if it exists.") 15 | 16 | parser.add_argument('--merge_coadread_gbmlgg', 17 | action='store_true', default=False, 18 | help="Merge GBM-LGG and COAD-READ.") 19 | 20 | 21 | args = parser.parse_args() 22 | 23 | 24 | url = 'https://www.cell.com/cms/10.1016/j.cell.2018.02.052/attachment/bbf46a06-1fb0-417a-a259-fd47591180e4/mmc1' 25 | 26 | 27 | fname = 'TCGA-CDR.xlsx' # 'mmc1.xlsx' 28 | os.makedirs(args.save_dir, exist_ok=True) 29 | fpath = os.path.join(args.save_dir, fname) 30 | if not os.path.exists(fpath) or args.overwrite: 31 | # TODO: this isn't working... 32 | urllib.request.urlretrieve(url=url, filename=fpath) 33 | 34 | if args.merge_coadread_gbmlgg: 35 | raise NotImplementedError("TODO: add -- we did this manually... lets do it programatically") 36 | -------------------------------------------------------------------------------- /tcga_scripts/make_discr_suvr_splits.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocessing for discrete survival model 3 | 1) make train/val/test split for slides, keeping patients in the same group 4 | 2) save slide survival response data frame. This has columns: ['sample_id', 'time_bin', 'survival_time', 'censorship'] 5 | """ 6 | import os 7 | import argparse 8 | import numpy as np 9 | from joblib import dump 10 | 11 | from var_pool.processing.data_split import train_test_val_split 12 | from var_pool.mhist.tcga_clinical_data import load_patient_data_from_cell_df,\ 13 | restrct_patients_to_avail_slides, broadcast_patient_to_slide 14 | from var_pool.processing.discr_surv_utils import get_discrete_surv_bins 15 | from var_pool.file_utils import find_fpaths 16 | from var_pool.file_utils import get_file_names 17 | 18 | parser = argparse.\ 19 | ArgumentParser(description='Makes the response data frame discritized survial model.') 20 | 21 | parser.add_argument('--tcga_clincal_fpath', type=str, required=True, 22 | help='Where the TCGA clincal data frame is saved.' 23 | 'This should be Table S1 from (Liu et al, 2018)') 24 | 25 | parser.add_argument('--feats_dir', type=str, required=True, 26 | help='Directory containing the slide features as .pt or .h5 files. Used to subset patients who we have data form.') 27 | 28 | parser.add_argument('--save_dir', type=str, required=True, 29 | help='Where the response data frame should be stored.') 30 | 31 | parser.add_argument('--subtype', type=str, default='brca', 32 | help='Which cancer subtype.') 33 | 34 | parser.add_argument('--endpoint', type=str, default='pfi', 35 | choices=['os', 'pfi', 'dfi', 'dss'], 36 | help="Survival endpoint." 37 | "os = Overall Survival," 38 | "pfi = Progression Free Interval," 39 | "dfi = disease free interval," 40 | "dss = Disease Specific Survival.") 41 | 42 | parser.add_argument('--bag_level', type=str, default='patient', 43 | choices=['patient', 'slide'], 44 | help='Are the bags at the slide level or patient level; for the latter we concatenated the bags for each slide for a patient.') 45 | 46 | 47 | parser.add_argument('--prop_train', type=float, default=0.7, 48 | help='The proportion of samples that go in the training dataset. The remaining samples are split evenly between the validation and test sets.') 49 | 50 | parser.add_argument('--no_test_split', 51 | action='store_true', default=False, 52 | help='Only make train/val splits, not test split.') 53 | 54 | parser.add_argument('--seed', type=int, default=1, 55 | help='The random seed for splitting the data.') 56 | 57 | parser.add_argument('--n_bins', type=int, default=4, 58 | help='Number of bins for survival time.') 59 | 60 | args = parser.parse_args() 61 | 62 | 63 | os.makedirs(args.save_dir, exist_ok=True) 64 | response_fpath = os.path.join(args.save_dir, 'discr_survival.csv') 65 | 66 | other_data_fpath = os.path.join(args.save_dir, 'other_data') 67 | 68 | 69 | ################################# 70 | # Load and process patient data # 71 | ################################# 72 | 73 | # column names with survival 74 | event_col = args.endpoint.upper() 75 | time_col = event_col + '.time' 76 | censor_col = 'censorship' 77 | patient_id_name = 'patient_id' 78 | 79 | # load patient clincal data 80 | patient_clincal_df, col_types = \ 81 | load_patient_data_from_cell_df(fpath=args.tcga_clincal_fpath, 82 | subtype=args.subtype, 83 | verbose=False) 84 | print("Loaded clinical data for {} patients from {}". 85 | format(patient_clincal_df.shape[0], args.subtype)) 86 | 87 | # drop patients who are missing this survival data 88 | nan_mask = patient_clincal_df[event_col].isna() | \ 89 | patient_clincal_df[time_col].isna() 90 | patient_clincal_df = patient_clincal_df[~nan_mask] 91 | print("Dropping {} patients with missing survival data, left with {} patients". 92 | format(nan_mask.sum(), patient_clincal_df.shape[0])) 93 | 94 | # slides we have data for 95 | slide_fpaths = find_fpaths(folders=args.feats_dir, ext=['h5', 'pt']) 96 | avail_slide_names = get_file_names(slide_fpaths) 97 | 98 | # restrict patients to those we have slides for 99 | patient_clincal_df, slides_with_clincal_data, match_info = \ 100 | restrct_patients_to_avail_slides(patient_df=patient_clincal_df, 101 | avail_slides=avail_slide_names, 102 | verbose=True) 103 | 104 | print("\n{} patients with both clincal data and slides". 105 | format(patient_clincal_df.shape[0])) 106 | 107 | # create survival response df 108 | patient_surv_df = patient_clincal_df[[time_col, event_col]].copy() 109 | patient_surv_df[event_col] = patient_surv_df[event_col].astype(bool) 110 | patient_surv_df['censorship'] = ~patient_surv_df[event_col] 111 | 112 | print('{}% of patients censored'. 113 | format(100 * patient_surv_df[censor_col].mean())) 114 | 115 | ##################### 116 | # Compute time bins # 117 | ##################### 118 | 119 | # compute discrete surival response data 120 | time_bin_idx, bins = \ 121 | get_discrete_surv_bins(patient_surv_df, 122 | n_bins=args.n_bins, 123 | time_col=time_col, 124 | censor_col=censor_col) 125 | 126 | patient_surv_df['time_bin'] = np.array(time_bin_idx).astype(int) 127 | 128 | ###################################### 129 | # split patients into train/val/test # 130 | ###################################### 131 | 132 | # stratify on time bin X censorship status 133 | # TODO: double check this is what we want to do 134 | time_bin_X_censor = patient_surv_df['time_bin'].astype(str)\ 135 | + '_X_' \ 136 | + patient_surv_df[censor_col].astype(str) 137 | 138 | if args.no_test_split: 139 | val_size = 1 - args.prop_train 140 | test_size = 0 141 | else: 142 | val_size = (1 - args.prop_train) / 2 143 | test_size = (1 - args.prop_train) / 2 144 | 145 | train_idxs, val_idxs, test_idxs = \ 146 | train_test_val_split(n_samples=patient_surv_df.shape[0], 147 | train_size=args.prop_train, 148 | val_size=val_size, 149 | test_size=test_size, 150 | shuffle=True, 151 | random_state=args.seed, 152 | stratify=time_bin_X_censor) 153 | 154 | patient_ids = patient_surv_df.index.values 155 | patient_surv_df['split'] = None 156 | patient_surv_df.loc[patient_ids[train_idxs], 'split'] = 'train' 157 | patient_surv_df.loc[patient_ids[val_idxs], 'split'] = 'val' 158 | patient_surv_df.loc[patient_ids[test_idxs], 'split'] = 'test' 159 | 160 | # Format to standardized names 161 | cols2keep = [time_col, censor_col, 'time_bin', 'split'] 162 | patient_surv_df = patient_surv_df[cols2keep] 163 | patient_surv_df.rename(columns={time_col: 'survival_time', 164 | censor_col: 'censorship' 165 | }, 166 | inplace=True) 167 | 168 | 169 | ############# 170 | # Save data # 171 | ############# 172 | 173 | if args.bag_level == 'patient': 174 | 175 | # standardize name 176 | patient_surv_df.index.name = 'sample_id' 177 | 178 | patient_surv_df.to_csv(response_fpath) 179 | 180 | dump({'time_bins': bins, 181 | 'patient_clincal_df': patient_clincal_df, 182 | 'col_types': col_types, 183 | 'slides_with_clincal_data': slides_with_clincal_data, 184 | 'match_info': match_info}, 185 | filename=other_data_fpath) 186 | 187 | 188 | elif args.bag_level == 'slide': 189 | 190 | ########################### 191 | # Create slide level data # 192 | ########################### 193 | 194 | slide_surv_df =\ 195 | broadcast_patient_to_slide(slide_names=slides_with_clincal_data, 196 | patient_df=patient_surv_df) 197 | 198 | # standardize names 199 | slide_surv_df.index.name = 'sample_id' 200 | slide_surv_df.rename(columns={patient_surv_df.index.name: patient_id_name}) 201 | 202 | slide_surv_df.to_csv(response_fpath) 203 | 204 | dump({'patient_surv_df': patient_surv_df, 205 | 'time_bins': bins, 206 | 'patient_clincal_df': patient_clincal_df, 207 | 'col_types': col_types, 208 | 'slides_with_clincal_data': slides_with_clincal_data, 209 | 'match_info': match_info}, 210 | filename=other_data_fpath) 211 | -------------------------------------------------------------------------------- /tcga_scripts/make_subtype_clf_splits.py: -------------------------------------------------------------------------------- 1 | """ 2 | Preprocessing for classification task between two or more cancer subtypes; make train/val/test splits and save response data frame. 3 | """ 4 | import os 5 | from glob import glob 6 | import pandas as pd 7 | import argparse 8 | from sklearn.model_selection import train_test_split 9 | 10 | from var_pool.file_utils import get_file_names 11 | 12 | parser = argparse.\ 13 | ArgumentParser(description='Makes the response data frame for a toy cancer subtype classification task.') 14 | 15 | parser.add_argument('--folders', nargs='+', required=True, 16 | help='Folders containing the subtype features, one folder per subtype.') 17 | 18 | parser.add_argument('--subtypes', nargs='+', required=True, 19 | help='Names of the subtypes.') 20 | 21 | parser.add_argument('--save_dir', type=str, required=True, 22 | help='Where the response data frame should be stored.') 23 | 24 | args = parser.parse_args() 25 | 26 | assert len(args.folders) == len(args.subtypes), \ 27 | "Make sure to provide the same number of names as folders!"\ 28 | "{} folder arguments and {} subtypes arguments provided".\ 29 | format(len(args.folders), len(args.subtypes)) 30 | 31 | os.makedirs(args.save_dir, exist_ok=True) 32 | 33 | exts = ['pt', 'h5'] 34 | 35 | # make subtype response for available files 36 | y = [] 37 | for subtype, folder in zip(args.subtypes, args.folders): 38 | 39 | # pull out all file paths in the subtype folder matching this extension 40 | fpaths = [] 41 | for ext in exts: 42 | fpaths_this_ext = glob(os.path.join(folder, '*.{}'.format(ext))) 43 | fpaths.extend(fpaths_this_ext) 44 | fnames = get_file_names(fpaths) 45 | 46 | print("{} files found for {}".format(len(fpaths), subtype)) 47 | y.append(pd.Series(subtype, index=fnames, name='label')) 48 | 49 | y = pd.concat(y) 50 | y.index.name = 'sample_id' 51 | y = pd.DataFrame(y) 52 | 53 | # make train test splits 54 | train_idxs, val_idxs = train_test_split(y.index, 55 | train_size=.8, 56 | shuffle=True, 57 | random_state=1, 58 | stratify=y) 59 | 60 | y['split'] = None 61 | y.loc[train_idxs, 'split'] = 'train' 62 | y.loc[val_idxs, 'split'] = 'val' 63 | 64 | 65 | # Save to disk 66 | os.makedirs(args.save_dir, exist_ok=True) 67 | fpath = os.path.join(args.save_dir, 'clf_{}.csv'. 68 | format('_'.join(args.subtypes))) 69 | 70 | y.to_csv(fpath) 71 | -------------------------------------------------------------------------------- /tcga_scripts/make_surv_yaml.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import os 4 | 5 | parser = argparse.\ 6 | ArgumentParser(description='Makes a yaml file for a survival prediction task.') 7 | 8 | 9 | parser.add_argument('--fpath', type=str, required=True, 10 | help='File path for the name of this yaml file.') 11 | 12 | 13 | parser.add_argument('--feats_dir', type=str, required=True, 14 | help='Directory where the features are stored') 15 | 16 | 17 | parser.add_argument('--y_fpath', type=str, required=True, 18 | help='File path to response data frame.') 19 | 20 | parser.add_argument('--train_dir', type=str, required=True, 21 | help='Directory where the training results are stored.') 22 | 23 | parser.add_argument('--task', type=str, required=True, 24 | help="Task") 25 | 26 | 27 | args = parser.parse_args() 28 | 29 | data = {'feats_dir': args.feats_dir, 30 | 'y_fpath': args.y_fpath, 31 | 'task': args.task, 32 | 'train_dir': args.train_dir 33 | } 34 | 35 | 36 | folder = os.path.dirname(args.fpath) 37 | os.makedirs(folder, exist_ok=True) 38 | 39 | with open(args.fpath, 'w') as file: 40 | yaml.dump(data, file) 41 | -------------------------------------------------------------------------------- /tcga_scripts/stat_sig_c_index_cutoff.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import argparse 4 | 5 | from var_pool.processing.discr_surv_utils import get_perm_c_index_quantile 6 | 7 | parser = argparse.\ 8 | ArgumentParser(description='Computes the cutoff for statistically significant c-index for a given task.') 9 | 10 | parser.add_argument('--response_fpath', type=str, 11 | help="Path to response file.") 12 | 13 | parser.add_argument('--subtype', type=str, 14 | choices=['brca', 'gbmlgg', 'blca', 'ucec', 'luad'], 15 | help="Which cancer subtype.") 16 | 17 | parser.add_argument('--endpoint', type=str, 18 | choices=['os', 'pfi', 'dfi', 'dss'], 19 | help="Which survival endpoint.") 20 | 21 | parser.add_argument('--save_dir', type=str, default=None, 22 | help="Folder where to save the cutoff data." 23 | "If not provided, will just print out results.") 24 | 25 | parser.add_argument('--save_stub', type=str, default=None, 26 | help="Name stub for the saved file.") 27 | 28 | args = parser.parse_args() 29 | 30 | if args.save_dir is not None: 31 | os.makedirs(args.save_dir, exist_ok=True) 32 | if args.save_stub is not None: 33 | fname = '{}-c_index_stat_sig.csv'.format(args.save_stub) 34 | else: 35 | fname = 'c_index_stat_sig.csv' 36 | save_fpath = os.path.join(args.save_dir, fname) 37 | 38 | df = pd.read_csv(args.response_fpath) 39 | results = [] 40 | for kind in ['train', 'val', 'test']: 41 | kind_df = df.query("split=='{}'".format(kind)) 42 | 43 | # if no samples in this split then dont bother 44 | if kind_df.shape[0] == 0: 45 | continue 46 | 47 | event = ~kind_df['censorship'].values.astype(bool) 48 | time = kind_df['survival_time'].values 49 | 50 | ci = get_perm_c_index_quantile(event=event, time=time, 51 | n_perm=1000, q=[.5, .95, .99]) 52 | 53 | results.append({'split': kind, 54 | 'null_ci_95': ci[1], 55 | 'null_ci_99': ci[2]}) 56 | 57 | results = pd.DataFrame(results) 58 | 59 | print("C-index statistical significance cutoffs for {}".\ 60 | format(args.response_fpath)) 61 | print(results) 62 | 63 | if args.save_dir is not None: 64 | results.to_csv(save_fpath) 65 | -------------------------------------------------------------------------------- /tcga_scripts/viz_top_patches-extremes_only.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from itertools import chain 3 | import os 4 | import numpy as np # TODO: this fixes weird import error with torch 5 | import torch 6 | 7 | 8 | from var_pool.mhist.get_model_from_args import get_model 9 | from var_pool.mhist.tcga_agg_slides_to_patient_level import \ 10 | tcga_agg_slides_to_patient_level 11 | from var_pool.file_utils import find_fpaths, get_file_names 12 | from var_pool.viz.top_attn import viz_top_attn_patches 13 | from var_pool.viz.var_pool_extremes import viz_top_var_proj_patches 14 | from var_pool.nn.arch.VarPool import AttnMeanAndVarPoolMIL 15 | 16 | parser = argparse.\ 17 | ArgumentParser(description='Visualizes the top attended patches') 18 | 19 | parser.add_argument('--autogen_fpath', type=str, 20 | help='Path to autgen file.') 21 | 22 | parser.add_argument('--checkpoint_fpath', type=str, 23 | help='Path to model checkpoint.') 24 | 25 | parser.add_argument('--wsi_dir', type=str, 26 | help='Directory containing WSIs.') 27 | 28 | parser.add_argument('--feat_h5_dir', type=str, 29 | help='Directory containing feature hdf5 files.') 30 | 31 | parser.add_argument('--high_risk', type=str, nargs='+', 32 | help='List of high risk patients.') 33 | 34 | parser.add_argument('--low_risk', type=str, nargs='+', 35 | help='List of low risk patients.') 36 | 37 | parser.add_argument('--save_dir', type=str, 38 | help='Where to save the images.') 39 | 40 | 41 | ################################# 42 | # Specify architecture of model # 43 | ################################# 44 | parser.add_argument('--arch', type=str, default='amil_slim', 45 | choices=['amil_slim', 'amil_nn', 'amil_var_nn', 'sum_mil', 'sum_var_mil', 'patchGCN', 'patchGCN_varpool'], 46 | help="Which neural network architecture to use.\n" 47 | "'amil_slim' just does attention mean pooling with a final linear layer.\n" 48 | "'amil_nn' does attention mean pooling with an additional neural network layers applied to the instance embeddings and the mean pooled output.") 49 | 50 | parser.add_argument('--final_nonlin', default='relu', type=str, 51 | choices=['relu', 'tanh', 'identity'], 52 | help='Choice of final nonlinearity for architecture.') 53 | 54 | parser.add_argument('--attn_latent_dim', default=256, type=int, 55 | help='Dimension of the attention latent space.') 56 | 57 | parser.add_argument('--head_n_hidden_layers', default=1, type=int, 58 | help='Number of hidden layers in the head network (excluding the final output layer).') 59 | 60 | 61 | parser.add_argument('--head_hidden_dim', default=256, type=int, 62 | help='Dimension of the head hidden layers.') 63 | 64 | parser.add_argument('--dropout', 65 | action='store_true', default=False, 66 | help='Use dropout (p=0.25 by default).') 67 | 68 | 69 | # For var pool 70 | parser.add_argument('--n_var_pools', default=10, type=int, 71 | help='Number of projection vectors for variance pooling.') 72 | 73 | parser.add_argument('--var_act_func', default='sigmoid', type=str, 74 | choices=['sqrt', 'log', 'sigmoid', 'identity'], 75 | help='Activation function for var pooling.') 76 | 77 | parser.add_argument('--separate_attn', 78 | action='store_true', default=False, 79 | help='Use separate attention branches for the mean and variance pools.') 80 | 81 | args = parser.parse_args() 82 | 83 | print(args) 84 | 85 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 86 | print("Using device: ", device) 87 | 88 | n_top_patches = 20 89 | 90 | ############## 91 | # Load stuff # 92 | ############## 93 | 94 | n_bag_feats = 1024 95 | out_dim = 1 # TODO: allow to specify for different tasks 96 | 97 | 98 | # Load model 99 | model = get_model(args=args, n_bag_feats=n_bag_feats, out_dim=out_dim) 100 | 101 | 102 | state_dict = torch.load(args.checkpoint_fpath, 103 | map_location=device) 104 | model.load_state_dict(state_dict) 105 | model.to(device) 106 | model.eval() 107 | 108 | 109 | # aggregate feature fpaths per patient 110 | feat_fpaths = find_fpaths(folders=args.feat_h5_dir, ext=['h5']) 111 | feat_fpaths = tcga_agg_slides_to_patient_level(feat_fpaths) 112 | assert len(feat_fpaths) > 0, "No feature files found" 113 | 114 | patient_iter = chain(zip(['low_risk'] * len(args.low_risk), args.low_risk), 115 | zip(['high_risk'] * len(args.high_risk), args.high_risk) 116 | ) 117 | 118 | for risk, patient_id in patient_iter: 119 | 120 | # get slide names for this patient 121 | h5_fpath = feat_fpaths[patient_id] 122 | slide_names = get_file_names(h5_fpath) 123 | wsi_fpath = [os.path.join(args.wsi_dir, name + '.svs') 124 | for name in slide_names] 125 | 126 | ######################### 127 | # Top atteneded patches # 128 | ######################### 129 | 130 | save_fpath = os.path.join(args.save_dir, 'top_attn', risk, 131 | patient_id + '.png') 132 | 133 | viz_top_attn_patches(model=model, 134 | wsi_fpath=wsi_fpath, 135 | h5_fpath=h5_fpath, 136 | autogen_fpath=args.autogen_fpath, 137 | device=device, 138 | n_top_patches=n_top_patches, 139 | save_fpath=save_fpath) 140 | 141 | ######################## 142 | # Variance projections # 143 | ######################## 144 | 145 | if not isinstance(model, AttnMeanAndVarPoolMIL): 146 | continue 147 | 148 | save_dir = os.path.join(args.save_dir, 'var_proj_extremes', risk) 149 | 150 | viz_top_var_proj_patches(model=model, 151 | wsi_fpath=wsi_fpath, 152 | h5_fpath=h5_fpath, 153 | autogen_fpath=args.autogen_fpath, 154 | device=device, 155 | with_attn=True, 156 | n_top_patches=n_top_patches, 157 | save_dir=save_dir, 158 | fname_stub=patient_id) 159 | 160 | save_dir = os.path.join(args.save_dir, 'var_proj_extremes-no_attn', risk) 161 | 162 | viz_top_var_proj_patches(model=model, 163 | wsi_fpath=wsi_fpath, 164 | h5_fpath=h5_fpath, 165 | autogen_fpath=args.autogen_fpath, 166 | device=device, 167 | with_attn=False, 168 | n_top_patches=n_top_patches, 169 | save_dir=save_dir, 170 | fname_stub=patient_id) 171 | -------------------------------------------------------------------------------- /tcga_scripts/viz_top_patches.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from itertools import chain 3 | import os 4 | import numpy as np # this fixes weird import error with torch 5 | import torch 6 | from tqdm import tqdm 7 | 8 | from var_pool.mhist.get_model_from_args import get_model 9 | from var_pool.mhist.tcga_agg_slides_to_patient_level import \ 10 | tcga_agg_slides_to_patient_level 11 | from var_pool.file_utils import find_fpaths, get_file_names 12 | from var_pool.viz.top_attn import viz_top_attn_patches 13 | from var_pool.viz.var_pool_extremes import viz_var_proj_patches_quantiles 14 | from var_pool.nn.arch.VarPool import AttnMeanAndVarPoolMIL 15 | 16 | parser = argparse.\ 17 | ArgumentParser(description='Visualizes the top attended patches') 18 | 19 | parser.add_argument('--autogen_fpath', type=str, 20 | help='Path to autgen file.') 21 | 22 | parser.add_argument('--checkpoint_fpath', type=str, 23 | help='Path to model checkpoint.') 24 | 25 | parser.add_argument('--wsi_dir', type=str, 26 | help='Directory containing WSIs.') 27 | 28 | parser.add_argument('--feat_h5_dir', type=str, 29 | help='Directory containing feature hdf5 files.') 30 | 31 | parser.add_argument('--high_risk', type=str, nargs='+', 32 | help='List of high risk patients.') 33 | 34 | parser.add_argument('--low_risk', type=str, nargs='+', 35 | help='List of low risk patients.') 36 | 37 | parser.add_argument('--save_dir', type=str, 38 | help='Where to save the images.') 39 | 40 | 41 | ################################# 42 | # Specify architecture of model # 43 | ################################# 44 | parser.add_argument('--arch', type=str, default='amil_slim', 45 | choices=['amil_slim', 'amil_nn', 'amil_var_nn', 'sum_mil', 'sum_var_mil', 'patchGCN', 'patchGCN_varpool'], 46 | help="Which neural network architecture to use.\n" 47 | "'amil_slim' just does attention mean pooling with a final linear layer.\n" 48 | "'amil_nn' does attention mean pooling with an additional neural network layers applied to the instance embeddings and the mean pooled output.") 49 | 50 | parser.add_argument('--final_nonlin', default='relu', type=str, 51 | choices=['relu', 'tanh', 'identity'], 52 | help='Choice of final nonlinearity for architecture.') 53 | 54 | parser.add_argument('--attn_latent_dim', default=256, type=int, 55 | help='Dimension of the attention latent space.') 56 | 57 | parser.add_argument('--head_n_hidden_layers', default=1, type=int, 58 | help='Number of hidden layers in the head network (excluding the final output layer).') 59 | 60 | 61 | parser.add_argument('--head_hidden_dim', default=256, type=int, 62 | help='Dimension of the head hidden layers.') 63 | 64 | parser.add_argument('--dropout', 65 | action='store_true', default=False, 66 | help='Use dropout (p=0.25 by default).') 67 | 68 | 69 | # For var pool 70 | parser.add_argument('--n_var_pools', default=10, type=int, 71 | help='Number of projection vectors for variance pooling.') 72 | 73 | parser.add_argument('--var_act_func', default='sigmoid', type=str, 74 | choices=['sqrt', 'log', 'sigmoid', 'identity'], 75 | help='Activation function for var pooling.') 76 | 77 | parser.add_argument('--separate_attn', 78 | action='store_true', default=False, 79 | help='Use separate attention branches for the mean and variance pools.') 80 | 81 | args = parser.parse_args() 82 | 83 | print(args) 84 | 85 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 86 | print("Using device: ", device) 87 | 88 | n_top_patches = 10 89 | 90 | ############## 91 | # Load stuff # 92 | ############## 93 | 94 | n_bag_feats = 1024 95 | out_dim = 1 # TODO: allow to specify for different tasks 96 | 97 | 98 | # Load model 99 | model = get_model(args=args, n_bag_feats=n_bag_feats, out_dim=out_dim) 100 | 101 | 102 | state_dict = torch.load(args.checkpoint_fpath, 103 | map_location=device) 104 | model.load_state_dict(state_dict) 105 | model.to(device) 106 | model.eval() 107 | 108 | 109 | # aggregate feature fpaths per patient 110 | feat_fpaths = find_fpaths(folders=args.feat_h5_dir, ext=['h5']) 111 | feat_fpaths = tcga_agg_slides_to_patient_level(feat_fpaths) 112 | assert len(feat_fpaths) > 0, "No feature files found" 113 | 114 | patient_iter = chain(zip(['low_risk'] * len(args.low_risk), args.low_risk), 115 | zip(['high_risk'] * len(args.high_risk), args.high_risk) 116 | ) 117 | 118 | for risk, patient_id in tqdm(list(patient_iter)): 119 | 120 | # get slide names for this patient 121 | h5_fpath = feat_fpaths[patient_id] 122 | slide_names = get_file_names(h5_fpath) 123 | wsi_fpath = [os.path.join(args.wsi_dir, name + '.svs') 124 | for name in slide_names] 125 | 126 | ######################### 127 | # Top atteneded patches # 128 | ######################### 129 | 130 | save_fpath = os.path.join(args.save_dir, 'top_attn', risk, 131 | patient_id + '.png') 132 | 133 | # TODO: uncomment 134 | # viz_top_attn_patches(model=model, 135 | # wsi_fpath=wsi_fpath, 136 | # h5_fpath=h5_fpath, 137 | # autogen_fpath=args.autogen_fpath, 138 | # device=device, 139 | # n_top_patches=n_top_patches, 140 | # save_fpath=save_fpath) 141 | 142 | ######################## 143 | # Variance projections # 144 | ######################## 145 | 146 | if not isinstance(model, AttnMeanAndVarPoolMIL): 147 | continue 148 | 149 | # quantiles = [0, 25, 50, 75, 100] 150 | quantiles = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100] 151 | 152 | save_dir = os.path.join(args.save_dir, 'var_proj_extremes', risk) 153 | 154 | viz_var_proj_patches_quantiles(model=model, 155 | wsi_fpath=wsi_fpath, 156 | h5_fpath=h5_fpath, 157 | autogen_fpath=args.autogen_fpath, 158 | device=device, 159 | with_attn=True, 160 | quantiles=quantiles, 161 | n_to_show=n_top_patches, 162 | save_dir=save_dir, 163 | name=patient_id) 164 | 165 | save_dir = os.path.join(args.save_dir, 'var_proj_extremes-no_attn', risk) 166 | 167 | viz_var_proj_patches_quantiles(model=model, 168 | wsi_fpath=wsi_fpath, 169 | h5_fpath=h5_fpath, 170 | autogen_fpath=args.autogen_fpath, 171 | device=device, 172 | with_attn=False, 173 | quantiles=quantiles, 174 | n_to_show=n_top_patches, 175 | save_dir=save_dir, 176 | name=patient_id) 177 | -------------------------------------------------------------------------------- /var_pool/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/.DS_Store -------------------------------------------------------------------------------- /var_pool/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.0.0' 2 | -------------------------------------------------------------------------------- /var_pool/file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | import pathlib 4 | 5 | 6 | def join_and_make(a, *p): 7 | """ 8 | new_dir = os.path.join(a, *p) 9 | os.makedirs(new_dir, exist_ok=True) 10 | """ 11 | new_dir = os.path.join(a, *p) 12 | os.makedirs(new_dir, exist_ok=True) 13 | return new_dir 14 | 15 | 16 | def find_fpaths(folders, ext=['h5', 'pt'], names=None): 17 | """ 18 | Gets all file paths from a folder or set of folders. 19 | 20 | Parameters 21 | ---------- 22 | folders: str, list of str 23 | The folder or folders to check. 24 | 25 | ext: str or list of str 26 | Only get files ending in this extension. 27 | 28 | names: None, str 29 | (Optional) Subset of files to include in this dataset. 30 | 31 | Output 32 | ------ 33 | fpaths: list of str 34 | The discovered file names ordered alphabetically by file name. 35 | """ 36 | if isinstance(ext, str): 37 | ext = [ext] 38 | 39 | # format to list of str 40 | fpaths = [] 41 | if isinstance(folders, str): 42 | folders = [folders] 43 | 44 | # find all available files with given extension 45 | for fd in folders: 46 | for e in ext: 47 | # find files in each feature directory 48 | fps = glob(os.path.join(fd, '*.{}'.format(e))) 49 | fpaths.extend(fps) 50 | 51 | # maybe subset to user specified file names 52 | if names is not None: 53 | fpaths = check_guest_list(restr=names, avail_fpaths=fpaths, 54 | drop_ext=True) 55 | 56 | # sort files alphabetically 57 | fpaths = sorted(fpaths, key=lambda p: os.path.basename(p)) 58 | 59 | return fpaths 60 | 61 | 62 | def check_guest_list(restr, avail_fpaths, drop_ext=False): 63 | """ 64 | Returns a list of file paths from available files given a subset of file names to restrict ourselves to. 65 | 66 | Parameters 67 | ---------- 68 | restr: list of str 69 | Names of the files we want to restrict ourselves to i.e. a guest list. 70 | 71 | avail_fpaths: list of str 72 | List of available file paths. 73 | 74 | drop_ext: bool 75 | Whether or not to drop the extension from the file paths in avail_fpaths. 76 | 77 | Output 78 | ------ 79 | fpaths: list of str 80 | A subset of avail_fpaths who were on the guest list. 81 | """ 82 | # get file names in case a full path was provided 83 | restr_names = set([os.path.basename(p) for p in restr]) 84 | 85 | fpaths_ret = [] 86 | for fpath in avail_fpaths: 87 | # pull out name of this available file 88 | if drop_ext: 89 | name = pathlib.Path(fpath).stem 90 | else: 91 | name = os.path.basename(fpath) 92 | 93 | # check if name is on the guest list 94 | if name in restr_names: 95 | fpaths_ret.append(fpath) 96 | 97 | return fpaths_ret 98 | 99 | 100 | def safe_drop_suffix(s, suffix): 101 | """ 102 | Drops a suffix if a string ends in the suffix. 103 | """ 104 | if s.endswith(suffix): 105 | return s[:-len(suffix)] 106 | else: 107 | return s 108 | 109 | 110 | def get_file_names(fpaths): 111 | """ 112 | Gets the file names from a list of file paths without their extensions. 113 | 114 | Parameters 115 | ---------- 116 | fpaths: list of str, str 117 | The file paths. 118 | 119 | Output 120 | ------ 121 | fnames: list of str, str 122 | The file names without their extension. 123 | """ 124 | if isinstance(fpaths, str): 125 | return pathlib.Path(fpaths).stem 126 | else: 127 | return [pathlib.Path(fpath).stem for fpath in fpaths] 128 | -------------------------------------------------------------------------------- /var_pool/gpu_utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import numpy as np 3 | import torch 4 | import time 5 | 6 | 7 | def assign_free_gpus(max_gpus=3): 8 | """ 9 | Identify least-utilized GPU and return the corresponding ID 10 | 11 | Parameters 12 | ---------- 13 | max_gpus (int, optional): Max GPUs is the maximum number of gpus to assign. 14 | Defaults to 3. 15 | """ 16 | 17 | # Get the list of GPUs via nvidia-smi 18 | smi_query_result = subprocess.check_output('nvidia-smi -q -d Memory | grep -A4 GPU', shell=True) 19 | # Extract the usage information 20 | gpu_info = smi_query_result.decode('utf-8').split('\n') 21 | gpu_info = list(filter(lambda info: 'Used' in info, gpu_info)) 22 | gpu_info = np.array([int(x.split(':')[1].replace('MiB', '').strip()) for x in gpu_info]) # Remove garbage 23 | 24 | # Delay by random amount to prevent choosing same gpus 25 | t_sleep = np.random.uniform(0, 30) 26 | print("Sleeping for ", t_sleep) 27 | time.sleep(t_sleep) 28 | 29 | # Tie breaking 30 | indices = np.where(gpu_info == gpu_info.min())[0] 31 | gpu_id = np.random.choice(indices) 32 | # if len(indices) == 1: 33 | # gpu_id = gpu_info.min() 34 | # else: 35 | # gpu_id = np.random.choice(np.arange(max_gpus, dtype=int)) 36 | 37 | device = torch.device("cuda:{}".format(gpu_id) if torch.cuda.is_available() else "cpu") 38 | print("----------------------------------") 39 | print("Identified GPU with minimal usage ", np.where(gpu_info == gpu_info.min())[0]) 40 | print("Training with device {}".format(device)) 41 | print("----------------------------------") 42 | return device 43 | -------------------------------------------------------------------------------- /var_pool/mhist/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/mhist/.DS_Store -------------------------------------------------------------------------------- /var_pool/mhist/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/mhist/__init__.py -------------------------------------------------------------------------------- /var_pool/mhist/clinical_data_porpoise.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import os 3 | import numpy as np 4 | 5 | from var_pool.file_utils import safe_drop_suffix 6 | 7 | avail_subtypes = ['blca', 'brca', 'coadread', 'gbmlgg', 'hnsc', 8 | 'kirc', 'kirp', 'lihc', 'luad', 9 | 'lusc', 'paad', 'skcm', 'stad', 'ucec'] 10 | 11 | 12 | def download_tcga_clinical_data(subtype, save_dir=None, verbose=True): 13 | """ 14 | Downloads TCGA clinical data from the PORPOSE github repo https://github.com/mahmoodlab/PORPOISE/tree/master/dataset_csv/ 15 | 16 | Parameters 17 | ---------- 18 | subtype: str 19 | Which cancer subtype e.g. ['blca', 'brca', ...]. 20 | 21 | save_dir: None, str 22 | (Optional) Directory where to save the csv file. 23 | 24 | Output 25 | ------ 26 | df: pd.DataFrame 27 | """ 28 | # download csv file 29 | url = 'https://raw.githubusercontent.com/mahmoodlab/PORPOISE/master/dataset_csv/tcga_{}_all.csv'.format(subtype) 30 | df = pd.read_csv(url) 31 | df = df.rename(columns={'Unnamed: 0': 'case_id'}) 32 | 33 | # maybe save to disk 34 | if save_dir is not None: 35 | os.makedirs(save_dir, exist_ok=True) 36 | fpath = os.path.join(save_dir, 'tcga_{}_all.csv'.format(subtype)) 37 | df.to_csv(fpath, index=False) 38 | 39 | if verbose: 40 | n_slides = len(np.unique(df['slide_id'])) 41 | n_cases = len(np.unique(df['case_id'])) 42 | 43 | print('Downloading clinical data for {}'.format(subtype)) 44 | print("Clinical shape {} with {} unique slides and {} unique case ids".\ 45 | format(df.shape, n_slides, n_cases)) 46 | 47 | return df 48 | 49 | 50 | def load_clinical_data(save_dir, subtype, verbose=True): 51 | """ 52 | Loads TCGA clinical data. 53 | 54 | Parameters 55 | ---------- 56 | save_dir: None, str 57 | Directory where to save the csv file. 58 | 59 | subtype: str 60 | Which cancer type e.g. ['blca', 'brca', ...]. 61 | 62 | Output 63 | ------ 64 | df: pd.DataFrame 65 | The clinical data file with slide_id as the index. 66 | """ 67 | # load file 68 | fpath = os.path.join(save_dir, 'tcga_{}_all.csv'.format(subtype)) 69 | df = pd.read_csv(fpath) 70 | 71 | # make slide id the index 72 | n_slides = len(np.unique(df['slide_id'])) 73 | assert n_slides == df.shape[0] 74 | 75 | # make slide id the index and drop file extension from slide ids 76 | df = df.set_index('slide_id') 77 | df.index = [safe_drop_suffix(s=i, suffix='.svs') for i in df.index] 78 | 79 | if verbose: 80 | n_cases = len(np.unique(df['case_id'])) 81 | 82 | print('Clinical data for {}: shape {} with '\ 83 | '{} unique slides and {} unique case ids'.\ 84 | format(subtype, df.shape, n_slides, n_cases)) 85 | 86 | return df 87 | -------------------------------------------------------------------------------- /var_pool/mhist/get_model_from_args.py: -------------------------------------------------------------------------------- 1 | from warnings import warn 2 | 3 | from var_pool.mhist.get_model import ( 4 | get_model_with_nn_layers, 5 | get_model_slim, 6 | get_model_varpool, 7 | get_model_summil, 8 | get_model_summil_varpool, 9 | get_model_patchGCN, 10 | get_model_patchGCN_varpool, 11 | get_model_MIL_Graph_FC, 12 | get_model_MIL_Graph_FC_varpool 13 | ) 14 | 15 | # from var_pool.mhist.get_model_with_switch import get_model_varpool_with_switch 16 | 17 | 18 | def get_model(args, n_bag_feats, out_dim): 19 | 20 | # load model object 21 | if args.arch == "amil_nn": 22 | return get_model_with_nn_layers( 23 | n_bag_feats=n_bag_feats, 24 | out_dim=out_dim, 25 | final_nonlin=args.final_nonlin, 26 | dropout=args.dropout, 27 | head_n_hidden_layers=args.head_n_hidden_layers, 28 | attn_latent_dim=args.attn_latent_dim, 29 | head_hidden_dim=args.head_hidden_dim, 30 | ) 31 | 32 | elif args.arch == "amil_slim": 33 | return get_model_slim( 34 | n_bag_feats=n_bag_feats, 35 | out_dim=out_dim, 36 | final_nonlin=args.final_nonlin, 37 | dropout=args.dropout, 38 | attn_latent_dim=args.attn_latent_dim, 39 | ) 40 | 41 | elif args.arch == "sum_mil": 42 | return get_model_summil( 43 | n_bag_feats=n_bag_feats, 44 | out_dim=out_dim, 45 | final_nonlin=args.final_nonlin, 46 | dropout=args.dropout, 47 | head_n_hidden_layers=args.head_n_hidden_layers, 48 | attn_latent_dim=args.attn_latent_dim, 49 | head_hidden_dim=args.head_hidden_dim, 50 | ) 51 | 52 | elif args.arch == "sum_var_mil": 53 | if args.freeze_var_epochs is not None and args.freeze_var_epochs > 0: 54 | warn("Var freezing not implemented for sum_var_mil") 55 | 56 | return get_model_summil_varpool( 57 | n_bag_feats=n_bag_feats, 58 | out_dim=out_dim, 59 | n_var_pools=args.n_var_pools, 60 | var_act_func=args.var_act_func, 61 | dropout=args.dropout, 62 | head_n_hidden_layers=args.head_n_hidden_layers, 63 | attn_latent_dim=args.attn_latent_dim, 64 | head_hidden_dim=args.head_hidden_dim, 65 | ) 66 | 67 | elif args.arch == "amil_var_nn": 68 | # Turn off switching for now 69 | return get_model_varpool( 70 | n_bag_feats=n_bag_feats, 71 | out_dim=out_dim, 72 | final_nonlin=args.final_nonlin, 73 | n_var_pools=args.n_var_pools, 74 | var_act_func=args.var_act_func, 75 | separate_attn=args.separate_attn, 76 | dropout=args.dropout, 77 | head_n_hidden_layers=args.head_n_hidden_layers, 78 | attn_latent_dim=args.attn_latent_dim, 79 | head_hidden_dim=args.head_hidden_dim, 80 | ) 81 | # return get_model_varpool_with_switch( 82 | # n_bag_feats=n_bag_feats, 83 | # out_dim=out_dim, 84 | # final_nonlin=args.final_nonlin, 85 | # n_var_pools=args.n_var_pools, 86 | # var_act_func=args.var_act_func, 87 | # separate_attn=args.separate_attn, 88 | # dropout=args.dropout, 89 | # head_n_hidden_layers=args.head_n_hidden_layers, 90 | # attn_latent_dim=args.attn_latent_dim, 91 | # head_hidden_dim=args.head_hidden_dim, 92 | # ) 93 | 94 | elif args.arch == "patchGCN": 95 | assert args.mode == "graph", "Patch GCN needs to be on graph mode" 96 | return get_model_patchGCN( 97 | n_bag_feats=n_bag_feats, 98 | out_dim=out_dim, 99 | final_nonlin=args.final_nonlin, 100 | dropout=args.dropout, 101 | ) 102 | 103 | elif args.arch == "patchGCN_varpool": 104 | if args.freeze_var_epochs is not None and args.freeze_var_epochs > 0: 105 | warn("Var freezing not implemented for patchGCN_varpool") 106 | 107 | assert args.mode == "graph", "Patch GCN needs to be on graph mode" 108 | return get_model_patchGCN_varpool( 109 | n_bag_feats=n_bag_feats, 110 | out_dim=out_dim, 111 | n_var_pools=args.n_var_pools, 112 | var_act_func=args.var_act_func, 113 | final_nonlin=args.final_nonlin, 114 | dropout=args.dropout, 115 | ) 116 | 117 | elif args.arch == 'amil_gcn': 118 | assert args.mode == "graph", "AMIL GCN needs to be on graph mode" 119 | return get_model_MIL_Graph_FC( 120 | n_bag_feats=n_bag_feats, 121 | out_dim=out_dim, 122 | final_nonlin=args.final_nonlin, 123 | dropout=args.dropout, 124 | ) 125 | 126 | elif args.arch == 'amil_gcn_varpool': 127 | assert args.mode == "graph", "AMIL GCN needs to be on graph mode" 128 | return get_model_MIL_Graph_FC_varpool( 129 | n_bag_feats=n_bag_feats, 130 | out_dim=out_dim, 131 | n_var_pools=args.n_var_pools, 132 | var_act_func=args.var_act_func, 133 | final_nonlin=args.final_nonlin, 134 | dropout=args.dropout, 135 | ) 136 | 137 | else: 138 | raise NotImplementedError("Not implemented for {}!".format(args.arch)) 139 | -------------------------------------------------------------------------------- /var_pool/mhist/get_model_with_switch.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from var_pool.nn.arch.VarPool_switch import AttnMeanAndVarPoolMIL_with_switch as AMVPool_with_switch 4 | 5 | from var_pool.mhist.get_model import _get_head_network 6 | 7 | 8 | def get_switch_parts(n_bag_feats, encoder_dim, out_dim, n_var_pools, 9 | head_hidden_dim, head_n_hidden_layers, 10 | final_nonlin, dropout): 11 | 12 | ############################### 13 | # First layer of head network # 14 | ############################### 15 | if head_n_hidden_layers == 0: 16 | head_hidden_dim = out_dim 17 | 18 | mean_neck = nn.Linear(encoder_dim, head_hidden_dim) 19 | var_neck = nn.Linear(n_var_pools, head_hidden_dim) 20 | 21 | ######################## 22 | # Rest of head network # 23 | ######################## 24 | 25 | if head_n_hidden_layers >= 1: 26 | 27 | # head transforms the concatenation of the mean pool and variance pool 28 | head = _get_head_network(encoder_dim=head_hidden_dim, 29 | out_dim=out_dim, 30 | dropout=dropout, 31 | head_n_hidden_layers=head_n_hidden_layers - 1, 32 | head_hidden_dim=head_hidden_dim, 33 | final_nonlin=final_nonlin, 34 | as_list=True) 35 | 36 | prefix = [nn.ReLU()] 37 | if dropout: 38 | prefix.append(nn.Dropout(0.25)) 39 | head = prefix + head 40 | 41 | head = nn.Sequential(*head) 42 | else: 43 | head = None 44 | 45 | return head, mean_neck, var_neck 46 | 47 | 48 | def get_model_varpool_with_switch(n_bag_feats, 49 | out_dim, 50 | n_var_pools, 51 | var_act_func, 52 | separate_attn, 53 | dropout, 54 | final_nonlin='relu', 55 | head_n_hidden_layers=1, 56 | attn_latent_dim=256, 57 | head_hidden_dim=256): 58 | """ 59 | Create a Varpool + attention pool model. 60 | 61 | 62 | Parameters 63 | ---------- 64 | n_bag_feats: int 65 | 66 | out_dim: int 67 | 68 | n_var_pools: int 69 | 70 | var_act_func: str 71 | 72 | separate_attn: bool 73 | 74 | dropout: bool 75 | 76 | final_nonlin: str 77 | 78 | head_n_hidden_layers: int 79 | 80 | attn_latent_dim: int 81 | 82 | head_hidden_dim: int 83 | 84 | Output 85 | ------ 86 | nn.Module 87 | """ 88 | 89 | encoder_dim = 512 90 | 91 | # additional MLP encoder for the instance features 92 | instance_encoder = [nn.Linear(n_bag_feats, encoder_dim), nn.ReLU()] 93 | if dropout: 94 | instance_encoder.append(nn.Dropout(p=0.25)) 95 | instance_encoder = nn.Sequential(*instance_encoder) 96 | 97 | head, mean_neck, var_neck = \ 98 | get_switch_parts(n_bag_feats=n_bag_feats, 99 | encoder_dim=encoder_dim, 100 | out_dim=out_dim, 101 | n_var_pools=n_var_pools, 102 | head_hidden_dim=head_hidden_dim, 103 | head_n_hidden_layers=head_n_hidden_layers, 104 | final_nonlin=final_nonlin, 105 | dropout=dropout) 106 | 107 | return AMVPool_with_switch(encoder_dim=encoder_dim, 108 | encoder=instance_encoder, 109 | mean_neck=mean_neck, 110 | var_neck=var_neck, 111 | head=head, 112 | n_attn_latent=attn_latent_dim, 113 | gated=True, 114 | separate_attn=bool(separate_attn), 115 | n_var_pools=n_var_pools, 116 | act_func=var_act_func, 117 | dropout=dropout) 118 | -------------------------------------------------------------------------------- /var_pool/mhist/patch_gcn_arch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modules from https://github.com/mahmoodlab/Patch-GCN/ 3 | """ 4 | import torch.nn as nn 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | class MIL_Attention_FC_surv(nn.Module): 10 | def __init__(self, size_arg="small", dropout=0.25, n_classes=4): 11 | r""" 12 | Attention MIL Implementation 13 | Args: 14 | size_arg (str): Size of NN architecture (Choices: small or large) 15 | dropout (float): Dropout rate 16 | n_classes (int): Output shape of NN 17 | """ 18 | super(MIL_Attention_FC_surv, self).__init__() 19 | self.size_dict_path = {"small": [1024, 512, 256], 20 | "big": [1024, 512, 384]} 21 | 22 | ### Deep Sets Architecture Construction 23 | size = self.size_dict_path[size_arg] 24 | fc = [nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout)] 25 | attention_net = Attn_Net_Gated( 26 | L=size[1], D=size[2], dropout=dropout, n_classes=1 27 | ) 28 | fc.append(attention_net) 29 | self.attention_net = nn.Sequential(*fc) 30 | self.rho = nn.Sequential( 31 | *[nn.Linear(size[1], size[2]), nn.ReLU(), nn.Dropout(dropout)] 32 | ) 33 | 34 | self.classifier = nn.Linear(size[2], n_classes) 35 | 36 | def relocate(self): 37 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 38 | if torch.cuda.device_count() >= 1: 39 | device_ids = list(range(torch.cuda.device_count())) 40 | self.attention_net = nn.DataParallel( 41 | self.attention_net, device_ids=device_ids 42 | ).to("cuda:0") 43 | 44 | self.rho = self.rho.to(device) 45 | self.classifier = self.classifier.to(device) 46 | 47 | def forward(self, bag): 48 | 49 | A, h_path = self.attention_net(bag) 50 | A = torch.transpose(A, 1, 0) 51 | A_raw = A 52 | A = F.softmax(A, dim=1) 53 | h_path = torch.mm(A, h_path) 54 | h_path = self.rho(h_path).squeeze() 55 | h = h_path # [256] vector 56 | 57 | logits = self.classifier(h).unsqueeze(0) # logits needs to be a [1 x 4] vector 58 | Y_hat = torch.topk(logits, 1, dim=1)[1] 59 | hazards = torch.sigmoid(logits) 60 | S = torch.cumprod(1 - hazards, dim=1) 61 | 62 | return hazards, S, Y_hat, None, None 63 | 64 | 65 | class Attn_Net_Gated(nn.Module): 66 | def __init__(self, L=1024, D=256, dropout=False, n_classes=1): 67 | r""" 68 | Attention Network with Sigmoid Gating (3 fc layers) 69 | args: 70 | L (int): input feature dimension 71 | D (int): hidden layer dimension 72 | dropout (bool): whether to apply dropout (p = 0.25) 73 | n_classes (int): number of classes 74 | """ 75 | super(Attn_Net_Gated, self).__init__() 76 | self.attention_a = [nn.Linear(L, D), nn.Tanh()] 77 | 78 | self.attention_b = [nn.Linear(L, D), nn.Sigmoid()] 79 | if dropout: 80 | self.attention_a.append(nn.Dropout(0.25)) 81 | self.attention_b.append(nn.Dropout(0.25)) 82 | 83 | self.attention_a = nn.Sequential(*self.attention_a) 84 | self.attention_b = nn.Sequential(*self.attention_b) 85 | self.attention_c = nn.Linear(D, n_classes) 86 | 87 | def forward(self, x): 88 | a = self.attention_a(x) 89 | b = self.attention_b(x) 90 | A = a.mul(b) 91 | A = self.attention_c(A) # N x n_classes 92 | return A, x 93 | -------------------------------------------------------------------------------- /var_pool/mhist/tcga_agg_slides_to_patient_level.py: -------------------------------------------------------------------------------- 1 | from var_pool.mhist.tcga_clinical_data import get_participant_from_tcga_barcode 2 | from var_pool.file_utils import get_file_names 3 | 4 | 5 | def tcga_agg_slides_to_patient_level(fpaths, names=None): 6 | """ 7 | Aggregates a list of slides to the patient level. 8 | 9 | Parameters 10 | ---------- 11 | fpaths: list of str 12 | A list of slide file paths. 13 | 14 | names: None, list of str 15 | A list of patient names to subset to. 16 | 17 | Output 18 | ------ 19 | patient2fpaths: dict of lists 20 | The file paths for the slides of each patient. 21 | """ 22 | 23 | names = set(names) if names is not None else None 24 | 25 | patient2fpaths = {} 26 | for slide_fpath in fpaths: 27 | 28 | # get patient name from slide file path 29 | slide_fname = get_file_names(slide_fpath) 30 | patient_id = get_participant_from_tcga_barcode(slide_fname) 31 | 32 | # maybe skip this patient. 33 | if names is not None and patient_id not in names: 34 | continue 35 | 36 | # add this slide file path to this patient 37 | if patient_id in patient2fpaths.keys(): 38 | patient2fpaths[patient_id].append(slide_fpath) 39 | else: 40 | patient2fpaths[patient_id] = [slide_fpath] 41 | 42 | return patient2fpaths 43 | -------------------------------------------------------------------------------- /var_pool/mhist/tcga_clinical_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | 5 | def get_participant_from_tcga_barcode(barcode): 6 | """ 7 | Takes a TCGA barcode and returns the participant part 8 | 9 | TCGA-3C-AALI-01Z-00-DX1.F6E9A5DF-D8FB-45CF-B4BD-C6B76294C291 -> TCGA-3C-AALI 10 | 11 | https://docs.gdc.cancer.gov/Encyclopedia/pages/TCGA_Barcode/ 12 | 13 | Parameters 14 | ---------- 15 | barcode: str 16 | The barcode. 17 | 18 | Output 19 | ------ 20 | participant: str 21 | The first three -s from the barcode. 22 | """ 23 | return '-'.join(barcode.split('-')[0:3]) 24 | 25 | 26 | def load_patient_data_from_cell_df(fpath, subtype=None, verbose=True): 27 | """ 28 | Download clinical data file from (Liu et al, 2018) from 29 | https://www.cell.com/cms/10.1016/j.cell.2018.02.052/attachment/bbf46a06-1fb0-417a-a259-fd47591180e4/mmc1 30 | 31 | it should be named mmc1.xlsx 32 | 33 | Paper can be found at 34 | https://www.cell.com/cell/fulltext/S0092-8674(18)30229-0 35 | 36 | 37 | Parameters 38 | ---------- 39 | fpath: str 40 | File path to data file from Table S1 (Liu et al, 2018). 41 | 42 | subtype: None, str 43 | (Optional) Subtype to subset to. 44 | 45 | Output 46 | ------ 47 | patient_clinical_df, cols 48 | 49 | patient_clinical_df: pd.DataFrame, (n_patients, n_features) 50 | 51 | cols: dict of lists 52 | Datatypes for each column 53 | 54 | References 55 | ---------- 56 | Liu, J., Lichtenberg, T., Hoadley, K.A., Poisson, L.M., Lazar, A.J., Cherniack, A.D., Kovatich, A.J., Benz, C.C., Levine, D.A., Lee, A.V. and Omberg, L., 2018. An integrated TCGA pan-cancer clinical data resource to drive high-quality survival outcome analytics. Cell, 173(2), pp.400-416. 57 | """ 58 | 59 | cell_df = pd.read_excel(fpath, sheet_name='TCGA-CDR', index_col=0) 60 | 61 | if verbose: 62 | print("{} patients with TCGA clinical data".format(cell_df.shape[0])) 63 | 64 | # set index 65 | # bcr_patient_barcode should be unique for each person 66 | assert cell_df.shape[0] == len(np.unique(cell_df['bcr_patient_barcode'])) 67 | cell_df = cell_df.set_index('bcr_patient_barcode') 68 | 69 | if subtype is not None: 70 | # subset to this subype 71 | cell_df = cell_df.query("type=='{}'".format(subtype.upper())) 72 | 73 | if verbose: 74 | print("{} patients with TCGA clinical data in {}". 75 | format(cell_df.shape[0], subtype)) 76 | 77 | # subset out columns we want 78 | cols = {} 79 | cols['survival_event'] = ['OS', 'DSS', 'DFI', 'PFI'] 80 | cols['survival_time'] = [c + '.time' for c in cols['survival_event']] 81 | cols['cts_feats'] = ['age_at_initial_pathologic_diagnosis'] 82 | cols['ordinal_feats'] = ['ajcc_pathologic_tumor_stage'] 83 | cols['cat_feats'] = ['gender', 'race', 'histological_type'] 84 | 85 | if subtype is None: 86 | cols['other'] = ['type'] 87 | else: 88 | cols['other'] = [] 89 | 90 | cols_we_want = np.concatenate(list(cols.values())) 91 | patient_clinical_df = cell_df[cols_we_want] 92 | 93 | # Handle NaNs 94 | feat_maps = {'ajcc_pathologic_tumor_stage': {'[Discrepancy]': 'NaN', 95 | '[Not Available]': 'NaN'}, 96 | 97 | 'race': {'[Not Available]': 'NaN', 98 | '[Not Evaluated]': 'NaN', 99 | }, 100 | 101 | 'histological_type': {'[Not Available]': 'NaN'} 102 | } 103 | 104 | patient_clinical_df.replace(to_replace=feat_maps, inplace=True) 105 | 106 | return patient_clinical_df, cols 107 | 108 | 109 | def restrct_patients_to_avail_slides(patient_df, avail_slides, verbose=True): 110 | """ 111 | Restricts a patient df to those patients with available slides. Also prints out summary. 112 | 113 | Parameters 114 | ---------- 115 | patient_df: pd.DataFrame, (n_patients, n_features) 116 | The patient level data frame indexed by 'bcr_patient_barcode' 117 | 118 | avail_slides: list of str 119 | The available slides. 120 | 121 | Output 122 | ------ 123 | patient_df, slides_with_patients, match_info 124 | 125 | patient_df: pd.DataFrame, (n_patients_avail, n_features) 126 | The patient df with available patients. 127 | 128 | slides_with_patients: list of str 129 | 130 | match_info: dict 131 | """ 132 | # find patietns we have both dataset for 133 | patient_ids = [] 134 | patient_ids_missing_slides = [] 135 | 136 | slide_patients = [] 137 | patient_slides = {} 138 | for slide in avail_slides: 139 | patient = get_participant_from_tcga_barcode(slide) 140 | slide_patients.append(patient) 141 | 142 | # add this slide to patient 143 | if patient in patient_slides: 144 | patient_slides[patient].append(slide) 145 | else: 146 | patient_slides[patient] = [slide] 147 | 148 | slide_patients_set = set(slide_patients) 149 | 150 | # got through patient clincial data 151 | slides_with_patients = [] 152 | for patient_id in patient_df.index: 153 | 154 | if patient_id in slide_patients_set: 155 | # this patient has slides 156 | patient_ids.append(patient_id) 157 | slides_with_patients.extend(patient_slides[patient_id]) 158 | else: 159 | # this patient has not slides 160 | patient_ids_missing_slides.append(patient_id) 161 | 162 | patient_ids_slides_no_clincal = slide_patients_set.\ 163 | difference(patient_df.index.values) 164 | 165 | if verbose: 166 | print("{} patients have both cinical data and slides". 167 | format(len(patient_ids))) 168 | print("{} total participants have slides". 169 | format(len(slide_patients_set))) 170 | print("{} patients have cinical data, but no slides". 171 | format(len(patient_ids_missing_slides))) 172 | print("{} patients have slides, but no clincal data". 173 | format(len(patient_ids_slides_no_clincal))) 174 | print("{} slides have patient data". 175 | format(len(slides_with_patients))) 176 | 177 | match_info = {'slides_and_clinical': patient_ids, 178 | 'all_slide_patients': slide_patients, 179 | 'clincal_no_slide': patient_ids_missing_slides, 180 | 'slide_no_clinical': patient_ids_slides_no_clincal 181 | } 182 | 183 | return patient_df.loc[patient_ids, :], slides_with_patients, match_info 184 | 185 | 186 | def broadcast_patient_to_slide(slide_names, patient_df): 187 | """ 188 | Broadcasts patient level information to slide level information. 189 | 190 | Parameters 191 | ---------- 192 | slide_names: list of str 193 | The names of the slides. 194 | 195 | patient_df: pd.DataFrame, (n_total_patients, n_features) 196 | The patient data. 197 | 198 | Output 199 | ------ 200 | slide_df 201 | 202 | slide_df: pd.DataFrame, (n_slides_with_data, n_features) 203 | """ 204 | 205 | # add patient id as a colum 206 | assert len(patient_df.index.name) > 0 207 | patient_df = patient_df.copy() 208 | patient_df[patient_df.index.name] = patient_df.index.values 209 | 210 | # broadcast patient level data to slide level data 211 | slide_df = {} 212 | for slide_name in slide_names: 213 | # get participant id from slid ename 214 | participant_id = get_participant_from_tcga_barcode(slide_name) 215 | 216 | # get the patient info for this participant 217 | if participant_id in patient_df.index: 218 | slide_df[slide_name] = patient_df.loc[participant_id] 219 | 220 | slide_df = pd.DataFrame(slide_df).T 221 | 222 | return slide_df 223 | -------------------------------------------------------------------------------- /var_pool/nn/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/nn/.DS_Store -------------------------------------------------------------------------------- /var_pool/nn/ComparablePairSampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Sampler 2 | from itertools import combinations 3 | import numpy as np 4 | 5 | # TODO: add shuffle 6 | class ComparablePairSampler(Sampler): 7 | """ 8 | Iterates over comparable pairs. 9 | 10 | Parameters 11 | ---------- 12 | times: array-like, (n_samples, ) 13 | The survival times. 14 | 15 | censor: array-like, (n_samples, ) 16 | The censor indicators. 17 | 18 | """ 19 | def __init__(self, times, censor): 20 | events = ~censor.astype(bool) 21 | self.pairs = get_comparable_pairs(times=times, events=events) 22 | 23 | def __iter__(self): 24 | return iter(self.pairs) 25 | 26 | def __len__(self): 27 | return len(self.pairs) 28 | 29 | 30 | def get_comparable_pairs(times, events): 31 | """ 32 | Gets the comparable pairs. 33 | 34 | Parmeters 35 | --------- 36 | times: array-like, (n_samples, ) 37 | The survival times. 38 | 39 | events: array-like, (n_samples, ) 40 | The event indicators. 41 | 42 | Output 43 | ------ 44 | pairs: list, (n_comparable, 2) 45 | The indices of the comparable pairs where the first entry is the more risky sample (smaller survival). 46 | """ 47 | times = np.array(times).reshape(-1) 48 | events = np.array(events).reshape(-1) 49 | 50 | events = events.astype(bool) 51 | n_samples = times.shape[0] 52 | 53 | pairs = [] 54 | for (idx_a, idx_b) in combinations(range(n_samples), 2): 55 | time_a, event_a = times[idx_a], events[idx_a] 56 | time_b, event_b = times[idx_b], events[idx_b] 57 | 58 | if time_a < time_b and event_a: 59 | # a and b are comparable, a is more risky 60 | pairs.append([idx_a, idx_b]) 61 | 62 | elif time_b < time_a and event_b: 63 | # a and b are comparable, b is more risky 64 | pairs.append([idx_b, idx_a]) 65 | 66 | return pairs 67 | -------------------------------------------------------------------------------- /var_pool/nn/CoxLoss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | from itertools import combinations 5 | 6 | 7 | class CoxLoss_alternative(nn.Module): 8 | """ 9 | Implements the Cox proportional hazards loss for deep learning models. See Equation (4) of (Katzman et al, 2018) without the L2 term. 10 | 11 | Parameters 12 | ---------- 13 | reduction: str 14 | Do we sum or average the loss function over the batches. Must be one of ['mean', 'sum']. 15 | 16 | References 17 | ---------- 18 | Katzman, J.L., Shaham, U., Cloninger, A., Bates, J., Jiang, T. and Kluger, Y., 2018. DeepSurv: personalized treatment recommender system using a Cox proportional hazards deep neural network. BMC medical research methodology, 18(1), pp.1-12. 19 | """ 20 | def __init__(self, reduction='mean'): 21 | super().__init__() 22 | assert reduction in ['sum', 'mean'] 23 | self.reduction = reduction 24 | 25 | def forward(self, z, c_t): 26 | """ 27 | Parameters 28 | ---------- 29 | z: (batch_size, 1) 30 | The predicted log risk scores i.e. h(x) in (Katzman et al, 2018). 31 | 32 | c_t: (batch_size, 2) 33 | first element: censorship 34 | second element: survival time 35 | """ 36 | 37 | censor = c_t[:, 0].bool() 38 | events = ~censor 39 | times = c_t[:, 1] 40 | 41 | z = z.rehape(-1) 42 | exp_z = torch.exp(z) 43 | 44 | batch_size = z.shape[0] 45 | 46 | ############################################################### 47 | # determine risk set for each observation with observed event # 48 | ############################################################### 49 | 50 | event_risk_sets = {} # risk set for everyone with observed event 51 | for (idx_a, idx_b) in combinations(range(batch_size), 2): 52 | time_a, event_a = times[idx_a], events[idx_a] 53 | time_b, event_b = times[idx_b], events[idx_b] 54 | 55 | # event_idx = experienced event and definietly died 56 | # before the still_alive_idx 57 | event_idx = None 58 | still_alive_idx = None 59 | if time_a <= time_b and event_a: 60 | event_idx = idx_a 61 | still_alive_idx = idx_b 62 | 63 | elif time_b <= time_a and event_b: 64 | event_idx = idx_b 65 | still_alive_idx = idx_a 66 | 67 | # risk_sets[event_idx] = list of idxs in risk set for event_idx 68 | if event_idx is not None: 69 | if event_idx not in event_risk_sets.keys(): 70 | event_risk_sets[event_idx] = [still_alive_idx] 71 | else: 72 | event_risk_sets[event_idx].append(still_alive_idx) 73 | 74 | ############################################################ 75 | # compute loss terms for observations with observed events # 76 | ############################################################ 77 | 78 | # if there are no comparable pairs then just return zero 79 | if len(event_risk_sets) == 0: 80 | # TODO: perhaps return None? 81 | return torch.zeros(1, requires_grad=True) 82 | 83 | # Compute each term in the sum in Equation (4) of (Katzman et al, 2018) 84 | summands = [] 85 | for event_idx, risk_set in event_risk_sets.items(): 86 | 87 | sum_exp_risk_set = torch.sum([exp_z[r_idx] for r_idx in risk_set]) 88 | summand = z - torch.log(sum_exp_risk_set) 89 | 90 | summands.append(summand) 91 | 92 | ########## 93 | # Output # 94 | ########## 95 | 96 | if self.reduction == 'mean': 97 | return -torch.mean(summands) 98 | if self.reduction == 'sum': 99 | return -torch.sum(summands) 100 | 101 | 102 | # This may have some issues 103 | # TODO: remove 104 | class CoxLoss(nn.Module): 105 | """ 106 | Implements the Cox PH loss. Borrowed from Richard's Pathomic Fusion code 107 | 108 | Parameters 109 | ---------- 110 | 111 | reduction: str 112 | Do we sum or average the loss function over the batches. Must be one of ['mean', 'sum']. 113 | """ 114 | def __init__(self, reduction='mean'): 115 | super().__init__() 116 | assert reduction in ['sum', 'mean'] 117 | self.reduction = reduction 118 | 119 | def forward(self, z, c_t): 120 | """ 121 | Parameters 122 | ---------- 123 | z: (batch_size, 1) 124 | The predicted risk scores. 125 | 126 | c_t: (batch_size, 2) 127 | first element: censorship 128 | second element: survival time 129 | """ 130 | # assert z.shape[1] == 1, "The network output doesn't fit cox model" 131 | 132 | hazards = z 133 | censor = c_t[:, 0].bool() 134 | events = ~censor 135 | survtime = c_t[:, 1] 136 | 137 | batch_size = hazards.shape[0] 138 | R_mat = np.zeros([batch_size, batch_size], dtype=int) 139 | for i in range(batch_size): 140 | for j in range(batch_size): 141 | R_mat[i, j] = survtime[j] >= survtime[i] 142 | 143 | # convert to torch and put on same device as hazards 144 | R_mat = hazards.new(R_mat) 145 | 146 | # R_mat = torch.FloatTensor(R_mat) 147 | R_mat = R_mat.float() 148 | theta = hazards.reshape(-1) 149 | exp_theta = torch.exp(theta) 150 | 151 | summands = theta - torch.log(torch.sum(exp_theta*R_mat, dim=1)) 152 | summands = summands * events 153 | 154 | if self.reduction == 'mean': 155 | return -torch.mean(summands) 156 | if self.reduction == 'sum': 157 | return -torch.sum(summands) 158 | -------------------------------------------------------------------------------- /var_pool/nn/NLLSurvLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class NLLSurvLoss(nn.Module): 6 | """ 7 | The negative log-likelihood loss function for the discrete time to event model (Zadeh and Schmid, 2020). 8 | 9 | Code borrowed from https://github.com/mahmoodlab/Patch-GCN/blob/master/utils/utils.py 10 | 11 | Parameters 12 | ---------- 13 | alpha: float 14 | TODO: document 15 | 16 | eps: float 17 | Numerical constant; lower bound to avoid taking logs of tiny numbers. 18 | 19 | reduction: str 20 | Do we sum or average the loss function over the batches. Must be one of ['mean', 'sum'] 21 | """ 22 | def __init__(self, alpha=0.0, eps=1e-7, reduction='mean'): 23 | super().__init__() 24 | self.alpha = alpha 25 | self.eps = eps 26 | self.reduction = reduction 27 | 28 | def __call__(self, h, y_c): 29 | """ 30 | Parameters 31 | ---------- 32 | h: (n_batches, n_classes) 33 | The neural network output discrete survival predictions such that hazards = sigmoid(h). 34 | 35 | y_c: (n_batches, 2) or (n_batches, 3) 36 | The true time bin label (first column) and censorship indicator (second column). 37 | """ 38 | y_true = y_c[:, 0].unsqueeze(1) 39 | c = y_c[:, 1].unsqueeze(1) 40 | 41 | return nll_loss(h=h, y_true=y_true, c=c, 42 | alpha=self.alpha, eps=self.eps, 43 | reduction=self.reduction) 44 | 45 | 46 | # TODO: document better and clean up 47 | def nll_loss(h, y_true, c, alpha=0.0, eps=1e-7, reduction='mean'): 48 | """ 49 | The negative log-likelihood loss function for the discrete time to event model (Zadeh and Schmid, 2020). 50 | 51 | Code borrowed from https://github.com/mahmoodlab/Patch-GCN/blob/master/utils/utils.py 52 | 53 | Parameters 54 | ---------- 55 | h: (n_batches, n_classes) 56 | The neural network output discrete survival predictions such that hazards = sigmoid(h). 57 | 58 | y_true: (n_batches, 1) 59 | The true time bin index label. 60 | 61 | c: (n_batches, 1) 62 | The censoring status indicator. 63 | 64 | alpha: float 65 | TODO: document 66 | 67 | eps: float 68 | Numerical constant; lower bound to avoid taking logs of tiny numbers. 69 | 70 | reduction: str 71 | Do we sum or average the loss function over the batches. Must be one of ['mean', 'sum'] 72 | 73 | References 74 | ---------- 75 | Zadeh, S.G. and Schmid, M., 2020. Bias in cross-entropy-based training of deep survival networks. IEEE transactions on pattern analysis and machine intelligence. 76 | """ 77 | 78 | # print("h shape", h.shape) 79 | 80 | # make sure these are ints 81 | y_true = y_true.type(torch.int64) 82 | c = c.type(torch.int64) 83 | 84 | hazards = torch.sigmoid(h) 85 | # print("hazards shape", hazards.shape) 86 | 87 | S = torch.cumprod(1 - hazards, dim=1) 88 | # print("S.shape", S.shape, S) 89 | 90 | S_padded = torch.cat([torch.ones_like(c), S], 1) 91 | # S(-1) = 0, all patients are alive from (-inf, 0) by definition 92 | # after padding, S(0) = S[1], S(1) = S[2], etc, h(0) = h[0] 93 | # hazards[y] = hazards(1) 94 | # S[1] = S(1) 95 | # TODO: document and check 96 | 97 | # print("S_padded.shape", S_padded.shape, S_padded) 98 | 99 | 100 | # TODO: document/better naming 101 | s_prev = torch.gather(S_padded, dim=1, index=y_true).clamp(min=eps) 102 | h_this = torch.gather(hazards, dim=1, index=y_true).clamp(min=eps) 103 | s_this = torch.gather(S_padded, dim=1, index=y_true+1).clamp(min=eps) 104 | # print('s_prev.s_prev', s_prev.shape, s_prev) 105 | # print('h_this.shape', h_this.shape, h_this) 106 | # print('s_this.shape', s_this.shape, s_this) 107 | 108 | uncensored_loss = -(1 - c) * (torch.log(s_prev) + torch.log(h_this)) 109 | censored_loss = - c * torch.log(s_this) 110 | # print('uncensored_loss.shape', uncensored_loss.shape) 111 | # print('censored_loss.shape', censored_loss.shape) 112 | 113 | neg_l = censored_loss + uncensored_loss 114 | if alpha is not None: 115 | loss = (1 - alpha) * neg_l + alpha * uncensored_loss 116 | 117 | if reduction == 'mean': 118 | loss = loss.mean() 119 | elif reduction == 'sum': 120 | loss = loss.sum() 121 | else: 122 | raise ValueError("Bad input for reduction: {}".format(reduction)) 123 | 124 | return loss 125 | -------------------------------------------------------------------------------- /var_pool/nn/SurvRankingLoss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | from itertools import combinations 5 | 6 | 7 | class SurvRankingLoss(nn.Module): 8 | """ 9 | Implements the surivival ranking loss which approximates the negaive c-index; see Section 3.2 of (Luck et al, 2018) -- but be careful of the typo in their c-index formula. 10 | 11 | The c-index for risk scores z_1, ..., z_n is given by 12 | 13 | c_index = sum_{(a, b) are comparable} 1(z_a > z_b) 14 | 15 | where (a, b) are comparable if and only if a's event is observed and a has a strictly lower survival time than b. This ignores ties. 16 | 17 | We replace the indicator with a continous approximation 18 | 19 | 1(z_a - z_b > 0 ) ~= phi(z_a - z_b) 20 | 21 | e.g. where phi(r) is a Relu or sigmoid function. 22 | 23 | The loss function we want to minimize is then 24 | 25 | - sum_{(a, b) are comparable} phi(z_a - z_b) 26 | 27 | where z_a, z_b are the risk scores output by the network. 28 | 29 | Parameters 30 | ---------- 31 | phi: str 32 | Which indicator approximation to use. Must be one of ['relu', 'sigmoid']. 33 | 34 | reduction: str 35 | Do we sum or average the loss function over the batches. Must be one of ['mean', 'sum'] 36 | 37 | References 38 | ---------- 39 | Luck, M., Sylvain, T., Cohen, J.P., Cardinal, H., Lodi, A. and Bengio, Y., 2018. Learning to rank for censored survival data. arXiv preprint arXiv:1806.01984. 40 | """ 41 | 42 | def __init__(self, phi='sigmoid', reduction='mean'): 43 | super().__init__() 44 | 45 | assert phi in ['sigmoid', 'relu'] 46 | assert reduction in ['mean', 'sum'] 47 | self.phi = phi 48 | self.reduction = reduction 49 | 50 | def forward(self, z, c_t): 51 | """ 52 | Parameters 53 | ---------- 54 | z: (batch_size, 1) 55 | The predicted risk scores. 56 | 57 | c_t: (batch_size, 2) 58 | first element: censorship 59 | second element: survival time 60 | """ 61 | batch_size = z.shape[0] 62 | 63 | if batch_size == 1: 64 | raise NotImplementedError("Batch size must be at least 2") 65 | 66 | censorship, times = c_t[:, 0], c_t[:, 1] 67 | events = 1 - censorship 68 | 69 | ############################## 70 | # determine comparable pairs # 71 | ############################## 72 | Z_more_risky = [] 73 | Z_less_risky = [] 74 | for (idx_a, idx_b) in combinations(range(batch_size), 2): 75 | time_a, event_a = times[idx_a], events[idx_a] 76 | time_b, event_b = times[idx_b], events[idx_b] 77 | 78 | if time_a < time_b and event_a: 79 | # a and b are comparable, a is more risky 80 | Z_more_risky.append(z[idx_a]) 81 | Z_less_risky.append(z[idx_b]) 82 | 83 | elif time_b < time_a and event_b: 84 | # a and b are comparable, b is more risky 85 | Z_more_risky.append(z[idx_b]) 86 | Z_less_risky.append(z[idx_a]) 87 | 88 | # if there are no comparable pairs then just return zero 89 | if len(Z_less_risky) == 0: 90 | # TODO: perhaps return None? 91 | return torch.zeros(1, requires_grad=True) 92 | 93 | Z_more_risky = torch.stack(Z_more_risky) 94 | Z_less_risky = torch.stack(Z_less_risky) 95 | 96 | # compute approximate c indices 97 | r = Z_more_risky - Z_less_risky 98 | if self.phi == 'sigmoid': 99 | approx_c_indices = torch.sigmoid(r) 100 | 101 | elif self.phi == 'relu': 102 | approx_c_indices = torch.relu(r) 103 | 104 | # negative mean/sum of c-indices 105 | if self.reduction == 'mean': 106 | return - approx_c_indices.mean() 107 | if self.reduction == 'sum': 108 | return - approx_c_indices.sum() 109 | -------------------------------------------------------------------------------- /var_pool/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/nn/__init__.py -------------------------------------------------------------------------------- /var_pool/nn/arch/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/nn/arch/.DS_Store -------------------------------------------------------------------------------- /var_pool/nn/arch/AttnMIL_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from var_pool.nn.arch.AttnMIL import Attn, GatendAttn 6 | 7 | 8 | class AttnMILMixin: 9 | """ 10 | Mixin for attention MIL modules 11 | 12 | Parameters 13 | ---------- 14 | enc_and_attend: nn.Module -> attn_scores, 15 | 16 | """ 17 | 18 | def compute_bag_feats_and_attn_scores(self, bag): 19 | """ 20 | Computes the instance encodings and attention scores 21 | 22 | Parameters 23 | ---------- 24 | if bag is a Tensor 25 | bag: shape (n_batches, n_instances, *instance_dims) 26 | The instance bag features. 27 | 28 | if bag is a tuple/list 29 | bag: tuple, shape (2, ) 30 | If a fixed bag size is provided then we input a tuple where the first element is the bag (as above) and the second element is the list of non_pad_size for each bag indicating how many instances in each bag are real instances and not padding. 31 | 32 | Ouput 33 | ----- 34 | attn_scores, bag_feats 35 | 36 | attn_scores: list or torch.Tensor, (batch_size, n_instances, 1) 37 | The attention scores; a list if there are separate attention brancehs. 38 | 39 | bag_feats: (batch_size, n_insatnaces, n_bag_feats) 40 | """ 41 | ################ 42 | # Format input # 43 | ################ 44 | if isinstance(bag, (list, tuple)): 45 | # split input = bag, non_pad_size 46 | bag, non_pad_size = bag 47 | n_batches, n_instances = bag.shape[0:2] 48 | 49 | else: 50 | n_batches, n_instances = bag.shape[0:2] 51 | 52 | assert bag.ndim >= 3, "Make sure to include first batch dimension" 53 | 54 | # flatten batches x instances so the attention scores 55 | # can be easily computed in parallel. This allows us to 56 | # wrap enc_and_attend a nn.DataParallel() to compute all instances 57 | # in parallel 58 | # (n_batches, n_instances, *instance_dims) -> 59 | # (n_batches * n_instances, *instance_dims) 60 | x = torch.flatten(bag, start_dim=0, end_dim=1) 61 | 62 | ##################### 63 | # Encode and attend # 64 | ##################### 65 | 66 | attn_scores, bag_feats = self.enc_and_attend(x) 67 | # (n_batches * n_instances, 1), (n_batches * n_instances, encode_dim) 68 | n_bag_feats = bag_feats.shape[1] 69 | 70 | ################# 71 | # Format output # 72 | ################# 73 | 74 | # unflatten 75 | bag_feats = bag_feats.view(n_batches, n_instances, n_bag_feats) 76 | 77 | if isinstance(attn_scores, list): 78 | # if there are multiple attention scores 79 | attn_scores = [a_s.view(n_batches, n_instances, 1) 80 | for a_s in attn_scores] 81 | else: 82 | attn_scores = attn_scores.view(n_batches, n_instances, 1) 83 | 84 | return attn_scores, bag_feats 85 | 86 | def computed_norm_attn(self, attn_scores, 87 | is_padded_bag, non_pad_size): 88 | """ 89 | Parameters 90 | ---------- 91 | attn_scores: torch.Tensor (batch_size, n_instances, 1) 92 | The attention scores tensor. 93 | 94 | is_padded_bag: bool 95 | Whether or not the bag was padded. 96 | 97 | non_pad_size: array-like, shape (batch_size, ) or None 98 | 99 | Output 100 | ------ 101 | norm_attn: (n_batches, n_instances, 1) 102 | """ 103 | 104 | batch_size, n_instances = attn_scores.shape[0], attn_scores.shape[1] 105 | 106 | if not is_padded_bag: 107 | # attn = softmax over instances 108 | norm_attn = F.softmax(attn_scores, dim=1) 109 | # (n_batches, n_instances, 1) 110 | 111 | else: 112 | # make the attention scores for padding instances 113 | # very big negative numbers so they effectively get ignored 114 | # this code is borrowed from 115 | # https://github.com/KatherLab/HIA/blob/main/models/model_Attmil.py 116 | 117 | # compute mask of true instances i.e. the non-pad elements 118 | instance_idxs = torch.arange(n_instances).repeat(batch_size, 1) 119 | instance_idxs = instance_idxs.to(attn_scores.device) 120 | 121 | true_instance_mask = instance_idxs < non_pad_size.unsqueeze(-1) 122 | true_instance_mask = true_instance_mask.unsqueeze(-1) 123 | 124 | # array with very large negative number = tiny attention 125 | big_neg = torch.full_like(attn_scores, -1e10) 126 | attn_scores_ignore_pad = torch.where(true_instance_mask, 127 | attn_scores, 128 | big_neg) 129 | 130 | # attn = softmax over instances 131 | norm_attn = F.softmax(attn_scores_ignore_pad, dim=1) 132 | 133 | return norm_attn 134 | 135 | def get_pad_info(self, bag): 136 | """ 137 | Parameters 138 | ---------- 139 | bag: torch.Tensor or list/tuple 140 | 141 | Output 142 | ------ 143 | is_padded_bag, non_pad_size 144 | 145 | is_padded_bag: bool 146 | Whether or not this is a padded bag. 147 | 148 | non_pad_size: None, torch.Tensor (batch_size) 149 | """ 150 | if isinstance(bag, (list, tuple)): 151 | # split input = bag, non_pad_size 152 | bag, non_pad_size = bag 153 | 154 | n_instances = bag.shape[1] 155 | 156 | if all(non_pad_size == n_instances): 157 | # if the non-pad size of all the bags is equal to the number of 158 | # instances in each bag they we did not add any padding 159 | # e.g. each bag was subset 160 | is_padded_bag = False 161 | # this allows us to skip the uncessary masking code below 162 | else: 163 | is_padded_bag = True 164 | 165 | else: 166 | is_padded_bag = False 167 | non_pad_size = None 168 | 169 | return is_padded_bag, non_pad_size 170 | 171 | 172 | def get_attn_module(encoder_dim, n_attn_latent, dropout, gated): 173 | """ 174 | Gets the attention module 175 | """ 176 | if gated: 177 | return Attn(n_in=encoder_dim, 178 | n_latent=n_attn_latent, 179 | dropout=dropout) 180 | else: 181 | return GatendAttn(n_in=encoder_dim, 182 | n_latent=n_attn_latent, 183 | dropout=dropout) 184 | 185 | 186 | class EncodeAndMultipleAttend(nn.Module): 187 | """ 188 | An encoder that feeds into multiple parallel attention branches. 189 | 190 | Parameters 191 | ---------- 192 | encoder: nn.Module, None 193 | The encoder that each samples is passed into. 194 | 195 | attns: list of nn.Module 196 | The attention branches. 197 | """ 198 | 199 | def __init__(self, encoder, attns): 200 | super().__init__() 201 | 202 | self.encoder = encoder 203 | self.attns = nn.ModuleList(attns) 204 | 205 | def forward(self, x): 206 | """ 207 | Parameters 208 | ---------- 209 | x: torch.Tensor, (n_batches, n_featues) 210 | 211 | Output 212 | ------ 213 | attn_scores, instance_encodings 214 | 215 | attn_scores: list len(attns) 216 | The attention scores applied to each encoder. 217 | """ 218 | if self.encoder is not None: 219 | x = self.encoder(x) 220 | 221 | # attention modules output (attn_scores, x) 222 | attn_scores = [attn(x)[0] for attn in self.attns] 223 | 224 | return attn_scores, x 225 | -------------------------------------------------------------------------------- /var_pool/nn/arch/GlobalPoolMIL.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class GlobalPoolMil(nn.Module): 6 | """ 7 | Global pool the bags (e.g. mean pool) the on linear layer. 8 | 9 | Parameters 10 | ---------- 11 | in_feats: int 12 | Feature input dimension. 13 | 14 | out_dim: int 15 | Output dimension. 16 | 17 | pool: str 18 | Which pooling operation to apply to each bag feature. 19 | """ 20 | def __init__(self, in_feats, out_dim, pool='mean'): 21 | super().__init__() 22 | 23 | self.head = nn.Linear(in_feats, out_dim) 24 | 25 | assert pool in ['mean'] 26 | self.pool = pool 27 | 28 | def forward(self, bags): 29 | if self.pool == 'mean': 30 | feats = torch.mean(bags, axis=1) 31 | return self.head(feats) 32 | 33 | -------------------------------------------------------------------------------- /var_pool/nn/arch/VarPool_switch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Variance Pooling + Attention based multiple instance learning architecture. 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | from var_pool.nn.arch.AttnMIL_utils import AttnMILMixin, get_attn_module,\ 8 | EncodeAndMultipleAttend 9 | from var_pool.nn.arch.VarPool import VarPool 10 | 11 | 12 | class AttnMeanAndVarPoolMIL_with_switch(AttnMILMixin, nn.Module): 13 | """ 14 | Attention mean and variance pooling architecture. 15 | 16 | Parameters 17 | ---------- 18 | 19 | encoder_dim: int 20 | Dimension of the encoder features. This is either the dimension output by the instance encoder (if there is one) or it is the dimension of the input feature (if there is no encoder). 21 | 22 | encoder: None, nn.Module 23 | (Optional) The bag instance encoding network. 24 | 25 | mean_neck: None, nn.Module. 26 | (Optional) Function to apply to the mean pool embedding. Should output the same shape as var_neck which is then input to head. 27 | 28 | var_neck: None, nn.Module. 29 | (Optional) Function to apply to the var pool embedding. Should output the same shape as mean_neck which is then input to head. 30 | 31 | head: nn.Module,None 32 | (Optional) The output of mean_neck and var_neck are added together then send to the head network. 33 | 34 | n_attn_latent: int, None 35 | Number of latent dimension for the attention layer. If None, will default to (n_in + 1) // 2. 36 | 37 | gated: bool 38 | Use the gated attention mechanism. 39 | 40 | separate_attn: bool 41 | WHether or not we want to use separate attention branches for the mean and variance pooling. 42 | 43 | n_var_pools: int 44 | Number of variance pooling projections. 45 | 46 | act_func: str 47 | The activation function to apply to variance pooling. Must be one of ['sqrt', 'log', 'sigmoid']. 48 | 49 | log_eps: float 50 | Epsilon value for log(epsilon + ) var pool activation function. 51 | 52 | dropout: bool, float 53 | Whether or not to use dropout in the attention mechanism. If True, will default to p=0.25. 54 | 55 | References 56 | ---------- 57 | Ilse, M., Tomczak, J. and Welling, M., 2018, July. Attention-based deep multiple instance learning. In International conference on machine learning (pp. 2127-2136). PMLR. 58 | """ 59 | def __init__(self, encoder_dim, encoder=None, 60 | mean_neck=None, var_neck=None, head=None, 61 | n_attn_latent=None, gated=True, 62 | separate_attn=False, n_var_pools=100, 63 | act_func='sqrt', log_eps=0.01, 64 | dropout=False): 65 | super().__init__() 66 | 67 | ########################### 68 | # Setup encode and attend # 69 | ########################### 70 | self.separate_attn = bool(separate_attn) 71 | self.encoder = encoder 72 | if self.separate_attn: 73 | attns = [get_attn_module(encoder_dim=encoder_dim, 74 | n_attn_latent=n_attn_latent, 75 | dropout=dropout, 76 | gated=gated) 77 | for _ in range(2)] 78 | 79 | self.enc_and_attend = EncodeAndMultipleAttend(encoder=self.encoder, 80 | attns=attns) 81 | 82 | else: 83 | attention = get_attn_module(encoder_dim=encoder_dim, 84 | n_attn_latent=n_attn_latent, 85 | dropout=dropout, 86 | gated=gated) 87 | 88 | if encoder is not None: 89 | self.enc_and_attend = nn.Sequential(self.encoder, attention) 90 | else: 91 | self.enc_and_attend = attention 92 | 93 | #################### 94 | # Variance pooling # 95 | #################### 96 | self.var_pool = VarPool(encoder_dim=encoder_dim, 97 | n_var_pools=n_var_pools, 98 | log_eps=log_eps, 99 | act_func=act_func) 100 | 101 | self._apply_var_pool = True 102 | 103 | ################ 104 | # Head network # 105 | ################ 106 | 107 | self.mean_neck = mean_neck 108 | self.var_neck = var_neck 109 | self.head = head 110 | 111 | def var_pool_off(self): 112 | """ 113 | Turns variance pooling off. 114 | """ 115 | self._apply_var_pool = False 116 | return self 117 | 118 | def var_pool_on(self): 119 | """ 120 | Turns variance pooling on. 121 | """ 122 | self._apply_var_pool = True 123 | return self 124 | 125 | def get_variance(self, bag, normalize=True): 126 | """ 127 | Get variance contribution. Essentially truncated forward pass to obtain var_pooled_bag_feats 128 | 129 | Parameters 130 | ---------- 131 | normalize: Bool 132 | Normalize each variance by norm of the projection vector 133 | 134 | Output 135 | ------ 136 | var_pooled_bag_feats (n_batches, var_pool) 137 | """ 138 | 139 | # instance encodings and attention scores 140 | attn_scores, bag_feats = self.compute_bag_feats_and_attn_scores(bag) 141 | 142 | is_padded_bag, non_pad_size = self.get_pad_info(bag) 143 | 144 | # normalize attetion 145 | if self.separate_attn: 146 | var_attn_scores = attn_scores[1] 147 | 148 | var_attn = self.computed_norm_attn(attn_scores=var_attn_scores, 149 | is_padded_bag=is_padded_bag, 150 | non_pad_size=non_pad_size) 151 | 152 | else: 153 | _attn = self.computed_norm_attn(attn_scores=attn_scores, 154 | is_padded_bag=is_padded_bag, 155 | non_pad_size=non_pad_size) 156 | 157 | var_attn = _attn 158 | 159 | var_pooled_bag_feats = self.var_pool(bag_feats, var_attn) 160 | 161 | if normalize: 162 | norm = torch.norm(self.var_pool.var_projections.weight.data, dim=1) 163 | var_pooled_bag_feats /= norm**2 # squared since it's variance 164 | 165 | return var_pooled_bag_feats 166 | 167 | def encode(self, x): 168 | return self.encoder(x) 169 | 170 | def forward(self, bag): 171 | 172 | ################################### 173 | # Instance encoding and attention # 174 | ################################### 175 | 176 | # instance encodings and attention scores 177 | attn_scores, bag_feats = self.compute_bag_feats_and_attn_scores(bag) 178 | 179 | is_padded_bag, non_pad_size = self.get_pad_info(bag) 180 | 181 | # normalize attetion 182 | if self.separate_attn: 183 | mean_attn_scores = attn_scores[0] 184 | var_attn_scores = attn_scores[1] 185 | 186 | mean_attn = self.computed_norm_attn(attn_scores=mean_attn_scores, 187 | is_padded_bag=is_padded_bag, 188 | non_pad_size=non_pad_size) 189 | 190 | if self._apply_var_pool: 191 | var_attn = self.computed_norm_attn(attn_scores=var_attn_scores, 192 | is_padded_bag=is_padded_bag, 193 | non_pad_size=non_pad_size) 194 | 195 | else: 196 | var_attn = None 197 | 198 | else: 199 | _attn = self.computed_norm_attn(attn_scores=attn_scores, 200 | is_padded_bag=is_padded_bag, 201 | non_pad_size=non_pad_size) 202 | 203 | mean_attn = _attn 204 | var_attn = _attn 205 | 206 | ##################### 207 | # Attention pooling # 208 | ##################### 209 | 210 | mean_pooled_feats = (bag_feats * mean_attn).sum(1) 211 | # (n_batches, n_instances, encode_dim) -> (n_batches, encoder_dim) 212 | if self.mean_neck is not None: 213 | mean_pooled_feats = self.mean_neck(mean_pooled_feats) 214 | # (n_batches, encoder_dim) -> (n_batches, cat_head_dim) 215 | 216 | if self._apply_var_pool: 217 | 218 | var_pooled_feats = self.var_pool(bag_feats, var_attn) 219 | # (n_batches, n_instances, encode_dim) -> (n_batches, n_var_pools) 220 | if self.var_neck is not None: 221 | var_pooled_feats = self.var_neck(var_pooled_feats) 222 | # (n_batches, n_var_pools) -> (n_batches, cat_head_dim) 223 | 224 | ################################ 225 | # get output from head network # 226 | ################################ 227 | 228 | if self._apply_var_pool: 229 | head_input = mean_pooled_feats + var_pooled_feats 230 | else: 231 | head_input = mean_pooled_feats 232 | # (n_batches, cat_head_dim) 233 | 234 | if self.head is not None: 235 | return self.head(head_input) 236 | # (n_batches, out_dim) 237 | else: 238 | return head_input 239 | -------------------------------------------------------------------------------- /var_pool/nn/arch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/nn/arch/__init__.py -------------------------------------------------------------------------------- /var_pool/nn/arch/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def mlp_from_tuple(dims, act='relu'): 5 | """ 6 | Creates a multi-layer perceptron. 7 | 8 | Parameters 9 | ---------- 10 | dims: list of ints 11 | The dimensions of the layers including the input and output layer i.e. there are len(dims) - 1 total layers 12 | 13 | act: str 14 | Activation function. Must be one of ['relu']. 15 | 16 | Output 17 | ------ 18 | net: nn.Module 19 | The MLP network. 20 | """ 21 | 22 | net = [] 23 | n_layers = len(dims) - 1 24 | for layer in range(n_layers): 25 | 26 | net.append(nn.Linear(dims[layer], dims[layer+1])) 27 | 28 | # add Relus after all but the last layer 29 | if layer < len(dims) - 1: 30 | if act == 'relu': 31 | a = nn.ReLU() 32 | else: 33 | raise NotImplementedError 34 | net.append(a) 35 | 36 | return nn.Sequential(*net) 37 | -------------------------------------------------------------------------------- /var_pool/nn/datasets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/nn/datasets/.DS_Store -------------------------------------------------------------------------------- /var_pool/nn/datasets/GraphDatasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset for graph neural network. Assumes the Data object has been already created 3 | """ 4 | import torch 5 | from torch_geometric.data import Data 6 | 7 | 8 | class GraphDataset: 9 | """ 10 | Graph Dataset. 11 | 12 | Parameters 13 | ---------- 14 | fpaths: list of str 15 | list of graph .pt files 16 | y: Dataframe 17 | Dataframe holding clinical information 18 | task: str 19 | rank_surv, cox_surv, or discr_surv 20 | arch: str 21 | Graph architecture to decide edge_spatial vs. edge_latent 22 | 23 | Outputs 24 | ------- 25 | graph_list: list of torch_geometric.data.Data 26 | 27 | """ 28 | def __init__(self, fpaths, y, task, arch): 29 | self.fpaths = fpaths 30 | self.y = y 31 | self.task = task 32 | self.arch = arch 33 | 34 | def __call__(self): 35 | graph_list = [] 36 | fnames = ['-'.join(f.rsplit('/')[-1].split('-')[:3]) 37 | for f in self.fpaths] 38 | 39 | for idx, name in enumerate(fnames): 40 | try: 41 | df = self.y.loc[name] 42 | except KeyError: 43 | pass 44 | else: 45 | label = int(df['time_bin']) 46 | c = int(df['censorship']) 47 | t = float(df['survival_time']) 48 | 49 | if self.task == 'discr_surv': 50 | y_out = torch.Tensor([label, c, t]).reshape(1, -1) 51 | else: 52 | y_out = torch.Tensor([c, t]).reshape(1, -1) 53 | 54 | for slide_path in self.fpaths[name]: 55 | temp = torch.load(slide_path) 56 | 57 | if self.arch in ['amil_gcn', 'amil_gcn_varpool']: 58 | graph_data = Data(edge_index=temp.edge_latent, x=temp.x, y=y_out) 59 | else: 60 | graph_data = Data(edge_index=temp.edge_index, x=temp.x, y=y_out) 61 | graph_list.append(graph_data) 62 | 63 | return graph_list 64 | 65 | # class GraphDataset: 66 | # """ 67 | # Graph Dataset. 68 | # 69 | # Parameters 70 | # ---------- 71 | # data_dir: str 72 | # Directory where graph .pt files are located 73 | # y: Dataframe 74 | # Dataframe holding clinical information 75 | # task: str 76 | # rank_surv, cox_surv, or discr_surv 77 | # 78 | # Outputs 79 | # ------- 80 | # graph_list: list of torch_geometric.data.Data 81 | # 82 | # """ 83 | # def __init__(self, data_dir, y, task): 84 | # self.data_dir = data_dir 85 | # self.y = y 86 | # self.task = task 87 | # 88 | # def __call__(self): 89 | # graph_list = [] 90 | # flist = glob(os.path.join(self.data_dir, '*')) 91 | # fnames = ['-'.join(f.rsplit('/')[-1].split('-')[:3]) for f in flist] 92 | # 93 | # for idx, f in enumerate(fnames): 94 | # try: 95 | # df = self.y.loc[f] 96 | # except KeyError: 97 | # pass 98 | # else: 99 | # label = int(df['time_bin']) 100 | # c = int(df['censorship']) 101 | # t = float(df['survival_time']) 102 | # 103 | # if self.task == 'discr_surv': 104 | # y_out = torch.Tensor([label, c, t]).reshape(1, -1) 105 | # else: 106 | # y_out = torch.Tensor([c, t]).reshape(1, -1) 107 | # 108 | # temp = torch.load(flist[idx]) 109 | # graph_data = Data(edge_index=temp.edge_index, x=temp.x, y=y_out) 110 | # graph_list.append(graph_data) 111 | # 112 | # return graph_list 113 | -------------------------------------------------------------------------------- /var_pool/nn/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/nn/datasets/__init__.py -------------------------------------------------------------------------------- /var_pool/nn/datasets/fixed_bag_size.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data._utils.collate import default_collate 3 | import numpy as np 4 | 5 | 6 | def to_fixed_size_bag(bag, fixed_bag_size=512): 7 | """ 8 | Takes an input bag and returns a bag of a fixed size through either random subsampling or zero padding. 9 | The random sampling is always done with torch so if you want reproducible results make sure to set the torch seed ahead of time. 10 | 11 | Parameters 12 | ---------- 13 | bag: torch.Tensor 14 | The bag of samples. 15 | 16 | fixed_bag_size: int 17 | Size of the desired bag. 18 | 19 | Output 20 | ------ 21 | new_bag, non_pad_size 22 | 23 | new_bag: 24 | The padded or subsampled bag with a fixed size. 25 | 26 | non_pad_size: int 27 | The number of samples in the bag corresponding to non padding elements i.e. this is min(bag_size, len(bag)) 28 | """ 29 | # borrowed from https://github.com/KatherLab/HIA/blob/main/models/model_Attmil.py 30 | fixed_bag_size = int(fixed_bag_size) 31 | this_bag_size = bag.shape[0] 32 | 33 | if this_bag_size > fixed_bag_size: 34 | 35 | # randomly subsample instances 36 | bag_idxs = torch.randperm(bag.shape[0]) # random smaple using torch 37 | bag_idxs = list(bag_idxs.numpy()) 38 | bag_idxs = bag_idxs[:fixed_bag_size] 39 | 40 | new_bag = bag[bag_idxs] 41 | 42 | non_pad_size = fixed_bag_size 43 | 44 | elif this_bag_size < fixed_bag_size: 45 | # zero-pad if we don't have enough samples 46 | n_feats = bag.shape[1] 47 | n_to_pad = fixed_bag_size - this_bag_size 48 | pad_size = (n_to_pad, n_feats) 49 | 50 | if isinstance(bag, torch.Tensor): 51 | pad = torch.zeros(pad_size, 52 | dtype=bag.dtype, 53 | device=bag.device) 54 | 55 | new_bag = torch.cat((bag, pad)) 56 | 57 | else: 58 | pad = np.zeros(pad_size, dtype=bag.dtype) 59 | new_bag = np.vstack([bag, pad]) 60 | 61 | non_pad_size = this_bag_size 62 | 63 | else: 64 | new_bag = bag 65 | non_pad_size = this_bag_size 66 | 67 | return new_bag, non_pad_size 68 | 69 | 70 | def get_collate_fixed_bag_size(fixed_bag_size): 71 | """ 72 | Gets the collate_fixed_bag_size function which sets every bag in the batch to a fixed number of instances (via random subsampling or padding) then calls torch.utils.data._utils.collate.default_collate. 73 | 74 | Note each bag array is replaced with a tuple of (bag, non_pad_size) where non_pad_size is as in to_fixed_size_bag. 75 | 76 | Parameters 77 | ---------- 78 | fixed_bag_size: int, str, None 79 | If an int, the size of the desired bag. If set to 'max', then will use the size of the largest bag in the batch. 80 | 81 | Output 82 | ------ 83 | collate_fixed_bag_size: callable 84 | """ 85 | 86 | def collate_fixed_bag_size(batch): 87 | """ 88 | Parameters 89 | ---------- 90 | batch: list 91 | The list of items in each batch. Each entry is either a an array (the bag features) or a tuple of length 2 (bag features and a response). 92 | 93 | Output 94 | ------ 95 | The output of default_collate(batch), but we have first applied to_fixed_size_bag() for each bag and replaced each bag with a tuple of (bag, non_pad_size). 96 | """ 97 | 98 | # setup bag size for this batch 99 | if fixed_bag_size == 'max': 100 | # get the largest bag size in the batch 101 | bag_sizes = [] 102 | for item in batch: 103 | if isinstance(item, tuple): 104 | bs = item[0].shape[0] 105 | else: 106 | bs = item.shape[0] 107 | bag_sizes.append(bs) 108 | 109 | FBS = max(bag_sizes) 110 | else: 111 | FBS = fixed_bag_size 112 | 113 | # reformat each bag to be a tuple of (new_bag, non_pad_size) 114 | batch_size = len(batch) 115 | for i in range(batch_size): 116 | 117 | # pull out the bag 118 | if isinstance(batch[i], tuple): 119 | # if the items are tuples, the bag should be the first entry 120 | bag = batch[i][0] 121 | else: 122 | bag = batch[i] 123 | 124 | # make the new bag with a fixed size 125 | new_bag, non_pad_size = \ 126 | to_fixed_size_bag(bag=bag, fixed_bag_size=FBS) 127 | 128 | # replace the old bag with tuple (new_bag, non_pad_size) 129 | new_item = (new_bag, non_pad_size) 130 | if isinstance(batch[i], tuple): 131 | batch[i] = (new_item, *batch[i][1:]) 132 | else: 133 | 134 | batch[i] = new_item 135 | 136 | return default_collate(batch) 137 | 138 | return collate_fixed_bag_size 139 | -------------------------------------------------------------------------------- /var_pool/nn/seeds.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import random 5 | 6 | 7 | def set_seeds(device, seed=1): 8 | """ 9 | Sets seeds to get reproducible experiments. 10 | 11 | Parametrs 12 | --------- 13 | device: torch.device 14 | The device we are using. 15 | 16 | seed: int 17 | The seed. 18 | """ 19 | 20 | random.seed(seed) 21 | os.environ['PYTHONHASHSEED'] = str(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | if device.type == 'cuda': 25 | torch.cuda.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 27 | 28 | torch.backends.cudnn.benchmark = False 29 | torch.backends.cudnn.deterministic = True 30 | -------------------------------------------------------------------------------- /var_pool/nn/train/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/nn/train/.DS_Store -------------------------------------------------------------------------------- /var_pool/nn/train/EarlyStopper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import torch 4 | from copy import deepcopy 5 | 6 | 7 | class EarlyStopper: 8 | """ 9 | Checks early stopping criteria and saves a model checkpoint anytime a record is set. Note the model checkpoints are saved everytime a record is set i.e. at the beginning of the patience period. 10 | 11 | Parameters 12 | ---------- 13 | save_dir: str 14 | Directory to where the model checkpoints are saved. 15 | 16 | min_epoch: int 17 | Dont check early stopping before this epoch. Record model checkpoints will still be saved before min_epoch. 18 | 19 | patience: int 20 | How many calls to wait after last time validation loss improved to decide whether or not to stop. Note this corresponds to the number of calls to EarlyStopper e.g. if we only check the validation score every K epochs then patience_epochs = K * patience. 21 | 22 | patience_min_improve: float 23 | The minimum improvement over the previous best score for us to reset the patience coutner. E.g. if the metric is very slowly improving we might want to stop training. 24 | 25 | abs_scale: bool 26 | Whether or not the patience_min_improve should be put on absolue (new - prev) or relative scale (new - prev)/prev. 27 | 28 | min_good: float 29 | Want to minimize the score metric (e.g. validation loss). 30 | 31 | verbose: bool 32 | Whether or not to print progress. 33 | 34 | Attributes 35 | ---------- 36 | patience_counter_: int 37 | The current patience counter. 38 | 39 | best_score_: float 40 | The best observed score so far. 41 | 42 | epochs_with_records_: list of in 43 | The epochs where a record was set. 44 | """ 45 | def __init__(self, save_dir, min_epoch=20, patience=10, min_good=True, 46 | patience_min_improve=0, abs_scale=True, verbose=True): 47 | 48 | self.save_dir = save_dir 49 | 50 | self.min_epoch = min_epoch 51 | self.patience = patience 52 | 53 | self.patience_min_improve = patience_min_improve 54 | self.abs_scale = abs_scale 55 | 56 | self.verbose = verbose 57 | self.min_good = min_good 58 | 59 | self._reset_tracking() 60 | 61 | def __call__(self, model, score, epoch, ckpt_name='checkpoint.pt'): 62 | """ 63 | Check early stopping criterion. 64 | 65 | Parametres 66 | ---------- 67 | model: 68 | The model to maybe save. 69 | 70 | score: float 71 | The metric we are scoring e.g. validation loss. 72 | 73 | epoch: int 74 | Which epoch just finished. Assumes zero indexed i.e. epoch=0 means we just finished the first epoch. 75 | 76 | ckpt_name: str 77 | The name of the checkpoint file. 78 | 79 | Output 80 | ----- 81 | stop_early: bool 82 | """ 83 | 84 | ################################################# 85 | # Check if new record + save record checkpoint # 86 | ################################################# 87 | prev_best = deepcopy(self.best_score_) 88 | 89 | # check if this score is a record 90 | if (self.min_good and score < self.best_score_) or \ 91 | ((not self.min_good) and score > self.best_score_): 92 | 93 | self.best_score_ = score 94 | is_record = True 95 | 96 | # always save a record to disk 97 | os.makedirs(self.save_dir, exist_ok=True) 98 | fpath = os.path.join(self.save_dir, ckpt_name) 99 | torch.save(model.state_dict(), fpath) 100 | 101 | if self.verbose: 102 | print('New record set on epoch {} at {:1.5f}' 103 | ' (previously was {:1.5f})'. 104 | format(epoch, self.best_score_, prev_best)) 105 | 106 | self.epochs_with_records_.append(epoch) 107 | 108 | else: 109 | is_record = False 110 | 111 | ######################## 112 | # Check early stopping # 113 | ######################## 114 | stop_early = False 115 | 116 | if (epoch + 1) >= self.min_epoch: 117 | # either increase or reset the counter since we are beyond the min epoch 118 | 119 | if is_record: # +1 for zero indexing 120 | # if we are passed the warm up period and just set a record 121 | 122 | # check if this is impressive record 123 | if abs(prev_best) == np.inf: 124 | # if the previous record was infintee this 125 | # was auotmatically an imporessive record 126 | is_impressive_record = True 127 | 128 | else: 129 | # compute difference on absolute or relative scale 130 | abs_diff = abs(score - prev_best) 131 | if self.abs_scale: 132 | diff_to_check = abs_diff 133 | else: 134 | epsilon = np.finfo(float).eps 135 | diff_to_check = abs_diff / (abs(prev_best) + epsilon) 136 | 137 | if diff_to_check >= self.patience_min_improve: 138 | is_impressive_record = True 139 | else: 140 | is_impressive_record = False 141 | else: 142 | # this was not a record 143 | is_impressive_record = False 144 | 145 | # reset the patience counter for impressive records 146 | # otherwise increase counter 147 | if is_impressive_record: 148 | self.patience_counter_ = 0 149 | else: 150 | self.patience_counter_ += 1 151 | 152 | if self.verbose: 153 | print("Early stopping counter {}/{}". 154 | format(self.patience_counter_, self.patience)) 155 | 156 | # if we have met our patience level we should stop! 157 | if self.patience_counter_ >= self.patience: 158 | stop_early = True 159 | 160 | return stop_early 161 | 162 | def _reset_tracking(self): 163 | """ 164 | resets the tracked data 165 | """ 166 | self.patience_counter_ = 0 167 | 168 | if self.min_good: 169 | self.best_score_ = np.Inf 170 | else: 171 | self.best_score_ = -np.Inf 172 | 173 | self.epochs_with_records_ = [] 174 | -------------------------------------------------------------------------------- /var_pool/nn/train/GradAccum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class GradAccum: 5 | """ 6 | Calculates loss function divosor for adjusting the loss when training with gradient accumulation. 7 | This object handles the case when number of batches is not divisible by the gradient accumultation and/or the case when the number of samples is not divisible by the batch size. 8 | 9 | Note this object assumes the loss function already averages over batches and that all but possibly the last batch has the same batch size. 10 | 11 | Be careful about this when using nn.DataParallel 12 | TODO: @iain think this through! 13 | 14 | This is also a small issue when using loss functions that have class weights, see https://github.com/pytorch/pytorch/issues/72047 15 | 16 | Parameters 17 | ---------- 18 | loader: torch DataLoader 19 | The data loader for the batches. 20 | 21 | grad_accum: int, None 22 | Number of gradient accumulation steps. 23 | 24 | 25 | Example 26 | ------- 27 | GA = GradAccum(loader=loader, grad_accum=grad_accum) 28 | 29 | model.zero_grad() 30 | for batch_idx, (x, y_true) in enumerate(loader): 31 | 32 | 33 | y_pred = model(x) 34 | loss = loss_func(y_true, y_pred) # assume this averages over batch 35 | 36 | # adjust loss divisor 37 | loss_div, update_params = GA.get_loss_div(batch_idx) 38 | loss = loss / loss_div 39 | 40 | loss.backward() 41 | 42 | # step after gradeint accumulation 43 | if update_params: 44 | optimizer.step() 45 | optimizer.zero_grad() 46 | """ 47 | 48 | def __init__(self, loader, grad_accum): 49 | 50 | # no grad accum if <= 1 51 | grad_accum = 1 if (grad_accum is not None and grad_accum <= 1)\ 52 | else grad_accum 53 | 54 | ######### 55 | # Setup # 56 | ######### 57 | 58 | n_batches = len(loader) 59 | n_samples = len(loader.dataset) 60 | batch_size = loader.batch_size 61 | 62 | # adjust for dropping last batch 63 | if loader.drop_last and (n_samples % batch_size == 0): 64 | # we see exactly full batches 65 | n_samples = n_batches * batch_size 66 | 67 | # number of gradient accumulation batches 68 | # i.e. the effective number of batches 69 | n_grad_accum_batches = int(np.ceil(n_batches / grad_accum)) 70 | 71 | # number of batches the last grad accum batch will have 72 | if n_batches % grad_accum == 0: 73 | n_batchs_in_last_ga_batch = grad_accum 74 | else: 75 | n_batchs_in_last_ga_batch = n_batches % grad_accum 76 | 77 | # number of samples in last batch 78 | if n_samples % batch_size == 0: 79 | n_samples_last_batch = batch_size 80 | else: 81 | n_samples_last_batch = n_samples % batch_size 82 | 83 | # the last grad accum batch may see 84 | # a different number of grad accum batches 85 | # and/or an uneven batch 86 | n_samples_in_last_grad_accum_batch = \ 87 | (n_batchs_in_last_ga_batch - 1) * batch_size +\ 88 | n_samples_last_batch 89 | 90 | # for the final gradient update batch we dont necessarily check 91 | # batch_idx % grad_accum == 0 92 | final_ga_batch_update_crit = n_batches % grad_accum 93 | 94 | # store data we need 95 | self.grad_accum = grad_accum 96 | self.batch_size = batch_size 97 | self.n_grad_accum_batches = n_grad_accum_batches 98 | self.n_samples_in_last_grad_accum_batch = \ 99 | n_samples_in_last_grad_accum_batch 100 | 101 | self.final_ga_batch_update_crit = final_ga_batch_update_crit 102 | 103 | def get_loss_div(self, batch_idx): 104 | """ 105 | Gets the quantity we should divide loss function by to account for gradient accumulation 106 | 107 | Parameters 108 | ---------- 109 | batch_idx: int 110 | The current batch index. 111 | 112 | Output 113 | ------ 114 | loss_div, update_params 115 | 116 | loss_div: float 117 | The loss divisor that accounts for gradient accumulation. 118 | 119 | update_params: bool 120 | Whether or not to update parameters with a gradient step on this batch_idx. 121 | """ 122 | 123 | if self.grad_accum is None: 124 | return 1. 125 | 126 | # which gradient accumulation batch we are on 127 | grad_accum_batch_idx = batch_idx // self.grad_accum 128 | 129 | if (grad_accum_batch_idx + 1) < self.n_grad_accum_batches: 130 | # NOT on last grad accum batch 131 | # current div is batch_size and we want to 132 | # change it to batch_size * grad_accum 133 | loss_div = self.grad_accum 134 | 135 | # update parameters every grad_accum times 136 | update_params = (batch_idx + 1) % self.grad_accum == 0 137 | 138 | else: 139 | # ON last grad accum batch 140 | # current div is batch_size and we want to 141 | # change it to n_samples_in_last_grad_accum_batch 142 | loss_div = self.batch_size / self.n_samples_in_last_grad_accum_batch 143 | 144 | # update parameters on very last batch index 145 | update_params = (batch_idx + 1) % self.grad_accum == \ 146 | self.final_ga_batch_update_crit 147 | 148 | return loss_div, update_params 149 | -------------------------------------------------------------------------------- /var_pool/nn/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/nn/train/__init__.py -------------------------------------------------------------------------------- /var_pool/nn/train/tests/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/nn/train/tests/.DS_Store -------------------------------------------------------------------------------- /var_pool/nn/train/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/nn/train/tests/__init__.py -------------------------------------------------------------------------------- /var_pool/nn/train/tests/test_GradAccum.py: -------------------------------------------------------------------------------- 1 | from var_pool.nn.train.tests.utils_grad_accum import check_GradAccum 2 | 3 | 4 | def test_GradAccum(): 5 | 6 | assert check_GradAccum(n_samples=5, batch_size=1, grad_accum=2) 7 | assert check_GradAccum(n_samples=5, batch_size=1, grad_accum=3) 8 | 9 | assert check_GradAccum(n_samples=4, batch_size=2, grad_accum=1) 10 | assert check_GradAccum(n_samples=4, batch_size=2, grad_accum=2) 11 | assert check_GradAccum(n_samples=5, batch_size=2, grad_accum=2) 12 | 13 | assert check_GradAccum(n_samples=10, batch_size=3, grad_accum=2) 14 | assert check_GradAccum(n_samples=10, batch_size=3, grad_accum=3) 15 | assert check_GradAccum(n_samples=10, batch_size=3, grad_accum=4) 16 | -------------------------------------------------------------------------------- /var_pool/nn/train/tests/utils_grad_accum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import DataLoader, Dataset 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | from var_pool.nn.train.GradAccum import GradAccum 9 | 10 | 11 | def check_GradAccum(n_samples, batch_size, grad_accum): 12 | """ 13 | Checks the GradAccum() object by checking that gradient accumulation is equivalent to using a new batch size of batch_size x grad_accum. 14 | 15 | Parameters 16 | ---------- 17 | n_samples: int 18 | Number of smaples in the dataset. 19 | 20 | batch_size: int 21 | the batch size. 22 | 23 | grad_accum: int 24 | Number of gradient accumulation setp. 25 | 26 | Output 27 | ------ 28 | test_passes: bool 29 | Whether or not the test passes. 30 | """ 31 | 32 | np.random.seed(1) 33 | torch.manual_seed(0) 34 | n_features = 5 35 | 36 | # n_samples = 4 37 | # batch_size = 4 38 | # grad_accum = 2 39 | 40 | # n_samples = 10 41 | # batch_size = 3 42 | # grad_accum = 2 43 | 44 | # Setup data set 45 | X = np.random.normal(size=(n_samples, n_features)) 46 | y = np.random.normal(size=n_samples) 47 | model = LinearRegression(n_features) 48 | dataset = RegressionDataset(X, y) 49 | 50 | loader = DataLoader(dataset, batch_size=batch_size, shuffle=False) 51 | 52 | # gradient accumulation with batch size B should be equivalent to using 53 | # batch size = B x grad_accum 54 | loader_BxG = DataLoader(dataset, batch_size=batch_size * grad_accum, 55 | shuffle=False) 56 | 57 | # run loops 58 | ga = epoch_with_grad_accum(model=model, loader=loader, 59 | grad_accum=grad_accum) 60 | BxG = epoch_manual(model=model, loader_BxG=loader_BxG) 61 | 62 | # check results 63 | return all([torch.allclose(a, b) for (a, b) in zip(ga, BxG)]) 64 | # test_passes = True 65 | # for a, b in zip(ga, BxG): 66 | # # print(a - b) 67 | # # assert torch.allclose(a, b) 68 | # if not torch.allclose(a, b): 69 | # test_passes = False 70 | 71 | 72 | class RegressionDataset(Dataset): 73 | def __init__(self, X, y): 74 | self.X = np.array(X) 75 | self.y = np.array(y).reshape(-1) 76 | 77 | def __len__(self): 78 | return self.X.shape[0] 79 | 80 | def __getitem__(self, idx): 81 | return self.X[idx, :], self.y[idx] 82 | 83 | 84 | class LinearRegression(nn.Module): 85 | def __init__(self, n_features): 86 | super().__init__() 87 | self.coef = nn.Linear(n_features, 1) 88 | 89 | def forward(self, input): 90 | return self.coef(input) 91 | 92 | 93 | def initialize_to_zero(model): 94 | for m in model.modules(): 95 | if isinstance(m, nn.Linear): 96 | m.weight.data.zero_() 97 | m.bias.data.zero_() 98 | 99 | 100 | def epoch_manual(model, loader_BxG, verbose=False): 101 | """ 102 | Run one epoch with loader that uses batch size n_batches x grad accum. 103 | 104 | Output 105 | ------ 106 | model_params: list of tensors 107 | The model parameters after one epoch. 108 | """ 109 | 110 | # setup loss/optimizer 111 | loss_func = nn.MSELoss() 112 | initialize_to_zero(model) # always init at zero 113 | optimizer = optim.SGD(params=model.parameters(), 114 | lr=1, momentum=0, weight_decay=0) 115 | 116 | if verbose: 117 | print('initial parameters', list(model.parameters()), '\n') 118 | 119 | model.zero_grad() 120 | for batch_idx, (x, y_true) in enumerate(loader_BxG): 121 | 122 | x = x.float() 123 | y_true = y_true.unsqueeze(1) 124 | 125 | y_pred = model(x.float()) 126 | loss = loss_func(y_true.float(), y_pred) 127 | loss.backward() 128 | 129 | if verbose: 130 | print('grad batch_idx={}'.format(batch_idx), 131 | [p.grad for p in model.parameters()]) 132 | 133 | optimizer.step() 134 | optimizer.zero_grad() 135 | 136 | return [p.data for p in model.parameters()] 137 | 138 | 139 | def epoch_with_grad_accum(model, loader, grad_accum, verbose=False): 140 | """ 141 | Run one epoch with gradient accumulation 142 | 143 | Output 144 | ------ 145 | model_params: list of tensors 146 | The model parameters after one epoch. 147 | """ 148 | 149 | GA_helper = GradAccum(loader=loader, grad_accum=grad_accum) 150 | 151 | # setup loss/optimizer 152 | loss_func = nn.MSELoss() 153 | initialize_to_zero(model) # always init at zero 154 | optimizer = optim.SGD(params=model.parameters(), 155 | lr=1, momentum=0, weight_decay=0) 156 | 157 | if verbose: 158 | print('initial parameters', list(model.parameters()), '\n') 159 | 160 | model.zero_grad() 161 | for batch_idx, (x, y_true) in enumerate(loader): 162 | x = x.float() 163 | y_true = y_true.unsqueeze(1) 164 | 165 | y_pred = model(x.float()) 166 | loss = loss_func(y_true.float(), y_pred) 167 | 168 | # adjust loss divisor 169 | loss_div, update_params = GA_helper.get_loss_div(batch_idx) 170 | loss = loss / loss_div 171 | 172 | if verbose: 173 | print('loss_div batch_idx={}'.format(batch_idx), loss_div) 174 | 175 | loss.backward() 176 | 177 | # step after gradeint accumulation 178 | if update_params: 179 | if verbose: 180 | print('grad batch_idx={}'.format(batch_idx), 181 | [p.grad for p in model.parameters()]) 182 | 183 | optimizer.step() 184 | optimizer.zero_grad() 185 | 186 | if verbose: 187 | print() 188 | print(GA_helper.__dict__) 189 | 190 | return [p.data for p in model.parameters()] 191 | -------------------------------------------------------------------------------- /var_pool/nn/tune_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from numbers import Number 4 | from time import time 5 | from datetime import datetime 6 | 7 | from var_pool.file_utils import join_and_make 8 | 9 | 10 | def run_train(tune_params, tune_idx, device, 11 | script_fpath, train_dir, 12 | fixed_params, fixed_flags, 13 | skip_if_results_exist=False): 14 | """ 15 | Calls train.py for one tuning parameter setting and saves the tuning parameters settings to disk. 16 | 17 | Parameters 18 | ---------- 19 | tune_idx: int 20 | Index of this tune setting. 21 | 22 | tune_params: dict 23 | The parameters that are tuned over and saved to disk. Any boolean values will be considered flags in the command string (and only called if they are True). 24 | 25 | device: None, int, list of int 26 | Which cuda devices(s) to use. 27 | 28 | script_fpath: str 29 | Path to the train.py script. 30 | 31 | train_dir: str 32 | Directory that containes the tune folders as tune_1, tune_2, ... 33 | 34 | fixed_params: dict 35 | The fixed parameters for each tune setting 36 | 37 | fixed_flags: list of str 38 | The flags that fixed for each tune setting. 39 | 40 | skip_if_results_exist: bool 41 | Dont run train.py if the corresponding results file already exists. Useful for picking up if tuning crashed half way through! 42 | """ 43 | 44 | name = 'tune_{}'.format(tune_idx) 45 | tune_save_dir = join_and_make(train_dir, name) 46 | 47 | # Maybe skip this if results is already there 48 | results_fpath = os.path.join(tune_save_dir, 'results.yaml') 49 | if skip_if_results_exist and os.path.exists(results_fpath): 50 | print("Skipping tune_idx {} because the results file exists". 51 | format(tune_idx)) 52 | return 53 | 54 | # Save the tunable parameters 55 | os.makedirs(tune_save_dir, exist_ok=True) 56 | tune_params_fpath = os.path.join(tune_save_dir, 'tune_params.yaml') 57 | with open(tune_params_fpath, 'w') as f: 58 | yaml.dump(tune_params, f) 59 | 60 | ########################## 61 | # setup parameter string # 62 | ########################## 63 | all_params = {**tune_params, **fixed_params} 64 | all_params['name'] = name 65 | all_flags = fixed_flags 66 | 67 | arg_str = '' 68 | for (k, v) in all_params.items(): 69 | 70 | if isinstance(v, bool): 71 | # bools get turned into flags 72 | if v: 73 | arg_str += ' --{}'.format(k) 74 | else: 75 | arg_str += ' --{} {}'.format(k, v) 76 | 77 | for flag in all_flags: 78 | arg_str += ' --{}'.format(flag) 79 | 80 | ###################### 81 | # setup cuda devices # 82 | ###################### 83 | if device is None: 84 | cuda_prefix = '' 85 | elif isinstance(device, Number): 86 | cuda_prefix = 'CUDA_VISIBLE_DEVICES={}'.format(int(device)) 87 | else: 88 | # list of devices 89 | cuda_prefix = 'CUDA_VISIBLE_DEVICES=' 90 | cuda_prefix += ''.join([str(int(d)) for d in device]) 91 | 92 | ################# 93 | # setup command # 94 | ################# 95 | command = cuda_prefix + ' python {}'.format(script_fpath) + arg_str 96 | 97 | current_time = datetime.now().strftime("%H:%M:%S") 98 | print("\n\n\n=====================================") 99 | print("Starting tune index {} at {} with params {}". 100 | format(tune_idx, current_time, tune_params)) 101 | print(command) 102 | print("=====================================") 103 | start_time = time() 104 | os.system(command) 105 | 106 | runtime = time() - start_time 107 | print("=====================================") 108 | print("Finished running tune index {} after {:1.2f} seconds " 109 | "with parameters {}". 110 | format(tune_idx, runtime, tune_params)) 111 | print("=====================================\n\n\n") 112 | -------------------------------------------------------------------------------- /var_pool/nn/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.optim as optim 3 | 4 | 5 | def get_network_summary(net): 6 | """ 7 | Prints a summary of neural network including the number of parameters. 8 | 9 | Output 10 | ---------- 11 | summary: str 12 | A text summary of the network. 13 | """ 14 | 15 | num_params = 0 16 | num_params_train = 0 17 | 18 | summary = str(net) 19 | 20 | for param in net.parameters(): 21 | n = param.numel() 22 | num_params += n 23 | if param.requires_grad: 24 | num_params_train += n 25 | 26 | summary += '\n\nTotal number of parameters: {}'.\ 27 | format(num_params) 28 | summary += '\nTotal number of trainable parameters: {}'.\ 29 | format(num_params_train) 30 | 31 | return summary 32 | 33 | 34 | def initialize_weights(module): 35 | """ 36 | Initializes the weights for a neural network. 37 | """ 38 | for m in module.modules(): 39 | if isinstance(m, nn.Linear): 40 | nn.init.xavier_normal_(m.weight) 41 | 42 | if m.bias is not None: 43 | m.bias.data.zero_() 44 | 45 | elif isinstance(m, nn.BatchNorm1d): 46 | nn.init.constant_(m.weight, 1) 47 | nn.init.constant_(m.bias, 0) 48 | 49 | 50 | def get_optim(model, algo='adam', lr=1e-4, weight_decay=1e-5): 51 | """ 52 | Sets up the optimizer for the trainable parametres. 53 | 54 | Parameters 55 | ---------- 56 | model: 57 | 58 | algo: str 59 | The optimization algorithm. Must be one of ['adam', 'sgd'] 60 | 61 | lr: float 62 | The learning rate. 63 | 64 | weight_decay: None, float 65 | Weight decay (L2 penalty) 66 | 67 | Output 68 | ------ 69 | optim: torch.optim.Optimizer 70 | The setup optimizer. 71 | """ 72 | 73 | # pull out trainable parameters 74 | trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 75 | 76 | if algo == "adam": 77 | optimizer = optim.Adam(params=trainable_params, lr=lr, 78 | weight_decay=weight_decay) 79 | 80 | elif algo == 'sgd': 81 | optimizer = optim.SGD(params=trainable_params, lr=lr, 82 | momentum=0.9, weight_decay=weight_decay) 83 | 84 | else: 85 | raise NotImplementedError("{} not currently implemented".format(algo)) 86 | 87 | return optimizer 88 | -------------------------------------------------------------------------------- /var_pool/processing/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/processing/.DS_Store -------------------------------------------------------------------------------- /var_pool/processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/processing/__init__.py -------------------------------------------------------------------------------- /var_pool/processing/clf_utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from sklearn.preprocessing import LabelEncoder 3 | import numpy as np 4 | 5 | 6 | def dict_split_clf_df(y_df, label_col='label', 7 | split_col='split', index_col='sample_id'): 8 | """ 9 | Splits the y label data frame into train, validation, and test sets. 10 | 11 | Parameters 12 | ---------- 13 | df: pd.DataFrame 14 | A data frame whose first three columns are ['sample_id', 'label', 'split'] 15 | 'sample_id' is the identifier for each sample 16 | 'label' is the class label 17 | 'split' should be one of ['train', 'val', 'test'] indicating which hold out set the sample is in 18 | 19 | label_col: str 20 | Name of the column containing the y label. 21 | 22 | spilt_col: str 23 | Name of the column containing the train/test/val splits. 24 | 25 | index_col: str 26 | Name of the column containin the sample identifies; this will be used to index the pd.Series in y_split. 27 | 28 | Output 29 | ------ 30 | y_split, label_enc 31 | 32 | y_split: dict of pd.Series 33 | The keys include ['train', 'val', 'test']. The series include the class labels converted to indices. The pd.Series are indexed by the index_col. 34 | 35 | label_enc: sklearn.preprocessing.LabelEncoder 36 | The label encoder use to convert the sample categories to indices. 37 | """ 38 | 39 | # pull out the important columns and copy. 40 | cols = [index_col, label_col, split_col] 41 | y_df = y_df[cols].copy() 42 | 43 | # tranform class names to indices 44 | label_enc = LabelEncoder() 45 | y_idxs = label_enc.fit_transform(y_df[label_col].values) 46 | y_df[label_col] = y_idxs 47 | 48 | # split into train/val/test sets 49 | y_split = {} 50 | for kind, df in y_df.groupby(split_col): 51 | assert kind in ['train', 'val', 'test'] 52 | 53 | # make pandas series for each split 54 | y_split[kind] = pd.Series(df[label_col].values, 55 | name='label', 56 | index=df[index_col]) 57 | 58 | return y_split, label_enc 59 | 60 | 61 | def get_weights_for_balanced_clf(y): 62 | """ 63 | Gets sample weights for the WeightedRandomSampler() for the data loader to make balanced training datasets in each epoch. 64 | 65 | Let class_counts, shape (n_classes, ) be the vector of class counts. The sample weights for the ith observation is 66 | 67 | n_samples / class_counts[y[i]] 68 | 69 | Parameters 70 | ---------- 71 | y: array-like, (n_samples, ) 72 | The observed class indices. 73 | 74 | Output 75 | ------ 76 | sample_weights: array-like, (n_samples, ) 77 | The sample weights 78 | """ 79 | 80 | y = pd.Series(y) 81 | class_counts = y.value_counts() # class counts 82 | 83 | n_samples = len(y) 84 | 85 | sample_weights = n_samples / np.array([class_counts[cl_idx] 86 | for cl_idx in y]) 87 | 88 | return sample_weights 89 | -------------------------------------------------------------------------------- /var_pool/processing/data_split.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import train_test_split 2 | from sklearn.utils import check_random_state 3 | import numpy as np 4 | 5 | # TODO: test code to add 6 | # n_samples = 99 7 | # train_idxs, val_idxs, test_idxs = train_test_val_split(n_samples) 8 | # assert len(train_idxs) + len(val_idxs) + len(test_idxs) == n_samples 9 | def train_test_val_split(n_samples, 10 | train_size=0.8, val_size=0.1, test_size=0.1, 11 | shuffle=True, random_state=None, stratify=None): 12 | """ 13 | Creates the indices for a train/validation/test set. 14 | 15 | Parameters 16 | ---------- 17 | samples: int, array-like 18 | Either the total number of samples or an array containing the sample identifiers. 19 | 20 | train_size, val_size, test_size: float 21 | The train/validation/test proportions. Need to add to 1. 22 | 23 | random_state : int, RandomState instance or None, default=None 24 | Controls the shuffling applied to the data before applying the split. 25 | Pass an int for reproducible output across multiple function calls. 26 | 27 | shuffle : bool, default=True 28 | Whether or not to shuffle the data before splitting. If shuffle=False 29 | then stratify must be None. 30 | 31 | stratify : array-like, default=None 32 | If not None, data is split in a stratified fashion, using this as 33 | the class labels. 34 | 35 | Output 36 | ------ 37 | train_idxs, val_idxs, test_idxs 38 | """ 39 | 40 | # ssert all([train_size > 0, val_size > 0, test_size > 0]) 41 | # assert np.allclose(train_size + val_size + test_size, 1) 42 | rng = check_random_state(random_state) 43 | 44 | idxs = np.arange(int(n_samples)) 45 | # # user input list of samples 46 | # if not isinstance(samples, Number): 47 | # samples = np.array(samples).reshape(-1) 48 | # idxs = np.arange(len(samples)) 49 | # else: 50 | # idxs = np.arange(int(samples)) 51 | 52 | if test_size == 0: 53 | train_idxs, val_idxs = train_test_split(idxs, 54 | train_size=train_size, 55 | test_size=val_size, 56 | random_state=rng, 57 | shuffle=shuffle, 58 | stratify=stratify) 59 | 60 | test_idxs = [] 61 | 62 | return train_idxs, val_idxs, test_idxs 63 | 64 | # split test set off 65 | tr_val_idxs, test_idxs = train_test_split(idxs, 66 | train_size=train_size + val_size, 67 | test_size=test_size, 68 | random_state=rng, 69 | shuffle=shuffle, 70 | stratify=stratify) 71 | 72 | if stratify is not None: 73 | stratify = stratify[tr_val_idxs] 74 | 75 | # calculate number of train samples 76 | n_train = len(tr_val_idxs) * (train_size / (train_size + val_size)) 77 | n_train = int(n_train) 78 | 79 | # split train and validation set 80 | train_idxs, val_idxs = train_test_split(tr_val_idxs, 81 | train_size=n_train, 82 | random_state=rng, shuffle=shuffle, 83 | stratify=stratify) 84 | 85 | return train_idxs, val_idxs, test_idxs 86 | -------------------------------------------------------------------------------- /var_pool/processing/discr_surv_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for discrete time survival models. 3 | 4 | Much of this code is adopted from https://github.com/mahmoodlab/Patch-GCN/blob/c6455a3a01c4ca20cde6ddb9a6f9cd807253a4f7/datasets/dataset_survival.py 5 | """ 6 | import pandas as pd 7 | import numpy as np 8 | from sksurv.metrics import concordance_index_censored 9 | 10 | 11 | def get_discrete_surv_bins(surv_df, n_bins, eps=1e-6, 12 | time_col='survival_time', censor_col='censorship'): 13 | """ 14 | Bins survival times into discrete time intervals and add corresponding bin labels to survival data frame. 15 | 16 | Parameters 17 | ---------- 18 | surv_df: pd.DataFrame, (n_samples, n_features) 19 | The dataframe containing the surivial information. 20 | 21 | n_bins: int 22 | Number bins to bin time into. 23 | 24 | eps: float 25 | TODO: document. 26 | 27 | time_col: str 28 | Name of the time column. 29 | 30 | censor_col: str 31 | Name of the binary censorship column. 32 | 33 | Output 34 | ------ 35 | time_bin_idx, bins 36 | 37 | time_bin_idx: array-like, (n_samples, ) 38 | The index of the survival time bin label. 39 | 40 | bins: array-like, shape (n_bins + 1) 41 | The bin intervals cutoffs. 42 | """ 43 | 44 | surv_df = surv_df.copy() 45 | 46 | # pull out all uncensored times 47 | censor_mask = surv_df[censor_col].astype(bool) 48 | uncensored_df = surv_df[~censor_mask] 49 | times_no_censor = uncensored_df[time_col] 50 | 51 | # TODO: document 52 | _, q_bins = pd.qcut(times_no_censor, q=n_bins, retbins=True, labels=False) 53 | q_bins[-1] = surv_df[time_col].max() + eps 54 | q_bins[0] = surv_df[time_col].min() - eps 55 | 56 | # y_discrete is the index label corresponding to the discrete time interval 57 | y_discr, bins = pd.cut(surv_df[time_col], bins=q_bins, 58 | retbins=True, labels=False, 59 | right=False, include_lowest=True) 60 | 61 | y_discr = y_discr.astype(int) 62 | 63 | return y_discr, bins 64 | 65 | # df.insert(loc=2, column='label', value=y_discr.values.astype(int)) 66 | # surv_df['time_bin'] = y_discr 67 | 68 | # TODO: I don't think we actually need this -- can we get rid of it? 69 | # create label dictionary for bin X censorship class 70 | # bin_cen_label_dict = {} 71 | # key_count = 0 72 | # for bin_idx in range(len(q_bins)-1): 73 | # for c in [0, 1]: 74 | # bin_cen_label_dict.update({(bin_idx, c): key_count}) 75 | # key_count += 1 76 | # add index for bin X censorship class 77 | # for idx in surv_df.index: 78 | # # add label for bins X censorship status 79 | # bin_idx, c = surv_df.loc[idx, ['time_bin', censor_col]] 80 | # surv_bin_cen_idx = bin_cen_label_dict[(bin_idx, int(c))] 81 | # surv_df.loc[idx, 'time_bin_X_censor'] = surv_bin_cen_idx 82 | 83 | # format to ints 84 | # cols = ['time_bin', 'time_bin_X_censor'] 85 | # surv_df[cols] = surv_df[cols].astype(int) 86 | # return surv_df, bins, bin_cen_label_dict 87 | 88 | 89 | def dict_split_discr_surv_df(y_df, time_bin_col='time_bin', 90 | time_col='survival_time', 91 | censor_col='censorship', 92 | index_col='sample_id', 93 | split_col='split'): 94 | """ 95 | Splits the discrete survival response data frame into train, validation, and test sets. 96 | 97 | Parameters 98 | ---------- 99 | df: pd.DataFrame 100 | A data frame whose first four columns are [sample_id, time_bin_col, censor_col, split_col] 101 | sample_id: is the identifier for each sample 102 | time_bin: is discrete time bin. 103 | censorship: is the censorship indicator 104 | split: should be one of ['train', 'val', 'test'] indicating which hold out set the sample is in 105 | 106 | time_bin_col: str 107 | Name of the time bin column. 108 | 109 | time_col: str 110 | Name of the survival time column. 111 | 112 | censor_col: str 113 | Name of the censorship column. 114 | 115 | index_col: str 116 | Name of the column containin the sample identifies; this will be used to index the pd.Series in y_split. 117 | 118 | spilt_col: str 119 | Name of the column containing the train/test/val splits. 120 | 121 | Output 122 | ------ 123 | y_split 124 | 125 | y_split: dict of pd.DataFrame 126 | The keys include ['train', 'val', 'test']. The data frames are indexed by the index_col and contain columns ['time_bin', 'censorship', 'survival_time']. 127 | """ 128 | 129 | # pull out the important columns and copy. 130 | cols = [index_col, time_bin_col, censor_col, time_col, split_col] 131 | y_df = y_df[cols].copy() 132 | 133 | # tranform class names to indices 134 | 135 | # split into train/val/test sets 136 | y_split = {} 137 | for kind, df in y_df.groupby(split_col): 138 | assert kind in ['train', 'val', 'test'] 139 | 140 | ########## 141 | # format # 142 | ########## 143 | 144 | # subset and set index 145 | df = df[[index_col, time_bin_col, censor_col, time_col]].\ 146 | set_index(index_col) 147 | 148 | # format numbers 149 | ensure_int_cols = [time_bin_col, censor_col] 150 | df[ensure_int_cols] = df[ensure_int_cols].astype(int) 151 | 152 | # standardize names 153 | df = df.rename(columns={time_bin_col: 'time_bin', 154 | censor_col: 'censorship', 155 | time_col: 'survival_time'} 156 | ) 157 | 158 | y_split[kind] = df 159 | 160 | return y_split 161 | 162 | 163 | def get_weights_for_balanced_binXcensor(surv_df, 164 | time_bin_col='time_bin', 165 | censor_col='censorship'): 166 | """ 167 | Gets sample weights for the WeightedRandomSampler() for the data loader used by a discrete survival task. 168 | 169 | - let tb_X_c, shape (n_samples, ) be the labels for the time bin X censorship status classes 170 | - let tb_X_c_counts, shape (n_classes_tb_X_c, ) be the counts of number of sampels in each of these classes 171 | 172 | the the sample weight for sample i is given by 173 | 174 | n_samples / tb_X_c_counts[tb_X_c[i]] 175 | 176 | Parameters 177 | ---------- 178 | surv_df: pd.DataFrame, (n_samples, 2) 179 | The survival response data frame containing the time bin columns and censorship status. 180 | 181 | time_bin_col: str 182 | Name of the time bin column. 183 | 184 | censor_col: str 185 | Name of the censorship column. 186 | 187 | Output 188 | ------ 189 | sample_weights: array-like, (n_samples, ) 190 | The sample weights 191 | """ 192 | # This is implementing https://github.com/mahmoodlab/Patch-GCN/blob/c6455a3a01c4ca20cde6ddb9a6f9cd807253a4f7/utils/utils.py#L184 193 | 194 | # create time_bin X censorship status labels 195 | tb_X_c = surv_df[time_bin_col].astype(str)\ 196 | + '_X_' \ 197 | + surv_df[censor_col].astype(str) 198 | 199 | tb_X_c_counts = tb_X_c.value_counts() # class counts 200 | 201 | # class counts for each sample 202 | sample_tbXc_couts = np.array([tb_X_c_counts[label] 203 | for label in tb_X_c]) 204 | 205 | n_samples = len(tb_X_c) 206 | 207 | sample_weights = n_samples / sample_tbXc_couts 208 | 209 | return sample_weights 210 | 211 | 212 | def get_perm_c_index_quantile(event, time, n_perm=1000, q=0.95): 213 | """ 214 | Gets the qth quantile from the permutation distribution of the c-index. 215 | 216 | Parameters 217 | ---------- 218 | event : array-like, shape = (n_samples,) 219 | Boolean array denotes whether an event occurred. 220 | 221 | time : array-like, shape = (n_samples,) 222 | Array containing the time of an event or time of censoring. 223 | 224 | n_perm: int 225 | Number of permutation samples to draw. 226 | 227 | q : array_like of float 228 | Quantile or sequence of quantiles to compute, which must be between 0 and 1 inclusive. 229 | 230 | Output 231 | ------ 232 | quantiles: float 233 | The qth quantile of the permutation distribution. 234 | """ 235 | 236 | perm_samples = [] 237 | random_estimate = np.arange(len(time)) 238 | for _ in range(n_perm): 239 | 240 | # randomly permuted estimate! 241 | random_estimate = np.random.permutation(random_estimate) 242 | 243 | ci_perm = concordance_index_censored(event_indicator=event, 244 | event_time=time, 245 | estimate=random_estimate)[0] 246 | 247 | perm_samples.append(ci_perm) 248 | 249 | return np.quantile(a=perm_samples, q=q) 250 | -------------------------------------------------------------------------------- /var_pool/script_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from pprint import pformat 4 | import numpy as np 5 | 6 | from var_pool.nn.utils import get_network_summary 7 | 8 | 9 | def parse_mil_task_yaml(fpath): 10 | """ 11 | Parameters 12 | ---------- 13 | fpath: str 14 | Path to yaml file containging task information. 15 | 16 | Output 17 | ------ 18 | feat_dir, y_fpath, train_dir, task 19 | """ 20 | assert os.path.exists(fpath), 'No file named {} found'.format(fpath) 21 | 22 | with open(fpath) as file: 23 | data = yaml.safe_load(file) 24 | 25 | return data['feats_dir'], data['y_fpath'], data['train_dir'], data['task'] 26 | 27 | 28 | def write_training_summary(fpath, task, y_split, args, 29 | runtime, val_loss, val_metrics, 30 | train_metrics, train_loss, 31 | model, loss_func, 32 | n_bag_feats, n_epoch_completed, 33 | train_bag_size_summary=None, 34 | epochs_with_records=None): 35 | """ 36 | Writes a summary of training a neural network. 37 | 38 | Parameters 39 | ---------- 40 | fpath: str 41 | Filepath for where to save the text file. 42 | 43 | Output 44 | ------ 45 | summary: str 46 | The summary string that was saved to disk. 47 | """ 48 | with open(fpath, 'w') as f: 49 | f.write("Final validation loss: {:1.3f}".format(val_loss)) 50 | f.write("\nFinal validation metrics\n") 51 | f.write(pformat(val_metrics)) 52 | 53 | f.write("\n\nFinal train loss: {:1.3f}".format(train_loss)) 54 | f.write("\nFinal train metrics\n") 55 | f.write(pformat(train_metrics)) 56 | 57 | if epochs_with_records is not None: 58 | epochs_with_records = np.array(epochs_with_records) + 1 59 | f.write("\nRecords set on epochs: {}\n".format(epochs_with_records)) 60 | 61 | runtime_min = runtime / 60 62 | f.write("\n\nTraining model took {:1.2f} minutes " 63 | "({:1.2f} minutes per epoch)". 64 | format(runtime_min, runtime_min / n_epoch_completed)) 65 | f.write("\nTraining completed {}/{} epochs". 66 | format(n_epoch_completed, args.n_epochs)) 67 | 68 | f.write("\n\ntask = {}\n".format(task)) 69 | for k, y in y_split.items(): 70 | f.write("Number of {} samples = {}\n".format(k, len(y))) 71 | 72 | f.write("\nNumber of bag features {}".format(n_bag_feats)) 73 | 74 | if train_bag_size_summary is not None: 75 | f.write("\nTraining bag size summary \n") 76 | f.write(pformat(train_bag_size_summary)) 77 | 78 | f.write("\n\nargs=") 79 | f.write(str(args)) 80 | 81 | f.write('\n\n') 82 | f.write(str(loss_func)) 83 | f.write('\n\n') 84 | f.write(get_network_summary(model)) 85 | 86 | # return the summary we just saved 87 | with open(fpath, 'r') as f: 88 | summary = f.read() 89 | return summary 90 | 91 | 92 | def write_test_summary(fpath, task, split, eval_loss, eval_metrics, y_eval, 93 | n_bag_feats, eval_bag_size_summary, loss_func, model): 94 | """ 95 | Writes a summary of the test results. 96 | 97 | Parameters 98 | ---------- 99 | fpath: str 100 | Filepath for where to save the text file. 101 | 102 | Output 103 | ------ 104 | summary: str 105 | The summary string that was saved to disk. 106 | """ 107 | with open(fpath, 'w') as f: 108 | f.write("Final {} loss: {:1.3f}".format(split, eval_loss)) 109 | f.write("\nFinal {} metrics\n".format(split)) 110 | f.write(pformat(eval_metrics)) 111 | 112 | f.write("\n\ntask = {}\n".format(task)) 113 | f.write("{} {} samples\n".format(len(y_eval), split)) 114 | 115 | f.write("\nNumber of bag features {}".format(n_bag_feats)) 116 | f.write("\n{} bag size summary \n") 117 | f.write(pformat(eval_bag_size_summary)) 118 | 119 | f.write('\n\n') 120 | f.write(str(loss_func)) 121 | f.write('\n\n') 122 | f.write(get_network_summary(model)) 123 | 124 | # return the summary we just saved 125 | with open(fpath, 'r') as f: 126 | summary = f.read() 127 | return summary 128 | 129 | 130 | def descr_stats(values): 131 | """ 132 | Returns a dict of descriptive statistics (mean, min, max, etc) of an array of values. 133 | """ 134 | values = np.array(values).reshape(-1) 135 | 136 | return {'mean': np.mean(values), 137 | 'median': np.median(values), 138 | 'std': np.std(values), 139 | 'min': np.min(values), 140 | 'max': np.max(values), 141 | 'num': len(values) 142 | } 143 | -------------------------------------------------------------------------------- /var_pool/utils.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | import numpy as np 3 | 4 | 5 | def format_command_line_args(kws={}, flags=[]): 6 | """ 7 | Parameters 8 | ---------- 9 | kws: dict 10 | Arguments that look like --key value. If a list is provided it will be input as --key value_1 value_2, ... 11 | 12 | flags: list of str 13 | Arguments that look like --flag 14 | 15 | Output 16 | ------ 17 | args: str 18 | The argument string 19 | """ 20 | 21 | args = '' 22 | for key, value in kws.items(): 23 | 24 | args += ' --{}'.format(key) 25 | 26 | if np.isscalar(value): 27 | args += ' {}'.format(value) 28 | else: 29 | for v in value: 30 | args += ' {}'.format(v) 31 | 32 | for flag in flags: 33 | args += ' --{}'.format(flag) 34 | 35 | return args 36 | 37 | 38 | def get_counts_and_props(y, n_classes=None, class_names=None): 39 | """ 40 | Gets the counts and proportions for each class from a vector of class index predictions. 41 | 42 | Parameters 43 | ---------- 44 | y: array-like, (n_samples, ) 45 | The predicted class indices. 46 | 47 | n_classes: None, int 48 | (Optional) The number of class labels; if not provided will try to guess by either len(class_names) (if provided) or max(y) + 1. 49 | 50 | class_names: None, str, list of str 51 | (Optional) If provided will return the counts/props as a dict with the class names as keys. If set to 'default' then will name the classes class_0, class_1, ... If None, will return the counts/props as lists 52 | 53 | Output 54 | ------ 55 | counts, props 56 | 57 | if class_names==None 58 | counts: array-like, (n_classes, ) 59 | The counts for each class. 60 | 61 | props: array-like, (n_classes, ) 62 | The proportions for each class. 63 | 64 | if class_names is provided then counts/props will be a dicts 65 | """ 66 | y = np.array(y).reshape(-1) 67 | 68 | if n_classes is None: 69 | if class_names is not None and not isinstance(class_names, str): 70 | n_classes = len(class_names) 71 | 72 | else: 73 | # +1 bc of zero indexing! 74 | n_classes = max(y) + 1 75 | 76 | ######################## 77 | # compute counts/props # 78 | ######################## 79 | counts = np.zeros(n_classes) 80 | for idx, cnt in Counter(y).items(): 81 | counts[idx] = cnt 82 | 83 | props = counts / len(y) 84 | 85 | ################# 86 | # Format output # 87 | ################# 88 | if class_names is None: 89 | return counts, props 90 | 91 | if isinstance(class_names, str) and class_names == 'default': 92 | class_names = ['class_' + str(i) for i in range(n_classes)] 93 | 94 | counts_dict = {} 95 | props_dict = {} 96 | for cl_idx, name in enumerate(class_names): 97 | 98 | counts_dict[name] = counts[cl_idx] 99 | props_dict[name] = props[cl_idx] 100 | 101 | return counts_dict, props_dict 102 | 103 | 104 | def get_traceback(e): 105 | """ 106 | Returns the traceback from any expection. 107 | Parameters 108 | ---------- 109 | e: BaseException 110 | Any excpetion. 111 | Output 112 | ------ 113 | str 114 | """ 115 | # https://stackoverflow.com/questions/3702675/how-to-catch-and-print-the-full-exception-traceback-without-halting-exiting-the 116 | import traceback 117 | return ''.join(traceback.format_exception(None, e, e.__traceback__)) 118 | -------------------------------------------------------------------------------- /var_pool/viz/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/viz/.DS_Store -------------------------------------------------------------------------------- /var_pool/viz/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mahmoodlab/varpool/2e40ab4d9a129e430fdb7dcf3304f55388f2ad88/var_pool/viz/__init__.py -------------------------------------------------------------------------------- /var_pool/viz/top_attn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import h5py 4 | from pathlib import Path 5 | import os 6 | import torch 7 | from openslide import open_slide 8 | 9 | from var_pool.file_utils import get_file_names 10 | from var_pool.viz.utils import read_region, make_image_grid, get_idxs_largest 11 | 12 | 13 | def viz_top_attn_patches(model, wsi_fpath, h5_fpath, autogen_fpath, 14 | device=None, 15 | n_top_patches=10, save_fpath=None): 16 | """ 17 | Visualizes the top attended patches. Handles the case when a patient has multiple WSIs. 18 | 19 | Parameteres 20 | ----------- 21 | model: nn.Module 22 | The model. 23 | 24 | wsi_fpath: str, list or str 25 | Path(s) to WSI image. 26 | 27 | h5_fpath: str, list of str 28 | Path(s) to hdf5 file containing the features and coordinates. 29 | 30 | autogen_fpath: str. 31 | Path to autogen file containint image metadata needed to load the patches. 32 | 33 | device: None, device 34 | The device to put the tensors on. 35 | 36 | save_fpath: None, str 37 | (Optional) Path to save the top attneded patch grid. 38 | 39 | Output 40 | ------ 41 | top_attn_patches: PIL.Image 42 | The top attended patches concatenated together in a grid. 43 | """ 44 | 45 | autogen = pd.read_csv(autogen_fpath, index_col='slide_id') 46 | feats, coords, patch_ids = load_patient_patch_data(wsi_fpath, h5_fpath) 47 | model.eval() 48 | 49 | # Compute attention scores 50 | bag = torch.from_numpy(feats).unsqueeze(0) 51 | if device is not None: 52 | bag = bag.to(device) 53 | with torch.no_grad(): 54 | attn_scores, _ = model.enc_and_attend(bag) 55 | attn_scores = attn_scores.detach().cpu().numpy().squeeze() 56 | 57 | # get top attended patches 58 | idxs_top_attn = get_idxs_largest(attn_scores, k=n_top_patches) 59 | 60 | # Load each of the top patches 61 | top_patches = [] 62 | for patient_level_idx in idxs_top_attn: 63 | 64 | # get the WSI/patch coordinate for this patch 65 | wsi_idx, patch_idx = patch_ids[patient_level_idx] 66 | 67 | wsi = open_slide(wsi_fpath[wsi_idx]) 68 | 69 | # metadata needed to load the patches 70 | wsi_fname = get_file_names(wsi_fpath[wsi_idx]) + '.svs' 71 | patch_level, patch_size, custom_downsample = \ 72 | autogen.loc[wsi_fname]\ 73 | [['patch_level', 'patch_size', 'custom_downsample']] 74 | 75 | location = coords[wsi_idx][patch_idx, :] 76 | 77 | patch = read_region(wsi=wsi, 78 | location=location, 79 | level=patch_level, 80 | patch_size=patch_size, 81 | custom_downsample=custom_downsample) 82 | 83 | top_patches.append(patch) 84 | 85 | # concatenated top patches together 86 | top_patches = make_image_grid(top_patches, pad=3, n_cols=10) 87 | 88 | # maybe save the image 89 | if save_fpath is not None: 90 | save_dir = Path(save_fpath).parent 91 | os.makedirs(save_dir, exist_ok=True) 92 | 93 | top_patches.save(save_fpath) 94 | 95 | return top_patches 96 | 97 | 98 | def load_patient_patch_data(wsi_fpath, h5_fpath): 99 | """ 100 | Loads the patch features and coordinates for a given patient who may have multiple WSIs. 101 | 102 | Parameters 103 | ---------- 104 | wsi_fpath: str, list or str 105 | Path(s) to WSI image. 106 | 107 | h5_fpath: str, list of str 108 | Path(s) to hdf5 file containing the features and coordinates. 109 | 110 | Output 111 | ----- 112 | feats, coords, patch_ids 113 | 114 | feats: array like, (n_patches_tot, n_feats) 115 | The patch features for each WSI. 116 | 117 | coords: list of array-like 118 | The patch coordinates for each WSI. 119 | 120 | patch_ids: array-like, (n_patches_tot, 2) 121 | The first column identifies which WSI the patch belongs to, the second column identifies which patch it is. 122 | """ 123 | 124 | # ensure wsi_fpath/h5_fpath are lists of str 125 | if isinstance(wsi_fpath, str): 126 | wsi_fpath = [wsi_fpath] 127 | h5_fpath = [h5_fpath] 128 | assert len(wsi_fpath) == len(h5_fpath) 129 | n_wsis = len(wsi_fpath) 130 | 131 | # Load feats/coords for each WSI 132 | coords = [] 133 | feats = [] 134 | patch_ids = [] 135 | for wsi_idx in range(n_wsis): 136 | 137 | # Load coords/instance features for this image 138 | with h5py.File(h5_fpath[wsi_idx], 'r') as hdf5_file: 139 | coords.append(np.array(hdf5_file['coords'])) 140 | feats.append(np.array(hdf5_file['features'])) 141 | 142 | # patch_id = (wsi_idx, patch_idx) 143 | n_patches = hdf5_file['features'].shape[0] 144 | patch_ids.append(np.vstack([np.repeat(wsi_idx, repeats=n_patches), 145 | np.arange(n_patches)]).T) 146 | 147 | patch_ids = np.vstack(patch_ids) 148 | feats = np.vstack(feats) 149 | 150 | return feats, coords, patch_ids 151 | -------------------------------------------------------------------------------- /var_pool/viz/utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | from numbers import Number 5 | 6 | 7 | def read_region(wsi, location, level, patch_size, custom_downsample): 8 | """ 9 | Wrapper for wsi.read_region that handles downsampling. If we want a patch at 20x mag, but the WSI object only has 40x mag available then we read in a patch at 40x that is twice the desired size then downsample it. 10 | 11 | Parameters 12 | ---------- 13 | wsi: openslide.OpenSlide 14 | The WSI object. 15 | 16 | location: tuple of ints 17 | The coordinates to read the patch in at the specified level before possibly downsampling. 18 | 19 | level: int 20 | The level at which to read the patch in from the WSI object. 21 | 22 | patch_size: int, tuple of ints 23 | The patch size to read in before possibliy downsampling. 24 | 25 | custom_downsample: int 26 | Downsample factor. 27 | 28 | Output 29 | ------ 30 | img: PIL.Image.Image 31 | The region image. 32 | """ 33 | 34 | # process read patch size to tuple 35 | if isinstance(patch_size, Number): 36 | _read_patch_size = (int(patch_size), int(patch_size)) 37 | else: 38 | _read_patch_size = tuple(patch_size) 39 | assert len(_read_patch_size) == 2 40 | 41 | # read in raw patch from image 42 | img = wsi.read_region(location=location, 43 | level=level, 44 | size=_read_patch_size 45 | ).convert('RGB') 46 | 47 | # possible resize image 48 | if custom_downsample > 1: 49 | target_patch_size = (_read_patch_size[0] // custom_downsample, 50 | _read_patch_size[1] // custom_downsample) 51 | 52 | img = img.resize(size=target_patch_size) 53 | 54 | return img 55 | 56 | 57 | def make_image_grid(imgs, pad=3, n_cols=10): 58 | """ 59 | Makes a grid of PIL images. 60 | 61 | Parameters 62 | ---------- 63 | imgs: list of PIL imags. 64 | The images to display in a grdie. 65 | 66 | pad: int 67 | Amount of padded to put between images. 68 | 69 | n_cols: int 70 | Maximum number of columns to put in the grid. 71 | 72 | Output 73 | ------ 74 | grid: PIL.Image 75 | The image grid. 76 | """ 77 | n_images = len(imgs) 78 | n_rows = int(np.ceil(n_images / n_cols)) 79 | 80 | img_height = max([img.size[0] for img in imgs]) 81 | img_width = max([img.size[1] for img in imgs]) 82 | 83 | grid = Image.new('RGBA', 84 | size=(n_rows * img_height + (n_rows - 1) * pad, 85 | n_cols * img_width + (n_cols - 1) * pad), 86 | color=(255, 255, 255)) 87 | 88 | for idx in range(n_images): 89 | 90 | row_idx = idx // n_cols 91 | col_idx = idx % n_cols 92 | 93 | grid.paste(imgs[idx], 94 | box=(row_idx * (img_height + pad), 95 | col_idx * (img_width + pad)) 96 | ) 97 | return grid 98 | 99 | 100 | def save_fig(fpath, dpi=100, bbox_inches='tight'): 101 | """ 102 | Saves and closes a plot. 103 | """ 104 | plt.savefig(fpath, dpi=dpi, bbox_inches=bbox_inches) 105 | plt.close() 106 | 107 | 108 | def get_idxs_largest(values, k): 109 | """ 110 | Gets the indices of the k largest elements of an array. 111 | 112 | Parameters 113 | ---------- 114 | values: array-like 115 | The values. 116 | 117 | Output 118 | ------ 119 | idxs: array-like, shape (k, ) 120 | The idxs of the largest elements. 121 | """ 122 | values = np.array(values) 123 | assert values.ndim == 1 124 | return np.argpartition(values, -k)[-k:] 125 | 126 | 127 | def get_idxs_smallest(values, k): 128 | """ 129 | Gets the indices of the k smallest elements of an array. 130 | 131 | Parameters 132 | ---------- 133 | values: array-like 134 | The values. 135 | 136 | Output 137 | ------ 138 | idxs: array-like, shape (k, ) 139 | The idxs of the smallest elements. 140 | """ 141 | values = np.array(values) 142 | assert values.ndim == 1 143 | return np.argpartition(values, k)[:k] 144 | --------------------------------------------------------------------------------