├── LICENSE
├── README.md
├── clustering.ipynb
├── configs
├── multimodal.yaml
├── robustness.yaml
└── unimodal.yaml
├── data
├── README.md
├── extract_brain_embeddings.py
├── files
│ ├── clinical_data.tsv
│ └── supplementary_data.tsv
├── get_wsi_thumbnails.py
├── mappings
│ └── wsi_mapping.json
├── preprocessing.ipynb
├── rna_preprocessors
│ ├── trf_0.joblib
│ ├── trf_1.joblib
│ ├── trf_2.joblib
│ ├── trf_3.joblib
│ └── trf_4.joblib
├── run_mri_pretraining.py
├── run_wsi_pretraining.py
└── scripts
│ ├── gdc_manifest_20230918_WSI_LGG.txt
│ ├── gdc_manifest_20231124_WSI_GBM.txt
│ ├── gdc_manifest_20231204_RNA_GBM_LGG.txt
│ └── gdc_manifest_20231205_DNAm_GBM_LGG.txt
├── drim
├── __init__.py
├── commons
│ ├── __init__.py
│ ├── losses.py
│ └── optim_contrastive.py
├── datasets.py
├── dnam
│ ├── __init__.py
│ ├── datasets.py
│ └── models.py
├── fusion
│ ├── __init__.py
│ ├── fusion.py
│ └── utils.py
├── helpers.py
├── logger.py
├── losses.py
├── models.py
├── mri
│ ├── __init__.py
│ ├── datasets.py
│ ├── models.py
│ └── transforms.py
├── multimodal
│ ├── __init__.py
│ ├── datasets.py
│ └── models.py
├── rna
│ ├── __init__.py
│ ├── datasets.py
│ └── models.py
├── trainers.py
├── utils.py
└── wsi
│ ├── __init__.py
│ ├── datasets.py
│ ├── func.py
│ ├── models.py
│ └── transforms.py
├── requirements.txt
├── robustness.py
├── static
└── DRIM.png
├── train_aux_multimodal.py
├── train_drimsurv.py
├── train_drimu.py
├── train_unimodal.py
└── train_vanilla_multimodal.py
/README.md:
--------------------------------------------------------------------------------
1 | # DRIM: Learning Disentangled Representations from Incomplete Multimodal Healthcare Data
2 |
3 | **[Go to pdf](https://papers.miccai.org/miccai-2024/paper/1276_paper.pdf)**
4 | This is the code associated to the paper **DRIM: Learning Disentangled Representations from Incomplete Multimodal Healthcare Data** (accepted to [MICCAI2024](https://conferences.miccai.org/2024/en/)).
5 |
6 | 
7 | __________________
8 | # Data Preprocessing
9 |
10 | Navigate to the `data` folder to deal with the download and the preprocessing of the data.
11 | The maximum file size limit of the supplemental does not allow the WSI pre-trained models to be supplied. So only the MRI pre-trained model is provided in `data/models/`.
12 |
13 | **N.B**: The data used is from the TCGA-GBMLGG and must have the following structure:
14 |
15 | ```
16 | ├── GBMLGG
17 | │ ├── MRI
18 | │ ├── WSI
19 | │ ├── DNAm
20 | │ ├── RNA
21 | ```
22 | Also in the `data` folder you must have, for each sample a `files/train_brain.csv` and a `files/test_brain.csv` containing for each modality, the path through the corresponding file. But again, check the `README.md` from the `data` folder for further explanations.
23 |
24 | ________
25 | # Quickstart
26 | Start by creating a dedicated conda environment and installing all the required packages
27 | ```
28 | $ conda create python=3.10.10 --name drim
29 | $ conda activate drim
30 | $ pip install -r requirements.txt
31 | ```
32 | **N.B:** This repository relies on [Hydra](https://hydra.cc/docs/intro/) to run experiments and benefits from all the advantages it offers (*multiruns, sweeps, flexibility...*)
33 |
34 | ________
35 | # Training loops
36 | Here are a number of scripts for training and cross-validation.
37 |
38 | **Unimodal**
39 |
40 | For example, to launch a unimodal training session on each of the modalities, using Hydra's features, you can enter the following in the CLI:
41 | ```
42 | $ python train_unimodal.py -m general.modalities=DNAm,WSI,RNA,MRI
43 | ```
44 |
45 | **Simple Multimodal**
46 |
47 | To use a simple fusion architecture, having the choice for each of $h_u$ and $h_s$
48 | ```
49 | $ python train_vanilla_multimodal.py fusion.name=maf
50 | ```
51 | But once again, to use all the possible functions
52 | ```
53 | $ python train_vanilla_multimodal.py -m fusion.name=mean,concat,masked_mean,tensor,sum,max,maf
54 | ```
55 |
56 | **Auxiliary Multimodal**
57 |
58 | In this repository, in addition to the fusion, you can add a cost function to the $s_i^m$ representations. This can be either the cost function used in DRIM directly adapted from the Supervised Contrastive loss function (1) or the MMO loss function (2) which tends to orthogonalise the representations.
59 | (see the paper for more details on the notations.)
60 | ```
61 | $ python train_aux_multimodal.py aux_loss.name=contrastive
62 | ```
63 | ```
64 | $ python train_aux_multimodal.py aux_loss.name=mmo aux_loss.alpha=0.5
65 | ```
66 | ## DRIM
67 | The DRIM method is very flexible and can be adapted to any type of task. For example, we offer two alternatives DRIMSurv and DRIMU.
68 |
69 |
70 | **DRIMSurv**
71 |
72 | DRIMSurv enables end-to-end training of a survival model by training the encoders at the same time as the higher attention blocks needed to merge the various representations. (Fig 1c.)
73 | ```
74 | $ python train_drimsurv.py
75 | ```
76 |
77 | **DRIMU**
78 |
79 | The other alternative DRIMU does not require a label and is trained using a reconstruction scheme. The code provided allows the backbone encoders to be pre-trained and then frozen for finetuning over 10 epochs on a survival task with a fusion scheme similar to DRIMSurv. (Fig 1b.)
80 | ```
81 | $ python train_drimu.py
82 | ```
83 |
84 | ## Robustness
85 | Once each model has been trained on the 5 splits, it is possible to test how they perform and how robust they are to different combinations of input modalities. Simply run the `robustness.py` script, which is based on the `configs/robustness.yaml` configuration file, and look at the output CS score performance.
86 | To check all available methods:
87 | ```
88 | $ python robustness -m method=drim,max,tensor,concat
89 | ```
90 |
91 | ## Stratifying high-risk vs low-risk patients
92 | All the code related to this stratisfication as well as the results of the logrank test can be found in the notebook `clustering.ipynb`
93 |
94 | ________
95 | # Logging
96 | All the training sessions have been logged on Weights and Biases. You can also use [Weights and Biases](https://wandb.ai/) to track the progress of your runs. To do this, add your project and entity to the configuration files `configs/unimodal.yaml` and `configs/multimodal.yaml`:
97 |
98 | ```
99 | wandb:
100 | entity: MyAwesomeNickname
101 | project: MyAwesomeProject
102 | ````
103 | Otherwise, training progression and evaluation will be displayed directly in the console.
104 |
105 | ________
106 | # Citation
107 | ```
108 | @InProceedings{robinet_drim_2024,
109 | author = { Robinet, Lucas and Berjaoui, Ahmad and Kheil, Ziad and Cohen-Jonathan Moyal, Elizabeth},
110 | title = { { DRIM: Learning Disentangled Representations from Incomplete Multimodal Healthcare Data } },
111 | booktitle = {proceedings of Medical Image Computing and Computer Assisted Intervention -- MICCAI 2024},
112 | year = {2024},
113 | publisher = {Springer Nature Switzerland},
114 | }
115 | ```
116 |
--------------------------------------------------------------------------------
/clustering.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "d23a603f",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "! pip install lifelines\n",
11 | "! pip install scikit-survival"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": null,
17 | "id": "bb41018f",
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "# Standard libraries\n",
22 | "from collections import defaultdict\n",
23 | "\n",
24 | "# Third-party libraries\n",
25 | "import pandas as pd\n",
26 | "import torch\n",
27 | "import numpy as np\n",
28 | "from omegaconf import DictConfig, OmegaConf\n",
29 | "from sksurv.nonparametric import kaplan_meier_estimator\n",
30 | "import matplotlib.pyplot as plt\n",
31 | "from lifelines.statistics import logrank_test\n",
32 | "\n",
33 | "# Local dependencies\n",
34 | "from drim.helpers import get_datasets, get_targets, get_encoder\n",
35 | "from drim.utils import log_transform, seed_everything, get_dataframes, prepare_data, interpolate_dataframe\n",
36 | "from drim.multimodal import MultimodalDataset, DRIMSurv, MultimodalModel\n",
37 | "from drim.datasets import SurvivalDataset\n",
38 | "from drim.models import MultimodalWrapper\n",
39 | "\n",
40 | "\n",
41 | "seed = 1999\n",
42 | "n_outs = 20\n",
43 | "import matplotlib.pyplot as plt\n",
44 | "plt.rcParams[\"font.family\"] = \"serif\""
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": null,
50 | "id": "b81dfa02",
51 | "metadata": {
52 | "scrolled": false
53 | },
54 | "outputs": [],
55 | "source": [
56 | "cfg = OmegaConf.load('./configs/robustness.yaml')\n",
57 | "for method in ['tensor', 'concat', 'max', 'drim']:\n",
58 | " if method == 'tensor':\n",
59 | " cfg.general.dim = 32\n",
60 | " else:\n",
61 | " cfg.general.dim = 128\n",
62 | " scores = []\n",
63 | " for fold in range(5):\n",
64 | " seed_everything(seed)\n",
65 | " dataframes = get_dataframes(fold)\n",
66 | " dataframes = {split: prepare_data(dataframe, ['DNAm', 'WSI', 'RNA', 'MRI']) for split, dataframe in dataframes.items()}\n",
67 | " test_datasets = {}\n",
68 | " encoders = {}\n",
69 | " if method == 'drim':\n",
70 | " encoders_u = {}\n",
71 | "\n",
72 | " for modality in ['DNAm', 'WSI', 'RNA', 'MRI']:\n",
73 | " datasets = get_datasets(dataframes, modality, fold, return_mask=True)\n",
74 | " test_datasets[modality] = datasets['test']\n",
75 | " encoder = get_encoder(modality, cfg).cuda()\n",
76 | " encoders[modality] = encoder\n",
77 | " if method == 'drim':\n",
78 | " encoder_u = get_encoder(modality, cfg).cuda()\n",
79 | " encoders_u[modality] = encoder_u\n",
80 | " \n",
81 | " targets, cut = get_targets(dataframes, cfg.general.n_outs)\n",
82 | " dataset_test = MultimodalDataset(test_datasets, return_mask=True)\n",
83 | " test_data = SurvivalDataset(dataset_test, *targets['test'])\n",
84 | " loader = torch.utils.data.DataLoader(test_data, shuffle=False, batch_size=24)\n",
85 | " if method == 'drim':\n",
86 | " from drim.fusion import MaskedAttentionFusion\n",
87 | " fusion = MaskedAttentionFusion(dim=cfg.general.dim, depth=1, heads=16, dim_head=64, mlp_dim=128)\n",
88 | " fusion_u = MaskedAttentionFusion(dim=cfg.general.dim, depth=1, heads=16, dim_head=64, mlp_dim=128)\n",
89 | " fusion.cuda()\n",
90 | " fusion_u.cuda()\n",
91 | " encoder = DRIMSurv(encoders_sh=encoders, encoders_u=encoders_u, fusion_s=fusion, fusion_u=fusion_u)\n",
92 | " model = MultimodalWrapper(encoder, embedding_dim=cfg.general.dim, n_outs=cfg.general.n_outs)\n",
93 | " model.load_state_dict(torch.load(f'./models/drimsurv_split_{int(fold)}.pth'))\n",
94 | " else:\n",
95 | " if method == 'max':\n",
96 | " from drim.fusion import ShallowFusion\n",
97 | " fusion = ShallowFusion('max')\n",
98 | " elif method == 'tensor':\n",
99 | " from drim.fusion import TensorFusion\n",
100 | " fusion = TensorFusion(modalities=['DNAm', 'WSI', 'RNA', 'MRI'], input_dim=cfg.general.dim, projected_dim=cfg.general.dim, output_dim=cfg.general.dim, dropout=0.)\n",
101 | " elif method == 'concat':\n",
102 | " from drim.fusion import ShallowFusion\n",
103 | " fusion = ShallowFusion('concat')\n",
104 | " elif method == 'maf':\n",
105 | " from drim.fusion import MaskedAttentionFusion\n",
106 | " fusion = MaskedAttentionFusion(dim=cfg.general.dim, depth=1, heads=16, dim_head=64, mlp_dim=128)\n",
107 | " \n",
108 | " fusion.cuda()\n",
109 | " encoder = MultimodalModel(encoders, fusion= fusion)\n",
110 | " if method == 'concat':\n",
111 | " size = cfg.general.dim * 4\n",
112 | " else:\n",
113 | " size = cfg.general.dim\n",
114 | " model = MultimodalWrapper(encoder, embedding_dim=size, n_outs=cfg.general.n_outs)\n",
115 | " if method == 'max':\n",
116 | " prefix = 'vanilla'\n",
117 | " else:\n",
118 | " prefix = 'aux_mmo'\n",
119 | " \n",
120 | " model.load_state_dict(torch.load(f'./models/{prefix}_{method}_split_{int(fold)}.pth'))\n",
121 | "\n",
122 | " \n",
123 | " model.cuda()\n",
124 | " model.eval()\n",
125 | "\n",
126 | " \n",
127 | " hazards = []\n",
128 | "\n",
129 | " with torch.no_grad():\n",
130 | " for batch in loader:\n",
131 | " data, time, event = batch\n",
132 | " data, mask = data\n",
133 | " outputs = model(data, mask, return_embedding=False)\n",
134 | " hazards.append(outputs)\n",
135 | " \n",
136 | " hazards = interpolate_dataframe(pd.DataFrame((1 - torch.cat(hazards, dim=0).sigmoid()).add(1e-7).log().cumsum(1).exp().cpu().numpy().transpose(), cut))\n",
137 | " scores.append(-hazards.sum(0).values)\n",
138 | " \n",
139 | " scores = torch.from_numpy(np.stack(scores).mean(0))\n",
140 | " high_scores = scores > scores.median()\n",
141 | " low_scores = scores <= scores.median()\n",
142 | " data_high = dataframes['test'].iloc[high_scores.numpy()]\n",
143 | " data_low = dataframes['test'].iloc[low_scores.numpy()]\n",
144 | " results = logrank_test(data_high[\"time\"], data_low[\"time\"], data_high[\"event\"], data_low[\"event\"])\n",
145 | "\n",
146 | " p_value = results.p_value\n",
147 | " fig = plt.figure(dpi=500)\n",
148 | " time, survival_prob, conf_int = kaplan_meier_estimator(\n",
149 | " data_low[\"event\"].astype(bool),\n",
150 | " data_low[\"time\"],\n",
151 | " conf_type=\"log-log\",\n",
152 | " )\n",
153 | "\n",
154 | " plt.step(time, survival_prob, where=\"post\", label=f\"Low risk\")\n",
155 | " plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step=\"post\")\n",
156 | " time, survival_prob, conf_int = kaplan_meier_estimator(\n",
157 | " data_high[\"event\"].astype(bool),\n",
158 | " data_high[\"time\"],\n",
159 | " conf_type=\"log-log\",\n",
160 | " )\n",
161 | "\n",
162 | " plt.step(time, survival_prob, where=\"post\", label=f\"High risk\")\n",
163 | " plt.fill_between(time, conf_int[0], conf_int[1], alpha=0.25, step=\"post\")\n",
164 | " plt.ylim(0, 1)\n",
165 | " plt.ylabel(r\"Probability of survival $\\hat{S}(t)$\")\n",
166 | " plt.xlabel(\"Time $t$\") \n",
167 | " plt.legend(loc=\"best\")\n",
168 | " if method == 'tensor' or method == 'concat':\n",
169 | " title = method.capitalize()\n",
170 | " title += ' w/ MMO'\n",
171 | " else:\n",
172 | " title = method.capitalize()\n",
173 | " plt.title(title)\n",
174 | " p_value = f'{p_value:0.2}'\n",
175 | " plt.text(11,0.8,r'$p_{value}$='+p_value, fontsize=12.8)\n",
176 | " plt.grid()\n",
177 | " plt.savefig(f'{method}.pdf')\n",
178 | " #tikzplotlib_fix_ncols(fig)\n",
179 | " #tikzplotlib.save(f'{method}.tikz')\n",
180 | " plt.show()"
181 | ]
182 | }
183 | ],
184 | "metadata": {
185 | "kernelspec": {
186 | "display_name": "Python 3 (ipykernel)",
187 | "language": "python",
188 | "name": "python3"
189 | },
190 | "language_info": {
191 | "codemirror_mode": {
192 | "name": "ipython",
193 | "version": 3
194 | },
195 | "file_extension": ".py",
196 | "mimetype": "text/x-python",
197 | "name": "python",
198 | "nbconvert_exporter": "python",
199 | "pygments_lexer": "ipython3",
200 | "version": "3.10.10"
201 | }
202 | },
203 | "nbformat": 4,
204 | "nbformat_minor": 5
205 | }
206 |
--------------------------------------------------------------------------------
/configs/multimodal.yaml:
--------------------------------------------------------------------------------
1 |
2 | general:
3 | seed: 1999
4 | val_split: 4
5 | modalities: [DNAm, WSI, RNA, MRI]
6 | n_outs: 20
7 | dropout: 0.2
8 | epochs: 30
9 | epochs_finetune: 10
10 | dim: 128
11 | save_path: ./models/
12 | n_folds: 5
13 |
14 | dataloader:
15 | batch_size: 24
16 | pin_memory: true
17 | num_workers: 40
18 | persistent_workers: true
19 |
20 | fusion:
21 | name: maf
22 | params:
23 | depth: 1
24 | heads: 16
25 | dim_head: 64
26 | mlp_dim: 128
27 |
28 | optimizer:
29 | name: AdamW
30 | params:
31 | lr: 1e-3
32 | weight_decay: 1e-2
33 |
34 | scheduler:
35 | T_max: ${general.epochs}
36 | eta_min: 5e-6
37 |
38 | aux_loss:
39 | name: contrastive
40 | alpha: 1.
41 |
42 | disentangled:
43 | gamma: 0.8
44 | dsm_lr: 1e-3
45 | dsm_wd: 3e-4
--------------------------------------------------------------------------------
/configs/robustness.yaml:
--------------------------------------------------------------------------------
1 | method: tensor
2 | general:
3 | dim: 128
4 | n_outs: 20
5 | dropout: 0.
6 | seed: 1999
--------------------------------------------------------------------------------
/configs/unimodal.yaml:
--------------------------------------------------------------------------------
1 |
2 | general:
3 | seed: 1999
4 | val_split: 4
5 | modalities: DNAm
6 | n_outs: 20
7 | dropout: 0.2
8 | epochs: 50
9 | save_path: ./models/
10 | dim: 128
11 | n_folds: 5
12 |
13 | dataloader:
14 | batch_size: 24
15 | pin_memory: true
16 | num_workers: 40
17 | persistent_workers: true
18 |
19 | optimizer:
20 | name: AdamW
21 | params:
22 | lr: 1e-3
23 | weight_decay: 1e-2
24 |
25 | scheduler:
26 | T_max: ${general.epochs}
27 | eta_min: 5e-6
28 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data
2 | All data used in this study are from The Cancer Genome Atlas (TCGA) program [(Weinstein et al., Nat Genet 2013)](https://www.nature.com/articles/ng.2764). They are all publicly available on the [GDC Data Portal](https://portal.gdc.cancer.gov/).
3 |
4 | # Structure
5 | The structure of this given part is organised as follows
6 | ```
7 | ├── scripts
8 | │ ├── gdc_manifest_20230619_DNAm_GBM_LGG.txt
9 | │ ├── gdc_manifest_20230918_WSI_LGG.txt
10 | │ ├── gdc_manifest_20231124_WSI_GBM.txt
11 | │ ├── gdc_manifest_20231203_RNA_GBM_LGG.txt
12 | ├── files
13 | │ ├── clinical_data.tsv
14 | │ ├── supplementary_data.tsv
15 | │ ├── ... <- preprocessed dataframes created from preprocessing.ipynb
16 | ├── mappings
17 | │ ├── wsi_mapping.json
18 | │ ├── BraTS2023_2017_GLI_Mapping.xlsx
19 | ├── models
20 | │ ├── wsi_encoder.pth
21 | │ ├── t1ce_flair_tumorTrue.pth
22 | │ ├── ...
23 | ├── rna_preprocessors
24 | │ ├── trf_0.joblib
25 | │ ├── trf_1.joblib
26 | │ ├── trf_2.joblib
27 | │ ├── trf_3.joblib
28 | │ ├── trf_4.joblib
29 | ├── get_wsi_thumbnails.py
30 | ├── run_mri_pretraining.py
31 | ├── run_wsi_pretraining.py
32 | ├── extract_brain_embeddings.py
33 | ├── preprocessing.ipynb
34 | ``````
35 | The `scripts` folder contains all the manifest files needed to download data from the [GDC Data Portal](https://portal.gdc.cancer.gov/) (see their documentation).
36 |
37 | For instance suppose we want data to be stocked in the `~/TCGA/GBMLGG` folder. To download RNASeq data, one can use the following command-line:
38 | ```
39 | $ mkdir ~/TCGA/GBMLGG/raw_RNA
40 | $ sudo /opt/gdc-client download \
41 | -d ~/TCGA/GBMLGG/raw_RNA/ \
42 | -m scripts/gdc_manifest_20231203_RNA_GBM_LGG.txt
43 | ```
44 | For data pre-processing we proceeded in a similar way to [Vale-Silva et al., 2022](https://www.nature.com/articles/s41598-021-92799-4). Thus the `clinical_data.tsv` file is identical. The supplementary data (`files/supplementary_data.tsv`) contains information about the methylation and IDH status of patients.
45 | It can be obtained using the following R code:
46 |
47 | ```
48 | library(TCGAWorkflowData)
49 | library(DT)
50 | library(TCGAbiolinks)
51 |
52 | gdc <- TCGAquery_subtype(tumor='all')
53 | output_path <- 'files/supplementary_data.tsv'
54 | readr::write_tsv(gdc, output_path)
55 | ```
56 | - In the `mappings` folder, the `wsi_mapping.json` file gives the path to the corresponding WSI for each patient (see `preprocessing.ipynb` for its construction).
57 | - The `BraTS2023_2017_GLI_Mapping.xlsx` file is used to link the BraTS competition MRIs to the TCGA patients, you can download it directly on the [BraTS competition homepage](https://www.synapse.org/#!Synapse:syn51156910/wiki/621282) by signing up to the competition or through The Cancer Imaging Archive (TCIA).
58 |
59 | The `models` folder is immediately self-explanatory: it contains pretrained encoder used for the multimodal training. The maximum file size limit of the supplemental does not allow the WSI pre-trained models to be supplied. So only the MRI pre-trained model is provided.
60 |
61 | # Preprocessing
62 | All the pre-processing is detailed in the file `preprocessing.ipynb`
63 | All the python files in the root are used to build the final data file used later. Their function is described in more detail in the notebook. To make a long story short:
64 | - `get_wsi_thumbnails.py`: is used to obtain a tissue mask and a low-resolution image for each WSI.
65 | - `run_mri_pretraining.py`: is used to train the specific encoder for MRI data.
66 | - `run_wsi_pretraining.py`: is used to train the specific encoder for WSI data.
67 | - `extract_brain_embeddings.py`: is used to extract embeddings from MRI data, where each volume is centered around the tumor and contains only t1ce and flair modalities (2x64x64x64).
68 |
69 | The `rna_preprocessors` folder contains all the objects to be used to preprocess the data according to the validation split.
70 | The script generating these files, which will be used throughout the training sessions, is present in `preprocessing.ipynb`.
71 | # Acknowledgements
72 | As a large part of this folder is largely inspired by [Vale-Silva's MultiSurv paper](https://www.nature.com/articles/s41598-021-92799-4), it seems important to thank him again for his very clear and reproducible work and to point it out. [Vale-Silva's GitHub repository](https://github.com/luisvalesilva/multisurv/tree/master/data)
--------------------------------------------------------------------------------
/data/extract_brain_embeddings.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 | import torch
4 | import csv
5 | from drim.mri.transforms import tumor_transfo
6 | from drim.mri.datasets import MRIProcessor
7 | from drim.mri.models import MRIEncoder
8 | from drim.utils import clean_state_dict
9 | import tqdm
10 | import os
11 |
12 | # Load the data
13 | data = pd.read_csv("./data/files/dataframe_brain.csv")
14 | encoder = MRIEncoder(2, 512, False)
15 | encoder.load_state_dict(
16 | clean_state_dict(torch.load("data/models/t1ce-flair_tumorTrue.pth")), strict=False
17 | )
18 | encoder.eval()
19 | for mri_path in tqdm.tqdm(data.MRI):
20 | if pd.isna(mri_path):
21 | continue
22 | process = MRIProcessor(
23 | mri_path,
24 | tumor_centered=True,
25 | transform=tumor_transfo,
26 | modalities=["t1ce", "flair"],
27 | size=(64, 64, 64),
28 | )
29 | mri = process.process().unsqueeze(0)
30 | with torch.no_grad():
31 | embedding = encoder(mri)
32 |
33 | # Save the embedding
34 | with open(os.path.join(mri_path, "embedding.csv"), "w") as f:
35 | writer = csv.writer(f)
36 | writer.writerow([str(i) for i in range(512)])
37 | writer.writerow(embedding.squeeze().numpy().tolist())
38 |
--------------------------------------------------------------------------------
/data/get_wsi_thumbnails.py:
--------------------------------------------------------------------------------
1 | import pyvips
2 | import os
3 | import tqdm
4 | import cv2
5 | import numpy as np
6 | from typing import Tuple
7 | from PIL import Image
8 | import argparse
9 |
10 |
11 | def segment(
12 | img_rgba: np.ndarray,
13 | sthresh: int = 25,
14 | sthresh_up: int = 255,
15 | mthresh: int = 9,
16 | otsu: bool = True,
17 | ) -> Tuple[np.ndarray, np.ndarray]:
18 | img = cv2.cvtColor(img_rgba, cv2.COLOR_RGBA2RGB)
19 |
20 | img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) # Convert to HSV space
21 | # img_med = filters.median(img_hsv))#, np.ones((3, 3,3))
22 | img_med = cv2.medianBlur(img_hsv[:, :, 1], mthresh) # Apply median blurring
23 | # img_med = img_hsv
24 | if otsu:
25 | _, img_otsu = cv2.threshold(
26 | img_med, 0, sthresh_up, cv2.THRESH_OTSU + cv2.THRESH_BINARY
27 | )
28 | else:
29 | _, img_otsu = cv2.threshold(img_med, sthresh, sthresh_up, cv2.THRESH_BINARY)
30 | masked_image = cv2.bitwise_and(img, img, mask=img_otsu)
31 | return masked_image, img_otsu
32 |
33 |
34 | def main(data_path, downscale_factor):
35 | """
36 | This script creates thumbnails and low resolution masks of the WSI images.
37 | The data folder must be organized as follows (from a TCGA download):
38 | - data_path
39 | - id1
40 | - filename1.svs
41 | - ...
42 | - id2
43 | - filename2.svs
44 | - ...
45 | - ...
46 |
47 | Args:
48 | data_path: Path to the input directory.
49 | downscale_factor: Downscale factor for x20 magnification.
50 |
51 | Returns:
52 | None
53 |
54 | ----
55 | Example usage:
56 | python get_wsi_thumbnails.py --data_path /data/ --downscale_factor 6
57 |
58 | """
59 | subdirectories = os.listdir(data_path)
60 | for subdirectory in tqdm.tqdm(subdirectories):
61 | subdirectory_path = os.path.join(data_path, subdirectory)
62 | filenames = os.listdir(subdirectory_path)
63 | wsi_filename = [f for f in filenames if f.endswith("svs") or f.endswith("tif")][
64 | 0
65 | ]
66 | slide = pyvips.Image.new_from_file(
67 | os.path.join(subdirectory_path, wsi_filename)
68 | )
69 | if int(float(slide.get("aperio.AppMag"))) == 40:
70 | d = downscale_factor + 1
71 | else:
72 | d = downscale_factor
73 | thumbnail = pyvips.Image.thumbnail(
74 | os.path.join(subdirectory_path, wsi_filename),
75 | slide.width / (2**d),
76 | height=slide.height / (2**d),
77 | ).numpy()
78 |
79 | thumbnail = cv2.cvtColor(thumbnail, cv2.COLOR_RGBA2RGB)
80 | thumbnail_hsv = cv2.cvtColor(thumbnail, cv2.COLOR_RGB2HSV)
81 | # to filter out felt-tip marks
82 | mask_hsv = np.tile(thumbnail_hsv[:, :, 1] < 160, (3, 1, 1)).transpose(1, 2, 0)
83 | thumbnail *= mask_hsv
84 | masked_image, mask = segment(thumbnail)
85 | masked_image = Image.fromarray(masked_image).convert("RGB")
86 | # save
87 | masked_image.save(os.path.join(subdirectory_path, "thumbnail.jpg"))
88 | np.save(os.path.join(subdirectory_path, "mask.npy"), mask)
89 |
90 |
91 | if __name__ == "__main__":
92 | parser = argparse.ArgumentParser(
93 | description="Create thumbnails and low resolution masks of the WSI images."
94 | )
95 | parser.add_argument("--data_path", "-p", help="Path to the input directory.")
96 | parser.add_argument(
97 | "--downscale_factor",
98 | "-d",
99 | type=int,
100 | help="Downscale factor for x20 magnification.",
101 | )
102 | args = parser.parse_args()
103 | main(args.data_path, args.downscale_factor)
104 |
--------------------------------------------------------------------------------
/data/rna_preprocessors/trf_0.joblib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/data/rna_preprocessors/trf_0.joblib
--------------------------------------------------------------------------------
/data/rna_preprocessors/trf_1.joblib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/data/rna_preprocessors/trf_1.joblib
--------------------------------------------------------------------------------
/data/rna_preprocessors/trf_2.joblib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/data/rna_preprocessors/trf_2.joblib
--------------------------------------------------------------------------------
/data/rna_preprocessors/trf_3.joblib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/data/rna_preprocessors/trf_3.joblib
--------------------------------------------------------------------------------
/data/rna_preprocessors/trf_4.joblib:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/data/rna_preprocessors/trf_4.joblib
--------------------------------------------------------------------------------
/data/run_mri_pretraining.py:
--------------------------------------------------------------------------------
1 | # Standard libraries
2 | import argparse
3 | import os
4 |
5 | # Third-party libraries
6 | import torch
7 | import wandb
8 | import numpy as np
9 | import pandas as pd
10 |
11 | # Local libraries
12 | from drim.utils import seed_everything, seed_worker
13 | from drim.logger import logger
14 | from drim.mri.datasets import DatasetBraTSTumorCentered
15 | from drim.mri.models import MRIEncoder
16 | from drim.mri.transforms import get_tumor_transforms
17 | from drim.commons.optim_contrastive import training_loop_contrastive
18 | from drim.commons.losses import ContrastiveLoss
19 |
20 |
21 | if __name__ == "__main__":
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument("--data_path", type=str, default="../TCGA/GBMLGG/MRI")
24 | parser.add_argument("--batch_size", type=int, default=48)
25 | parser.add_argument("--epochs", type=int, default=30)
26 | parser.add_argument("--lr", type=float, default=1e-4)
27 | parser.add_argument("--modalities", nargs="+", default=["t1ce", "flair"])
28 | parser.add_argument("--weight_decay", type=float, default=1e-6)
29 | parser.add_argument("--entity", type=str, default=None)
30 | parser.add_argument("--project", type=str, default="DRIM")
31 | parser.add_argument("--temperature", type=float, default=0.07)
32 | parser.add_argument("--tumor_centered", type=bool, default=True)
33 | parser.add_argument("--n_cpus", type=int, default=40)
34 | parser.add_argument("--n_gpus", type=int, default=4)
35 | parser.add_argument("--k", type=int, default=3)
36 | parser.add_argument("--seed", type=int, default=1999)
37 | args = parser.parse_args()
38 |
39 | # Set seed
40 | seed_everything(args.seed)
41 | patients = os.listdir(args.data_path)
42 | dataframe = pd.read_csv("data/files/dataframe_brain.csv")
43 | dataframe_test = dataframe[dataframe["group"] == "test"]
44 | # get patient ids where MRI is not NaN
45 | dataframe_test = dataframe_test[~dataframe_test["MRI"].isna()]
46 | patients_to_exclude = [
47 | patient_path.split("/")[-1] for patient_path in dataframe_test.MRI.values
48 | ]
49 | patients = [patient for patient in patients if patient not in patients_to_exclude]
50 |
51 | # split patient into train and val by taking random 70% of patients for training
52 | train_patients = np.random.choice(
53 | patients, int(len(patients) * 0.75), replace=False
54 | )
55 | val_patients = [patient for patient in patients if patient not in train_patients]
56 | logger.info("Performing contrastive training on BraTS dataset.")
57 | logger.info("Modalities used : {}.", args.modalities)
58 | if bool(args.tumor_centered):
59 | logger.info("Using tumor centered dataset.")
60 | category = "tumor"
61 | sizes = (64, 64, 64)
62 | train_dataset = DatasetBraTSTumorCentered(
63 | args.data_path,
64 | args.modalities,
65 | patients=train_patients,
66 | sizes=sizes,
67 | return_mask=False,
68 | transform=get_tumor_transforms(sizes),
69 | )
70 | val_dataset = DatasetBraTSTumorCentered(
71 | args.data_path,
72 | args.modalities,
73 | patients=val_patients,
74 | sizes=sizes,
75 | return_mask=False,
76 | transform=get_tumor_transforms(sizes),
77 | )
78 | else:
79 | raise NotImplementedError
80 |
81 | train_loader = torch.utils.data.DataLoader(
82 | train_dataset,
83 | batch_size=args.batch_size,
84 | shuffle=True,
85 | pin_memory=True,
86 | num_workers=args.n_cpus,
87 | persistent_workers=True,
88 | worker_init_fn=seed_worker,
89 | )
90 | val_loader = torch.utils.data.DataLoader(
91 | val_dataset,
92 | batch_size=args.batch_size,
93 | shuffle=False,
94 | num_workers=args.n_cpus,
95 | pin_memory=True,
96 | persistent_workers=True,
97 | worker_init_fn=seed_worker,
98 | )
99 |
100 | model = MRIEncoder(projection_head=True, in_channels=len(args.modalities))
101 |
102 | logger.info(
103 | "Number of parameters : {}",
104 | sum(p.numel() for p in model.parameters() if p.requires_grad),
105 | )
106 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
107 |
108 | if args.n_gpus > 1:
109 | model = torch.nn.DataParallel(model, device_ids=list(range(args.n_gpus)))
110 |
111 | model.to(device)
112 | logger.info("Using {} gpus to train the model", args.n_gpus)
113 |
114 | contrastive_loss = ContrastiveLoss(temperature=args.temperature, k=args.k)
115 |
116 | optimizer = torch.optim.Adam(
117 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay
118 | )
119 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
120 | optimizer, T_max=args.epochs, eta_min=5e-6
121 | )
122 |
123 | path_to_save = (
124 | f"data/models/{'-'.join(args.modalities)}_tumor{str(args.tumor_centered)}.pth"
125 | )
126 | # if a wandb entity is provided, log the training on wandb
127 | wandb_logging = True if args.entity is not None else False
128 | if wandb_logging:
129 | run = wandb.init(
130 | project=args.project,
131 | entity=args.entity,
132 | name=f"Pretraining_MRI",
133 | reinit=True,
134 | config=vars(args),
135 | )
136 | logger.info("Training started!")
137 | _, _ = training_loop_contrastive(
138 | model,
139 | args.epochs,
140 | contrastive_loss,
141 | optimizer,
142 | lr_scheduler,
143 | train_loader,
144 | val_loader,
145 | device,
146 | path_to_save,
147 | wandb_logging=wandb_logging,
148 | k=args.k,
149 | )
150 | if wandb_logging:
151 | run.finish()
152 | logger.info("Training finished!")
153 |
--------------------------------------------------------------------------------
/data/run_wsi_pretraining.py:
--------------------------------------------------------------------------------
1 | # Standard libraries
2 | import argparse
3 | import os
4 |
5 | # Third-party libraries
6 | import torch
7 | import wandb
8 | import numpy as np
9 | from imutils import paths
10 |
11 | # Local dependencies
12 | from drim.utils import seed_everything, seed_worker
13 | from drim.logger import logger
14 | from drim.wsi.transforms import contrastive_wsi_transforms
15 | from drim.wsi.datasets import PatchDataset
16 | from drim.wsi.models import ResNetWrapperSimCLR
17 | from drim.commons.optim_contrastive import training_loop_contrastive
18 | from drim.commons.losses import ContrastiveLoss
19 |
20 |
21 | if __name__ == "__main__":
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument(
24 | "--data_path",
25 | nargs="+",
26 | default=["/media/hddext/medical/TCGA-GBM_WSI", "/archive/medical/tcga_lgg/wsi"],
27 | )
28 | parser.add_argument("--batch_size", type=int, default=200)
29 | parser.add_argument("--epochs", type=int, default=30)
30 | parser.add_argument("--lr", type=float, default=1e-3)
31 | parser.add_argument("--temperature", type=float, default=0.07)
32 | parser.add_argument("--weight_decay", type=float, default=1e-6)
33 | parser.add_argument("--out_dim", type=int, default=256)
34 | parser.add_argument("--entity", type=str, default=None)
35 | parser.add_argument("--project", type=str, default="DRIM")
36 | parser.add_argument("--n_cpus", type=int, default=30)
37 | parser.add_argument("--n_gpus", type=int, default=4)
38 | parser.add_argument("--k", type=int, default=10)
39 | parser.add_argument("--seed", type=int, default=1999)
40 | args = parser.parse_args()
41 |
42 | # Set seed
43 | seed_everything(args.seed)
44 | # get every filepaths
45 | filepaths = []
46 | for path in args.data_path:
47 | temp_filepaths = list(paths.list_images(path))
48 | temp_filepaths = [
49 | filepath for filepath in temp_filepaths if "patches" in filepath
50 | ]
51 | filepaths += temp_filepaths
52 |
53 | filepaths = np.array(filepaths)
54 | # split patient into train and val by taking random 75% of patients for training
55 |
56 | train_idx = np.random.choice(
57 | np.arange(len(filepaths)), int(len(filepaths) * 0.75), replace=False
58 | )
59 | train_mask = np.zeros(len(filepaths), dtype=bool)
60 | train_mask[train_idx] = True
61 | train_filepaths = filepaths[train_idx]
62 | val_filepaths = filepaths[~train_mask]
63 |
64 | logger.info("Performing contrastive training on TCGA dataset.")
65 | logger.info("Number of training images : {}.", len(train_filepaths))
66 | logger.info("Number of validation images : {}.", len(val_filepaths))
67 |
68 | train_dataset = PatchDataset(train_filepaths, transforms=contrastive_wsi_transforms)
69 | val_dataset = PatchDataset(val_filepaths, transforms=contrastive_wsi_transforms)
70 |
71 | train_loader = torch.utils.data.DataLoader(
72 | train_dataset,
73 | batch_size=args.batch_size,
74 | shuffle=True,
75 | pin_memory=True,
76 | worker_init_fn=seed_worker,
77 | )
78 | val_loader = torch.utils.data.DataLoader(
79 | val_dataset,
80 | batch_size=args.batch_size,
81 | shuffle=False,
82 | pin_memory=True,
83 | worker_init_fn=seed_worker,
84 | )
85 |
86 | model = ResNetWrapperSimCLR(out_dim=args.out_dim, projection_head=True)
87 |
88 | logger.info(
89 | "Number of parameters : {}",
90 | sum(p.numel() for p in model.parameters() if p.requires_grad),
91 | )
92 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
93 |
94 | if args.n_gpus > 1:
95 | model = torch.nn.DataParallel(model, device_ids=list(range(args.n_gpus)))
96 |
97 | model.to(device)
98 | logger.info("Using {} gpus to train the model", args.n_gpus)
99 |
100 | contrastive_loss = ContrastiveLoss(temperature=args.temperature, k=args.k)
101 |
102 | optimizer = torch.optim.Adam(
103 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay
104 | )
105 | lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
106 | optimizer, T_max=40, eta_min=5e-6
107 | )
108 |
109 | path_to_save = f"data/models/wsi_encoder.pth"
110 | wandb_logging = True if args.entity is not None else False
111 | if wandb_logging:
112 | run = wandb.init(
113 | project=args.project,
114 | entity=args.entity,
115 | name=f"Pretraining_WSI",
116 | reinit=True,
117 | config=vars(args),
118 | )
119 | logger.info("Training started!")
120 | _, _ = training_loop_contrastive(
121 | model,
122 | args.epochs,
123 | contrastive_loss,
124 | optimizer,
125 | lr_scheduler,
126 | train_loader,
127 | val_loader,
128 | device,
129 | path_to_save,
130 | wandb_logging=wandb_logging,
131 | k=args.k,
132 | )
133 | if wandb_logging:
134 | run.finish()
135 | logger.info("Training finished!")
136 |
--------------------------------------------------------------------------------
/drim/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/drim/__init__.py
--------------------------------------------------------------------------------
/drim/commons/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/drim/commons/__init__.py
--------------------------------------------------------------------------------
/drim/commons/losses.py:
--------------------------------------------------------------------------------
1 | from torch.nn.modules.loss import _Loss
2 | import torch.nn.functional as F
3 | from warnings import warn
4 | import torch
5 |
6 |
7 | class ContrastiveLoss(_Loss):
8 | """
9 | Deeply inspired and copy/paste from a previous PR on MONAI (https://github.com/Project-MONAI/MONAI/blob/dev/monai/losses/contrastive.py)
10 |
11 | Compute the Contrastive loss defined in:
12 |
13 | Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International
14 | conference on machine learning. PMLR, 2020. (http://proceedings.mlr.press/v119/chen20j.html)
15 |
16 | Adapted from:
17 | https://github.com/Sara-Ahmed/SiT/blob/1aacd6adcd39b71efc903d16b4e9095b97dda76f/losses.py#L5
18 |
19 | """
20 |
21 | def __init__(
22 | self, temperature: float = 0.5, k: int = 3, batch_size: int = -1
23 | ) -> None:
24 | """
25 | Args:
26 | temperature: Can be scaled between 0 and 1 for learning from negative samples, ideally set to 0.5.
27 |
28 | Raises:
29 | ValueError: When an input of dimension length > 2 is passed
30 | ValueError: When input and target are of different shapes
31 |
32 | """
33 | super().__init__()
34 | self.temperature = temperature
35 | self.k = k
36 |
37 | if batch_size != -1:
38 | warn(
39 | "batch_size is no longer required to be set. It will be estimated dynamically in the forward call"
40 | )
41 |
42 | def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
43 | """
44 | Args:
45 | input: the shape should be B[F].
46 | target: the shape should be B[F].
47 | """
48 | if len(target.shape) > 2 or len(input.shape) > 2:
49 | raise ValueError(
50 | f"Either target or input has dimensions greater than 2 where target "
51 | f"shape is ({target.shape}) and input shape is ({input.shape})"
52 | )
53 |
54 | if target.shape != input.shape:
55 | raise ValueError(
56 | f"ground truth has differing shape ({target.shape}) from input ({input.shape})"
57 | )
58 |
59 | temperature_tensor = torch.as_tensor(self.temperature).to(input.device)
60 | batch_size = input.shape[0]
61 |
62 | negatives_mask = ~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool)
63 | negatives_mask = torch.clone(negatives_mask.type(torch.float)).to(input.device)
64 |
65 | repr = torch.cat([input, target], dim=0)
66 | sim_matrix = F.cosine_similarity(repr.unsqueeze(1), repr.unsqueeze(0), dim=2)
67 | sim_ij = torch.diag(sim_matrix, batch_size)
68 | sim_ji = torch.diag(sim_matrix, -batch_size)
69 |
70 | positives = torch.cat([sim_ij, sim_ji], dim=0)
71 | nominator = torch.exp(positives / temperature_tensor)
72 | denominator = negatives_mask * torch.exp(sim_matrix / temperature_tensor)
73 |
74 | loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))
75 | # create pos mask
76 | self_mask = torch.eye(
77 | sim_matrix.shape[0], dtype=torch.bool, device=sim_matrix.device
78 | )
79 | pos_mask = self_mask.roll(shifts=sim_matrix.shape[0] // 2, dims=0)
80 | sim_matrix.masked_fill_(self_mask, -9e15)
81 | comb_sim = torch.cat(
82 | [
83 | sim_matrix[pos_mask][:, None], # First position positive example
84 | sim_matrix.masked_fill(pos_mask, -9e15),
85 | ],
86 | dim=-1,
87 | )
88 | sim_argsort = comb_sim.argsort(dim=-1, descending=True).argmin(dim=-1)
89 |
90 | return (
91 | torch.sum(loss_partial) / (2 * batch_size),
92 | (sim_argsort == 0).float().mean(),
93 | (sim_argsort < self.k).float().mean(),
94 | 1 + sim_argsort.float().mean(),
95 | )
96 |
--------------------------------------------------------------------------------
/drim/commons/optim_contrastive.py:
--------------------------------------------------------------------------------
1 | # Standard libraries
2 | from collections import defaultdict
3 | from typing import Tuple
4 |
5 | # Third party libraries
6 | import torch
7 | import wandb
8 | import tqdm
9 | import numpy as np
10 |
11 | # Local dependencies
12 | from ..logger import logger
13 |
14 |
15 | def training_loop_contrastive(
16 | model: torch.nn.Module,
17 | epochs: int,
18 | loss_fn: torch.nn.modules.loss._Loss,
19 | optimizer: torch.optim.Optimizer,
20 | scheduler: torch.optim.lr_scheduler._LRScheduler,
21 | train_dl: torch.utils.data.DataLoader,
22 | valid_dl: torch.utils.data.DataLoader,
23 | device: torch.device,
24 | path_to_save: str,
25 | wandb_logging: bool,
26 | k: int = 3,
27 | ) -> Tuple[torch.nn.Module, dict]:
28 | metrics = defaultdict(list)
29 | best_loss = np.infty
30 | for epoch in range(epochs):
31 | logger.info(f"Epoch {epoch + 1}/{epochs}")
32 | logger.info("-" * 10)
33 | model.train()
34 | metrics["lr"].append(optimizer.state_dict()["param_groups"][0]["lr"])
35 | epoch_loss = 0.0
36 | top_1, top_k, mean_pos = 0.0, 0.0, 0.0
37 | for batch_data in tqdm.tqdm(train_dl, desc="Training...", total=len(train_dl)):
38 | if isinstance(batch_data, dict):
39 | inputs, inputs_2 = (
40 | batch_data["image"].to(device),
41 | batch_data["image_2"].to(device),
42 | )
43 | else:
44 | inputs, inputs_2 = batch_data[0].to(device), batch_data[1].to(device)
45 | optimizer.zero_grad()
46 | outputs, _ = model(inputs)
47 | outputs_2, _ = model(inputs_2)
48 | loss, acc_top_1, acc_top_k, acc_mean_pos = loss_fn(outputs, outputs_2)
49 | top_1 += acc_top_1 * inputs.shape[0]
50 | top_k += acc_top_k * inputs.shape[0]
51 | mean_pos += acc_mean_pos * inputs.shape[0]
52 | loss.backward()
53 | optimizer.step()
54 | epoch_loss += loss.item() * inputs.shape[0]
55 |
56 | epoch_loss /= len(train_dl.dataset)
57 | print(f"Training loss: {epoch_loss:.4f}")
58 |
59 | metrics["train/loss"].append(epoch_loss)
60 | metrics["train/top_1"].append(top_1 / len(train_dl.dataset))
61 | metrics[f"train/top_{k}"].append(top_k / len(train_dl.dataset))
62 | metrics["train/mean_pos"].append(mean_pos / len(train_dl.dataset))
63 | model.eval()
64 | with torch.no_grad():
65 | val_loss = 0.0
66 | top_1, top_k, mean_pos = 0.0, 0.0, 0.0
67 | for batch_data in tqdm.tqdm(valid_dl, desc="Validation..."):
68 | if isinstance(batch_data, dict):
69 | inputs, inputs_2 = (
70 | batch_data["image"].to(device),
71 | batch_data["image_2"].to(device),
72 | )
73 | else:
74 | inputs, inputs_2 = batch_data[0].to(device), batch_data[1].to(
75 | device
76 | )
77 | outputs, _ = model(inputs)
78 | outputs_2, _ = model(inputs_2)
79 |
80 | loss, acc_top_1, acc_top_k, acc_mean_pos = loss_fn(outputs, outputs_2)
81 | top_1 += acc_top_1 * inputs.shape[0]
82 | top_k += acc_top_k * inputs.shape[0]
83 | mean_pos += acc_mean_pos * inputs.shape[0]
84 | val_loss += loss.item() * inputs.shape[0]
85 |
86 | val_loss /= len(valid_dl.dataset)
87 | print(f"Validation loss: {val_loss:.4f}")
88 | metrics["val/top_1"].append(top_1 / len(valid_dl.dataset))
89 | metrics[f"val/top_{k}"].append(top_k / len(valid_dl.dataset))
90 | metrics["val/mean_pos"].append(mean_pos / len(valid_dl.dataset))
91 | metrics["val/loss"].append(val_loss)
92 |
93 | if wandb_logging:
94 | wandb.log({k: v[-1] for k, v in metrics.items()})
95 |
96 | scheduler.step()
97 | if metrics["val/loss"][-1] < best_loss:
98 | best_loss = metrics["val/loss"][-1]
99 | torch.save(model.state_dict(), path_to_save)
100 |
101 | return model, metrics
102 |
--------------------------------------------------------------------------------
/drim/datasets.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | import pandas as pd
3 | import torch
4 | from typing import Any, Tuple
5 |
6 |
7 | __all__ = ["SurvivalDataset"]
8 |
9 |
10 | class _BaseDataset(torch.utils.data.Dataset):
11 | def __init__(self, dataframe: pd.DataFrame, return_mask: bool = False) -> None:
12 | self.dataframe = dataframe
13 | self.return_mask = return_mask
14 |
15 | @abstractmethod
16 | def __getitem__(self, idx: int):
17 | raise NotImplementedError
18 |
19 | def __len__(self) -> int:
20 | return len(self.dataframe)
21 |
22 |
23 | class SurvivalDataset(torch.utils.data.Dataset):
24 | def __init__(
25 | self, dataset: torch.utils.data.Dataset, time: torch.tensor, event: torch.tensor
26 | ) -> None:
27 | self.dataset = dataset
28 | self.time, self.event = torch.from_numpy(time), torch.from_numpy(event)
29 |
30 | def __getitem__(self, idx: int) -> Tuple[Any, torch.Tensor, torch.Tensor]:
31 | return self.dataset[idx], self.time[idx], self.event[idx]
32 |
33 | def __len__(self) -> int:
34 | return len(self.dataset)
35 |
--------------------------------------------------------------------------------
/drim/dnam/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import DNAmDataset
2 | from .models import DNAmDecoder, DNAmEncoder
3 |
--------------------------------------------------------------------------------
/drim/dnam/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 | from typing import Union, Tuple
4 | from ..datasets import _BaseDataset
5 |
6 |
7 | class DNAmDataset(_BaseDataset):
8 | def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, bool], torch.Tensor]:
9 | sample = self.dataframe.iloc[idx]
10 | if not pd.isna(sample.DNAm):
11 | out = (
12 | torch.from_numpy(pd.read_csv(sample.DNAm).Beta_value.fillna(0.0).values)
13 | .float()
14 | .unsqueeze(0)
15 | )
16 | mask = True
17 | else:
18 | out = torch.zeros(1, 25978).float()
19 | mask = False
20 |
21 | if self.return_mask:
22 | return out, mask
23 | else:
24 | return out
25 |
--------------------------------------------------------------------------------
/drim/dnam/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops.layers.torch import Rearrange
4 |
5 |
6 | class DNAmEncoder(nn.Module):
7 | """
8 | A vanilla encoder based on 1-d convolution for DNAm data.
9 | """
10 |
11 | def __init__(self, embedding_dim: int, dropout: float) -> None:
12 | super().__init__()
13 | self.encoder = nn.Sequential(
14 | nn.Conv1d(1, 8, 9, 3),
15 | nn.GELU(),
16 | nn.BatchNorm1d(8),
17 | nn.Dropout(dropout),
18 | nn.Conv1d(8, 32, 9, 3),
19 | nn.GELU(),
20 | nn.BatchNorm1d(32),
21 | nn.Dropout(dropout),
22 | nn.Conv1d(32, 64, 9, 3),
23 | nn.GELU(),
24 | nn.BatchNorm1d(64),
25 | nn.Dropout(dropout),
26 | nn.Conv1d(64, 128, 9, 3),
27 | nn.GELU(),
28 | nn.BatchNorm1d(128),
29 | nn.Dropout(dropout),
30 | nn.Conv1d(128, 256, 9, 3),
31 | nn.GELU(),
32 | nn.BatchNorm1d(256),
33 | nn.Dropout(dropout),
34 | nn.Conv1d(256, embedding_dim, 9, 3),
35 | nn.GELU(),
36 | nn.BatchNorm1d(embedding_dim),
37 | nn.Dropout(dropout),
38 | nn.AdaptiveAvgPool1d(1),
39 | )
40 |
41 | def forward(self, x: torch.Tensor) -> torch.Tensor:
42 | return self.encoder(x).squeeze(-1)
43 |
44 |
45 | class DNAmDecoder(nn.Module):
46 | """
47 | A vanilla decoder based on 1-d transposed convolution for DNAm data.
48 | """
49 |
50 | def __init__(self, embedding_dim: int, dropout: float) -> None:
51 | super().__init__()
52 | self.decoder = nn.Sequential(
53 | nn.Linear(embedding_dim, embedding_dim * 33),
54 | Rearrange("b (n e) -> b n e", e=33),
55 | nn.GELU(),
56 | nn.BatchNorm1d(embedding_dim),
57 | nn.Dropout(dropout),
58 | nn.ConvTranspose1d(embedding_dim, 256, 9, 3),
59 | nn.GELU(),
60 | nn.BatchNorm1d(256),
61 | nn.Dropout(dropout),
62 | nn.ConvTranspose1d(256, 128, 9, 3, 1),
63 | nn.GELU(),
64 | nn.BatchNorm1d(128),
65 | nn.Dropout(dropout),
66 | nn.ConvTranspose1d(128, 64, 9, 3, 1),
67 | nn.GELU(),
68 | nn.BatchNorm1d(64),
69 | nn.Dropout(dropout),
70 | nn.ConvTranspose1d(64, 32, 9, 3, 2),
71 | nn.GELU(),
72 | nn.BatchNorm1d(32),
73 | nn.Dropout(dropout),
74 | nn.ConvTranspose1d(32, 8, 9, 3, 1),
75 | nn.GELU(),
76 | nn.BatchNorm1d(8),
77 | nn.Dropout(dropout),
78 | nn.ConvTranspose1d(8, 1, 8, 3, 2),
79 | nn.Sigmoid(),
80 | )
81 |
82 | def forward(self, x: torch.Tensor) -> torch.Tensor:
83 | x = self.decoder(x)
84 | return x
85 |
--------------------------------------------------------------------------------
/drim/fusion/__init__.py:
--------------------------------------------------------------------------------
1 | from .fusion import *
2 |
--------------------------------------------------------------------------------
/drim/fusion/fusion.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from typing import Dict, List, Optional, Tuple
3 | import torch
4 | import random
5 | from einops import repeat
6 | from .utils import MultiHeadSelfAttention, FeedForward
7 |
8 | __all__ = ["ShallowFusion", "TensorFusion", "MaskedAttentionFusion"]
9 |
10 |
11 | class ShallowFusion(nn.Module):
12 | """
13 | Class computing the simplest fusion methods. Adapted from MultiSurv (https://www.nature.com/articles/s41598-021-92799-4?proof=t).
14 | It also contains the code from Cheerla et al. (https://github.com/gevaertlab/MultimodalPrognosis/blob/master/experiments/chart4.py).
15 | """
16 |
17 | def __init__(self, fusion_type: str = "mean") -> None:
18 | super().__init__()
19 | self.fusion_type = fusion_type
20 |
21 | def forward(
22 | self, x: Dict[str, torch.Tensor], mask: Dict[str, torch.Tensor] = None
23 | ) -> torch.Tensor:
24 | if self.fusion_type == "mean":
25 | x = torch.stack(list(x.values()), dim=0).mean(dim=0)
26 | elif self.fusion_type == "max":
27 | x, _ = torch.stack(list(x.values()), dim=0).max(dim=0)
28 | elif self.fusion_type == "sum":
29 | x = torch.stack(list(x.values()), dim=0).sum(dim=0)
30 | elif self.fusion_type == "concat":
31 | x = torch.cat(list(x.values()), dim=-1)
32 | elif self.fusion_type == "masked_mean":
33 | x = self.masked_mean(x, mask)
34 |
35 | return x
36 |
37 | @staticmethod
38 | def masked_mean(x, mask):
39 | num = sum((x[modality] * mask[modality][:, None]) for modality in x.keys())
40 | den = sum(mask[modality] for modality in mask.keys())[:, None].float()
41 | return num / den
42 |
43 |
44 | class TensorFusion(nn.Module):
45 | """
46 | Multimodal Tensor Fusion via Kronecker Product and Gating-Based Attention derived from [1, 2].
47 |
48 | [1] https://ieeexplore.ieee.org/document/9186053
49 | [2] https://link.springer.com/chapter/10.1007/978-3-030-87240-3_64
50 |
51 | NB: This is an extension of the original loss, which can take more than 3 modalities.
52 |
53 | This fusion is not particularly well adapted to missing modalities and its number of trainable parameters can go high.
54 | Be careful with the number of modalities and the size of the tensors. Moreover it is not invariant to modality permutation.
55 |
56 | For instance: [modalities (m), input dimension (i), projected dimension (p), output dimension (o) ~ number of parameters (millions)]
57 | m: 3, i: 128, p: 128, o: 128 ~ 280M
58 | m: 4, i: 16, p: 16, o: 16 ~ 1M4
59 | m: 4, i: 32, p: 32, o: 32 ~ 38M
60 | """
61 |
62 | def __init__(
63 | self,
64 | modalities: List[str],
65 | input_dim: int,
66 | projected_dim: int,
67 | output_dim: int,
68 | gate: bool = True,
69 | skip: bool = True,
70 | dropout: float = 0.1,
71 | pairs: Optional[Dict[str, str]] = None,
72 | ) -> None:
73 | # raise a warning because of the important number of trainable parameters.
74 | super().__init__()
75 | self.skip = skip
76 | self.modalities = sorted(modalities)
77 | self.dropout = dropout
78 | self.gate = gate
79 | if gate:
80 | if pairs:
81 | self.pairs = pairs
82 | else:
83 | self.pairs = {}
84 | for modality in self.modalities:
85 | i = self.modalities.index(modality)
86 | j = i
87 | while j == i:
88 | j = random.randint(0, len(self.modalities) - 1)
89 | self.pairs[modality] = self.modalities[j]
90 |
91 | for modality in self.modalities:
92 | if gate:
93 | setattr(
94 | self,
95 | f"{modality.lower()}_linear",
96 | nn.Sequential(nn.Linear(input_dim, projected_dim), nn.ReLU()),
97 | )
98 | setattr(
99 | self,
100 | f"{modality.lower()}_bilinear",
101 | nn.Bilinear(input_dim, input_dim, projected_dim),
102 | )
103 | setattr(
104 | self,
105 | f"{modality.lower()}_last_linear",
106 | nn.Sequential(
107 | nn.Linear(projected_dim, projected_dim),
108 | nn.ReLU(),
109 | nn.Dropout(p=dropout),
110 | ),
111 | )
112 | else:
113 | setattr(
114 | self,
115 | f"{modality.lower()}_proj",
116 | nn.Sequential(nn.Linear(input_dim, projected_dim), nn.ReLU()),
117 | )
118 |
119 | self.dropout = nn.Dropout(dropout)
120 | self.post_fusion = nn.Sequential(
121 | nn.Linear((input_dim + 1) ** len(modalities), output_dim),
122 | nn.ReLU(),
123 | nn.Dropout(p=dropout),
124 | )
125 | if self.skip:
126 | self.skip = nn.Sequential(
127 | nn.Linear(output_dim + ((input_dim + 1) * len(modalities)), output_dim),
128 | nn.ReLU(),
129 | nn.Dropout(p=dropout),
130 | )
131 |
132 | def forward(self, x: Dict[str, torch.Tensor], mask=None) -> torch.Tensor:
133 | if self.gate:
134 | h = {k: getattr(self, f"{k.lower()}_linear")(v) for k, v in x.items()}
135 | x = {
136 | k: getattr(self, f"{k.lower()}_bilinear")(v, x[self.pairs[k]])
137 | for k, v in x.items()
138 | }
139 | x = {
140 | k: getattr(self, f"{k.lower()}_last_linear")(nn.Sigmoid()(x[k]) * h[k])
141 | for k, v in x.items()
142 | }
143 |
144 | else:
145 | x = {k: getattr(self, f"{k.lower()}_proj")(v) for k, v in x.items()}
146 |
147 | x = {
148 | k: torch.cat((v, torch.FloatTensor(v.shape[0], 1).fill_(1).to(v.device)), 1)
149 | for k, v in x.items()
150 | }
151 | out = torch.bmm(
152 | x[self.modalities[0]].unsqueeze(2), x[self.modalities[1]].unsqueeze(1)
153 | ).flatten(start_dim=1)
154 | for modality in self.modalities[2:]:
155 | out = torch.bmm(out.unsqueeze(2), x[modality].unsqueeze(1)).flatten(
156 | start_dim=1
157 | )
158 |
159 | out = self.dropout(out)
160 | out = self.post_fusion(out)
161 | if self.skip:
162 | out = torch.cat((out, *list(x.values())), 1)
163 | out = self.skip(out)
164 |
165 | return out
166 |
167 |
168 | class MaskedAttentionFusion(nn.Module):
169 | def __init__(
170 | self,
171 | dim: int,
172 | depth: int,
173 | heads: int,
174 | dim_head: int,
175 | mlp_dim: int,
176 | dropout: float = 0.0,
177 | ):
178 | super().__init__()
179 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
180 | self.norm = nn.LayerNorm(dim)
181 | self.layers = nn.ModuleList([])
182 | for _ in range(depth):
183 | self.layers.append(
184 | nn.ModuleList(
185 | [
186 | MultiHeadSelfAttention(
187 | dim, heads=heads, dim_head=dim_head, dropout=dropout
188 | ),
189 | FeedForward(dim, mlp_dim, dropout=dropout),
190 | ]
191 | )
192 | )
193 |
194 | def forward(
195 | self, x: Dict[str, torch.Tensor], mask: Optional[Dict[str, torch.Tensor]] = None
196 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
197 | keys = list(x.keys())
198 | x = torch.stack([x[k] for k in keys], 1)
199 | scores = []
200 | cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=x.shape[0]).to(x.device)
201 | x = torch.cat((cls_tokens, x), dim=1)
202 | if mask is not None:
203 | mask = torch.stack(
204 | [torch.ones(x.shape[0]).bool().to(mask[keys[0]].device)]
205 | + [mask[k] for k in keys],
206 | 1,
207 | )
208 |
209 | for attn, ff in self.layers:
210 | temp_x, score = attn(x, mask=mask)
211 | scores.append(score)
212 | x = temp_x + x
213 | x = ff(x) + x
214 | x = x[:, 0]
215 | scores = self._get_attn_scores(torch.stack(scores))
216 | return self.norm(x), {k: scores[:, i] for i, k in enumerate(keys)}
217 |
218 | @staticmethod
219 | def _get_attn_scores(x):
220 | return x.mean(dim=2)[:, :, 0, 1:].mean(0)
221 |
--------------------------------------------------------------------------------
/drim/fusion/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from einops import rearrange
4 | from typing import Optional
5 |
6 |
7 | def softmax_one(x, dim=None):
8 | """
9 | Quiet softmax function as presented by Evan Miller in his blog post "Attention is Off by One".
10 |
11 | https://www.evanmiller.org/attention-is-off-by-one.html
12 | https://github.com/kyegomez/AttentionIsOFFByOne
13 | """
14 | x = x - x.max(dim=dim, keepdim=True).values
15 | exp_x = torch.exp(x)
16 |
17 | return exp_x / (1 + exp_x.sum(dim=dim, keepdim=True))
18 |
19 |
20 | class FeedForward(nn.Module):
21 | def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0) -> None:
22 | super().__init__()
23 | self.net = nn.Sequential(
24 | nn.LayerNorm(dim),
25 | nn.Linear(dim, hidden_dim),
26 | nn.GELU(),
27 | nn.Dropout(dropout),
28 | nn.Linear(hidden_dim, dim),
29 | nn.Dropout(dropout),
30 | )
31 |
32 | def forward(self, x: torch.Tensor) -> torch.Tensor:
33 | return self.net(x)
34 |
35 |
36 | class MultiHeadSelfAttention(nn.Module):
37 | """
38 | Adapted from lucidrains vit-pytorch. It handles mask alongside the features vector.
39 | """
40 |
41 | def __init__(
42 | self, dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0
43 | ) -> None:
44 | super().__init__()
45 | inner_dim = dim_head * heads
46 | project_out = not (heads == 1 and dim_head == dim)
47 |
48 | self.heads = heads
49 | self.scale = dim_head**-0.5
50 |
51 | self.norm = nn.LayerNorm(dim)
52 |
53 | self.dropout = nn.Dropout(dropout)
54 |
55 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
56 |
57 | self.to_out = (
58 | nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
59 | if project_out
60 | else nn.Identity()
61 | )
62 |
63 | def forward(
64 | self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
65 | ) -> torch.Tensor:
66 | x = self.norm(x)
67 |
68 | qkv = self.to_qkv(x).chunk(3, dim=-1)
69 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
70 |
71 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
72 |
73 | if mask is not None:
74 | mask_value = torch.finfo(dots.dtype).min
75 | mask = mask[:, None, :, None] * mask[:, None, None, :]
76 | dots.masked_fill_(~mask, mask_value)
77 |
78 | attn = softmax_one(dots, dim=-1)
79 | attn_d = self.dropout(attn)
80 |
81 | out = torch.matmul(attn_d, v)
82 | out = rearrange(out, "b h n d -> b n (h d)")
83 | return self.to_out(out), attn
84 |
--------------------------------------------------------------------------------
/drim/helpers.py:
--------------------------------------------------------------------------------
1 | # Standard libraries
2 | from typing import Dict
3 |
4 | # Third-party libraries
5 | import pandas as pd
6 | from omegaconf import DictConfig
7 | from pycox.models import LogisticHazard
8 |
9 | # Local dependencies
10 | from drim.utils import prepare_data, get_target_survival, log_transform
11 |
12 |
13 | def get_encoder(modality: str, cfg: DictConfig):
14 | if modality == "DNAm":
15 | from drim.dnam import DNAmEncoder
16 |
17 | encoder = DNAmEncoder(
18 | embedding_dim=cfg.general.dim, dropout=cfg.general.dropout
19 | )
20 | elif modality == "RNA":
21 | from drim.rna import RNAEncoder
22 |
23 | encoder = RNAEncoder(embedding_dim=cfg.general.dim, dropout=cfg.general.dropout)
24 | elif modality == "WSI":
25 | from drim.wsi import WSIEncoder
26 |
27 | encoder = WSIEncoder(
28 | cfg.general.dim,
29 | depth=1,
30 | heads=8,
31 | dropout=cfg.general.dropout,
32 | emb_dropout=cfg.general.dropout,
33 | )
34 | elif modality == "MRI":
35 | from drim.mri import MRIEmbeddingEncoder
36 |
37 | encoder = MRIEmbeddingEncoder(
38 | embedding_dim=cfg.general.dim, dropout=cfg.general.dropout
39 | )
40 | else:
41 | raise NotImplementedError(f"Modality {modality} not implemented")
42 |
43 | return encoder
44 |
45 |
46 | def get_decoder(modality: str, cfg: DictConfig):
47 | if modality == "DNAm":
48 | from drim.dnam import DNAmDecoder
49 |
50 | decoder = DNAmDecoder(
51 | embedding_dim=cfg.general.dim, dropout=cfg.general.dropout
52 | )
53 | elif modality == "RNA":
54 | from drim.rna import RNADecoder
55 |
56 | decoder = RNADecoder(embedding_dim=cfg.general.dim, dropout=cfg.general.dropout)
57 | elif modality == "WSI":
58 | from drim.wsi import WSIDecoder
59 |
60 | decoder = WSIDecoder(cfg.general.dim, dropout=cfg.general.dropout)
61 | elif modality == "MRI":
62 | from drim.mri import MRIEmbeddingDecoder
63 |
64 | decoder = MRIEmbeddingDecoder(
65 | embedding_dim=cfg.general.dim, dropout=cfg.general.dropout
66 | )
67 | else:
68 | raise NotImplementedError(f"Modality {modality} not implemented")
69 |
70 | return decoder
71 |
72 |
73 | def get_datasets(
74 | dataframes: Dict[str, pd.DataFrame],
75 | modality: str,
76 | fold: int,
77 | return_mask: bool = False,
78 | ):
79 | if modality == "DNAm":
80 | from drim.dnam import DNAmDataset
81 |
82 | datasets = {
83 | split: DNAmDataset(dataframe, return_mask=return_mask)
84 | for split, dataframe in dataframes.items()
85 | }
86 | elif modality == "RNA":
87 | from drim.rna import RNADataset
88 | from joblib import load
89 |
90 | rna_processor = load(f"./data/rna_preprocessors/trf_{int(fold)}.joblib")
91 | datasets = {
92 | split: RNADataset(dataframe, rna_processor, return_mask=return_mask)
93 | for split, dataframe in dataframes.items()
94 | }
95 | elif modality == "WSI":
96 | from drim.wsi import WSIDataset
97 |
98 | datasets = {}
99 | datasets["train"] = WSIDataset(
100 | dataframes["train"], k=10, is_train=True, return_mask=return_mask
101 | )
102 | datasets["val"] = WSIDataset(
103 | dataframes["val"], k=10, is_train=False, return_mask=return_mask
104 | )
105 | datasets["test"] = WSIDataset(
106 | dataframes["test"], k=10, is_train=False, return_mask=return_mask
107 | )
108 |
109 | elif modality == "MRI":
110 | from drim.mri import MRIEmbeddingDataset
111 |
112 | datasets = {
113 | split: MRIEmbeddingDataset(dataframe, return_mask=return_mask)
114 | for split, dataframe in dataframes.items()
115 | }
116 | else:
117 | raise NotImplementedError(f"Modality {modality} not implemented")
118 |
119 | return datasets
120 |
121 |
122 | def get_targets(dataframes: Dict[str, pd.DataFrame], n_outs: int):
123 | labtrans = LogisticHazard.label_transform(n_outs)
124 | labtrans.fit(*get_target_survival(dataframes["train"]))
125 |
126 | return {
127 | split: labtrans.transform(*get_target_survival(dataframe))
128 | for split, dataframe in dataframes.items()
129 | }, labtrans.cuts
130 |
--------------------------------------------------------------------------------
/drim/logger.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from loguru import logger
3 |
4 | logger.remove()
5 | logger.add(
6 | sys.stdout,
7 | colorize=True,
8 | format=(
9 | "GBMLGG:" " {message}"
10 | ),
11 | )
12 |
--------------------------------------------------------------------------------
/drim/losses.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class DiscriminatorLoss(nn.Module):
7 | """
8 | Implementation of the traditionnal discriminator loss to approximate Jensen-Shannon Divergence.
9 | """
10 |
11 | def __init__(self) -> None:
12 | super().__init__()
13 |
14 | def forward(
15 | self, real_logits: torch.Tensor, fake_logits: torch.Tensor
16 | ) -> torch.Tensor:
17 | """
18 | Compute discriminator loss
19 |
20 | Args:
21 | real_logits:
22 | torch.tensor containing logits extracted from p(x)p(y), size (bsz, 1)
23 | fake_logits:
24 | torch.tensor containing logits extracted from p(x, y), size (bsz, 1)
25 |
26 | Returns:
27 | The computed scalar loss.
28 | """
29 | # Discriminator should predict real logits as logits from the real distribution
30 | discriminator_real = torch.nn.functional.binary_cross_entropy_with_logits(
31 | input=real_logits, target=torch.ones_like(real_logits)
32 | )
33 | # Discriminator should predict fake logits as logits from the generated distribution
34 | discriminator_fake = torch.nn.functional.binary_cross_entropy_with_logits(
35 | input=fake_logits, target=torch.zeros_like(fake_logits)
36 | )
37 |
38 | discriminator_loss = discriminator_real + discriminator_fake
39 |
40 | return discriminator_loss
41 |
42 |
43 | class ContrastiveLoss(nn.Module):
44 | """
45 | This code is adapted from the code for supervised contrastive learning (https://arxiv.org/pdf/2004.11362.pdf).
46 | """
47 |
48 | def __init__(
49 | self,
50 | temperature: float = 0.07,
51 | contrast_mode: str = "all",
52 | base_temperature: float = 0.07,
53 | ) -> None:
54 | super(ContrastiveLoss, self).__init__()
55 | self.temperature = temperature
56 | self.contrast_mode = contrast_mode
57 | self.base_temperature = base_temperature
58 |
59 | def forward(
60 | self, x: Dict[str, torch.Tensor], mask: Dict[str, torch.Tensor] = None
61 | ) -> torch.Tensor:
62 | """Compute shared loss for our disentangled framework.
63 |
64 | Args:
65 | x:
66 | Dict associating each modality (key) to the corresponding embedding, tensor of size (bsz, d)
67 | mask:
68 | Dict associating each modality (key) to a boolean tensor of size (bsz,) indicating if the modality is indeed present in the minibatch
69 | Returns:
70 | The computed scalar loss.
71 | """
72 | keys = list(x.keys())
73 | features = torch.stack([x[key] for key in keys], 1)
74 | modality_mask = torch.stack([mask[key] for key in keys], 1)
75 | batch_size = features.shape[0]
76 |
77 | contrast_count = features.shape[1]
78 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
79 | modality_mask = torch.cat(torch.unbind(modality_mask, dim=1), dim=0).view(1, -1)
80 |
81 | # compute logits
82 | anchor_dot_contrast = torch.div(
83 | torch.matmul(contrast_feature, contrast_feature.T), self.temperature
84 | )
85 | # for numerical stability
86 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
87 | logits = anchor_dot_contrast - logits_max.detach()
88 |
89 | mask = torch.eye(batch_size, dtype=torch.float32).to(modality_mask.device)
90 | mask = mask.repeat(contrast_count, contrast_count)
91 | # mask-out self-contrast cases
92 | logits_mask = torch.scatter(
93 | (modality_mask[:, None, :, None] * modality_mask[:, None, None, :])
94 | .squeeze()
95 | .long(),
96 | 1,
97 | torch.arange(batch_size * contrast_count)
98 | .view(-1, 1)
99 | .to(modality_mask.device),
100 | 0,
101 | )
102 | mask = mask * logits_mask
103 | # compute log_prob
104 | exp_logits = torch.exp(logits) * logits_mask
105 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
106 |
107 | # compute mean of log-likelihood over positive
108 | mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-6)
109 | # loss
110 | loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos
111 | loss = loss.view(contrast_count, batch_size).mean()
112 |
113 | return loss
114 |
115 |
116 | class MMOLoss(nn.Module):
117 | """
118 | Loss function used in the paper:
119 | Deep Orthogonal Fusion: Multimodal: Prognostic Biomarker Discovery Integrating Radiology,
120 | Pathology, Genomic, and Clinical Data.
121 | https://link.springer.com/chapter/10.1007/978-3-030-87240-3_64
122 |
123 | N.B: In the paper and by design, this loss is not well-suited to deal with missing modality, we propose here an alternative.
124 | """
125 |
126 | def __init__(self) -> None:
127 | super(MMOLoss, self).__init__()
128 |
129 | def forward(
130 | self, x: Dict[str, torch.Tensor], mask: Dict[str, torch.Tensor] = None
131 | ) -> torch.Tensor:
132 | """
133 | Compute auxiliary orthogonal loss.
134 |
135 | Args:
136 | x:
137 | Dict associating each modality (key) to the corresponding embedding, tensor of size (bsz, d)
138 | mask:
139 | Dict associating each modality (key) to a boolean tensor of size (bsz,) indicating if the modality is indeed present in the minibatch,
140 | Returns:
141 | The computed scalar loss.
142 | """
143 | if mask is None:
144 | mask = {key: torch.ones_like(x[key]).bool() for key in x.keys()}
145 |
146 | loss = 0.0
147 | for key in x.keys():
148 | loss += torch.max(
149 | torch.tensor(1), torch.linalg.matrix_norm(x[key][mask[key]], ord="nuc")
150 | )
151 |
152 | full = torch.stack(list(x.values()))[torch.stack(list(mask.values()))]
153 | loss -= torch.linalg.matrix_norm(full, ord="nuc")
154 | loss /= full.size(0)
155 | return loss
156 |
--------------------------------------------------------------------------------
/drim/models.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 | import torch.nn as nn
3 | import torch
4 |
5 |
6 | class _BaseWrapper(nn.Module):
7 | """
8 | Base wrapper for taking the final embedding and linking to the survival task
9 | """
10 |
11 | def __init__(
12 | self, encoder: nn.Module, embedding_dim: int, n_outs: int, device: str = "cuda"
13 | ) -> None:
14 | super().__init__()
15 | self.encoder = encoder
16 | self.final = nn.Linear(embedding_dim, n_outs).to(device)
17 | self.device = device
18 |
19 |
20 | class UnimodalWrapper(_BaseWrapper):
21 | """
22 | Survival wrapper taking an encoder et linking it to the survival task
23 | """
24 |
25 | def forward(self, x: torch.Tensor, return_embedding: bool = False) -> torch.Tensor:
26 | """
27 | Args:
28 | x:
29 | tensor containing the raw input (bsz, n_features)
30 | return_embedding:
31 | boolean indicating if the pre-survival layer embedding will be used
32 |
33 | Returns
34 | Whether the predicted hazards or the predicted hazards and the last embedding.
35 | """
36 | x = x.to(self.device)
37 | x = self.encoder(x)
38 | if return_embedding:
39 | return self.final(x), x
40 | else:
41 | return self.final(x)
42 |
43 |
44 | class MultimodalWrapper(_BaseWrapper):
45 | def forward(
46 | self,
47 | x: Dict[str, torch.Tensor],
48 | mask: Optional[Dict[str, torch.Tensor]] = None,
49 | return_embedding: bool = False,
50 | ) -> torch.Tensor:
51 | """
52 | Args:
53 | x:
54 | Dict associating each modality (key) to the raw input tensor (bsz, n_features)
55 | mask:
56 | Dict associating each modality (key) to a boolean tensor of size (bsz,) indicating if the modality is indeed present in the minibatch
57 | return_embedding:
58 | boolean indicating if the pre-survival layer embedding will be used
59 |
60 | Returns
61 | Whether the predicted hazards or the predicted hazards and the last embedding.
62 | """
63 | x = self.encoder(x, mask=mask, return_embedding=return_embedding)
64 | if not isinstance(x, tuple):
65 | # for fine-tuning or simple multimodal training
66 | return self.final(x.to(self.device))
67 | else:
68 | if len(x) == 2 and isinstance(x[0], dict):
69 | # for the self supervised settings
70 | return x
71 | else:
72 | out = self.final(x[0].to(self.device))
73 | if len(x) == 2:
74 | # auxiliary training
75 | return out, x[1]
76 | else:
77 | # disentangled training
78 | return out, x[1], x[2]
79 |
80 |
81 | class Discriminator(nn.Module):
82 | """
83 | Generic discriminator for the adversarial training
84 | """
85 |
86 | def __init__(self, embedding_dim: int, dropout: float) -> None:
87 | super().__init__()
88 | self.model = nn.Sequential(
89 | nn.Linear(embedding_dim, 1024),
90 | nn.Dropout(dropout),
91 | nn.GELU(),
92 | nn.Linear(1024, embedding_dim),
93 | nn.Dropout(dropout),
94 | nn.GELU(),
95 | nn.Linear(embedding_dim, 1),
96 | )
97 |
98 | def forward(self, x: torch.Tensor) -> torch.Tensor:
99 | return self.model(x)
100 |
--------------------------------------------------------------------------------
/drim/mri/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import MRIEmbeddingDataset
2 | from .models import MRIEmbeddingDecoder, MRIEmbeddingEncoder
3 |
--------------------------------------------------------------------------------
/drim/mri/datasets.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, List, Dict, Tuple
2 | import os
3 | import numpy as np
4 | import nibabel as nib
5 | from abc import abstractmethod
6 | from torch.utils.data import Dataset
7 | import monai
8 | import torch
9 | import pandas as pd
10 | from ..datasets import _BaseDataset
11 |
12 | __all__ = ["DatasetBraTS", "DatasetBraTSTumorCentered", "MRIDataset"]
13 |
14 |
15 | class _BaseDatasetBraTS(Dataset):
16 | def __init__(
17 | self,
18 | path: str,
19 | modality: Union[str, List[str]],
20 | patients: Optional[List[str]],
21 | return_mask: bool,
22 | transform: Optional["monai.transforms"] = None,
23 | ) -> None:
24 | self.path = path
25 | self.modality = modality
26 | self.patients = np.array(os.listdir(path)) if patients is None else patients
27 | self.transform = transform
28 | self.return_mask = return_mask
29 |
30 | def __len__(self):
31 | return len(self.patients)
32 |
33 | @abstractmethod
34 | def __getitem__(self, idx):
35 | raise NotImplementedError
36 |
37 | def _load_nifti_modalities(self, patient: str) -> np.ndarray:
38 | if len(self.modality) == 1:
39 | img = nib.load(
40 | os.path.join(
41 | self.path, patient, patient + "_" + self.modality[0] + ".nii.gz"
42 | )
43 | ).get_fdata()
44 | img = np.expand_dims(img, 0)
45 |
46 | else:
47 | early_fused = []
48 | for modality in self.modality:
49 | early_fused.append(
50 | nib.load(
51 | os.path.join(
52 | self.path, patient, patient + "_" + modality + ".nii.gz"
53 | )
54 | ).get_fdata()
55 | )
56 |
57 | img = np.stack(early_fused)
58 |
59 | return img
60 |
61 | @staticmethod
62 | def _monaify(img: np.ndarray, mask: Optional[np.ndarray]) -> Dict[str, np.ndarray]:
63 | item = {"image": img}
64 | if mask is not None:
65 | item["mask"] = mask
66 |
67 | return item
68 |
69 | def _load_mask(self, patient: str) -> np.ndarray:
70 | return nib.load(
71 | os.path.join(self.path, patient, patient + "_seg.nii.gz")
72 | ).get_fdata()
73 |
74 |
75 | class DatasetBraTS(_BaseDatasetBraTS):
76 | def __getitem__(self, idx) -> Dict[str, Union[torch.tensor, np.ndarray]]:
77 | patient = self.patients[idx]
78 | mask = (
79 | super(DatasetBraTS, self)._load_mask(patient=patient)
80 | if self.return_mask
81 | else None
82 | )
83 | img = super(DatasetBraTS, self)._load_nifti_modalities(patient=patient)
84 |
85 | item = super(DatasetBraTS, self)._monaify(img=img, mask=mask)
86 | if self.transform is not None:
87 | item = self.transform(item)
88 |
89 | return item
90 |
91 |
92 | class DatasetBraTSTumorCentered(_BaseDatasetBraTS):
93 | def __init__(
94 | self,
95 | path: str,
96 | modality: Union[str, List[str]],
97 | patients: Optional[List[str]],
98 | sizes: Tuple[int, ...],
99 | transform: Optional["monai.transforms"] = None,
100 | return_mask: bool = False,
101 | ) -> None:
102 | super(DatasetBraTSTumorCentered, self).__init__(
103 | path=path,
104 | modality=modality,
105 | patients=patients,
106 | return_mask=return_mask,
107 | transform=transform,
108 | )
109 | self.sizes = sizes
110 |
111 | def __getitem__(self, idx) -> Dict[str, Union[torch.tensor, np.ndarray]]:
112 | patient = self.patients[idx]
113 | mask = super(DatasetBraTSTumorCentered, self)._load_mask(patient=patient)
114 | img = super(DatasetBraTSTumorCentered, self)._load_nifti_modalities(
115 | patient=patient
116 | )
117 | item = self._compute_subvolumes(img=img, mask=mask)
118 | if self.transform is not None:
119 | item = self.transform(item)
120 |
121 | return item
122 |
123 | def _compute_subvolumes(
124 | self, img: np.ndarray, mask: np.ndarray
125 | ) -> Dict[str, np.ndarray]:
126 | centroid = self._compute_centroid(mask=mask)
127 | # bounds for boolean indexing
128 | lower_bound, upper_bound = self._get_bounds(
129 | centroid=centroid, input_dims=img.shape[1:]
130 | )
131 | img = img[
132 | :,
133 | lower_bound[0] : upper_bound[0],
134 | lower_bound[1] : upper_bound[1],
135 | lower_bound[2] : upper_bound[2],
136 | ]
137 | mask = (
138 | mask[
139 | lower_bound[0] : upper_bound[0],
140 | lower_bound[1] : upper_bound[1],
141 | lower_bound[2] : upper_bound[2],
142 | ]
143 | if self.return_mask
144 | else None
145 | )
146 | return super(DatasetBraTSTumorCentered, self)._monaify(img=img, mask=mask)
147 |
148 | def _get_bounds(
149 | self, centroid: np.ndarray, input_dims: Tuple[int, ...]
150 | ) -> Tuple[np.ndarray, np.ndarray]:
151 | lower = (centroid - (np.array(self.sizes) / 2)).astype(int)
152 | upper = (centroid + (np.array(self.sizes) / 2)).astype(int)
153 | return np.clip(lower, 0, input_dims), np.clip(upper, 0, input_dims)
154 |
155 | @staticmethod
156 | def _compute_centroid(mask: np.ndarray) -> np.ndarray:
157 | return np.mean(np.argwhere(mask), axis=0).astype(int)
158 |
159 |
160 | class MRIProcessor:
161 | def __init__(
162 | self,
163 | base_path: str,
164 | tumor_centered: bool,
165 | transform: "augmentations",
166 | modalities: List[str] = ["t1", "t1ce", "t2", "flair"],
167 | size: Tuple[int, ...] = (64, 64, 64),
168 | ) -> None:
169 | self.base_path = base_path
170 | self.size = size
171 | self.tumor_centered = tumor_centered
172 | self.modalities = modalities
173 | self.transform = transform
174 |
175 | def _compute_subvolumes(
176 | self, img: np.ndarray, mask: np.ndarray
177 | ) -> Dict[str, np.ndarray]:
178 | centroid = self._compute_centroid(mask=mask)
179 | # bounds for boolean indexing
180 | lower_bound, upper_bound = self._get_bounds(
181 | centroid=centroid, input_dims=img.shape[1:]
182 | )
183 | img = img[
184 | :,
185 | lower_bound[0] : upper_bound[0],
186 | lower_bound[1] : upper_bound[1],
187 | lower_bound[2] : upper_bound[2],
188 | ]
189 | return img, mask
190 |
191 | def _get_bounds(
192 | self, centroid: np.ndarray, input_dims: Tuple[int, ...]
193 | ) -> Tuple[np.ndarray, np.ndarray]:
194 | lower = (centroid - (np.array(self.size) / 2)).astype(int)
195 | upper = (centroid + (np.array(self.size) / 2)).astype(int)
196 | return np.clip(lower, 0, input_dims), np.clip(upper, 0, input_dims)
197 |
198 | @staticmethod
199 | def _compute_centroid(mask: np.ndarray) -> np.ndarray:
200 | return np.mean(np.argwhere(mask), axis=0).astype(int)
201 |
202 | def _load_nifti_modalities(self, base_path: str, patient: str) -> np.ndarray:
203 | if len(self.modalities) == 1:
204 | img = nib.load(
205 | os.path.join(base_path, patient + "_" + self.modality[0] + ".nii.gz")
206 | ).get_fdata()
207 | img = np.expand_dims(img, 0)
208 |
209 | else:
210 | early_fused = []
211 | for modality in self.modalities:
212 | early_fused.append(
213 | nib.load(
214 | os.path.join(base_path, patient + "_" + modality + ".nii.gz")
215 | ).get_fdata()
216 | )
217 |
218 | img = np.stack(early_fused)
219 |
220 | return img
221 |
222 | def _load_mask(self, base_path: str, patient: str) -> np.ndarray:
223 | return nib.load(os.path.join(base_path, patient + "_seg.nii.gz")).get_fdata()
224 |
225 | def process(self) -> torch.Tensor:
226 | patient = self.base_path.split("/")[-1]
227 | img = self._load_nifti_modalities(base_path=self.base_path, patient=patient)
228 | if self.tumor_centered:
229 | mask = self._load_mask(base_path=self.base_path, patient=patient)
230 | img, _ = self._compute_subvolumes(img=img, mask=mask)
231 |
232 | img = self.transform(img)
233 |
234 | return img
235 |
236 |
237 | class MRIEmbeddingDataset(_BaseDataset):
238 | def __init__(self, dataframe, return_mask: bool = False):
239 | super().__init__(dataframe, return_mask)
240 |
241 | def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, bool], torch.Tensor]:
242 | sample = self.dataframe.iloc[idx]
243 | if not pd.isna(sample.MRI):
244 | mri = torch.from_numpy(
245 | pd.read_csv(os.path.join(sample.MRI, "embedding.csv")).values[0]
246 | ).float()
247 | mask = True
248 | else:
249 | mri = torch.zeros(512)
250 | mask = False
251 |
252 | if self.return_mask:
253 | return mri, mask
254 | else:
255 | return mri
256 |
257 |
258 | class MRIDataset(_BaseDataset):
259 | def __init__(
260 | self,
261 | dataframe: pd.DataFrame,
262 | transform: "augmentations",
263 | tumor_centered: bool,
264 | modalities: List[str] = ["t1ce", "flair"],
265 | return_mask: bool = False,
266 | sizes: Tuple[int, ...] = (64, 64, 64),
267 | ) -> None:
268 | super().__init__(dataframe, return_mask)
269 | self.sizes = sizes
270 | self.modalities = modalities
271 | self.tumor_centered = tumor_centered
272 | self.transform = transform
273 |
274 | def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, bool], torch.Tensor]:
275 | sample = self.dataframe.iloc[idx]
276 | if not pd.isna(sample.MRI):
277 | processor = MRIProcessor(
278 | sample.MRI,
279 | transform=self.transform,
280 | modalities=self.modalities,
281 | size=self.sizes,
282 | tumor_centered=self.tumor_centered,
283 | )
284 | mri = processor.process()
285 | mask = True
286 | else:
287 | mri = torch.zeros(
288 | len(self.modalities), self.sizes[0], self.sizes[1], self.sizes[1]
289 | ).float()
290 | mask = False
291 |
292 | if self.return_mask:
293 | return mri, mask
294 | else:
295 | return mri
296 |
--------------------------------------------------------------------------------
/drim/mri/models.py:
--------------------------------------------------------------------------------
1 | # Standard libraries
2 | from typing import Tuple, Union
3 |
4 | # Third-party libraries
5 | import torch
6 | import torch.nn as nn
7 | from monai.networks.nets import resnet10
8 |
9 |
10 | class MRIEmbeddingEncoder(nn.Module):
11 | def __init__(self, embedding_dim: int, dropout: float):
12 | super(MRIEmbeddingEncoder, self).__init__()
13 | self.encoder = nn.Sequential(
14 | nn.Linear(512, 256),
15 | nn.Dropout(dropout),
16 | nn.GELU(),
17 | nn.Linear(256, embedding_dim),
18 | )
19 |
20 | def forward(self, x: torch.tensor) -> torch.tensor:
21 | return self.encoder(x)
22 |
23 |
24 | class MRIEncoder(nn.Module):
25 | def __init__(
26 | self, in_channels: int, embedding_dim: int = 512, projection_head: bool = True
27 | ) -> None:
28 | super().__init__()
29 | if projection_head:
30 | self.projection_head = nn.Sequential(
31 | nn.Linear(embedding_dim, embedding_dim),
32 | nn.ReLU(inplace=True),
33 | nn.Linear(embedding_dim, int(embedding_dim / 2)),
34 | )
35 |
36 | self.encoder = resnet10(
37 | spatial_dims=3, n_input_channels=in_channels, num_classes=1
38 | )
39 | if embedding_dim != 512:
40 | self.encoder.fc = nn.Sequential(
41 | nn.Linear(512, 256), nn.LeakyReLU(), nn.Linear(256, embedding_dim)
42 | )
43 | else:
44 | self.encoder.fc = nn.Identity()
45 |
46 | def forward(
47 | self, x: torch.tensor
48 | ) -> Union[torch.tensor, Tuple[torch.tensor, torch.tensor]]:
49 | x = self.encoder(x)
50 | if hasattr(self, "projection_head"):
51 | return self.projection_head(x), x
52 | else:
53 | return x
54 |
55 |
56 | class MRIEmbeddingDecoder(nn.Module):
57 | def __init__(self, embedding_dim: int, dropout: float):
58 | super(MRIEmbeddingDecoder, self).__init__()
59 | self.decoder = nn.Sequential(
60 | nn.Linear(embedding_dim, 256),
61 | nn.Dropout(dropout),
62 | nn.GELU(),
63 | nn.Linear(256, 512),
64 | )
65 |
66 | def forward(self, x: torch.tensor) -> torch.tensor:
67 | return self.decoder(x)
68 |
--------------------------------------------------------------------------------
/drim/mri/transforms.py:
--------------------------------------------------------------------------------
1 | from monai.transforms import (
2 | Compose,
3 | NormalizeIntensityd,
4 | SpatialPadd,
5 | CopyItemsd,
6 | OneOf,
7 | RandCoarseDropoutd,
8 | RandCoarseShuffled,
9 | ToTensord,
10 | NormalizeIntensity,
11 | SpatialPad,
12 | )
13 |
14 |
15 | def get_tumor_transforms(roi_size):
16 | return Compose(
17 | [
18 | NormalizeIntensityd(keys=["image"], channel_wise=True, nonzero=True),
19 | SpatialPadd(keys=["image"], spatial_size=roi_size),
20 | CopyItemsd(
21 | keys=["image"], times=1, names=["image_2"], allow_missing_keys=False
22 | ),
23 | OneOf(
24 | transforms=[
25 | RandCoarseDropoutd(
26 | keys=["image"],
27 | prob=1.0,
28 | holes=8,
29 | spatial_size=10,
30 | dropout_holes=True,
31 | max_spatial_size=18,
32 | ),
33 | RandCoarseDropoutd(
34 | keys=["image"],
35 | prob=1.0,
36 | holes=6,
37 | spatial_size=18,
38 | dropout_holes=False,
39 | max_spatial_size=24,
40 | ),
41 | ]
42 | ),
43 | RandCoarseShuffled(
44 | keys=["image"], prob=0.8, holes=4, spatial_size=4, max_spatial_size=4
45 | ),
46 | OneOf(
47 | transforms=[
48 | RandCoarseDropoutd(
49 | keys=["image_2"],
50 | prob=1.0,
51 | holes=8,
52 | spatial_size=10,
53 | dropout_holes=True,
54 | max_spatial_size=18,
55 | ),
56 | RandCoarseDropoutd(
57 | keys=["image_2"],
58 | prob=1.0,
59 | holes=6,
60 | spatial_size=18,
61 | dropout_holes=False,
62 | max_spatial_size=24,
63 | ),
64 | ]
65 | ),
66 | RandCoarseShuffled(
67 | keys=["image_2"], prob=0.8, holes=10, spatial_size=8, max_spatial_size=8
68 | ),
69 | ToTensord(keys=["image", "image_2"]),
70 | ]
71 | )
72 |
73 |
74 | tumor_transfo = Compose(
75 | [
76 | NormalizeIntensity(nonzero=True, channel_wise=True),
77 | SpatialPad(spatial_size=(64, 64, 64)),
78 | ]
79 | )
80 |
--------------------------------------------------------------------------------
/drim/multimodal/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import MultimodalDataset
2 | from .models import *
3 |
--------------------------------------------------------------------------------
/drim/multimodal/datasets.py:
--------------------------------------------------------------------------------
1 | from ..datasets import _BaseDataset
2 | from typing import Dict, List, Tuple, Union
3 | import torch
4 |
5 |
6 | class MultimodalDataset(_BaseDataset):
7 | def __init__(
8 | self, datasets: Dict[str, _BaseDataset], return_mask: bool = True
9 | ) -> None:
10 | modalities = list(datasets.keys())
11 | self._modality_sanity_check(modalities)
12 | if return_mask:
13 | for modality in modalities:
14 | assert datasets[
15 | modality
16 | ].return_mask, f"The dataset for modality {modality} does not return a mask, please set return_mask to False"
17 | self.return_mask = return_mask
18 | self.datasets = datasets
19 |
20 | def __len__(self):
21 | return len(self.datasets.values().__iter__().__next__())
22 |
23 | def __getitem__(
24 | self, idx: int
25 | ) -> Union[
26 | Tuple[Dict[str, torch.Tensor], Dict[str, bool]], Dict[str, torch.Tensor]
27 | ]:
28 | x = {}
29 | if self.return_mask:
30 | mask = {}
31 |
32 | for modality in self.datasets.keys():
33 | out = self.datasets[modality][idx]
34 | if self.return_mask:
35 | x[modality], mask[modality] = out[0], out[1]
36 | else:
37 | x[modality] = out
38 |
39 | if self.return_mask:
40 | return x, mask
41 | else:
42 | return x
43 |
44 | @staticmethod
45 | def _modality_sanity_check(
46 | modalities,
47 | available_modalities: List[str] = ["RNA", "MRI", "Clinical", "DNAm", "WSI"],
48 | ) -> None:
49 | for modality in modalities:
50 | assert (
51 | modality in available_modalities
52 | ), f"The requested modality: {modality} is not available, please choose modalities among {available_modalities}"
53 |
--------------------------------------------------------------------------------
/drim/multimodal/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import Dict, List, Optional, Union, Tuple
4 |
5 | __all__ = ["MultimodalModel", "MultimodalEncoder", "DRIMSurv", "DRIMU"]
6 |
7 |
8 | class MultimodalEncoder(nn.Module):
9 | def __init__(
10 | self, encoders: Dict[str, nn.Module], devices: Optional[Dict[str, str]] = None
11 | ) -> None:
12 | super().__init__()
13 | modalities = list(encoders.keys())
14 | self._modality_sanity_check(modalities)
15 | if not devices:
16 | devices = {k: next(v.parameters()).device for k, v in encoders.items()}
17 |
18 | self.devices = devices
19 | for key, model in encoders.items():
20 | setattr(self, f"{key.lower()}_encoder", model.to(self.devices[key]))
21 |
22 | def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
23 | return self._forward_encoders(x)
24 |
25 | def _forward_encoders(
26 | self, x: Dict[str, torch.Tensor], unique: Optional[bool] = None
27 | ) -> Dict[str, torch.Tensor]:
28 | # put the input tensors to the corresponding device
29 | x = {k: v.to(self.devices[k]) for k, v in x.items()}
30 | # forward pass
31 | if unique:
32 | x = {k: getattr(self, f"{k.lower()}_encoder_u")(v) for k, v in x.items()}
33 | else:
34 | x = {k: getattr(self, f"{k.lower()}_encoder")(v) for k, v in x.items()}
35 |
36 | # normalize embeddings
37 | x = {k: nn.functional.normalize(v, p=2, dim=-1) for k, v in x.items()}
38 |
39 | return x
40 |
41 | def _put_on_devices(
42 | self,
43 | x: Dict[str, torch.Tensor],
44 | mask: Optional[Dict[str, torch.Tensor]],
45 | key: str,
46 | ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]]]:
47 | x = {k: v.to(self.devices[key]) for k, v in x.items()}
48 | if mask is not None:
49 | mask = {k: v.to(self.devices[key]) for k, v in mask.items()}
50 | return x, mask
51 |
52 | @staticmethod
53 | def _modality_sanity_check(
54 | modalities: List[str],
55 | available_modalities: List[str] = ["RNA", "MRI", "DNAm", "WSI"],
56 | ) -> None:
57 | for modality in modalities:
58 | assert (
59 | modality in available_modalities
60 | ), f"The requested modality: {modality} is not available, please choose modalities among {available_modalities}"
61 |
62 |
63 | class MultimodalModel(MultimodalEncoder):
64 | def __init__(
65 | self,
66 | encoders: Dict[str, nn.Module],
67 | fusion: nn.Module,
68 | devices: Optional[Dict[str, str]] = None,
69 | ) -> None:
70 | super().__init__(encoders, devices)
71 | if "fusion" not in self.devices.keys():
72 | try:
73 | self.devices["fusion"] = next(fusion.parameters()).device
74 | except:
75 | # if fusion do not have parameters, then it is a function
76 | self.devices["fusion"] = "cuda:0"
77 | self.fusion = fusion
78 |
79 | def forward(
80 | self,
81 | x: Dict[str, torch.Tensor],
82 | mask: Optional[Dict[str, torch.Tensor]] = None,
83 | return_embedding: bool = False,
84 | ) -> Union[torch.Tensor, Tuple[torch.tensor, Dict[str, torch.Tensor]]]:
85 | x = super().forward(x)
86 |
87 | x, mask = self._put_on_devices(x, mask, "fusion")
88 | fused = self.fusion(x, mask)
89 | # if fusion returns a tuple, then it is a masked attention fusion
90 | if isinstance(fused, tuple):
91 | fused, _ = fused
92 | if return_embedding:
93 | return fused, x
94 | else:
95 | return fused
96 |
97 |
98 | class DRIMSurv(MultimodalEncoder):
99 | def __init__(
100 | self,
101 | encoders_sh: Dict[str, nn.Module],
102 | encoders_u: Dict[str, nn.Module],
103 | fusion_s: nn.Module,
104 | fusion_u: nn.Module,
105 | devices: Optional[Dict[str, str]] = None,
106 | ) -> None:
107 | super().__init__(encoders_sh, devices)
108 | for key, model in encoders_u.items():
109 | setattr(self, f"{key.lower()}_encoder_u", model.to(f"{self.devices[key]}"))
110 |
111 | if "fusion_sh" not in self.devices.keys():
112 | try:
113 | self.devices["fusion_s"] = next(fusion_s.parameters()).device
114 | except:
115 | # if fusion do not have parameters, then it is a function
116 | self.devices["fusion_s"] = "cuda:0"
117 |
118 | if "fusion_u" not in self.devices.keys():
119 | try:
120 | self.devices["fusion_u"] = next(fusion_u.parameters()).device
121 | except:
122 | # if fusion do not have parameters, then it is a function
123 | self.devices["fusion_u"] = "cuda:0"
124 | self.fusion_s = fusion_s
125 | self.fusion_u = fusion_u
126 |
127 | def forward(
128 | self,
129 | x: Dict[str, torch.Tensor],
130 | mask: Optional[Dict[str, torch.Tensor]] = None,
131 | return_embedding: bool = False,
132 | ) -> Union[torch.Tensor, Tuple[torch.tensor, Dict[str, torch.Tensor]]]:
133 | x_s = super()._forward_encoders(x)
134 | x_u = super()._forward_encoders(x, unique=True)
135 |
136 | fused = self.drim_forward(x_s, x_u, mask)
137 | if isinstance(fused, tuple):
138 | fused, _ = fused
139 | if return_embedding:
140 | return fused, x_s, x_u
141 | else:
142 | return fused
143 |
144 | def drim_forward(
145 | self,
146 | x_s: Dict[str, torch.Tensor],
147 | x_u: Dict[str, torch.Tensor],
148 | mask: Optional[Dict[str, torch.Tensor]] = None,
149 | ) -> Union[torch.Tensor, Tuple[torch.tensor, Dict[str, torch.Tensor]]]:
150 | x_s, mask = self._put_on_devices(x_s, mask, "fusion_s")
151 | fused_s = self.fusion_s(x_s, mask)
152 | if isinstance(fused_s, tuple):
153 | fused_s, _ = fused_s
154 | x_t = {**x_u, "shared": fused_s}
155 | mask.update({"shared": torch.ones_like(next(iter(mask.values())))})
156 | # put everything on the same device
157 | x_t, mask = self._put_on_devices(x_t, mask, "fusion_u")
158 | fused = self.fusion_u(x_t, mask)
159 | return fused
160 |
161 |
162 | class DRIMU(DRIMSurv):
163 | def forward(
164 | self,
165 | x: Dict[str, torch.Tensor],
166 | mask: Optional[Dict[str, torch.Tensor]] = None,
167 | return_embedding: bool = False,
168 | ) -> Union[torch.Tensor, Tuple[torch.tensor, Dict[str, torch.Tensor]]]:
169 | x_s = super()._forward_encoders(x)
170 | x_u = super()._forward_encoders(x, unique=True)
171 | if return_embedding:
172 | return x_s, x_u
173 | else:
174 | fused = self.drim_forward(x_s, x_u, mask)
175 | if isinstance(fused, tuple):
176 | fused, _ = fused
177 |
178 | return fused
179 |
--------------------------------------------------------------------------------
/drim/rna/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import RNADataset
2 | from .models import RNADecoder, RNAEncoder, RNAEncoderMlp
3 |
--------------------------------------------------------------------------------
/drim/rna/datasets.py:
--------------------------------------------------------------------------------
1 | from ..datasets import _BaseDataset
2 | import pandas as pd
3 | import torch
4 | from typing import Union, Tuple
5 | import numpy as np
6 |
7 |
8 | def log_transform(x):
9 | return np.log(x + 1)
10 |
11 |
12 | class RNADataset(_BaseDataset):
13 | """
14 | Simple dataset for RNA data.
15 | """
16 |
17 | def __init__(
18 | self,
19 | dataframe: pd.DataFrame,
20 | preprocessor: "sklearn.pipeline.Pipeline",
21 | return_mask: bool = False,
22 | ) -> None:
23 | super().__init__(dataframe, return_mask)
24 | self.preprocessor = preprocessor
25 |
26 | def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, bool], torch.Tensor]:
27 | sample = self.dataframe.iloc[idx]
28 | if not pd.isna(sample.RNA):
29 | out = torch.from_numpy(
30 | self.preprocessor.transform(
31 | pd.read_csv(sample.RNA)["fpkm_uq_unstranded"].values.reshape(1, -1)
32 | )
33 | ).float()
34 | mask = True
35 | else:
36 | out = torch.zeros(1, 16304).float()
37 | mask = False
38 |
39 | if self.return_mask:
40 | return out, mask
41 | else:
42 | return out
43 |
--------------------------------------------------------------------------------
/drim/rna/models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from einops.layers.torch import Rearrange
4 |
5 |
6 | class RNAEncoder(nn.Module):
7 | """
8 | A vanilla encoder based on 1-d convolution for RNA data.
9 | """
10 |
11 | def __init__(self, embedding_dim: int, dropout: float) -> None:
12 | super().__init__()
13 | self.encoder = nn.Sequential(
14 | nn.Conv1d(1, 8, 9, 3),
15 | nn.GELU(),
16 | nn.BatchNorm1d(8),
17 | nn.Dropout(dropout),
18 | nn.Conv1d(8, 32, 9, 3),
19 | nn.GELU(),
20 | nn.BatchNorm1d(32),
21 | nn.Dropout(dropout),
22 | nn.Conv1d(32, 64, 9, 3),
23 | nn.GELU(),
24 | nn.BatchNorm1d(64),
25 | nn.Dropout(dropout),
26 | nn.Conv1d(64, 128, 9, 3),
27 | nn.GELU(),
28 | nn.BatchNorm1d(128),
29 | nn.Dropout(dropout),
30 | nn.Conv1d(128, 256, 9, 3),
31 | nn.GELU(),
32 | nn.BatchNorm1d(256),
33 | nn.Dropout(dropout),
34 | nn.Conv1d(256, embedding_dim, 9, 3),
35 | nn.GELU(),
36 | nn.BatchNorm1d(embedding_dim),
37 | nn.Dropout(dropout),
38 | nn.AdaptiveAvgPool1d(1),
39 | )
40 |
41 | def forward(self, x: torch.Tensor) -> torch.Tensor:
42 | x = self.encoder(x).squeeze(-1)
43 | return x
44 |
45 |
46 | class RNAEncoderMlp(nn.Module):
47 | def __init__(self, embedding_dim: int, dropout: float) -> None:
48 | super().__init__()
49 | self.encoder = nn.Sequential(
50 | nn.BatchNorm1d(18000),
51 | nn.Linear(18000, 512),
52 | nn.GELU(),
53 | nn.BatchNorm1d(512),
54 | nn.Dropout(dropout),
55 | nn.Linear(512, embedding_dim),
56 | )
57 |
58 | def forward(self, x: torch.Tensor) -> torch.Tensor:
59 | return self.encoder(x).squeeze(-1)
60 |
61 |
62 | class RNADecoder(nn.Module):
63 | """
64 | A vanilla decoder based on 1-d transposed convolution for RNA data.
65 | """
66 |
67 | def __init__(self, embedding_dim: int, dropout: float) -> None:
68 | super().__init__()
69 | self.decoder = nn.Sequential(
70 | nn.Linear(embedding_dim, embedding_dim * 20),
71 | Rearrange("b (n e) -> b n e", e=20),
72 | nn.GELU(),
73 | nn.BatchNorm1d(embedding_dim),
74 | nn.Dropout(dropout),
75 | nn.ConvTranspose1d(embedding_dim, 256, 9, 3),
76 | nn.GELU(),
77 | nn.BatchNorm1d(256),
78 | nn.Dropout(dropout),
79 | nn.ConvTranspose1d(256, 128, 9, 3, 2),
80 | nn.GELU(),
81 | nn.BatchNorm1d(128),
82 | nn.Dropout(dropout),
83 | nn.ConvTranspose1d(128, 64, 9, 3, 1),
84 | nn.GELU(),
85 | nn.BatchNorm1d(64),
86 | nn.Dropout(dropout),
87 | nn.ConvTranspose1d(64, 32, 9, 3, 4),
88 | nn.GELU(),
89 | nn.BatchNorm1d(32),
90 | nn.Dropout(dropout),
91 | nn.ConvTranspose1d(32, 8, 9, 3, 1),
92 | nn.GELU(),
93 | nn.BatchNorm1d(8),
94 | nn.Dropout(dropout),
95 | nn.ConvTranspose1d(8, 1, 9, 3, 2),
96 | )
97 |
98 | def forward(self, x: torch.Tensor) -> torch.Tensor:
99 | x = self.decoder(x)
100 | return x
101 |
--------------------------------------------------------------------------------
/drim/trainers.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from typing import Dict, Optional
3 | from torchinfo import summary
4 | from collections import defaultdict
5 | import numpy as np
6 | from pycox.evaluation import EvalSurv
7 | import wandb
8 | from .utils import interpolate_dataframe
9 | import torch
10 | from omegaconf import DictConfig, OmegaConf
11 | from .logger import logger
12 | import pandas as pd
13 | import tqdm
14 | from .losses import DiscriminatorLoss
15 |
16 |
17 | class _BaseTrainer:
18 | """
19 | Base trainer for all experiments
20 | """
21 |
22 | def __init__(
23 | self,
24 | model: nn.Module,
25 | optimizer: torch.optim.Optimizer,
26 | scheduler: torch.optim.lr_scheduler._LRScheduler,
27 | task_criterion: nn.modules.loss._Loss,
28 | dataloaders: Dict[str, torch.utils.data.DataLoader],
29 | cfg: DictConfig,
30 | wandb_logging: bool,
31 | ):
32 | self.model = model
33 | self.optimizer = optimizer
34 | self.scheduler = scheduler
35 | self.task_criterion = task_criterion
36 | self.dataloaders = dataloaders
37 | self.cfg = cfg
38 | self.wandb_logging = wandb_logging
39 |
40 | def _print_summary(self) -> None:
41 | """
42 | Print the model summary and the number of trainable parameters
43 | """
44 | logger.info(
45 | "Trainable parameters {}",
46 | sum(p.numel() for p in self.model.parameters() if p.requires_grad),
47 | )
48 | logger.info(self.model)
49 |
50 | def _initialize_logger(self) -> None:
51 | """
52 | Initialize the WandB logger
53 | """
54 | if "wandb" in self.cfg:
55 | self.wandb_logging = True
56 | if not isinstance(self.cfg.general.modalities, str):
57 | name = "_".join(self.cfg.general.modalities)
58 | else:
59 | name = self.cfg.general.modalities
60 | wandb.init(
61 | name=self.prefix + name,
62 | config={
63 | k: v
64 | for k, v in OmegaConf.to_container(self.cfg).items()
65 | if k != "wandb"
66 | },
67 | **self.cfg.wandb,
68 | )
69 | else:
70 | self.wandb_logging = False
71 |
72 | def put_model_on_correct_mode(self, split: str) -> None:
73 | if split == "train":
74 | self.model.train()
75 | torch.set_grad_enabled(True)
76 | else:
77 | self.model.eval()
78 | torch.set_grad_enabled(False)
79 |
80 | def clear(self) -> None:
81 | if self.wandb_logging:
82 | wandb.finish()
83 |
84 | return
85 |
86 | def fit(self, epochs: Optional[int] = None) -> None:
87 | """
88 | Train the self.model for a given number of epochs
89 | """
90 | # self._initialize_logger()
91 | # self.wandb_logging = False
92 | self._print_summary()
93 | metrics = defaultdict(list)
94 | if not epochs:
95 | epochs = self.cfg.general.epochs
96 | for epoch in range(epochs):
97 | logger.info(f"Epoch {epoch + 1}/{epochs}")
98 | logger.info("-" * 10)
99 | metrics["lr"].append(self.optimizer.state_dict()["param_groups"][0]["lr"])
100 |
101 | train_losses = self.shared_loop(split="train")
102 | for key, value in train_losses.items():
103 | logger.info(f"Train {key} {value:.4f}")
104 | metrics[f"train/{key}"].append(value)
105 |
106 | val_metrics = self.shared_loop(split="val")
107 | for key, value in val_metrics.items():
108 | logger.info(f"Val {key} {value:.4f}")
109 | metrics[f"val/{key}"].append(value)
110 |
111 | if self.wandb_logging:
112 | wandb.log({key: value[-1] for key, value in metrics.items()})
113 |
114 | self.scheduler.step()
115 | torch.save(self.model.state_dict(), self.cfg.general.save_path)
116 |
117 | def evaluate(self, split: str) -> Dict[str, float]:
118 | outputs = self.shared_loop(split)
119 | for key, value in outputs.items():
120 | logger.info(f"{split} {key} {value:.4f}")
121 | if self.wandb_logging:
122 | # update outputs
123 | outputs = {f"{split}/{key}": value for key, value in outputs.items()}
124 | wandb.log(outputs)
125 |
126 | return outputs
127 |
128 |
129 | class BaseSurvivalTrainer(_BaseTrainer):
130 | def __init__(
131 | self,
132 | model: nn.Module,
133 | optimizer: torch.optim.Optimizer,
134 | scheduler: torch.optim.lr_scheduler._LRScheduler,
135 | task_criterion: nn.modules.loss._Loss,
136 | dataloaders: Dict[str, torch.utils.data.DataLoader],
137 | cfg: DictConfig,
138 | wandb_logging: bool,
139 | cuts: np.ndarray,
140 | ):
141 | super().__init__(
142 | model, optimizer, scheduler, task_criterion, dataloaders, cfg, wandb_logging
143 | )
144 | self.cuts = cuts
145 |
146 | def compute_survival_metrics(self, outputs, time, event):
147 | """
148 | Compute the survival metrics with the PyCox package.
149 | """
150 | hazard = torch.cat(outputs, dim=0)
151 | survival = (1 - hazard.sigmoid()).add(1e-7).log().cumsum(1).exp().cpu().numpy()
152 | survival = interpolate_dataframe(pd.DataFrame(survival.transpose(), self.cuts))
153 | evaluator = EvalSurv(
154 | survival, time.cpu().numpy(), event.cpu().numpy(), censor_surv="km"
155 | )
156 | c_index = evaluator.concordance_td()
157 | ibs = evaluator.integrated_brier_score(np.linspace(0, time.cpu().numpy().max()))
158 | inbll = evaluator.integrated_nbll(np.linspace(0, time.cpu().numpy().max()))
159 | cs_score = (c_index + (1 - ibs)) / 2
160 | return {"c_index": c_index, "ibs": ibs, "inbll": inbll, "cs_score": cs_score}
161 |
162 | def shared_loop(self, split: str) -> Dict[str, float]:
163 | self.put_model_on_correct_mode(split)
164 | total_task_loss = 0.0
165 | raw_predictions = []
166 | times = []
167 | events = []
168 | for batch in tqdm.tqdm(self.dataloaders[split]):
169 | data, time, event = batch
170 | # check if data contains mask
171 | if isinstance(data, list):
172 | data, mask = data
173 | outputs = self.model(data, mask, return_embedding=False)
174 | else:
175 | outputs = self.model(data, return_embedding=False)
176 |
177 | loss = self.task_criterion(
178 | outputs, time.to(outputs.device), event.to(outputs.device)
179 | )
180 | if split == "train":
181 | self.optimizer.zero_grad()
182 | loss.backward()
183 | self.optimizer.step()
184 | total_task_loss += loss.item() * time.size(0)
185 | raw_predictions.append(outputs)
186 | times.append(time)
187 | events.append(event)
188 |
189 | outputs = {"task_loss": total_task_loss / len(self.dataloaders[split].dataset)}
190 | if split != "train":
191 | task_metrics = self.compute_survival_metrics(
192 | raw_predictions, torch.cat(times, dim=0), torch.cat(events, dim=0)
193 | )
194 | outputs.update(task_metrics)
195 |
196 | return outputs
197 |
198 |
199 | class AuxSurvivalTrainer(BaseSurvivalTrainer):
200 | """
201 | This class assumes that an auxiliary loss is added to the survival loss. This auxiliary loss can be either MMO Loss or CL loss.
202 | """
203 |
204 | def __init__(
205 | self,
206 | model: nn.Module,
207 | optimizer: torch.optim.Optimizer,
208 | scheduler: torch.optim.lr_scheduler._LRScheduler,
209 | task_criterion: nn.modules.loss._Loss,
210 | dataloaders: Dict[str, torch.utils.data.DataLoader],
211 | cfg: DictConfig,
212 | wandb_logging: bool,
213 | cuts: np.ndarray,
214 | aux_loss: torch.nn.modules.loss._Loss,
215 | ) -> None:
216 | super().__init__(
217 | model,
218 | optimizer,
219 | scheduler,
220 | task_criterion,
221 | dataloaders,
222 | cfg,
223 | wandb_logging,
224 | cuts,
225 | )
226 | self.aux_loss = aux_loss
227 |
228 | def shared_loop(self, split: str) -> Dict[str, float]:
229 | self.put_model_on_correct_mode(split)
230 | raw_predictions = []
231 | embeddings = defaultdict(list)
232 | masks = defaultdict(list)
233 | total_task_loss = 0.0
234 | total_aux_loss = 0.0
235 | times = []
236 | events = []
237 | for batch in tqdm.tqdm(self.dataloaders[split]):
238 | data, time, event = batch
239 | # check if data contains mask
240 | if isinstance(data, list):
241 | data, mask = data
242 | outputs, batch_embeddings = self.model(
243 | data, mask, return_embedding=True
244 | )
245 | else:
246 | outputs, batch_embeddings = self.model(data, return_embedding=True)
247 |
248 | task_loss = self.task_criterion(
249 | outputs, time.to(outputs.device), event.to(outputs.device)
250 | )
251 | # put mask on the same device as embeddings
252 | mask = {k: v.to(batch_embeddings[k].device) for k, v in mask.items()}
253 | aux_loss = self.aux_loss(batch_embeddings, mask)
254 | loss = task_loss + self.cfg.aux_loss.alpha * aux_loss
255 | if split == "train":
256 | self.optimizer.zero_grad()
257 | loss.backward()
258 | self.optimizer.step()
259 |
260 | total_task_loss += task_loss.item() * time.size(0)
261 | total_aux_loss += aux_loss.item() * time.size(0)
262 | raw_predictions.append(outputs)
263 | times.append(time)
264 | events.append(event)
265 |
266 | outputs = {
267 | "task_loss": total_task_loss / len(self.dataloaders[split].dataset),
268 | "aux_loss": total_aux_loss / len(self.dataloaders[split].dataset),
269 | }
270 | if split != "train":
271 | task_metrics = self.compute_survival_metrics(
272 | raw_predictions, torch.cat(times, dim=0), torch.cat(events, dim=0)
273 | )
274 | embeddings = {k: torch.cat(v) for k, v in embeddings.items()}
275 | masks = {k: torch.cat(v) for k, v in masks.items()}
276 | outputs.update(task_metrics)
277 |
278 | return outputs
279 |
280 |
281 | class DRIMSurvTrainer(AuxSurvivalTrainer):
282 | def __init__(
283 | self,
284 | model: nn.Module,
285 | discriminators: Dict[str, nn.Module],
286 | optimizer: torch.optim.Optimizer,
287 | optimizers_dsm: Dict[str, torch.optim.Optimizer],
288 | scheduler: torch.optim.lr_scheduler._LRScheduler,
289 | task_criterion: nn.modules.loss._Loss,
290 | dataloaders: Dict[str, torch.utils.data.DataLoader],
291 | cfg: DictConfig,
292 | wandb_logging: bool,
293 | cuts: np.ndarray,
294 | aux_loss: torch.nn.modules.loss._Loss,
295 | ) -> None:
296 | super().__init__(
297 | model,
298 | optimizer,
299 | scheduler,
300 | task_criterion,
301 | dataloaders,
302 | cfg,
303 | wandb_logging,
304 | cuts,
305 | aux_loss,
306 | )
307 | self.discriminators = discriminators
308 | self.optimizers_dsm = optimizers_dsm
309 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
310 | self.discriminator_loss = DiscriminatorLoss()
311 |
312 | def shared_loop(self, split: str) -> Dict[str, float]:
313 | self.put_model_on_correct_mode(split)
314 | raw_predictions = []
315 | embeddings = defaultdict(list)
316 | masks = defaultdict(list)
317 | total_task_loss = 0.0
318 | total_sh_loss = 0.0
319 | total_u_loss = 0.0
320 | total_discrimators_loss = {k: 0.0 for k in self.optimizers_dsm.keys()}
321 | times = []
322 | events = []
323 | for batch in tqdm.tqdm(self.dataloaders[split]):
324 | data, time, event = batch
325 | # check if data contains mask
326 | if isinstance(data, list):
327 | data, mask = data
328 | outputs, batch_embeddings_sh, batch_embeddings_u = self.model(
329 | data, mask, return_embedding=True
330 | )
331 | else:
332 | outputs, batch_embeddings_sh, batch_embeddings_u = self.model(
333 | data, return_embedding=True
334 | )
335 |
336 | # link to survival loss
337 | task_loss = self.task_criterion(
338 | outputs, time.to(outputs.device), event.to(outputs.device)
339 | )
340 | # compute shared loss
341 | sh_loss = self.aux_loss(
342 | {k: v.to(self.device) for k, v in batch_embeddings_sh.items()},
343 | {k: v.to(self.device) for k, v in mask.items()},
344 | )
345 | # put mask on cpu
346 | mask = {k: v.cpu() for k, v in mask.items()}
347 | # prepare inputs for unique encoders, x_r is suppose to be drawn from P(x, y) while x_r_prime is drawn from P(x)P(y)
348 | x_r = {
349 | k: torch.cat(
350 | [
351 | batch_embeddings_sh[k][mask[k]]
352 | .detach()
353 | .to(batch_embeddings_u[k].device),
354 | batch_embeddings_u[k][mask[k]],
355 | ],
356 | dim=1,
357 | )
358 | for k in batch_embeddings_u.keys()
359 | }
360 | x_r_prime = {
361 | k: torch.cat(
362 | [
363 | torch.roll(
364 | batch_embeddings_sh[k][mask[k]]
365 | .detach()
366 | .to(batch_embeddings_u[k].device),
367 | 1,
368 | 0,
369 | ),
370 | batch_embeddings_u[k][mask[k]],
371 | ],
372 | dim=1,
373 | )
374 | for k in batch_embeddings_sh.keys()
375 | }
376 | # compute unique loss
377 | u_loss = torch.tensor(0.0).to(self.device)
378 | for modality in x_r.keys():
379 | # if mask is empty then continue
380 | if sum(mask[modality]) <= 1:
381 | continue
382 | # get fake logits
383 | fake_logits = self.discriminators[modality](x_r[modality])
384 | u_loss += torch.nn.functional.binary_cross_entropy_with_logits(
385 | input=fake_logits, target=torch.ones_like(fake_logits)
386 | ).to(self.device)
387 |
388 | # update shared and unique encoders according to their respective loss and the task loss
389 | loss_encoders = (
390 | task_loss
391 | + self.cfg.aux_loss.alpha * sh_loss
392 | + self.cfg.disentangled.gamma * u_loss
393 | )
394 | if split == "train":
395 | self.optimizer.zero_grad()
396 | loss_encoders.backward()
397 | self.optimizer.step()
398 |
399 | # update discriminators
400 | for modality in x_r.keys():
401 | # if mask is empty then continue
402 | if sum(mask[modality]) <= 1:
403 | continue
404 |
405 | fake_logits = self.discriminators[modality](x_r[modality].detach())
406 | real_logits = self.discriminators[modality](
407 | x_r_prime[modality].detach()
408 | )
409 | loss_dsm = self.discriminator_loss(
410 | real_logits=real_logits, fake_logits=fake_logits
411 | )
412 |
413 | if split == "train":
414 | self.optimizers_dsm[modality].zero_grad()
415 | loss_dsm.backward()
416 | self.optimizers_dsm[modality].step()
417 |
418 | total_discrimators_loss[modality] += loss_dsm.item() * time.size(0)
419 |
420 | total_task_loss += task_loss.item() * time.size(0)
421 | total_sh_loss += sh_loss.item() * time.size(0)
422 | total_u_loss += u_loss.item() * time.size(0)
423 | raw_predictions.append(outputs)
424 | times.append(time)
425 | events.append(event)
426 |
427 | outputs = {
428 | "shared_loss": total_sh_loss / len(self.dataloaders[split].dataset),
429 | "unique_loss": total_u_loss / len(self.dataloaders[split].dataset),
430 | "task_loss": total_task_loss / len(self.dataloaders[split].dataset),
431 | }
432 | outputs.update(
433 | {
434 | f"discriminator_{k}": v / len(self.dataloaders[split].dataset)
435 | for k, v in total_discrimators_loss.items()
436 | }
437 | )
438 |
439 | if split != "train":
440 | task_metrics = self.compute_survival_metrics(
441 | raw_predictions, torch.cat(times, dim=0), torch.cat(events, dim=0)
442 | )
443 | embeddings = {k: torch.cat(v) for k, v in embeddings.items()}
444 | masks = {k: torch.cat(v) for k, v in masks.items()}
445 | outputs = {**outputs, **task_metrics}
446 |
447 | return outputs
448 |
449 |
450 | class DRIMUTrainer(DRIMSurvTrainer):
451 | def __init__(
452 | self,
453 | model: nn.Module,
454 | decoders: Dict[str, nn.Module],
455 | discriminators: Dict[str, nn.Module],
456 | optimizer: torch.optim.Optimizer,
457 | optimizers_dsm: Dict[str, torch.optim.Optimizer],
458 | scheduler: torch.optim.lr_scheduler._LRScheduler,
459 | task_criterion: nn.modules.loss._Loss,
460 | dataloaders: Dict[str, torch.utils.data.DataLoader],
461 | cfg: DictConfig,
462 | wandb_logging: bool,
463 | cuts: np.ndarray,
464 | aux_loss: torch.nn.modules.loss._Loss,
465 | ) -> None:
466 | super().__init__(
467 | model,
468 | discriminators,
469 | optimizer,
470 | optimizers_dsm,
471 | scheduler,
472 | task_criterion,
473 | dataloaders,
474 | cfg,
475 | wandb_logging,
476 | cuts,
477 | aux_loss,
478 | )
479 | self.decoders = decoders
480 |
481 | def shared_loop(self, split: str) -> Dict[str, float]:
482 | self.put_model_on_correct_mode(split)
483 | total_loss = 0.0
484 | embeddings = defaultdict(list)
485 | masks = defaultdict(list)
486 | total_sh_loss = 0.0
487 | total_u_loss = 0.0
488 | total_discriminators_loss = {k: 0.0 for k in self.optimizers_dsm.keys()}
489 | total_decoders_loss = {k: 0.0 for k in self.decoders.keys()}
490 | for batch in tqdm.tqdm(self.dataloaders[split]):
491 | data, time, _ = batch
492 | # check if data contains mask
493 | if isinstance(data, list):
494 | data, mask = data
495 | batch_embeddings_sh, batch_embeddings_u = self.model(
496 | data, mask, return_embedding=True
497 | )
498 | else:
499 | batch_embeddings_sh, batch_embeddings_u = self.model(
500 | data, return_embedding=True
501 | )
502 |
503 | sh_loss = self.aux_loss(
504 | {k: v.to(self.device) for k, v in batch_embeddings_sh.items()},
505 | {k: v.to(self.device) for k, v in mask.items()},
506 | )
507 | # put mask on cpu
508 | mask = {k: v.cpu() for k, v in mask.items()}
509 | x_r = {
510 | k: torch.cat(
511 | [
512 | batch_embeddings_sh[k][mask[k]]
513 | .detach()
514 | .to(batch_embeddings_u[k].device),
515 | batch_embeddings_u[k][mask[k]],
516 | ],
517 | dim=1,
518 | )
519 | for k in batch_embeddings_u.keys()
520 | }
521 | x_r_prime = {
522 | k: torch.cat(
523 | [
524 | torch.roll(
525 | batch_embeddings_sh[k][mask[k]]
526 | .detach()
527 | .to(batch_embeddings_u[k].device),
528 | 1,
529 | 0,
530 | ),
531 | batch_embeddings_u[k][mask[k]],
532 | ],
533 | dim=1,
534 | )
535 | for k in batch_embeddings_sh.keys()
536 | }
537 | u_loss = torch.tensor(0.0).to(self.device)
538 | decoder_loss_iteration = torch.tensor(0.0).to(self.device)
539 | for modality in x_r.keys():
540 | if sum(mask[modality]) <= 1:
541 | continue
542 | fake_logits = self.discriminators[modality](x_r[modality])
543 | u_loss += torch.nn.functional.binary_cross_entropy_with_logits(
544 | input=fake_logits, target=torch.ones_like(fake_logits)
545 | ).to(self.device)
546 | decoder_outputs = self.decoders[modality](
547 | batch_embeddings_u[modality][mask[modality]]
548 | )
549 | decoder_loss = torch.nn.functional.mse_loss(
550 | decoder_outputs,
551 | data[modality][mask[modality]].to(decoder_outputs.device),
552 | )
553 | decoder_loss_iteration += decoder_loss.to(self.device)
554 | total_decoders_loss[modality] += decoder_loss.item() * time.size(0)
555 |
556 | loss_encoders = (
557 | self.cfg.aux_loss.alpha * sh_loss
558 | + self.cfg.disentangled.gamma * u_loss
559 | + decoder_loss_iteration
560 | )
561 | if split == "train":
562 | self.optimizer.zero_grad()
563 | loss_encoders.backward()
564 | self.optimizer.step()
565 | else:
566 | for modality in batch_embeddings_sh.keys():
567 | embeddings[modality].append(batch_embeddings_sh[modality].cpu())
568 | masks[modality].append(mask[modality].cpu())
569 |
570 | losses_dsm = {}
571 | for modality in x_r.keys():
572 | if sum(mask[modality]) <= 1:
573 | losses_dsm[modality] = 0.0
574 | continue
575 |
576 | fake_logits = self.discriminators[modality](x_r[modality].detach())
577 | real_logits = self.discriminators[modality](
578 | x_r_prime[modality].detach()
579 | )
580 | loss_dsm = self.discriminator_loss(
581 | real_logits=real_logits, fake_logits=fake_logits
582 | )
583 |
584 | if split == "train":
585 | self.optimizers_dsm[modality].zero_grad()
586 | loss_dsm.backward()
587 | self.optimizers_dsm[modality].step()
588 |
589 | total_discriminators_loss[modality] += loss_dsm.item()
590 |
591 | total_sh_loss += sh_loss.item() * time.size(0)
592 | total_u_loss += u_loss.item() * time.size(0)
593 |
594 | outputs = {
595 | "loss": total_loss / len(self.dataloaders[split].dataset),
596 | "sh_loss": total_sh_loss / len(self.dataloaders[split].dataset),
597 | "unique_loss": total_u_loss / len(self.dataloaders[split].dataset),
598 | }
599 | outputs.update(
600 | {
601 | f"discriminator_{k}": v / len(self.dataloaders[split].dataset)
602 | for k, v in total_discriminators_loss.items()
603 | }
604 | )
605 | outputs.update(
606 | {
607 | f"decoder_{k}": v / len(self.dataloaders[split].dataset)
608 | for k, v in total_decoders_loss.items()
609 | }
610 | )
611 |
612 | return outputs
613 |
614 | def finetune(self):
615 | self.shared_loop = self.shared_loop_finetune
616 | self.optimizer = torch.optim.AdamW(
617 | self.model.parameters(), **self.cfg.optimizer.params
618 | )
619 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
620 | self.optimizer, T_max=30, eta_min=5e-6
621 | )
622 | for name, param in self.model.named_parameters():
623 | if "_encoder" in name:
624 | param.requires_grad = False
625 | self.fit(epochs=self.cfg.general.epochs_finetune)
626 |
627 | def shared_loop_finetune(self, split: str) -> Dict[str, float]:
628 | self.put_model_on_correct_mode(split)
629 | total_loss = 0.0
630 | raw_predictions = []
631 | times = []
632 | events = []
633 | for batch in tqdm.tqdm(self.dataloaders[split]):
634 | data, time, event = batch
635 | # check if data contains mask
636 | if isinstance(data, list):
637 | data, mask = data
638 | outputs = self.model(data, mask, return_embedding=False)
639 | else:
640 | outputs = self.model(data, return_embedding=False)
641 |
642 | loss = self.task_criterion(
643 | outputs, time.to(outputs.device), event.to(outputs.device)
644 | )
645 | if split == "train":
646 | self.optimizer.zero_grad()
647 | loss.backward()
648 | self.optimizer.step()
649 | total_loss += loss.item() * time.size(0)
650 | raw_predictions.append(outputs)
651 | times.append(time)
652 | events.append(event)
653 |
654 | outputs = {"task_loss": total_loss / len(self.dataloaders[split].dataset)}
655 | if split != "train":
656 | task_metrics = self.compute_survival_metrics(
657 | raw_predictions, torch.cat(times, dim=0), torch.cat(events, dim=0)
658 | )
659 | outputs.update(task_metrics)
660 |
661 | return outputs
662 |
--------------------------------------------------------------------------------
/drim/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | from typing import List, Union, Dict, Any, Tuple
4 | import torch
5 | import random
6 |
7 |
8 | def seed_everything(seed: int) -> None:
9 | import monai
10 |
11 | monai.utils.set_determinism(seed=seed, additional_settings=None)
12 | np.random.seed(seed)
13 | random.seed(seed)
14 | torch.manual_seed(seed)
15 | torch.cuda.manual_seed(seed)
16 | torch.cuda.manual_seed_all(seed)
17 | torch.backends.cudnn.deterministic = True
18 | torch.backends.cudnn.benchmark = False
19 |
20 |
21 | def clean_state_dict(state_dict: dict) -> dict:
22 | new_state_dict = {}
23 | for key, value in state_dict.items():
24 | if key.startswith("module."):
25 | new_state_dict[key[7:]] = value
26 | else:
27 | new_state_dict[key] = value
28 | return new_state_dict
29 |
30 |
31 | def prepare_data(
32 | dataframe: pd.DataFrame, modalities: Union[List[str], str], min_k_modality: int = 2
33 | ) -> pd.DataFrame:
34 | if isinstance(modalities, str):
35 | modalities = [modalities]
36 |
37 | # columns = ['project_id', 'submitter_id']
38 | # dataframe = dataframe.drop(columns=columns)
39 |
40 | if len(modalities) == 1:
41 | try:
42 | mask = ~dataframe[modalities[0]].isna()
43 | dataframe = dataframe.loc[mask]
44 | except:
45 | pass
46 |
47 | else:
48 | temp_df = dataframe[modalities]
49 | dataframe = dataframe[temp_df.count(axis=1) >= min_k_modality]
50 |
51 | return dataframe
52 |
53 |
54 | get_target_survival = lambda df: (df["time"].values, df["event"].values)
55 |
56 |
57 | def log_transform(x: np.ndarray) -> np.ndarray:
58 | return np.log(1 + x)
59 |
60 |
61 | def create_nan_dataframe(
62 | num_row: int, num_col: int, name_col: List[str]
63 | ) -> pd.DataFrame:
64 | df = pd.DataFrame(np.zeros((num_row, num_col)), columns=name_col)
65 | df[:] = np.nan
66 | return df
67 |
68 |
69 | def interpolate_dataframe(dataframe: pd.DataFrame, n: int = 10) -> pd.DataFrame:
70 | dataframe.reset_index(inplace=True)
71 | dataframe_list = []
72 | for i, idx in enumerate(dataframe.index):
73 | df_temp = dataframe[dataframe.index == idx]
74 | dataframe_list.append(df_temp)
75 | if i != len(dataframe) - 1:
76 | dataframe_list.append(
77 | create_nan_dataframe(n, df_temp.shape[1], df_temp.columns)
78 | )
79 |
80 | dataframe = pd.concat(dataframe_list).interpolate("linear")
81 | dataframe = dataframe.set_index("index")
82 | return dataframe
83 |
84 |
85 | def seed_worker(worker_id):
86 | worker_seed = torch.initial_seed() % 2**32
87 | np.random.seed(worker_seed)
88 | random.seed(worker_seed)
89 |
90 |
91 | def get_dataframes(fold: int) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
92 | dataframe = pd.read_csv("data/files/train_brain.csv")
93 | dataframe_train = (
94 | dataframe[dataframe["split"] != fold].copy().drop(columns=["split"])
95 | )
96 | dataframe_val = dataframe[dataframe["split"] == fold].copy().drop(columns=["split"])
97 | dataframe_test = pd.read_csv("data/files/test_brain.csv")
98 | return {"train": dataframe_train, "val": dataframe_val, "test": dataframe_test}
99 |
--------------------------------------------------------------------------------
/drim/wsi/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import WSIDataset
2 | from .models import WSIDecoder, WSIEncoder, ResNetWrapperSimCLR
3 |
--------------------------------------------------------------------------------
/drim/wsi/datasets.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union
2 | from PIL import Image
3 | import torch
4 | from ..datasets import _BaseDataset
5 | import pandas as pd
6 |
7 |
8 | class PatchDataset(torch.utils.data.Dataset):
9 | def __init__(
10 | self, filepaths: Tuple[str, ...], transforms: "torchvision.transforms"
11 | ) -> None:
12 | self.filepaths = filepaths
13 | self.transforms = transforms
14 |
15 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
16 | path = self.filepaths[idx]
17 | image = Image.open(path)
18 | image_1, image_2 = self.transforms(image)
19 | return image_1, image_2
20 |
21 | def __len__(self) -> int:
22 | return len(self.filepaths)
23 |
24 |
25 | class WSIDataset(_BaseDataset):
26 | def __init__(
27 | self,
28 | dataframe: pd.DataFrame,
29 | k: int,
30 | is_train: bool = True,
31 | return_mask: bool = False,
32 | ) -> None:
33 | super().__init__(dataframe, return_mask)
34 | self.k = k
35 | self.is_train = is_train
36 |
37 | def __getitem__(self, idx: int) -> Union[Tuple[torch.Tensor, bool], torch.Tensor]:
38 | sample = self.dataframe.iloc[idx]
39 | if not pd.isna(sample.WSI):
40 | data = pd.read_csv(sample.WSI)
41 | # get k random embeddings
42 | if self.is_train:
43 | data = data.sample(self.k)
44 | else:
45 | data = data.iloc[: self.k]
46 |
47 | data = torch.from_numpy(data.values).float()
48 | mask = True
49 | else:
50 | data = torch.zeros(self.k, 512).float()
51 | mask = False
52 |
53 | if self.return_mask:
54 | return data, mask
55 | else:
56 | return data
57 |
--------------------------------------------------------------------------------
/drim/wsi/func.py:
--------------------------------------------------------------------------------
1 | """
2 | Deeply c/p from vit_pytorch (https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py)
3 | """
4 | import torch
5 | from torch import nn
6 | from einops import rearrange
7 |
8 |
9 | def softmax_one(x, dim=None):
10 | # subtract the max for stability
11 | x = x - x.max(dim=dim, keepdim=True).values
12 | # compute exponentials
13 | exp_x = torch.exp(x)
14 | # compute softmax values and add on in the denominator
15 | return exp_x / (1 + exp_x.sum(dim=dim, keepdim=True))
16 |
17 |
18 | # helpers
19 |
20 |
21 | def pair(t):
22 | return t if isinstance(t, tuple) else (t, t)
23 |
24 |
25 | # classes
26 |
27 |
28 | class FeedForward(nn.Module):
29 | def __init__(self, dim, hidden_dim, dropout=0.0):
30 | super().__init__()
31 | self.net = nn.Sequential(
32 | nn.LayerNorm(dim),
33 | nn.Linear(dim, hidden_dim),
34 | nn.GELU(),
35 | nn.Dropout(dropout),
36 | nn.Linear(hidden_dim, dim),
37 | nn.Dropout(dropout),
38 | )
39 |
40 | def forward(self, x):
41 | return self.net(x)
42 |
43 |
44 | class Attention(nn.Module):
45 | def __init__(self, dim, heads=8, dim_head=64, dropout=0.0):
46 | super().__init__()
47 | inner_dim = dim_head * heads
48 | project_out = not (heads == 1 and dim_head == dim)
49 |
50 | self.heads = heads
51 | self.scale = dim_head**-0.5
52 |
53 | self.norm = nn.LayerNorm(dim)
54 |
55 | self.dropout = nn.Dropout(dropout)
56 |
57 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
58 |
59 | self.to_out = (
60 | nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
61 | if project_out
62 | else nn.Identity()
63 | )
64 |
65 | def forward(self, x):
66 | x = self.norm(x)
67 |
68 | qkv = self.to_qkv(x).chunk(3, dim=-1)
69 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=self.heads), qkv)
70 |
71 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
72 |
73 | attn = softmax_one(dots, dim=-1)
74 | attn = self.dropout(attn)
75 |
76 | out = torch.matmul(attn, v)
77 | out = rearrange(out, "b h n d -> b n (h d)")
78 | return self.to_out(out)
79 |
80 |
81 | class Transformer(nn.Module):
82 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.0):
83 | super().__init__()
84 | self.norm = nn.LayerNorm(dim)
85 | self.layers = nn.ModuleList([])
86 | for _ in range(depth):
87 | self.layers.append(
88 | nn.ModuleList(
89 | [
90 | Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout),
91 | FeedForward(dim, mlp_dim, dropout=dropout),
92 | ]
93 | )
94 | )
95 |
96 | def forward(self, x):
97 | for attn, ff in self.layers:
98 | x = attn(x) + x
99 | x = ff(x) + x
100 |
101 | return self.norm(x)
102 |
--------------------------------------------------------------------------------
/drim/wsi/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision
4 | from typing import Tuple, Union
5 | from einops import repeat
6 | from .func import Transformer
7 | from einops.layers.torch import Rearrange
8 |
9 |
10 | class ResNetWrapperSimCLR(nn.Module):
11 | """
12 | A wrapper for the ResNet34 model with a projection head for SimCLR.
13 | """
14 |
15 | def __init__(self, out_dim: int, projection_head: bool = True) -> None:
16 | super().__init__()
17 | self.encoder = torchvision.models.resnet34(pretrained=False)
18 | self.encoder.fc = nn.Identity()
19 | if projection_head:
20 | self.projection_head = nn.Sequential(
21 | nn.Linear(512, 512), nn.ReLU(inplace=True), nn.Linear(512, out_dim)
22 | )
23 |
24 | def forward(
25 | self, x: torch.tensor
26 | ) -> Union[torch.tensor, Tuple[torch.tensor, torch.tensor]]:
27 | x = self.encoder(x)
28 | if hasattr(self, "projection_head"):
29 | return self.projection_head(x), x
30 | else:
31 | return x
32 |
33 |
34 | class WSIEncoder(nn.Module):
35 | """
36 | A attention-based encoder for WSI data.
37 | """
38 |
39 | def __init__(
40 | self,
41 | embedding_dim: int,
42 | depth: int,
43 | heads: int,
44 | dim: int = 512,
45 | pool: str = "cls",
46 | dim_head: int = 64,
47 | mlp_dim: int = 128,
48 | dropout: float = 0.0,
49 | emb_dropout: float = 0.0,
50 | ) -> None:
51 | super().__init__()
52 | self.layer_norm = nn.LayerNorm(dim)
53 |
54 | # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
55 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
56 | self.dropout = nn.Dropout(emb_dropout)
57 |
58 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
59 |
60 | self.pool = pool
61 | self.to_latent = (
62 | nn.Identity() if embedding_dim == dim else nn.Linear(dim, embedding_dim)
63 | )
64 |
65 | def forward(self, x: torch.Tensor) -> torch.Tensor:
66 | x = self.layer_norm(x)
67 | b, n, _ = x.shape
68 |
69 | cls_tokens = repeat(self.cls_token, "1 1 d -> b 1 d", b=b)
70 | x = torch.cat((cls_tokens, x), dim=1)
71 | # x += self.pos_embedding[:, :(n + 1)]
72 | x = self.dropout(x)
73 |
74 | x = self.transformer(x)
75 |
76 | x = x.mean(dim=1) if self.pool == "mean" else x[:, 0]
77 | x = self.to_latent(x)
78 |
79 | return x
80 |
81 |
82 | class WSIDecoder(nn.Module):
83 | """
84 | A vanilla mlp-based decoder for WSI data.
85 | """
86 |
87 | def __init__(self, embedding_dim: int, dropout: float) -> None:
88 | super().__init__()
89 | self.decoder = nn.Sequential(
90 | nn.Linear(embedding_dim, 256),
91 | nn.BatchNorm1d(256),
92 | nn.Dropout(dropout),
93 | nn.LeakyReLU(),
94 | nn.Linear(256, 5120),
95 | Rearrange("b (p e) -> b p e", p=10),
96 | )
97 |
98 | def forward(self, x: torch.Tensor) -> torch.Tensor:
99 | return self.decoder(x)
100 |
--------------------------------------------------------------------------------
/drim/wsi/transforms.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | import torch
3 |
4 |
5 | class Dropout:
6 | def __init__(self, dropout_prob, p):
7 | self.dropout_prob = dropout_prob
8 | self.p = p
9 |
10 | def __call__(self, img):
11 | if torch.rand(1).item() < self.p:
12 | mask = ~torch.bernoulli(torch.full_like(img, self.dropout_prob)).bool()
13 | return img * mask
14 | else:
15 | return img
16 |
17 |
18 | class NviewsAugment(object):
19 | def __init__(self, transforms, n_views=2):
20 | self.transforms = transforms
21 | self.n_views = n_views
22 |
23 | def __call__(self, x):
24 | return [self.transforms(x) for i in range(self.n_views)]
25 |
26 |
27 | contrastive_base = transforms.Compose(
28 | [
29 | transforms.RandomResizedCrop(size=256, scale=(0.35, 1.0)),
30 | transforms.RandomHorizontalFlip(),
31 | transforms.RandomVerticalFlip(),
32 | transforms.RandomApply([transforms.ColorJitter(0.8, 0.8, 0.8, 0.2)], p=0.8),
33 | transforms.RandomGrayscale(p=0.2),
34 | transforms.GaussianBlur(kernel_size=11), # , sigma=(0.1, 1.)),
35 | transforms.ToTensor(),
36 | Dropout(0.1, 0.3),
37 | ]
38 | )
39 |
40 |
41 | contrastive_wsi_transforms = NviewsAugment(contrastive_base, n_views=2)
42 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.24.3
2 | matplotlib
3 | imutils
4 | torch==1.13.1
5 | torchvision==0.14.1
6 | monai==1.1.0
7 | tqdm
8 | pyvips==2.2.1
9 | opencv-python==4.7.0.72
10 | pandas==2.0.1
11 | hydra-core==1.3.2
12 | einops==0.6.1
13 | pycox==0.2.3
14 | loguru
--------------------------------------------------------------------------------
/robustness.py:
--------------------------------------------------------------------------------
1 | # Standard libraries
2 | from typing import Union, List
3 | from collections import defaultdict
4 |
5 | # Third-party libraries
6 | import pandas as pd
7 | import hydra
8 | from omegaconf import DictConfig
9 | import torch
10 | import numpy as np
11 | from pycox.evaluation import EvalSurv
12 |
13 | # Local dependencies
14 | from drim.helpers import get_datasets, get_targets, get_encoder
15 | from drim.logger import logger
16 | from drim.utils import (
17 | interpolate_dataframe,
18 | log_transform,
19 | seed_everything,
20 | get_dataframes,
21 | )
22 | from drim.multimodal import MultimodalDataset, DRIMSurv, MultimodalModel
23 | from drim.datasets import SurvivalDataset
24 | from drim.models import MultimodalWrapper
25 |
26 |
27 | def prepare_data(
28 | dataframe: pd.DataFrame, modalities: Union[List[str], str]
29 | ) -> pd.DataFrame:
30 | available_modalities = ["DNAm", "WSI", "RNA", "MRI"]
31 |
32 | # columns = ['project_id', 'submitter_id']
33 | # dataframe = dataframe.drop(columns=columns)
34 | # get remaining modalities
35 | remaining = [
36 | modality for modality in available_modalities if modality not in modalities
37 | ]
38 | if len(modalities) == 1:
39 | try:
40 | mask = ~dataframe[modalities[0]].isna()
41 | dataframe = dataframe.loc[mask]
42 | except:
43 | pass
44 |
45 | else:
46 | temp_df = dataframe[modalities]
47 | dataframe = dataframe[temp_df.count(axis=1) >= len(modalities)]
48 |
49 | # put the whole columns of remaining modalities to NaN
50 | dataframe = dataframe.assign(**{modality: None for modality in remaining})
51 | return dataframe
52 |
53 |
54 | @hydra.main(version_base=None, config_path="configs", config_name="robustness")
55 | def main(cfg: DictConfig) -> None:
56 | logger.info("Starting multimodal robustness test.")
57 | logger.info("Method tested {}.", cfg.method)
58 | modality_combinations = [
59 | ["WSI", "MRI"],
60 | ["WSI", "RNA"],
61 | ["DNAm", "WSI"],
62 | ["DNAm", "RNA"],
63 | ["DNAm", "MRI"],
64 | ["RNA", "MRI"],
65 | ["DNAm", "WSI", "RNA"],
66 | ["DNAm", "WSI", "MRI"],
67 | ["DNAm", "RNA", "MRI"],
68 | ["WSI", "RNA", "MRI"],
69 | ["DNAm", "WSI", "RNA", "MRI"],
70 | ]
71 | for combination in modality_combinations:
72 | if cfg.method == "tensor":
73 | cfg.general.dim = 32
74 | logs = defaultdict(list)
75 | for fold in range(5):
76 | seed_everything(cfg.general.seed)
77 | dataframes = get_dataframes(fold)
78 | dataframes = {
79 | split: prepare_data(dataframe, combination)
80 | for split, dataframe in dataframes.items()
81 | }
82 | test_datasets = {}
83 | encoders = {}
84 | if cfg.method == "drim":
85 | encoders_u = {}
86 |
87 | for modality in ["DNAm", "WSI", "RNA", "MRI"]:
88 | datasets = get_datasets(dataframes, modality, fold, return_mask=True)
89 | test_datasets[modality] = datasets["test"]
90 | encoder = get_encoder(modality, cfg).cuda()
91 | encoders[modality] = encoder
92 | if cfg.method == "drim":
93 | encoder_u = get_encoder(modality, cfg).cuda()
94 | encoders_u[modality] = encoder_u
95 |
96 | targets, cut = get_targets(dataframes, cfg.general.n_outs)
97 | dataset_test = MultimodalDataset(test_datasets, return_mask=True)
98 | test_data = SurvivalDataset(dataset_test, *targets["test"])
99 | loader = torch.utils.data.DataLoader(
100 | test_data, shuffle=False, batch_size=24
101 | )
102 | if cfg.method == "drim":
103 | from drim.fusion import MaskedAttentionFusion
104 |
105 | fusion = MaskedAttentionFusion(
106 | dim=cfg.general.dim, depth=1, heads=16, dim_head=64, mlp_dim=128
107 | )
108 | fusion_u = MaskedAttentionFusion(
109 | dim=cfg.general.dim, depth=1, heads=16, dim_head=64, mlp_dim=128
110 | )
111 | fusion.cuda()
112 | fusion_u.cuda()
113 | encoder = DRIMSurv(
114 | encoders_sh=encoders,
115 | encoders_u=encoders_u,
116 | fusion_s=fusion,
117 | fusion_u=fusion_u,
118 | )
119 | model = MultimodalWrapper(
120 | encoder, embedding_dim=cfg.general.dim, n_outs=cfg.general.n_outs
121 | )
122 | model.load_state_dict(
123 | torch.load(f"./models/drimsurv_split_{int(fold)}.pth")
124 | )
125 | else:
126 | if cfg.method == "max":
127 | from drim.fusion import ShallowFusion
128 |
129 | fusion = ShallowFusion("max")
130 | elif cfg.method == "tensor":
131 | from drim.fusion import TensorFusion
132 |
133 | fusion = TensorFusion(
134 | modalities=["DNAm", "WSI", "RNA", "MRI"],
135 | input_dim=cfg.general.dim,
136 | projected_dim=cfg.general.dim,
137 | output_dim=cfg.general.dim,
138 | dropout=0.0,
139 | )
140 | elif cfg.method == "concat":
141 | from drim.fusion import ShallowFusion
142 |
143 | fusion = ShallowFusion("concat")
144 |
145 | fusion.cuda()
146 | encoder = MultimodalModel(encoders, fusion=fusion)
147 | if cfg.method == "concat":
148 | size = cfg.general.dim * 4
149 | else:
150 | size = cfg.general.dim
151 | model = MultimodalWrapper(
152 | encoder, embedding_dim=size, n_outs=cfg.general.n_outs
153 | )
154 | if cfg.method == "max":
155 | prefix = "vanilla"
156 | else:
157 | prefix = "aux_contrastive"
158 | model.load_state_dict(
159 | torch.load(f"./models/{prefix}_{cfg.method}_split_{int(fold)}.pth")
160 | )
161 |
162 | model.cuda()
163 | model.eval()
164 | hazards = []
165 | times = []
166 | events = []
167 | with torch.no_grad():
168 | for batch in loader:
169 | data, time, event = batch
170 | data, mask = data
171 | outputs = model(data, mask, return_embedding=False)
172 | hazards.append(outputs)
173 | times.append(time)
174 | events.append(event)
175 |
176 | times = torch.cat(times, dim=0).cpu().numpy()
177 | events = torch.cat(events, dim=0).cpu().numpy()
178 | hazards = interpolate_dataframe(
179 | pd.DataFrame(
180 | (1 - torch.cat(hazards, dim=0).sigmoid())
181 | .add(1e-7)
182 | .log()
183 | .cumsum(1)
184 | .exp()
185 | .cpu()
186 | .numpy()
187 | .transpose(),
188 | cut,
189 | )
190 | )
191 | ev = EvalSurv(hazards, times, events, censor_surv="km")
192 | c_index = ev.concordance_td()
193 | ibs = ev.integrated_brier_score(np.linspace(0, times.max()))
194 | CS = (c_index + (1 - ibs)) / 2
195 | logs["c_index"].append(c_index)
196 | logs["ibs"].append(ibs)
197 | logs["CS"].append(CS)
198 |
199 | logger.info(
200 | f"{combination} - CS: {np.mean(logs['CS']):.3f} $\pm$ {np.std(logs['CS']):.3f}"
201 | )
202 |
203 |
204 | if __name__ == "__main__":
205 | main()
206 |
--------------------------------------------------------------------------------
/static/DRIM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lucas-rbnt/DRIM/fffd026639e7f475e9d00f13f994920578d0e06e/static/DRIM.png
--------------------------------------------------------------------------------
/train_aux_multimodal.py:
--------------------------------------------------------------------------------
1 | # Standard libraries
2 | from collections import defaultdict
3 |
4 | # Third-party libraries
5 | from omegaconf import DictConfig, OmegaConf
6 | import hydra
7 | from torch.utils.data import DataLoader
8 | import torch
9 | from pycox.models.loss import NLLLogistiHazardLoss
10 | import numpy as np
11 |
12 | # Local dependencies
13 | from drim.trainers import AuxSurvivalTrainer
14 | from drim.multimodal import MultimodalDataset
15 | from drim.multimodal import MultimodalModel
16 | from drim.models import MultimodalWrapper
17 | from drim.datasets import SurvivalDataset
18 | from drim.logger import logger
19 | from drim.utils import (
20 | seed_everything,
21 | seed_worker,
22 | prepare_data,
23 | get_dataframes,
24 | log_transform,
25 | )
26 | from drim.helpers import get_encoder, get_datasets, get_targets
27 |
28 |
29 | @hydra.main(version_base=None, config_path="configs", config_name="multimodal")
30 | def main(cfg: DictConfig) -> None:
31 | cv_metrics = defaultdict(list)
32 | # check if wandb key is in cfg
33 | if "wandb" in cfg:
34 | import wandb
35 |
36 | wandb_logging = True
37 | wandb.init(
38 | name=f"aux_{cfg.aux_loss.name}_{cfg.fusion.name}_"
39 | + "_".join(cfg.general.modalities),
40 | config={
41 | k: v for k, v in OmegaConf.to_container(cfg).items() if k != "wandb"
42 | },
43 | **cfg.wandb,
44 | )
45 | else:
46 | wandb_logging = False
47 | logger.info("Starting multimodal cross-validation.")
48 | logger.info("Modalities used: {}.", cfg.general.modalities)
49 | for fold in range(cfg.general.n_folds):
50 | logger.info("Starting fold {}", fold)
51 | seed_everything(cfg.general.seed)
52 | # Load the data
53 | dataframes = get_dataframes(fold)
54 | dataframes = {
55 | split: prepare_data(dataframe, cfg.general.modalities)
56 | for split, dataframe in dataframes.items()
57 | }
58 | cfg.general.save_path = (
59 | f"./models/aux_{cfg.aux_loss.name}_{cfg.fusion.name}_split_{int(fold)}.pth"
60 | )
61 |
62 | for split, dataframe in dataframes.items():
63 | logger.info(f"{split} samples: {len(dataframe)}")
64 |
65 | train_datasets = {}
66 | val_datasets = {}
67 | test_datasets = {}
68 | encoders = {}
69 | logger.info("Loading models and preparing corresponding dataset...")
70 | for modality in cfg.general.modalities:
71 | encoder = get_encoder(modality, cfg).cuda()
72 | encoders[modality] = encoder
73 | datasets = get_datasets(dataframes, modality, fold, return_mask=True)
74 | train_datasets[modality] = datasets["train"]
75 | val_datasets[modality] = datasets["val"]
76 | test_datasets[modality] = datasets["test"]
77 |
78 | targets, cuts = get_targets(dataframes, cfg.general.n_outs)
79 |
80 | dataset_train = MultimodalDataset(train_datasets, return_mask=True)
81 | dataset_val = MultimodalDataset(val_datasets, return_mask=True)
82 | dataset_test = MultimodalDataset(test_datasets, return_mask=True)
83 | train_data = SurvivalDataset(dataset_train, *targets["train"])
84 | val_data = SurvivalDataset(dataset_val, *targets["val"])
85 | test_data = SurvivalDataset(dataset_test, *targets["test"])
86 |
87 | dataloaders = {
88 | "train": DataLoader(
89 | train_data, shuffle=True, worker_init_fn=seed_worker, **cfg.dataloader
90 | ),
91 | "val": DataLoader(
92 | val_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader
93 | ),
94 | "test": DataLoader(
95 | test_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader
96 | ),
97 | }
98 |
99 | if cfg.fusion.name in ["mean", "concat", "max", "sum", "masked_mean"]:
100 | from drim.fusion import ShallowFusion
101 |
102 | fusion = ShallowFusion(cfg.fusion.name)
103 | elif cfg.fusion.name == "maf":
104 | from drim.fusion import MaskedAttentionFusion
105 |
106 | fusion = MaskedAttentionFusion(
107 | dim=cfg.general.dim, dropout=cfg.general.dropout, **cfg.fusion.params
108 | )
109 | elif cfg.fusion.name == "tensor":
110 | from drim.fusion import TensorFusion
111 |
112 | fusion = TensorFusion(
113 | modalities=cfg.general.modalities,
114 | input_dim=cfg.general.dim,
115 | projected_dim=cfg.general.dim,
116 | output_dim=cfg.general.dim,
117 | dropout=cfg.general.dropout,
118 | )
119 | else:
120 | raise NotImplementedError
121 |
122 | fusion.cuda()
123 |
124 | encoder = MultimodalModel(encoders, fusion=fusion)
125 | if cfg.fusion.name == "concat":
126 | size = cfg.general.dim * len(cfg.general.modalities)
127 | else:
128 | size = cfg.general.dim
129 |
130 | model = MultimodalWrapper(encoder, size, n_outs=cfg.general.n_outs)
131 | # model = model.cuda()
132 | logger.info("Done!")
133 |
134 | # define optimizer and scheduler
135 | optimizer = torch.optim.AdamW(model.parameters(), **cfg.optimizer.params)
136 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
137 | optimizer, **cfg.scheduler
138 | )
139 |
140 | # define task criterion
141 | task_criterion = NLLLogistiHazardLoss()
142 |
143 | # define auxiliary criterion
144 | if cfg.aux_loss.name == "contrastive":
145 | from drim.losses import ContrastiveLoss
146 |
147 | aux_loss = ContrastiveLoss()
148 | elif cfg.aux_loss.name == "mmo":
149 | from drim.losses import MMOLoss
150 |
151 | aux_loss = MMOLoss()
152 |
153 | trainer = AuxSurvivalTrainer(
154 | model=model,
155 | optimizer=optimizer,
156 | scheduler=scheduler,
157 | dataloaders=dataloaders,
158 | task_criterion=task_criterion,
159 | cfg=cfg,
160 | wandb_logging=wandb_logging,
161 | cuts=cuts,
162 | aux_loss=aux_loss,
163 | )
164 |
165 | trainer.fit()
166 | val_logs = trainer.evaluate("val")
167 | test_logs = trainer.evaluate("test")
168 | # add to cv_metrics
169 | for key, value in val_logs.items():
170 | cv_metrics[key].append(value)
171 |
172 | for key, value in test_logs.items():
173 | cv_metrics[key].append(value)
174 |
175 | logger.info("Fold {} done!", fold)
176 |
177 | # log first the mean ± std of the validation metrics
178 | logs = {}
179 | for key, value in cv_metrics.items():
180 | if key in [
181 | "test/c_index",
182 | "test/cs_score",
183 | "test/inbll",
184 | "test/ibs",
185 | "val/c_index",
186 | "val/cs_score",
187 | "val/inbll",
188 | "val/ibs",
189 | ]:
190 | mean, std = np.mean(value), np.std(value)
191 | logger.info(f"{key}: {mean:.4f} ± {std:.4f}")
192 | logs[f"fin/{'_'.join(key.split('/'))}_mean"] = mean
193 | logs[f"fin/{'_'.join(key.split('/'))}_std"] = std
194 |
195 | if wandb_logging:
196 | wandb.log(logs)
197 | wandb.finish()
198 |
199 |
200 | if __name__ == "__main__":
201 | main()
202 |
--------------------------------------------------------------------------------
/train_drimsurv.py:
--------------------------------------------------------------------------------
1 | # Standard libraries
2 | from collections import defaultdict
3 |
4 | # Third-party libraries
5 | from omegaconf import DictConfig, OmegaConf
6 | import hydra
7 | from torch.utils.data import DataLoader
8 | import torch
9 | from pycox.models.loss import NLLLogistiHazardLoss
10 | import numpy as np
11 |
12 | # Local dependencies
13 | from drim.trainers import DRIMSurvTrainer
14 | from drim.multimodal import MultimodalDataset, DRIMSurv
15 | from drim.models import MultimodalWrapper, Discriminator
16 | from drim.losses import ContrastiveLoss
17 | from drim.datasets import SurvivalDataset
18 | from drim.logger import logger
19 | from drim.utils import (
20 | seed_everything,
21 | seed_worker,
22 | prepare_data,
23 | get_dataframes,
24 | log_transform,
25 | )
26 | from drim.helpers import get_encoder, get_datasets, get_targets
27 |
28 |
29 | @hydra.main(version_base=None, config_path="configs", config_name="multimodal")
30 | def main(cfg: DictConfig) -> None:
31 | cv_metrics = defaultdict(list)
32 | # check if wandb key is in cfg
33 | if "wandb" in cfg:
34 | import wandb
35 |
36 | wandb_logging = True
37 | wandb.init(
38 | name="DRIMSurv_" + "_".join(cfg.general.modalities),
39 | config={
40 | k: v for k, v in OmegaConf.to_container(cfg).items() if k != "wandb"
41 | },
42 | **cfg.wandb,
43 | )
44 | else:
45 | wandb_logging = False
46 | logger.info("Starting multimodal cross-validation.")
47 | logger.info("Modalities used: {}.", cfg.general.modalities)
48 | for fold in range(cfg.general.n_folds):
49 | logger.info("Starting fold {}", fold)
50 | seed_everything(cfg.general.seed)
51 | # Load the data
52 | dataframes = get_dataframes(fold)
53 | dataframes = {
54 | split: prepare_data(dataframe, cfg.general.modalities)
55 | for split, dataframe in dataframes.items()
56 | }
57 | cfg.general.save_path = f"./models/drimsurv_split_{int(fold)}.pth"
58 |
59 | for split, dataframe in dataframes.items():
60 | logger.info(f"{split} samples: {len(dataframe)}")
61 |
62 | train_datasets = {}
63 | val_datasets = {}
64 | test_datasets = {}
65 | encoders = {}
66 | encoders_u = {}
67 | logger.info("Loading models and preparing corresponding dataset...")
68 | for modality in cfg.general.modalities:
69 | encoder = get_encoder(modality, cfg).cuda()
70 | encoders[modality] = encoder
71 | encoder_u = get_encoder(modality, cfg).cuda()
72 | encoders_u[modality] = encoder_u
73 | datasets = get_datasets(dataframes, modality, fold, return_mask=True)
74 | train_datasets[modality] = datasets["train"]
75 | val_datasets[modality] = datasets["val"]
76 | test_datasets[modality] = datasets["test"]
77 |
78 | targets, cuts = get_targets(dataframes, cfg.general.n_outs)
79 |
80 | dataset_train = MultimodalDataset(train_datasets, return_mask=True)
81 | dataset_val = MultimodalDataset(val_datasets, return_mask=True)
82 | dataset_test = MultimodalDataset(test_datasets, return_mask=True)
83 | train_data = SurvivalDataset(dataset_train, *targets["train"])
84 | val_data = SurvivalDataset(dataset_val, *targets["val"])
85 | test_data = SurvivalDataset(dataset_test, *targets["test"])
86 |
87 | dataloaders = {
88 | "train": DataLoader(
89 | train_data, shuffle=True, worker_init_fn=seed_worker, **cfg.dataloader
90 | ),
91 | "val": DataLoader(
92 | val_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader
93 | ),
94 | "test": DataLoader(
95 | test_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader
96 | ),
97 | }
98 |
99 | if cfg.fusion.name in ["mean", "concat", "max", "sum", "masked_mean"]:
100 | from drim.fusion import ShallowFusion
101 |
102 | fusion = ShallowFusion(cfg.fusion.name)
103 | fusion_u = ShallowFusion(cfg.fusion.name)
104 | elif cfg.fusion.name == "maf":
105 | from drim.fusion import MaskedAttentionFusion
106 |
107 | fusion = MaskedAttentionFusion(
108 | dim=cfg.general.dim, dropout=cfg.general.dropout, **cfg.fusion.params
109 | )
110 | fusion_u = MaskedAttentionFusion(
111 | dim=cfg.general.dim, dropout=cfg.general.dropout, **cfg.fusion.params
112 | )
113 | elif cfg.fusion.name == "tensor":
114 | from drim.fusion import TensorFusion
115 |
116 | fusion = TensorFusion(
117 | modalities=cfg.general.modalities,
118 | input_dim=cfg.general.dim,
119 | projected_dim=cfg.general.dim,
120 | output_dim=cfg.general.dim,
121 | dropout=cfg.general.dropout,
122 | )
123 | fusion_u = TensorFusion(
124 | modalities=cfg.general.modalities + ["shared"],
125 | input_dim=cfg.general.dim,
126 | projected_dim=cfg.general.dim,
127 | output_dim=cfg.general.dim,
128 | dropout=cfg.general.dropout,
129 | )
130 | else:
131 | raise NotImplementedError
132 |
133 | fusion.cuda()
134 | fusion_u.cuda()
135 |
136 | encoder = DRIMSurv(
137 | encoders_sh=encoders,
138 | encoders_u=encoders_u,
139 | fusion_s=fusion,
140 | fusion_u=fusion_u,
141 | )
142 |
143 | model = MultimodalWrapper(
144 | encoder, embedding_dim=cfg.general.dim, n_outs=cfg.general.n_outs
145 | )
146 | logger.info("Done!")
147 |
148 | logger.info("Preparing discriminators..")
149 | discriminators = {
150 | k: Discriminator(
151 | embedding_dim=cfg.general.dim * 2, dropout=cfg.general.dropout
152 | ).to("cuda")
153 | for k in encoders.keys()
154 | }
155 |
156 | # define optimizer and scheduler
157 | optimizer = torch.optim.AdamW(model.parameters(), **cfg.optimizer.params)
158 | optimizers_dsm = {
159 | k: torch.optim.AdamW(
160 | discriminators[k].parameters(),
161 | lr=cfg.disentangled.dsm_lr,
162 | weight_decay=cfg.disentangled.dsm_wd,
163 | )
164 | for k in encoders.keys()
165 | }
166 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
167 | optimizer, **cfg.scheduler
168 | )
169 |
170 | # define task criterion
171 | task_criterion = NLLLogistiHazardLoss()
172 |
173 | # define auxiliary criterion
174 | aux_loss = ContrastiveLoss()
175 |
176 | trainer = DRIMSurvTrainer(
177 | model=model,
178 | discriminators=discriminators,
179 | optimizer=optimizer,
180 | optimizers_dsm=optimizers_dsm,
181 | scheduler=scheduler,
182 | dataloaders=dataloaders,
183 | task_criterion=task_criterion,
184 | aux_loss=aux_loss,
185 | cfg=cfg,
186 | wandb_logging=wandb_logging,
187 | cuts=cuts,
188 | )
189 |
190 | trainer.fit()
191 | val_logs = trainer.evaluate("val")
192 | test_logs = trainer.evaluate("test")
193 | # add to cv_metrics
194 | for key, value in val_logs.items():
195 | cv_metrics[key].append(value)
196 |
197 | for key, value in test_logs.items():
198 | cv_metrics[key].append(value)
199 |
200 | logger.info("Fold {} done!", fold)
201 |
202 | # log first the mean ± std of the validation metrics
203 | logs = {}
204 | for key, value in cv_metrics.items():
205 | if key in [
206 | "test/c_index",
207 | "test/cs_score",
208 | "test/inbll",
209 | "test/ibs",
210 | "val/c_index",
211 | "val/cs_score",
212 | "val/inbll",
213 | "val/ibs",
214 | ]:
215 | mean, std = np.mean(value), np.std(value)
216 | logger.info(f"{key}: {mean:.4f} ± {std:.4f}")
217 | logs[f"fin/{'_'.join(key.split('/'))}_mean"] = mean
218 | logs[f"fin/{'_'.join(key.split('/'))}_std"] = std
219 |
220 | if wandb_logging:
221 | wandb.log(logs)
222 | wandb.finish()
223 |
224 |
225 | if __name__ == "__main__":
226 | main()
227 |
--------------------------------------------------------------------------------
/train_drimu.py:
--------------------------------------------------------------------------------
1 | # Standard libraries
2 | from collections import defaultdict
3 |
4 | # Third-party libraries
5 | from omegaconf import DictConfig, OmegaConf
6 | import hydra
7 | from torch.utils.data import DataLoader
8 | import torch
9 | from pycox.models.loss import NLLLogistiHazardLoss
10 | import numpy as np
11 |
12 | # Local dependencies
13 | from drim.trainers import DRIMUTrainer
14 | from drim.multimodal import MultimodalDataset, DRIMU
15 | from drim.models import MultimodalWrapper, Discriminator
16 | from drim.losses import ContrastiveLoss
17 | from drim.datasets import SurvivalDataset
18 | from drim.logger import logger
19 | from drim.utils import (
20 | seed_everything,
21 | seed_worker,
22 | prepare_data,
23 | get_dataframes,
24 | log_transform,
25 | )
26 | from drim.helpers import get_encoder, get_datasets, get_targets, get_decoder
27 |
28 |
29 | @hydra.main(version_base=None, config_path="configs", config_name="multimodal")
30 | def main(cfg: DictConfig) -> None:
31 | cv_metrics = defaultdict(list)
32 | # check if wandb key is in cfg
33 | if "wandb" in cfg:
34 | import wandb
35 |
36 | wandb_logging = True
37 | wandb.init(
38 | name="DRIMU_" + "_".join(cfg.general.modalities),
39 | config={
40 | k: v for k, v in OmegaConf.to_container(cfg).items() if k != "wandb"
41 | },
42 | **cfg.wandb,
43 | )
44 | else:
45 | wandb_logging = False
46 | logger.info("Starting multimodal cross-validation.")
47 | logger.info("Modalities used: {}.", cfg.general.modalities)
48 | for fold in range(cfg.general.n_folds):
49 | logger.info("Starting fold {}", fold)
50 | seed_everything(cfg.general.seed)
51 | # Load the data
52 | dataframes = get_dataframes(fold)
53 | dataframes = {
54 | split: prepare_data(dataframe, cfg.general.modalities)
55 | for split, dataframe in dataframes.items()
56 | }
57 | cfg.general.save_path = f"./models/drimu_split_{int(fold)}.pth"
58 |
59 | for split, dataframe in dataframes.items():
60 | logger.info(f"{split} samples: {len(dataframe)}")
61 |
62 | train_datasets = {}
63 | val_datasets = {}
64 | test_datasets = {}
65 | encoders = {}
66 | encoders_u = {}
67 | decoders = {}
68 | logger.info("Loading models and preparing corresponding dataset...")
69 | for modality in cfg.general.modalities:
70 | encoder = get_encoder(modality, cfg).cuda()
71 | encoders[modality] = encoder
72 | encoder_u = get_encoder(modality, cfg).cuda()
73 | encoders_u[modality] = encoder_u
74 | decoders[modality] = get_decoder(modality, cfg).cuda()
75 | datasets = get_datasets(dataframes, modality, fold, return_mask=True)
76 | train_datasets[modality] = datasets["train"]
77 | val_datasets[modality] = datasets["val"]
78 | test_datasets[modality] = datasets["test"]
79 |
80 | targets, cuts = get_targets(dataframes, cfg.general.n_outs)
81 |
82 | dataset_train = MultimodalDataset(train_datasets, return_mask=True)
83 | dataset_val = MultimodalDataset(val_datasets, return_mask=True)
84 | dataset_test = MultimodalDataset(test_datasets, return_mask=True)
85 | train_data = SurvivalDataset(dataset_train, *targets["train"])
86 | val_data = SurvivalDataset(dataset_val, *targets["val"])
87 | test_data = SurvivalDataset(dataset_test, *targets["test"])
88 |
89 | dataloaders = {
90 | "train": DataLoader(
91 | train_data, shuffle=True, worker_init_fn=seed_worker, **cfg.dataloader
92 | ),
93 | "val": DataLoader(
94 | val_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader
95 | ),
96 | "test": DataLoader(
97 | test_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader
98 | ),
99 | }
100 |
101 | if cfg.fusion.name in ["mean", "concat", "max", "sum", "masked_mean"]:
102 | from drim.fusion import ShallowFusion
103 |
104 | fusion = ShallowFusion(cfg.fusion.name)
105 | fusion_u = ShallowFusion(cfg.fusion.name)
106 | elif cfg.fusion.name == "maf":
107 | from drim.fusion import MaskedAttentionFusion
108 |
109 | fusion = MaskedAttentionFusion(
110 | dim=cfg.general.dim, dropout=cfg.general.dropout, **cfg.fusion.params
111 | )
112 | fusion_u = MaskedAttentionFusion(
113 | dim=cfg.general.dim, dropout=cfg.general.dropout, **cfg.fusion.params
114 | )
115 | elif cfg.fusion.name == "tensor":
116 | from drim.fusion import TensorFusion
117 |
118 | fusion = TensorFusion(
119 | modalities=cfg.general.modalities,
120 | input_dim=cfg.general.dim,
121 | projected_dim=cfg.general.dim,
122 | output_dim=cfg.general.dim,
123 | dropout=cfg.general.dropout,
124 | )
125 | fusion_u = TensorFusion(
126 | modalities=cfg.general.modalities,
127 | input_dim=cfg.general.dim,
128 | projected_dim=cfg.general.dim,
129 | output_dim=cfg.general.dim,
130 | dropout=cfg.general.dropout,
131 | )
132 | else:
133 | raise NotImplementedError
134 |
135 | fusion.cuda()
136 | fusion_u.cuda()
137 |
138 | encoder = DRIMU(
139 | encoders_sh=encoders,
140 | encoders_u=encoders_u,
141 | fusion_s=fusion,
142 | fusion_u=fusion_u,
143 | )
144 |
145 | model = MultimodalWrapper(
146 | encoder, embedding_dim=cfg.general.dim, n_outs=cfg.general.n_outs
147 | )
148 | logger.info("Done!")
149 |
150 | logger.info("Preparing discriminators..")
151 | discriminators = {
152 | k: Discriminator(
153 | embedding_dim=cfg.general.dim * 2, dropout=cfg.general.dropout
154 | ).to("cuda")
155 | for k in encoders.keys()
156 | }
157 |
158 | # define optimizer and scheduler
159 | decoder_parameters = []
160 | for k in decoders.keys():
161 | decoder_parameters += list(decoders[k].parameters())
162 | optimizer = torch.optim.AdamW(
163 | list(model.parameters()) + decoder_parameters, **cfg.optimizer.params
164 | )
165 | optimizers_dsm = {
166 | k: torch.optim.AdamW(
167 | discriminators[k].parameters(),
168 | lr=cfg.disentangled.dsm_lr,
169 | weight_decay=cfg.disentangled.dsm_wd,
170 | )
171 | for k in encoders.keys()
172 | }
173 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
174 | optimizer, eta_min=cfg.optimizer.params.lr, T_max=cfg.general.epochs
175 | )
176 |
177 | task_criterion = NLLLogistiHazardLoss()
178 | # define auxiliary criterion
179 | aux_loss = ContrastiveLoss()
180 |
181 | trainer = DRIMUTrainer(
182 | model,
183 | decoders=decoders,
184 | discriminators=discriminators,
185 | optimizer=optimizer,
186 | optimizers_dsm=optimizers_dsm,
187 | scheduler=scheduler,
188 | task_criterion=task_criterion,
189 | dataloaders=dataloaders,
190 | cfg=cfg,
191 | cuts=cuts,
192 | wandb_logging=wandb_logging,
193 | aux_loss=aux_loss,
194 | )
195 |
196 | trainer.fit()
197 | trainer.finetune()
198 | val_logs = trainer.evaluate("val")
199 | test_logs = trainer.evaluate("test")
200 | # add to cv_metrics
201 | for key, value in val_logs.items():
202 | cv_metrics[key].append(value)
203 |
204 | for key, value in test_logs.items():
205 | cv_metrics[key].append(value)
206 |
207 | logger.info("Fold {} done!", fold)
208 |
209 | # log first the mean ± std of the validation metrics
210 | logs = {}
211 | for key, value in cv_metrics.items():
212 | if key in [
213 | "test/c_index",
214 | "test/cs_score",
215 | "test/inbll",
216 | "test/ibs",
217 | "val/c_index",
218 | "val/cs_score",
219 | "val/inbll",
220 | "val/ibs",
221 | ]:
222 | mean, std = np.mean(value), np.std(value)
223 | logger.info(f"{key}: {mean:.4f} ± {std:.4f}")
224 | logs[f"fin/{'_'.join(key.split('/'))}_mean"] = mean
225 | logs[f"fin/{'_'.join(key.split('/'))}_std"] = std
226 |
227 | if wandb_logging:
228 | wandb.log(logs)
229 | wandb.finish()
230 |
231 |
232 | if __name__ == "__main__":
233 | main()
234 |
--------------------------------------------------------------------------------
/train_unimodal.py:
--------------------------------------------------------------------------------
1 | # Standard libraries
2 | from collections import defaultdict
3 |
4 | # Third-party libraries
5 | import hydra
6 | from omegaconf import DictConfig, OmegaConf
7 | from torch.utils.data import DataLoader
8 | import torch
9 | from pycox.models.loss import NLLLogistiHazardLoss
10 | import numpy as np
11 |
12 | # Local dependencies
13 | from drim.models import UnimodalWrapper
14 | from drim.trainers import BaseSurvivalTrainer
15 | from drim.datasets import SurvivalDataset
16 | from drim.logger import logger
17 | from drim.utils import (
18 | seed_everything,
19 | seed_worker,
20 | prepare_data,
21 | get_dataframes,
22 | log_transform,
23 | )
24 | from drim.helpers import get_encoder, get_datasets, get_targets
25 |
26 |
27 | @hydra.main(version_base=None, config_path="configs", config_name="unimodal")
28 | def main(cfg: DictConfig) -> None:
29 | cv_metrics = defaultdict(list)
30 | # check if wandb key is in cfg
31 | if "wandb" in cfg:
32 | import wandb
33 |
34 | wandb_logging = True
35 | wandb.init(
36 | name=cfg.general.modalities,
37 | config={
38 | k: v for k, v in OmegaConf.to_container(cfg).items() if k != "wandb"
39 | },
40 | **cfg.wandb,
41 | )
42 | else:
43 | wandb_logging = False
44 |
45 | logger.info("Starting unimodal cross-validation.")
46 | logger.info("Modality used: {}.", cfg.general.modalities)
47 | for fold in range(cfg.general.n_folds):
48 | logger.info("Starting fold {}", fold)
49 | seed_everything(cfg.general.seed)
50 | # Load the data
51 | dataframes = get_dataframes(fold)
52 | # take only the intersection between multimodal data and unimodal to ensure fair comparisons
53 | dataframes_multi = {
54 | split: prepare_data(dataframe, ["DNAm", "WSI", "RNA", "MRI"])
55 | for split, dataframe in dataframes.items()
56 | }
57 |
58 | dataframes = {
59 | split: prepare_data(dataframe, cfg.general.modalities)
60 | for split, dataframe in dataframes.items()
61 | }
62 |
63 | dataframes = {
64 | split: dataframe[
65 | dataframe["submitter_id"].isin(dataframes_multi[split]["submitter_id"])
66 | ]
67 | for split, dataframe in dataframes.items()
68 | }
69 | cfg.general.save_path = (
70 | f"./models/{cfg.general.modalities}_split_{int(fold)}.pth"
71 | )
72 | for split, dataframe in dataframes.items():
73 | logger.info(f"{split} samples: {len(dataframe)}")
74 |
75 | # Load the model
76 | logger.info("Loading model and preparing corresponding dataset...")
77 | encoder = get_encoder(cfg.general.modalities, cfg)
78 | datasets = get_datasets(
79 | dataframes, cfg.general.modalities, fold, return_mask=False
80 | )
81 | targets, cuts = get_targets(dataframes, cfg.general.n_outs)
82 | train_data = SurvivalDataset(datasets["train"], *targets["train"])
83 | val_data = SurvivalDataset(datasets["val"], *targets["val"])
84 | test_data = SurvivalDataset(datasets["test"], *targets["test"])
85 |
86 | dataloaders = {
87 | "train": DataLoader(
88 | train_data, shuffle=True, worker_init_fn=seed_worker, **cfg.dataloader
89 | ),
90 | "val": DataLoader(
91 | val_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader
92 | ),
93 | "test": DataLoader(
94 | test_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader
95 | ),
96 | }
97 |
98 | model = UnimodalWrapper(encoder, cfg.general.dim, n_outs=cfg.general.n_outs)
99 | model = model.cuda()
100 | logger.info("Done!")
101 |
102 | optimizer = torch.optim.AdamW(model.parameters(), **cfg.optimizer.params)
103 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
104 | optimizer, **cfg.scheduler
105 | )
106 |
107 | task_criterion = NLLLogistiHazardLoss()
108 | trainer = BaseSurvivalTrainer(
109 | model=model,
110 | optimizer=optimizer,
111 | scheduler=scheduler,
112 | dataloaders=dataloaders,
113 | task_criterion=task_criterion,
114 | cfg=cfg,
115 | wandb_logging=wandb_logging,
116 | cuts=cuts,
117 | )
118 |
119 | trainer.fit()
120 | val_logs = trainer.evaluate("val")
121 | test_logs = trainer.evaluate("test")
122 | # add to cv_metrics
123 | for key, value in val_logs.items():
124 | cv_metrics[key].append(value)
125 |
126 | for key, value in test_logs.items():
127 | cv_metrics[key].append(value)
128 |
129 | logger.info("Fold {} done!", fold)
130 |
131 | # log first the mean ± std of the validation metrics
132 | logs = {}
133 | for key, value in cv_metrics.items():
134 | if key in [
135 | "test/c_index",
136 | "test/cs_score",
137 | "test/inbll",
138 | "test/ibs",
139 | "val/c_index",
140 | "val/cs_score",
141 | "val/inbll",
142 | "val/ibs",
143 | ]:
144 | mean, std = np.mean(value), np.std(value)
145 | logger.info(f"{key}: {mean:.4f} ± {std:.4f}")
146 | logs[f"fin/{'_'.join(key.split('/'))}_mean"] = mean
147 | logs[f"fin/{'_'.join(key.split('/'))}_std"] = std
148 |
149 | if wandb_logging:
150 | wandb.log(logs)
151 | wandb.finish()
152 |
153 |
154 | if __name__ == "__main__":
155 | main()
156 |
--------------------------------------------------------------------------------
/train_vanilla_multimodal.py:
--------------------------------------------------------------------------------
1 | # Standard libraries
2 | from collections import defaultdict
3 |
4 | # Third-party libraries
5 | from omegaconf import DictConfig, OmegaConf
6 | import hydra
7 | from torch.utils.data import DataLoader
8 | import torch
9 | from pycox.models.loss import NLLLogistiHazardLoss
10 | import numpy as np
11 |
12 | # Local dependencies
13 | from drim.trainers import BaseSurvivalTrainer
14 | from drim.multimodal import MultimodalDataset
15 | from drim.multimodal import MultimodalModel
16 | from drim.models import MultimodalWrapper
17 | from drim.datasets import SurvivalDataset
18 | from drim.logger import logger
19 | from drim.utils import (
20 | seed_everything,
21 | seed_worker,
22 | prepare_data,
23 | get_dataframes,
24 | log_transform,
25 | )
26 | from drim.helpers import get_encoder, get_datasets, get_targets
27 |
28 |
29 | @hydra.main(version_base=None, config_path="configs", config_name="multimodal")
30 | def main(cfg: DictConfig) -> None:
31 | cv_metrics = defaultdict(list)
32 | # check if wandb key is in cfg
33 | if "wandb" in cfg:
34 | import wandb
35 |
36 | wandb_logging = True
37 | wandb.init(
38 | name=f"vanilla_{cfg.fusion.name}_" + "_".join(cfg.general.modalities),
39 | config={
40 | k: v for k, v in OmegaConf.to_container(cfg).items() if k != "wandb"
41 | },
42 | **cfg.wandb,
43 | )
44 | else:
45 | wandb_logging = False
46 | logger.info("Starting multimodal cross-validation.")
47 | logger.info("Modalities used: {}.", cfg.general.modalities)
48 | for fold in range(cfg.general.n_folds):
49 | logger.info("Starting fold {}", fold)
50 | seed_everything(cfg.general.seed)
51 | # Load the data
52 | dataframes = get_dataframes(fold)
53 | dataframes = {
54 | split: prepare_data(dataframe, cfg.general.modalities)
55 | for split, dataframe in dataframes.items()
56 | }
57 | cfg.general.save_path = (
58 | f"./models/vanilla_{cfg.fusion.name}_split_{int(fold)}.pth"
59 | )
60 |
61 | for split, dataframe in dataframes.items():
62 | logger.info(f"{split} samples: {len(dataframe)}")
63 |
64 | train_datasets = {}
65 | val_datasets = {}
66 | test_datasets = {}
67 | encoders = {}
68 | logger.info("Loading models and preparing corresponding dataset...")
69 | for modality in cfg.general.modalities:
70 | encoder = get_encoder(modality, cfg).cuda()
71 | encoders[modality] = encoder
72 | datasets = get_datasets(dataframes, modality, fold, return_mask=True)
73 | train_datasets[modality] = datasets["train"]
74 | val_datasets[modality] = datasets["val"]
75 | test_datasets[modality] = datasets["test"]
76 |
77 | targets, cuts = get_targets(dataframes, cfg.general.n_outs)
78 |
79 | dataset_train = MultimodalDataset(train_datasets, return_mask=True)
80 | dataset_val = MultimodalDataset(val_datasets, return_mask=True)
81 | dataset_test = MultimodalDataset(test_datasets, return_mask=True)
82 | train_data = SurvivalDataset(dataset_train, *targets["train"])
83 | val_data = SurvivalDataset(dataset_val, *targets["val"])
84 | test_data = SurvivalDataset(dataset_test, *targets["test"])
85 |
86 | dataloaders = {
87 | "train": DataLoader(
88 | train_data, shuffle=True, worker_init_fn=seed_worker, **cfg.dataloader
89 | ),
90 | "val": DataLoader(
91 | val_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader
92 | ),
93 | "test": DataLoader(
94 | test_data, shuffle=False, worker_init_fn=seed_worker, **cfg.dataloader
95 | ),
96 | }
97 |
98 | if cfg.fusion.name in ["mean", "concat", "max", "sum", "masked_mean"]:
99 | from drim.fusion import ShallowFusion
100 |
101 | fusion = ShallowFusion(cfg.fusion.name)
102 | elif cfg.fusion.name == "maf":
103 | from drim.fusion import MaskedAttentionFusion
104 |
105 | fusion = MaskedAttentionFusion(
106 | dim=cfg.general.dim, dropout=cfg.general.dropout, **cfg.fusion.params
107 | )
108 | elif cfg.fusion.name == "tensor":
109 | from drim.fusion import TensorFusion
110 |
111 | fusion = TensorFusion(
112 | modalities=cfg.general.modalities,
113 | input_dim=cfg.general.dim,
114 | projected_dim=cfg.general.dim,
115 | output_dim=cfg.general.dim,
116 | dropout=cfg.general.dropout,
117 | )
118 | else:
119 | raise NotImplementedError
120 |
121 | fusion.cuda()
122 |
123 | encoder = MultimodalModel(encoders, fusion=fusion)
124 | if cfg.fusion.name == "concat":
125 | size = cfg.general.dim * len(cfg.general.modalities)
126 | else:
127 | size = cfg.general.dim
128 |
129 | model = MultimodalWrapper(encoder, size, n_outs=cfg.general.n_outs)
130 | # model = model.cuda()
131 | logger.info("Done!")
132 |
133 | # define optimizer and scheduler
134 | optimizer = torch.optim.AdamW(model.parameters(), **cfg.optimizer.params)
135 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
136 | optimizer, **cfg.scheduler
137 | )
138 |
139 | # define task criterion
140 | task_criterion = NLLLogistiHazardLoss()
141 | trainer = BaseSurvivalTrainer(
142 | model=model,
143 | optimizer=optimizer,
144 | scheduler=scheduler,
145 | dataloaders=dataloaders,
146 | task_criterion=task_criterion,
147 | cfg=cfg,
148 | wandb_logging=wandb_logging,
149 | cuts=cuts,
150 | )
151 |
152 | trainer.fit()
153 | val_logs = trainer.evaluate("val")
154 | test_logs = trainer.evaluate("test")
155 | # add to cv_metrics
156 | for key, value in val_logs.items():
157 | cv_metrics[key].append(value)
158 |
159 | for key, value in test_logs.items():
160 | cv_metrics[key].append(value)
161 |
162 | logger.info("Fold {} done!", fold)
163 |
164 | # log first the mean ± std of the validation metrics
165 | logs = {}
166 | for key, value in cv_metrics.items():
167 | if key in [
168 | "test/c_index",
169 | "test/cs_score",
170 | "test/inbll",
171 | "test/ibs",
172 | "val/c_index",
173 | "val/cs_score",
174 | "val/inbll",
175 | "val/ibs",
176 | ]:
177 | mean, std = np.mean(value), np.std(value)
178 | logger.info(f"{key}: {mean:.4f} ± {std:.4f}")
179 | logs[f"fin/{'_'.join(key.split('/'))}_mean"] = mean
180 | logs[f"fin/{'_'.join(key.split('/'))}_std"] = std
181 |
182 | if wandb_logging:
183 | wandb.log(logs)
184 | wandb.finish()
185 |
186 |
187 | if __name__ == "__main__":
188 | main()
189 |
--------------------------------------------------------------------------------