├── .gitignore ├── LICENSE ├── README.md ├── evaluation ├── CorrelationStats.py ├── __init__.py ├── evaluate_model.py └── predict_independent_dataset.py ├── examples ├── gene_list.csv └── ref_file.csv ├── images ├── overview_new.png ├── seq-logo.png └── sequoia-logo.png ├── pre_processing ├── __init__.py ├── compute_features_hdf5.py ├── download_RNASeq_TCGAbiolinks.R ├── kmean_features.py ├── patch_gen_hdf5.py ├── patient_splits.zip └── test_wsis.pkl ├── requirements.txt ├── scripts ├── extract_kmean_features.sh ├── extract_patch.sh ├── extract_resnet_features.sh ├── run_he2rna.sh ├── run_train.sh └── run_visualize.sh ├── spatial_vis ├── __init__.py ├── gbm_celltype_analysis.py ├── get_emd.py └── visualize.py └── src ├── __init__.py ├── folds ├── test-blca-0.npy ├── test-blca-1.npy ├── test-blca-2.npy ├── test-blca-3.npy ├── test-blca-4.npy ├── test-brca-0.npy ├── test-brca-1.npy ├── test-brca-2.npy ├── test-brca-3.npy ├── test-brca-4.npy ├── test-coad-0.npy ├── test-coad-1.npy ├── test-coad-2.npy ├── test-coad-3.npy ├── test-coad-4.npy ├── test-gbm-0.npy ├── test-gbm-1.npy ├── test-gbm-2.npy ├── test-gbm-3.npy ├── test-gbm-4.npy ├── test-hnsc-0.npy ├── test-hnsc-1.npy ├── test-hnsc-2.npy ├── test-hnsc-3.npy ├── test-hnsc-4.npy ├── test-kirc-0.npy ├── test-kirc-1.npy ├── test-kirc-2.npy ├── test-kirc-3.npy ├── test-kirc-4.npy ├── test-kirp-0.npy ├── test-kirp-1.npy ├── test-kirp-2.npy ├── test-kirp-3.npy ├── test-kirp-4.npy ├── test-lihc-0.npy ├── test-lihc-1.npy ├── test-lihc-2.npy ├── test-lihc-3.npy ├── test-lihc-4.npy ├── test-luad-0.npy ├── test-luad-1.npy ├── test-luad-2.npy ├── test-luad-3.npy ├── test-luad-4.npy ├── test-lusc-0.npy ├── test-lusc-1.npy ├── test-lusc-2.npy ├── test-lusc-3.npy ├── test-lusc-4.npy ├── test-paad-0.npy ├── test-paad-1.npy ├── test-paad-2.npy ├── test-paad-3.npy ├── test-paad-4.npy ├── test-prad-0.npy ├── test-prad-1.npy ├── test-prad-2.npy ├── test-prad-3.npy ├── test-prad-4.npy ├── test-skcm-0.npy ├── test-skcm-1.npy ├── test-skcm-2.npy ├── test-skcm-3.npy ├── test-skcm-4.npy ├── test-stad-0.npy ├── test-stad-1.npy ├── test-stad-2.npy ├── test-stad-3.npy ├── test-stad-4.npy ├── test-thca-0.npy ├── test-thca-1.npy ├── test-thca-2.npy ├── test-thca-3.npy ├── test-thca-4.npy ├── test-ucec-0.npy ├── test-ucec-1.npy ├── test-ucec-2.npy ├── test-ucec-3.npy └── test-ucec-4.npy ├── he2rna.py ├── main.py ├── pretrain_gtex.py ├── read_data.py ├── resnet.py ├── tformer_lin.py ├── utils.py └── vit.py /.gitignore: -------------------------------------------------------------------------------- 1 | env/ 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 GevaertLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | 6 | # :evergreen_tree: SEQUOIA: Digital profiling of cancer transcriptomes with linearized attention 7 | 8 | **Abstract** 9 | 10 | _Cancer is a heterogeneous disease requiring costly genetic profiling for better understanding and management. Recent advances in deep learning have enabled cost-effective predictions of genetic alterations from whole slide images (WSIs). While transformers have driven significant progress in non-medical domains, their application to WSIs lags behind due to high model complexity and limited dataset sizes. Here, we introduce SEQUOIA, a linearized transformer model that predicts cancer transcriptomic profiles from WSIs. SEQUOIA is developed using 7,584 tumor samples across 16 cancer types, with its generalization capacity validated on two independent cohorts comprising 1,368 tumors. Accurately predicted genes are associated with key cancer processes, including inflammatory response, cell cycles and metabolism. Further, we demonstrate the value of SEQUOIA in stratifying the risk of breast cancer recurrence and in resolving spatial gene expression at loco-regional levels. SEQUOIA hence deciphers clinically relevant information from WSIs, opening avenues for personalized cancer management._ 11 | 12 | **Overview** 13 |

14 | 15 |

16 | 17 | ## Folder structure 18 | 19 | - `scripts`: example bash (driver) scripts to run the pre-processing, training and evaluation. 20 | - `examples`: example input files. 21 | - `pre-processing`: pre-processing scripts. 22 | - `evaluation`: evaluation scripts and output gene list ordered by index. 23 | - `spatial_vis`: scripts for generating spatial predictions of gene expression values. 24 | - `src`: main files for models and training. 25 | 26 | ## System requirements 27 | 28 | Software dependencies and versions are listed in requirements.txt 29 | 30 | ## Installation 31 | 32 | First, clone this git repository: `git clone https://github.com/gevaertlab/sequoia-pub.git` 33 | 34 | Then, create a conda environment: `conda create -n sequoia python=3.9` and activate: `conda activate sequoia` 35 | 36 | Install the openslide library: `conda install -c conda-forge openslide==4.0.0` 37 | 38 | Install the required package dependencies: `pip install -r requirements.txt` 39 | 40 | Finally, install [Openslide](https://openslide.org/download/) (>v3.4.0) 41 | 42 | Expected installation time in normal Linux environment: 15 mins 43 | 44 | ## Pre-processing 45 | 46 | Scripts for pre-processing are located in the `pre-processing` folder. 47 | 48 | ## Construct reference file 49 | All computational processes requires a *reference.csv* file, which has one row per WSI and their corresponding gene expression values. The RNA columns are named with the following format 'rna_{GENENAME}'. An optional 'tcga_project' column indicates the TCGA project that data belongs to. See `examples/ref_file.csv` for an example. A complete list of genes (columns) used for the model developed and evaluation in our paper can be found in `examples/gene_list.csv`. 50 | 51 | For inference using pre-training model weights, simply set all gene expression value in the table to 0. 52 | For model training, we need to replace 0 with real gene expression data. 53 | 54 | ### Step 1: Patch extraction 55 | 56 | To extract patches from whole-slide images (WSIs), please use the script `patch_gen_hdf5.py`. 57 | An example script to run the patch extraction: `scripts/extract_patch.sh` 58 | 59 | Note, the ```--start``` and ```--end``` parameters indicate the rows (WSIs) in the *reference.csv* file that need to be extracted. This is useful to execute the script in parallel. 60 | 61 | ### Step 2: Obtain resnet/uni features 62 | 63 | To obtain resnet/uni features from patches, please use the script `compute_features_hdf5.py`. The script converts each patch into a linear feature vector. 64 | 65 | Note: if you use the UNI model, you need to follow the installation procedure in the original [github](https://github.com/mahmoodlab/UNI) and install the necessary [required packages](https://github.com/mahmoodlab/UNI/blob/main/setup.py). 66 | 67 | An example script to run the patch extraction: `scripts/extract_resnet_features.sh` 68 | 69 | ### Step 3: Obtain k-Means features 70 | 71 | The next step once the resnet/uni features have been obtained is to compute the 100 clusters used as input for the model. They are computed per slide, so it is pretty straightforward, and it is pretty fast. 72 | 73 | An example script to run the patch extraction: `scripts/extract_kmean_features.sh` 74 | 75 | - Outputs from Step 2 and Step 3: 76 | *features* folder, this contains for each WSI a HDF5 file that stores both the features obtained using the resnet/uni (inside the **resnet_features** or **uni_features** dataset) as well as the output from the K-means algorithm (inside **cluster_features** dataset). 77 | 78 | Expected run time: depend on the hardware (CPU/GPU) and the number of slides 79 | 80 | ## Running evaluation of our pre-trained model on an independent dataset 81 | 82 | After WSI pre-processing, our pre-trained SEQUOIA model (UNI features and linearized transformer aggregation) can be evaluated on the WSIs of an independent dataset by running ``evaluation/predict_independent_dataset.py``. 83 | 84 | We released the weights for each cancer type, from each of the five folds on [HuggingFace](https://huggingface.co/gevaertlab), so make sure to login (See [HuggingFace docs](https://huggingface.co/docs/huggingface_hub/en/quick-start#login-command) for more information): 85 | 86 | ``` 87 | from huggingface_hub import login 88 | login() 89 | ``` 90 | 91 | The gene names corresponding to the output can be found in the `evaluation/gene_list.csv` file. 92 | 93 | ## Pre-training, fine-tunning and loading pre-trained weights 94 | 95 | ### Step 1 (Optional): pretrain models on the GTEx data 96 | 97 | To pretrain the weights of the model on normal tissues, please use the script `pretrain_gtex.py`. The process requires an input *reference.csv* file, indicating the gene expression values for each WSI. See `examples/ref_file.csv` for an example. 98 | 99 | ### Step 2 (Optional): load the same train/validation/test splits that we used 100 | 101 | The TCGA splits for each fold are available in the `patient_splits.zip` file in the [pre_processing](https://github.com/gevaertlab/sequoia-pub/blob/master/pre_processing/patient_splits.zip) folder. 102 | 103 | To load the splits from the numpy file, unzip the `patient_splits.zip` folder. To use: 104 | 105 | ``` 106 | split = np.load(f'TCGA-{cancer}.npy'), allow_pickle=True).item() 107 | for i in range(5): 108 | train_patients = split[f'fold_{i}']['train'] 109 | val_patients = split[f'fold_{i}']['val'] 110 | test_patients = split[f'fold_{i}']['test'] 111 | ``` 112 | 113 | Note that these contain only the patient ID, not the entire WSI filename. The WSI file names within each test fold are available in `test_wsis.pkl` in the same [pre_processing](https://github.com/gevaertlab/sequoia-pub/blob/master/pre_processing/) folder. To use: 114 | 115 | ``` 116 | with open('test_wsis.pkl','rb') as f: 117 | data = pickle.load(f) 118 | test_wsis = data[f'{cancer}']['split_{i}'] 119 | ``` 120 | 121 | Concatenating all the WSIs from a particular cancer type across all the folds results in all the WSI IDs that were used for that cancer type. So to find the exact WSI filenames used in the train/validation split from fold 0, match the patient IDs from `train_patients` and `val_patients` above to the WSI IDs across folds 1-4 in `test_wsis.pkl`: 122 | 123 | ``` 124 | train_patients = split['fold_0']['train'] 125 | val_patients = split['fold_0']['val'] 126 | wsis = np.concatenate([data['brca'][f'split_{i}']['wsi_file_name'] for i in range(1,5)]) 127 | train_wsis = [i for i in wsis if '-'.join(i.split('-')[:3]) in train_patients] 128 | val_wsis = [i for i in wsis if '-'.join(i.split('-')[:3]) in val_patients] 129 | 130 | ``` 131 | 132 | ### Step 3 (Optional): load published model checkpoint 133 | 134 | As mentioned above, our pre-trained checkpoint weights for SEQUOIA are available on [HuggingFace](https://huggingface.co/gevaertlab). Patients that were present in the test set in each fold can be found in the [pre_processing](https://github.com/gevaertlab/sequoia-pub/blob/master/pre_processing/) as explained in Step 2 above. Make sure to login to HuggingFace (see above). 135 | 136 | Then use: 137 | ``` 138 | from src.tformer_lin import ViS 139 | 140 | cancer = 'brca' 141 | i = 0 ## fold number 142 | model = ViS.from_pretrained(f"gevaertlab/sequoia-{cancer}-{i}") 143 | ``` 144 | The gene names corresponding to the output can be found in the `evaluation/gene_list.csv` file. 145 | 146 | ### Step 4: Train or fine-tune SEQUOIA on the TCGA data 147 | 148 | Now we can train the model from scratch or fine-tune it on the TCGA data. Here is an example bash script to run the process: `scripts/run_train.sh` 149 | 150 | The parameters are explained within the `main.py` file. 151 | 152 | Some points that we want to emphasize: 153 | - If you pre-trained on a dataset that contains a different number of genes than the finetuning dataset, you need to set the ```--change_num_genes``` parameter to 1 and specify in the ```--num_genes``` parameter how many genes were used for pretraining. To indicate the path to the pretrained weights, use the ```--checkpoint``` parameters. 154 | - ```--model_type``` is used to define the aggregation type. For the SEQUOIA model (linearized transformer aggregation) use 'vis'. 155 | 156 | ## Benchmarking 157 | 158 | For running the benchmarked variations of the architecture: 159 | - MLP aggregation: for this part we made use of the implementation from HE2RNA, which can be found in `he2rna.py`. An example run script is provided in `scripts/run_he2rna.sh` 160 | - transformer aggregation: this model type is implemented in the `main.py`. use ```--model_type``` 'vit'. 161 | 162 | 163 | ## Evaluation 164 | 165 | Pearson correlation and RMSE values are calculated to compare the predicted gene expression values to the ground truth. The significantly well predicted genes are selected using correlation coefficient, p value, rmse, and by statistical comparisons to an untrained model with the same architecture. 166 | 167 | Evaluation script: `evaluation/evaluate_model.py`. Output: three dataframes `all_genes.csv`: contains evaluation metrics for all genes, `sig_genes.csv`: metrics for only the significant genes and `num_sig_genes.csv` contains the number of significant genes per cancer type with this model. 168 | 169 | ## Spatial gene expression predictions 170 | 171 | Scripts for predicting spatial gene expression levels within the same tissue slide are wrapped in: `spatial_vis` 172 | 173 | - ```visualize.py``` is the file to generate spatial predictions made with a saved SEQUOIA model. 174 | - the arguments are explained in the file. an example run file is provided in `scripts/run_visualize.sh` 175 | - output: the output is a dataframe that contains the following columns: 176 | ``` 177 | - xcoord: the x coordinate of a tile (absolute position of tile in the WSI -- note that adjacent tiles will have coordinates that are tile_width apart!) 178 | - ycoord: same as xcoord for the y 179 | - xcoord_tf: the x coordinate of a tile when transforming the original coordinates to start in the left upper corner at position x=0,y=0 and with distance 1 between tiles (i.e. next tile has coordinate x=1,y=0) 180 | - ycoord_tf: same as xcoord_tf for the y 181 | - gene_{x}: for each gene, there will be a column 'gene_{x}' that contains the spatial prediction for that gene of the model from fold {x}, with x = 1..num_folds 182 | - gene: for each gene there will also be a column without the _{x} part, which represents the average across the used folds 183 | ``` 184 | - ```get_emd.py``` contains code to calculate the two dimensional Earth Mover's Distance between a prediction map (generated with ```visualize.py``` script) and ground truth spatial transcriptomics. 185 | - ```gbm_celltype_analysis.py``` contains (1) code to examine spatial co-expression of genes for the four meta-modules described in the paper; (2) code to visualize spatial organization of meta-modules on the considered slides. 186 | 187 | 188 | # License 189 | 190 | © [Gevaert's Lab](https://med.stanford.edu/gevaertlab.html) MIT License 191 | 192 | 193 | 194 | -------------------------------------------------------------------------------- /evaluation/CorrelationStats.py: -------------------------------------------------------------------------------- 1 | """ 2 | Functions for calculating the statistical significant differences between two dependent or independent correlation 3 | coefficients. 4 | The Fisher and Steiger method is adopted from the R package http://personality-project.org/r/html/paired.r.html 5 | and is described in detail in the book 'Statistical Methods for Psychology' 6 | The Zou method is adopted from http://seriousstats.wordpress.com/2012/02/05/comparing-correlations/ 7 | Credit goes to the authors of above mentioned packages! 8 | 9 | Author: Philipp Singer (www.philippsinger.info) 10 | """ 11 | 12 | from __future__ import division 13 | 14 | __author__ = 'psinger' 15 | 16 | import numpy as np 17 | from scipy.stats import t, norm 18 | from math import atanh, pow 19 | from numpy import tanh 20 | 21 | def rz_ci(r, n, conf_level = 0.95): 22 | zr_se = pow(1/(n - 3), .5) 23 | moe = norm.ppf(1 - (1 - conf_level)/float(2)) * zr_se 24 | zu = atanh(r) + moe 25 | zl = atanh(r) - moe 26 | return tanh((zl, zu)) 27 | 28 | def rho_rxy_rxz(rxy, rxz, ryz): 29 | num = (ryz-1/2.*rxy*rxz)*(1-pow(rxy,2)-pow(rxz,2)-pow(ryz,2))+pow(ryz,3) 30 | den = (1 - pow(rxy,2)) * (1 - pow(rxz,2)) 31 | return num/float(den) 32 | 33 | def dependent_corr(xy, xz, yz, n, twotailed=True, conf_level=0.95, method='steiger'): 34 | """ 35 | Calculates the statistic significance between two dependent correlation coefficients 36 | @param xy: correlation coefficient between x and y 37 | @param xz: correlation coefficient between x and z 38 | @param yz: correlation coefficient between y and z 39 | @param n: number of elements in x, y and z 40 | @param twotailed: whether to calculate a one or two tailed test, only works for 'steiger' method 41 | @param conf_level: confidence level, only works for 'zou' method 42 | @param method: defines the method uses, 'steiger' or 'zou' 43 | @return: t and p-val 44 | """ 45 | if method == 'steiger': 46 | d = xy - xz 47 | determin = 1 - xy * xy - xz * xz - yz * yz + 2 * xy * xz * yz 48 | av = (xy + xz)/2 49 | cube = (1 - yz) * (1 - yz) * (1 - yz) 50 | 51 | t2 = d * np.sqrt((n - 1) * (1 + yz)/(((2 * (n - 1)/(n - 3)) * determin + av * av * cube))) 52 | p = 1 - t.cdf(abs(t2), n - 3) 53 | if twotailed: 54 | p *= 2 55 | return t2, p 56 | 57 | elif method == 'zou': 58 | L1 = rz_ci(xy, n, conf_level=conf_level)[0] 59 | U1 = rz_ci(xy, n, conf_level=conf_level)[1] 60 | L2 = rz_ci(xz, n, conf_level=conf_level)[0] 61 | U2 = rz_ci(xz, n, conf_level=conf_level)[1] 62 | rho_r12_r13 = rho_rxy_rxz(xy, xz, yz) 63 | lower = xy - xz - pow((pow((xy - L1), 2) + pow((U2 - xz), 2) - 2 * rho_r12_r13 * (xy - L1) * (U2 - xz)), 0.5) 64 | upper = xy - xz + pow((pow((U1 - xy), 2) + pow((xz - L2), 2) - 2 * rho_r12_r13 * (U1 - xy) * (xz - L2)), 0.5) 65 | return lower, upper 66 | else: 67 | raise Exception('Wrong method!') 68 | 69 | def independent_corr(xy, ab, n, n2 = None, twotailed=True, conf_level=0.95, method='fisher'): 70 | """ 71 | Calculates the statistic significance between two independent correlation coefficients 72 | @param xy: correlation coefficient between x and y 73 | @param xz: correlation coefficient between a and b 74 | @param n: number of elements in xy 75 | @param n2: number of elements in ab (if distinct from n) 76 | @param twotailed: whether to calculate a one or two tailed test, only works for 'fisher' method 77 | @param conf_level: confidence level, only works for 'zou' method 78 | @param method: defines the method uses, 'fisher' or 'zou' 79 | @return: z and p-val 80 | """ 81 | 82 | if method == 'fisher': 83 | xy_z = 0.5 * np.log((1 + xy)/(1 - xy)) 84 | ab_z = 0.5 * np.log((1 + ab)/(1 - ab)) 85 | if n2 is None: 86 | n2 = n 87 | 88 | se_diff_r = np.sqrt(1/(n - 3) + 1/(n2 - 3)) 89 | diff = xy_z - ab_z 90 | z = abs(diff / se_diff_r) 91 | p = (1 - norm.cdf(z)) 92 | if twotailed: 93 | p *= 2 94 | 95 | return z, p 96 | elif method == 'zou': 97 | L1 = rz_ci(xy, n, conf_level=conf_level)[0] 98 | U1 = rz_ci(xy, n, conf_level=conf_level)[1] 99 | L2 = rz_ci(ab, n2, conf_level=conf_level)[0] 100 | U2 = rz_ci(ab, n2, conf_level=conf_level)[1] 101 | lower = xy - ab - pow((pow((xy - L1), 2) + pow((U2 - ab), 2)), 0.5) 102 | upper = xy - ab + pow((pow((U1 - xy), 2) + pow((ab - L2), 2)), 0.5) 103 | return lower, upper 104 | else: 105 | raise Exception('Wrong method!') 106 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/evaluate_model.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | import pickle as pl 5 | import pdb 6 | import os 7 | 8 | from sklearn.metrics import mean_squared_error 9 | from statsmodels.stats.multitest import fdrcorrection 10 | from scipy import stats 11 | 12 | import seaborn as sns 13 | import matplotlib.pyplot as plt 14 | 15 | from evaluation.CorrelationStats import dependent_corr 16 | 17 | if __name__=='__main__': 18 | 19 | model_dir = 'model_path' 20 | folds = 5 21 | cancers = ['brca', 'coad', 'gbm', 'kirp', 'kirc', 'luad', 'lusc', 'paad', 22 | 'prad', 'skcm', 'thca', 'ucec', 'hnsc', 'stad', 'blca', 'lihc'] 23 | 24 | save_path = os.path.join(model_dir, 'results') 25 | if not os.path.exists(save_path): 26 | os.makedirs(save_path) 27 | 28 | df_list = [] 29 | for cancer_type in cancers: 30 | try: 31 | print(cancer_type) 32 | with open(os.path.join(model_dir, cancer_type, 'test_results.pkl'), 'rb') as f: 33 | test_res = pl.load(f) 34 | 35 | real = [] 36 | pred = [] 37 | random = [] 38 | wsi = [] 39 | genes = test_res['genes'] 40 | 41 | for k in range(folds): 42 | data = test_res[f'split_{k}'] 43 | n_sample = len(data['preds'][:, 0]) 44 | wsi.extend(data['wsi_file_name']) 45 | real.append(pd.DataFrame(data['real'], index = data['wsi_file_name'], columns = genes)) 46 | pred.append(pd.DataFrame(data['preds'], index = data['wsi_file_name'], columns = genes)) 47 | random.append(pd.DataFrame(data['random'], index = data['wsi_file_name'], columns = genes)) 48 | 49 | df_real = pd.concat(real) 50 | df_pred = pd.concat(pred) 51 | df_random = pd.concat(random) 52 | 53 | #Make sure the index (samples) are identical in all the dataframes 54 | assert np.all(df_real.index == df_pred.index) 55 | assert np.all(df_real.index == df_random.index) 56 | 57 | pred_r = [] 58 | random_r = [] 59 | test_p = [] 60 | pearson_p = [] 61 | rmse_pred = [] 62 | rmse_random = [] 63 | rmse_quantile_norm = [] 64 | rmse_mean_norm = [] 65 | valid_genes = [] 66 | 67 | for i, gene in enumerate(genes): 68 | real = df_real.loc[:, gene] 69 | pred = df_pred.loc[:, gene] 70 | random = df_random.loc[:, gene] 71 | 72 | if len(set(pred)) == 1 or len(set(real)) ==1 or len(set(random)) == 1: 73 | xy, xy, yz = 0, 0, 0 74 | p1, p2, p3, p = 1, 1, 1, 1 75 | else: 76 | xy, p1 = stats.pearsonr(real, pred) 77 | xz, p2 = stats.pearsonr(real, random) 78 | yz, p3 = stats.pearsonr(pred, random) 79 | t, p = dependent_corr(xy, xz, yz, len(real), twotailed=False, conf_level=0.95, method='steiger') 80 | 81 | pred_r.append(xy) 82 | random_r.append(xz) 83 | test_p.append(p) 84 | pearson_p.append(p1) 85 | 86 | # RMSE test 87 | rmse_p = mean_squared_error(real, pred, squared=False) 88 | rmse_r = mean_squared_error(real, random, squared=False) 89 | rmse_q = rmse_p / (np.quantile(real, 0.75) - np.quantile(real, 0.25) + 1e-5) 90 | rmse_m = rmse_p / np.mean(real) 91 | 92 | rmse_pred.append(rmse_p) 93 | rmse_random.append(rmse_r) 94 | rmse_quantile_norm.append(rmse_q) 95 | rmse_mean_norm.append(rmse_m) 96 | valid_genes.append(gene) 97 | 98 | combine_res = pd.DataFrame({'pred_real_r': pred_r,\ 99 | 'random_real_r': random_r,\ 100 | 'pearson_p': pearson_p,\ 101 | 'Steiger_p': test_p,\ 102 | 'rmse_pred': rmse_pred, \ 103 | 'rmse_random': rmse_random, 104 | 'rmse_quantile_norm': rmse_quantile_norm, 105 | 'rmse_mean_norm': rmse_mean_norm}, 106 | index=valid_genes) 107 | 108 | combine_res = combine_res.sort_values('pred_real_r', ascending = False) 109 | 110 | # In case of constant values, replace correlation coefficient to 0 111 | combine_res['pred_real_r'] = combine_res['pred_real_r'].fillna(0) 112 | combine_res['random_real_r'] = combine_res['random_real_r'].fillna(0) 113 | 114 | # Correct pearson p values 115 | combine_res['pearson_p'] = combine_res['pearson_p'].fillna(1) 116 | _, fdr_pearson_p = fdrcorrection(combine_res['pearson_p']) 117 | combine_res['fdr_pearson_p'] = fdr_pearson_p 118 | 119 | # Correct steiger p values 120 | combine_res['Steiger_p'] = combine_res['Steiger_p'].fillna(1) 121 | _, fdr_Steiger_p = fdrcorrection(combine_res['Steiger_p']) 122 | combine_res['fdr_Steiger_p'] = fdr_Steiger_p 123 | 124 | combine_res['cancer'] = cancer_type 125 | df_list.append(combine_res) 126 | 127 | except: 128 | print(f'no data for {cancer_type}') 129 | 130 | all_res = pd.concat(df_list) 131 | sig_res = all_res[(all_res['pred_real_r'] > 0) & \ 132 | (all_res['pearson_p'] < 0.05) & \ 133 | (all_res['rmse_pred'] < all_res['rmse_random']) & \ 134 | (all_res['pred_real_r'] > all_res['random_real_r']) & \ 135 | (all_res['Steiger_p'] < 0.05) & \ 136 | (all_res['fdr_Steiger_p'] < 0.2)] 137 | 138 | all_res.to_csv(os.path.join(save_path, 'all_genes.csv')) 139 | sig_res.to_csv(os.path.join(save_path, 'sig_genes.csv')) 140 | 141 | df_num_sig = sig_res['cancer'].value_counts().reset_index() 142 | df_num_sig.columns = ['cancer', 'num_genes'] 143 | df_num_sig.to_csv(os.path.join(save_path, 'num_sign_genes.csv')) -------------------------------------------------------------------------------- /evaluation/predict_independent_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pickle 5 | from torch.utils.data import DataLoader 6 | import torch.nn as nn 7 | 8 | from src.read_data import SuperTileRNADataset 9 | from src.utils import filter_no_features, custom_collate_fn 10 | from src.vit import train, evaluate, predict 11 | from src.tformer_lin import ViS 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser(description='Getting features') 16 | parser.add_argument('--ref_file', type=str, required=True, help='Reference file') 17 | parser.add_argument('--feature_path', type=str, default='', help='Directory where pre-processed WSI features are stored') 18 | parser.add_argument('--feature_use', type=str, default='cluster_mean_features', help='Which feature to use for training the model') 19 | parser.add_argument('--folds', type=int, default=5, help='Folds for pre-trained model') 20 | parser.add_argument('--seed', type=int, default=99, help='Seed for random generation') 21 | parser.add_argument('--batch_size', type=int, default=16, help='Batch size') 22 | parser.add_argument('--depth', type=int, default=6, help='Transformer depth') 23 | parser.add_argument('--num-heads', type=int, default=16, help='Number of attention heads') 24 | parser.add_argument('--tcga_project', default=None, type=str, default='', help='The tcga_project we want to use') 25 | parser.add_argument('--save_dir', type=str, default='', help='Where to save results') 26 | parser.add_argument('--exp_name', type=str, default='exp', help='Experiment name') 27 | 28 | ############################################## variables ############################################## 29 | 30 | args = parser.parse_args() 31 | np.random.seed(args.seed) 32 | torch.manual_seed(args.seed) 33 | device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") 34 | print(device) 35 | 36 | ############################################## saving ############################################## 37 | 38 | save_dir = os.path.join(args.save_dir, args.exp_name) 39 | if not os.path.exists(save_dir): 40 | os.makedirs(save_dir) 41 | 42 | ############################################## data prep ############################################## 43 | 44 | path_csv = args.ref_file 45 | df = pd.read_csv(path_csv) 46 | 47 | # filter out WSIs for which we don't have features and filter on TCGA project 48 | df = filter_no_features(df, feature_path = args.feature_path, feature_name = args.feature_use) 49 | genes = [c[4:] for c in df.columns if "rna_" in c] 50 | if 'tcga_project' in df.columns and args.tcga_project: 51 | df = df[df['tcga_project'].isin(args.tcga_project)].reset_index(drop=True) 52 | 53 | # init test dataloader 54 | test_dataset = SuperTileRNADataset(df, args.feature_path, args.feature_use) 55 | test_dataloader = DataLoader(test_dataset, 56 | num_workers=0, pin_memory=True, 57 | shuffle=False, batch_size=args.batch_size, 58 | collate_fn=custom_collate_fn) 59 | feature_dim = test_dataset.feature_dim 60 | 61 | res_preds = [] 62 | res_random = [] 63 | cancer = args.tcga_project.split('-')[-1].lower() 64 | 65 | for fold in range(args.folds): 66 | 67 | # load model from huggingface 68 | model = ViS.from_pretrained(f"gevaertlab/sequoia-{cancer}-{fold}") 69 | model.to(device) 70 | 71 | # model prediction on test set 72 | preds, wsis, projs = predict(model, test_dataloader, run=None) 73 | 74 | # random predictions 75 | random_model = ViS(num_outputs=test_dataset.num_genes, 76 | input_dim=feature_dim, 77 | depth=args.depth, nheads=args.num_heads, 78 | dimensions_f=64, dimensions_c=64, dimensions_s=64, device=device) 79 | random_model.to(device) 80 | random_preds, _, _ = predict(random_model, test_dataloader, run=None) 81 | 82 | # save predictions 83 | res_preds.append(preds) 84 | res_random.append(random_preds) 85 | 86 | # calculate average across folds 87 | avg_preds = np.mean(res_preds, axis = 0) 88 | avg_random = np.mean(res_random, axis = 0) 89 | 90 | df_pred = pd.DataFrame(avg_preds, index = wsis, columns = genes) 91 | df_random = pd.DataFrame(avg_random, index = wsis, columns = genes) 92 | 93 | test_results = {'pred': df_pred, 'random': df_random} 94 | 95 | with open(os.path.join(save_dir, 'test_results.pkl'), 'wb') as f: 96 | pickle.dump(test_results, f, protocol=pickle.HIGHEST_PROTOCOL) -------------------------------------------------------------------------------- /images/overview_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/images/overview_new.png -------------------------------------------------------------------------------- /images/seq-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/images/seq-logo.png -------------------------------------------------------------------------------- /images/sequoia-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/images/sequoia-logo.png -------------------------------------------------------------------------------- /pre_processing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/pre_processing/__init__.py -------------------------------------------------------------------------------- /pre_processing/compute_features_hdf5.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pickle 5 | from tqdm import tqdm 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision import transforms 10 | import h5py 11 | import timm 12 | from PIL import Image 13 | 14 | import pdb 15 | 16 | from src.read_data import * 17 | from src.resnet import resnet50 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser(description='Getting features') 21 | 22 | parser.add_argument('--feat_type', default="uni", type=str, required=True, help='Which feature extractor to use, either "resnet" or "uni"') 23 | parser.add_argument('--ref_file', default="/examples/ref_file.csv", type=str, required=True, help='Path with reference csv file') 24 | parser.add_argument('--patch_data_path', default="/examples/Patches_hdf5", type=str, required=True, help='Directory where the patch is saved') 25 | parser.add_argument('--feature_path', type=str, default="/examples/features", help='Output directory to save features') 26 | parser.add_argument('--max_patch_number', type=int, default=4000, help='Max number of patches to use per slide') 27 | parser.add_argument('--seed', type=int, default=99, help='Seed for random generation') 28 | parser.add_argument("--tcga_projects", help="the tcga_projects we want to use", default=None, type=str, nargs='*') 29 | parser.add_argument('--start', type=int, default=0, help='Start slide index for parallelization') 30 | parser.add_argument('--end', type=int, default=None, help='End slide index for parallelization') 31 | args = parser.parse_args() 32 | 33 | np.random.seed(args.seed) 34 | torch.manual_seed(args.seed) 35 | 36 | print(10*'-') 37 | print('Args for this experiment \n') 38 | print(args) 39 | print(10*'-') 40 | 41 | random.seed(args.seed) 42 | 43 | path_csv = args.ref_file 44 | patch_data_path = args.patch_data_path 45 | 46 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 47 | 48 | if args.feat_type == 'resnet': 49 | transforms_val = torch.nn.Sequential( 50 | transforms.ConvertImageDtype(torch.float), 51 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])) 52 | else: 53 | transforms_val = transforms.Compose([ 54 | transforms.Resize(224), 55 | transforms.ToTensor(), 56 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),]) 57 | 58 | if args.feat_type == 'resnet': 59 | model = resnet50(pretrained=True).to(device) 60 | model.eval() 61 | else: 62 | local_dir = "" # add dir for saved model 63 | model = timm.create_model("vit_large_patch16_224", img_size=224, patch_size=16, 64 | init_values=1e-5, num_classes=0, dynamic_img_size=True) 65 | model.load_state_dict(torch.load(os.path.join(local_dir, 66 | "pytorch_model.bin"), map_location="cpu"), strict=True) 67 | model.to(device) 68 | model.eval() 69 | 70 | print('Loading dataset...') 71 | 72 | df = pd.read_csv(path_csv) 73 | df = df.drop_duplicates(["wsi_file_name"]) # there could be duplicated WSIs mapped to different RNA files and we only need features for each WSI 74 | 75 | # Filter tcga projects 76 | if args.tcga_projects: 77 | df = df[df['tcga_project'].isin(args.tcga_projects)] 78 | 79 | # indexing based on values for parallelization 80 | if args.start is not None and args.end is not None: 81 | df = df.iloc[args.start:args.end] 82 | elif args.start is not None: 83 | df = df.iloc[args.start:] 84 | elif args.end is not None: 85 | df = df.iloc[:args.end] 86 | 87 | print(f'Number of slides = {df.shape[0]}') 88 | 89 | for i, row in tqdm(df.iterrows()): 90 | WSI = row['wsi_file_name'] 91 | WSI_slide = WSI.split('.')[0] 92 | project = row['tcga_project'] 93 | WSI = WSI.replace('.svs', '') # in the ref file of prad there is a .svs that should not be there 94 | 95 | if not os.path.exists(os.path.join(patch_data_path, WSI_slide)): 96 | print('Not exist {}'.format(os.path.join(patch_data_path, WSI_slide))) 97 | continue 98 | 99 | path = os.path.join(patch_data_path, WSI_slide, WSI_slide + '.hdf5') 100 | path_h5 = os.path.join(args.feature_path, project, WSI) 101 | 102 | if not os.path.exists(path_h5): 103 | os.makedirs(path_h5) 104 | 105 | if os.path.exists(os.path.join(path_h5, "complete_resnet.txt")): 106 | print(f'{WSI}: Resnet features already obtained') 107 | continue 108 | 109 | try: 110 | with h5py.File(path, 'r') as f_read: 111 | keys = list(f_read.keys()) 112 | if len(keys) > args.max_patch_number: 113 | keys = random.sample(keys, args.max_patch_number) 114 | 115 | features_tiles = [] 116 | for key in tqdm(keys): 117 | image = f_read[key][:] 118 | if args.feat_type == 'resnet': 119 | image = torch.from_numpy(image).permute(2,0,1) 120 | image = transforms_val(image).to(device) 121 | with torch.no_grad(): 122 | features = model.forward_extract(image[None,:]) 123 | features_tiles.append(features[0].detach().cpu().numpy()) 124 | else: 125 | image = Image.fromarray(image).convert("RGB") 126 | image = transforms_val(image).to(device) 127 | with torch.no_grad(): 128 | features = model(image[None,:]) 129 | features_tiles.append(features[0].detach().cpu().numpy()) 130 | 131 | features_tiles = np.asarray(features_tiles) 132 | n_tiles = len(features_tiles) 133 | 134 | f_write = h5py.File(os.path.join(path_h5, WSI+'.h5'), "w") 135 | dset = f_write.create_dataset(f"{args.feat_type}_features", data=features_tiles) 136 | f_write.close() 137 | 138 | with open(os.path.join(path_h5, "complete_tile.txt"), 'w') as f_sum: 139 | f_sum.write(f"Total n patch = {n_tiles}") 140 | 141 | except Exception as e: 142 | print(e) 143 | print(WSI) 144 | continue 145 | 146 | -------------------------------------------------------------------------------- /pre_processing/download_RNASeq_TCGAbiolinks.R: -------------------------------------------------------------------------------- 1 | # Download FPKM-UQ value from GDC data portal 2 | library(TCGAbiolinks) 3 | library(SummarizedExperiment) 4 | 5 | master.dir <- "." 6 | save.dir <- file.path(master.dir, "gene_expression_FPKM_UQ") 7 | 8 | cancer.types <- c("LUAD", "LUSC", "BRCA", "GBM", "COAD", "KIRC", "PAAD", "PRAD") 9 | 10 | gene.numbers <- c() 11 | protein.numbers <- c() 12 | patient.numbers <- c() 13 | 14 | for (cancer in cancer.types){ 15 | data <- fpkm.data <- NULL 16 | query <- GDCquery( 17 | project = paste0("TCGA-", cancer), 18 | data.category = "Transcriptome Profiling", 19 | data.type = "Gene Expression Quantification", 20 | workflow.type = "STAR - Counts" 21 | ) 22 | 23 | GDCdownload(query = query) 24 | data <- GDCprepare(query = query) 25 | data <- data[which(rowData(data)$gene_type %in% c("protein_coding", 'miRNA', "lncRNA")), ] 26 | fpkm.data <- assays(data)$fpkm_uq_unstrand 27 | rownames(fpkm.data) <- rowData(data)$gene_name 28 | fpkm.data <- fpkm.data[which(rowMedians(fpkm.data)>0), ] 29 | gene.types <- rowData(data)$gene_type[match(rownames(fpkm.data), rowData(data)$gene_name)] 30 | gene.numbers <- c(gene.numbers, nrow(fpkm.data)) 31 | protein.numbers <- c(protein.numbers, table(gene.types)['protein_coding']) 32 | patient.numbers <- c(patient.numbers, ncol(fpkm.data)) 33 | 34 | write.table(fpkm.data, paste0(save.dir, "/", cancer, ".txt"), sep = " ") 35 | } 36 | 37 | df_gene_number <- data.frame(cancer = cancer.types, n_gene = gene.numbers, n_protein_coding = protein.numbers, n_patient = patient.numbers) 38 | 39 | write.csv(df_gene_number, paste0(save.dir, "/", "gene_number_summary_3.csv")) 40 | 41 | -------------------------------------------------------------------------------- /pre_processing/kmean_features.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pdb 5 | from tqdm import tqdm 6 | 7 | import numpy as np 8 | import h5py 9 | from sklearn.cluster import KMeans 10 | 11 | from src.read_data import * 12 | from src.utils import exists 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser(description='Getting features') 16 | parser.add_argument('--ref_file', default="/examples/ref_file.csv", type=str, required=True, help='Path with reference csv file') 17 | parser.add_argument('--patch_data_path', default="/examples/Patches_hdf5", type=str, required=True, help='Directory where the patch is saved') 18 | parser.add_argument('--feature_path', type=str, default="/examples/features", help='Output directory to save features') 19 | parser.add_argument('--num_clusters', type=int, default=100, 20 | help='Number of clusters for the kmeans') 21 | parser.add_argument("--tcga_projects", help="the tcga_projects we want to use", 22 | default=None, type=str, nargs='*') 23 | parser.add_argument('--start', type=int, default=0, 24 | help='Start slide index for parallelization') 25 | parser.add_argument('--end', type=int, default=None, 26 | help='End slide index for parallelization') 27 | parser.add_argument("--gtex", help="using gtex data", 28 | action="store_true") 29 | parser.add_argument('--gtex_tissue', type=str, default=None, 30 | help='GTex tissue being used') 31 | parser.add_argument('--seed', type=int, default=99, 32 | help='Seed for random generation') 33 | 34 | args = parser.parse_args() 35 | 36 | np.random.seed(args.seed) 37 | torch.manual_seed(args.seed) 38 | 39 | print(10*'-') 40 | print('Args for this experiment \n') 41 | print(args) 42 | print(10*'-') 43 | 44 | path_csv = args.ref_file 45 | 46 | df = pd.read_csv(path_csv) 47 | df = df.drop_duplicates(['wsi_file_name']) 48 | 49 | # Filter tcga projects 50 | if args.tcga_projects: 51 | df = df[df['tcga_project'].isin(args.tcga_projects)] 52 | 53 | print(f'Total number of slides = {df.shape[0]}') 54 | 55 | # indexing based on values for parallelization 56 | if exists(args.start) and exists(args.end): 57 | df = df.iloc[args.start:args.end] 58 | elif exists(args.start): 59 | df = df.iloc[args.start:] 60 | elif args.end: 61 | df = df.iloc[:args.end] 62 | 63 | print(f'New number of slides = {df.shape[0]}') 64 | 65 | for i, row in tqdm(df.iterrows()): 66 | WSI = row['wsi_file_name'] 67 | if args.gtex: 68 | project = args.gtex_tissue 69 | else: 70 | project = df.iloc[0]['tcga_project'] 71 | WSI = WSI.replace('.svs', '') 72 | 73 | path = os.path.join(args.feature_path, project, WSI) 74 | try: 75 | f = h5py.File(os.path.join(path,WSI+'.h5'), "r+") 76 | except: 77 | print(f'Cannot open file {path}') 78 | continue 79 | try: 80 | features = f['resnet_features'] 81 | except: 82 | print(f'No resnet features for {path}') 83 | f.close() 84 | continue 85 | 86 | if features.shape[0] < args.num_clusters: 87 | print(f'{WSI} less number of patches than clusters') 88 | f.close() 89 | continue 90 | 91 | if 'cluster_features' in f.keys(): 92 | print(f'{WSI}: Cluster feature already available') 93 | f.close() 94 | continue 95 | 96 | kmeans = KMeans(n_clusters=args.num_clusters, random_state=0).fit(features) 97 | clusters = kmeans.labels_ 98 | 99 | mean_features = [] 100 | for pos in tqdm(range(args.num_clusters)): 101 | indexes = np.where(clusters == pos) 102 | features_aux = features[indexes] 103 | mean_features.append(np.mean(features_aux, axis=0)) 104 | 105 | mean_features = np.asarray(mean_features) 106 | 107 | try: 108 | dset = f.create_dataset("cluster_features", data=mean_features) 109 | f.close() 110 | except Exception as e: 111 | print(f"{WSI}: Error in creating cluster_feauture") 112 | print(e) 113 | f.close() 114 | print('Done!') 115 | -------------------------------------------------------------------------------- /pre_processing/patch_gen_hdf5.py: -------------------------------------------------------------------------------- 1 | # This script generates pathes from whole slide images (e.g. from TCGA) and save the extracted patches into a hdf5 file. 2 | # This will save all the paches but keeps the file number small. 3 | 4 | import pandas as pd 5 | import numpy as np 6 | from openslide import OpenSlide 7 | from multiprocessing import Pool, Value, Lock 8 | import os 9 | from skimage.color import rgb2hsv 10 | from skimage.filters import threshold_otsu 11 | from skimage.io import imsave, imread 12 | from skimage.exposure.exposure import is_low_contrast 13 | from skimage.transform import resize 14 | from scipy.ndimage import binary_dilation, binary_erosion 15 | import argparse 16 | import logging 17 | import h5py 18 | from tqdm import tqdm 19 | 20 | import pickle 21 | import re 22 | import pdb 23 | import pandas as pd 24 | 25 | def get_mask_image(img_RGB, RGB_min=50): 26 | img_HSV = rgb2hsv(img_RGB) 27 | 28 | background_R = img_RGB[:, :, 0] > threshold_otsu(img_RGB[:, :, 0]) 29 | background_G = img_RGB[:, :, 1] > threshold_otsu(img_RGB[:, :, 1]) 30 | background_B = img_RGB[:, :, 2] > threshold_otsu(img_RGB[:, :, 2]) 31 | tissue_RGB = np.logical_not(background_R & background_G & background_B) 32 | tissue_S = img_HSV[:, :, 1] > threshold_otsu(img_HSV[:, :, 1]) 33 | min_R = img_RGB[:, :, 0] > RGB_min 34 | min_G = img_RGB[:, :, 1] > RGB_min 35 | min_B = img_RGB[:, :, 2] > RGB_min 36 | 37 | mask = tissue_S & tissue_RGB & min_R & min_G & min_B 38 | return mask 39 | 40 | def get_mask(slide, level='max', RGB_min=50): 41 | #read svs image at a certain level and compute the otsu mask 42 | if level == 'max': 43 | level = len(slide.level_dimensions) - 1 44 | # note the shape of img_RGB is the transpose of slide.level_dimensions 45 | img_RGB = np.transpose(np.array(slide.read_region((0, 0),level,slide.level_dimensions[level]).convert('RGB')), 46 | axes=[1, 0, 2]) 47 | 48 | tissue_mask = get_mask_image(img_RGB, RGB_min) 49 | return tissue_mask, level 50 | 51 | def extract_patches(slide_path, mask_path, patch_size, patches_output_dir, slide_id, max_patches_per_slide=2000): 52 | 53 | patch_folder = os.path.join(patches_output_dir, slide_id) 54 | if not os.path.isdir(patch_folder): 55 | os.makedirs(patch_folder) 56 | 57 | patch_folder_mask = os.path.join(mask_path, slide_id) 58 | if not os.path.isdir(patch_folder_mask): 59 | os.makedirs(patch_folder_mask) 60 | 61 | if os.path.exists(os.path.join(patch_folder, "complete.txt")): 62 | print(f'{slide_id}: patches have already been extreacted') 63 | return 64 | 65 | path_hdf5 = os.path.join(patch_folder, f"{slide_id}.hdf5") 66 | hdf = h5py.File(path_hdf5, 'w') 67 | 68 | slide = OpenSlide(slide_path) 69 | mask, mask_level = get_mask(slide) 70 | mask = binary_dilation(mask, iterations=3) 71 | mask = binary_erosion(mask, iterations=3) 72 | np.save(os.path.join(patch_folder_mask, "mask.npy"), mask) 73 | 74 | mask_level = len(slide.level_dimensions) - 1 75 | 76 | PATCH_LEVEL = 0 77 | BACKGROUND_THRESHOLD = .2 78 | 79 | try: 80 | ratio_x = slide.level_dimensions[PATCH_LEVEL][0] / slide.level_dimensions[mask_level][0] 81 | ratio_y = slide.level_dimensions[PATCH_LEVEL][1] / slide.level_dimensions[mask_level][1] 82 | 83 | xmax, ymax = slide.level_dimensions[PATCH_LEVEL] 84 | 85 | # handle slides with 40 magnification at base level 86 | resize_factor = float(slide.properties.get('aperio.AppMag', 20)) / 20.0 87 | if not slide.properties.get('aperio.AppMag', 20): print(f"magnifications for {slide_id} is not found, using default magnification 20X") 88 | 89 | patch_size_resized = (int(resize_factor * patch_size[0]), int(resize_factor * patch_size[1])) 90 | print(f"patch size for {slide_id}: {patch_size_resized}") 91 | 92 | i = 0 93 | indices = [(x, y) for x in range(0, xmax, patch_size_resized[0]) for y in 94 | range(0, ymax, patch_size_resized[0])] 95 | 96 | # here, we generate all the pathes with valid mask 97 | if max_patches_per_slide is None: 98 | max_patches_per_slide = len(indices) 99 | 100 | np.random.seed(5) 101 | np.random.shuffle(indices) 102 | 103 | for x, y in indices: 104 | # check if in background mask 105 | x_mask = int(x / ratio_x) 106 | y_mask = int(y / ratio_y) 107 | if mask[x_mask, y_mask] == 1: 108 | patch = slide.read_region((x, y), PATCH_LEVEL, patch_size_resized).convert('RGB') 109 | try: 110 | mask_patch = get_mask_image(np.array(patch)) 111 | mask_patch = binary_dilation(mask_patch, iterations=3) 112 | except Exception as e: 113 | print("error with slide id {} patch {}".format(slide_id, i)) 114 | print(e) 115 | if (mask_patch.sum() > BACKGROUND_THRESHOLD * mask_patch.size) and not (is_low_contrast(patch)): 116 | if resize_factor != 1.0: 117 | patch = patch.resize(patch_size) 118 | patch = np.array(patch) 119 | tile_name = f"{x}_{y}" 120 | hdf.create_dataset(tile_name, data=patch) 121 | i = i + 1 122 | if i >= max_patches_per_slide: 123 | break 124 | 125 | hdf.close() 126 | 127 | if i == 0: 128 | print("no patch extracted for slide {}".format(slide_id)) 129 | else: 130 | with open(os.path.join(patch_folder, "complete.txt"), 'w') as f: 131 | f.write('Process complete!\n') 132 | f.write(f"Total n patch = {i}") 133 | print(f"{slide_id} complete, total n patch = {i}") 134 | 135 | except Exception as e: 136 | print("error with slide id {} patch {}".format(slide_id, i)) 137 | print(e) 138 | 139 | def get_slide_id(slide_name): 140 | return slide_name.split('.')[0] 141 | 142 | def process(opts): 143 | slide_path, patch_size, patches_output_dir, mask_path, slide_id, max_patches_per_slide = opts 144 | extract_patches(slide_path, mask_path, patch_size, 145 | patches_output_dir, slide_id, max_patches_per_slide) 146 | 147 | 148 | parser = argparse.ArgumentParser(description='Generate patches from a given folder of images') 149 | parser.add_argument('--ref_file', default="examples/ref_file.csv", required=False, metavar='ref_file', type=str, 150 | help='Path to the ref_file, if provided, only the WSIs in the ref file will be processed') 151 | parser.add_argument('--wsi_path', default="examples/HE", metavar='WSI_PATH', type=str, 152 | help='Path to the input directory of WSI files') 153 | parser.add_argument('--patch_path', default="examples/Patches_hdf5" ,metavar='PATCH_PATH', type=str, 154 | help='Path to the output directory of patch images') 155 | parser.add_argument('--mask_path', default="examples/Patches_hdf5", metavar='MASK_PATH', type=str, 156 | help='Path to the directory of numpy masks') 157 | parser.add_argument('--patch_size', default=256, type=int, help='patch size, ' 158 | 'default 256') 159 | parser.add_argument('--start', type=int, default=0, 160 | help='Start slide index for parallelization') 161 | parser.add_argument('--end', type=int, default=None, 162 | help='End slide index for parallelization') 163 | parser.add_argument('--max_patches_per_slide', default=None, type=int) 164 | parser.add_argument('--debug', default=0, type=int, 165 | help='whether to use debug mode') 166 | parser.add_argument('--parallel', default=1, type=int, 167 | help='whether to use parallel computation') 168 | 169 | 170 | if __name__ == '__main__': 171 | 172 | args = parser.parse_args() 173 | slide_list = os.listdir(args.wsi_path) 174 | slide_list = [s for s in slide_list if s.endswith('.svs') or s.endswith('.tiff')] 175 | 176 | if args.ref_file: 177 | ref_file = pd.read_csv(args.ref_file) 178 | selected_slides = list(ref_file['wsi_file_name']) 179 | wsi_files = [f'{s}.svs' for s in selected_slides] 180 | slide_list = list(set(slide_list) & set(wsi_files)) 181 | slide_list = sorted(slide_list) 182 | 183 | if args.start is not None and args.end is not None: 184 | slide_list = slide_list[args.start:args.end] 185 | elif args.start is not None: 186 | slide_list = slide_list[args.start:] 187 | elif args.end is not None: 188 | slide_list = slide_list[:args.end] 189 | 190 | if args.debug: 191 | slide_list = slide_list[0:5] 192 | args.max_patches_per_slide = 20 193 | 194 | print(f"Found {len(slide_list)} slides") 195 | 196 | opts = [ 197 | (os.path.join(args.wsi_path, s), (args.patch_size, args.patch_size), args.patch_path, args.mask_path, 198 | get_slide_id(s), args.max_patches_per_slide) for 199 | (i, s) in enumerate(slide_list)] 200 | 201 | if args.parallel: 202 | pool = Pool(processes=4) 203 | pool.map(process, opts) 204 | else: 205 | for opt in opts: 206 | process(opt) 207 | 208 | 209 | -------------------------------------------------------------------------------- /pre_processing/patient_splits.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/pre_processing/patient_splits.zip -------------------------------------------------------------------------------- /pre_processing/test_wsis.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/pre_processing/test_wsis.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.29.3 2 | anndata==0.10.7 3 | appdirs==1.4.4 4 | array-api-compat==1.6 5 | certifi==2024.2.2 6 | charset-normalizer==3.3.2 7 | click==8.1.7 8 | contourpy==1.2.1 9 | cycler==0.12.1 10 | docker-pycreds==0.4.0 11 | einops==0.4.1 12 | exceptiongroup==1.2.1 13 | filelock==3.13.4 14 | fonttools==4.51.0 15 | fsspec==2024.3.1 16 | get-annotations==0.1.2 17 | gitdb==4.0.11 18 | GitPython==3.1.43 19 | h5py==3.11.0 20 | huggingface-hub==0.22.2 21 | idna==3.7 22 | imageio==2.34.0 23 | importlib-resources==6.4.0 24 | Jinja2==3.1.3 25 | joblib==1.4.0 26 | kiwisolver==1.4.5 27 | lazy-loader==0.4 28 | legacy-api-wrap==1.4 29 | llvmlite==0.42.0 30 | lmdb==1.4.1 31 | MarkupSafe==2.1.5 32 | matplotlib==3.8.4 33 | mpmath==1.3.0 34 | natsort==8.4.0 35 | networkx==3.2.1 36 | numba==0.59.1 37 | numpy==1.26.4 38 | nvidia-cublas-cu12==12.1.3.1 39 | nvidia-cuda-cupti-cu12==12.1.105 40 | nvidia-cuda-nvrtc-cu12==12.1.105 41 | nvidia-cuda-runtime-cu12==12.1.105 42 | nvidia-cudnn-cu12==8.9.2.26 43 | nvidia-cufft-cu12==11.0.2.54 44 | nvidia-curand-cu12==10.3.2.106 45 | nvidia-cusolver-cu12==11.4.5.107 46 | nvidia-cusparse-cu12==12.1.0.106 47 | nvidia-nccl-cu12==2.19.3 48 | nvidia-nvjitlink-cu12==12.4.127 49 | nvidia-nvtx-cu12==12.1.105 50 | opencv-python==4.9.0.80 51 | openslide-python==1.3.1 52 | packaging==24.0 53 | pandas==2.2.2 54 | patsy==0.5.6 55 | pillow==10.3.0 56 | POT==0.9.3 57 | protobuf==4.25.3 58 | psutil==5.9.8 59 | py-lz4framed==0.14.0 60 | pynndescent==0.5.12 61 | pyparsing==3.1.2 62 | python-dateutil==2.9.0.post0 63 | pytz==2024.1 64 | PyYAML==6.0.1 65 | requests==2.31.0 66 | safetensors==0.4.3 67 | scanpy==1.10.1 68 | scikit-image==0.22.0 69 | scikit-learn==1.4.2 70 | scipy==1.13.0 71 | seaborn==0.13.2 72 | sentry-sdk==1.45.0 73 | session-info==1.0.0 74 | setproctitle==1.3.3 75 | six==1.16.0 76 | smmap==5.0.1 77 | statsmodels==0.14.2 78 | stdlib-list==0.10.0 79 | sympy==1.12 80 | threadpoolctl==3.4.0 81 | tifffile==2024.4.18 82 | torch==2.2.2 83 | torchvision==0.17.2 84 | tqdm==4.66.2 85 | triton==2.2.0 86 | typing-extensions==4.11.0 87 | tzdata==2024.1 88 | umap-learn==0.5.6 89 | urllib3==2.2.1 90 | wandb==0.16.6 91 | zipp==3.18.1 92 | -------------------------------------------------------------------------------- /scripts/extract_kmean_features.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 pre_processing/kmean_features.py \ 3 | --ref_file ./examples/ref_file.csv \ 4 | --patch_data_path ./examples/Patches_hdf5 \ 5 | --feature_path ./examples/features \ 6 | --num_clusters 100 -------------------------------------------------------------------------------- /scripts/extract_patch.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | 3 | python3 pre_processing/patch_gen_hdf5.py \ 4 | --ref_file ./examples/ref_file.csv \ 5 | --wsi_path ./examples/HE \ 6 | --patch_path ./examples/Patches_hdf5 \ 7 | --mask_path ./examples/Patches_hdf5 \ 8 | --patch_size 256 \ 9 | --max_patches_per_slide 4000 -------------------------------------------------------------------------------- /scripts/extract_resnet_features.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 pre_processing/compute_resnet_features_hdf5.py \ 4 | --ref_file ./examples/ref_file.csv \ 5 | --patch_data_path ./examples/Patches_hdf5 \ 6 | --feature_path ./examples/features \ 7 | --max_patch_number 4000 \ 8 | --feat_type resnet -------------------------------------------------------------------------------- /scripts/run_he2rna.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python3 src/he2rna.py \ 3 | --path_csv examples/ref_file.csv \ 4 | --subfolder he2rna \ 5 | --exp_name BRCA \ 6 | --lr 1e-3 \ 7 | --checkpoint pretrained_models/model.pt \ 8 | --change_num_genes \ 9 | --num_genes 19198 \ 10 | --log 0 11 | 12 | -------------------------------------------------------------------------------- /scripts/run_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 src/main.py \ 4 | --model_type vis \ 5 | --ref_file examples/ref_file.csv \ 6 | --save_dir output \ 7 | --cohort TCGA \ 8 | --exp_name run_train \ 9 | --batch_size 16 \ 10 | --checkpoint pretrained_models/model_best.pt \ 11 | --k 5 \ 12 | --train \ 13 | --log 0 \ 14 | --change_num_genes \ 15 | --num_genes 19198 \ 16 | --save_on loss+corr \ 17 | --stop_on loss+corr 18 | -------------------------------------------------------------------------------- /scripts/run_visualize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python3 src/visualize.py --study gbm \ 4 | --project spatial_GBM_pred \ 5 | --wsi_file_name HRI_251_T.tif \ 6 | --gene_names $path/gene_ids/top_1000_gbm.npy \ 7 | --save_folder top_1000_gbm \ 8 | --model_type vis \ 9 | --feat_type uni -------------------------------------------------------------------------------- /spatial_vis/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /spatial_vis/gbm_celltype_analysis.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import matplotlib.pyplot as plt 3 | from functools import reduce 4 | import numpy as np 5 | from scipy.stats import percentileofscore 6 | import os 7 | from tqdm import tqdm 8 | import math 9 | import seaborn as sns 10 | import matplotlib.colors 11 | 12 | def score2percentile(score, ref): 13 | if np.isnan(score): 14 | return score # deal with nans in visualization (set to black) 15 | percentile = percentileofscore(ref, score) 16 | return percentile 17 | 18 | if __name__=='__main__': 19 | root = '.' 20 | src_path = root + 'visualizations/spatial_GBM_pred/' 21 | folder = 'gbm_celltypes/' 22 | draw_heatmaps = True 23 | all_genes = np.load(root + 'gene_ids/gbm_experiments/all.npy', allow_pickle=True) 24 | 25 | slide_names = os.listdir(src_path + folder) 26 | slide_names = [i for i in slide_names if i not in ['corr_maps', 'spatial_maps']] 27 | all_corr_dfs = [] 28 | 29 | dests = [src_path + folder + '/corr_maps/', src_path + folder + '/spatial_maps/'] 30 | for dest in dests: 31 | if not os.path.exists(dest): 32 | os.makedirs(dest) 33 | 34 | ac = np.load(root + 'gene_ids/celltypes/AC.npy',allow_pickle=True) 35 | g1s = np.load(root + 'gene_ids/celltypes/G1S.npy',allow_pickle=True) 36 | g2m = np.load(root + 'gene_ids/celltypes/G2M.npy',allow_pickle=True) 37 | mes1 = np.load(root + 'gene_ids/celltypes/MES1.npy',allow_pickle=True) 38 | mes2 = np.load(root + 'gene_ids/celltypes/MES2.npy',allow_pickle=True) 39 | npc1 = np.load(root + 'gene_ids/celltypes/NPC1.npy',allow_pickle=True) 40 | npc2 = np.load(root + 'gene_ids/celltypes/NPC2.npy',allow_pickle=True) 41 | opc = np.load(root + 'gene_ids/celltypes/OPC.npy',allow_pickle=True) 42 | mapper = {} 43 | 44 | green = '#CEBC36' 45 | red = '#CE3649' 46 | blue = '#3648CE' 47 | purple = '#36CEBC' 48 | 49 | mapper.update(dict.fromkeys(ac, matplotlib.colors.to_rgb(purple))) # purple 50 | mapper.update(dict.fromkeys(g1s, matplotlib.colors.to_rgb(red))) # red 51 | mapper.update(dict.fromkeys(g2m, matplotlib.colors.to_rgb(red))) 52 | mapper.update(dict.fromkeys(mes1, matplotlib.colors.to_rgb(blue))) #blue 53 | mapper.update(dict.fromkeys(mes2, matplotlib.colors.to_rgb(blue))) 54 | mapper.update(dict.fromkeys(npc1, matplotlib.colors.to_rgb(green))) #green 55 | mapper.update(dict.fromkeys(npc2, matplotlib.colors.to_rgb(green))) 56 | mapper.update(dict.fromkeys(opc, matplotlib.colors.to_rgb(green))) 57 | 58 | max_lim = 0 59 | 60 | for slide_name in tqdm(slide_names): 61 | 62 | source_path = src_path + folder + '/' + slide_name 63 | path = source_path + '/stride-1.csv' 64 | df = pd.read_csv(path) 65 | 66 | df_max = max_lim = max(max(df.xcoord_tf), max(df.ycoord_tf)) 67 | if df_max > max_lim: 68 | max_lim = df_max 69 | 70 | all_genes = list(set(all_genes)&set(df.columns)) 71 | 72 | df = df.dropna(axis=0, how='any') 73 | df = df[['xcoord_tf','ycoord_tf']+all_genes] 74 | 75 | corrdf = df[all_genes].corr() 76 | kind = corrdf.columns.map(mapper) 77 | all_corr_dfs.append(corrdf) 78 | 79 | plt.close() 80 | plt.figure() 81 | pl = sns.clustermap(corrdf, row_colors=kind, cmap='magma') #, yticklabels=True, xticklabels=True, figsize=(50,50) 82 | pl.ax_row_dendrogram.set_visible(False) 83 | pl.ax_col_dendrogram.set_visible(False) 84 | plt.savefig(src_path + folder + '/corr_maps/' + slide_name + '_clustered.png', bbox_inches='tight', dpi=300) 85 | 86 | if draw_heatmaps: 87 | 88 | scaling_factor = 2 89 | max_lim += scaling_factor*5 90 | 91 | for slide_name in tqdm(slide_names): 92 | 93 | source_path = src_path + folder + '/' + slide_name 94 | path = source_path + '/stride-1.csv' 95 | df = pd.read_csv(path) 96 | all_genes = list(set(all_genes)&set(df.columns)) 97 | df = df.dropna(axis=0, how='any') 98 | df = df[['xcoord_tf','ycoord_tf']+all_genes] 99 | 100 | categories = [ac.tolist(), g1s.tolist()+g2m.tolist(), mes1.tolist()+mes2.tolist(), npc1.tolist()+npc2.tolist()+opc.tolist()] 101 | labels = ['ac', 'cc', 'mes', 'lin'] 102 | colors = {'ac':purple, 'cc':red, 'mes':blue, 'lin':green} 103 | 104 | for j,label in enumerate(labels): 105 | df[label] = df[[i for i in categories[j] if i in df.columns]].mean(axis=1) 106 | ref = df[label].values 107 | df[label + '_perc'] = df.apply(lambda row: score2percentile(row[label], ref), axis=1) 108 | 109 | df['color'] = df[[i+'_perc' for i in labels]].idxmax(axis=1) 110 | df['color'] = df['color'].str.replace('_perc', '') 111 | df['color'] = df['color'].map(colors) 112 | 113 | plt.close() 114 | fig, ax = plt.subplots() 115 | x_padding = int((max_lim-max(df.xcoord_tf))/2) 116 | y_padding = int((max_lim-max(df.ycoord_tf))/2) 117 | df['xcoord_tf'] += x_padding 118 | df['ycoord_tf'] += y_padding 119 | 120 | ax.scatter(df['xcoord_tf']*scaling_factor, 121 | df['ycoord_tf']*scaling_factor, 122 | s=17, 123 | c=df['color']) 124 | 125 | ax.set_xlim([0,max_lim*scaling_factor]) 126 | ax.set_ylim([0,max_lim*scaling_factor]) 127 | ax.set_facecolor("#F1EFF0") 128 | for p in ['top', 'right', 'bottom', 'left']: 129 | ax.spines[p].set_color('gray') #.set_visible(False) 130 | ax.spines[p].set_linewidth(1) 131 | ax.invert_yaxis() 132 | ax.set_aspect('equal') 133 | ax.tick_params(axis='both', which='both', length=0, labelsize=0) 134 | 135 | plt.savefig(src_path + folder + '/spatial_maps/' + slide_name + '.png', bbox_inches='tight', dpi=300) 136 | 137 | sum_df = all_corr_dfs[0] 138 | for i in range(1,len(all_corr_dfs)): 139 | sum_df += all_corr_dfs[i] 140 | sum_df = sum_df / len(all_corr_dfs) 141 | 142 | plt.close() 143 | plt.figure() 144 | kind = sum_df.columns.map(mapper) 145 | pl = sns.clustermap(sum_df, row_colors=kind, col_colors=kind, cmap='magma') 146 | pl.ax_row_dendrogram.set_visible(False) 147 | pl.ax_col_dendrogram.set_visible(False) 148 | plt.savefig(src_path + folder + '/corr_maps/total_clustered.png', bbox_inches='tight', dpi=300) 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /spatial_vis/get_emd.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import matplotlib.pyplot as plt 4 | import matplotlib 5 | import matplotlib.cm as cm 6 | from matplotlib import rcParams 7 | 8 | import scanpy as sc 9 | import argparse 10 | 11 | import numpy as np 12 | import pandas as pd 13 | from scipy.stats import pearsonr 14 | import cv2 15 | import ot 16 | import gc 17 | from tqdm import tqdm 18 | 19 | from scipy.stats import percentileofscore 20 | 21 | def score2percentile(score, ref): 22 | if np.isnan(score): 23 | return score # deal with nans in visualization (set to black) 24 | percentile = percentileofscore(ref, score) 25 | return percentile 26 | 27 | def get_average(xcoord, ycoord, df, num_tiles): 28 | 29 | distances_x = np.power(df['x'] - xcoord, 2).values 30 | distances_y = np.power(df['y'] - ycoord, 2).values 31 | distances = np.sqrt(distances_x + distances_y) 32 | closest_samples = sorted(range(len(distances)), key = lambda sub: distances[sub])[:num_tiles] 33 | 34 | gene_expr_vals = [] 35 | for i in closest_samples: 36 | gene_expr_vals.append(df.iloc[i]['gene_expr']) 37 | 38 | return np.mean(gene_expr_vals) 39 | 40 | 41 | def median_filter(df, col, xcoord, ycoord, num_neighbors): 42 | window = df[ (df['xcoord_tf'] >= xcoord - num_neighbors) & 43 | (df['ycoord_tf'] >= ycoord - num_neighbors) & 44 | (df['xcoord_tf'] <= xcoord + num_neighbors) & 45 | (df['ycoord_tf'] <= ycoord + num_neighbors) ] 46 | 47 | full_window_size = (num_neighbors*2+1)**2 48 | if window.shape[0] > full_window_size/2: 49 | return np.median(window[col].values) 50 | 51 | return df[(df['xcoord_tf'] == xcoord) & (df['ycoord_tf'] == ycoord)][col].values[0] 52 | 53 | 54 | def img_to_sig(arr): 55 | """Convert a 2D array to a signature for cv2.EMD""" 56 | 57 | # cv2.EMD requires single-precision, floating-point input 58 | sig = np.empty((arr.size, 3), dtype=np.float32) 59 | count = 0 60 | for i in range(arr.shape[0]): 61 | for j in range(arr.shape[1]): 62 | sig[count] = np.array([arr[i,j], i, j]) 63 | count += 1 64 | return sig 65 | 66 | 67 | def calculate_emd(arr1, arr2, norm=False): 68 | assert arr1.shape == arr2.shape, "please provide consistent shapes" 69 | assert len(arr1.shape) == 2, "please give nxm matrix format" 70 | 71 | if (not np.any(arr1)) and (not np.any(arr2)): # if both are totally 0 then the EMD is 0 72 | return 0 73 | 74 | # if one of the two maps is totally zero and the other is not, the EMD is not defined 75 | # in that case we return NaN 76 | if not np.any(arr1): 77 | return np.nan 78 | if not np.any(arr2): 79 | return np.nan 80 | 81 | arr1 = arr1 / np.sum(arr1) 82 | arr2 = arr2 / np.sum(arr2) 83 | 84 | sig1 = img_to_sig(arr1) 85 | sig2 = img_to_sig(arr2) 86 | dist, _, _ = cv2.EMD(sig1, sig2, cv2.DIST_L2) 87 | 88 | if norm: 89 | dist = dist / np.sqrt(arr1.shape[0]*arr2.shape[0]) 90 | return dist 91 | 92 | 93 | def fill_arr(arr, x, y, val): 94 | arr[x,y] = val 95 | 96 | if __name__=='__main__': 97 | 98 | # get args 99 | parser = argparse.ArgumentParser(description='Getting features') 100 | parser.add_argument('--slide_nr', type=str, help='slide nr for which to run script') 101 | parser.add_argument('--pred_folder', type=str, help='folder with predictions to visualize') 102 | parser.add_argument('--save_folder', type=str, help='where to save results') 103 | parser.add_argument('--gene_names', type=str, help='name of genes to visualize (separated by comma) or path to npy array containing gene names') 104 | args = parser.parse_args() 105 | 106 | slide_nr = args.slide_nr 107 | preds_path = f'./visualizations/spatial_GBM_pred/{args.pred_folder}/' 108 | dest_path = f'./visualizations/comparisons/{args.save_folder}/' 109 | 110 | slide_name = 'HRI_'+str(slide_nr)+'_T.tif' 111 | print(slide_name) 112 | csv_path = preds_path + slide_name + '/stride-1.csv' 113 | 114 | gene_names = args.gene_names 115 | if '.npy' in gene_names: 116 | genes = np.load(gene_names, allow_pickle=True) 117 | else: 118 | genes = gene_names.split(",") 119 | 120 | dest_path += slide_name + '/' 121 | if not os.path.exists(dest_path): 122 | os.makedirs(dest_path) 123 | 124 | num_tiles = 4 # how many tiles in ground truth are equal to one tile in prediction (same for all of spatial gbm because prediction resolution is patch of 256x256 at 0.5um pp) 125 | 126 | correlations = {} 127 | pvals = {} 128 | sens_vals = {} 129 | emds = {} 130 | nr_gt_vals = {} 131 | 132 | # after converting the ground truth and prediction to 0-100 and applying median filtering to ground truth 133 | correlations_filt = {} 134 | pvals_filt = {} 135 | sens_vals_filt = {} 136 | emds_filt = {} 137 | nr_gt_vals_filt = {} 138 | 139 | rcParams['font.family'] = 'sans-serif' 140 | fig_resize = 1 141 | 142 | for i_, gene in tqdm(enumerate(genes)): 143 | 144 | try: 145 | 146 | # get ground truth data 147 | AnnData_dir = "./data/Spatial_Heiland/data/AnnDataObject/raw" 148 | adata = sc.read_h5ad(os.path.join(AnnData_dir, f'{slide_nr}_T.h5ad')) 149 | sc.pp.normalize_total(adata, inplace=True) 150 | sc.pp.log1p(adata) 151 | sc.pp.scale(adata) 152 | 153 | adata_subset = adata[:,gene] 154 | coords = adata_subset.obs[['x', 'y']].values 155 | gene_expr = np.asarray(adata_subset.X).flatten() 156 | 157 | df = pd.DataFrame(coords, columns=['x', 'y']) 158 | df['gene_expr'] = gene_expr 159 | df['x_tf'] = (df['x']-min(df['x'])).astype(int) # transform coordinates to regular grid 160 | df['y_tf'] = (df['y']-min(df['y'])).astype(int) 161 | 162 | # transform ground truth to same resolution 163 | df2 = pd.read_csv(csv_path) 164 | df2 = df2.dropna(axis=0, how='any') 165 | df2['ground_truth'] = df2.apply(lambda row: get_average(row['xcoord'], row['ycoord'], df, num_tiles=num_tiles), axis=1) 166 | df2 = df2.dropna(axis=0, how='any') 167 | 168 | # perform median filtering and transform to percentile 169 | # (med filtering only for ground truth, gene prediction is already smooth because of sliding window method) 170 | df2['ground_truth_filt'] = df2.apply(lambda row: median_filter(df2, 'ground_truth', row['xcoord_tf'], row['ycoord_tf'], 1), axis=1) 171 | ref = df2['ground_truth_filt'].values 172 | df2['ground_truth_filt'] = df2.apply(lambda row: score2percentile(row['ground_truth_filt'], ref), axis=1) 173 | 174 | ref2 = df2[gene].values 175 | df2[gene + '_filt'] = df2.apply(lambda row: score2percentile(row[gene], ref2), axis=1) 176 | 177 | for i, gt_col, gene_col in zip(range(2), ['ground_truth', 'ground_truth_filt'], [gene, gene + '_filt']): 178 | 179 | # get EMD 180 | max_x = max(df2.xcoord_tf) 181 | max_y = max(df2.ycoord_tf) 182 | arr0 = np.zeros((max_x+1, max_y+1)) 183 | df2.apply(lambda row: fill_arr(arr0, row['xcoord_tf'].astype(int),row['ycoord_tf'].astype(int), row[gene_col]),axis=1) 184 | arr1 = np.zeros((max_x+1, max_y+1)) 185 | df2.apply(lambda row: fill_arr(arr1, row['xcoord_tf'].astype(int),row['ycoord_tf'].astype(int), row[gt_col]),axis=1) 186 | 187 | arr0 = arr0 + np.abs(np.min(arr0)) 188 | arr1 = arr1 + np.abs(np.min(arr1)) 189 | emd = calculate_emd(arr0, arr1, norm=False) 190 | 191 | if i == 0: 192 | emds[gene] = emd 193 | else: 194 | emds_filt[gene] = emd 195 | 196 | ##### enough to only run this once 197 | # write the area, nr of tiles per slide to a file so normalization can be done afterwards 198 | if i_ == 0: 199 | filename = "./visualizations/spatial_GBM_pred/slide_info.txt" 200 | with open(filename, 'a') as file: 201 | file.write(f"{slide_name} \t {arr0.shape[0]*arr1.shape[0]} \t {df2.shape[0]} \n") 202 | 203 | # also write per slide, per gene, the number of unique values in the ground truth to file to detect any artefacts if needed 204 | nr_gt_vals[gene] = len(np.unique(df2['ground_truth'].values)) 205 | nr_gt_vals_filt[gene] = len(np.unique(df2['ground_truth_filt'].values)) 206 | 207 | except Exception as e: 208 | print(e) 209 | print(gene) 210 | 211 | gc.collect() 212 | 213 | emd_df = pd.DataFrame(emds.items(), columns=['gene', 'emd']) 214 | nr_gt_df = pd.DataFrame(nr_gt_vals.items(), columns=['gene', 'nr_gt_vals']) 215 | 216 | emd_df_filt = pd.DataFrame(emds_filt.items(), columns=['gene', 'emd_filt']) 217 | nr_gt_df_filt = pd.DataFrame(nr_gt_vals_filt.items(), columns=['gene', 'nr_gt_vals_filt']) 218 | 219 | total_df = pd.merge(pd.merge(pd.merge(pd.merge( emd_df, on='gene'), 220 | nr_gt_df, on='gene'), 221 | emd_df_filt, on='gene'), 222 | nr_gt_df_filt, on='gene') 223 | 224 | total_df.to_csv(dest_path + '/' + 'metrics.csv') 225 | print('Done') 226 | 227 | 228 | 229 | 230 | -------------------------------------------------------------------------------- /spatial_vis/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | import json 5 | import numpy as np 6 | import pandas as pd 7 | from einops import rearrange 8 | 9 | from scipy.ndimage.morphology import binary_dilation 10 | import openslide 11 | from PIL import Image 12 | import timm 13 | import torch 14 | from torchvision import transforms 15 | 16 | from he2rna import HE2RNA 17 | from vit import ViT 18 | from resnet import resnet50 19 | from tformer_lin import ViS 20 | 21 | BACKGROUND_THRESHOLD = .5 22 | 23 | 24 | def read_pickle(path): 25 | objects = [] 26 | with (open(path, "rb")) as openfile: 27 | while True: 28 | try: 29 | objects.append(pickle.load(openfile)) 30 | except EOFError: 31 | break 32 | return objects 33 | 34 | 35 | def sliding_window_method(df, patch_size_resized, 36 | feat_model, model, inds_gene_of_interest, stride, 37 | feat_model_type, feat_dim, model_type='vis', device='cpu'): 38 | 39 | max_x = max(df['xcoord_tf']) 40 | max_y = max(df['ycoord_tf']) 41 | 42 | preds = {} # {key:value} where key is a gene index and value is a new dict that contains the predictions per tile for that gene 43 | for ind_gene in inds_gene_of_interest: 44 | preds[ind_gene] = {} 45 | 46 | for x in tqdm(range(0, max_x, stride)): 47 | for y in range(0, max_y, stride): 48 | 49 | window = df[((df['xcoord_tf']>=x) & (df['xcoord_tf']<(x+10))) & 50 | ((df['ycoord_tf']>=y) & (df['ycoord_tf']<(y+10)))] 51 | 52 | if window.shape[0] > ((10*10)/2): 53 | # get the patches 54 | features_all = [] 55 | for ind in window.index: 56 | col = df.iloc[ind]['xcoord'] 57 | row = df.iloc[ind]['ycoord'] 58 | if hasattr(slide, 'read_region'): 59 | patch = slide.read_region((col, row), 0, (patch_size_resized, patch_size_resized)).convert('RGB') 60 | else: 61 | patch = Image.fromarray(np.asarray(slide)[col:col+patch_size_resized,row:row+patch_size_resized]) 62 | patch_tf = transforms_(patch).unsqueeze(0).to(device) 63 | 64 | with torch.no_grad(): 65 | if feat_model_type == 'resnet': 66 | features = feat_model.forward_extract(patch_tf) 67 | else: 68 | features = feat_model(patch_tf) 69 | features_all.append(features) 70 | 71 | # if window contains less than 10x10 tiles, pad with 0 72 | features_all = torch.cat(features_all) 73 | if features_all.shape[0] < 100: 74 | padding = torch.cat([torch.zeros(1, feat_dim) for _ in range(100-features_all.shape[0])]).to(device) 75 | features_all = torch.cat([features_all, padding]) 76 | 77 | # get predictions 78 | with torch.no_grad(): 79 | if model_type == 'he2rna': 80 | features_all = torch.unsqueeze(features_all, dim=0) 81 | features_all = rearrange(features_all, 'b c f -> b f c') 82 | model_predictions = model(features_all) 83 | 84 | predictions = model_predictions.detach().cpu().numpy()[0] 85 | 86 | # add predictions to dict (same for all tiles in window) 87 | for ind_gene in inds_gene_of_interest: 88 | for _, key in enumerate(window.index): 89 | if stride == 10: 90 | preds[ind_gene][key] = predictions[ind_gene] 91 | else: 92 | if key not in preds[ind_gene].keys(): 93 | preds[ind_gene][key] = [predictions[ind_gene]] 94 | else: 95 | preds[ind_gene][key].append(predictions[ind_gene]) 96 | 97 | if stride < 10: 98 | for ind_gene in inds_gene_of_interest: 99 | for key in preds[ind_gene].keys(): 100 | preds[ind_gene][key] = np.mean(preds[ind_gene][key]) 101 | 102 | return preds 103 | 104 | if __name__=='__main__': 105 | 106 | print('Start running visualize script') 107 | 108 | ############################## get args 109 | parser = argparse.ArgumentParser(description='Getting features') 110 | parser.add_argument('--study', type=str, help='cancer study abbreviation, lowercase') 111 | parser.add_argument('--project', type=str, help='name of project (spatial_GBM_pred, TCGA-GBM, PESO, Breast-ST)') 112 | parser.add_argument('--gene_names', type=str, help='name of genes to visualize, separated by commas. if you want all the predicted genes, pass "all" ') 113 | parser.add_argument('--wsi_file_name', type=str, help='wsi filename') 114 | parser.add_argument('--save_folder', type=str, help='destination folder') 115 | parser.add_argument('--model_type', type=str, help='model to use: "he2rna", "vit" or "vis"') 116 | parser.add_argument('--feat_type', type=str, help='"resnet" or "uni"') 117 | parser.add_argument('--folds', type=str, help='folds to use in prediction split by comma', default='0,1,2,3,4') 118 | args = parser.parse_args() 119 | 120 | ############################## general 121 | study = args.study 122 | assert args.feat_type in ['resnet', 'uni'] 123 | assert args.model_type in ['vit', 'vis', 'he2rna'] 124 | 125 | ############################## get model 126 | checkpoint = f'{args.model_type}_{args.feat_type}/{study}/' 127 | obj = read_pickle(checkpoint + 'test_results.pkl')[0] 128 | gene_ids = obj['genes'] 129 | 130 | ############################## prepare data 131 | stride = 1 132 | patch_size = 256 # at 20x (0.5um pp) 133 | wsi_file_name = args.wsi_file_name 134 | project = args.project 135 | save_path = f'./visualizations/{project}/{args.save_folder}/{args.wsi_file_name}/' 136 | if not os.path.exists(save_path): 137 | os.makedirs(save_path) 138 | 139 | if args.gene_names != 'all': 140 | if '.npy' in args.gene_names: 141 | gene_names = np.load(args.gene_names,allow_pickle=True) 142 | else: 143 | gene_names = args.gene_names.split(",") 144 | else: 145 | gene_names = gene_ids 146 | 147 | # prepare and load WSI 148 | if 'TCGA' in wsi_file_name: 149 | slide_path = './TCGA/'+project+'/' 150 | mask_path = './TCGA/'+project+'_Masks/' 151 | mask = np.load(mask_path+wsi_file_name.replace('.svs', '')+'/'+'mask.npy') 152 | manual_resize = None # nr of um/px can be read from slide properties 153 | elif project == 'spatial_GBM_pred': 154 | slide_path = f'./Spatial_GBM/pyramid/' 155 | mask_path = './Spatial_GBM/masks/' 156 | px_df = pd.read_csv('./Spatial_Heiland/data/classify/spot_diameter.csv') 157 | mask = np.load(mask_path + wsi_file_name.replace('.tif', '.npy')) 158 | diam = px_df[px_df['slide_id']==wsi_file_name.split('_')[1] + '_T']['pixel_diameter'].values[0] 159 | um_px = 55/diam # um/px for the WSI 160 | manual_resize = 0.5/um_px 161 | elif project == 'Breast-ST': 162 | slide_path = './Gen-Pred/Breast-ST/wsis/' 163 | mask_path = './Gen-Pred/Breast-ST/masks/' 164 | mask = np.load(mask_path+wsi_file_name.replace('.tif', '.npy')) 165 | metadata = json.load(open(f"./Gen-Pred/Breast-ST/metadata/{wsi_file_name.replace('.tif','.json')}")) 166 | mag = eval(metadata['magnification'].replace('x','')) 167 | manual_resize = mag/20.0 168 | else: 169 | print('please provide correct file name format (containing "TCGA") or correct project id ("spatial_GBM_pred" or "Breast-ST")') 170 | exit() 171 | 172 | # load wsi and calculate patch size in original image (coordinates are at level 0 for openslide) and in mask 173 | slide = openslide.OpenSlide(slide_path + wsi_file_name) 174 | downsample_factor = int(slide.dimensions[0]/mask.shape[0]) # mask downsample factor 175 | slide_dim0, slide_dim1 = slide.dimensions[0], slide.dimensions[1] 176 | 177 | if manual_resize == None: 178 | resize_factor = float(slide.properties.get('aperio.AppMag',20)) / 20.0 179 | else: 180 | resize_factor = manual_resize 181 | 182 | patch_size_resized = int(resize_factor * patch_size) 183 | patch_size_in_mask = int(patch_size_resized/downsample_factor) 184 | 185 | # get valid coordinates (that have tissue) 186 | valid_idx = [] 187 | mask = (np.transpose(mask, axes=[1,0]))*1 188 | for col in range(0, slide_dim0-patch_size_resized, patch_size_resized): #slide.dimensions is (width, height) 189 | for row in range(0, slide_dim1-patch_size_resized, patch_size_resized): 190 | 191 | row_downs = int(row/downsample_factor) 192 | col_downs = int(col/downsample_factor) 193 | 194 | patch_in_mask = mask[row_downs:row_downs+patch_size_in_mask,col_downs:col_downs+patch_size_in_mask] 195 | patch_in_mask = binary_dilation(patch_in_mask, iterations=3) 196 | 197 | if patch_in_mask.sum() >= (BACKGROUND_THRESHOLD * patch_in_mask.size): 198 | # keep patch 199 | valid_idx.append((col, row)) 200 | 201 | # dataframe which contains coordinates of valid patches 202 | df = pd.DataFrame(valid_idx, columns=['xcoord', 'ycoord']) 203 | # rescale coordinates to (0,0) and step size 1 204 | df['xcoord_tf'] = ((df['xcoord']-min(df['xcoord']))/patch_size_resized).astype(int) 205 | df['ycoord_tf'] = ((df['ycoord']-min(df['ycoord']))/patch_size_resized).astype(int) 206 | 207 | print('Got dataframe of valid tiles') 208 | 209 | ############################## feature extractor model 210 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 211 | if args.feat_type == 'resnet': 212 | transforms_ = transforms.Compose([ 213 | transforms.Resize((256,265)), 214 | transforms.ToTensor(), 215 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 216 | ]) 217 | feat_model = resnet50(pretrained=True).to(device) 218 | feat_model.eval() 219 | else: 220 | feat_model = timm.create_model("vit_large_patch16_224", img_size=224, 221 | patch_size=16, init_values=1e-5, 222 | num_classes=0, dynamic_img_size=True) 223 | local_dir = "./Gen-Pred/src/spatial_vis/uni_ckpt/" 224 | feat_model.load_state_dict(torch.load(os.path.join(local_dir, 225 | "pytorch_model.bin"), map_location=device), strict=True) 226 | transforms_ = transforms.Compose([ 227 | transforms.Resize(224), 228 | transforms.ToTensor(), 229 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 230 | ]) 231 | feat_model = feat_model.to(device) 232 | feat_model.eval() 233 | 234 | ############################## get preds 235 | res_df = df.copy(deep=True) 236 | folds = [int(i) for i in args.folds.split(',')] 237 | num_folds = len(folds) 238 | 239 | for fold in folds: 240 | 241 | fold_ckpt = checkpoint + 'model_best_' + str(fold) + '.pt' 242 | if (fold == 0) and ((args.model_type == 'vit') or (args.model_type == 'vis')): 243 | fold_ckpt = fold_ckpt.replace('_0','') 244 | 245 | input_dim = 2048 if args.feat_type == 'resnet' else 1024 246 | if args.model_type == 'vit': 247 | model = ViT(num_outputs=len(gene_ids), 248 | dim=input_dim, depth=6, heads=16, mlp_dim=2048, dim_head = 64) 249 | model.load_state_dict(torch.load(fold_ckpt, map_location=torch.device(device))) 250 | elif args.model_type == 'he2rna': 251 | model = HE2RNA(input_dim=input_dim, layers=[256,256], 252 | ks=[1,2,5,10,20,50,100], 253 | output_dim=len(gene_ids), device=device) 254 | fold_ckpt = fold_ckpt.replace('best_','') 255 | model.load_state_dict(torch.load(fold_ckpt, map_location=torch.device(device)).state_dict()) 256 | elif args.model_type == 'vis': 257 | model = ViS(num_outputs=len(gene_ids), 258 | input_dim=input_dim, 259 | depth=6, nheads=16, 260 | dimensions_f=64, dimensions_c=64, dimensions_s=64, device=device) 261 | model.load_state_dict(torch.load(fold_ckpt, map_location=torch.device(device))) 262 | 263 | model = model.to(device) 264 | model.eval() 265 | 266 | # get indices of requested genes 267 | inds_gene_of_interest = [] 268 | for gene_name in gene_names: 269 | try: 270 | inds_gene_of_interest.append(gene_ids.index(gene_name)) 271 | except: 272 | print('gene not in predicted values '+gene_name) 273 | 274 | # get visualization 275 | preds = sliding_window_method(df=df, patch_size_resized=patch_size_resized, 276 | feat_model=feat_model, model=model, 277 | inds_gene_of_interest=inds_gene_of_interest, stride=stride, 278 | feat_model_type=args.feat_type, feat_dim=input_dim, model_type=args.model_type, device=device) 279 | 280 | for ind_gene in inds_gene_of_interest: 281 | res_df[gene_ids[ind_gene] + '_' + str(fold)] = res_df.index.map(preds[ind_gene]) 282 | 283 | for ind_gene in inds_gene_of_interest: 284 | res_df[gene_ids[ind_gene]] = res_df[[gene_ids[ind_gene] + '_' + str(i) for i in folds]].mean(axis=1) 285 | 286 | save_name = save_path + 'stride-' + str(stride) + '.csv' 287 | res_df.to_csv(save_name) 288 | 289 | print('Done') 290 | 291 | 292 | 293 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/__init__.py -------------------------------------------------------------------------------- /src/folds/test-blca-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-blca-0.npy -------------------------------------------------------------------------------- /src/folds/test-blca-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-blca-1.npy -------------------------------------------------------------------------------- /src/folds/test-blca-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-blca-2.npy -------------------------------------------------------------------------------- /src/folds/test-blca-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-blca-3.npy -------------------------------------------------------------------------------- /src/folds/test-blca-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-blca-4.npy -------------------------------------------------------------------------------- /src/folds/test-brca-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-brca-0.npy -------------------------------------------------------------------------------- /src/folds/test-brca-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-brca-1.npy -------------------------------------------------------------------------------- /src/folds/test-brca-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-brca-2.npy -------------------------------------------------------------------------------- /src/folds/test-brca-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-brca-3.npy -------------------------------------------------------------------------------- /src/folds/test-brca-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-brca-4.npy -------------------------------------------------------------------------------- /src/folds/test-coad-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-coad-0.npy -------------------------------------------------------------------------------- /src/folds/test-coad-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-coad-1.npy -------------------------------------------------------------------------------- /src/folds/test-coad-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-coad-2.npy -------------------------------------------------------------------------------- /src/folds/test-coad-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-coad-3.npy -------------------------------------------------------------------------------- /src/folds/test-coad-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-coad-4.npy -------------------------------------------------------------------------------- /src/folds/test-gbm-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-gbm-0.npy -------------------------------------------------------------------------------- /src/folds/test-gbm-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-gbm-1.npy -------------------------------------------------------------------------------- /src/folds/test-gbm-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-gbm-2.npy -------------------------------------------------------------------------------- /src/folds/test-gbm-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-gbm-3.npy -------------------------------------------------------------------------------- /src/folds/test-gbm-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-gbm-4.npy -------------------------------------------------------------------------------- /src/folds/test-hnsc-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-hnsc-0.npy -------------------------------------------------------------------------------- /src/folds/test-hnsc-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-hnsc-1.npy -------------------------------------------------------------------------------- /src/folds/test-hnsc-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-hnsc-2.npy -------------------------------------------------------------------------------- /src/folds/test-hnsc-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-hnsc-3.npy -------------------------------------------------------------------------------- /src/folds/test-hnsc-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-hnsc-4.npy -------------------------------------------------------------------------------- /src/folds/test-kirc-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-kirc-0.npy -------------------------------------------------------------------------------- /src/folds/test-kirc-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-kirc-1.npy -------------------------------------------------------------------------------- /src/folds/test-kirc-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-kirc-2.npy -------------------------------------------------------------------------------- /src/folds/test-kirc-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-kirc-3.npy -------------------------------------------------------------------------------- /src/folds/test-kirc-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-kirc-4.npy -------------------------------------------------------------------------------- /src/folds/test-kirp-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-kirp-0.npy -------------------------------------------------------------------------------- /src/folds/test-kirp-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-kirp-1.npy -------------------------------------------------------------------------------- /src/folds/test-kirp-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-kirp-2.npy -------------------------------------------------------------------------------- /src/folds/test-kirp-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-kirp-3.npy -------------------------------------------------------------------------------- /src/folds/test-kirp-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-kirp-4.npy -------------------------------------------------------------------------------- /src/folds/test-lihc-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-lihc-0.npy -------------------------------------------------------------------------------- /src/folds/test-lihc-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-lihc-1.npy -------------------------------------------------------------------------------- /src/folds/test-lihc-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-lihc-2.npy -------------------------------------------------------------------------------- /src/folds/test-lihc-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-lihc-3.npy -------------------------------------------------------------------------------- /src/folds/test-lihc-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-lihc-4.npy -------------------------------------------------------------------------------- /src/folds/test-luad-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-luad-0.npy -------------------------------------------------------------------------------- /src/folds/test-luad-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-luad-1.npy -------------------------------------------------------------------------------- /src/folds/test-luad-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-luad-2.npy -------------------------------------------------------------------------------- /src/folds/test-luad-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-luad-3.npy -------------------------------------------------------------------------------- /src/folds/test-luad-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-luad-4.npy -------------------------------------------------------------------------------- /src/folds/test-lusc-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-lusc-0.npy -------------------------------------------------------------------------------- /src/folds/test-lusc-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-lusc-1.npy -------------------------------------------------------------------------------- /src/folds/test-lusc-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-lusc-2.npy -------------------------------------------------------------------------------- /src/folds/test-lusc-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-lusc-3.npy -------------------------------------------------------------------------------- /src/folds/test-lusc-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-lusc-4.npy -------------------------------------------------------------------------------- /src/folds/test-paad-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-paad-0.npy -------------------------------------------------------------------------------- /src/folds/test-paad-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-paad-1.npy -------------------------------------------------------------------------------- /src/folds/test-paad-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-paad-2.npy -------------------------------------------------------------------------------- /src/folds/test-paad-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-paad-3.npy -------------------------------------------------------------------------------- /src/folds/test-paad-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-paad-4.npy -------------------------------------------------------------------------------- /src/folds/test-prad-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-prad-0.npy -------------------------------------------------------------------------------- /src/folds/test-prad-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-prad-1.npy -------------------------------------------------------------------------------- /src/folds/test-prad-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-prad-2.npy -------------------------------------------------------------------------------- /src/folds/test-prad-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-prad-3.npy -------------------------------------------------------------------------------- /src/folds/test-prad-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-prad-4.npy -------------------------------------------------------------------------------- /src/folds/test-skcm-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-skcm-0.npy -------------------------------------------------------------------------------- /src/folds/test-skcm-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-skcm-1.npy -------------------------------------------------------------------------------- /src/folds/test-skcm-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-skcm-2.npy -------------------------------------------------------------------------------- /src/folds/test-skcm-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-skcm-3.npy -------------------------------------------------------------------------------- /src/folds/test-skcm-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-skcm-4.npy -------------------------------------------------------------------------------- /src/folds/test-stad-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-stad-0.npy -------------------------------------------------------------------------------- /src/folds/test-stad-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-stad-1.npy -------------------------------------------------------------------------------- /src/folds/test-stad-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-stad-2.npy -------------------------------------------------------------------------------- /src/folds/test-stad-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-stad-3.npy -------------------------------------------------------------------------------- /src/folds/test-stad-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-stad-4.npy -------------------------------------------------------------------------------- /src/folds/test-thca-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-thca-0.npy -------------------------------------------------------------------------------- /src/folds/test-thca-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-thca-1.npy -------------------------------------------------------------------------------- /src/folds/test-thca-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-thca-2.npy -------------------------------------------------------------------------------- /src/folds/test-thca-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-thca-3.npy -------------------------------------------------------------------------------- /src/folds/test-thca-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-thca-4.npy -------------------------------------------------------------------------------- /src/folds/test-ucec-0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-ucec-0.npy -------------------------------------------------------------------------------- /src/folds/test-ucec-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-ucec-1.npy -------------------------------------------------------------------------------- /src/folds/test-ucec-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-ucec-2.npy -------------------------------------------------------------------------------- /src/folds/test-ucec-3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-ucec-3.npy -------------------------------------------------------------------------------- /src/folds/test-ucec-4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gevaertlab/sequoia-pub/daafd609c941e64a5726509ee3dacb0268b56f96/src/folds/test-ucec-4.npy -------------------------------------------------------------------------------- /src/he2rna.py: -------------------------------------------------------------------------------- 1 | """ 2 | HE2RNA: definition of the algorithm to generate a model for gene expression prediction 3 | Copyright (C) 2020 Owkin Inc. 4 | This program is free software: you can redistribute it and/or modify 5 | it under the terms of the GNU General Public License as published by 6 | the Free Software Foundation, either version 3 of the License, or 7 | (at your option) any later version. 8 | This program is distributed in the hope that it will be useful, 9 | but WITHOUT ANY WARRANTY; without even the implied warranty of 10 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 | GNU General Public License for more details. 12 | You should have received a copy of the GNU General Public License 13 | along with this program. If not, see . 14 | """ 15 | 16 | from tkinter.messagebox import NO 17 | import numpy as np 18 | import pandas as pd 19 | import torch 20 | import time 21 | import os 22 | from torch import nn 23 | from torch.utils.data import DataLoader 24 | #from tensorboardX import SummaryWriter 25 | from tqdm import tqdm 26 | import datetime 27 | import wandb 28 | import argparse 29 | import json 30 | from sklearn.model_selection import train_test_split 31 | #from accelerate import Accelerator 32 | from einops import rearrange 33 | import pickle 34 | import h5py 35 | 36 | from huggingface_hub import PyTorchModelHubMixin 37 | 38 | from src.read_data import SuperTileRNADataset 39 | from src.utils import patient_split, patient_kfold, custom_collate_fn, filter_no_features 40 | 41 | 42 | class HE2RNA(nn.Module, PyTorchModelHubMixin): 43 | """Model that generates one score per tile and per predicted gene. 44 | Args 45 | output_dim (int): Output dimension, must match the number of genes to 46 | predict. 47 | layers (list): List of the layers' dimensions 48 | nonlin (torch.nn.modules.activation) 49 | ks (list): list of numbers of highest-scored tiles to keep in each 50 | channel. 51 | dropout (float) 52 | device (str): 'cpu' or 'cuda' 53 | mode (str): 'binary' or 'regression' 54 | """ 55 | def __init__(self, input_dim, output_dim, 56 | layers=[1], nonlin=nn.ReLU(), ks=[10], 57 | dropout=0.5, device='cpu', 58 | bias_init=None, **kwargs): 59 | super(HE2RNA, self).__init__() 60 | 61 | self.input_dim = input_dim 62 | self.output_dim = output_dim 63 | 64 | layers = [input_dim] + layers + [output_dim] 65 | self.layers = [] 66 | for i in range(len(layers) - 1): 67 | layer = nn.Conv1d(in_channels=layers[i], 68 | out_channels=layers[i+1], 69 | kernel_size=1, 70 | stride=1, 71 | bias=True) 72 | setattr(self, 'conv' + str(i), layer) 73 | self.layers.append(layer) 74 | if bias_init is not None: 75 | self.layers[-1].bias = bias_init 76 | self.ks = np.array(ks) 77 | 78 | self.nonlin = nonlin 79 | self.do = nn.Dropout(dropout) 80 | self.device = device 81 | self.to(self.device) 82 | 83 | def forward(self, x): 84 | if self.training: 85 | k = int(np.random.choice(self.ks)) 86 | return self.forward_fixed_k(x, k) 87 | else: 88 | pred = 0 89 | for k in self.ks: 90 | pred += self.forward_fixed_k(x, int(k)) / len(self.ks) 91 | return pred 92 | 93 | def forward_fixed_k(self, x, k): 94 | mask, _ = torch.max(x, dim=1, keepdim=True) 95 | mask = (mask > 0).float() 96 | x = self.conv(x) * mask 97 | t, _ = torch.topk(x, k, dim=2, largest=True, sorted=True) 98 | x = torch.sum(t * mask[:, :, :k], dim=2) / torch.sum(mask[:, :, :k], dim=2) 99 | return x 100 | 101 | def conv(self, x): 102 | x = x[:, x.shape[1] - self.input_dim:] 103 | for i in range(len(self.layers) - 1): 104 | x = self.do(self.nonlin(self.layers[i](x))) 105 | x = self.layers[-1](x) 106 | return x 107 | 108 | def training_epoch(model, dataloader, optimizer): 109 | """Train model for one epoch. 110 | """ 111 | model.train() 112 | loss_fn = nn.MSELoss() 113 | train_loss = [] 114 | for x, y, _, _ in tqdm(dataloader): 115 | x = x.float().to(model.device) 116 | # rearranging dimenions (b, c, f) to (b, c*f) 117 | x = rearrange(x, 'b c f -> b f c') 118 | y = y.float().to(model.device) 119 | pred = model(x) 120 | loss = loss_fn(pred, y) 121 | train_loss += [loss.detach().cpu().numpy()] 122 | optimizer.zero_grad() 123 | loss.backward() 124 | optimizer.step() 125 | train_loss = np.mean(train_loss) 126 | return train_loss 127 | 128 | ''' 129 | def compute_correlations(labels, preds, projects): 130 | metrics = [] 131 | for project in np.unique(projects): 132 | for i in range(labels.shape[1]): 133 | y_true = labels[projects == project, i] 134 | if len(np.unique(y_true)) > 1: 135 | y_prob = preds[projects == project, i] 136 | metrics.append(np.corrcoef(y_true, y_prob)[0, 1]) 137 | metrics = np.asarray(metrics) 138 | return np.mean(metrics) 139 | ''' 140 | def compute_correlations(labels, preds): 141 | metrics = [] 142 | for i in range(labels.shape[1]): 143 | y_true = labels[:, i] 144 | if len(np.unique(y_true)) > 1: 145 | y_prob = preds[:, i] 146 | metrics.append(np.corrcoef(y_true, y_prob)[0, 1]) 147 | metrics = np.asarray(metrics) 148 | metrics = metrics[~np.isnan(metrics)] 149 | return np.mean(metrics) 150 | 151 | def evaluate(model, dataloader): 152 | """Evaluate the model on the validation set and return loss and metrics. 153 | """ 154 | model.eval() 155 | loss_fn = nn.MSELoss() 156 | valid_loss = [] 157 | preds = [] 158 | labels = [] 159 | for x, y, _, _ in dataloader: 160 | # rearranging dimenions (b, c, f) to (b, c*f) 161 | x = x.float().to(model.device) 162 | x = rearrange(x, 'b c f -> b f c') 163 | pred = model(x) 164 | labels += [y] 165 | loss = loss_fn(pred, y.float().to(model.device)) 166 | valid_loss += [loss.detach().cpu().numpy()] 167 | pred = nn.ReLU()(pred) 168 | preds += [pred.detach().cpu().numpy()] 169 | valid_loss = np.mean(valid_loss) 170 | preds = np.concatenate(preds) 171 | labels = np.concatenate(labels) 172 | metrics = compute_correlations(labels, preds) 173 | return valid_loss, metrics 174 | 175 | def he2rna_predict(model, dataloader): 176 | """Perform prediction on the test set. 177 | """ 178 | model.eval() 179 | preds = [] 180 | wsis = [] 181 | projs = [] 182 | labels = [] 183 | for x, y, wsi_file_name, tcga_project in dataloader: 184 | x = x.float().to(model.device) 185 | wsis.append(wsi_file_name) 186 | projs.append(tcga_project) 187 | # rearranging dimenions (b, c, f) to (b, c*f) 188 | x = rearrange(x, 'b c f -> b f c') 189 | pred = model(x) 190 | pred = nn.ReLU()(pred) 191 | preds += [pred.detach().cpu().numpy()] 192 | labels += [y] 193 | preds = np.concatenate((preds), axis=0) 194 | wsis = np.concatenate((wsis), axis=0) 195 | projs = np.concatenate((projs), axis=0) 196 | labels = np.concatenate((labels), axis=0) 197 | return preds, labels, wsis, projs 198 | 199 | # def predict(model, dataloader): 200 | # """Perform prediction on the test set. 201 | # """ 202 | # model.eval() 203 | # labels = [] 204 | # preds = [] 205 | # for x, y, _ , _ in dataloader: 206 | # # rearranging dimenions (b, c, f) to (b, c*f) 207 | # x = x.float().to(model.device) 208 | # x = rearrange(x, 'b c f -> b f c') 209 | # pred = model(x) 210 | # labels += [y] 211 | # pred = nn.ReLU()(pred) 212 | # preds += [pred.detach().cpu().numpy()] 213 | # preds = np.concatenate(preds) 214 | # labels = np.concatenate(labels) 215 | # return preds, labels 216 | 217 | def fit(model, 218 | lr, 219 | train_loader, 220 | valid_loader, 221 | test_loader, 222 | params={}, 223 | fold=None, 224 | optimizer=None, 225 | path=None): 226 | """Fit the model and make prediction on evaluation set. 227 | Args: 228 | model (nn.Module) 229 | train_set (torch.utils.data.Dataset) 230 | valid_set (torch.utils.data.Dataset) 231 | params (dict): Dictionary for specifying training parameters. 232 | keys are 'max_epochs' (int, default=200), 'patience' (int, 233 | default=20) and 'batch_size' (int, default=16). 234 | optimizer (torch.optim.Optimizer): Optimizer for training the model 235 | test_set (None or torch.utils.data.Dataset): If None, return 236 | predictions on the validation set. 237 | path (str): Path to the folder where th model will be saved. 238 | logdir (str): Path for TensoboardX. 239 | """ 240 | 241 | if path is not None and not os.path.exists(path): 242 | os.mkdir(path) 243 | 244 | default_params = { 245 | 'max_epochs': 200, 246 | 'patience': 100} 247 | default_params.update(params) 248 | patience = default_params['patience'] 249 | max_epochs = default_params['max_epochs'] 250 | 251 | if optimizer is None: 252 | optimizer = torch.optim.Adam(list(model.parameters()), lr=lr, weight_decay=0.) 253 | 254 | metrics = 'correlations' 255 | epoch_since_best = 0 256 | start_time = time.time() 257 | 258 | if valid_loader != None: 259 | valid_loss, best = evaluate(model, valid_loader) 260 | print('{}: {:.3f}'.format(metrics, best)) 261 | 262 | if np.isnan(best): 263 | best = 0 264 | else: 265 | best = 0 266 | 267 | name = 'model' 268 | if fold != None: 269 | name = name + '_' + str(fold) 270 | 271 | try: 272 | for e in range(max_epochs): 273 | 274 | epoch_since_best += 1 275 | 276 | train_loss = training_epoch(model, train_loader, optimizer) 277 | dic_loss = {'train_loss': train_loss} 278 | 279 | print('Epoch {}/{} - {:.2f}s'.format(e + 1, max_epochs, time.time() - start_time)) 280 | start_time = time.time() 281 | 282 | if valid_loader != None: 283 | valid_loss, scores = evaluate( model, valid_loader) 284 | dic_loss['valid_loss'] = valid_loss 285 | score = np.mean(scores) 286 | 287 | if args.log: 288 | wandb.log({'epoch': e, 'score '+str(fold): score}) 289 | wandb.log({'epoch': e, 'valid loss fold '+str(fold): valid_loss}) 290 | wandb.log({'epoch': e, 'train loss fold '+str(fold): train_loss}) 291 | 292 | print('loss: {:.4f}, val loss: {:.4f}'.format(train_loss, valid_loss)) 293 | print('{}: {:.3f}'.format(metrics, score)) 294 | 295 | criterion = (score > best) 296 | 297 | if criterion: 298 | epoch_since_best = 0 299 | best = score 300 | if path is not None: 301 | torch.save(model, os.path.join(path, name + '.pt')) 302 | 303 | if epoch_since_best == patience: 304 | print('Early stopping at epoch {}'.format(e + 1)) 305 | break 306 | 307 | except KeyboardInterrupt: 308 | pass 309 | 310 | if path is not None and os.path.exists(os.path.join(path, name + '.pt')): 311 | model = torch.load(os.path.join(path, name + '.pt')) 312 | 313 | elif path is not None: 314 | torch.save(model, os.path.join(path, name + '.pt')) 315 | 316 | if (test_loader != None): 317 | preds_test, labels_test, wsis, projs = he2rna_predict(model, test_loader) 318 | return preds_test, labels_test, wsis, projs 319 | 320 | return model 321 | 322 | 323 | if __name__ == '__main__': 324 | parser = argparse.ArgumentParser(description='Getting features') 325 | parser.add_argument('--path_csv', type=str, help='path to csv file with gene expression info') 326 | parser.add_argument('--feature_path', type=str, default="features/", help='path to resnet/uni and clustered features') 327 | parser.add_argument('--checkpoint', type=str, help='pretrained model path') 328 | parser.add_argument('--change_num_genes', help="whether finetuning from a model trained on different number of genes", action="store_true") 329 | parser.add_argument('--num_genes', type=int, help='number of genes in output of pretrained model') 330 | parser.add_argument('--seed', type=int, default=99, help='Seed for random generation') 331 | parser.add_argument('--log', type=int, default=1, help='Whether to do the log with wandb') 332 | parser.add_argument('--k', type=int, default=5, help='Number of splits') 333 | parser.add_argument('--batch_size', type=int, default=16, help='Batch size') 334 | parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') 335 | parser.add_argument('--num_workers', type=int, default=0, help='num workers dataloader') 336 | parser.add_argument('--tcga_projects', help="the tcga_projects we want to use", default=None, type=str, nargs='*') 337 | parser.add_argument('--exp_name', type=str, default="exp", help='Experiment name') 338 | parser.add_argument('--subfolder', type=str, default="", help='subfolder where result will be saved') 339 | parser.add_argument('--destfolder', type=str, default="", help='destination folder') 340 | 341 | args = parser.parse_args() 342 | 343 | np.random.seed(args.seed) 344 | torch.manual_seed(args.seed) 345 | 346 | save_dir = os.path.join(args.destfolder, args.subfolder, args.exp_name) 347 | if not os.path.exists(save_dir): 348 | os.makedirs(save_dir) 349 | 350 | experiment_name = args.exp_name 351 | 352 | if args.log: 353 | run = wandb.init(project="sequoia", entity='entity_name', config=args, name=experiment_name) 354 | 355 | path_csv = args.path_csv 356 | 357 | df = pd.read_csv(path_csv) 358 | if args.tcga_projects: 359 | df = df[df['tcga_project'].isin(args.tcga_projects)] 360 | 361 | ############################################## data prep ############################################## 362 | # filter out WSIs for which we don't have features 363 | df = filter_no_features(df, args.feature_path, 'cluster_features') 364 | 365 | ############################################## model training ##############################################v 366 | 367 | train_idxs, val_idxs, test_idxs = patient_kfold(df, n_splits=args.k) 368 | test_results_splits = {} 369 | i = 0 370 | for train_idx, val_idx, test_idx in zip(train_idxs, val_idxs, test_idxs): 371 | train_df = df.iloc[train_idx] 372 | val_df = df.iloc[val_idx] 373 | test_df = df.iloc[test_idx] 374 | 375 | train_dataset = SuperTileRNADataset(train_df, args.feature_path) 376 | val_dataset = SuperTileRNADataset(val_df, args.feature_path) 377 | test_dataset = SuperTileRNADataset(test_df, args.feature_path) 378 | 379 | # SET num_workers TO 0 WHEN WORKING WITH hdf5 FILES 380 | train_loader = DataLoader( 381 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=custom_collate_fn) 382 | 383 | valid_loader = DataLoader( 384 | val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=custom_collate_fn) 385 | 386 | test_loader = DataLoader( 387 | test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=custom_collate_fn) 388 | 389 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 390 | 391 | if not args.change_num_genes: 392 | model = HE2RNA(input_dim=2048, layers=[256,256], 393 | ks=[1,2,5,10,20,50,100], 394 | output_dim=train_dataset.num_genes, device=device) 395 | else: 396 | model = HE2RNA(input_dim=2048, layers=[256,256], 397 | ks=[1,2,5,10,20,50,100], 398 | output_dim=args.num_genes, device=device) 399 | 400 | if args.checkpoint: 401 | model.load_state_dict(torch.load(args.checkpoint).state_dict()) 402 | 403 | if args.change_num_genes: # change num genes: num genes was different for pretraining on gtex 404 | model.conv2 = nn.Conv1d(in_channels=model.conv1.out_channels, 405 | out_channels=train_dataset.num_genes, 406 | kernel_size=1, 407 | stride=1, 408 | bias=True).to(device) 409 | model.layers[-1] = model.conv2 410 | 411 | preds_random, labels_random, wsis_rand, projs_rand = he2rna_predict(model, test_loader) 412 | 413 | preds, labels, wsis, projs = fit(model=model, 414 | lr=args.lr, 415 | train_loader=train_loader, 416 | valid_loader=valid_loader, 417 | test_loader=test_loader, 418 | params={}, 419 | fold=i, 420 | optimizer=None, 421 | path=save_dir) 422 | 423 | test_results = { 424 | 'real': labels, 425 | 'preds': preds, 426 | 'random': preds_random, 427 | 'wsi_file_name': wsis, 428 | 'tcga_project': projs 429 | } 430 | 431 | test_results_splits[f'split_{i}'] = test_results 432 | i += 1 433 | 434 | test_results_splits['genes'] = [x[4:] for x in df.columns if 'rna_' in x] 435 | with open(os.path.join(save_dir, 'test_results.pkl'), 'wb') as f: 436 | pickle.dump(test_results_splits, f, protocol=pickle.HIGHEST_PROTOCOL) 437 | 438 | 439 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from tqdm import tqdm 4 | import pickle 5 | import h5py 6 | import wandb 7 | import random 8 | 9 | import torch.nn as nn 10 | from torch.utils.data import DataLoader 11 | from sklearn.metrics import mean_absolute_error, mean_absolute_percentage_error, mean_squared_error 12 | 13 | from read_data import SuperTileRNADataset 14 | from utils import patient_kfold, filter_no_features, custom_collate_fn 15 | from vit import train, ViT, evaluate 16 | from tformer_lin import ViS 17 | 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser(description='Getting features') 21 | 22 | # general args 23 | parser.add_argument('--src_path', type=str, default='', help='project path') 24 | parser.add_argument('--ref_file', type=str, default=None, help='path to reference file') 25 | parser.add_argument('--sample-percent', type=float, default=None, help='Downsample available data to test the effect of having a smaller dataset. If None, no downsampling.') 26 | parser.add_argument('--tcga_projects', help="the tcga_projects we want to use, separated by comma", default=None, type=str) 27 | parser.add_argument('--feature_path', type=str, default="features/", help='path to resnet/uni and clustered features') 28 | parser.add_argument('--save_dir', type=str, default='saved_exp', help='parent destination folder') 29 | parser.add_argument('--cohort', type=str, default="TCGA", help='cohort name for creating the saving folder of the results') 30 | parser.add_argument('--exp_name', type=str, default="exp", help='Experiment name for creating the saving folder of the results') 31 | parser.add_argument('--filter_no_features', type=int, default=1, help='Whether to filter out samples with no features') 32 | parser.add_argument('--log', type=str, help='Experiment name to log') 33 | 34 | # model args 35 | parser.add_argument('--model_type', type=str, default='vit', help='"vit" for transformer or "vis" for linearized transformer') 36 | parser.add_argument('--depth', type=int, default=6, help='transformer depth') 37 | parser.add_argument('--num-heads', type=int, default=16, help='number of attention heads') 38 | parser.add_argument('--seed', type=int, default=99, help='Seed for random generation') 39 | parser.add_argument('--lr', type=float, default=1e-3, help='Learning rate') 40 | parser.add_argument('--batch_size', type=int, default=16, help='Batch size') 41 | parser.add_argument('--checkpoint', type=str, default=None, help='Checkpoint from trained model.') 42 | parser.add_argument('--train', help="if you want to train the model", action="store_true") 43 | parser.add_argument('--num_epochs', type=int, default=200, help='number of epochs to train') 44 | parser.add_argument('--change_num_genes', type=int, default=0, help="whether finetuning from a model trained on different number of genes") 45 | parser.add_argument('--num_genes', type=int, default=None, help='number of genes on which pretrained model was trained') 46 | parser.add_argument('--k', type=int, default=5, help='Number of splits') 47 | parser.add_argument('--save_on', type=str, default='loss', help='which criterium to save model on, "loss" or "loss+corr"') 48 | parser.add_argument('--stop_on', type=str, default='loss', help='which criterium to do early stopping on, "loss" or "loss+corr"') 49 | 50 | args = parser.parse_args() 51 | 52 | ############################################## seeds ############################################## 53 | 54 | np.random.seed(args.seed) 55 | torch.manual_seed(args.seed) 56 | random.seed(args.seed) 57 | torch.backends.cudnn.benchmark = False # possibly reduced performance but better reproducibility 58 | torch.backends.cudnn.deterministic = True 59 | 60 | # reproducibility train dataloader 61 | def seed_worker(worker_id): 62 | worker_seed = torch.initial_seed() % 2**32 63 | np.random.seed(worker_seed) 64 | random.seed(worker_seed) 65 | g = torch.Generator() 66 | g.manual_seed(0) 67 | 68 | ############################################## logging ############################################## 69 | 70 | save_dir = os.path.join(args.src_path, args.save_dir, args.cohort, args.exp_name) 71 | if not os.path.exists(save_dir): 72 | os.makedirs(save_dir) 73 | 74 | run = None 75 | if args.log: 76 | run = wandb.init(project=args.log, config=args, name=args.exp_name) 77 | 78 | device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") 79 | print(device) 80 | 81 | ############################################## data prep ############################################## 82 | 83 | df = pd.read_csv(args.ref_file) 84 | if args.sample_percent != None: 85 | df = df.sample(frac=args.sample_percent).reset_index(drop=True) 86 | 87 | if ('tcga_project' in df.columns) and (args.tcga_projects != None): 88 | projects = args.tcga_projects.split(',') 89 | df = df[df['tcga_project'].isin(projects)].reset_index(drop=True) 90 | print(f'Filtered project {projects}') 91 | 92 | if args.filter_no_features: 93 | df = filter_no_features(df, feature_path=args.feature_path, 'cluster_features') 94 | 95 | ############################################## kfold ############################################## 96 | train_idxs, val_idxs, test_idxs = patient_kfold(df, n_splits=args.k) 97 | 98 | test_results_splits = {} 99 | i = 0 100 | 101 | for train_idx, val_idx, test_idx in zip(train_idxs, val_idxs, test_idxs): 102 | train_df = df.iloc[train_idx] 103 | val_df = df.iloc[val_idx] 104 | test_df = df.iloc[test_idx] 105 | 106 | # save patient ids to file 107 | np.save(save_dir + '/train_'+str(i)+'.npy', np.unique(train_df.patient_id) ) 108 | np.save(save_dir + '/val_'+str(i)+'.npy', np.unique(val_df.patient_id) ) 109 | np.save(save_dir + '/test_'+str(i)+'.npy', np.unique(test_df.patient_id) ) 110 | 111 | # init dataset 112 | train_dataset = SuperTileRNADataset(train_df, args.feature_path) 113 | val_dataset = SuperTileRNADataset(val_df, args.feature_path) 114 | test_dataset = SuperTileRNADataset(test_df, args.feature_path) 115 | 116 | num_outputs = train_dataset.num_genes 117 | feature_dim = train_dataset.feature_dim 118 | 119 | # init dataloaders 120 | train_dataloader = DataLoader(train_dataset, 121 | num_workers=0, pin_memory=True, 122 | shuffle=shuffle, batch_size=args.batch_size, 123 | collate_fn=custom_collate_fn, 124 | worker_init_fn=seed_worker, 125 | generator=g) 126 | 127 | val_dataloader = DataLoader(val_dataset, 128 | num_workers=0, pin_memory=True, 129 | shuffle=True, batch_size=args.batch_size, 130 | collate_fn=custom_collate_fn) 131 | 132 | test_dataloader = DataLoader(test_dataset, 133 | num_workers=0, pin_memory=True, 134 | shuffle=False, batch_size=args.batch_size, 135 | collate_fn=custom_collate_fn) 136 | 137 | # get model 138 | if args.checkpoint and args.change_num_genes: # if finetuning from model trained on gtex 139 | model_path = os.path.join(args.checkpoint) 140 | if args.model_type == 'vit': 141 | model = ViT(num_outputs=args.change_num_genes, dim=feature_dim, 142 | depth=args.depth, heads=args.num_heads, 143 | mlp_dim=2048, dim_head=64, device=device) 144 | elif args.model_type == 'vis': 145 | model = ViS(num_outputs=args.change_num_genes, input_dim=feature_dim, 146 | depth=args.depth, nheads=args.num_heads, 147 | dimensions_f=64, dimensions_c=64, dimensions_s=64, device=device) 148 | else: 149 | print('please specify correct model type "vit" or "vis"') 150 | exit() 151 | 152 | model.load_state_dict(torch.load(model_path, map_location = device)) 153 | print(f'Loaded model from {model_path}') 154 | 155 | model.linear_head = nn.Sequential( 156 | nn.LayerNorm(feature_dim), 157 | nn.Linear(feature_dim, num_outputs)) 158 | 159 | else: # if training from scratch or continuing training same model (then load state dict in next if) 160 | if args.model_type == 'vit': 161 | model = ViT(num_outputs=num_outputs, dim=feature_dim, 162 | depth=args.depth, heads=args.num_heads, 163 | mlp_dim=2048, dim_head=64, device=device) 164 | elif args.model_type == 'vis': 165 | model = ViS(num_outputs=num_outputs, input_dim=feature_dim, 166 | depth=args.depth, nheads=args.num_heads, 167 | dimensions_f=64, dimensions_c=64, dimensions_s=64, device=device) 168 | else: 169 | print('please specify correct model type "vit" or "vis"') 170 | 171 | if args.checkpoint and not args.change_num_genes: 172 | suff = f'_{i}' if i > 0 else '' 173 | model_path = args.checkpoint + f'model_best{suff}.pt' 174 | print(f'Loading model from {model_path}') 175 | model.load_state_dict(torch.load(model_path, map_location=device)) 176 | 177 | model.to(device) 178 | 179 | # training 180 | optimizer = torch.optim.AdamW(list(model.parameters()), 181 | lr=args.lr, 182 | amsgrad=False, 183 | weight_decay=0.) 184 | dataloaders = { 'train': train_dataloader, 'val': val_dataloader} 185 | 186 | if args.train: 187 | model = train(model, dataloaders, optimizer, 188 | num_epochs=args.num_epochs, run=run,split=i, 189 | save_on=args.save_on, stop_on=args.stop_on, 190 | delta=0.5, save_dir=save_dir) 191 | 192 | preds, real, wsis, projs = evaluate(model, test_dataloader, run=run, suff='_'+str(i)) 193 | 194 | # get random model predictions 195 | if args.model_type == 'vit': 196 | random_model = ViT(num_outputs=num_outputs, dim=feature_dim, 197 | depth=args.depth, heads=args.num_heads, 198 | mlp_dim=2048, dim_head = 64, device = device) 199 | else: 200 | random_model = ViS(num_outputs=num_outputs, input_dim=feature_dim, 201 | depth=args.depth, nheads=args.num_heads, 202 | dimensions_f=64, dimensions_c=64, dimensions_s=64, device=device) 203 | random_model = random_model.to(device) 204 | random_preds, _, _, _ = evaluate(random_model, test_dataloader, run=run, suff='_'+str(i)+'_rand') 205 | 206 | test_results = { 207 | 'real': real, 208 | 'preds': preds, 209 | 'random': random_preds, 210 | 'wsi_file_name': wsis, 211 | 'tcga_project': projs 212 | } 213 | 214 | test_results_splits[f'split_{i}'] = test_results 215 | i += 1 216 | 217 | test_results_splits['genes'] = [x[4:] for x in df.columns if 'rna_' in x] 218 | with open(os.path.join(save_dir, 'test_results.pkl'), 'wb') as f: 219 | pickle.dump(test_results_splits, f, protocol=pickle.HIGHEST_PROTOCOL) -------------------------------------------------------------------------------- /src/pretrain_gtex.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import datetime 4 | 5 | from torch.utils.data import DataLoader 6 | import wandb 7 | 8 | from src.read_data import * 9 | from src.vit import ViT, train 10 | from src.he2rna import HE2RNA, fit 11 | from src.tformer_lin import ViS 12 | 13 | def custom_collate_fn(batch): 14 | """Remove bad entries from the dataloader 15 | Args: 16 | batch (torch.Tensor): batch of tensors from the dataaset 17 | Returns: 18 | collate: Default collage for the dataloader 19 | """ 20 | 21 | try: 22 | batch = list(filter(lambda x: x[0] is not None, batch)) 23 | except: 24 | batch['image'] = [] 25 | return torch.utils.data.dataloader.default_collate(batch) 26 | 27 | def filter_no_features(df, feature_path = "examples/features"): 28 | no_features = [] 29 | for i, row in df.iterrows(): 30 | row = row.to_dict() 31 | wsi = row['wsi_file_name'] 32 | project = row['tcga_project'] 33 | path = os.path.join(feature_path, project, wsi, wsi+'.h5') 34 | if not os.path.exists(path): 35 | no_features.append(wsi) 36 | df = df[~df['wsi_file_name'].isin(no_features)] 37 | return df 38 | 39 | if __name__ == '__main__': 40 | parser = argparse.ArgumentParser(description='Getting features') 41 | parser.add_argument('--save_dir', type=str, default="/examples/pretrained_model", help='save directory') 42 | parser.add_argument('--path_csv', type=str, default="/examples/ref_file.csv", help='path to reference file with gene expression data') 43 | parser.add_argument('--feature_path', type=str, default="/examples/features", help='path to resnet and clustered features') 44 | parser.add_argument('--exp_name', type=str, default="exp", help='Experiment name used to create saved model name') 45 | parser.add_argument('--log', type=int, default=0, help='whether to log the loss') 46 | parser.add_argument('--model', type=str, default='vis', help='model to pretrain, "he2rna" for MLP aggregation, "vit" for transformer aggregation or "vis" for linearized transformer aggregation') 47 | parser.add_argument('--seed', type=int, default=99, help='Seed for random generation') 48 | parser.add_argument('--num_epochs', type=int, default=200, help='number of epochs to train') 49 | parser.add_argument('--batch_size', type=int, default=16, help='batch size to train') 50 | parser.add_argument('--n_workers', type=int, default=8, help='number of workers to train') 51 | parser.add_argument('--checkpoint', type=str, default=None, help='Checkpoint from trained model') 52 | parser.add_argument('--quick', type=int, default=0, help='Whether to run a quick exp for debugging') 53 | 54 | args = parser.parse_args() 55 | 56 | np.random.seed(args.seed) 57 | torch.manual_seed(args.seed) 58 | 59 | ############################################## logging and save dir ############################################## 60 | if args.exp_name == "": 61 | args.exp_name = '{date:%Y-%m-%d}'.format(date=datetime.datetime.now()) 62 | else: 63 | args.exp_name = '{date:%Y-%m-%d}'.format(date=datetime.datetime.now()) + "_" + args.exp_name 64 | 65 | save_dir = os.path.join(args.save_dir, args.exp_name) 66 | if not os.path.exists(save_dir): 67 | os.makedirs(save_dir) 68 | 69 | run = None 70 | if args.log: 71 | run = wandb.init(project="sequoia", entity='account_name', config=args, name=args.exp_name) 72 | 73 | ############################################## prepare data ############################################## 74 | device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") 75 | print(device) 76 | 77 | df = pd.read_csv(args.path_csv) 78 | df = filter_no_features(df) 79 | 80 | if args.quick: 81 | df = df.iloc[0:20, :] 82 | args.num_epochs = 5 83 | 84 | dataset = SuperTileRNADataset(df, args.feature_path) 85 | 86 | dataloader = DataLoader(dataset, 87 | num_workers=args.n_workers, pin_memory=True, 88 | shuffle=True, batch_size=args.batch_size, 89 | collate_fn=custom_collate_fn) 90 | 91 | ############################################## model ############################################## 92 | 93 | if args.model == 'vis': 94 | model = ViS(num_outputs=dataset.num_genes, input_dim=dataset.feature_dim, 95 | depth=6, nheads=16, 96 | dimensions_f=64, dimensions_c=64, dimensions_s=64, device=device) 97 | elif args.model == 'vit': 98 | model = ViT(num_outputs=dataset.num_genes, 99 | dim=dataset.feature_dim, depth=6, heads=16, mlp_dim=2048, dim_head = 64, 100 | device=device) 101 | 102 | elif args.model == 'he2rna': 103 | model = HE2RNA(input_dim=dataset.feature_dim, layers=[256,256], 104 | ks=[1,2,5,10,20,50,100], 105 | output_dim=dataset.num_genes, device=device) 106 | else: 107 | print('please specify correct model name, "vit" or "he2rna"') 108 | exit() 109 | 110 | if args.checkpoint != None: 111 | model.load_state_dict(torch.load(args.checkpoint)) 112 | model = model.to(device) 113 | 114 | ############################################## training ############################################## 115 | optimizer = torch.optim.AdamW(list(model.parameters()), lr=3e-3,weight_decay=0.) 116 | dataloaders = {'train': dataloader,} 117 | 118 | if (args.model == 'vis') or (args.model == 'vit'): 119 | model = train(model, dataloaders, optimizer, num_epochs=args.num_epochs, phases=['train'], save_dir=save_dir, run=run) 120 | else: 121 | model = fit(model=model, lr=3e-3, train_loader=dataloaders['train'], valid_loader=None, test_loader=None, 122 | params={}, fold=None, optimizer=None, path=save_dir) 123 | 124 | print('Finished pre-training') 125 | -------------------------------------------------------------------------------- /src/read_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | import pandas as pd 7 | import torch 8 | from tqdm import tqdm 9 | import h5py 10 | 11 | 12 | class SuperTileRNADataset(Dataset): 13 | def __init__(self, csv_path: str, features_path, quick=None): 14 | self.csv_path = csv_path 15 | self.quick = quick 16 | self.features_path = features_path 17 | if type(csv_path) == str: 18 | self.data = pd.read_csv(csv_path) 19 | else: 20 | self.data = csv_path 21 | 22 | # find the number of genes 23 | row = self.data.iloc[0] 24 | rna_data = row[[x for x in row.keys() if 'rna_' in x]].values.astype(np.float32) 25 | self.num_genes = len(rna_data) 26 | 27 | # find the feature dimension, assume all images in the reference file have the same dimension 28 | path = os.path.join(self.features_path, row['tcga_project'], 29 | row['wsi_file_name'], row['wsi_file_name']+'.h5') 30 | f = h5py.File(path, 'r') 31 | features = f[self.feature_use][:] 32 | self.feature_dim = features.shape[1] 33 | f.close() 34 | 35 | def __len__(self): 36 | return self.data.shape[0] 37 | 38 | def __getitem__(self, idx): 39 | row = self.data.iloc[idx] 40 | path = os.path.join(self.features_path, row['tcga_project'], 41 | row['wsi_file_name'], row['wsi_file_name']+'.h5') 42 | rna_data = row[[x for x in row.keys() if 'rna_' in x]].values.astype(np.float32) 43 | rna_data = torch.tensor(rna_data, dtype=torch.float32) 44 | try: 45 | if 'GTEX' not in path: 46 | path = path.replace('.svs','') 47 | f = h5py.File(path, 'r') 48 | features = f['cluster_features'][:] 49 | f.close() 50 | features = torch.tensor(features, dtype=torch.float32) 51 | except Exception as e: 52 | print(e) 53 | print(path) 54 | features = None 55 | 56 | return features, rna_data, row['wsi_file_name'], row['tcga_project'] -------------------------------------------------------------------------------- /src/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch.nn.functional as F 5 | import torch 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | 98 | def __init__(self, block, layers, num_classes=1000): 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 110 | self.avgpool = nn.AvgPool2d(7) 111 | self.fc = nn.Linear(512 * block.expansion, num_classes) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | 149 | x = self.avgpool(x) 150 | x = x.view(x.size(0), -1) 151 | x = self.fc(x) 152 | 153 | return x 154 | 155 | def forward_extract(self, x): 156 | x = self.conv1(x) 157 | x = self.bn1(x) 158 | x = self.relu(x) 159 | x = self.maxpool(x) 160 | 161 | x = self.layer1(x) 162 | x = self.layer2(x) 163 | x = self.layer3(x) 164 | x = self.layer4(x) 165 | 166 | x = self.avgpool(x) 167 | 168 | x = x.view(x.size(0), -1) 169 | 170 | return x 171 | 172 | class RNfour(nn.Module): 173 | 174 | def __init__(self, block, layers, num_classes=1000): 175 | self.inplanes = 64 176 | super(RNfour, self).__init__() 177 | self.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, 178 | bias=False) 179 | self.bn1 = nn.BatchNorm2d(64) 180 | self.relu = nn.ReLU(inplace=True) 181 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 182 | self.layer1 = self._make_layer(block, 64, layers[0]) 183 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 184 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 185 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 186 | self.avgpool = nn.AvgPool2d(7, stride=1) 187 | self.fc = nn.Linear(512 * block.expansion, num_classes) 188 | 189 | for m in self.modules(): 190 | if isinstance(m, nn.Conv2d): 191 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 192 | m.weight.data.normal_(0, math.sqrt(2. / n)) 193 | elif isinstance(m, nn.BatchNorm2d): 194 | m.weight.data.fill_(1) 195 | m.bias.data.zero_() 196 | 197 | def _make_layer(self, block, planes, blocks, stride=1): 198 | downsample = None 199 | if stride != 1 or self.inplanes != planes * block.expansion: 200 | downsample = nn.Sequential( 201 | nn.Conv2d(self.inplanes, planes * block.expansion, 202 | kernel_size=1, stride=stride, bias=False), 203 | nn.BatchNorm2d(planes * block.expansion), 204 | ) 205 | 206 | layers = [] 207 | layers.append(block(self.inplanes, planes, stride, downsample)) 208 | self.inplanes = planes * block.expansion 209 | for i in range(1, blocks): 210 | layers.append(block(self.inplanes, planes)) 211 | 212 | return nn.Sequential(*layers) 213 | 214 | def forward(self, x): 215 | x = self.conv1(x) 216 | x = self.bn1(x) 217 | x = self.relu(x) 218 | x = self.maxpool(x) 219 | 220 | x = self.layer1(x) 221 | x = self.layer2(x) 222 | x = self.layer3(x) 223 | x = self.layer4(x) 224 | 225 | x = self.avgpool(x) 226 | x = x.view(x.size(0), -1) 227 | x = self.fc(x) 228 | 229 | return x 230 | 231 | def forward_extract(self, x): 232 | x = self.conv1(x) 233 | x = self.bn1(x) 234 | x = self.relu(x) 235 | x = self.maxpool(x) 236 | 237 | x = self.layer1(x) 238 | x = self.layer2(x) 239 | x = self.layer3(x) 240 | x = self.layer4(x) 241 | 242 | x = self.avgpool(x) 243 | x = x.view(x.size(0), -1) 244 | 245 | return x 246 | 247 | class RNone(nn.Module): 248 | 249 | def __init__(self, block, layers, num_classes=1000): 250 | self.inplanes = 64 251 | super(RNone, self).__init__() 252 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, 253 | bias=False) 254 | self.bn1 = nn.BatchNorm2d(64) 255 | self.relu = nn.ReLU(inplace=True) 256 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 257 | self.layer1 = self._make_layer(block, 64, layers[0]) 258 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 259 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 260 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 261 | self.avgpool = nn.AvgPool2d(7, stride=1) 262 | self.fc = nn.Linear(512 * block.expansion, num_classes) 263 | 264 | for m in self.modules(): 265 | if isinstance(m, nn.Conv2d): 266 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 267 | m.weight.data.normal_(0, math.sqrt(2. / n)) 268 | elif isinstance(m, nn.BatchNorm2d): 269 | m.weight.data.fill_(1) 270 | m.bias.data.zero_() 271 | 272 | def _make_layer(self, block, planes, blocks, stride=1): 273 | downsample = None 274 | if stride != 1 or self.inplanes != planes * block.expansion: 275 | downsample = nn.Sequential( 276 | nn.Conv2d(self.inplanes, planes * block.expansion, 277 | kernel_size=1, stride=stride, bias=False), 278 | nn.BatchNorm2d(planes * block.expansion), 279 | ) 280 | 281 | layers = [] 282 | layers.append(block(self.inplanes, planes, stride, downsample)) 283 | self.inplanes = planes * block.expansion 284 | for i in range(1, blocks): 285 | layers.append(block(self.inplanes, planes)) 286 | 287 | return nn.Sequential(*layers) 288 | 289 | def forward(self, x): 290 | x = self.conv1(x) 291 | x = self.bn1(x) 292 | x = self.relu(x) 293 | x = self.maxpool(x) 294 | 295 | x = self.layer1(x) 296 | x = self.layer2(x) 297 | x = self.layer3(x) 298 | x = self.layer4(x) 299 | 300 | x = self.avgpool(x) 301 | x = x.view(x.size(0), -1) 302 | x = self.fc(x) 303 | 304 | return x 305 | 306 | def forward_extract(self, x): 307 | x = self.conv1(x) 308 | x = self.bn1(x) 309 | x = self.relu(x) 310 | x = self.maxpool(x) 311 | 312 | x = self.layer1(x) 313 | x = self.layer2(x) 314 | x = self.layer3(x) 315 | x = self.layer4(x) 316 | 317 | x = self.avgpool(x) 318 | x = x.view(x.size(0), -1) 319 | 320 | return x 321 | 322 | 323 | class ResNetProject(nn.Module): 324 | 325 | def __init__(self, resnet, hdim=200, input_dim=2048, dropout=.3): 326 | super(ResNetProject, self).__init__() 327 | self.resnet = resnet 328 | self.hdim = hdim 329 | self.dropout = nn.Dropout(p=dropout) 330 | self.project = nn.Linear(input_dim, hdim) 331 | self.fc = nn.Linear(hdim, 1) 332 | 333 | def forward_extract(self, x): 334 | x = self.resnet.forward_extract(x) 335 | x = self.project(x) 336 | x = F.tanh(x) 337 | x = self.dropout(x) 338 | return x 339 | 340 | def forward(self, x): 341 | x = self.forward_extract(x) 342 | x = self.fc(x) 343 | return x 344 | 345 | 346 | def resnet18(pretrained=False, **kwargs): 347 | """Constructs a ResNet-18 model. 348 | 349 | Args: 350 | pretrained (bool): If True, returns a model pre-trained on ImageNet 351 | """ 352 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 353 | if pretrained: 354 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 355 | return model 356 | 357 | 358 | def resnet34(pretrained=False, **kwargs): 359 | """Constructs a ResNet-34 model. 360 | 361 | Args: 362 | pretrained (bool): If True, returns a model pre-trained on ImageNet 363 | """ 364 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 365 | if pretrained: 366 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 367 | return model 368 | 369 | 370 | def resnet50(pretrained=False, **kwargs): 371 | """Constructs a ResNet-50 model. 372 | 373 | Args: 374 | pretrained (bool): If True, returns a model pre-trained on ImageNet 375 | """ 376 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 377 | if pretrained: 378 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 379 | return model 380 | 381 | def resnet50_4channel(pretrained=False, **kwargs): 382 | """Constructs a ResNet-50 model. 383 | 384 | Args: 385 | pretrained (bool): If True, returns a model pre-trained on ImageNet 386 | """ 387 | new_model = RNfour(Bottleneck, [3, 4, 6, 3], **kwargs) 388 | 389 | if pretrained: 390 | 391 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 392 | new_model_dict = new_model.state_dict() 393 | 394 | # 1. filter out unnecessary keys 395 | filtered_pretrained_dict = {k: v for k, v in pretrained_dict.items() if k!='conv1.weight'} 396 | # 2. overwrite entries in the existing state dict 397 | new_model_dict.update(filtered_pretrained_dict) 398 | # 3. load the new state dict 399 | new_model.load_state_dict(new_model_dict) 400 | 401 | new_model.conv1.weight.data.normal_(0, 0.001) 402 | new_model.conv1.weight.data[:, :3, :, :] = pretrained_dict['conv1.weight'] 403 | 404 | 405 | return new_model 406 | 407 | def resnet50_1channel(pretrained=False, **kwargs): 408 | """Constructs a ResNet-50 model. 409 | 410 | Args: 411 | pretrained (bool): If True, returns a model pre-trained on ImageNet 412 | """ 413 | new_model = RNone(Bottleneck, [3, 4, 6, 3], **kwargs) 414 | 415 | if pretrained: 416 | 417 | pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 418 | new_model_dict = new_model.state_dict() 419 | 420 | # 1. filter out unnecessary keys 421 | filtered_pretrained_dict = {k: v for k, v in pretrained_dict.items() if k!='conv1.weight'} 422 | # 2. overwrite entries in the existing state dict 423 | new_model_dict.update(filtered_pretrained_dict) 424 | # 3. load the new state dict 425 | new_model.load_state_dict(new_model_dict) 426 | 427 | con1w=pretrained_dict['conv1.weight'] 428 | con1w_mean=torch.mean(con1w, dim=1, keepdim=True) 429 | 430 | new_model.conv1.weight.data=con1w_mean 431 | 432 | 433 | 434 | 435 | return new_model 436 | 437 | def resnet101(pretrained=False, **kwargs): 438 | """Constructs a ResNet-101 model. 439 | 440 | Args: 441 | pretrained (bool): If True, returns a model pre-trained on ImageNet 442 | """ 443 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 444 | if pretrained: 445 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 446 | return model 447 | 448 | 449 | def resnet152(pretrained=False, **kwargs): 450 | """Constructs a ResNet-152 model. 451 | 452 | Args: 453 | pretrained (bool): If True, returns a model pre-trained on ImageNet 454 | """ 455 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 456 | if pretrained: 457 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 458 | return model 459 | -------------------------------------------------------------------------------- /src/tformer_lin.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from einops import rearrange 4 | from huggingface_hub import PyTorchModelHubMixin 5 | 6 | 7 | class SummaryMixing(nn.Module): 8 | def __init__(self, input_dim, dimensions_f, dimensions_s, dimensions_c): 9 | super().__init__() 10 | 11 | self.local_norm = nn.LayerNorm(dimensions_f) 12 | self.summary_norm = nn.LayerNorm(dimensions_s) 13 | 14 | self.s = nn.Linear(input_dim, dimensions_s) 15 | self.f = nn.Linear(input_dim, dimensions_f) 16 | self.c = nn.Linear(dimensions_s + dimensions_f, dimensions_c) 17 | 18 | def forward(self, x): 19 | 20 | local_summ = torch.nn.GELU()(self.local_norm(self.f(x))) 21 | time_summ = self.s(x) 22 | time_summ = torch.nn.GELU()(self.summary_norm(torch.mean(time_summ, dim=1))) 23 | time_summ = time_summ.unsqueeze(1).repeat(1, x.shape[1], 1) 24 | out = torch.nn.GELU()(self.c(torch.cat([local_summ, time_summ], dim=-1))) 25 | 26 | return out 27 | 28 | 29 | class MultiHeadSummary(nn.Module): 30 | def __init__(self, nheads, input_dim, dimensions_f, dimensions_s, dimensions_c, dimensions_projection): 31 | super().__init__() 32 | 33 | self.mixers = nn.ModuleList([]) 34 | for _ in range(nheads): 35 | self.mixers.append(SummaryMixing(input_dim=input_dim, dimensions_f=dimensions_f, dimensions_s=dimensions_s, dimensions_c=dimensions_c)) 36 | 37 | self.projection = nn.Linear(nheads*dimensions_c, dimensions_projection) 38 | 39 | def forward(self, x): 40 | 41 | outs = [] 42 | for mixer in self.mixers: 43 | outs.append(mixer(x)) 44 | 45 | outs = torch.cat(outs, dim=-1) 46 | out = self.projection(outs) 47 | 48 | return out 49 | 50 | 51 | class FeedForward(nn.Module): 52 | def __init__(self, dim, hidden_dim): 53 | super().__init__() 54 | self.net = nn.Sequential( 55 | nn.LayerNorm(dim), 56 | nn.Linear(dim, hidden_dim), 57 | nn.GELU(), 58 | nn.Linear(hidden_dim, dim), 59 | ) 60 | def forward(self, x): 61 | return self.net(x) 62 | 63 | 64 | class SummaryTransformer(nn.Module): 65 | def __init__(self, input_dim, depth, nheads, dimensions_f, dimensions_s, dimensions_c): 66 | super().__init__() 67 | self.layers = nn.ModuleList([]) 68 | for _ in range(depth): 69 | self.layers.append(nn.ModuleList([ 70 | MultiHeadSummary(nheads, input_dim, dimensions_f, dimensions_s, dimensions_c, dimensions_projection=input_dim), 71 | FeedForward(input_dim, input_dim) 72 | ])) 73 | def forward(self, x): 74 | for attn, ff in self.layers: 75 | x = attn(x) + x # dimensions_projection needs to be equal to input_dim 76 | x = ff(x) + x # output_dim of feedforward needs to be equal to input_dim 77 | return x 78 | 79 | 80 | class ViS(nn.Module, PyTorchModelHubMixin): 81 | def __init__(self, num_outputs, input_dim, depth, nheads, 82 | dimensions_f, dimensions_s, dimensions_c, 83 | num_clusters=100, device='cuda:0'): 84 | super().__init__() 85 | 86 | self.pos_emb1D = nn.Parameter(torch.randn(num_clusters, input_dim)) 87 | 88 | self.transformer = SummaryTransformer(input_dim, depth, nheads, dimensions_f, dimensions_s, dimensions_c) 89 | 90 | self.to_latent = nn.Identity() 91 | self.linear_head = nn.Sequential( 92 | nn.LayerNorm(input_dim), 93 | nn.Linear(input_dim, num_outputs) 94 | ) 95 | self.device = device 96 | 97 | def forward(self, x): 98 | 99 | #pe = posemb_sincos_2d(x) 100 | x = rearrange(x, 'b ... d -> b (...) d') + self.pos_emb1D 101 | 102 | x = self.transformer(x) 103 | x = x.mean(dim = 1) 104 | 105 | x = self.to_latent(x) 106 | return self.linear_head(x) 107 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os 4 | import h5py 5 | import pdb 6 | from sklearn.model_selection import train_test_split, KFold 7 | from sklearn.model_selection import StratifiedGroupKFold 8 | import torch 9 | 10 | def custom_collate_fn(batch): 11 | """Remove bad entries from the dataloader 12 | Args: 13 | batch (torch.Tensor): batch of tensors from the dataaset 14 | Returns: 15 | collate: Default collage for the dataloader 16 | """ 17 | batch = list(filter(lambda x: x[0] is not None, batch)) 18 | return torch.utils.data.dataloader.default_collate(batch) 19 | 20 | 21 | def filter_no_features(df, feature_path, feature_name): 22 | print(f'Filtering WSIs that do not have {feature_name} features') 23 | projects = np.unique(df.tcga_project) 24 | all_wsis_with_features = [] 25 | remove = [] 26 | for proj in projects: 27 | wsis_with_features = os.listdir(os.path.join(feature_path, proj)) 28 | for wsi in wsis_with_features: 29 | try: 30 | with h5py.File(os.path.join(feature_path, proj, wsi, wsi+'.h5'), "r") as f: 31 | cols = list(f.keys()) 32 | if feature_name not in cols: 33 | remove.append(wsi) 34 | except Exception as e: 35 | remove.append(wsi) 36 | all_wsis_with_features += wsis_with_features 37 | remove += df[~df['wsi_file_name'].isin(all_wsis_with_features)].wsi_file_name.values.tolist() 38 | print(f'Original shape: {df.shape}') 39 | df = df[~df['wsi_file_name'].isin(remove)].reset_index(drop=True) 40 | print(f'New shape: {df.shape}') 41 | return df 42 | 43 | 44 | def patient_split(dataset, random_state=0): 45 | """Perform patient split of any of the previously defined datasets. 46 | """ 47 | patients_unique = np.unique(dataset.patient_id) 48 | patients_train, patients_test = train_test_split( 49 | patients_unique, test_size=0.2, random_state=random_state) 50 | patients_train, patients_val = train_test_split( 51 | patients_train, test_size=0.2, random_state=random_state) 52 | 53 | indices = np.arange(len(dataset)) 54 | train_idx = indices[np.any(np.array(dataset.patient_id)[:, np.newaxis] == 55 | np.array(patients_train)[np.newaxis], axis=1)] 56 | valid_idx = indices[np.any(np.array(dataset.patient_id)[:, np.newaxis] == 57 | np.array(patients_val)[np.newaxis], axis=1)] 58 | test_idx = indices[np.any(np.array(dataset.patient_id)[:, np.newaxis] == 59 | np.array(patients_test)[np.newaxis], axis=1)] 60 | 61 | return train_idx, valid_idx, test_idx 62 | 63 | 64 | def match_patient_split(dataset, split): 65 | """Recover previously saved patient split 66 | """ 67 | train_patients, valid_patients, test_patients = split 68 | indices = np.arange(len(dataset)) 69 | train_idx = indices[np.any(dataset.patients[:, np.newaxis] == 70 | train_patients[np.newaxis], axis=1)] 71 | valid_idx = indices[np.any(dataset.patients[:, np.newaxis] == 72 | valid_patients[np.newaxis], axis=1)] 73 | test_idx = indices[np.any(dataset.patients[:, np.newaxis] == 74 | test_patients[np.newaxis], axis=1)] 75 | 76 | return train_idx, valid_idx, test_idx 77 | 78 | 79 | def patient_kfold(dataset, n_splits=5, random_state=0, valid_size=0.1): 80 | """Perform cross-validation with patient split. 81 | """ 82 | indices = np.arange(len(dataset)) 83 | 84 | patients_unique = np.unique(dataset.patient_id) 85 | 86 | skf = KFold(n_splits, shuffle=True, random_state=random_state) 87 | ind = skf.split(patients_unique) 88 | 89 | train_idx = [] 90 | valid_idx = [] 91 | test_idx = [] 92 | 93 | for k, (ind_train, ind_test) in enumerate(ind): 94 | 95 | patients_train = patients_unique[ind_train] 96 | patients_test = patients_unique[ind_test] 97 | 98 | test_idx.append(indices[np.any(np.array(dataset.patient_id)[:, np.newaxis] == 99 | np.array(patients_test)[np.newaxis], axis=1)]) 100 | 101 | if valid_size > 0: 102 | patients_train, patients_valid = train_test_split( 103 | patients_train, test_size=valid_size, random_state=0) 104 | valid_idx.append(indices[np.any(np.array(dataset.patient_id)[:, np.newaxis] == 105 | np.array(patients_valid)[np.newaxis], axis=1)]) 106 | 107 | train_idx.append(indices[np.any(np.array(dataset.patient_id)[:, np.newaxis] == 108 | np.array(patients_train)[np.newaxis], axis=1)]) 109 | 110 | return train_idx, valid_idx, test_idx 111 | 112 | 113 | def match_patient_kfold(dataset, splits): 114 | """Recover previously saved patient splits for cross-validation. 115 | """ 116 | 117 | indices = np.arange(len(dataset)) 118 | train_idx = [] 119 | valid_idx = [] 120 | test_idx = [] 121 | 122 | for train_patients, valid_patients, test_patients in splits: 123 | 124 | train_idx.append(indices[np.any(dataset.patients[:, np.newaxis] == 125 | train_patients[np.newaxis], axis=1)]) 126 | valid_idx.append(indices[np.any(dataset.patients[:, np.newaxis] == 127 | valid_patients[np.newaxis], axis=1)]) 128 | test_idx.append(indices[np.any(dataset.patients[:, np.newaxis] == 129 | test_patients[np.newaxis], axis=1)]) 130 | 131 | return train_idx, valid_idx, test_idx 132 | 133 | def exists(x): 134 | return x != None 135 | -------------------------------------------------------------------------------- /src/vit.py: -------------------------------------------------------------------------------- 1 | # Code from the awesome lucidrains: https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/simple_vit.py 2 | # with some of my own modifications 3 | 4 | import torch 5 | from torch import nn 6 | import os 7 | from tqdm import tqdm 8 | import numpy as np 9 | from sklearn.metrics import mean_absolute_error, mean_absolute_percentage_error, mean_squared_error 10 | 11 | from einops import rearrange 12 | from src.he2rna import compute_correlations 13 | import pdb 14 | 15 | 16 | def pair(t): 17 | return t if isinstance(t, tuple) else (t, t) 18 | 19 | def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32): 20 | _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype 21 | 22 | y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij') 23 | assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' 24 | omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1) 25 | omega = 1. / (temperature ** omega) 26 | 27 | y = y.flatten()[:, None] * omega[None, :] 28 | x = x.flatten()[:, None] * omega[None, :] 29 | pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1) 30 | return pe.type(dtype) 31 | 32 | def smape(A, F): 33 | return 100/len(A) * np.sum(2 * np.abs(F - A) / (np.abs(A) + np.abs(F))) 34 | 35 | # classes 36 | 37 | class FeedForward(nn.Module): 38 | def __init__(self, dim, hidden_dim): 39 | super().__init__() 40 | self.net = nn.Sequential( 41 | nn.LayerNorm(dim), 42 | nn.Linear(dim, hidden_dim), 43 | nn.GELU(), 44 | nn.Linear(hidden_dim, dim), 45 | ) 46 | def forward(self, x): 47 | return self.net(x) 48 | 49 | class Attention(nn.Module): 50 | def __init__(self, dim, heads = 8, dim_head = 64): 51 | super().__init__() 52 | inner_dim = dim_head * heads 53 | self.heads = heads 54 | self.scale = dim_head ** -0.5 55 | self.norm = nn.LayerNorm(dim) 56 | 57 | self.attend = nn.Softmax(dim = -1) 58 | 59 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 60 | self.to_out = nn.Linear(inner_dim, dim, bias = False) 61 | 62 | def forward(self, x): 63 | x = self.norm(x) 64 | 65 | qkv = self.to_qkv(x).chunk(3, dim = -1) 66 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 67 | 68 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 69 | 70 | attn = self.attend(dots) 71 | 72 | out = torch.matmul(attn, v) 73 | out = rearrange(out, 'b h n d -> b n (h d)') 74 | return self.to_out(out) 75 | 76 | class Transformer(nn.Module): 77 | def __init__(self, dim, depth, heads, dim_head, mlp_dim): 78 | super().__init__() 79 | self.layers = nn.ModuleList([]) 80 | for _ in range(depth): 81 | self.layers.append(nn.ModuleList([ 82 | Attention(dim, heads = heads, dim_head = dim_head), 83 | FeedForward(dim, mlp_dim) 84 | ])) 85 | def forward(self, x): 86 | for attn, ff in self.layers: 87 | x = attn(x) + x 88 | x = ff(x) + x 89 | return x 90 | 91 | class ViT(nn.Module): 92 | def __init__(self, *, num_outputs, dim, depth, heads, mlp_dim, dim_head = 64, 93 | num_clusters=100, device='cuda'): 94 | super().__init__() 95 | 96 | self.pos_emb1D = nn.Parameter(torch.randn(num_clusters, dim)) 97 | 98 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim) 99 | 100 | self.to_latent = nn.Identity() 101 | self.linear_head = nn.Sequential( 102 | nn.LayerNorm(dim), 103 | nn.Linear(dim, num_outputs) 104 | ) 105 | self.device = device 106 | 107 | def forward(self, x): 108 | #pe = posemb_sincos_2d(x) 109 | x = rearrange(x, 'b ... d -> b (...) d') + self.pos_emb1D 110 | 111 | x = self.transformer(x) 112 | x = x.mean(dim = 1) 113 | 114 | x = self.to_latent(x) 115 | return self.linear_head(x) 116 | 117 | def train(model, dataloaders, optimizer, accelerator=None, 118 | num_epochs=200, save_dir='exp/', patience=20, 119 | run=None, verbose=True, phases=['train', 'val'], split=None, 120 | save_on='loss', stop_on='loss', delta=0.5): 121 | 122 | if save_dir is not None and not os.path.exists(save_dir): 123 | os.mkdir(save_dir) 124 | if split: 125 | save_path = os.path.join(save_dir, f'model_best_{split}.pt') 126 | else: 127 | save_path = os.path.join(save_dir, 'model_best.pt') 128 | 129 | loss_fn = nn.MSELoss() 130 | epoch_since_best = 0 131 | best_loss = np.inf 132 | 133 | # these are for early stopping on loss + score 134 | early_stop_on_loss_triggered = 0 135 | epoch_since_best_score = 0 136 | best_score = 0 137 | epoch_since_ok_loss = 0 138 | 139 | for epoch in tqdm(range(num_epochs)): 140 | for phase in phases: 141 | if phase == 'train': 142 | model.train() 143 | else: 144 | model.eval() 145 | losses = { 146 | 'train': [], 147 | 'val': [] 148 | } 149 | maes = { 150 | 'train': [], 151 | 'val': [] 152 | } 153 | scores = { 154 | 'train': [], 155 | 'val': [] 156 | } 157 | 158 | for s, (image, rna_data, _, _) in enumerate(dataloaders[phase]): 159 | if image == []: continue 160 | image = image.to(model.device) 161 | rna_data = rna_data.to(model.device) 162 | 163 | with torch.set_grad_enabled(phase == 'train'): 164 | pred = model(image) 165 | 166 | loss = loss_fn(pred, rna_data) 167 | mae = mean_absolute_error(rna_data.detach().cpu().numpy(), pred.detach().cpu().numpy()) 168 | score = compute_correlations(rna_data.detach().cpu().numpy(), pred.detach().cpu().numpy()) 169 | 170 | losses[phase] += [loss.detach().cpu().numpy()] 171 | maes[phase] += [mae] 172 | scores[phase] += [score] 173 | 174 | if phase == 'train': 175 | optimizer.zero_grad() 176 | if accelerator: 177 | accelerator.backward(loss) 178 | else: 179 | loss.backward() 180 | optimizer.step() 181 | 182 | losses[phase] = np.mean(losses[phase]) 183 | maes[phase] = np.mean(maes[phase]) 184 | scores[phase] = np.mean(scores[phase]) 185 | 186 | if phase == 'val': 187 | suffix = 'id' 188 | else: 189 | suffix = '' 190 | 191 | if run: 192 | run.log({'epoch': epoch, f'score {phase}{suffix} {split}': scores[phase]}) 193 | run.log({'epoch': epoch, f'{phase}{suffix} loss fold {split}': losses[phase]}) 194 | run.log({'epoch': epoch, f'{phase}{suffix} mae fold {split}': maes[phase]}) 195 | 196 | if verbose: 197 | print(f'Epoch {epoch}: {phase} loss {losses[phase]} mae {maes[phase]}') 198 | 199 | if (phase == 'val') or (len(phases) == 1): 200 | 201 | # only relevant for early stopping on loss+corr 202 | if early_stop_on_loss_triggered == 1: 203 | if losses[phase] < (best_loss + delta): # we allow loss to deviate a little bit from optimal while continuing training for good correlation 204 | epoch_since_ok_loss = 0 205 | else: 206 | epoch_since_ok_loss += 1 207 | 208 | # relevant for both early stopping and model save on loss/loss+corr 209 | if losses[phase] < best_loss: 210 | best_loss = losses[phase] 211 | epoch_since_best = 0 212 | if save_on == 'loss': 213 | torch.save(model.state_dict(), save_path) 214 | elif (save_on == 'loss+corr') and (early_stop_on_loss_triggered == 0): # first save model based on loss, later overwrite if there is a model at epoch with loss close to best loss and better correlation 215 | torch.save(model.state_dict(), save_path) 216 | else: 217 | epoch_since_best += 1 218 | 219 | # for early stopping and model save based on loss+corr 220 | if scores[phase] > best_score: 221 | best_score = scores[phase] 222 | epoch_since_best_score = 0 223 | if (save_on == 'loss+corr') and (early_stop_on_loss_triggered == 1): 224 | torch.save(model.state_dict(), save_path) 225 | print(f'Saved model on loss+corr at epoch {epoch} of better score and loss within {delta} of optimal loss') 226 | else: 227 | epoch_since_best_score += 1 228 | 229 | if epoch_since_best == patience: 230 | early_stop_on_loss_triggered = 1 231 | if stop_on == 'loss': 232 | print(f'Early stopping at epoch {epoch}!') 233 | break 234 | 235 | if stop_on == 'loss+corr': 236 | if (early_stop_on_loss_triggered == 1) and (epoch_since_best_score == patience): 237 | print(f'Early stopping at epoch {epoch} because neither loss nor score is improving anymore!') 238 | break 239 | 240 | if (early_stop_on_loss_triggered == 1) and (epoch_since_ok_loss == patience): 241 | print(f'Early stopping at epoch {epoch} because loss is not within {delta} of best loss anymore!') 242 | break 243 | return model 244 | 245 | def evaluate(model, dataloader, run=None, verbose=True, suff=''): 246 | model.eval() 247 | loss_fn = nn.MSELoss() 248 | losses = [] 249 | preds = [] 250 | real = [] 251 | wsis = [] 252 | projs = [] 253 | maes = [] 254 | smapes = [] 255 | for image, rna_data, wsi_file_name, tcga_project in tqdm(dataloader): 256 | 257 | if image == []: continue 258 | 259 | image = image.to(model.device) 260 | rna_data = rna_data.to(model.device) 261 | wsis.append(wsi_file_name) 262 | projs.append(tcga_project) 263 | 264 | pred = model(image) 265 | preds.append(pred.detach().cpu().numpy()) 266 | loss = loss_fn(pred, rna_data) 267 | real.append(rna_data.detach().cpu().numpy()) 268 | mae = mean_absolute_error(rna_data.detach().cpu().numpy(), pred.detach().cpu().numpy()) 269 | smape_var = smape(rna_data.detach().cpu().numpy(), pred.detach().cpu().numpy()) 270 | losses += [loss.detach().cpu().numpy()] 271 | maes += [mae] 272 | smapes += [smape_var] 273 | 274 | losses = np.mean(losses) 275 | maes = np.mean(maes) 276 | smapes = np.mean(smapes) 277 | if run: 278 | run.log({f'test_loss'+suff: losses}) 279 | run.log({f'test_MAE'+suff: maes}) 280 | run.log({f'test_MAPE'+suff: smapes}) 281 | if verbose: 282 | print(f'Test loss: {losses}') 283 | print(f'Test MAE: {mae}') 284 | print(f'Test MAPE: {smapes}') 285 | 286 | preds = np.concatenate((preds), axis=0) 287 | real = np.concatenate((real), axis=0) 288 | wsis = np.concatenate((wsis), axis=0) 289 | projs = np.concatenate((projs), axis=0) 290 | 291 | return preds, real, wsis, projs 292 | 293 | def predict(model, dataloader, run=None, verbose=True): 294 | model.eval() 295 | preds = [] 296 | wsis = [] 297 | projs = [] 298 | for image, rna_data, wsi_file_name, tcga_project in tqdm(dataloader): 299 | if image == []: continue 300 | image = image.to(model.device) 301 | wsis.append(wsi_file_name) 302 | projs.append(tcga_project) 303 | 304 | pred = model(image) 305 | preds.append(pred.detach().cpu().numpy()) 306 | 307 | preds = np.concatenate((preds), axis=0) 308 | wsis = np.concatenate((wsis), axis=0) 309 | projs = np.concatenate((projs), axis=0) 310 | 311 | return preds, wsis, projs --------------------------------------------------------------------------------