├── LICENSE ├── README.md ├── clustering.ipynb ├── configs ├── multimodal.yaml ├── robustness.yaml └── unimodal.yaml ├── data ├── README.md ├── extract_brain_embeddings.py ├── files │ ├── clinical_data.tsv │ └── supplementary_data.tsv ├── get_wsi_thumbnails.py ├── mappings │ └── wsi_mapping.json ├── preprocessing.ipynb ├── rna_preprocessors │ ├── trf_0.joblib │ ├── trf_1.joblib │ ├── trf_2.joblib │ ├── trf_3.joblib │ └── trf_4.joblib ├── run_mri_pretraining.py ├── run_wsi_pretraining.py └── scripts │ ├── gdc_manifest_20230918_WSI_LGG.txt │ ├── gdc_manifest_20231124_WSI_GBM.txt │ ├── gdc_manifest_20231204_RNA_GBM_LGG.txt │ └── gdc_manifest_20231205_DNAm_GBM_LGG.txt ├── drim ├── __init__.py ├── commons │ ├── __init__.py │ ├── losses.py │ └── optim_contrastive.py ├── datasets.py ├── dnam │ ├── __init__.py │ ├── datasets.py │ └── models.py ├── fusion │ ├── __init__.py │ ├── fusion.py │ └── utils.py ├── helpers.py ├── logger.py ├── losses.py ├── models.py ├── mri │ ├── __init__.py │ ├── datasets.py │ ├── models.py │ └── transforms.py ├── multimodal │ ├── __init__.py │ ├── datasets.py │ └── models.py ├── rna │ ├── __init__.py │ ├── datasets.py │ └── models.py ├── trainers.py ├── utils.py └── wsi │ ├── __init__.py │ ├── datasets.py │ ├── func.py │ ├── models.py │ └── transforms.py ├── requirements.txt ├── robustness.py ├── static └── DRIM.png ├── train_aux_multimodal.py ├── train_drimsurv.py ├── train_drimu.py ├── train_unimodal.py └── train_vanilla_multimodal.py /README.md: -------------------------------------------------------------------------------- 1 | # DRIM: Learning Disentangled Representations from Incomplete Multimodal Healthcare Data 2 | 3 | **[Go to pdf](https://papers.miccai.org/miccai-2024/paper/1276_paper.pdf)**
4 | This is the code associated to the paper **DRIM: Learning Disentangled Representations from Incomplete Multimodal Healthcare Data** (accepted to [MICCAI2024](https://conferences.miccai.org/2024/en/)). 5 |
6 | ![DRIM](static/DRIM.png) 7 | __________________ 8 | # Data Preprocessing 9 | 10 | Navigate to the `data` folder to deal with the download and the preprocessing of the data. 11 | The maximum file size limit of the supplemental does not allow the WSI pre-trained models to be supplied. So only the MRI pre-trained model is provided in `data/models/`. 12 | 13 | **N.B**: The data used is from the TCGA-GBMLGG and must have the following structure: 14 | 15 | ``` 16 | ├── GBMLGG 17 | │ ├── MRI 18 | │ ├── WSI 19 | │ ├── DNAm 20 | │ ├── RNA 21 | ``` 22 | Also in the `data` folder you must have, for each sample a `files/train_brain.csv` and a `files/test_brain.csv` containing for each modality, the path through the corresponding file. But again, check the `README.md` from the `data` folder for further explanations. 23 | 24 | ________ 25 | # Quickstart 26 | Start by creating a dedicated conda environment and installing all the required packages 27 | ``` 28 | $ conda create python=3.10.10 --name drim 29 | $ conda activate drim 30 | $ pip install -r requirements.txt 31 | ``` 32 | **N.B:** This repository relies on [Hydra](https://hydra.cc/docs/intro/) to run experiments and benefits from all the advantages it offers (*multiruns, sweeps, flexibility...*) 33 | 34 | ________ 35 | # Training loops 36 | Here are a number of scripts for training and cross-validation. 37 | 38 | **Unimodal** 39 |
40 | For example, to launch a unimodal training session on each of the modalities, using Hydra's features, you can enter the following in the CLI: 41 | ``` 42 | $ python train_unimodal.py -m general.modalities=DNAm,WSI,RNA,MRI 43 | ``` 44 | 45 | **Simple Multimodal** 46 |
47 | To use a simple fusion architecture, having the choice for each of $h_u$ and $h_s$ 48 | ``` 49 | $ python train_vanilla_multimodal.py fusion.name=maf 50 | ``` 51 | But once again, to use all the possible functions 52 | ``` 53 | $ python train_vanilla_multimodal.py -m fusion.name=mean,concat,masked_mean,tensor,sum,max,maf 54 | ``` 55 | 56 | **Auxiliary Multimodal** 57 |
58 | In this repository, in addition to the fusion, you can add a cost function to the $s_i^m$ representations. This can be either the cost function used in DRIM directly adapted from the Supervised Contrastive loss function (1) or the MMO loss function (2) which tends to orthogonalise the representations. 59 | (see the paper for more details on the notations.) 60 | ``` 61 | $ python train_aux_multimodal.py aux_loss.name=contrastive 62 | ``` 63 | ``` 64 | $ python train_aux_multimodal.py aux_loss.name=mmo aux_loss.alpha=0.5 65 | ``` 66 | ## DRIM 67 | The DRIM method is very flexible and can be adapted to any type of task. For example, we offer two alternatives DRIMSurv and DRIMU. 68 |
69 |
70 | **DRIMSurv** 71 |
72 | DRIMSurv enables end-to-end training of a survival model by training the encoders at the same time as the higher attention blocks needed to merge the various representations. (Fig 1c.) 73 | ``` 74 | $ python train_drimsurv.py 75 | ``` 76 | 77 | **DRIMU** 78 |
79 | The other alternative DRIMU does not require a label and is trained using a reconstruction scheme. The code provided allows the backbone encoders to be pre-trained and then frozen for finetuning over 10 epochs on a survival task with a fusion scheme similar to DRIMSurv. (Fig 1b.) 80 | ``` 81 | $ python train_drimu.py 82 | ``` 83 | 84 | ## Robustness 85 | Once each model has been trained on the 5 splits, it is possible to test how they perform and how robust they are to different combinations of input modalities. Simply run the `robustness.py` script, which is based on the `configs/robustness.yaml` configuration file, and look at the output CS score performance. 86 | To check all available methods: 87 | ``` 88 | $ python robustness -m method=drim,max,tensor,concat 89 | ``` 90 | 91 | ## Stratifying high-risk vs low-risk patients 92 | All the code related to this stratisfication as well as the results of the logrank test can be found in the notebook `clustering.ipynb` 93 | 94 | ________ 95 | # Logging 96 | All the training sessions have been logged on Weights and Biases. You can also use [Weights and Biases](https://wandb.ai/) to track the progress of your runs. To do this, add your project and entity to the configuration files `configs/unimodal.yaml` and `configs/multimodal.yaml`: 97 | 98 | ``` 99 | wandb: 100 | entity: MyAwesomeNickname 101 | project: MyAwesomeProject 102 | ```` 103 | Otherwise, training progression and evaluation will be displayed directly in the console. 104 | 105 | ________ 106 | # Citation 107 | ``` 108 | @InProceedings{robinet_drim_2024, 109 | author = { Robinet, Lucas and Berjaoui, Ahmad and Kheil, Ziad and Cohen-Jonathan Moyal, Elizabeth}, 110 | title = { { DRIM: Learning Disentangled Representations from Incomplete Multimodal Healthcare Data } }, 111 | booktitle = {proceedings of Medical Image Computing and Computer Assisted Intervention -- MICCAI 2024}, 112 | year = {2024}, 113 | publisher = {Springer Nature Switzerland}, 114 | } 115 | ``` 116 | -------------------------------------------------------------------------------- /clustering.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d23a603f", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "! pip install lifelines\n", 11 | "! pip install scikit-survival" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "bb41018f", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "# Standard libraries\n", 22 | "from collections import defaultdict\n", 23 | "\n", 24 | "# Third-party libraries\n", 25 | "import pandas as pd\n", 26 | "import torch\n", 27 | "import numpy as np\n", 28 | "from omegaconf import DictConfig, OmegaConf\n", 29 | "from sksurv.nonparametric import kaplan_meier_estimator\n", 30 | "import matplotlib.pyplot as plt\n", 31 | "from lifelines.statistics import logrank_test\n", 32 | "\n", 33 | "# Local dependencies\n", 34 | "from drim.helpers import get_datasets, get_targets, get_encoder\n", 35 | "from drim.utils import log_transform, seed_everything, get_dataframes, prepare_data, interpolate_dataframe\n", 36 | "from drim.multimodal import MultimodalDataset, DRIMSurv, MultimodalModel\n", 37 | "from drim.datasets import SurvivalDataset\n", 38 | "from drim.models import MultimodalWrapper\n", 39 | "\n", 40 | "\n", 41 | "seed = 1999\n", 42 | "n_outs = 20\n", 43 | "import matplotlib.pyplot as plt\n", 44 | "plt.rcParams[\"font.family\"] = \"serif\"" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "b81dfa02", 51 | "metadata": { 52 | "scrolled": false 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "cfg = OmegaConf.load('./configs/robustness.yaml')\n", 57 | "for method in ['tensor', 'concat', 'max', 'drim']:\n", 58 | " if method == 'tensor':\n", 59 | " cfg.general.dim = 32\n", 60 | " else:\n", 61 | " cfg.general.dim = 128\n", 62 | " scores = []\n", 63 | " for fold in range(5):\n", 64 | " seed_everything(seed)\n", 65 | " dataframes = get_dataframes(fold)\n", 66 | " dataframes = {split: prepare_data(dataframe, ['DNAm', 'WSI', 'RNA', 'MRI']) for split, dataframe in dataframes.items()}\n", 67 | " test_datasets = {}\n", 68 | " encoders = {}\n", 69 | " if method == 'drim':\n", 70 | " encoders_u = {}\n", 71 | "\n", 72 | " for modality in ['DNAm', 'WSI', 'RNA', 'MRI']:\n", 73 | " datasets = get_datasets(dataframes, modality, fold, return_mask=True)\n", 74 | " test_datasets[modality] = datasets['test']\n", 75 | " encoder = get_encoder(modality, cfg).cuda()\n", 76 | " encoders[modality] = encoder\n", 77 | " if method == 'drim':\n", 78 | " encoder_u = get_encoder(modality, cfg).cuda()\n", 79 | " encoders_u[modality] = encoder_u\n", 80 | " \n", 81 | " targets, cut = get_targets(dataframes, cfg.general.n_outs)\n", 82 | " dataset_test = MultimodalDataset(test_datasets, return_mask=True)\n", 83 | " test_data = SurvivalDataset(dataset_test, *targets['test'])\n", 84 | " loader = torch.utils.data.DataLoader(test_data, shuffle=False, batch_size=24)\n", 85 | " if method == 'drim':\n", 86 | " from drim.fusion import MaskedAttentionFusion\n", 87 | " fusion = MaskedAttentionFusion(dim=cfg.general.dim, depth=1, heads=16, dim_head=64, mlp_dim=128)\n", 88 | " fusion_u = MaskedAttentionFusion(dim=cfg.general.dim, depth=1, heads=16, dim_head=64, mlp_dim=128)\n", 89 | " fusion.cuda()\n", 90 | " fusion_u.cuda()\n", 91 | " encoder = DRIMSurv(encoders_sh=encoders, encoders_u=encoders_u, fusion_s=fusion, fusion_u=fusion_u)\n", 92 | " model = MultimodalWrapper(encoder, embedding_dim=cfg.general.dim, n_outs=cfg.general.n_outs)\n", 93 | " model.load_state_dict(torch.load(f'./models/drimsurv_split_{int(fold)}.pth'))\n", 94 | " else:\n", 95 | " if method == 'max':\n", 96 | " from drim.fusion import ShallowFusion\n", 97 | " fusion = ShallowFusion('max')\n", 98 | " elif method == 'tensor':\n", 99 | " from drim.fusion import TensorFusion\n", 100 | " fusion = TensorFusion(modalities=['DNAm', 'WSI', 'RNA', 'MRI'], input_dim=cfg.general.dim, projected_dim=cfg.general.dim, output_dim=cfg.general.dim, dropout=0.)\n", 101 | " elif method == 'concat':\n", 102 | " from drim.fusion import ShallowFusion\n", 103 | " fusion = ShallowFusion('concat')\n", 104 | " elif method == 'maf':\n", 105 | " from drim.fusion import MaskedAttentionFusion\n", 106 | " fusion = MaskedAttentionFusion(dim=cfg.general.dim, depth=1, heads=16, dim_head=64, mlp_dim=128)\n", 107 | " \n", 108 | " fusion.cuda()\n", 109 | " encoder = MultimodalModel(encoders, fusion= fusion)\n", 110 | " if method == 'concat':\n", 111 | " size = cfg.general.dim * 4\n", 112 | " else:\n", 113 | " size = cfg.general.dim\n", 114 | " model = MultimodalWrapper(encoder, embedding_dim=size, n_outs=cfg.general.n_outs)\n", 115 | " if method == 'max':\n", 116 | " prefix = 'vanilla'\n", 117 | " else:\n", 118 | " prefix = 'aux_mmo'\n", 119 | " \n", 120 | " model.load_state_dict(torch.load(f'./models/{prefix}_{method}_split_{int(fold)}.pth'))\n", 121 | "\n", 122 | " \n", 123 | " model.cuda()\n", 124 | " model.eval()\n", 125 | "\n", 126 | " \n", 127 | " hazards = []\n", 128 | "\n", 129 | " with torch.no_grad():\n", 130 | " for batch in loader:\n", 131 | " data, time, event = batch\n", 132 | " data, mask = data\n", 133 | " outputs = model(data, mask, return_embedding=False)\n", 134 | " hazards.append(outputs)\n", 135 | " \n", 136 | " hazards = interpolate_dataframe(pd.DataFrame((1 - torch.cat(hazards, dim=0).sigmoid()).add(1e-7).log().cumsum(1).exp().cpu().numpy().transpose(), cut))\n", 137 | " scores.append(-hazards.sum(0).values)\n", 138 | " \n", 139 | " scores = torch.from_numpy(np.stack(scores).mean(0))\n", 140 | " high_scores = scores > scores.median()\n", 141 | " low_scores = scores <= scores.median()\n", 142 | " data_high = dataframes['test'].iloc[high_scores.numpy()]\n", 143 | " data_low = dataframes['test'].iloc[low_scores.numpy()]\n", 144 | " results = logrank_test(data_high[\"time\"], data_low[\"time\"], data_high[\"event\"], data_low[\"event\"])\n", 145 | "\n", 146 | " p_value = results.p_value\n", 147 | " fig = plt.figure(dpi=500)\n", 148 | " time, survival_prob, conf_int = kaplan_meier_estimator(\n", 149 | " data_low[\"event\"].astype(bool),\n", 150 | " data_low[\"time\"],\n", 151 | " conf_type=\"log-log\",\n", 152 | " )\n", 153 | "\n", 154 | " plt.step(time, survival_prob, where=\"post\", label=f\"Low risk\")\n", 155 | " plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step=\"post\")\n", 156 | " time, survival_prob, conf_int = kaplan_meier_estimator(\n", 157 | " data_high[\"event\"].astype(bool),\n", 158 | " data_high[\"time\"],\n", 159 | " conf_type=\"log-log\",\n", 160 | " )\n", 161 | "\n", 162 | " plt.step(time, survival_prob, where=\"post\", label=f\"High risk\")\n", 163 | " plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step=\"post\")\n", 164 | " plt.ylim(0, 1)\n", 165 | " plt.ylabel(r\"Probability of survival $\\hat{S}(t)$\")\n", 166 | " plt.xlabel(\"Time $t$\") \n", 167 | " plt.legend(loc=\"best\")\n", 168 | " if method == 'tensor' or method == 'concat':\n", 169 | " title = method.capitalize()\n", 170 | " title += ' w/ MMO'\n", 171 | " else:\n", 172 | " title = method.capitalize()\n", 173 | " plt.title(title)\n", 174 | " p_value = f'{p_value:0.2}'\n", 175 | " plt.text(11,0.8,r'$p_{value}$='+p_value, fontsize=12.8)\n", 176 | " plt.grid()\n", 177 | " plt.savefig(f'{method}.pdf')\n", 178 | " #tikzplotlib_fix_ncols(fig)\n", 179 | " #tikzplotlib.save(f'{method}.tikz')\n", 180 | " plt.show()" 181 | ] 182 | } 183 | ], 184 | "metadata": { 185 | "kernelspec": { 186 | "display_name": "Python 3 (ipykernel)", 187 | "language": "python", 188 | "name": "python3" 189 | }, 190 | "language_info": { 191 | "codemirror_mode": { 192 | "name": "ipython", 193 | "version": 3 194 | }, 195 | "file_extension": ".py", 196 | "mimetype": "text/x-python", 197 | "name": "python", 198 | "nbconvert_exporter": "python", 199 | "pygments_lexer": "ipython3", 200 | "version": "3.10.10" 201 | } 202 | }, 203 | "nbformat": 4, 204 | "nbformat_minor": 5 205 | } 206 | -------------------------------------------------------------------------------- /configs/multimodal.yaml: -------------------------------------------------------------------------------- 1 | 2 | general: 3 | seed: 1999 4 | val_split: 4 5 | modalities: [DNAm, WSI, RNA, MRI] 6 | n_outs: 20 7 | dropout: 0.2 8 | epochs: 30 9 | epochs_finetune: 10 10 | dim: 128 11 | save_path: ./models/ 12 | n_folds: 5 13 | 14 | dataloader: 15 | batch_size: 24 16 | pin_memory: true 17 | num_workers: 40 18 | persistent_workers: true 19 | 20 | fusion: 21 | name: maf 22 | params: 23 | depth: 1 24 | heads: 16 25 | dim_head: 64 26 | mlp_dim: 128 27 | 28 | optimizer: 29 | name: AdamW 30 | params: 31 | lr: 1e-3 32 | weight_decay: 1e-2 33 | 34 | scheduler: 35 | T_max: ${general.epochs} 36 | eta_min: 5e-6 37 | 38 | aux_loss: 39 | name: contrastive 40 | alpha: 1. 41 | 42 | disentangled: 43 | gamma: 0.8 44 | dsm_lr: 1e-3 45 | dsm_wd: 3e-4 -------------------------------------------------------------------------------- /configs/robustness.yaml: -------------------------------------------------------------------------------- 1 | method: tensor 2 | general: 3 | dim: 128 4 | n_outs: 20 5 | dropout: 0. 6 | seed: 1999 -------------------------------------------------------------------------------- /configs/unimodal.yaml: -------------------------------------------------------------------------------- 1 | 2 | general: 3 | seed: 1999 4 | val_split: 4 5 | modalities: DNAm 6 | n_outs: 20 7 | dropout: 0.2 8 | epochs: 50 9 | save_path: ./models/ 10 | dim: 128 11 | n_folds: 5 12 | 13 | dataloader: 14 | batch_size: 24 15 | pin_memory: true 16 | num_workers: 40 17 | persistent_workers: true 18 | 19 | optimizer: 20 | name: AdamW 21 | params: 22 | lr: 1e-3 23 | weight_decay: 1e-2 24 | 25 | scheduler: 26 | T_max: ${general.epochs} 27 | eta_min: 5e-6 28 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data 2 | All data used in this study are from The Cancer Genome Atlas (TCGA) program [(Weinstein et al., Nat Genet 2013)](https://www.nature.com/articles/ng.2764). They are all publicly available on the [GDC Data Portal](https://portal.gdc.cancer.gov/). 3 | 4 | # Structure 5 | The structure of this given part is organised as follows 6 | ``` 7 | ├── scripts 8 | │ ├── gdc_manifest_20230619_DNAm_GBM_LGG.txt 9 | │ ├── gdc_manifest_20230918_WSI_LGG.txt 10 | │ ├── gdc_manifest_20231124_WSI_GBM.txt 11 | │ ├── gdc_manifest_20231203_RNA_GBM_LGG.txt 12 | ├── files 13 | │ ├── clinical_data.tsv 14 | │ ├── supplementary_data.tsv 15 | │ ├── ... <- preprocessed dataframes created from preprocessing.ipynb 16 | ├── mappings 17 | │ ├── wsi_mapping.json 18 | │ ├── BraTS2023_2017_GLI_Mapping.xlsx 19 | ├── models 20 | │ ├── wsi_encoder.pth 21 | │ ├── t1ce_flair_tumorTrue.pth 22 | │ ├── ... 23 | ├── rna_preprocessors 24 | │ ├── trf_0.joblib 25 | │ ├── trf_1.joblib 26 | │ ├── trf_2.joblib 27 | │ ├── trf_3.joblib 28 | │ ├── trf_4.joblib 29 | ├── get_wsi_thumbnails.py 30 | ├── run_mri_pretraining.py 31 | ├── run_wsi_pretraining.py 32 | ├── extract_brain_embeddings.py 33 | ├── preprocessing.ipynb 34 | `````` 35 | The `scripts` folder contains all the manifest files needed to download data from the [GDC Data Portal](https://portal.gdc.cancer.gov/) (see their documentation). 36 | 37 | For instance suppose we want data to be stocked in the `~/TCGA/GBMLGG` folder. To download RNASeq data, one can use the following command-line: 38 | ``` 39 | $ mkdir ~/TCGA/GBMLGG/raw_RNA 40 | $ sudo /opt/gdc-client download \ 41 | -d ~/TCGA/GBMLGG/raw_RNA/ \ 42 | -m scripts/gdc_manifest_20231203_RNA_GBM_LGG.txt 43 | ``` 44 | For data pre-processing we proceeded in a similar way to [Vale-Silva et al., 2022](https://www.nature.com/articles/s41598-021-92799-4). Thus the `clinical_data.tsv` file is identical. The supplementary data (`files/supplementary_data.tsv`) contains information about the methylation and IDH status of patients. 45 | It can be obtained using the following R code: 46 | 47 | ``` 48 | library(TCGAWorkflowData) 49 | library(DT) 50 | library(TCGAbiolinks) 51 | 52 | gdc <- TCGAquery_subtype(tumor='all') 53 | output_path <- 'files/supplementary_data.tsv' 54 | readr::write_tsv(gdc, output_path) 55 | ``` 56 | - In the `mappings` folder, the `wsi_mapping.json` file gives the path to the corresponding WSI for each patient (see `preprocessing.ipynb` for its construction). 57 | - The `BraTS2023_2017_GLI_Mapping.xlsx` file is used to link the BraTS competition MRIs to the TCGA patients, you can download it directly on the [BraTS competition homepage](https://www.synapse.org/#!Synapse:syn51156910/wiki/621282) by signing up to the competition or through The Cancer Imaging Archive (TCIA). 58 | 59 | The `models` folder is immediately self-explanatory: it contains pretrained encoder used for the multimodal training. The maximum file size limit of the supplemental does not allow the WSI pre-trained models to be supplied. So only the MRI pre-trained model is provided. 60 | 61 | # Preprocessing 62 | All the pre-processing is detailed in the file `preprocessing.ipynb` 63 | All the python files in the root are used to build the final data file used later. Their function is described in more detail in the notebook. To make a long story short: 64 | - `get_wsi_thumbnails.py`: is used to obtain a tissue mask and a low-resolution image for each WSI. 65 | - `run_mri_pretraining.py`: is used to train the specific encoder for MRI data. 66 | - `run_wsi_pretraining.py`: is used to train the specific encoder for WSI data. 67 | - `extract_brain_embeddings.py`: is used to extract embeddings from MRI data, where each volume is centered around the tumor and contains only t1ce and flair modalities (2x64x64x64). 68 | 69 | The `rna_preprocessors` folder contains all the objects to be used to preprocess the data according to the validation split. 70 | The script generating these files, which will be used throughout the training sessions, is present in `preprocessing.ipynb`. 71 | # Acknowledgements 72 | As a large part of this folder is largely inspired by [Vale-Silva's MultiSurv paper](https://www.nature.com/articles/s41598-021-92799-4), it seems important to thank him again for his very clear and reproducible work and to point it out. [Vale-Silva's GitHub repository](https://github.com/luisvalesilva/multisurv/tree/master/data) -------------------------------------------------------------------------------- /data/extract_brain_embeddings.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | import torch 4 | import csv 5 | from drim.mri.transforms import tumor_transfo 6 | from drim.mri.datasets import MRIProcessor 7 | from drim.mri.models import MRIEncoder 8 | from drim.utils import clean_state_dict 9 | import tqdm 10 | import os 11 | 12 | # Load the data 13 | data = pd.read_csv("./data/files/dataframe_brain.csv") 14 | encoder = MRIEncoder(2, 512, False) 15 | encoder.load_state_dict( 16 | clean_state_dict(torch.load("data/models/t1ce-flair_tumorTrue.pth")), strict=False 17 | ) 18 | encoder.eval() 19 | for mri_path in tqdm.tqdm(data.MRI): 20 | if pd.isna(mri_path): 21 | continue 22 | process = MRIProcessor( 23 | mri_path, 24 | tumor_centered=True, 25 | transform=tumor_transfo, 26 | modalities=["t1ce", "flair"], 27 | size=(64, 64, 64), 28 | ) 29 | mri = process.process().unsqueeze(0) 30 | with torch.no_grad(): 31 | embedding = encoder(mri) 32 | 33 | # Save the embedding 34 | with open(os.path.join(mri_path, "embedding.csv"), "w") as f: 35 | writer = csv.writer(f) 36 | writer.writerow([str(i) for i in range(512)]) 37 | writer.writerow(embedding.squeeze().numpy().tolist()) 38 | -------------------------------------------------------------------------------- /data/get_wsi_thumbnails.py: -------------------------------------------------------------------------------- 1 | import pyvips 2 | import os 3 | import tqdm 4 | import cv2 5 | import numpy as np 6 | from typing import Tuple 7 | from PIL import Image 8 | import argparse 9 | 10 | 11 | def segment( 12 | img_rgba: np.ndarray, 13 | sthresh: int = 25, 14 | sthresh_up: int = 255, 15 | mthresh: int = 9, 16 | otsu: bool = True, 17 | ) -> Tuple[np.ndarray, np.ndarray]: 18 | img = cv2.cvtColor(img_rgba, cv2.COLOR_RGBA2RGB) 19 | 20 | img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) # Convert to HSV space 21 | # img_med = filters.median(img_hsv))#, np.ones((3, 3,3)) 22 | img_med = cv2.medianBlur(img_hsv[:, :, 1], mthresh) # Apply median blurring 23 | # img_med = img_hsv 24 | if otsu: 25 | _, img_otsu = cv2.threshold( 26 | img_med, 0, sthresh_up, cv2.THRESH_OTSU + cv2.THRESH_BINARY 27 | ) 28 | else: 29 | _, img_otsu = cv2.threshold(img_med, sthresh, sthresh_up, cv2.THRESH_BINARY) 30 | masked_image = cv2.bitwise_and(img, img, mask=img_otsu) 31 | return masked_image, img_otsu 32 | 33 | 34 | def main(data_path, downscale_factor): 35 | """ 36 | This script creates thumbnails and low resolution masks of the WSI images. 37 | The data folder must be organized as follows (from a TCGA download): 38 | - data_path 39 | - id1 40 | - filename1.svs 41 | - ... 42 | - id2 43 | - filename2.svs 44 | - ... 45 | - ... 46 | 47 | Args: 48 | data_path: Path to the input directory. 49 | downscale_factor: Downscale factor for x20 magnification. 50 | 51 | Returns: 52 | None 53 | 54 | ---- 55 | Example usage: 56 | python get_wsi_thumbnails.py --data_path /data/ --downscale_factor 6 57 | 58 | """ 59 | subdirectories = os.listdir(data_path) 60 | for subdirectory in tqdm.tqdm(subdirectories): 61 | subdirectory_path = os.path.join(data_path, subdirectory) 62 | filenames = os.listdir(subdirectory_path) 63 | wsi_filename = [f for f in filenames if f.endswith("svs") or f.endswith("tif")][ 64 | 0 65 | ] 66 | slide = pyvips.Image.new_from_file( 67 | os.path.join(subdirectory_path, wsi_filename) 68 | ) 69 | if int(float(slide.get("aperio.AppMag"))) == 40: 70 | d = downscale_factor + 1 71 | else: 72 | d = downscale_factor 73 | thumbnail = pyvips.Image.thumbnail( 74 | os.path.join(subdirectory_path, wsi_filename), 75 | slide.width / (2**d), 76 | height=slide.height / (2**d), 77 | ).numpy() 78 | 79 | thumbnail = cv2.cvtColor(thumbnail, cv2.COLOR_RGBA2RGB) 80 | thumbnail_hsv = cv2.cvtColor(thumbnail, cv2.COLOR_RGB2HSV) 81 | # to filter out felt-tip marks 82 | mask_hsv = np.tile(thumbnail_hsv[:, :, 1] < 160, (3, 1, 1)).transpose(1, 2, 0) 83 | thumbnail *= mask_hsv 84 | masked_image, mask = segment(thumbnail) 85 | masked_image = Image.fromarray(masked_image).convert("RGB") 86 | # save 87 | masked_image.save(os.path.join(subdirectory_path, "thumbnail.jpg")) 88 | np.save(os.path.join(subdirectory_path, "mask.npy"), mask) 89 | 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser( 93 | description="Create thumbnails and low resolution masks of the WSI images." 94 | ) 95 | parser.add_argument("--data_path", "-p", help="Path to the input directory.") 96 | parser.add_argument( 97 | "--downscale_factor", 98 | "-d", 99 | type=int, 100 | help="Downscale factor for x20 magnification.", 101 | ) 102 | args = parser.parse_args() 103 | main(args.data_path, args.downscale_factor) 104 | -------------------------------------------------------------------------------- /data/rna_preprocessors/trf_0.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/data/rna_preprocessors/trf_0.joblib -------------------------------------------------------------------------------- /data/rna_preprocessors/trf_1.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/data/rna_preprocessors/trf_1.joblib -------------------------------------------------------------------------------- /data/rna_preprocessors/trf_2.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/data/rna_preprocessors/trf_2.joblib -------------------------------------------------------------------------------- /data/rna_preprocessors/trf_3.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/data/rna_preprocessors/trf_3.joblib -------------------------------------------------------------------------------- /data/rna_preprocessors/trf_4.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/data/rna_preprocessors/trf_4.joblib -------------------------------------------------------------------------------- /data/run_mri_pretraining.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | import argparse 3 | import os 4 | 5 | # Third-party libraries 6 | import torch 7 | import wandb 8 | import numpy as np 9 | import pandas as pd 10 | 11 | # Local libraries 12 | from drim.utils import seed_everything, seed_worker 13 | from drim.logger import logger 14 | from drim.mri.datasets import DatasetBraTSTumorCentered 15 | from drim.mri.models import MRIEncoder 16 | from drim.mri.transforms import get_tumor_transforms 17 | from drim.commons.optim_contrastive import training_loop_contrastive 18 | from drim.commons.losses import ContrastiveLoss 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--data_path", type=str, default="../TCGA/GBMLGG/MRI") 24 | parser.add_argument("--batch_size", type=int, default=48) 25 | parser.add_argument("--epochs", type=int, default=30) 26 | parser.add_argument("--lr", type=float, default=1e-4) 27 | parser.add_argument("--modalities", nargs="+", default=["t1ce", "flair"]) 28 | parser.add_argument("--weight_decay", type=float, default=1e-6) 29 | parser.add_argument("--entity", type=str, default=None) 30 | parser.add_argument("--project", type=str, default="DRIM") 31 | parser.add_argument("--temperature", type=float, default=0.07) 32 | parser.add_argument("--tumor_centered", type=bool, default=True) 33 | parser.add_argument("--n_cpus", type=int, default=40) 34 | parser.add_argument("--n_gpus", type=int, default=4) 35 | parser.add_argument("--k", type=int, default=3) 36 | parser.add_argument("--seed", type=int, default=1999) 37 | args = parser.parse_args() 38 | 39 | # Set seed 40 | seed_everything(args.seed) 41 | patients = os.listdir(args.data_path) 42 | dataframe = pd.read_csv("data/files/dataframe_brain.csv") 43 | dataframe_test = dataframe[dataframe["group"] == "test"] 44 | # get patient ids where MRI is not NaN 45 | dataframe_test = dataframe_test[~dataframe_test["MRI"].isna()] 46 | patients_to_exclude = [ 47 | patient_path.split("/")[-1] for patient_path in dataframe_test.MRI.values 48 | ] 49 | patients = [patient for patient in patients if patient not in patients_to_exclude] 50 | 51 | # split patient into train and val by taking random 70% of patients for training 52 | train_patients = np.random.choice( 53 | patients, int(len(patients) * 0.75), replace=False 54 | ) 55 | val_patients = [patient for patient in patients if patient not in train_patients] 56 | logger.info("Performing contrastive training on BraTS dataset.") 57 | logger.info("Modalities used : {}.", args.modalities) 58 | if bool(args.tumor_centered): 59 | logger.info("Using tumor centered dataset.") 60 | category = "tumor" 61 | sizes = (64, 64, 64) 62 | train_dataset = DatasetBraTSTumorCentered( 63 | args.data_path, 64 | args.modalities, 65 | patients=train_patients, 66 | sizes=sizes, 67 | return_mask=False, 68 | transform=get_tumor_transforms(sizes), 69 | ) 70 | val_dataset = DatasetBraTSTumorCentered( 71 | args.data_path, 72 | args.modalities, 73 | patients=val_patients, 74 | sizes=sizes, 75 | return_mask=False, 76 | transform=get_tumor_transforms(sizes), 77 | ) 78 | else: 79 | raise NotImplementedError 80 | 81 | train_loader = torch.utils.data.DataLoader( 82 | train_dataset, 83 | batch_size=args.batch_size, 84 | shuffle=True, 85 | pin_memory=True, 86 | num_workers=args.n_cpus, 87 | persistent_workers=True, 88 | worker_init_fn=seed_worker, 89 | ) 90 | val_loader = torch.utils.data.DataLoader( 91 | val_dataset, 92 | batch_size=args.batch_size, 93 | shuffle=False, 94 | num_workers=args.n_cpus, 95 | pin_memory=True, 96 | persistent_workers=True, 97 | worker_init_fn=seed_worker, 98 | ) 99 | 100 | model = MRIEncoder(projection_head=True, in_channels=len(args.modalities)) 101 | 102 | logger.info( 103 | "Number of parameters : {}", 104 | sum(p.numel() for p in model.parameters() if p.requires_grad), 105 | ) 106 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 107 | 108 | if args.n_gpus > 1: 109 | model = torch.nn.DataParallel(model, device_ids=list(range(args.n_gpus))) 110 | 111 | model.to(device) 112 | logger.info("Using {} gpus to train the model", args.n_gpus) 113 | 114 | contrastive_loss = ContrastiveLoss(temperature=args.temperature, k=args.k) 115 | 116 | optimizer = torch.optim.Adam( 117 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay 118 | ) 119 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 120 | optimizer, T_max=args.epochs, eta_min=5e-6 121 | ) 122 | 123 | path_to_save = ( 124 | f"data/models/{'-'.join(args.modalities)}_tumor{str(args.tumor_centered)}.pth" 125 | ) 126 | # if a wandb entity is provided, log the training on wandb 127 | wandb_logging = True if args.entity is not None else False 128 | if wandb_logging: 129 | run = wandb.init( 130 | project=args.project, 131 | entity=args.entity, 132 | name=f"Pretraining_MRI", 133 | reinit=True, 134 | config=vars(args), 135 | ) 136 | logger.info("Training started!") 137 | _, _ = training_loop_contrastive( 138 | model, 139 | args.epochs, 140 | contrastive_loss, 141 | optimizer, 142 | lr_scheduler, 143 | train_loader, 144 | val_loader, 145 | device, 146 | path_to_save, 147 | wandb_logging=wandb_logging, 148 | k=args.k, 149 | ) 150 | if wandb_logging: 151 | run.finish() 152 | logger.info("Training finished!") 153 | -------------------------------------------------------------------------------- /data/run_wsi_pretraining.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | import argparse 3 | import os 4 | 5 | # Third-party libraries 6 | import torch 7 | import wandb 8 | import numpy as np 9 | from imutils import paths 10 | 11 | # Local dependencies 12 | from drim.utils import seed_everything, seed_worker 13 | from drim.logger import logger 14 | from drim.wsi.transforms import contrastive_wsi_transforms 15 | from drim.wsi.datasets import PatchDataset 16 | from drim.wsi.models import ResNetWrapperSimCLR 17 | from drim.commons.optim_contrastive import training_loop_contrastive 18 | from drim.commons.losses import ContrastiveLoss 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument( 24 | "--data_path", 25 | nargs="+", 26 | default=["/media/hddext/medical/TCGA-GBM_WSI", "/archive/medical/tcga_lgg/wsi"], 27 | ) 28 | parser.add_argument("--batch_size", type=int, default=200) 29 | parser.add_argument("--epochs", type=int, default=30) 30 | parser.add_argument("--lr", type=float, default=1e-3) 31 | parser.add_argument("--temperature", type=float, default=0.07) 32 | parser.add_argument("--weight_decay", type=float, default=1e-6) 33 | parser.add_argument("--out_dim", type=int, default=256) 34 | parser.add_argument("--entity", type=str, default=None) 35 | parser.add_argument("--project", type=str, default="DRIM") 36 | parser.add_argument("--n_cpus", type=int, default=30) 37 | parser.add_argument("--n_gpus", type=int, default=4) 38 | parser.add_argument("--k", type=int, default=10) 39 | parser.add_argument("--seed", type=int, default=1999) 40 | args = parser.parse_args() 41 | 42 | # Set seed 43 | seed_everything(args.seed) 44 | # get every filepaths 45 | filepaths = [] 46 | for path in args.data_path: 47 | temp_filepaths = list(paths.list_images(path)) 48 | temp_filepaths = [ 49 | filepath for filepath in temp_filepaths if "patches" in filepath 50 | ] 51 | filepaths += temp_filepaths 52 | 53 | filepaths = np.array(filepaths) 54 | # split patient into train and val by taking random 75% of patients for training 55 | 56 | train_idx = np.random.choice( 57 | np.arange(len(filepaths)), int(len(filepaths) * 0.75), replace=False 58 | ) 59 | train_mask = np.zeros(len(filepaths), dtype=bool) 60 | train_mask[train_idx] = True 61 | train_filepaths = filepaths[train_idx] 62 | val_filepaths = filepaths[~train_mask] 63 | 64 | logger.info("Performing contrastive training on TCGA dataset.") 65 | logger.info("Number of training images : {}.", len(train_filepaths)) 66 | logger.info("Number of validation images : {}.", len(val_filepaths)) 67 | 68 | train_dataset = PatchDataset(train_filepaths, transforms=contrastive_wsi_transforms) 69 | val_dataset = PatchDataset(val_filepaths, transforms=contrastive_wsi_transforms) 70 | 71 | train_loader = torch.utils.data.DataLoader( 72 | train_dataset, 73 | batch_size=args.batch_size, 74 | shuffle=True, 75 | pin_memory=True, 76 | worker_init_fn=seed_worker, 77 | ) 78 | val_loader = torch.utils.data.DataLoader( 79 | val_dataset, 80 | batch_size=args.batch_size, 81 | shuffle=False, 82 | pin_memory=True, 83 | worker_init_fn=seed_worker, 84 | ) 85 | 86 | model = ResNetWrapperSimCLR(out_dim=args.out_dim, projection_head=True) 87 | 88 | logger.info( 89 | "Number of parameters : {}", 90 | sum(p.numel() for p in model.parameters() if p.requires_grad), 91 | ) 92 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 93 | 94 | if args.n_gpus > 1: 95 | model = torch.nn.DataParallel(model, device_ids=list(range(args.n_gpus))) 96 | 97 | model.to(device) 98 | logger.info("Using {} gpus to train the model", args.n_gpus) 99 | 100 | contrastive_loss = ContrastiveLoss(temperature=args.temperature, k=args.k) 101 | 102 | optimizer = torch.optim.Adam( 103 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay 104 | ) 105 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 106 | optimizer, T_max=40, eta_min=5e-6 107 | ) 108 | 109 | path_to_save = f"data/models/wsi_encoder.pth" 110 | wandb_logging = True if args.entity is not None else False 111 | if wandb_logging: 112 | run = wandb.init( 113 | project=args.project, 114 | entity=args.entity, 115 | name=f"Pretraining_WSI", 116 | reinit=True, 117 | config=vars(args), 118 | ) 119 | logger.info("Training started!") 120 | _, _ = training_loop_contrastive( 121 | model, 122 | args.epochs, 123 | contrastive_loss, 124 | optimizer, 125 | lr_scheduler, 126 | train_loader, 127 | val_loader, 128 | device, 129 | path_to_save, 130 | wandb_logging=wandb_logging, 131 | k=args.k, 132 | ) 133 | if wandb_logging: 134 | run.finish() 135 | logger.info("Training finished!") 136 | -------------------------------------------------------------------------------- /drim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/drim/__init__.py -------------------------------------------------------------------------------- /drim/commons/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/drim/commons/__init__.py -------------------------------------------------------------------------------- /drim/commons/losses.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.loss import _Loss 2 | import torch.nn.functional as F 3 | from warnings import warn 4 | import torch 5 | 6 | 7 | class ContrastiveLoss(_Loss): 8 | """ 9 | Deeply inspired and copy/paste from a previous PR on MONAI (https://github.com/Project-MONAI/MONAI/blob/dev/monai/losses/contrastive.py) 10 | 11 | Compute the Contrastive loss defined in: 12 | 13 | Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International 14 | conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v119/chen20j.html) 15 | 16 | Adapted from: 17 | https://github.com/Sara-Ahmed/SiT/blob/1aacd6adcd39b71efc903d16b4e9095b97dda76f/losses.py#L5 18 | 19 | """ 20 | 21 | def __init__( 22 | self, temperature: float = 0.5, k: int = 3, batch_size: int = -1 23 | ) -> None: 24 | """ 25 | Args: 26 | temperature: Can be scaled between 0 and 1 for learning from negative samples, ideally set to 0.5. 27 | 28 | Raises: 29 | ValueError: When an input of dimension length > 2 is passed 30 | ValueError: When input and target are of different shapes 31 | 32 | """ 33 | super().__init__() 34 | self.temperature = temperature 35 | self.k = k 36 | 37 | if batch_size != -1: 38 | warn( 39 | "batch_size is no longer required to be set. It will be estimated dynamically in the forward call" 40 | ) 41 | 42 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 43 | """ 44 | Args: 45 | input: the shape should be B[F]. 46 | target: the shape should be B[F]. 47 | """ 48 | if len(target.shape) > 2 or len(input.shape) > 2: 49 | raise ValueError( 50 | f"Either target or input has dimensions greater than 2 where target " 51 | f"shape is ({target.shape}) and input shape is ({input.shape})" 52 | ) 53 | 54 | if target.shape != input.shape: 55 | raise ValueError( 56 | f"ground truth has differing shape ({target.shape}) from input ({input.shape})" 57 | ) 58 | 59 | temperature_tensor = torch.as_tensor(self.temperature).to(input.device) 60 | batch_size = input.shape[0] 61 | 62 | negatives_mask = ~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool) 63 | negatives_mask = torch.clone(negatives_mask.type(torch.float)).to(input.device) 64 | 65 | repr = torch.cat([input, target], dim=0) 66 | sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2) 67 | sim_ij = torch.diag(sim_matrix, batch_size) 68 | sim_ji = torch.diag(sim_matrix, -batch_size) 69 | 70 | positives = torch.cat([sim_ij, sim_ji], dim=0) 71 | nominator = torch.exp(positives / temperature_tensor) 72 | denominator = negatives_mask * torch.exp(sim_matrix / temperature_tensor) 73 | 74 | loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1)) 75 | # create pos mask 76 | self_mask = torch.eye( 77 | sim_matrix.shape[0], dtype=torch.bool, device=sim_matrix.device 78 | ) 79 | pos_mask = self_mask.roll(shifts=sim_matrix.shape[0] // 2, dims=0) 80 | sim_matrix.masked_fill_(self_mask, -9e15) 81 | comb_sim = torch.cat( 82 | [ 83 | sim_matrix[pos_mask][:, None], # First position positive example 84 | sim_matrix.masked_fill(pos_mask, -9e15), 85 | ], 86 | dim=-1, 87 | ) 88 | sim_argsort = comb_sim.argsort(dim=-1, descending=True).argmin(dim=-1) 89 | 90 | return ( 91 | torch.sum(loss_partial) / (2 * batch_size), 92 | (sim_argsort == 0).float().mean(), 93 | (sim_argsort < self.k).float().mean(), 94 | 1 + sim_argsort.float().mean(), 95 | ) 96 | -------------------------------------------------------------------------------- /drim/commons/optim_contrastive.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | from collections import defaultdict 3 | from typing import Tuple 4 | 5 | # Third party libraries 6 | import torch 7 | import wandb 8 | import tqdm 9 | import numpy as np 10 | 11 | # Local dependencies 12 | from ..logger import logger 13 | 14 | 15 | def training_loop_contrastive( 16 | model: torch.nn.Module, 17 | epochs: int, 18 | loss_fn: torch.nn.modules.loss._Loss, 19 | optimizer: torch.optim.Optimizer, 20 | scheduler: torch.optim.lr_scheduler._LRScheduler, 21 | train_dl: torch.utils.data.DataLoader, 22 | valid_dl: torch.utils.data.DataLoader, 23 | device: torch.device, 24 | path_to_save: str, 25 | wandb_logging: bool, 26 | k: int = 3, 27 | ) -> Tuple[torch.nn.Module, dict]: 28 | metrics = defaultdict(list) 29 | best_loss = np.infty 30 | for epoch in range(epochs): 31 | logger.info(f"Epoch {epoch + 1}/{epochs}") 32 | logger.info("-" * 10) 33 | model.train() 34 | metrics["lr"].append(optimizer.state_dict()["param_groups"][0]["lr"]) 35 | epoch_loss = 0.0 36 | top_1, top_k, mean_pos = 0.0, 0.0, 0.0 37 | for batch_data in tqdm.tqdm(train_dl, desc="Training...", total=len(train_dl)): 38 | if isinstance(batch_data, dict): 39 | inputs, inputs_2 = ( 40 | batch_data["image"].to(device), 41 | batch_data["image_2"].to(device), 42 | ) 43 | else: 44 | inputs, inputs_2 = batch_data[0].to(device), batch_data[1].to(device) 45 | optimizer.zero_grad() 46 | outputs, _ = model(inputs) 47 | outputs_2, _ = model(inputs_2) 48 | loss, acc_top_1, acc_top_k, acc_mean_pos = loss_fn(outputs, outputs_2) 49 | top_1 += acc_top_1 * inputs.shape[0] 50 | top_k += acc_top_k * inputs.shape[0] 51 | mean_pos += acc_mean_pos * inputs.shape[0] 52 | loss.backward() 53 | optimizer.step() 54 | epoch_loss += loss.item() * inputs.shape[0] 55 | 56 | epoch_loss /= len(train_dl.dataset) 57 | print(f"Training loss: {epoch_loss:.4f}") 58 | 59 | metrics["train/loss"].append(epoch_loss) 60 | metrics["train/top_1"].append(top_1 / len(train_dl.dataset)) 61 | metrics[f"train/top_{k}"].append(top_k / len(train_dl.dataset)) 62 | metrics["train/mean_pos"].append(mean_pos / len(train_dl.dataset)) 63 | model.eval() 64 | with torch.no_grad(): 65 | val_loss = 0.0 66 | top_1, top_k, mean_pos = 0.0, 0.0, 0.0 67 | for batch_data in tqdm.tqdm(valid_dl, desc="Validation..."): 68 | if isinstance(batch_data, dict): 69 | inputs, inputs_2 = ( 70 | batch_data["image"].to(device), 71 | batch_data["image_2"].to(device), 72 | ) 73 | else: 74 | inputs, inputs_2 = batch_data[0].to(device), batch_data[1].to( 75 | device 76 | ) 77 | outputs, _ = model(inputs) 78 | outputs_2, _ = model(inputs_2) 79 | 80 | loss, acc_top_1, acc_top_k, acc_mean_pos = loss_fn(outputs, outputs_2) 81 | top_1 += acc_top_1 * inputs.shape[0] 82 | top_k += acc_top_k * inputs.shape[0] 83 | mean_pos += acc_mean_pos * inputs.shape[0] 84 | val_loss += loss.item() * inputs.shape[0] 85 | 86 | val_loss /= len(valid_dl.dataset) 87 | print(f"Validation loss: {val_loss:.4f}") 88 | metrics["val/top_1"].append(top_1 / len(valid_dl.dataset)) 89 | metrics[f"val/top_{k}"].append(top_k / len(valid_dl.dataset)) 90 | metrics["val/mean_pos"].append(mean_pos / len(valid_dl.dataset)) 91 | metrics["val/loss"].append(val_loss) 92 | 93 | if wandb_logging: 94 | wandb.log({k: v[-1] for k, v in metrics.items()}) 95 | 96 | scheduler.step() 97 | if metrics["val/loss"][-1] < best_loss: 98 | best_loss = metrics["val/loss"][-1] 99 | torch.save(model.state_dict(), path_to_save) 100 | 101 | return model, metrics 102 | -------------------------------------------------------------------------------- /drim/datasets.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import pandas as pd 3 | import torch 4 | from typing import Any, Tuple 5 | 6 | 7 | __all__ = ["SurvivalDataset"] 8 | 9 | 10 | class _BaseDataset(torch.utils.data.Dataset): 11 | def __init__(self, dataframe: pd.DataFrame, return_mask: bool = False) -> None: 12 | self.dataframe = dataframe 13 | self.return_mask = return_mask 14 | 15 | @abstractmethod 16 | def __getitem__(self, idx: int): 17 | raise NotImplementedError 18 | 19 | def __len__(self) -> int: 20 | return len(self.dataframe) 21 | 22 | 23 | class SurvivalDataset(torch.utils.data.Dataset): 24 | def __init__( 25 | self, dataset: torch.utils.data.Dataset, time: torch.tensor, event: torch.tensor 26 | ) -> None: 27 | self.dataset = dataset 28 | self.time, self.event = torch.from_numpy(time), torch.from_numpy(event) 29 | 30 | def __getitem__(self, idx: int) -> Tuple[Any, torch.Tensor, torch.Tensor]: 31 | return self.dataset[idx], self.time[idx], self.event[idx] 32 | 33 | def __len__(self) -> int: 34 | return len(self.dataset) 35 | -------------------------------------------------------------------------------- /drim/dnam/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import DNAmDataset 2 | from .models import DNAmDecoder, DNAmEncoder 3 | -------------------------------------------------------------------------------- /drim/dnam/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | from typing import Union, Tuple 4 | from ..datasets import _BaseDataset 5 | 6 | 7 | class DNAmDataset(_BaseDataset): 8 | def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, bool], torch.Tensor]: 9 | sample = self.dataframe.iloc[idx] 10 | if not pd.isna(sample.DNAm): 11 | out = ( 12 | torch.from_numpy(pd.read_csv(sample.DNAm).Beta_value.fillna(0.0).values) 13 | .float() 14 | .unsqueeze(0) 15 | ) 16 | mask = True 17 | else: 18 | out = torch.zeros(1, 25978).float() 19 | mask = False 20 | 21 | if self.return_mask: 22 | return out, mask 23 | else: 24 | return out 25 | -------------------------------------------------------------------------------- /drim/dnam/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops.layers.torch import Rearrange 4 | 5 | 6 | class DNAmEncoder(nn.Module): 7 | """ 8 | A vanilla encoder based on 1-d convolution for DNAm data. 9 | """ 10 | 11 | def __init__(self, embedding_dim: int, dropout: float) -> None: 12 | super().__init__() 13 | self.encoder = nn.Sequential( 14 | nn.Conv1d(1, 8, 9, 3), 15 | nn.GELU(), 16 | nn.BatchNorm1d(8), 17 | nn.Dropout(dropout), 18 | nn.Conv1d(8, 32, 9, 3), 19 | nn.GELU(), 20 | nn.BatchNorm1d(32), 21 | nn.Dropout(dropout), 22 | nn.Conv1d(32, 64, 9, 3), 23 | nn.GELU(), 24 | nn.BatchNorm1d(64), 25 | nn.Dropout(dropout), 26 | nn.Conv1d(64, 128, 9, 3), 27 | nn.GELU(), 28 | nn.BatchNorm1d(128), 29 | nn.Dropout(dropout), 30 | nn.Conv1d(128, 256, 9, 3), 31 | nn.GELU(), 32 | nn.BatchNorm1d(256), 33 | nn.Dropout(dropout), 34 | nn.Conv1d(256, embedding_dim, 9, 3), 35 | nn.GELU(), 36 | nn.BatchNorm1d(embedding_dim), 37 | nn.Dropout(dropout), 38 | nn.AdaptiveAvgPool1d(1), 39 | ) 40 | 41 | def forward(self, x: torch.Tensor) -> torch.Tensor: 42 | return self.encoder(x).squeeze(-1) 43 | 44 | 45 | class DNAmDecoder(nn.Module): 46 | """ 47 | A vanilla decoder based on 1-d transposed convolution for DNAm data. 48 | """ 49 | 50 | def __init__(self, embedding_dim: int, dropout: float) -> None: 51 | super().__init__() 52 | self.decoder = nn.Sequential( 53 | nn.Linear(embedding_dim, embedding_dim * 33), 54 | Rearrange("b (n e) -> b n e", e=33), 55 | nn.GELU(), 56 | nn.BatchNorm1d(embedding_dim), 57 | nn.Dropout(dropout), 58 | nn.ConvTranspose1d(embedding_dim, 256, 9, 3), 59 | nn.GELU(), 60 | nn.BatchNorm1d(256), 61 | nn.Dropout(dropout), 62 | nn.ConvTranspose1d(256, 128, 9, 3, 1), 63 | nn.GELU(), 64 | nn.BatchNorm1d(128), 65 | nn.Dropout(dropout), 66 | nn.ConvTranspose1d(128, 64, 9, 3, 1), 67 | nn.GELU(), 68 | nn.BatchNorm1d(64), 69 | nn.Dropout(dropout), 70 | nn.ConvTranspose1d(64, 32, 9, 3, 2), 71 | nn.GELU(), 72 | nn.BatchNorm1d(32), 73 | nn.Dropout(dropout), 74 | nn.ConvTranspose1d(32, 8, 9, 3, 1), 75 | nn.GELU(), 76 | nn.BatchNorm1d(8), 77 | nn.Dropout(dropout), 78 | nn.ConvTranspose1d(8, 1, 8, 3, 2), 79 | nn.Sigmoid(), 80 | ) 81 | 82 | def forward(self, x: torch.Tensor) -> torch.Tensor: 83 | x = self.decoder(x) 84 | return x 85 | -------------------------------------------------------------------------------- /drim/fusion/__init__.py: -------------------------------------------------------------------------------- 1 | from .fusion import * 2 | -------------------------------------------------------------------------------- /drim/fusion/fusion.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import Dict, List, Optional, Tuple 3 | import torch 4 | import random 5 | from einops import repeat 6 | from .utils import MultiHeadSelfAttention, FeedForward 7 | 8 | __all__ = ["ShallowFusion", "TensorFusion", "MaskedAttentionFusion"] 9 | 10 | 11 | class ShallowFusion(nn.Module): 12 | """ 13 | Class computing the simplest fusion methods. Adapted from MultiSurv (https://www.nature.com/articles/s41598-021-92799-4?proof=t). 14 | It also contains the code from Cheerla et al. (https://github.com/gevaertlab/MultimodalPrognosis/blob/master/experiments/chart4.py). 15 | """ 16 | 17 | def __init__(self, fusion_type: str = "mean") -> None: 18 | super().__init__() 19 | self.fusion_type = fusion_type 20 | 21 | def forward( 22 | self, x: Dict[str, torch.Tensor], mask: Dict[str, torch.Tensor] = None 23 | ) -> torch.Tensor: 24 | if self.fusion_type == "mean": 25 | x = torch.stack(list(x.values()), dim=0).mean(dim=0) 26 | elif self.fusion_type == "max": 27 | x, _ = torch.stack(list(x.values()), dim=0).max(dim=0) 28 | elif self.fusion_type == "sum": 29 | x = torch.stack(list(x.values()), dim=0).sum(dim=0) 30 | elif self.fusion_type == "concat": 31 | x = torch.cat(list(x.values()), dim=-1) 32 | elif self.fusion_type == "masked_mean": 33 | x = self.masked_mean(x, mask) 34 | 35 | return x 36 | 37 | @staticmethod 38 | def masked_mean(x, mask): 39 | num = sum((x[modality] * mask[modality][:, None]) for modality in x.keys()) 40 | den = sum(mask[modality] for modality in mask.keys())[:, None].float() 41 | return num / den 42 | 43 | 44 | class TensorFusion(nn.Module): 45 | """ 46 | Multimodal Tensor Fusion via Kronecker Product and Gating-Based Attention derived from [1, 2]. 47 | 48 | [1] https://ieeexplore.ieee.org/document/9186053 49 | [2] https://link.springer.com/chapter/10.1007/978-3-030-87240-3_64 50 | 51 | NB: This is an extension of the original loss, which can take more than 3 modalities. 52 | 53 | This fusion is not particularly well adapted to missing modalities and its number of trainable parameters can go high. 54 | Be careful with the number of modalities and the size of the tensors. Moreover it is not invariant to modality permutation. 55 | 56 | For instance: [modalities (m), input dimension (i), projected dimension (p), output dimension (o) ~ number of parameters (millions)] 57 | m: 3, i: 128, p: 128, o: 128 ~ 280M 58 | m: 4, i: 16, p: 16, o: 16 ~ 1M4 59 | m: 4, i: 32, p: 32, o: 32 ~ 38M 60 | """ 61 | 62 | def __init__( 63 | self, 64 | modalities: List[str], 65 | input_dim: int, 66 | projected_dim: int, 67 | output_dim: int, 68 | gate: bool = True, 69 | skip: bool = True, 70 | dropout: float = 0.1, 71 | pairs: Optional[Dict[str, str]] = None, 72 | ) -> None: 73 | # raise a warning because of the important number of trainable parameters. 74 | super().__init__() 75 | self.skip = skip 76 | self.modalities = sorted(modalities) 77 | self.dropout = dropout 78 | self.gate = gate 79 | if gate: 80 | if pairs: 81 | self.pairs = pairs 82 | else: 83 | self.pairs = {} 84 | for modality in self.modalities: 85 | i = self.modalities.index(modality) 86 | j = i 87 | while j == i: 88 | j = random.randint(0, len(self.modalities) - 1) 89 | self.pairs[modality] = self.modalities[j] 90 | 91 | for modality in self.modalities: 92 | if gate: 93 | setattr( 94 | self, 95 | f"{modality.lower()}_linear", 96 | nn.Sequential(nn.Linear(input_dim, projected_dim), nn.ReLU()), 97 | ) 98 | setattr( 99 | self, 100 | f"{modality.lower()}_bilinear", 101 | nn.Bilinear(input_dim, input_dim, projected_dim), 102 | ) 103 | setattr( 104 | self, 105 | f"{modality.lower()}_last_linear", 106 | nn.Sequential( 107 | nn.Linear(projected_dim, projected_dim), 108 | nn.ReLU(), 109 | nn.Dropout(p=dropout), 110 | ), 111 | ) 112 | else: 113 | setattr( 114 | self, 115 | f"{modality.lower()}_proj", 116 | nn.Sequential(nn.Linear(input_dim, projected_dim), nn.ReLU()), 117 | ) 118 | 119 | self.dropout = nn.Dropout(dropout) 120 | self.post_fusion = nn.Sequential( 121 | nn.Linear((input_dim + 1) ** len(modalities), output_dim), 122 | nn.ReLU(), 123 | nn.Dropout(p=dropout), 124 | ) 125 | if self.skip: 126 | self.skip = nn.Sequential( 127 | nn.Linear(output_dim + ((input_dim + 1) * len(modalities)), output_dim), 128 | nn.ReLU(), 129 | nn.Dropout(p=dropout), 130 | ) 131 | 132 | def forward(self, x: Dict[str, torch.Tensor], mask=None) -> torch.Tensor: 133 | if self.gate: 134 | h = {k: getattr(self, f"{k.lower()}_linear")(v) for k, v in x.items()} 135 | x = { 136 | k: getattr(self, f"{k.lower()}_bilinear")(v, x[self.pairs[k]]) 137 | for k, v in x.items() 138 | } 139 | x = { 140 | k: getattr(self, f"{k.lower()}_last_linear")(nn.Sigmoid()(x[k]) * h[k]) 141 | for k, v in x.items() 142 | } 143 | 144 | else: 145 | x = {k: getattr(self, f"{k.lower()}_proj")(v) for k, v in x.items()} 146 | 147 | x = { 148 | k: torch.cat((v, torch.FloatTensor(v.shape[0], 1).fill_(1).to(v.device)), 1) 149 | for k, v in x.items() 150 | } 151 | out = torch.bmm( 152 | x[self.modalities[0]].unsqueeze(2), x[self.modalities[1]].unsqueeze(1) 153 | ).flatten(start_dim=1) 154 | for modality in self.modalities[2:]: 155 | out = torch.bmm(out.unsqueeze(2), x[modality].unsqueeze(1)).flatten( 156 | start_dim=1 157 | ) 158 | 159 | out = self.dropout(out) 160 | out = self.post_fusion(out) 161 | if self.skip: 162 | out = torch.cat((out, *list(x.values())), 1) 163 | out = self.skip(out) 164 | 165 | return out 166 | 167 | 168 | class MaskedAttentionFusion(nn.Module): 169 | def __init__( 170 | self, 171 | dim: int, 172 | depth: int, 173 | heads: int, 174 | dim_head: int, 175 | mlp_dim: int, 176 | dropout: float = 0.0, 177 | ): 178 | super().__init__() 179 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 180 | self.norm = nn.LayerNorm(dim) 181 | self.layers = nn.ModuleList([]) 182 | for _ in range(depth): 183 | self.layers.append( 184 | nn.ModuleList( 185 | [ 186 | MultiHeadSelfAttention( 187 | dim, heads=heads, dim_head=dim_head, dropout=dropout 188 | ), 189 | FeedForward(dim, mlp_dim, dropout=dropout), 190 | ] 191 | ) 192 | ) 193 | 194 | def forward( 195 | self, x: Dict[str, torch.Tensor], mask: Optional[Dict[str, torch.Tensor]] = None 196 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 197 | keys = list(x.keys()) 198 | x = torch.stack([x[k] for k in keys], 1) 199 | scores = [] 200 | cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=x.shape[0]).to(x.device) 201 | x = torch.cat((cls_tokens, x), dim=1) 202 | if mask is not None: 203 | mask = torch.stack( 204 | [torch.ones(x.shape[0]).bool().to(mask[keys[0]].device)] 205 | + [mask[k] for k in keys], 206 | 1, 207 | ) 208 | 209 | for attn, ff in self.layers: 210 | temp_x, score = attn(x, mask=mask) 211 | scores.append(score) 212 | x = temp_x + x 213 | x = ff(x) + x 214 | x = x[:, 0] 215 | scores = self._get_attn_scores(torch.stack(scores)) 216 | return self.norm(x), {k: scores[:, i] for i, k in enumerate(keys)} 217 | 218 | @staticmethod 219 | def _get_attn_scores(x): 220 | return x.mean(dim=2)[:, :, 0, 1:].mean(0) 221 | -------------------------------------------------------------------------------- /drim/fusion/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops import rearrange 4 | from typing import Optional 5 | 6 | 7 | def softmax_one(x, dim=None): 8 | """ 9 | Quiet softmax function as presented by Evan Miller in his blog post "Attention is Off by One". 10 | 11 | https://www.evanmiller.org/attention-is-off-by-one.html 12 | https://github.com/kyegomez/AttentionIsOFFByOne 13 | """ 14 | x = x - x.max(dim=dim, keepdim=True).values 15 | exp_x = torch.exp(x) 16 | 17 | return exp_x / (1 + exp_x.sum(dim=dim, keepdim=True)) 18 | 19 | 20 | class FeedForward(nn.Module): 21 | def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0) -> None: 22 | super().__init__() 23 | self.net = nn.Sequential( 24 | nn.LayerNorm(dim), 25 | nn.Linear(dim, hidden_dim), 26 | nn.GELU(), 27 | nn.Dropout(dropout), 28 | nn.Linear(hidden_dim, dim), 29 | nn.Dropout(dropout), 30 | ) 31 | 32 | def forward(self, x: torch.Tensor) -> torch.Tensor: 33 | return self.net(x) 34 | 35 | 36 | class MultiHeadSelfAttention(nn.Module): 37 | """ 38 | Adapted from lucidrains vit-pytorch. It handles mask alongside the features vector. 39 | """ 40 | 41 | def __init__( 42 | self, dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0 43 | ) -> None: 44 | super().__init__() 45 | inner_dim = dim_head * heads 46 | project_out = not (heads == 1 and dim_head == dim) 47 | 48 | self.heads = heads 49 | self.scale = dim_head**-0.5 50 | 51 | self.norm = nn.LayerNorm(dim) 52 | 53 | self.dropout = nn.Dropout(dropout) 54 | 55 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 56 | 57 | self.to_out = ( 58 | nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) 59 | if project_out 60 | else nn.Identity() 61 | ) 62 | 63 | def forward( 64 | self, x: torch.Tensor, mask: Optional[torch.Tensor] = None 65 | ) -> torch.Tensor: 66 | x = self.norm(x) 67 | 68 | qkv = self.to_qkv(x).chunk(3, dim=-1) 69 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) 70 | 71 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 72 | 73 | if mask is not None: 74 | mask_value = torch.finfo(dots.dtype).min 75 | mask = mask[:, None, :, None] * mask[:, None, None, :] 76 | dots.masked_fill_(~mask, mask_value) 77 | 78 | attn = softmax_one(dots, dim=-1) 79 | attn_d = self.dropout(attn) 80 | 81 | out = torch.matmul(attn_d, v) 82 | out = rearrange(out, "b h n d -> b n (h d)") 83 | return self.to_out(out), attn 84 | -------------------------------------------------------------------------------- /drim/helpers.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | from typing import Dict 3 | 4 | # Third-party libraries 5 | import pandas as pd 6 | from omegaconf import DictConfig 7 | from pycox.models import LogisticHazard 8 | 9 | # Local dependencies 10 | from drim.utils import prepare_data, get_target_survival, log_transform 11 | 12 | 13 | def get_encoder(modality: str, cfg: DictConfig): 14 | if modality == "DNAm": 15 | from drim.dnam import DNAmEncoder 16 | 17 | encoder = DNAmEncoder( 18 | embedding_dim=cfg.general.dim, dropout=cfg.general.dropout 19 | ) 20 | elif modality == "RNA": 21 | from drim.rna import RNAEncoder 22 | 23 | encoder = RNAEncoder(embedding_dim=cfg.general.dim, dropout=cfg.general.dropout) 24 | elif modality == "WSI": 25 | from drim.wsi import WSIEncoder 26 | 27 | encoder = WSIEncoder( 28 | cfg.general.dim, 29 | depth=1, 30 | heads=8, 31 | dropout=cfg.general.dropout, 32 | emb_dropout=cfg.general.dropout, 33 | ) 34 | elif modality == "MRI": 35 | from drim.mri import MRIEmbeddingEncoder 36 | 37 | encoder = MRIEmbeddingEncoder( 38 | embedding_dim=cfg.general.dim, dropout=cfg.general.dropout 39 | ) 40 | else: 41 | raise NotImplementedError(f"Modality {modality} not implemented") 42 | 43 | return encoder 44 | 45 | 46 | def get_decoder(modality: str, cfg: DictConfig): 47 | if modality == "DNAm": 48 | from drim.dnam import DNAmDecoder 49 | 50 | decoder = DNAmDecoder( 51 | embedding_dim=cfg.general.dim, dropout=cfg.general.dropout 52 | ) 53 | elif modality == "RNA": 54 | from drim.rna import RNADecoder 55 | 56 | decoder = RNADecoder(embedding_dim=cfg.general.dim, dropout=cfg.general.dropout) 57 | elif modality == "WSI": 58 | from drim.wsi import WSIDecoder 59 | 60 | decoder = WSIDecoder(cfg.general.dim, dropout=cfg.general.dropout) 61 | elif modality == "MRI": 62 | from drim.mri import MRIEmbeddingDecoder 63 | 64 | decoder = MRIEmbeddingDecoder( 65 | embedding_dim=cfg.general.dim, dropout=cfg.general.dropout 66 | ) 67 | else: 68 | raise NotImplementedError(f"Modality {modality} not implemented") 69 | 70 | return decoder 71 | 72 | 73 | def get_datasets( 74 | dataframes: Dict[str, pd.DataFrame], 75 | modality: str, 76 | fold: int, 77 | return_mask: bool = False, 78 | ): 79 | if modality == "DNAm": 80 | from drim.dnam import DNAmDataset 81 | 82 | datasets = { 83 | split: DNAmDataset(dataframe, return_mask=return_mask) 84 | for split, dataframe in dataframes.items() 85 | } 86 | elif modality == "RNA": 87 | from drim.rna import RNADataset 88 | from joblib import load 89 | 90 | rna_processor = load(f"./data/rna_preprocessors/trf_{int(fold)}.joblib") 91 | datasets = { 92 | split: RNADataset(dataframe, rna_processor, return_mask=return_mask) 93 | for split, dataframe in dataframes.items() 94 | } 95 | elif modality == "WSI": 96 | from drim.wsi import WSIDataset 97 | 98 | datasets = {} 99 | datasets["train"] = WSIDataset( 100 | dataframes["train"], k=10, is_train=True, return_mask=return_mask 101 | ) 102 | datasets["val"] = WSIDataset( 103 | dataframes["val"], k=10, is_train=False, return_mask=return_mask 104 | ) 105 | datasets["test"] = WSIDataset( 106 | dataframes["test"], k=10, is_train=False, return_mask=return_mask 107 | ) 108 | 109 | elif modality == "MRI": 110 | from drim.mri import MRIEmbeddingDataset 111 | 112 | datasets = { 113 | split: MRIEmbeddingDataset(dataframe, return_mask=return_mask) 114 | for split, dataframe in dataframes.items() 115 | } 116 | else: 117 | raise NotImplementedError(f"Modality {modality} not implemented") 118 | 119 | return datasets 120 | 121 | 122 | def get_targets(dataframes: Dict[str, pd.DataFrame], n_outs: int): 123 | labtrans = LogisticHazard.label_transform(n_outs) 124 | labtrans.fit(*get_target_survival(dataframes["train"])) 125 | 126 | return { 127 | split: labtrans.transform(*get_target_survival(dataframe)) 128 | for split, dataframe in dataframes.items() 129 | }, labtrans.cuts 130 | -------------------------------------------------------------------------------- /drim/logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from loguru import logger 3 | 4 | logger.remove() 5 | logger.add( 6 | sys.stdout, 7 | colorize=True, 8 | format=( 9 | "GBMLGG:" " {message}" 10 | ), 11 | ) 12 | -------------------------------------------------------------------------------- /drim/losses.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class DiscriminatorLoss(nn.Module): 7 | """ 8 | Implementation of the traditionnal discriminator loss to approximate Jensen-Shannon Divergence. 9 | """ 10 | 11 | def __init__(self) -> None: 12 | super().__init__() 13 | 14 | def forward( 15 | self, real_logits: torch.Tensor, fake_logits: torch.Tensor 16 | ) -> torch.Tensor: 17 | """ 18 | Compute discriminator loss 19 | 20 | Args: 21 | real_logits: 22 | torch.tensor containing logits extracted from p(x)p(y), size (bsz, 1) 23 | fake_logits: 24 | torch.tensor containing logits extracted from p(x, y), size (bsz, 1) 25 | 26 | Returns: 27 | The computed scalar loss. 28 | """ 29 | # Discriminator should predict real logits as logits from the real distribution 30 | discriminator_real = torch.nn.functional.binary_cross_entropy_with_logits( 31 | input=real_logits, target=torch.ones_like(real_logits) 32 | ) 33 | # Discriminator should predict fake logits as logits from the generated distribution 34 | discriminator_fake = torch.nn.functional.binary_cross_entropy_with_logits( 35 | input=fake_logits, target=torch.zeros_like(fake_logits) 36 | ) 37 | 38 | discriminator_loss = discriminator_real + discriminator_fake 39 | 40 | return discriminator_loss 41 | 42 | 43 | class ContrastiveLoss(nn.Module): 44 | """ 45 | This code is adapted from the code for supervised contrastive learning (https://arxiv.org/pdf/2004.11362.pdf). 46 | """ 47 | 48 | def __init__( 49 | self, 50 | temperature: float = 0.07, 51 | contrast_mode: str = "all", 52 | base_temperature: float = 0.07, 53 | ) -> None: 54 | super(ContrastiveLoss, self).__init__() 55 | self.temperature = temperature 56 | self.contrast_mode = contrast_mode 57 | self.base_temperature = base_temperature 58 | 59 | def forward( 60 | self, x: Dict[str, torch.Tensor], mask: Dict[str, torch.Tensor] = None 61 | ) -> torch.Tensor: 62 | """Compute shared loss for our disentangled framework. 63 | 64 | Args: 65 | x: 66 | Dict associating each modality (key) to the corresponding embedding, tensor of size (bsz, d) 67 | mask: 68 | Dict associating each modality (key) to a boolean tensor of size (bsz,) indicating if the modality is indeed present in the minibatch 69 | Returns: 70 | The computed scalar loss. 71 | """ 72 | keys = list(x.keys()) 73 | features = torch.stack([x[key] for key in keys], 1) 74 | modality_mask = torch.stack([mask[key] for key in keys], 1) 75 | batch_size = features.shape[0] 76 | 77 | contrast_count = features.shape[1] 78 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 79 | modality_mask = torch.cat(torch.unbind(modality_mask, dim=1), dim=0).view(1, -1) 80 | 81 | # compute logits 82 | anchor_dot_contrast = torch.div( 83 | torch.matmul(contrast_feature, contrast_feature.T), self.temperature 84 | ) 85 | # for numerical stability 86 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 87 | logits = anchor_dot_contrast - logits_max.detach() 88 | 89 | mask = torch.eye(batch_size, dtype=torch.float32).to(modality_mask.device) 90 | mask = mask.repeat(contrast_count, contrast_count) 91 | # mask-out self-contrast cases 92 | logits_mask = torch.scatter( 93 | (modality_mask[:, None, :, None] * modality_mask[:, None, None, :]) 94 | .squeeze() 95 | .long(), 96 | 1, 97 | torch.arange(batch_size * contrast_count) 98 | .view(-1, 1) 99 | .to(modality_mask.device), 100 | 0, 101 | ) 102 | mask = mask * logits_mask 103 | # compute log_prob 104 | exp_logits = torch.exp(logits) * logits_mask 105 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6) 106 | 107 | # compute mean of log-likelihood over positive 108 | mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-6) 109 | # loss 110 | loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos 111 | loss = loss.view(contrast_count, batch_size).mean() 112 | 113 | return loss 114 | 115 | 116 | class MMOLoss(nn.Module): 117 | """ 118 | Loss function used in the paper: 119 | Deep Orthogonal Fusion: Multimodal: Prognostic Biomarker Discovery Integrating Radiology, 120 | Pathology, Genomic, and Clinical Data. 121 | https://link.springer.com/chapter/10.1007/978-3-030-87240-3_64 122 | 123 | N.B: In the paper and by design, this loss is not well-suited to deal with missing modality, we propose here an alternative. 124 | """ 125 | 126 | def __init__(self) -> None: 127 | super(MMOLoss, self).__init__() 128 | 129 | def forward( 130 | self, x: Dict[str, torch.Tensor], mask: Dict[str, torch.Tensor] = None 131 | ) -> torch.Tensor: 132 | """ 133 | Compute auxiliary orthogonal loss. 134 | 135 | Args: 136 | x: 137 | Dict associating each modality (key) to the corresponding embedding, tensor of size (bsz, d) 138 | mask: 139 | Dict associating each modality (key) to a boolean tensor of size (bsz,) indicating if the modality is indeed present in the minibatch, 140 | Returns: 141 | The computed scalar loss. 142 | """ 143 | if mask is None: 144 | mask = {key: torch.ones_like(x[key]).bool() for key in x.keys()} 145 | 146 | loss = 0.0 147 | for key in x.keys(): 148 | loss += torch.max( 149 | torch.tensor(1), torch.linalg.matrix_norm(x[key][mask[key]], ord="nuc") 150 | ) 151 | 152 | full = torch.stack(list(x.values()))[torch.stack(list(mask.values()))] 153 | loss -= torch.linalg.matrix_norm(full, ord="nuc") 154 | loss /= full.size(0) 155 | return loss 156 | -------------------------------------------------------------------------------- /drim/models.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | class _BaseWrapper(nn.Module): 7 | """ 8 | Base wrapper for taking the final embedding and linking to the survival task 9 | """ 10 | 11 | def __init__( 12 | self, encoder: nn.Module, embedding_dim: int, n_outs: int, device: str = "cuda" 13 | ) -> None: 14 | super().__init__() 15 | self.encoder = encoder 16 | self.final = nn.Linear(embedding_dim, n_outs).to(device) 17 | self.device = device 18 | 19 | 20 | class UnimodalWrapper(_BaseWrapper): 21 | """ 22 | Survival wrapper taking an encoder et linking it to the survival task 23 | """ 24 | 25 | def forward(self, x: torch.Tensor, return_embedding: bool = False) -> torch.Tensor: 26 | """ 27 | Args: 28 | x: 29 | tensor containing the raw input (bsz, n_features) 30 | return_embedding: 31 | boolean indicating if the pre-survival layer embedding will be used 32 | 33 | Returns 34 | Whether the predicted hazards or the predicted hazards and the last embedding. 35 | """ 36 | x = x.to(self.device) 37 | x = self.encoder(x) 38 | if return_embedding: 39 | return self.final(x), x 40 | else: 41 | return self.final(x) 42 | 43 | 44 | class MultimodalWrapper(_BaseWrapper): 45 | def forward( 46 | self, 47 | x: Dict[str, torch.Tensor], 48 | mask: Optional[Dict[str, torch.Tensor]] = None, 49 | return_embedding: bool = False, 50 | ) -> torch.Tensor: 51 | """ 52 | Args: 53 | x: 54 | Dict associating each modality (key) to the raw input tensor (bsz, n_features) 55 | mask: 56 | Dict associating each modality (key) to a boolean tensor of size (bsz,) indicating if the modality is indeed present in the minibatch 57 | return_embedding: 58 | boolean indicating if the pre-survival layer embedding will be used 59 | 60 | Returns 61 | Whether the predicted hazards or the predicted hazards and the last embedding. 62 | """ 63 | x = self.encoder(x, mask=mask, return_embedding=return_embedding) 64 | if not isinstance(x, tuple): 65 | # for fine-tuning or simple multimodal training 66 | return self.final(x.to(self.device)) 67 | else: 68 | if len(x) == 2 and isinstance(x[0], dict): 69 | # for the self supervised settings 70 | return x 71 | else: 72 | out = self.final(x[0].to(self.device)) 73 | if len(x) == 2: 74 | # auxiliary training 75 | return out, x[1] 76 | else: 77 | # disentangled training 78 | return out, x[1], x[2] 79 | 80 | 81 | class Discriminator(nn.Module): 82 | """ 83 | Generic discriminator for the adversarial training 84 | """ 85 | 86 | def __init__(self, embedding_dim: int, dropout: float) -> None: 87 | super().__init__() 88 | self.model = nn.Sequential( 89 | nn.Linear(embedding_dim, 1024), 90 | nn.Dropout(dropout), 91 | nn.GELU(), 92 | nn.Linear(1024, embedding_dim), 93 | nn.Dropout(dropout), 94 | nn.GELU(), 95 | nn.Linear(embedding_dim, 1), 96 | ) 97 | 98 | def forward(self, x: torch.Tensor) -> torch.Tensor: 99 | return self.model(x) 100 | -------------------------------------------------------------------------------- /drim/mri/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import MRIEmbeddingDataset 2 | from .models import MRIEmbeddingDecoder, MRIEmbeddingEncoder 3 | -------------------------------------------------------------------------------- /drim/mri/datasets.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, List, Dict, Tuple 2 | import os 3 | import numpy as np 4 | import nibabel as nib 5 | from abc import abstractmethod 6 | from torch.utils.data import Dataset 7 | import monai 8 | import torch 9 | import pandas as pd 10 | from ..datasets import _BaseDataset 11 | 12 | __all__ = ["DatasetBraTS", "DatasetBraTSTumorCentered", "MRIDataset"] 13 | 14 | 15 | class _BaseDatasetBraTS(Dataset): 16 | def __init__( 17 | self, 18 | path: str, 19 | modality: Union[str, List[str]], 20 | patients: Optional[List[str]], 21 | return_mask: bool, 22 | transform: Optional["monai.transforms"] = None, 23 | ) -> None: 24 | self.path = path 25 | self.modality = modality 26 | self.patients = np.array(os.listdir(path)) if patients is None else patients 27 | self.transform = transform 28 | self.return_mask = return_mask 29 | 30 | def __len__(self): 31 | return len(self.patients) 32 | 33 | @abstractmethod 34 | def __getitem__(self, idx): 35 | raise NotImplementedError 36 | 37 | def _load_nifti_modalities(self, patient: str) -> np.ndarray: 38 | if len(self.modality) == 1: 39 | img = nib.load( 40 | os.path.join( 41 | self.path, patient, patient + "_" + self.modality[0] + ".nii.gz" 42 | ) 43 | ).get_fdata() 44 | img = np.expand_dims(img, 0) 45 | 46 | else: 47 | early_fused = [] 48 | for modality in self.modality: 49 | early_fused.append( 50 | nib.load( 51 | os.path.join( 52 | self.path, patient, patient + "_" + modality + ".nii.gz" 53 | ) 54 | ).get_fdata() 55 | ) 56 | 57 | img = np.stack(early_fused) 58 | 59 | return img 60 | 61 | @staticmethod 62 | def _monaify(img: np.ndarray, mask: Optional[np.ndarray]) -> Dict[str, np.ndarray]: 63 | item = {"image": img} 64 | if mask is not None: 65 | item["mask"] = mask 66 | 67 | return item 68 | 69 | def _load_mask(self, patient: str) -> np.ndarray: 70 | return nib.load( 71 | os.path.join(self.path, patient, patient + "_seg.nii.gz") 72 | ).get_fdata() 73 | 74 | 75 | class DatasetBraTS(_BaseDatasetBraTS): 76 | def __getitem__(self, idx) -> Dict[str, Union[torch.tensor, np.ndarray]]: 77 | patient = self.patients[idx] 78 | mask = ( 79 | super(DatasetBraTS, self)._load_mask(patient=patient) 80 | if self.return_mask 81 | else None 82 | ) 83 | img = super(DatasetBraTS, self)._load_nifti_modalities(patient=patient) 84 | 85 | item = super(DatasetBraTS, self)._monaify(img=img, mask=mask) 86 | if self.transform is not None: 87 | item = self.transform(item) 88 | 89 | return item 90 | 91 | 92 | class DatasetBraTSTumorCentered(_BaseDatasetBraTS): 93 | def __init__( 94 | self, 95 | path: str, 96 | modality: Union[str, List[str]], 97 | patients: Optional[List[str]], 98 | sizes: Tuple[int, ...], 99 | transform: Optional["monai.transforms"] = None, 100 | return_mask: bool = False, 101 | ) -> None: 102 | super(DatasetBraTSTumorCentered, self).__init__( 103 | path=path, 104 | modality=modality, 105 | patients=patients, 106 | return_mask=return_mask, 107 | transform=transform, 108 | ) 109 | self.sizes = sizes 110 | 111 | def __getitem__(self, idx) -> Dict[str, Union[torch.tensor, np.ndarray]]: 112 | patient = self.patients[idx] 113 | mask = super(DatasetBraTSTumorCentered, self)._load_mask(patient=patient) 114 | img = super(DatasetBraTSTumorCentered, self)._load_nifti_modalities( 115 | patient=patient 116 | ) 117 | item = self._compute_subvolumes(img=img, mask=mask) 118 | if self.transform is not None: 119 | item = self.transform(item) 120 | 121 | return item 122 | 123 | def _compute_subvolumes( 124 | self, img: np.ndarray, mask: np.ndarray 125 | ) -> Dict[str, np.ndarray]: 126 | centroid = self._compute_centroid(mask=mask) 127 | # bounds for boolean indexing 128 | lower_bound, upper_bound = self._get_bounds( 129 | centroid=centroid, input_dims=img.shape[1:] 130 | ) 131 | img = img[ 132 | :, 133 | lower_bound[0] : upper_bound[0], 134 | lower_bound[1] : upper_bound[1], 135 | lower_bound[2] : upper_bound[2], 136 | ] 137 | mask = ( 138 | mask[ 139 | lower_bound[0] : upper_bound[0], 140 | lower_bound[1] : upper_bound[1], 141 | lower_bound[2] : upper_bound[2], 142 | ] 143 | if self.return_mask 144 | else None 145 | ) 146 | return super(DatasetBraTSTumorCentered, self)._monaify(img=img, mask=mask) 147 | 148 | def _get_bounds( 149 | self, centroid: np.ndarray, input_dims: Tuple[int, ...] 150 | ) -> Tuple[np.ndarray, np.ndarray]: 151 | lower = (centroid - (np.array(self.sizes) / 2)).astype(int) 152 | upper = (centroid + (np.array(self.sizes) / 2)).astype(int) 153 | return np.clip(lower, 0, input_dims), np.clip(upper, 0, input_dims) 154 | 155 | @staticmethod 156 | def _compute_centroid(mask: np.ndarray) -> np.ndarray: 157 | return np.mean(np.argwhere(mask), axis=0).astype(int) 158 | 159 | 160 | class MRIProcessor: 161 | def __init__( 162 | self, 163 | base_path: str, 164 | tumor_centered: bool, 165 | transform: "augmentations", 166 | modalities: List[str] = ["t1", "t1ce", "t2", "flair"], 167 | size: Tuple[int, ...] = (64, 64, 64), 168 | ) -> None: 169 | self.base_path = base_path 170 | self.size = size 171 | self.tumor_centered = tumor_centered 172 | self.modalities = modalities 173 | self.transform = transform 174 | 175 | def _compute_subvolumes( 176 | self, img: np.ndarray, mask: np.ndarray 177 | ) -> Dict[str, np.ndarray]: 178 | centroid = self._compute_centroid(mask=mask) 179 | # bounds for boolean indexing 180 | lower_bound, upper_bound = self._get_bounds( 181 | centroid=centroid, input_dims=img.shape[1:] 182 | ) 183 | img = img[ 184 | :, 185 | lower_bound[0] : upper_bound[0], 186 | lower_bound[1] : upper_bound[1], 187 | lower_bound[2] : upper_bound[2], 188 | ] 189 | return img, mask 190 | 191 | def _get_bounds( 192 | self, centroid: np.ndarray, input_dims: Tuple[int, ...] 193 | ) -> Tuple[np.ndarray, np.ndarray]: 194 | lower = (centroid - (np.array(self.size) / 2)).astype(int) 195 | upper = (centroid + (np.array(self.size) / 2)).astype(int) 196 | return np.clip(lower, 0, input_dims), np.clip(upper, 0, input_dims) 197 | 198 | @staticmethod 199 | def _compute_centroid(mask: np.ndarray) -> np.ndarray: 200 | return np.mean(np.argwhere(mask), axis=0).astype(int) 201 | 202 | def _load_nifti_modalities(self, base_path: str, patient: str) -> np.ndarray: 203 | if len(self.modalities) == 1: 204 | img = nib.load( 205 | os.path.join(base_path, patient + "_" + self.modality[0] + ".nii.gz") 206 | ).get_fdata() 207 | img = np.expand_dims(img, 0) 208 | 209 | else: 210 | early_fused = [] 211 | for modality in self.modalities: 212 | early_fused.append( 213 | nib.load( 214 | os.path.join(base_path, patient + "_" + modality + ".nii.gz") 215 | ).get_fdata() 216 | ) 217 | 218 | img = np.stack(early_fused) 219 | 220 | return img 221 | 222 | def _load_mask(self, base_path: str, patient: str) -> np.ndarray: 223 | return nib.load(os.path.join(base_path, patient + "_seg.nii.gz")).get_fdata() 224 | 225 | def process(self) -> torch.Tensor: 226 | patient = self.base_path.split("/")[-1] 227 | img = self._load_nifti_modalities(base_path=self.base_path, patient=patient) 228 | if self.tumor_centered: 229 | mask = self._load_mask(base_path=self.base_path, patient=patient) 230 | img, _ = self._compute_subvolumes(img=img, mask=mask) 231 | 232 | img = self.transform(img) 233 | 234 | return img 235 | 236 | 237 | class MRIEmbeddingDataset(_BaseDataset): 238 | def __init__(self, dataframe, return_mask: bool = False): 239 | super().__init__(dataframe, return_mask) 240 | 241 | def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, bool], torch.Tensor]: 242 | sample = self.dataframe.iloc[idx] 243 | if not pd.isna(sample.MRI): 244 | mri = torch.from_numpy( 245 | pd.read_csv(os.path.join(sample.MRI, "embedding.csv")).values[0] 246 | ).float() 247 | mask = True 248 | else: 249 | mri = torch.zeros(512) 250 | mask = False 251 | 252 | if self.return_mask: 253 | return mri, mask 254 | else: 255 | return mri 256 | 257 | 258 | class MRIDataset(_BaseDataset): 259 | def __init__( 260 | self, 261 | dataframe: pd.DataFrame, 262 | transform: "augmentations", 263 | tumor_centered: bool, 264 | modalities: List[str] = ["t1ce", "flair"], 265 | return_mask: bool = False, 266 | sizes: Tuple[int, ...] = (64, 64, 64), 267 | ) -> None: 268 | super().__init__(dataframe, return_mask) 269 | self.sizes = sizes 270 | self.modalities = modalities 271 | self.tumor_centered = tumor_centered 272 | self.transform = transform 273 | 274 | def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, bool], torch.Tensor]: 275 | sample = self.dataframe.iloc[idx] 276 | if not pd.isna(sample.MRI): 277 | processor = MRIProcessor( 278 | sample.MRI, 279 | transform=self.transform, 280 | modalities=self.modalities, 281 | size=self.sizes, 282 | tumor_centered=self.tumor_centered, 283 | ) 284 | mri = processor.process() 285 | mask = True 286 | else: 287 | mri = torch.zeros( 288 | len(self.modalities), self.sizes[0], self.sizes[1], self.sizes[1] 289 | ).float() 290 | mask = False 291 | 292 | if self.return_mask: 293 | return mri, mask 294 | else: 295 | return mri 296 | -------------------------------------------------------------------------------- /drim/mri/models.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | from typing import Tuple, Union 3 | 4 | # Third-party libraries 5 | import torch 6 | import torch.nn as nn 7 | from monai.networks.nets import resnet10 8 | 9 | 10 | class MRIEmbeddingEncoder(nn.Module): 11 | def __init__(self, embedding_dim: int, dropout: float): 12 | super(MRIEmbeddingEncoder, self).__init__() 13 | self.encoder = nn.Sequential( 14 | nn.Linear(512, 256), 15 | nn.Dropout(dropout), 16 | nn.GELU(), 17 | nn.Linear(256, embedding_dim), 18 | ) 19 | 20 | def forward(self, x: torch.tensor) -> torch.tensor: 21 | return self.encoder(x) 22 | 23 | 24 | class MRIEncoder(nn.Module): 25 | def __init__( 26 | self, in_channels: int, embedding_dim: int = 512, projection_head: bool = True 27 | ) -> None: 28 | super().__init__() 29 | if projection_head: 30 | self.projection_head = nn.Sequential( 31 | nn.Linear(embedding_dim, embedding_dim), 32 | nn.ReLU(inplace=True), 33 | nn.Linear(embedding_dim, int(embedding_dim / 2)), 34 | ) 35 | 36 | self.encoder = resnet10( 37 | spatial_dims=3, n_input_channels=in_channels, num_classes=1 38 | ) 39 | if embedding_dim != 512: 40 | self.encoder.fc = nn.Sequential( 41 | nn.Linear(512, 256), nn.LeakyReLU(), nn.Linear(256, embedding_dim) 42 | ) 43 | else: 44 | self.encoder.fc = nn.Identity() 45 | 46 | def forward( 47 | self, x: torch.tensor 48 | ) -> Union[torch.tensor, Tuple[torch.tensor, torch.tensor]]: 49 | x = self.encoder(x) 50 | if hasattr(self, "projection_head"): 51 | return self.projection_head(x), x 52 | else: 53 | return x 54 | 55 | 56 | class MRIEmbeddingDecoder(nn.Module): 57 | def __init__(self, embedding_dim: int, dropout: float): 58 | super(MRIEmbeddingDecoder, self).__init__() 59 | self.decoder = nn.Sequential( 60 | nn.Linear(embedding_dim, 256), 61 | nn.Dropout(dropout), 62 | nn.GELU(), 63 | nn.Linear(256, 512), 64 | ) 65 | 66 | def forward(self, x: torch.tensor) -> torch.tensor: 67 | return self.decoder(x) 68 | -------------------------------------------------------------------------------- /drim/mri/transforms.py: -------------------------------------------------------------------------------- 1 | from monai.transforms import ( 2 | Compose, 3 | NormalizeIntensityd, 4 | SpatialPadd, 5 | CopyItemsd, 6 | OneOf, 7 | RandCoarseDropoutd, 8 | RandCoarseShuffled, 9 | ToTensord, 10 | NormalizeIntensity, 11 | SpatialPad, 12 | ) 13 | 14 | 15 | def get_tumor_transforms(roi_size): 16 | return Compose( 17 | [ 18 | NormalizeIntensityd(keys=["image"], channel_wise=True, nonzero=True), 19 | SpatialPadd(keys=["image"], spatial_size=roi_size), 20 | CopyItemsd( 21 | keys=["image"], times=1, names=["image_2"], allow_missing_keys=False 22 | ), 23 | OneOf( 24 | transforms=[ 25 | RandCoarseDropoutd( 26 | keys=["image"], 27 | prob=1.0, 28 | holes=8, 29 | spatial_size=10, 30 | dropout_holes=True, 31 | max_spatial_size=18, 32 | ), 33 | RandCoarseDropoutd( 34 | keys=["image"], 35 | prob=1.0, 36 | holes=6, 37 | spatial_size=18, 38 | dropout_holes=False, 39 | max_spatial_size=24, 40 | ), 41 | ] 42 | ), 43 | RandCoarseShuffled( 44 | keys=["image"], prob=0.8, holes=4, spatial_size=4, max_spatial_size=4 45 | ), 46 | OneOf( 47 | transforms=[ 48 | RandCoarseDropoutd( 49 | keys=["image_2"], 50 | prob=1.0, 51 | holes=8, 52 | spatial_size=10, 53 | dropout_holes=True, 54 | max_spatial_size=18, 55 | ), 56 | RandCoarseDropoutd( 57 | keys=["image_2"], 58 | prob=1.0, 59 | holes=6, 60 | spatial_size=18, 61 | dropout_holes=False, 62 | max_spatial_size=24, 63 | ), 64 | ] 65 | ), 66 | RandCoarseShuffled( 67 | keys=["image_2"], prob=0.8, holes=10, spatial_size=8, max_spatial_size=8 68 | ), 69 | ToTensord(keys=["image", "image_2"]), 70 | ] 71 | ) 72 | 73 | 74 | tumor_transfo = Compose( 75 | [ 76 | NormalizeIntensity(nonzero=True, channel_wise=True), 77 | SpatialPad(spatial_size=(64, 64, 64)), 78 | ] 79 | ) 80 | -------------------------------------------------------------------------------- /drim/multimodal/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import MultimodalDataset 2 | from .models import * 3 | -------------------------------------------------------------------------------- /drim/multimodal/datasets.py: -------------------------------------------------------------------------------- 1 | from ..datasets import _BaseDataset 2 | from typing import Dict, List, Tuple, Union 3 | import torch 4 | 5 | 6 | class MultimodalDataset(_BaseDataset): 7 | def __init__( 8 | self, datasets: Dict[str, _BaseDataset], return_mask: bool = True 9 | ) -> None: 10 | modalities = list(datasets.keys()) 11 | self._modality_sanity_check(modalities) 12 | if return_mask: 13 | for modality in modalities: 14 | assert datasets[ 15 | modality 16 | ].return_mask, f"The dataset for modality {modality} does not return a mask, please set return_mask to False" 17 | self.return_mask = return_mask 18 | self.datasets = datasets 19 | 20 | def __len__(self): 21 | return len(self.datasets.values().__iter__().__next__()) 22 | 23 | def __getitem__( 24 | self, idx: int 25 | ) -> Union[ 26 | Tuple[Dict[str, torch.Tensor], Dict[str, bool]], Dict[str, torch.Tensor] 27 | ]: 28 | x = {} 29 | if self.return_mask: 30 | mask = {} 31 | 32 | for modality in self.datasets.keys(): 33 | out = self.datasets[modality][idx] 34 | if self.return_mask: 35 | x[modality], mask[modality] = out[0], out[1] 36 | else: 37 | x[modality] = out 38 | 39 | if self.return_mask: 40 | return x, mask 41 | else: 42 | return x 43 | 44 | @staticmethod 45 | def _modality_sanity_check( 46 | modalities, 47 | available_modalities: List[str] = ["RNA", "MRI", "Clinical", "DNAm", "WSI"], 48 | ) -> None: 49 | for modality in modalities: 50 | assert ( 51 | modality in available_modalities 52 | ), f"The requested modality: {modality} is not available, please choose modalities among {available_modalities}" 53 | -------------------------------------------------------------------------------- /drim/multimodal/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Dict, List, Optional, Union, Tuple 4 | 5 | __all__ = ["MultimodalModel", "MultimodalEncoder", "DRIMSurv", "DRIMU"] 6 | 7 | 8 | class MultimodalEncoder(nn.Module): 9 | def __init__( 10 | self, encoders: Dict[str, nn.Module], devices: Optional[Dict[str, str]] = None 11 | ) -> None: 12 | super().__init__() 13 | modalities = list(encoders.keys()) 14 | self._modality_sanity_check(modalities) 15 | if not devices: 16 | devices = {k: next(v.parameters()).device for k, v in encoders.items()} 17 | 18 | self.devices = devices 19 | for key, model in encoders.items(): 20 | setattr(self, f"{key.lower()}_encoder", model.to(self.devices[key])) 21 | 22 | def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 23 | return self._forward_encoders(x) 24 | 25 | def _forward_encoders( 26 | self, x: Dict[str, torch.Tensor], unique: Optional[bool] = None 27 | ) -> Dict[str, torch.Tensor]: 28 | # put the input tensors to the corresponding device 29 | x = {k: v.to(self.devices[k]) for k, v in x.items()} 30 | # forward pass 31 | if unique: 32 | x = {k: getattr(self, f"{k.lower()}_encoder_u")(v) for k, v in x.items()} 33 | else: 34 | x = {k: getattr(self, f"{k.lower()}_encoder")(v) for k, v in x.items()} 35 | 36 | # normalize embeddings 37 | x = {k: nn.functional.normalize(v, p=2, dim=-1) for k, v in x.items()} 38 | 39 | return x 40 | 41 | def _put_on_devices( 42 | self, 43 | x: Dict[str, torch.Tensor], 44 | mask: Optional[Dict[str, torch.Tensor]], 45 | key: str, 46 | ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]]]: 47 | x = {k: v.to(self.devices[key]) for k, v in x.items()} 48 | if mask is not None: 49 | mask = {k: v.to(self.devices[key]) for k, v in mask.items()} 50 | return x, mask 51 | 52 | @staticmethod 53 | def _modality_sanity_check( 54 | modalities: List[str], 55 | available_modalities: List[str] = ["RNA", "MRI", "DNAm", "WSI"], 56 | ) -> None: 57 | for modality in modalities: 58 | assert ( 59 | modality in available_modalities 60 | ), f"The requested modality: {modality} is not available, please choose modalities among {available_modalities}" 61 | 62 | 63 | class MultimodalModel(MultimodalEncoder): 64 | def __init__( 65 | self, 66 | encoders: Dict[str, nn.Module], 67 | fusion: nn.Module, 68 | devices: Optional[Dict[str, str]] = None, 69 | ) -> None: 70 | super().__init__(encoders, devices) 71 | if "fusion" not in self.devices.keys(): 72 | try: 73 | self.devices["fusion"] = next(fusion.parameters()).device 74 | except: 75 | # if fusion do not have parameters, then it is a function 76 | self.devices["fusion"] = "cuda:0" 77 | self.fusion = fusion 78 | 79 | def forward( 80 | self, 81 | x: Dict[str, torch.Tensor], 82 | mask: Optional[Dict[str, torch.Tensor]] = None, 83 | return_embedding: bool = False, 84 | ) -> Union[torch.Tensor, Tuple[torch.tensor, Dict[str, torch.Tensor]]]: 85 | x = super().forward(x) 86 | 87 | x, mask = self._put_on_devices(x, mask, "fusion") 88 | fused = self.fusion(x, mask) 89 | # if fusion returns a tuple, then it is a masked attention fusion 90 | if isinstance(fused, tuple): 91 | fused, _ = fused 92 | if return_embedding: 93 | return fused, x 94 | else: 95 | return fused 96 | 97 | 98 | class DRIMSurv(MultimodalEncoder): 99 | def __init__( 100 | self, 101 | encoders_sh: Dict[str, nn.Module], 102 | encoders_u: Dict[str, nn.Module], 103 | fusion_s: nn.Module, 104 | fusion_u: nn.Module, 105 | devices: Optional[Dict[str, str]] = None, 106 | ) -> None: 107 | super().__init__(encoders_sh, devices) 108 | for key, model in encoders_u.items(): 109 | setattr(self, f"{key.lower()}_encoder_u", model.to(f"{self.devices[key]}")) 110 | 111 | if "fusion_sh" not in self.devices.keys(): 112 | try: 113 | self.devices["fusion_s"] = next(fusion_s.parameters()).device 114 | except: 115 | # if fusion do not have parameters, then it is a function 116 | self.devices["fusion_s"] = "cuda:0" 117 | 118 | if "fusion_u" not in self.devices.keys(): 119 | try: 120 | self.devices["fusion_u"] = next(fusion_u.parameters()).device 121 | except: 122 | # if fusion do not have parameters, then it is a function 123 | self.devices["fusion_u"] = "cuda:0" 124 | self.fusion_s = fusion_s 125 | self.fusion_u = fusion_u 126 | 127 | def forward( 128 | self, 129 | x: Dict[str, torch.Tensor], 130 | mask: Optional[Dict[str, torch.Tensor]] = None, 131 | return_embedding: bool = False, 132 | ) -> Union[torch.Tensor, Tuple[torch.tensor, Dict[str, torch.Tensor]]]: 133 | x_s = super()._forward_encoders(x) 134 | x_u = super()._forward_encoders(x, unique=True) 135 | 136 | fused = self.drim_forward(x_s, x_u, mask) 137 | if isinstance(fused, tuple): 138 | fused, _ = fused 139 | if return_embedding: 140 | return fused, x_s, x_u 141 | else: 142 | return fused 143 | 144 | def drim_forward( 145 | self, 146 | x_s: Dict[str, torch.Tensor], 147 | x_u: Dict[str, torch.Tensor], 148 | mask: Optional[Dict[str, torch.Tensor]] = None, 149 | ) -> Union[torch.Tensor, Tuple[torch.tensor, Dict[str, torch.Tensor]]]: 150 | x_s, mask = self._put_on_devices(x_s, mask, "fusion_s") 151 | fused_s = self.fusion_s(x_s, mask) 152 | if isinstance(fused_s, tuple): 153 | fused_s, _ = fused_s 154 | x_t = {**x_u, "shared": fused_s} 155 | mask.update({"shared": torch.ones_like(next(iter(mask.values())))}) 156 | # put everything on the same device 157 | x_t, mask = self._put_on_devices(x_t, mask, "fusion_u") 158 | fused = self.fusion_u(x_t, mask) 159 | return fused 160 | 161 | 162 | class DRIMU(DRIMSurv): 163 | def forward( 164 | self, 165 | x: Dict[str, torch.Tensor], 166 | mask: Optional[Dict[str, torch.Tensor]] = None, 167 | return_embedding: bool = False, 168 | ) -> Union[torch.Tensor, Tuple[torch.tensor, Dict[str, torch.Tensor]]]: 169 | x_s = super()._forward_encoders(x) 170 | x_u = super()._forward_encoders(x, unique=True) 171 | if return_embedding: 172 | return x_s, x_u 173 | else: 174 | fused = self.drim_forward(x_s, x_u, mask) 175 | if isinstance(fused, tuple): 176 | fused, _ = fused 177 | 178 | return fused 179 | -------------------------------------------------------------------------------- /drim/rna/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import RNADataset 2 | from .models import RNADecoder, RNAEncoder, RNAEncoderMlp 3 | -------------------------------------------------------------------------------- /drim/rna/datasets.py: -------------------------------------------------------------------------------- 1 | from ..datasets import _BaseDataset 2 | import pandas as pd 3 | import torch 4 | from typing import Union, Tuple 5 | import numpy as np 6 | 7 | 8 | def log_transform(x): 9 | return np.log(x + 1) 10 | 11 | 12 | class RNADataset(_BaseDataset): 13 | """ 14 | Simple dataset for RNA data. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | dataframe: pd.DataFrame, 20 | preprocessor: "sklearn.pipeline.Pipeline", 21 | return_mask: bool = False, 22 | ) -> None: 23 | super().__init__(dataframe, return_mask) 24 | self.preprocessor = preprocessor 25 | 26 | def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, bool], torch.Tensor]: 27 | sample = self.dataframe.iloc[idx] 28 | if not pd.isna(sample.RNA): 29 | out = torch.from_numpy( 30 | self.preprocessor.transform( 31 | pd.read_csv(sample.RNA)["fpkm_uq_unstranded"].values.reshape(1, -1) 32 | ) 33 | ).float() 34 | mask = True 35 | else: 36 | out = torch.zeros(1, 16304).float() 37 | mask = False 38 | 39 | if self.return_mask: 40 | return out, mask 41 | else: 42 | return out 43 | -------------------------------------------------------------------------------- /drim/rna/models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from einops.layers.torch import Rearrange 4 | 5 | 6 | class RNAEncoder(nn.Module): 7 | """ 8 | A vanilla encoder based on 1-d convolution for RNA data. 9 | """ 10 | 11 | def __init__(self, embedding_dim: int, dropout: float) -> None: 12 | super().__init__() 13 | self.encoder = nn.Sequential( 14 | nn.Conv1d(1, 8, 9, 3), 15 | nn.GELU(), 16 | nn.BatchNorm1d(8), 17 | nn.Dropout(dropout), 18 | nn.Conv1d(8, 32, 9, 3), 19 | nn.GELU(), 20 | nn.BatchNorm1d(32), 21 | nn.Dropout(dropout), 22 | nn.Conv1d(32, 64, 9, 3), 23 | nn.GELU(), 24 | nn.BatchNorm1d(64), 25 | nn.Dropout(dropout), 26 | nn.Conv1d(64, 128, 9, 3), 27 | nn.GELU(), 28 | nn.BatchNorm1d(128), 29 | nn.Dropout(dropout), 30 | nn.Conv1d(128, 256, 9, 3), 31 | nn.GELU(), 32 | nn.BatchNorm1d(256), 33 | nn.Dropout(dropout), 34 | nn.Conv1d(256, embedding_dim, 9, 3), 35 | nn.GELU(), 36 | nn.BatchNorm1d(embedding_dim), 37 | nn.Dropout(dropout), 38 | nn.AdaptiveAvgPool1d(1), 39 | ) 40 | 41 | def forward(self, x: torch.Tensor) -> torch.Tensor: 42 | x = self.encoder(x).squeeze(-1) 43 | return x 44 | 45 | 46 | class RNAEncoderMlp(nn.Module): 47 | def __init__(self, embedding_dim: int, dropout: float) -> None: 48 | super().__init__() 49 | self.encoder = nn.Sequential( 50 | nn.BatchNorm1d(18000), 51 | nn.Linear(18000, 512), 52 | nn.GELU(), 53 | nn.BatchNorm1d(512), 54 | nn.Dropout(dropout), 55 | nn.Linear(512, embedding_dim), 56 | ) 57 | 58 | def forward(self, x: torch.Tensor) -> torch.Tensor: 59 | return self.encoder(x).squeeze(-1) 60 | 61 | 62 | class RNADecoder(nn.Module): 63 | """ 64 | A vanilla decoder based on 1-d transposed convolution for RNA data. 65 | """ 66 | 67 | def __init__(self, embedding_dim: int, dropout: float) -> None: 68 | super().__init__() 69 | self.decoder = nn.Sequential( 70 | nn.Linear(embedding_dim, embedding_dim * 20), 71 | Rearrange("b (n e) -> b n e", e=20), 72 | nn.GELU(), 73 | nn.BatchNorm1d(embedding_dim), 74 | nn.Dropout(dropout), 75 | nn.ConvTranspose1d(embedding_dim, 256, 9, 3), 76 | nn.GELU(), 77 | nn.BatchNorm1d(256), 78 | nn.Dropout(dropout), 79 | nn.ConvTranspose1d(256, 128, 9, 3, 2), 80 | nn.GELU(), 81 | nn.BatchNorm1d(128), 82 | nn.Dropout(dropout), 83 | nn.ConvTranspose1d(128, 64, 9, 3, 1), 84 | nn.GELU(), 85 | nn.BatchNorm1d(64), 86 | nn.Dropout(dropout), 87 | nn.ConvTranspose1d(64, 32, 9, 3, 4), 88 | nn.GELU(), 89 | nn.BatchNorm1d(32), 90 | nn.Dropout(dropout), 91 | nn.ConvTranspose1d(32, 8, 9, 3, 1), 92 | nn.GELU(), 93 | nn.BatchNorm1d(8), 94 | nn.Dropout(dropout), 95 | nn.ConvTranspose1d(8, 1, 9, 3, 2), 96 | ) 97 | 98 | def forward(self, x: torch.Tensor) -> torch.Tensor: 99 | x = self.decoder(x) 100 | return x 101 | -------------------------------------------------------------------------------- /drim/trainers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from typing import Dict, Optional 3 | from torchinfo import summary 4 | from collections import defaultdict 5 | import numpy as np 6 | from pycox.evaluation import EvalSurv 7 | import wandb 8 | from .utils import interpolate_dataframe 9 | import torch 10 | from omegaconf import DictConfig, OmegaConf 11 | from .logger import logger 12 | import pandas as pd 13 | import tqdm 14 | from .losses import DiscriminatorLoss 15 | 16 | 17 | class _BaseTrainer: 18 | """ 19 | Base trainer for all experiments 20 | """ 21 | 22 | def __init__( 23 | self, 24 | model: nn.Module, 25 | optimizer: torch.optim.Optimizer, 26 | scheduler: torch.optim.lr_scheduler._LRScheduler, 27 | task_criterion: nn.modules.loss._Loss, 28 | dataloaders: Dict[str, torch.utils.data.DataLoader], 29 | cfg: DictConfig, 30 | wandb_logging: bool, 31 | ): 32 | self.model = model 33 | self.optimizer = optimizer 34 | self.scheduler = scheduler 35 | self.task_criterion = task_criterion 36 | self.dataloaders = dataloaders 37 | self.cfg = cfg 38 | self.wandb_logging = wandb_logging 39 | 40 | def _print_summary(self) -> None: 41 | """ 42 | Print the model summary and the number of trainable parameters 43 | """ 44 | logger.info( 45 | "Trainable parameters {}", 46 | sum(p.numel() for p in self.model.parameters() if p.requires_grad), 47 | ) 48 | logger.info(self.model) 49 | 50 | def _initialize_logger(self) -> None: 51 | """ 52 | Initialize the WandB logger 53 | """ 54 | if "wandb" in self.cfg: 55 | self.wandb_logging = True 56 | if not isinstance(self.cfg.general.modalities, str): 57 | name = "_".join(self.cfg.general.modalities) 58 | else: 59 | name = self.cfg.general.modalities 60 | wandb.init( 61 | name=self.prefix + name, 62 | config={ 63 | k: v 64 | for k, v in OmegaConf.to_container(self.cfg).items() 65 | if k != "wandb" 66 | }, 67 | **self.cfg.wandb, 68 | ) 69 | else: 70 | self.wandb_logging = False 71 | 72 | def put_model_on_correct_mode(self, split: str) -> None: 73 | if split == "train": 74 | self.model.train() 75 | torch.set_grad_enabled(True) 76 | else: 77 | self.model.eval() 78 | torch.set_grad_enabled(False) 79 | 80 | def clear(self) -> None: 81 | if self.wandb_logging: 82 | wandb.finish() 83 | 84 | return 85 | 86 | def fit(self, epochs: Optional[int] = None) -> None: 87 | """ 88 | Train the self.model for a given number of epochs 89 | """ 90 | # self._initialize_logger() 91 | # self.wandb_logging = False 92 | self._print_summary() 93 | metrics = defaultdict(list) 94 | if not epochs: 95 | epochs = self.cfg.general.epochs 96 | for epoch in range(epochs): 97 | logger.info(f"Epoch {epoch + 1}/{epochs}") 98 | logger.info("-" * 10) 99 | metrics["lr"].append(self.optimizer.state_dict()["param_groups"][0]["lr"]) 100 | 101 | train_losses = self.shared_loop(split="train") 102 | for key, value in train_losses.items(): 103 | logger.info(f"Train {key} {value:.4f}") 104 | metrics[f"train/{key}"].append(value) 105 | 106 | val_metrics = self.shared_loop(split="val") 107 | for key, value in val_metrics.items(): 108 | logger.info(f"Val {key} {value:.4f}") 109 | metrics[f"val/{key}"].append(value) 110 | 111 | if self.wandb_logging: 112 | wandb.log({key: value[-1] for key, value in metrics.items()}) 113 | 114 | self.scheduler.step() 115 | torch.save(self.model.state_dict(), self.cfg.general.save_path) 116 | 117 | def evaluate(self, split: str) -> Dict[str, float]: 118 | outputs = self.shared_loop(split) 119 | for key, value in outputs.items(): 120 | logger.info(f"{split} {key} {value:.4f}") 121 | if self.wandb_logging: 122 | # update outputs 123 | outputs = {f"{split}/{key}": value for key, value in outputs.items()} 124 | wandb.log(outputs) 125 | 126 | return outputs 127 | 128 | 129 | class BaseSurvivalTrainer(_BaseTrainer): 130 | def __init__( 131 | self, 132 | model: nn.Module, 133 | optimizer: torch.optim.Optimizer, 134 | scheduler: torch.optim.lr_scheduler._LRScheduler, 135 | task_criterion: nn.modules.loss._Loss, 136 | dataloaders: Dict[str, torch.utils.data.DataLoader], 137 | cfg: DictConfig, 138 | wandb_logging: bool, 139 | cuts: np.ndarray, 140 | ): 141 | super().__init__( 142 | model, optimizer, scheduler, task_criterion, dataloaders, cfg, wandb_logging 143 | ) 144 | self.cuts = cuts 145 | 146 | def compute_survival_metrics(self, outputs, time, event): 147 | """ 148 | Compute the survival metrics with the PyCox package. 149 | """ 150 | hazard = torch.cat(outputs, dim=0) 151 | survival = (1 - hazard.sigmoid()).add(1e-7).log().cumsum(1).exp().cpu().numpy() 152 | survival = interpolate_dataframe(pd.DataFrame(survival.transpose(), self.cuts)) 153 | evaluator = EvalSurv( 154 | survival, time.cpu().numpy(), event.cpu().numpy(), censor_surv="km" 155 | ) 156 | c_index = evaluator.concordance_td() 157 | ibs = evaluator.integrated_brier_score(np.linspace(0, time.cpu().numpy().max())) 158 | inbll = evaluator.integrated_nbll(np.linspace(0, time.cpu().numpy().max())) 159 | cs_score = (c_index + (1 - ibs)) / 2 160 | return {"c_index": c_index, "ibs": ibs, "inbll": inbll, "cs_score": cs_score} 161 | 162 | def shared_loop(self, split: str) -> Dict[str, float]: 163 | self.put_model_on_correct_mode(split) 164 | total_task_loss = 0.0 165 | raw_predictions = [] 166 | times = [] 167 | events = [] 168 | for batch in tqdm.tqdm(self.dataloaders[split]): 169 | data, time, event = batch 170 | # check if data contains mask 171 | if isinstance(data, list): 172 | data, mask = data 173 | outputs = self.model(data, mask, return_embedding=False) 174 | else: 175 | outputs = self.model(data, return_embedding=False) 176 | 177 | loss = self.task_criterion( 178 | outputs, time.to(outputs.device), event.to(outputs.device) 179 | ) 180 | if split == "train": 181 | self.optimizer.zero_grad() 182 | loss.backward() 183 | self.optimizer.step() 184 | total_task_loss += loss.item() * time.size(0) 185 | raw_predictions.append(outputs) 186 | times.append(time) 187 | events.append(event) 188 | 189 | outputs = {"task_loss": total_task_loss / len(self.dataloaders[split].dataset)} 190 | if split != "train": 191 | task_metrics = self.compute_survival_metrics( 192 | raw_predictions, torch.cat(times, dim=0), torch.cat(events, dim=0) 193 | ) 194 | outputs.update(task_metrics) 195 | 196 | return outputs 197 | 198 | 199 | class AuxSurvivalTrainer(BaseSurvivalTrainer): 200 | """ 201 | This class assumes that an auxiliary loss is added to the survival loss. This auxiliary loss can be either MMO Loss or CL loss. 202 | """ 203 | 204 | def __init__( 205 | self, 206 | model: nn.Module, 207 | optimizer: torch.optim.Optimizer, 208 | scheduler: torch.optim.lr_scheduler._LRScheduler, 209 | task_criterion: nn.modules.loss._Loss, 210 | dataloaders: Dict[str, torch.utils.data.DataLoader], 211 | cfg: DictConfig, 212 | wandb_logging: bool, 213 | cuts: np.ndarray, 214 | aux_loss: torch.nn.modules.loss._Loss, 215 | ) -> None: 216 | super().__init__( 217 | model, 218 | optimizer, 219 | scheduler, 220 | task_criterion, 221 | dataloaders, 222 | cfg, 223 | wandb_logging, 224 | cuts, 225 | ) 226 | self.aux_loss = aux_loss 227 | 228 | def shared_loop(self, split: str) -> Dict[str, float]: 229 | self.put_model_on_correct_mode(split) 230 | raw_predictions = [] 231 | embeddings = defaultdict(list) 232 | masks = defaultdict(list) 233 | total_task_loss = 0.0 234 | total_aux_loss = 0.0 235 | times = [] 236 | events = [] 237 | for batch in tqdm.tqdm(self.dataloaders[split]): 238 | data, time, event = batch 239 | # check if data contains mask 240 | if isinstance(data, list): 241 | data, mask = data 242 | outputs, batch_embeddings = self.model( 243 | data, mask, return_embedding=True 244 | ) 245 | else: 246 | outputs, batch_embeddings = self.model(data, return_embedding=True) 247 | 248 | task_loss = self.task_criterion( 249 | outputs, time.to(outputs.device), event.to(outputs.device) 250 | ) 251 | # put mask on the same device as embeddings 252 | mask = {k: v.to(batch_embeddings[k].device) for k, v in mask.items()} 253 | aux_loss = self.aux_loss(batch_embeddings, mask) 254 | loss = task_loss + self.cfg.aux_loss.alpha * aux_loss 255 | if split == "train": 256 | self.optimizer.zero_grad() 257 | loss.backward() 258 | self.optimizer.step() 259 | 260 | total_task_loss += task_loss.item() * time.size(0) 261 | total_aux_loss += aux_loss.item() * time.size(0) 262 | raw_predictions.append(outputs) 263 | times.append(time) 264 | events.append(event) 265 | 266 | outputs = { 267 | "task_loss": total_task_loss / len(self.dataloaders[split].dataset), 268 | "aux_loss": total_aux_loss / len(self.dataloaders[split].dataset), 269 | } 270 | if split != "train": 271 | task_metrics = self.compute_survival_metrics( 272 | raw_predictions, torch.cat(times, dim=0), torch.cat(events, dim=0) 273 | ) 274 | embeddings = {k: torch.cat(v) for k, v in embeddings.items()} 275 | masks = {k: torch.cat(v) for k, v in masks.items()} 276 | outputs.update(task_metrics) 277 | 278 | return outputs 279 | 280 | 281 | class DRIMSurvTrainer(AuxSurvivalTrainer): 282 | def __init__( 283 | self, 284 | model: nn.Module, 285 | discriminators: Dict[str, nn.Module], 286 | optimizer: torch.optim.Optimizer, 287 | optimizers_dsm: Dict[str, torch.optim.Optimizer], 288 | scheduler: torch.optim.lr_scheduler._LRScheduler, 289 | task_criterion: nn.modules.loss._Loss, 290 | dataloaders: Dict[str, torch.utils.data.DataLoader], 291 | cfg: DictConfig, 292 | wandb_logging: bool, 293 | cuts: np.ndarray, 294 | aux_loss: torch.nn.modules.loss._Loss, 295 | ) -> None: 296 | super().__init__( 297 | model, 298 | optimizer, 299 | scheduler, 300 | task_criterion, 301 | dataloaders, 302 | cfg, 303 | wandb_logging, 304 | cuts, 305 | aux_loss, 306 | ) 307 | self.discriminators = discriminators 308 | self.optimizers_dsm = optimizers_dsm 309 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 310 | self.discriminator_loss = DiscriminatorLoss() 311 | 312 | def shared_loop(self, split: str) -> Dict[str, float]: 313 | self.put_model_on_correct_mode(split) 314 | raw_predictions = [] 315 | embeddings = defaultdict(list) 316 | masks = defaultdict(list) 317 | total_task_loss = 0.0 318 | total_sh_loss = 0.0 319 | total_u_loss = 0.0 320 | total_discrimators_loss = {k: 0.0 for k in self.optimizers_dsm.keys()} 321 | times = [] 322 | events = [] 323 | for batch in tqdm.tqdm(self.dataloaders[split]): 324 | data, time, event = batch 325 | # check if data contains mask 326 | if isinstance(data, list): 327 | data, mask = data 328 | outputs, batch_embeddings_sh, batch_embeddings_u = self.model( 329 | data, mask, return_embedding=True 330 | ) 331 | else: 332 | outputs, batch_embeddings_sh, batch_embeddings_u = self.model( 333 | data, return_embedding=True 334 | ) 335 | 336 | # link to survival loss 337 | task_loss = self.task_criterion( 338 | outputs, time.to(outputs.device), event.to(outputs.device) 339 | ) 340 | # compute shared loss 341 | sh_loss = self.aux_loss( 342 | {k: v.to(self.device) for k, v in batch_embeddings_sh.items()}, 343 | {k: v.to(self.device) for k, v in mask.items()}, 344 | ) 345 | # put mask on cpu 346 | mask = {k: v.cpu() for k, v in mask.items()} 347 | # prepare inputs for unique encoders, x_r is suppose to be drawn from P(x, y) while x_r_prime is drawn from P(x)P(y) 348 | x_r = { 349 | k: torch.cat( 350 | [ 351 | batch_embeddings_sh[k][mask[k]] 352 | .detach() 353 | .to(batch_embeddings_u[k].device), 354 | batch_embeddings_u[k][mask[k]], 355 | ], 356 | dim=1, 357 | ) 358 | for k in batch_embeddings_u.keys() 359 | } 360 | x_r_prime = { 361 | k: torch.cat( 362 | [ 363 | torch.roll( 364 | batch_embeddings_sh[k][mask[k]] 365 | .detach() 366 | .to(batch_embeddings_u[k].device), 367 | 1, 368 | 0, 369 | ), 370 | batch_embeddings_u[k][mask[k]], 371 | ], 372 | dim=1, 373 | ) 374 | for k in batch_embeddings_sh.keys() 375 | } 376 | # compute unique loss 377 | u_loss = torch.tensor(0.0).to(self.device) 378 | for modality in x_r.keys(): 379 | # if mask is empty then continue 380 | if sum(mask[modality]) <= 1: 381 | continue 382 | # get fake logits 383 | fake_logits = self.discriminators[modality](x_r[modality]) 384 | u_loss += torch.nn.functional.binary_cross_entropy_with_logits( 385 | input=fake_logits, target=torch.ones_like(fake_logits) 386 | ).to(self.device) 387 | 388 | # update shared and unique encoders according to their respective loss and the task loss 389 | loss_encoders = ( 390 | task_loss 391 | + self.cfg.aux_loss.alpha * sh_loss 392 | + self.cfg.disentangled.gamma * u_loss 393 | ) 394 | if split == "train": 395 | self.optimizer.zero_grad() 396 | loss_encoders.backward() 397 | self.optimizer.step() 398 | 399 | # update discriminators 400 | for modality in x_r.keys(): 401 | # if mask is empty then continue 402 | if sum(mask[modality]) <= 1: 403 | continue 404 | 405 | fake_logits = self.discriminators[modality](x_r[modality].detach()) 406 | real_logits = self.discriminators[modality]( 407 | x_r_prime[modality].detach() 408 | ) 409 | loss_dsm = self.discriminator_loss( 410 | real_logits=real_logits, fake_logits=fake_logits 411 | ) 412 | 413 | if split == "train": 414 | self.optimizers_dsm[modality].zero_grad() 415 | loss_dsm.backward() 416 | self.optimizers_dsm[modality].step() 417 | 418 | total_discrimators_loss[modality] += loss_dsm.item() * time.size(0) 419 | 420 | total_task_loss += task_loss.item() * time.size(0) 421 | total_sh_loss += sh_loss.item() * time.size(0) 422 | total_u_loss += u_loss.item() * time.size(0) 423 | raw_predictions.append(outputs) 424 | times.append(time) 425 | events.append(event) 426 | 427 | outputs = { 428 | "shared_loss": total_sh_loss / len(self.dataloaders[split].dataset), 429 | "unique_loss": total_u_loss / len(self.dataloaders[split].dataset), 430 | "task_loss": total_task_loss / len(self.dataloaders[split].dataset), 431 | } 432 | outputs.update( 433 | { 434 | f"discriminator_{k}": v / len(self.dataloaders[split].dataset) 435 | for k, v in total_discrimators_loss.items() 436 | } 437 | ) 438 | 439 | if split != "train": 440 | task_metrics = self.compute_survival_metrics( 441 | raw_predictions, torch.cat(times, dim=0), torch.cat(events, dim=0) 442 | ) 443 | embeddings = {k: torch.cat(v) for k, v in embeddings.items()} 444 | masks = {k: torch.cat(v) for k, v in masks.items()} 445 | outputs = {**outputs, **task_metrics} 446 | 447 | return outputs 448 | 449 | 450 | class DRIMUTrainer(DRIMSurvTrainer): 451 | def __init__( 452 | self, 453 | model: nn.Module, 454 | decoders: Dict[str, nn.Module], 455 | discriminators: Dict[str, nn.Module], 456 | optimizer: torch.optim.Optimizer, 457 | optimizers_dsm: Dict[str, torch.optim.Optimizer], 458 | scheduler: torch.optim.lr_scheduler._LRScheduler, 459 | task_criterion: nn.modules.loss._Loss, 460 | dataloaders: Dict[str, torch.utils.data.DataLoader], 461 | cfg: DictConfig, 462 | wandb_logging: bool, 463 | cuts: np.ndarray, 464 | aux_loss: torch.nn.modules.loss._Loss, 465 | ) -> None: 466 | super().__init__( 467 | model, 468 | discriminators, 469 | optimizer, 470 | optimizers_dsm, 471 | scheduler, 472 | task_criterion, 473 | dataloaders, 474 | cfg, 475 | wandb_logging, 476 | cuts, 477 | aux_loss, 478 | ) 479 | self.decoders = decoders 480 | 481 | def shared_loop(self, split: str) -> Dict[str, float]: 482 | self.put_model_on_correct_mode(split) 483 | total_loss = 0.0 484 | embeddings = defaultdict(list) 485 | masks = defaultdict(list) 486 | total_sh_loss = 0.0 487 | total_u_loss = 0.0 488 | total_discriminators_loss = {k: 0.0 for k in self.optimizers_dsm.keys()} 489 | total_decoders_loss = {k: 0.0 for k in self.decoders.keys()} 490 | for batch in tqdm.tqdm(self.dataloaders[split]): 491 | data, time, _ = batch 492 | # check if data contains mask 493 | if isinstance(data, list): 494 | data, mask = data 495 | batch_embeddings_sh, batch_embeddings_u = self.model( 496 | data, mask, return_embedding=True 497 | ) 498 | else: 499 | batch_embeddings_sh, batch_embeddings_u = self.model( 500 | data, return_embedding=True 501 | ) 502 | 503 | sh_loss = self.aux_loss( 504 | {k: v.to(self.device) for k, v in batch_embeddings_sh.items()}, 505 | {k: v.to(self.device) for k, v in mask.items()}, 506 | ) 507 | # put mask on cpu 508 | mask = {k: v.cpu() for k, v in mask.items()} 509 | x_r = { 510 | k: torch.cat( 511 | [ 512 | batch_embeddings_sh[k][mask[k]] 513 | .detach() 514 | .to(batch_embeddings_u[k].device), 515 | batch_embeddings_u[k][mask[k]], 516 | ], 517 | dim=1, 518 | ) 519 | for k in batch_embeddings_u.keys() 520 | } 521 | x_r_prime = { 522 | k: torch.cat( 523 | [ 524 | torch.roll( 525 | batch_embeddings_sh[k][mask[k]] 526 | .detach() 527 | .to(batch_embeddings_u[k].device), 528 | 1, 529 | 0, 530 | ), 531 | batch_embeddings_u[k][mask[k]], 532 | ], 533 | dim=1, 534 | ) 535 | for k in batch_embeddings_sh.keys() 536 | } 537 | u_loss = torch.tensor(0.0).to(self.device) 538 | decoder_loss_iteration = torch.tensor(0.0).to(self.device) 539 | for modality in x_r.keys(): 540 | if sum(mask[modality]) <= 1: 541 | continue 542 | fake_logits = self.discriminators[modality](x_r[modality]) 543 | u_loss += torch.nn.functional.binary_cross_entropy_with_logits( 544 | input=fake_logits, target=torch.ones_like(fake_logits) 545 | ).to(self.device) 546 | decoder_outputs = self.decoders[modality]( 547 | batch_embeddings_u[modality][mask[modality]] 548 | ) 549 | decoder_loss = torch.nn.functional.mse_loss( 550 | decoder_outputs, 551 | data[modality][mask[modality]].to(decoder_outputs.device), 552 | ) 553 | decoder_loss_iteration += decoder_loss.to(self.device) 554 | total_decoders_loss[modality] += decoder_loss.item() * time.size(0) 555 | 556 | loss_encoders = ( 557 | self.cfg.aux_loss.alpha * sh_loss 558 | + self.cfg.disentangled.gamma * u_loss 559 | + decoder_loss_iteration 560 | ) 561 | if split == "train": 562 | self.optimizer.zero_grad() 563 | loss_encoders.backward() 564 | self.optimizer.step() 565 | else: 566 | for modality in batch_embeddings_sh.keys(): 567 | embeddings[modality].append(batch_embeddings_sh[modality].cpu()) 568 | masks[modality].append(mask[modality].cpu()) 569 | 570 | losses_dsm = {} 571 | for modality in x_r.keys(): 572 | if sum(mask[modality]) <= 1: 573 | losses_dsm[modality] = 0.0 574 | continue 575 | 576 | fake_logits = self.discriminators[modality](x_r[modality].detach()) 577 | real_logits = self.discriminators[modality]( 578 | x_r_prime[modality].detach() 579 | ) 580 | loss_dsm = self.discriminator_loss( 581 | real_logits=real_logits, fake_logits=fake_logits 582 | ) 583 | 584 | if split == "train": 585 | self.optimizers_dsm[modality].zero_grad() 586 | loss_dsm.backward() 587 | self.optimizers_dsm[modality].step() 588 | 589 | total_discriminators_loss[modality] += loss_dsm.item() 590 | 591 | total_sh_loss += sh_loss.item() * time.size(0) 592 | total_u_loss += u_loss.item() * time.size(0) 593 | 594 | outputs = { 595 | "loss": total_loss / len(self.dataloaders[split].dataset), 596 | "sh_loss": total_sh_loss / len(self.dataloaders[split].dataset), 597 | "unique_loss": total_u_loss / len(self.dataloaders[split].dataset), 598 | } 599 | outputs.update( 600 | { 601 | f"discriminator_{k}": v / len(self.dataloaders[split].dataset) 602 | for k, v in total_discriminators_loss.items() 603 | } 604 | ) 605 | outputs.update( 606 | { 607 | f"decoder_{k}": v / len(self.dataloaders[split].dataset) 608 | for k, v in total_decoders_loss.items() 609 | } 610 | ) 611 | 612 | return outputs 613 | 614 | def finetune(self): 615 | self.shared_loop = self.shared_loop_finetune 616 | self.optimizer = torch.optim.AdamW( 617 | self.model.parameters(), **self.cfg.optimizer.params 618 | ) 619 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 620 | self.optimizer, T_max=30, eta_min=5e-6 621 | ) 622 | for name, param in self.model.named_parameters(): 623 | if "_encoder" in name: 624 | param.requires_grad = False 625 | self.fit(epochs=self.cfg.general.epochs_finetune) 626 | 627 | def shared_loop_finetune(self, split: str) -> Dict[str, float]: 628 | self.put_model_on_correct_mode(split) 629 | total_loss = 0.0 630 | raw_predictions = [] 631 | times = [] 632 | events = [] 633 | for batch in tqdm.tqdm(self.dataloaders[split]): 634 | data, time, event = batch 635 | # check if data contains mask 636 | if isinstance(data, list): 637 | data, mask = data 638 | outputs = self.model(data, mask, return_embedding=False) 639 | else: 640 | outputs = self.model(data, return_embedding=False) 641 | 642 | loss = self.task_criterion( 643 | outputs, time.to(outputs.device), event.to(outputs.device) 644 | ) 645 | if split == "train": 646 | self.optimizer.zero_grad() 647 | loss.backward() 648 | self.optimizer.step() 649 | total_loss += loss.item() * time.size(0) 650 | raw_predictions.append(outputs) 651 | times.append(time) 652 | events.append(event) 653 | 654 | outputs = {"task_loss": total_loss / len(self.dataloaders[split].dataset)} 655 | if split != "train": 656 | task_metrics = self.compute_survival_metrics( 657 | raw_predictions, torch.cat(times, dim=0), torch.cat(events, dim=0) 658 | ) 659 | outputs.update(task_metrics) 660 | 661 | return outputs 662 | -------------------------------------------------------------------------------- /drim/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from typing import List, Union, Dict, Any, Tuple 4 | import torch 5 | import random 6 | 7 | 8 | def seed_everything(seed: int) -> None: 9 | import monai 10 | 11 | monai.utils.set_determinism(seed=seed, additional_settings=None) 12 | np.random.seed(seed) 13 | random.seed(seed) 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | torch.backends.cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = False 19 | 20 | 21 | def clean_state_dict(state_dict: dict) -> dict: 22 | new_state_dict = {} 23 | for key, value in state_dict.items(): 24 | if key.startswith("module."): 25 | new_state_dict[key[7:]] = value 26 | else: 27 | new_state_dict[key] = value 28 | return new_state_dict 29 | 30 | 31 | def prepare_data( 32 | dataframe: pd.DataFrame, modalities: Union[List[str], str], min_k_modality: int = 2 33 | ) -> pd.DataFrame: 34 | if isinstance(modalities, str): 35 | modalities = [modalities] 36 | 37 | # columns = ['project_id', 'submitter_id'] 38 | # dataframe = dataframe.drop(columns=columns) 39 | 40 | if len(modalities) == 1: 41 | try: 42 | mask = ~dataframe[modalities[0]].isna() 43 | dataframe = dataframe.loc[mask] 44 | except: 45 | pass 46 | 47 | else: 48 | temp_df = dataframe[modalities] 49 | dataframe = dataframe[temp_df.count(axis=1) >= min_k_modality] 50 | 51 | return dataframe 52 | 53 | 54 | get_target_survival = lambda df: (df["time"].values, df["event"].values) 55 | 56 | 57 | def log_transform(x: np.ndarray) -> np.ndarray: 58 | return np.log(1 + x) 59 | 60 | 61 | def create_nan_dataframe( 62 | num_row: int, num_col: int, name_col: List[str] 63 | ) -> pd.DataFrame: 64 | df = pd.DataFrame(np.zeros((num_row, num_col)), columns=name_col) 65 | df[:] = np.nan 66 | return df 67 | 68 | 69 | def interpolate_dataframe(dataframe: pd.DataFrame, n: int = 10) -> pd.DataFrame: 70 | dataframe.reset_index(inplace=True) 71 | dataframe_list = [] 72 | for i, idx in enumerate(dataframe.index): 73 | df_temp = dataframe[dataframe.index == idx] 74 | dataframe_list.append(df_temp) 75 | if i != len(dataframe) - 1: 76 | dataframe_list.append( 77 | create_nan_dataframe(n, df_temp.shape[1], df_temp.columns) 78 | ) 79 | 80 | dataframe = pd.concat(dataframe_list).interpolate("linear") 81 | dataframe = dataframe.set_index("index") 82 | return dataframe 83 | 84 | 85 | def seed_worker(worker_id): 86 | worker_seed = torch.initial_seed() % 2**32 87 | np.random.seed(worker_seed) 88 | random.seed(worker_seed) 89 | 90 | 91 | def get_dataframes(fold: int) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: 92 | dataframe = pd.read_csv("data/files/train_brain.csv") 93 | dataframe_train = ( 94 | dataframe[dataframe["split"] != fold].copy().drop(columns=["split"]) 95 | ) 96 | dataframe_val = dataframe[dataframe["split"] == fold].copy().drop(columns=["split"]) 97 | dataframe_test = pd.read_csv("data/files/test_brain.csv") 98 | return {"train": dataframe_train, "val": dataframe_val, "test": dataframe_test} 99 | -------------------------------------------------------------------------------- /drim/wsi/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import WSIDataset 2 | from .models import WSIDecoder, WSIEncoder, ResNetWrapperSimCLR 3 | -------------------------------------------------------------------------------- /drim/wsi/datasets.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | from PIL import Image 3 | import torch 4 | from ..datasets import _BaseDataset 5 | import pandas as pd 6 | 7 | 8 | class PatchDataset(torch.utils.data.Dataset): 9 | def __init__( 10 | self, filepaths: Tuple[str, ...], transforms: "torchvision.transforms" 11 | ) -> None: 12 | self.filepaths = filepaths 13 | self.transforms = transforms 14 | 15 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: 16 | path = self.filepaths[idx] 17 | image = Image.open(path) 18 | image_1, image_2 = self.transforms(image) 19 | return image_1, image_2 20 | 21 | def __len__(self) -> int: 22 | return len(self.filepaths) 23 | 24 | 25 | class WSIDataset(_BaseDataset): 26 | def __init__( 27 | self, 28 | dataframe: pd.DataFrame, 29 | k: int, 30 | is_train: bool = True, 31 | return_mask: bool = False, 32 | ) -> None: 33 | super().__init__(dataframe, return_mask) 34 | self.k = k 35 | self.is_train = is_train 36 | 37 | def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, bool], torch.Tensor]: 38 | sample = self.dataframe.iloc[idx] 39 | if not pd.isna(sample.WSI): 40 | data = pd.read_csv(sample.WSI) 41 | # get k random embeddings 42 | if self.is_train: 43 | data = data.sample(self.k) 44 | else: 45 | data = data.iloc[: self.k] 46 | 47 | data = torch.from_numpy(data.values).float() 48 | mask = True 49 | else: 50 | data = torch.zeros(self.k, 512).float() 51 | mask = False 52 | 53 | if self.return_mask: 54 | return data, mask 55 | else: 56 | return data 57 | -------------------------------------------------------------------------------- /drim/wsi/func.py: -------------------------------------------------------------------------------- 1 | """ 2 | Deeply c/p from vit_pytorch (https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py) 3 | """ 4 | import torch 5 | from torch import nn 6 | from einops import rearrange 7 | 8 | 9 | def softmax_one(x, dim=None): 10 | # subtract the max for stability 11 | x = x - x.max(dim=dim, keepdim=True).values 12 | # compute exponentials 13 | exp_x = torch.exp(x) 14 | # compute softmax values and add on in the denominator 15 | return exp_x / (1 + exp_x.sum(dim=dim, keepdim=True)) 16 | 17 | 18 | # helpers 19 | 20 | 21 | def pair(t): 22 | return t if isinstance(t, tuple) else (t, t) 23 | 24 | 25 | # classes 26 | 27 | 28 | class FeedForward(nn.Module): 29 | def __init__(self, dim, hidden_dim, dropout=0.0): 30 | super().__init__() 31 | self.net = nn.Sequential( 32 | nn.LayerNorm(dim), 33 | nn.Linear(dim, hidden_dim), 34 | nn.GELU(), 35 | nn.Dropout(dropout), 36 | nn.Linear(hidden_dim, dim), 37 | nn.Dropout(dropout), 38 | ) 39 | 40 | def forward(self, x): 41 | return self.net(x) 42 | 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0): 46 | super().__init__() 47 | inner_dim = dim_head * heads 48 | project_out = not (heads == 1 and dim_head == dim) 49 | 50 | self.heads = heads 51 | self.scale = dim_head**-0.5 52 | 53 | self.norm = nn.LayerNorm(dim) 54 | 55 | self.dropout = nn.Dropout(dropout) 56 | 57 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) 58 | 59 | self.to_out = ( 60 | nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) 61 | if project_out 62 | else nn.Identity() 63 | ) 64 | 65 | def forward(self, x): 66 | x = self.norm(x) 67 | 68 | qkv = self.to_qkv(x).chunk(3, dim=-1) 69 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv) 70 | 71 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 72 | 73 | attn = softmax_one(dots, dim=-1) 74 | attn = self.dropout(attn) 75 | 76 | out = torch.matmul(attn, v) 77 | out = rearrange(out, "b h n d -> b n (h d)") 78 | return self.to_out(out) 79 | 80 | 81 | class Transformer(nn.Module): 82 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0): 83 | super().__init__() 84 | self.norm = nn.LayerNorm(dim) 85 | self.layers = nn.ModuleList([]) 86 | for _ in range(depth): 87 | self.layers.append( 88 | nn.ModuleList( 89 | [ 90 | Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout), 91 | FeedForward(dim, mlp_dim, dropout=dropout), 92 | ] 93 | ) 94 | ) 95 | 96 | def forward(self, x): 97 | for attn, ff in self.layers: 98 | x = attn(x) + x 99 | x = ff(x) + x 100 | 101 | return self.norm(x) 102 | -------------------------------------------------------------------------------- /drim/wsi/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from typing import Tuple, Union 5 | from einops import repeat 6 | from .func import Transformer 7 | from einops.layers.torch import Rearrange 8 | 9 | 10 | class ResNetWrapperSimCLR(nn.Module): 11 | """ 12 | A wrapper for the ResNet34 model with a projection head for SimCLR. 13 | """ 14 | 15 | def __init__(self, out_dim: int, projection_head: bool = True) -> None: 16 | super().__init__() 17 | self.encoder = torchvision.models.resnet34(pretrained=False) 18 | self.encoder.fc = nn.Identity() 19 | if projection_head: 20 | self.projection_head = nn.Sequential( 21 | nn.Linear(512, 512), nn.ReLU(inplace=True), nn.Linear(512, out_dim) 22 | ) 23 | 24 | def forward( 25 | self, x: torch.tensor 26 | ) -> Union[torch.tensor, Tuple[torch.tensor, torch.tensor]]: 27 | x = self.encoder(x) 28 | if hasattr(self, "projection_head"): 29 | return self.projection_head(x), x 30 | else: 31 | return x 32 | 33 | 34 | class WSIEncoder(nn.Module): 35 | """ 36 | A attention-based encoder for WSI data. 37 | """ 38 | 39 | def __init__( 40 | self, 41 | embedding_dim: int, 42 | depth: int, 43 | heads: int, 44 | dim: int = 512, 45 | pool: str = "cls", 46 | dim_head: int = 64, 47 | mlp_dim: int = 128, 48 | dropout: float = 0.0, 49 | emb_dropout: float = 0.0, 50 | ) -> None: 51 | super().__init__() 52 | self.layer_norm = nn.LayerNorm(dim) 53 | 54 | # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 55 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 56 | self.dropout = nn.Dropout(emb_dropout) 57 | 58 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 59 | 60 | self.pool = pool 61 | self.to_latent = ( 62 | nn.Identity() if embedding_dim == dim else nn.Linear(dim, embedding_dim) 63 | ) 64 | 65 | def forward(self, x: torch.Tensor) -> torch.Tensor: 66 | x = self.layer_norm(x) 67 | b, n, _ = x.shape 68 | 69 | cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=b) 70 | x = torch.cat((cls_tokens, x), dim=1) 71 | # x += self.pos_embedding[:, :(n + 1)] 72 | x = self.dropout(x) 73 | 74 | x = self.transformer(x) 75 | 76 | x = x.mean(dim=1) if self.pool == "mean" else x[:, 0] 77 | x = self.to_latent(x) 78 | 79 | return x 80 | 81 | 82 | class WSIDecoder(nn.Module): 83 | """ 84 | A vanilla mlp-based decoder for WSI data. 85 | """ 86 | 87 | def __init__(self, embedding_dim: int, dropout: float) -> None: 88 | super().__init__() 89 | self.decoder = nn.Sequential( 90 | nn.Linear(embedding_dim, 256), 91 | nn.BatchNorm1d(256), 92 | nn.Dropout(dropout), 93 | nn.LeakyReLU(), 94 | nn.Linear(256, 5120), 95 | Rearrange("b (p e) -> b p e", p=10), 96 | ) 97 | 98 | def forward(self, x: torch.Tensor) -> torch.Tensor: 99 | return self.decoder(x) 100 | -------------------------------------------------------------------------------- /drim/wsi/transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | import torch 3 | 4 | 5 | class Dropout: 6 | def __init__(self, dropout_prob, p): 7 | self.dropout_prob = dropout_prob 8 | self.p = p 9 | 10 | def __call__(self, img): 11 | if torch.rand(1).item() < self.p: 12 | mask = ~torch.bernoulli(torch.full_like(img, self.dropout_prob)).bool() 13 | return img * mask 14 | else: 15 | return img 16 | 17 | 18 | class NviewsAugment(object): 19 | def __init__(self, transforms, n_views=2): 20 | self.transforms = transforms 21 | self.n_views = n_views 22 | 23 | def __call__(self, x): 24 | return [self.transforms(x) for i in range(self.n_views)] 25 | 26 | 27 | contrastive_base = transforms.Compose( 28 | [ 29 | transforms.RandomResizedCrop(size=256, scale=(0.35, 1.0)), 30 | transforms.RandomHorizontalFlip(), 31 | transforms.RandomVerticalFlip(), 32 | transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8), 33 | transforms.RandomGrayscale(p=0.2), 34 | transforms.GaussianBlur(kernel_size=11), # , sigma=(0.1, 1.)), 35 | transforms.ToTensor(), 36 | Dropout(0.1, 0.3), 37 | ] 38 | ) 39 | 40 | 41 | contrastive_wsi_transforms = NviewsAugment(contrastive_base, n_views=2) 42 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.24.3 2 | matplotlib 3 | imutils 4 | torch==1.13.1 5 | torchvision==0.14.1 6 | monai==1.1.0 7 | tqdm 8 | pyvips==2.2.1 9 | opencv-python==4.7.0.72 10 | pandas==2.0.1 11 | hydra-core==1.3.2 12 | einops==0.6.1 13 | pycox==0.2.3 14 | loguru -------------------------------------------------------------------------------- /robustness.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | from typing import Union, List 3 | from collections import defaultdict 4 | 5 | # Third-party libraries 6 | import pandas as pd 7 | import hydra 8 | from omegaconf import DictConfig 9 | import torch 10 | import numpy as np 11 | from pycox.evaluation import EvalSurv 12 | 13 | # Local dependencies 14 | from drim.helpers import get_datasets, get_targets, get_encoder 15 | from drim.logger import logger 16 | from drim.utils import ( 17 | interpolate_dataframe, 18 | log_transform, 19 | seed_everything, 20 | get_dataframes, 21 | ) 22 | from drim.multimodal import MultimodalDataset, DRIMSurv, MultimodalModel 23 | from drim.datasets import SurvivalDataset 24 | from drim.models import MultimodalWrapper 25 | 26 | 27 | def prepare_data( 28 | dataframe: pd.DataFrame, modalities: Union[List[str], str] 29 | ) -> pd.DataFrame: 30 | available_modalities = ["DNAm", "WSI", "RNA", "MRI"] 31 | 32 | # columns = ['project_id', 'submitter_id'] 33 | # dataframe = dataframe.drop(columns=columns) 34 | # get remaining modalities 35 | remaining = [ 36 | modality for modality in available_modalities if modality not in modalities 37 | ] 38 | if len(modalities) == 1: 39 | try: 40 | mask = ~dataframe[modalities[0]].isna() 41 | dataframe = dataframe.loc[mask] 42 | except: 43 | pass 44 | 45 | else: 46 | temp_df = dataframe[modalities] 47 | dataframe = dataframe[temp_df.count(axis=1) >= len(modalities)] 48 | 49 | # put the whole columns of remaining modalities to NaN 50 | dataframe = dataframe.assign(**{modality: None for modality in remaining}) 51 | return dataframe 52 | 53 | 54 | @hydra.main(version_base=None, config_path="configs", config_name="robustness") 55 | def main(cfg: DictConfig) -> None: 56 | logger.info("Starting multimodal robustness test.") 57 | logger.info("Method tested {}.", cfg.method) 58 | modality_combinations = [ 59 | ["WSI", "MRI"], 60 | ["WSI", "RNA"], 61 | ["DNAm", "WSI"], 62 | ["DNAm", "RNA"], 63 | ["DNAm", "MRI"], 64 | ["RNA", "MRI"], 65 | ["DNAm", "WSI", "RNA"], 66 | ["DNAm", "WSI", "MRI"], 67 | ["DNAm", "RNA", "MRI"], 68 | ["WSI", "RNA", "MRI"], 69 | ["DNAm", "WSI", "RNA", "MRI"], 70 | ] 71 | for combination in modality_combinations: 72 | if cfg.method == "tensor": 73 | cfg.general.dim = 32 74 | logs = defaultdict(list) 75 | for fold in range(5): 76 | seed_everything(cfg.general.seed) 77 | dataframes = get_dataframes(fold) 78 | dataframes = { 79 | split: prepare_data(dataframe, combination) 80 | for split, dataframe in dataframes.items() 81 | } 82 | test_datasets = {} 83 | encoders = {} 84 | if cfg.method == "drim": 85 | encoders_u = {} 86 | 87 | for modality in ["DNAm", "WSI", "RNA", "MRI"]: 88 | datasets = get_datasets(dataframes, modality, fold, return_mask=True) 89 | test_datasets[modality] = datasets["test"] 90 | encoder = get_encoder(modality, cfg).cuda() 91 | encoders[modality] = encoder 92 | if cfg.method == "drim": 93 | encoder_u = get_encoder(modality, cfg).cuda() 94 | encoders_u[modality] = encoder_u 95 | 96 | targets, cut = get_targets(dataframes, cfg.general.n_outs) 97 | dataset_test = MultimodalDataset(test_datasets, return_mask=True) 98 | test_data = SurvivalDataset(dataset_test, *targets["test"]) 99 | loader = torch.utils.data.DataLoader( 100 | test_data, shuffle=False, batch_size=24 101 | ) 102 | if cfg.method == "drim": 103 | from drim.fusion import MaskedAttentionFusion 104 | 105 | fusion = MaskedAttentionFusion( 106 | dim=cfg.general.dim, depth=1, heads=16, dim_head=64, mlp_dim=128 107 | ) 108 | fusion_u = MaskedAttentionFusion( 109 | dim=cfg.general.dim, depth=1, heads=16, dim_head=64, mlp_dim=128 110 | ) 111 | fusion.cuda() 112 | fusion_u.cuda() 113 | encoder = DRIMSurv( 114 | encoders_sh=encoders, 115 | encoders_u=encoders_u, 116 | fusion_s=fusion, 117 | fusion_u=fusion_u, 118 | ) 119 | model = MultimodalWrapper( 120 | encoder, embedding_dim=cfg.general.dim, n_outs=cfg.general.n_outs 121 | ) 122 | model.load_state_dict( 123 | torch.load(f"./models/drimsurv_split_{int(fold)}.pth") 124 | ) 125 | else: 126 | if cfg.method == "max": 127 | from drim.fusion import ShallowFusion 128 | 129 | fusion = ShallowFusion("max") 130 | elif cfg.method == "tensor": 131 | from drim.fusion import TensorFusion 132 | 133 | fusion = TensorFusion( 134 | modalities=["DNAm", "WSI", "RNA", "MRI"], 135 | input_dim=cfg.general.dim, 136 | projected_dim=cfg.general.dim, 137 | output_dim=cfg.general.dim, 138 | dropout=0.0, 139 | ) 140 | elif cfg.method == "concat": 141 | from drim.fusion import ShallowFusion 142 | 143 | fusion = ShallowFusion("concat") 144 | 145 | fusion.cuda() 146 | encoder = MultimodalModel(encoders, fusion=fusion) 147 | if cfg.method == "concat": 148 | size = cfg.general.dim * 4 149 | else: 150 | size = cfg.general.dim 151 | model = MultimodalWrapper( 152 | encoder, embedding_dim=size, n_outs=cfg.general.n_outs 153 | ) 154 | if cfg.method == "max": 155 | prefix = "vanilla" 156 | else: 157 | prefix = "aux_contrastive" 158 | model.load_state_dict( 159 | torch.load(f"./models/{prefix}_{cfg.method}_split_{int(fold)}.pth") 160 | ) 161 | 162 | model.cuda() 163 | model.eval() 164 | hazards = [] 165 | times = [] 166 | events = [] 167 | with torch.no_grad(): 168 | for batch in loader: 169 | data, time, event = batch 170 | data, mask = data 171 | outputs = model(data, mask, return_embedding=False) 172 | hazards.append(outputs) 173 | times.append(time) 174 | events.append(event) 175 | 176 | times = torch.cat(times, dim=0).cpu().numpy() 177 | events = torch.cat(events, dim=0).cpu().numpy() 178 | hazards = interpolate_dataframe( 179 | pd.DataFrame( 180 | (1 - torch.cat(hazards, dim=0).sigmoid()) 181 | .add(1e-7) 182 | .log() 183 | .cumsum(1) 184 | .exp() 185 | .cpu() 186 | .numpy() 187 | .transpose(), 188 | cut, 189 | ) 190 | ) 191 | ev = EvalSurv(hazards, times, events, censor_surv="km") 192 | c_index = ev.concordance_td() 193 | ibs = ev.integrated_brier_score(np.linspace(0, times.max())) 194 | CS = (c_index + (1 - ibs)) / 2 195 | logs["c_index"].append(c_index) 196 | logs["ibs"].append(ibs) 197 | logs["CS"].append(CS) 198 | 199 | logger.info( 200 | f"{combination} - CS: {np.mean(logs['CS']):.3f} $\pm$ {np.std(logs['CS']):.3f}" 201 | ) 202 | 203 | 204 | if __name__ == "__main__": 205 | main() 206 | -------------------------------------------------------------------------------- /static/DRIM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/static/DRIM.png -------------------------------------------------------------------------------- /train_aux_multimodal.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | from collections import defaultdict 3 | 4 | # Third-party libraries 5 | from omegaconf import DictConfig, OmegaConf 6 | import hydra 7 | from torch.utils.data import DataLoader 8 | import torch 9 | from pycox.models.loss import NLLLogistiHazardLoss 10 | import numpy as np 11 | 12 | # Local dependencies 13 | from drim.trainers import AuxSurvivalTrainer 14 | from drim.multimodal import MultimodalDataset 15 | from drim.multimodal import MultimodalModel 16 | from drim.models import MultimodalWrapper 17 | from drim.datasets import SurvivalDataset 18 | from drim.logger import logger 19 | from drim.utils import ( 20 | seed_everything, 21 | seed_worker, 22 | prepare_data, 23 | get_dataframes, 24 | log_transform, 25 | ) 26 | from drim.helpers import get_encoder, get_datasets, get_targets 27 | 28 | 29 | @hydra.main(version_base=None, config_path="configs", config_name="multimodal") 30 | def main(cfg: DictConfig) -> None: 31 | cv_metrics = defaultdict(list) 32 | # check if wandb key is in cfg 33 | if "wandb" in cfg: 34 | import wandb 35 | 36 | wandb_logging = True 37 | wandb.init( 38 | name=f"aux_{cfg.aux_loss.name}_{cfg.fusion.name}_" 39 | + "_".join(cfg.general.modalities), 40 | config={ 41 | k: v for k, v in OmegaConf.to_container(cfg).items() if k != "wandb" 42 | }, 43 | **cfg.wandb, 44 | ) 45 | else: 46 | wandb_logging = False 47 | logger.info("Starting multimodal cross-validation.") 48 | logger.info("Modalities used: {}.", cfg.general.modalities) 49 | for fold in range(cfg.general.n_folds): 50 | logger.info("Starting fold {}", fold) 51 | seed_everything(cfg.general.seed) 52 | # Load the data 53 | dataframes = get_dataframes(fold) 54 | dataframes = { 55 | split: prepare_data(dataframe, cfg.general.modalities) 56 | for split, dataframe in dataframes.items() 57 | } 58 | cfg.general.save_path = ( 59 | f"./models/aux_{cfg.aux_loss.name}_{cfg.fusion.name}_split_{int(fold)}.pth" 60 | ) 61 | 62 | for split, dataframe in dataframes.items(): 63 | logger.info(f"{split} samples: {len(dataframe)}") 64 | 65 | train_datasets = {} 66 | val_datasets = {} 67 | test_datasets = {} 68 | encoders = {} 69 | logger.info("Loading models and preparing corresponding dataset...") 70 | for modality in cfg.general.modalities: 71 | encoder = get_encoder(modality, cfg).cuda() 72 | encoders[modality] = encoder 73 | datasets = get_datasets(dataframes, modality, fold, return_mask=True) 74 | train_datasets[modality] = datasets["train"] 75 | val_datasets[modality] = datasets["val"] 76 | test_datasets[modality] = datasets["test"] 77 | 78 | targets, cuts = get_targets(dataframes, cfg.general.n_outs) 79 | 80 | dataset_train = MultimodalDataset(train_datasets, return_mask=True) 81 | dataset_val = MultimodalDataset(val_datasets, return_mask=True) 82 | dataset_test = MultimodalDataset(test_datasets, return_mask=True) 83 | train_data = SurvivalDataset(dataset_train, *targets["train"]) 84 | val_data = SurvivalDataset(dataset_val, *targets["val"]) 85 | test_data = SurvivalDataset(dataset_test, *targets["test"]) 86 | 87 | dataloaders = { 88 | "train": DataLoader( 89 | train_data, shuffle=True, worker_init_fn=seed_worker, **cfg.dataloader 90 | ), 91 | "val": DataLoader( 92 | val_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader 93 | ), 94 | "test": DataLoader( 95 | test_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader 96 | ), 97 | } 98 | 99 | if cfg.fusion.name in ["mean", "concat", "max", "sum", "masked_mean"]: 100 | from drim.fusion import ShallowFusion 101 | 102 | fusion = ShallowFusion(cfg.fusion.name) 103 | elif cfg.fusion.name == "maf": 104 | from drim.fusion import MaskedAttentionFusion 105 | 106 | fusion = MaskedAttentionFusion( 107 | dim=cfg.general.dim, dropout=cfg.general.dropout, **cfg.fusion.params 108 | ) 109 | elif cfg.fusion.name == "tensor": 110 | from drim.fusion import TensorFusion 111 | 112 | fusion = TensorFusion( 113 | modalities=cfg.general.modalities, 114 | input_dim=cfg.general.dim, 115 | projected_dim=cfg.general.dim, 116 | output_dim=cfg.general.dim, 117 | dropout=cfg.general.dropout, 118 | ) 119 | else: 120 | raise NotImplementedError 121 | 122 | fusion.cuda() 123 | 124 | encoder = MultimodalModel(encoders, fusion=fusion) 125 | if cfg.fusion.name == "concat": 126 | size = cfg.general.dim * len(cfg.general.modalities) 127 | else: 128 | size = cfg.general.dim 129 | 130 | model = MultimodalWrapper(encoder, size, n_outs=cfg.general.n_outs) 131 | # model = model.cuda() 132 | logger.info("Done!") 133 | 134 | # define optimizer and scheduler 135 | optimizer = torch.optim.AdamW(model.parameters(), **cfg.optimizer.params) 136 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 137 | optimizer, **cfg.scheduler 138 | ) 139 | 140 | # define task criterion 141 | task_criterion = NLLLogistiHazardLoss() 142 | 143 | # define auxiliary criterion 144 | if cfg.aux_loss.name == "contrastive": 145 | from drim.losses import ContrastiveLoss 146 | 147 | aux_loss = ContrastiveLoss() 148 | elif cfg.aux_loss.name == "mmo": 149 | from drim.losses import MMOLoss 150 | 151 | aux_loss = MMOLoss() 152 | 153 | trainer = AuxSurvivalTrainer( 154 | model=model, 155 | optimizer=optimizer, 156 | scheduler=scheduler, 157 | dataloaders=dataloaders, 158 | task_criterion=task_criterion, 159 | cfg=cfg, 160 | wandb_logging=wandb_logging, 161 | cuts=cuts, 162 | aux_loss=aux_loss, 163 | ) 164 | 165 | trainer.fit() 166 | val_logs = trainer.evaluate("val") 167 | test_logs = trainer.evaluate("test") 168 | # add to cv_metrics 169 | for key, value in val_logs.items(): 170 | cv_metrics[key].append(value) 171 | 172 | for key, value in test_logs.items(): 173 | cv_metrics[key].append(value) 174 | 175 | logger.info("Fold {} done!", fold) 176 | 177 | # log first the mean ± std of the validation metrics 178 | logs = {} 179 | for key, value in cv_metrics.items(): 180 | if key in [ 181 | "test/c_index", 182 | "test/cs_score", 183 | "test/inbll", 184 | "test/ibs", 185 | "val/c_index", 186 | "val/cs_score", 187 | "val/inbll", 188 | "val/ibs", 189 | ]: 190 | mean, std = np.mean(value), np.std(value) 191 | logger.info(f"{key}: {mean:.4f} ± {std:.4f}") 192 | logs[f"fin/{'_'.join(key.split('/'))}_mean"] = mean 193 | logs[f"fin/{'_'.join(key.split('/'))}_std"] = std 194 | 195 | if wandb_logging: 196 | wandb.log(logs) 197 | wandb.finish() 198 | 199 | 200 | if __name__ == "__main__": 201 | main() 202 | -------------------------------------------------------------------------------- /train_drimsurv.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | from collections import defaultdict 3 | 4 | # Third-party libraries 5 | from omegaconf import DictConfig, OmegaConf 6 | import hydra 7 | from torch.utils.data import DataLoader 8 | import torch 9 | from pycox.models.loss import NLLLogistiHazardLoss 10 | import numpy as np 11 | 12 | # Local dependencies 13 | from drim.trainers import DRIMSurvTrainer 14 | from drim.multimodal import MultimodalDataset, DRIMSurv 15 | from drim.models import MultimodalWrapper, Discriminator 16 | from drim.losses import ContrastiveLoss 17 | from drim.datasets import SurvivalDataset 18 | from drim.logger import logger 19 | from drim.utils import ( 20 | seed_everything, 21 | seed_worker, 22 | prepare_data, 23 | get_dataframes, 24 | log_transform, 25 | ) 26 | from drim.helpers import get_encoder, get_datasets, get_targets 27 | 28 | 29 | @hydra.main(version_base=None, config_path="configs", config_name="multimodal") 30 | def main(cfg: DictConfig) -> None: 31 | cv_metrics = defaultdict(list) 32 | # check if wandb key is in cfg 33 | if "wandb" in cfg: 34 | import wandb 35 | 36 | wandb_logging = True 37 | wandb.init( 38 | name="DRIMSurv_" + "_".join(cfg.general.modalities), 39 | config={ 40 | k: v for k, v in OmegaConf.to_container(cfg).items() if k != "wandb" 41 | }, 42 | **cfg.wandb, 43 | ) 44 | else: 45 | wandb_logging = False 46 | logger.info("Starting multimodal cross-validation.") 47 | logger.info("Modalities used: {}.", cfg.general.modalities) 48 | for fold in range(cfg.general.n_folds): 49 | logger.info("Starting fold {}", fold) 50 | seed_everything(cfg.general.seed) 51 | # Load the data 52 | dataframes = get_dataframes(fold) 53 | dataframes = { 54 | split: prepare_data(dataframe, cfg.general.modalities) 55 | for split, dataframe in dataframes.items() 56 | } 57 | cfg.general.save_path = f"./models/drimsurv_split_{int(fold)}.pth" 58 | 59 | for split, dataframe in dataframes.items(): 60 | logger.info(f"{split} samples: {len(dataframe)}") 61 | 62 | train_datasets = {} 63 | val_datasets = {} 64 | test_datasets = {} 65 | encoders = {} 66 | encoders_u = {} 67 | logger.info("Loading models and preparing corresponding dataset...") 68 | for modality in cfg.general.modalities: 69 | encoder = get_encoder(modality, cfg).cuda() 70 | encoders[modality] = encoder 71 | encoder_u = get_encoder(modality, cfg).cuda() 72 | encoders_u[modality] = encoder_u 73 | datasets = get_datasets(dataframes, modality, fold, return_mask=True) 74 | train_datasets[modality] = datasets["train"] 75 | val_datasets[modality] = datasets["val"] 76 | test_datasets[modality] = datasets["test"] 77 | 78 | targets, cuts = get_targets(dataframes, cfg.general.n_outs) 79 | 80 | dataset_train = MultimodalDataset(train_datasets, return_mask=True) 81 | dataset_val = MultimodalDataset(val_datasets, return_mask=True) 82 | dataset_test = MultimodalDataset(test_datasets, return_mask=True) 83 | train_data = SurvivalDataset(dataset_train, *targets["train"]) 84 | val_data = SurvivalDataset(dataset_val, *targets["val"]) 85 | test_data = SurvivalDataset(dataset_test, *targets["test"]) 86 | 87 | dataloaders = { 88 | "train": DataLoader( 89 | train_data, shuffle=True, worker_init_fn=seed_worker, **cfg.dataloader 90 | ), 91 | "val": DataLoader( 92 | val_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader 93 | ), 94 | "test": DataLoader( 95 | test_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader 96 | ), 97 | } 98 | 99 | if cfg.fusion.name in ["mean", "concat", "max", "sum", "masked_mean"]: 100 | from drim.fusion import ShallowFusion 101 | 102 | fusion = ShallowFusion(cfg.fusion.name) 103 | fusion_u = ShallowFusion(cfg.fusion.name) 104 | elif cfg.fusion.name == "maf": 105 | from drim.fusion import MaskedAttentionFusion 106 | 107 | fusion = MaskedAttentionFusion( 108 | dim=cfg.general.dim, dropout=cfg.general.dropout, **cfg.fusion.params 109 | ) 110 | fusion_u = MaskedAttentionFusion( 111 | dim=cfg.general.dim, dropout=cfg.general.dropout, **cfg.fusion.params 112 | ) 113 | elif cfg.fusion.name == "tensor": 114 | from drim.fusion import TensorFusion 115 | 116 | fusion = TensorFusion( 117 | modalities=cfg.general.modalities, 118 | input_dim=cfg.general.dim, 119 | projected_dim=cfg.general.dim, 120 | output_dim=cfg.general.dim, 121 | dropout=cfg.general.dropout, 122 | ) 123 | fusion_u = TensorFusion( 124 | modalities=cfg.general.modalities + ["shared"], 125 | input_dim=cfg.general.dim, 126 | projected_dim=cfg.general.dim, 127 | output_dim=cfg.general.dim, 128 | dropout=cfg.general.dropout, 129 | ) 130 | else: 131 | raise NotImplementedError 132 | 133 | fusion.cuda() 134 | fusion_u.cuda() 135 | 136 | encoder = DRIMSurv( 137 | encoders_sh=encoders, 138 | encoders_u=encoders_u, 139 | fusion_s=fusion, 140 | fusion_u=fusion_u, 141 | ) 142 | 143 | model = MultimodalWrapper( 144 | encoder, embedding_dim=cfg.general.dim, n_outs=cfg.general.n_outs 145 | ) 146 | logger.info("Done!") 147 | 148 | logger.info("Preparing discriminators..") 149 | discriminators = { 150 | k: Discriminator( 151 | embedding_dim=cfg.general.dim * 2, dropout=cfg.general.dropout 152 | ).to("cuda") 153 | for k in encoders.keys() 154 | } 155 | 156 | # define optimizer and scheduler 157 | optimizer = torch.optim.AdamW(model.parameters(), **cfg.optimizer.params) 158 | optimizers_dsm = { 159 | k: torch.optim.AdamW( 160 | discriminators[k].parameters(), 161 | lr=cfg.disentangled.dsm_lr, 162 | weight_decay=cfg.disentangled.dsm_wd, 163 | ) 164 | for k in encoders.keys() 165 | } 166 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 167 | optimizer, **cfg.scheduler 168 | ) 169 | 170 | # define task criterion 171 | task_criterion = NLLLogistiHazardLoss() 172 | 173 | # define auxiliary criterion 174 | aux_loss = ContrastiveLoss() 175 | 176 | trainer = DRIMSurvTrainer( 177 | model=model, 178 | discriminators=discriminators, 179 | optimizer=optimizer, 180 | optimizers_dsm=optimizers_dsm, 181 | scheduler=scheduler, 182 | dataloaders=dataloaders, 183 | task_criterion=task_criterion, 184 | aux_loss=aux_loss, 185 | cfg=cfg, 186 | wandb_logging=wandb_logging, 187 | cuts=cuts, 188 | ) 189 | 190 | trainer.fit() 191 | val_logs = trainer.evaluate("val") 192 | test_logs = trainer.evaluate("test") 193 | # add to cv_metrics 194 | for key, value in val_logs.items(): 195 | cv_metrics[key].append(value) 196 | 197 | for key, value in test_logs.items(): 198 | cv_metrics[key].append(value) 199 | 200 | logger.info("Fold {} done!", fold) 201 | 202 | # log first the mean ± std of the validation metrics 203 | logs = {} 204 | for key, value in cv_metrics.items(): 205 | if key in [ 206 | "test/c_index", 207 | "test/cs_score", 208 | "test/inbll", 209 | "test/ibs", 210 | "val/c_index", 211 | "val/cs_score", 212 | "val/inbll", 213 | "val/ibs", 214 | ]: 215 | mean, std = np.mean(value), np.std(value) 216 | logger.info(f"{key}: {mean:.4f} ± {std:.4f}") 217 | logs[f"fin/{'_'.join(key.split('/'))}_mean"] = mean 218 | logs[f"fin/{'_'.join(key.split('/'))}_std"] = std 219 | 220 | if wandb_logging: 221 | wandb.log(logs) 222 | wandb.finish() 223 | 224 | 225 | if __name__ == "__main__": 226 | main() 227 | -------------------------------------------------------------------------------- /train_drimu.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | from collections import defaultdict 3 | 4 | # Third-party libraries 5 | from omegaconf import DictConfig, OmegaConf 6 | import hydra 7 | from torch.utils.data import DataLoader 8 | import torch 9 | from pycox.models.loss import NLLLogistiHazardLoss 10 | import numpy as np 11 | 12 | # Local dependencies 13 | from drim.trainers import DRIMUTrainer 14 | from drim.multimodal import MultimodalDataset, DRIMU 15 | from drim.models import MultimodalWrapper, Discriminator 16 | from drim.losses import ContrastiveLoss 17 | from drim.datasets import SurvivalDataset 18 | from drim.logger import logger 19 | from drim.utils import ( 20 | seed_everything, 21 | seed_worker, 22 | prepare_data, 23 | get_dataframes, 24 | log_transform, 25 | ) 26 | from drim.helpers import get_encoder, get_datasets, get_targets, get_decoder 27 | 28 | 29 | @hydra.main(version_base=None, config_path="configs", config_name="multimodal") 30 | def main(cfg: DictConfig) -> None: 31 | cv_metrics = defaultdict(list) 32 | # check if wandb key is in cfg 33 | if "wandb" in cfg: 34 | import wandb 35 | 36 | wandb_logging = True 37 | wandb.init( 38 | name="DRIMU_" + "_".join(cfg.general.modalities), 39 | config={ 40 | k: v for k, v in OmegaConf.to_container(cfg).items() if k != "wandb" 41 | }, 42 | **cfg.wandb, 43 | ) 44 | else: 45 | wandb_logging = False 46 | logger.info("Starting multimodal cross-validation.") 47 | logger.info("Modalities used: {}.", cfg.general.modalities) 48 | for fold in range(cfg.general.n_folds): 49 | logger.info("Starting fold {}", fold) 50 | seed_everything(cfg.general.seed) 51 | # Load the data 52 | dataframes = get_dataframes(fold) 53 | dataframes = { 54 | split: prepare_data(dataframe, cfg.general.modalities) 55 | for split, dataframe in dataframes.items() 56 | } 57 | cfg.general.save_path = f"./models/drimu_split_{int(fold)}.pth" 58 | 59 | for split, dataframe in dataframes.items(): 60 | logger.info(f"{split} samples: {len(dataframe)}") 61 | 62 | train_datasets = {} 63 | val_datasets = {} 64 | test_datasets = {} 65 | encoders = {} 66 | encoders_u = {} 67 | decoders = {} 68 | logger.info("Loading models and preparing corresponding dataset...") 69 | for modality in cfg.general.modalities: 70 | encoder = get_encoder(modality, cfg).cuda() 71 | encoders[modality] = encoder 72 | encoder_u = get_encoder(modality, cfg).cuda() 73 | encoders_u[modality] = encoder_u 74 | decoders[modality] = get_decoder(modality, cfg).cuda() 75 | datasets = get_datasets(dataframes, modality, fold, return_mask=True) 76 | train_datasets[modality] = datasets["train"] 77 | val_datasets[modality] = datasets["val"] 78 | test_datasets[modality] = datasets["test"] 79 | 80 | targets, cuts = get_targets(dataframes, cfg.general.n_outs) 81 | 82 | dataset_train = MultimodalDataset(train_datasets, return_mask=True) 83 | dataset_val = MultimodalDataset(val_datasets, return_mask=True) 84 | dataset_test = MultimodalDataset(test_datasets, return_mask=True) 85 | train_data = SurvivalDataset(dataset_train, *targets["train"]) 86 | val_data = SurvivalDataset(dataset_val, *targets["val"]) 87 | test_data = SurvivalDataset(dataset_test, *targets["test"]) 88 | 89 | dataloaders = { 90 | "train": DataLoader( 91 | train_data, shuffle=True, worker_init_fn=seed_worker, **cfg.dataloader 92 | ), 93 | "val": DataLoader( 94 | val_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader 95 | ), 96 | "test": DataLoader( 97 | test_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader 98 | ), 99 | } 100 | 101 | if cfg.fusion.name in ["mean", "concat", "max", "sum", "masked_mean"]: 102 | from drim.fusion import ShallowFusion 103 | 104 | fusion = ShallowFusion(cfg.fusion.name) 105 | fusion_u = ShallowFusion(cfg.fusion.name) 106 | elif cfg.fusion.name == "maf": 107 | from drim.fusion import MaskedAttentionFusion 108 | 109 | fusion = MaskedAttentionFusion( 110 | dim=cfg.general.dim, dropout=cfg.general.dropout, **cfg.fusion.params 111 | ) 112 | fusion_u = MaskedAttentionFusion( 113 | dim=cfg.general.dim, dropout=cfg.general.dropout, **cfg.fusion.params 114 | ) 115 | elif cfg.fusion.name == "tensor": 116 | from drim.fusion import TensorFusion 117 | 118 | fusion = TensorFusion( 119 | modalities=cfg.general.modalities, 120 | input_dim=cfg.general.dim, 121 | projected_dim=cfg.general.dim, 122 | output_dim=cfg.general.dim, 123 | dropout=cfg.general.dropout, 124 | ) 125 | fusion_u = TensorFusion( 126 | modalities=cfg.general.modalities, 127 | input_dim=cfg.general.dim, 128 | projected_dim=cfg.general.dim, 129 | output_dim=cfg.general.dim, 130 | dropout=cfg.general.dropout, 131 | ) 132 | else: 133 | raise NotImplementedError 134 | 135 | fusion.cuda() 136 | fusion_u.cuda() 137 | 138 | encoder = DRIMU( 139 | encoders_sh=encoders, 140 | encoders_u=encoders_u, 141 | fusion_s=fusion, 142 | fusion_u=fusion_u, 143 | ) 144 | 145 | model = MultimodalWrapper( 146 | encoder, embedding_dim=cfg.general.dim, n_outs=cfg.general.n_outs 147 | ) 148 | logger.info("Done!") 149 | 150 | logger.info("Preparing discriminators..") 151 | discriminators = { 152 | k: Discriminator( 153 | embedding_dim=cfg.general.dim * 2, dropout=cfg.general.dropout 154 | ).to("cuda") 155 | for k in encoders.keys() 156 | } 157 | 158 | # define optimizer and scheduler 159 | decoder_parameters = [] 160 | for k in decoders.keys(): 161 | decoder_parameters += list(decoders[k].parameters()) 162 | optimizer = torch.optim.AdamW( 163 | list(model.parameters()) + decoder_parameters, **cfg.optimizer.params 164 | ) 165 | optimizers_dsm = { 166 | k: torch.optim.AdamW( 167 | discriminators[k].parameters(), 168 | lr=cfg.disentangled.dsm_lr, 169 | weight_decay=cfg.disentangled.dsm_wd, 170 | ) 171 | for k in encoders.keys() 172 | } 173 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 174 | optimizer, eta_min=cfg.optimizer.params.lr, T_max=cfg.general.epochs 175 | ) 176 | 177 | task_criterion = NLLLogistiHazardLoss() 178 | # define auxiliary criterion 179 | aux_loss = ContrastiveLoss() 180 | 181 | trainer = DRIMUTrainer( 182 | model, 183 | decoders=decoders, 184 | discriminators=discriminators, 185 | optimizer=optimizer, 186 | optimizers_dsm=optimizers_dsm, 187 | scheduler=scheduler, 188 | task_criterion=task_criterion, 189 | dataloaders=dataloaders, 190 | cfg=cfg, 191 | cuts=cuts, 192 | wandb_logging=wandb_logging, 193 | aux_loss=aux_loss, 194 | ) 195 | 196 | trainer.fit() 197 | trainer.finetune() 198 | val_logs = trainer.evaluate("val") 199 | test_logs = trainer.evaluate("test") 200 | # add to cv_metrics 201 | for key, value in val_logs.items(): 202 | cv_metrics[key].append(value) 203 | 204 | for key, value in test_logs.items(): 205 | cv_metrics[key].append(value) 206 | 207 | logger.info("Fold {} done!", fold) 208 | 209 | # log first the mean ± std of the validation metrics 210 | logs = {} 211 | for key, value in cv_metrics.items(): 212 | if key in [ 213 | "test/c_index", 214 | "test/cs_score", 215 | "test/inbll", 216 | "test/ibs", 217 | "val/c_index", 218 | "val/cs_score", 219 | "val/inbll", 220 | "val/ibs", 221 | ]: 222 | mean, std = np.mean(value), np.std(value) 223 | logger.info(f"{key}: {mean:.4f} ± {std:.4f}") 224 | logs[f"fin/{'_'.join(key.split('/'))}_mean"] = mean 225 | logs[f"fin/{'_'.join(key.split('/'))}_std"] = std 226 | 227 | if wandb_logging: 228 | wandb.log(logs) 229 | wandb.finish() 230 | 231 | 232 | if __name__ == "__main__": 233 | main() 234 | -------------------------------------------------------------------------------- /train_unimodal.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | from collections import defaultdict 3 | 4 | # Third-party libraries 5 | import hydra 6 | from omegaconf import DictConfig, OmegaConf 7 | from torch.utils.data import DataLoader 8 | import torch 9 | from pycox.models.loss import NLLLogistiHazardLoss 10 | import numpy as np 11 | 12 | # Local dependencies 13 | from drim.models import UnimodalWrapper 14 | from drim.trainers import BaseSurvivalTrainer 15 | from drim.datasets import SurvivalDataset 16 | from drim.logger import logger 17 | from drim.utils import ( 18 | seed_everything, 19 | seed_worker, 20 | prepare_data, 21 | get_dataframes, 22 | log_transform, 23 | ) 24 | from drim.helpers import get_encoder, get_datasets, get_targets 25 | 26 | 27 | @hydra.main(version_base=None, config_path="configs", config_name="unimodal") 28 | def main(cfg: DictConfig) -> None: 29 | cv_metrics = defaultdict(list) 30 | # check if wandb key is in cfg 31 | if "wandb" in cfg: 32 | import wandb 33 | 34 | wandb_logging = True 35 | wandb.init( 36 | name=cfg.general.modalities, 37 | config={ 38 | k: v for k, v in OmegaConf.to_container(cfg).items() if k != "wandb" 39 | }, 40 | **cfg.wandb, 41 | ) 42 | else: 43 | wandb_logging = False 44 | 45 | logger.info("Starting unimodal cross-validation.") 46 | logger.info("Modality used: {}.", cfg.general.modalities) 47 | for fold in range(cfg.general.n_folds): 48 | logger.info("Starting fold {}", fold) 49 | seed_everything(cfg.general.seed) 50 | # Load the data 51 | dataframes = get_dataframes(fold) 52 | # take only the intersection between multimodal data and unimodal to ensure fair comparisons 53 | dataframes_multi = { 54 | split: prepare_data(dataframe, ["DNAm", "WSI", "RNA", "MRI"]) 55 | for split, dataframe in dataframes.items() 56 | } 57 | 58 | dataframes = { 59 | split: prepare_data(dataframe, cfg.general.modalities) 60 | for split, dataframe in dataframes.items() 61 | } 62 | 63 | dataframes = { 64 | split: dataframe[ 65 | dataframe["submitter_id"].isin(dataframes_multi[split]["submitter_id"]) 66 | ] 67 | for split, dataframe in dataframes.items() 68 | } 69 | cfg.general.save_path = ( 70 | f"./models/{cfg.general.modalities}_split_{int(fold)}.pth" 71 | ) 72 | for split, dataframe in dataframes.items(): 73 | logger.info(f"{split} samples: {len(dataframe)}") 74 | 75 | # Load the model 76 | logger.info("Loading model and preparing corresponding dataset...") 77 | encoder = get_encoder(cfg.general.modalities, cfg) 78 | datasets = get_datasets( 79 | dataframes, cfg.general.modalities, fold, return_mask=False 80 | ) 81 | targets, cuts = get_targets(dataframes, cfg.general.n_outs) 82 | train_data = SurvivalDataset(datasets["train"], *targets["train"]) 83 | val_data = SurvivalDataset(datasets["val"], *targets["val"]) 84 | test_data = SurvivalDataset(datasets["test"], *targets["test"]) 85 | 86 | dataloaders = { 87 | "train": DataLoader( 88 | train_data, shuffle=True, worker_init_fn=seed_worker, **cfg.dataloader 89 | ), 90 | "val": DataLoader( 91 | val_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader 92 | ), 93 | "test": DataLoader( 94 | test_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader 95 | ), 96 | } 97 | 98 | model = UnimodalWrapper(encoder, cfg.general.dim, n_outs=cfg.general.n_outs) 99 | model = model.cuda() 100 | logger.info("Done!") 101 | 102 | optimizer = torch.optim.AdamW(model.parameters(), **cfg.optimizer.params) 103 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 104 | optimizer, **cfg.scheduler 105 | ) 106 | 107 | task_criterion = NLLLogistiHazardLoss() 108 | trainer = BaseSurvivalTrainer( 109 | model=model, 110 | optimizer=optimizer, 111 | scheduler=scheduler, 112 | dataloaders=dataloaders, 113 | task_criterion=task_criterion, 114 | cfg=cfg, 115 | wandb_logging=wandb_logging, 116 | cuts=cuts, 117 | ) 118 | 119 | trainer.fit() 120 | val_logs = trainer.evaluate("val") 121 | test_logs = trainer.evaluate("test") 122 | # add to cv_metrics 123 | for key, value in val_logs.items(): 124 | cv_metrics[key].append(value) 125 | 126 | for key, value in test_logs.items(): 127 | cv_metrics[key].append(value) 128 | 129 | logger.info("Fold {} done!", fold) 130 | 131 | # log first the mean ± std of the validation metrics 132 | logs = {} 133 | for key, value in cv_metrics.items(): 134 | if key in [ 135 | "test/c_index", 136 | "test/cs_score", 137 | "test/inbll", 138 | "test/ibs", 139 | "val/c_index", 140 | "val/cs_score", 141 | "val/inbll", 142 | "val/ibs", 143 | ]: 144 | mean, std = np.mean(value), np.std(value) 145 | logger.info(f"{key}: {mean:.4f} ± {std:.4f}") 146 | logs[f"fin/{'_'.join(key.split('/'))}_mean"] = mean 147 | logs[f"fin/{'_'.join(key.split('/'))}_std"] = std 148 | 149 | if wandb_logging: 150 | wandb.log(logs) 151 | wandb.finish() 152 | 153 | 154 | if __name__ == "__main__": 155 | main() 156 | -------------------------------------------------------------------------------- /train_vanilla_multimodal.py: -------------------------------------------------------------------------------- 1 | # Standard libraries 2 | from collections import defaultdict 3 | 4 | # Third-party libraries 5 | from omegaconf import DictConfig, OmegaConf 6 | import hydra 7 | from torch.utils.data import DataLoader 8 | import torch 9 | from pycox.models.loss import NLLLogistiHazardLoss 10 | import numpy as np 11 | 12 | # Local dependencies 13 | from drim.trainers import BaseSurvivalTrainer 14 | from drim.multimodal import MultimodalDataset 15 | from drim.multimodal import MultimodalModel 16 | from drim.models import MultimodalWrapper 17 | from drim.datasets import SurvivalDataset 18 | from drim.logger import logger 19 | from drim.utils import ( 20 | seed_everything, 21 | seed_worker, 22 | prepare_data, 23 | get_dataframes, 24 | log_transform, 25 | ) 26 | from drim.helpers import get_encoder, get_datasets, get_targets 27 | 28 | 29 | @hydra.main(version_base=None, config_path="configs", config_name="multimodal") 30 | def main(cfg: DictConfig) -> None: 31 | cv_metrics = defaultdict(list) 32 | # check if wandb key is in cfg 33 | if "wandb" in cfg: 34 | import wandb 35 | 36 | wandb_logging = True 37 | wandb.init( 38 | name=f"vanilla_{cfg.fusion.name}_" + "_".join(cfg.general.modalities), 39 | config={ 40 | k: v for k, v in OmegaConf.to_container(cfg).items() if k != "wandb" 41 | }, 42 | **cfg.wandb, 43 | ) 44 | else: 45 | wandb_logging = False 46 | logger.info("Starting multimodal cross-validation.") 47 | logger.info("Modalities used: {}.", cfg.general.modalities) 48 | for fold in range(cfg.general.n_folds): 49 | logger.info("Starting fold {}", fold) 50 | seed_everything(cfg.general.seed) 51 | # Load the data 52 | dataframes = get_dataframes(fold) 53 | dataframes = { 54 | split: prepare_data(dataframe, cfg.general.modalities) 55 | for split, dataframe in dataframes.items() 56 | } 57 | cfg.general.save_path = ( 58 | f"./models/vanilla_{cfg.fusion.name}_split_{int(fold)}.pth" 59 | ) 60 | 61 | for split, dataframe in dataframes.items(): 62 | logger.info(f"{split} samples: {len(dataframe)}") 63 | 64 | train_datasets = {} 65 | val_datasets = {} 66 | test_datasets = {} 67 | encoders = {} 68 | logger.info("Loading models and preparing corresponding dataset...") 69 | for modality in cfg.general.modalities: 70 | encoder = get_encoder(modality, cfg).cuda() 71 | encoders[modality] = encoder 72 | datasets = get_datasets(dataframes, modality, fold, return_mask=True) 73 | train_datasets[modality] = datasets["train"] 74 | val_datasets[modality] = datasets["val"] 75 | test_datasets[modality] = datasets["test"] 76 | 77 | targets, cuts = get_targets(dataframes, cfg.general.n_outs) 78 | 79 | dataset_train = MultimodalDataset(train_datasets, return_mask=True) 80 | dataset_val = MultimodalDataset(val_datasets, return_mask=True) 81 | dataset_test = MultimodalDataset(test_datasets, return_mask=True) 82 | train_data = SurvivalDataset(dataset_train, *targets["train"]) 83 | val_data = SurvivalDataset(dataset_val, *targets["val"]) 84 | test_data = SurvivalDataset(dataset_test, *targets["test"]) 85 | 86 | dataloaders = { 87 | "train": DataLoader( 88 | train_data, shuffle=True, worker_init_fn=seed_worker, **cfg.dataloader 89 | ), 90 | "val": DataLoader( 91 | val_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader 92 | ), 93 | "test": DataLoader( 94 | test_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader 95 | ), 96 | } 97 | 98 | if cfg.fusion.name in ["mean", "concat", "max", "sum", "masked_mean"]: 99 | from drim.fusion import ShallowFusion 100 | 101 | fusion = ShallowFusion(cfg.fusion.name) 102 | elif cfg.fusion.name == "maf": 103 | from drim.fusion import MaskedAttentionFusion 104 | 105 | fusion = MaskedAttentionFusion( 106 | dim=cfg.general.dim, dropout=cfg.general.dropout, **cfg.fusion.params 107 | ) 108 | elif cfg.fusion.name == "tensor": 109 | from drim.fusion import TensorFusion 110 | 111 | fusion = TensorFusion( 112 | modalities=cfg.general.modalities, 113 | input_dim=cfg.general.dim, 114 | projected_dim=cfg.general.dim, 115 | output_dim=cfg.general.dim, 116 | dropout=cfg.general.dropout, 117 | ) 118 | else: 119 | raise NotImplementedError 120 | 121 | fusion.cuda() 122 | 123 | encoder = MultimodalModel(encoders, fusion=fusion) 124 | if cfg.fusion.name == "concat": 125 | size = cfg.general.dim * len(cfg.general.modalities) 126 | else: 127 | size = cfg.general.dim 128 | 129 | model = MultimodalWrapper(encoder, size, n_outs=cfg.general.n_outs) 130 | # model = model.cuda() 131 | logger.info("Done!") 132 | 133 | # define optimizer and scheduler 134 | optimizer = torch.optim.AdamW(model.parameters(), **cfg.optimizer.params) 135 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 136 | optimizer, **cfg.scheduler 137 | ) 138 | 139 | # define task criterion 140 | task_criterion = NLLLogistiHazardLoss() 141 | trainer = BaseSurvivalTrainer( 142 | model=model, 143 | optimizer=optimizer, 144 | scheduler=scheduler, 145 | dataloaders=dataloaders, 146 | task_criterion=task_criterion, 147 | cfg=cfg, 148 | wandb_logging=wandb_logging, 149 | cuts=cuts, 150 | ) 151 | 152 | trainer.fit() 153 | val_logs = trainer.evaluate("val") 154 | test_logs = trainer.evaluate("test") 155 | # add to cv_metrics 156 | for key, value in val_logs.items(): 157 | cv_metrics[key].append(value) 158 | 159 | for key, value in test_logs.items(): 160 | cv_metrics[key].append(value) 161 | 162 | logger.info("Fold {} done!", fold) 163 | 164 | # log first the mean ± std of the validation metrics 165 | logs = {} 166 | for key, value in cv_metrics.items(): 167 | if key in [ 168 | "test/c_index", 169 | "test/cs_score", 170 | "test/inbll", 171 | "test/ibs", 172 | "val/c_index", 173 | "val/cs_score", 174 | "val/inbll", 175 | "val/ibs", 176 | ]: 177 | mean, std = np.mean(value), np.std(value) 178 | logger.info(f"{key}: {mean:.4f} ± {std:.4f}") 179 | logs[f"fin/{'_'.join(key.split('/'))}_mean"] = mean 180 | logs[f"fin/{'_'.join(key.split('/'))}_std"] = std 181 | 182 | if wandb_logging: 183 | wandb.log(logs) 184 | wandb.finish() 185 | 186 | 187 | if __name__ == "__main__": 188 | main() 189 | --------------------------------------------------------------------------------