├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── configs ├── defaults.py └── test.yaml ├── dataloader ├── __init__.py └── feat_bag_dataset.py ├── dist_train_modal_fusion.py ├── figs └── arch.png ├── metrics ├── __init__.py ├── epoch_metric.py └── roc_auc.py ├── models ├── __init__.py ├── effnet.py ├── fusion.py ├── mil_net.py └── tabnet │ ├── __init__.py │ ├── sparsemax.py │ ├── tab_network.py │ └── utils.py ├── preprocessing ├── extract_feat_with_tta.py └── merge_patch_feat.py ├── sampledata.7z └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | .idea/ 4 | vis 5 | vis/ 6 | result/ 7 | *.py[cod] 8 | *$py.class 9 | data/* 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | .dmypy.json 116 | dmypy.json 117 | 118 | # Pyre type checker 119 | .pyre/ 120 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MMMI 2 | 3 | 4 | 5 | This repository is an official PyTorch implementation of the paper 6 | **"Multi-modal Multi-instance Learning using Weakly Correlated Histopathological Images and Tabular Clinical Information"** [paper](https://link.springer.com/chapter/10.1007/978-3-030-87237-3_51) 7 | from **MICCAI 2021**. 8 | 9 | ![](./figs/arch.png) 10 | 11 | ## Citation 12 | ```bibtex 13 | @inproceedings{li2021multi, 14 | title={Multi-modal Multi-instance Learning Using Weakly Correlated Histopathological Images and Tabular Clinical Information}, 15 | author={Li, Hang and Yang, Fan and Xing, Xiaohan and Zhao, Yu and Zhang, Jun and Liu, Yueping and Han, Mengxue and Huang, Junzhou and Wang, Liansheng and Yao, Jianhua}, 16 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 17 | pages={529--539}, 18 | year={2021}, 19 | organization={Springer} 20 | } 21 | ``` 22 | ## Installation 23 | ### Dependencies 24 | * Python 3.6 25 | * PyTorch >= 1.5.0 26 | * einops 27 | * numpy 28 | * scipy 29 | * sklearn 30 | * openslide 31 | * albumentations 32 | * opencv 33 | * efficientnet_pytorch 34 | * yacs 35 | 36 | ## Usage 37 | 38 | ### Evaluation 39 | ```shell script 40 | python3 -m torch.distributed.launch \ 41 | --nproc_per_node 4 \ 42 | --master_port=XXXX \ 43 | dist_train_modal_fusion.py \ 44 | --cfg configs/test.yaml \ 45 | 46 | ``` 47 | XXXX denotes your master port 48 | 49 | 50 | ### Note for data 51 | 52 | We provide sample data for testing the pipeline, but they are not real data used in the experiments. 53 | 54 | The dataset used in this research is private. 55 | 56 | You can refer the code in `preprocessing/extract_feat_with_tta.py` to process your own data. 57 | 58 | #### 1. Offline-feature extraction 59 | 60 | The features are extracted using EfficientNet-b0. 61 | The file is saved in the format: 62 | 63 | ``` 64 | { 65 | 'tr': np.ndarray, shape(n, 1280) 66 | 'val': np.ndarray, shape(n, 1280) 67 | } 68 | ``` 69 | 70 | #### 2. Combine feature files 71 | Combine the features of all Patches within each WSI into a single file to reduce the IO overhead during training. 72 | 73 | ```bash 74 | python3 preprocessing/merge_patch_feat.py 75 | ``` 76 | 77 | ## Disclaimer 78 | 79 | This tool is for research purpose and not approved for clinical use. 80 | 81 | This is not an official Tencent product. -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfzon/Multi-modal-Multi-instance-Learning/c30051bc87c8d40ec93d29ca216ae088816ac6b5/__init__.py -------------------------------------------------------------------------------- /configs/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _C = CN(new_allowed=True) -------------------------------------------------------------------------------- /configs/test.yaml: -------------------------------------------------------------------------------- 1 | dataset: 2 | name: 'mil-multi-modal-x20' 3 | df_path: './sampledata/sample_tabnet_feat.csv' 4 | tab_data_path: './sampledata/sample.pickle' 5 | 6 | scale1_feat_root: './sampledata' 7 | scale2_feat_root: './sampledata' 8 | scale3_feat_root: './sampledata' 9 | select_scale: 0 10 | 11 | 12 | 13 | 14 | 15 | model: 16 | arch: "attention-fusion-net" 17 | 18 | input_dim: 1280 19 | attention_dim: 256 20 | attention_out_dim: 1 21 | instance_attention_layers: 1 22 | feature_attention_layers: 1 23 | feature_represent_layers: 1 24 | num_modal: 4 25 | use_tabnet: true 26 | 27 | 28 | 29 | 30 | 31 | train: 32 | batch_size_per_gpu: 1 33 | num_epoch: 240 34 | start_epoch: 241 35 | epoch_iters: 5000 36 | optim: "Adam" 37 | lr: 0.02 38 | lr_pow: 0.9 39 | beta1: 0.9 40 | weight_decay: 1e-3 41 | fix_bn: False 42 | workers: 8 43 | disp_iter: 200 44 | seed: 304 45 | 46 | 47 | test: 48 | checkpoint: "./checkpoints/model_epoch_171.pth" 49 | result: "./" 50 | 51 | save_dir: "./results" 52 | 53 | local_rank: -1 -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfzon/Multi-modal-Multi-instance-Learning/c30051bc87c8d40ec93d29ca216ae088816ac6b5/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/feat_bag_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import pickle 3 | from typing import Dict, Any 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | import torch.utils.data as data_utils 9 | from rich import print 10 | 11 | 12 | class FeatBagDataset(data_utils.Dataset): 13 | def __init__(self, bag_feat_root: str, bag_feat_aug_root: str, df: pd.DataFrame, cfg, shuffle_bag=False, 14 | is_train=False, local_rank=0) -> None: 15 | self.pids = df['pid'].values.tolist() 16 | self.bag_feat_root = bag_feat_root 17 | self.bag_feat_aug_root = bag_feat_aug_root 18 | self.is_train = is_train 19 | 20 | miss_cnt = 0 21 | 22 | targets = df['target'].values.tolist() 23 | 24 | exist_targets = [] 25 | exist_pids = [] 26 | 27 | for idx, pid in enumerate(self.pids): 28 | bag_fp = osp.join(self.bag_feat_aug_root, f'{pid}.pkl') 29 | if osp.exists(bag_fp): # and len(os.listdir(bag_fp)) > 0: 30 | exist_pids.append(pid) 31 | exist_targets.append(targets[idx]) 32 | else: 33 | miss_cnt += 1 34 | 35 | if cfg.local_rank == 0: 36 | print(f'Total : {len(self.pids)}, found {len(exist_pids)}, miss {miss_cnt}') 37 | self.pids = exist_pids 38 | self.targets = exist_targets 39 | self.shuffle_bag = shuffle_bag 40 | self.if_shuffled = False 41 | self.local_rank = local_rank 42 | 43 | def __len__(self): 44 | return len(self.pids) 45 | 46 | def __getitem__(self, idx) -> Dict: 47 | if (not self.if_shuffled) and self.is_train: 48 | np.random.seed(self.local_rank) 49 | p = np.random.permutation(len(self.pids)) 50 | self.pids = np.array(self.pids)[p] 51 | self.targets = np.array(self.targets)[p] 52 | self.if_shuffled = True 53 | 54 | c_pid = self.pids[idx] 55 | label = self.targets[idx] 56 | 57 | bag_fp = osp.join(self.bag_feat_aug_root, f'{c_pid}.pkl') 58 | pid_dir = osp.join(self.bag_feat_aug_root, c_pid) 59 | 60 | with open(bag_fp, 'rb') as infile: 61 | bag_feat_list_obj = pickle.load(infile) 62 | 63 | bag_feat = [] 64 | for aug_feat_dict in bag_feat_list_obj: 65 | if self.is_train: 66 | aug_feat = aug_feat_dict['tr'] 67 | aug_feat = np.vstack([aug_feat, np.expand_dims(aug_feat_dict['val'], 0)]) 68 | random_row = np.random.randint(0, aug_feat.shape[0]) 69 | choice_feat = aug_feat[random_row] 70 | bag_feat.append(choice_feat) 71 | else: 72 | aug_feat = aug_feat_dict['val'] 73 | bag_feat.append(aug_feat) 74 | 75 | bag_feat = np.vstack(bag_feat) 76 | 77 | if self.is_train and bag_feat.shape[0] <= 1: 78 | # print(f'Only one instance in {bag_fp}') 79 | rand_choose_idx = np.random.randint(0, len(self)) 80 | return self[rand_choose_idx] 81 | 82 | if self.shuffle_bag: 83 | instance_size = bag_feat.shape[0] 84 | shuffle_idx = np.random.permutation(instance_size) 85 | bag_feat = bag_feat[shuffle_idx] 86 | 87 | num_of_drop_columns = np.random.randint(0, 10) 88 | for _ in range(num_of_drop_columns): 89 | random_drop_column = np.random.randint(0, bag_feat.shape[1]) 90 | bag_feat[:, random_drop_column] = 0 91 | 92 | if np.random.rand() < 0.3: 93 | noise = np.random.normal(loc=0, scale=0.05, size=bag_feat.shape) 94 | bag_feat += noise 95 | 96 | return { 97 | 'data': torch.from_numpy(bag_feat).float(), 98 | 'name': c_pid, 99 | 'label': torch.tensor(label).float(), 100 | 'pid': c_pid 101 | } 102 | 103 | 104 | class FeatBagTabFeatDataset(data_utils.Dataset): 105 | """ 106 | 同时加载WSI的特征和Tabnetj计算后的特征 107 | """ 108 | 109 | def __init__(self, 110 | bag_feat_root: str, bag_feat_aug_root: str, 111 | df: pd.DataFrame, cfg, shuffle_bag=False, 112 | is_train=False, local_rank=0) -> None: 113 | """ 114 | 115 | Parameters 116 | ---------- 117 | bag_feat_root 118 | bag_feat_aug_root: 整合后的离线特征路径 119 | df: 120 | cfg: 121 | shuffle_bag: 122 | is_train: 123 | local_rank: 124 | """ 125 | self.pids = df['pid'].values.tolist() 126 | self.bag_feat_root = bag_feat_root 127 | self.bag_feat_aug_root = bag_feat_aug_root 128 | self.is_train = is_train 129 | 130 | self.tab_feat_df = pd.read_csv(cfg.dataset.tab_feat_df_fp) 131 | self.tab_feat_cols = [x for x in self.tab_feat_df.columns if x.startswith('feat')] 132 | 133 | miss_cnt = 0 134 | 135 | targets = df['target'].values.tolist() 136 | 137 | exist_targets = [] 138 | exist_pids = [] 139 | 140 | for idx, pid in enumerate(self.pids): 141 | bag_fp = osp.join(self.bag_feat_aug_root, f'{pid}.pkl') 142 | if osp.exists(bag_fp): # and len(os.listdir(bag_fp)) > 0: 143 | exist_pids.append(pid) 144 | exist_targets.append(targets[idx]) 145 | else: 146 | miss_cnt += 1 147 | 148 | if cfg.local_rank == 0: 149 | print(f'Tab feat : {len(self.tab_feat_cols)}') 150 | print(f'Total : {len(self.pids)}, found {len(exist_pids)}, miss {miss_cnt}') 151 | self.pids = exist_pids 152 | self.targets = exist_targets 153 | self.shuffle_bag = shuffle_bag 154 | self.if_shuffled = False 155 | self.local_rank = local_rank 156 | 157 | def __len__(self): 158 | return len(self.pids) 159 | 160 | def __getitem__(self, idx) -> Dict: 161 | if (not self.if_shuffled) and self.is_train: 162 | np.random.seed(self.local_rank) 163 | p = np.random.permutation(len(self.pids)) 164 | self.pids = np.array(self.pids)[p] 165 | self.targets = np.array(self.targets)[p] 166 | self.if_shuffled = True 167 | 168 | c_pid = self.pids[idx] 169 | label = self.targets[idx] 170 | 171 | bag_fp = osp.join(self.bag_feat_aug_root, f'{c_pid}.pkl') 172 | 173 | with open(bag_fp, 'rb') as infile: 174 | bag_feat_list_obj = pickle.load(infile) 175 | 176 | bag_feat = [] 177 | for aug_feat_dict in bag_feat_list_obj: 178 | if self.is_train: 179 | aug_feat = aug_feat_dict['tr'] 180 | aug_feat = np.vstack([aug_feat, np.expand_dims(aug_feat_dict['val'], 0)]) 181 | random_row = np.random.randint(0, aug_feat.shape[0]) 182 | choice_feat = aug_feat[random_row] 183 | bag_feat.append(choice_feat) 184 | else: 185 | aug_feat = aug_feat_dict['val'] 186 | bag_feat.append(aug_feat) 187 | 188 | bag_feat = np.vstack(bag_feat) 189 | 190 | if self.is_train and bag_feat.shape[0] <= 1: 191 | 192 | rand_choose_idx = np.random.randint(0, len(self)) 193 | return self[rand_choose_idx] 194 | 195 | if self.shuffle_bag: 196 | instance_size = bag_feat.shape[0] 197 | shuffle_idx = np.random.permutation(instance_size) 198 | bag_feat = bag_feat[shuffle_idx] 199 | 200 | num_of_drop_columns = np.random.randint(0, 10) 201 | for _ in range(num_of_drop_columns): 202 | random_drop_column = np.random.randint(0, bag_feat.shape[1]) 203 | bag_feat[:, random_drop_column] = 0 204 | 205 | if np.random.rand() < 0.3: 206 | noise = np.random.normal(loc=0, scale=0.05, size=bag_feat.shape) 207 | bag_feat += noise 208 | 209 | tab_feat = self.tab_feat_df[self.tab_feat_df.pid == c_pid][self.tab_feat_cols].values[0] 210 | 211 | return { 212 | 'data': torch.from_numpy(bag_feat).float(), 213 | 'name': c_pid, 214 | 'tab_feat': torch.from_numpy(tab_feat).float(), 215 | 'label': torch.tensor(label).float(), 216 | 'pid': c_pid 217 | } 218 | 219 | 220 | 221 | class ModalFusionDataset(data_utils.Dataset): 222 | """" 223 | Load multi-modal data 224 | """ 225 | def __init__(self, 226 | cli_feat: pd.DataFrame, 227 | cli_data: pd.DataFrame, 228 | scale1_feat_root, scale2_feat_root: None, scale3_feat_root: None, 229 | select_scale: int, 230 | cfg: Any, 231 | shuffle_bag=False, is_train=False): 232 | """ 233 | 234 | Args: 235 | cli_feat: 236 | cli_data: 237 | scale1_feat_root: 238 | scale2_feat_root: 239 | scale3_feat_root: 240 | select_scale: 241 | cfg: 242 | shuffle_bag: 243 | is_train: 244 | """ 245 | super(ModalFusionDataset, self).__init__() 246 | 247 | self.pids = cli_feat.pid.values.tolist() 248 | self.targets = cli_feat.target.values.tolist() 249 | 250 | self.cli_feat_df = cli_feat 251 | self.cli_feat_cols = [x for x in cli_feat.columns if x.startswith('feat')] 252 | 253 | self.cli_data_df = cli_data 254 | self.cli_data_cols = [x for x in cli_data.columns if x not in ['pid', 'split', 'target']] 255 | 256 | self.scale1_feat_root = scale1_feat_root 257 | self.scale2_feat_root = scale2_feat_root 258 | self.scale3_feat_root = scale3_feat_root 259 | self.select_scale = select_scale 260 | 261 | self.cfg = cfg 262 | self.shuffle_bag = shuffle_bag 263 | self.is_train = is_train 264 | 265 | exist_targets = [] 266 | exist_pids = [] 267 | miss_cnt = 0 268 | for idx, pid in enumerate(self.pids): 269 | bag_fp = osp.join(self.scale1_feat_root, f'{pid}.pkl') 270 | if osp.exists(bag_fp): 271 | exist_pids.append(pid) 272 | exist_targets.append(self.targets[idx]) 273 | else: 274 | miss_cnt += 1 275 | 276 | if cfg.local_rank == 0: 277 | print(f'Tab feat : {len(self.cli_feat_cols)}') 278 | print(f'Total : {len(self.pids)}, found {len(exist_pids)}, miss {miss_cnt}') 279 | 280 | self.pids = exist_pids 281 | self.targets = exist_targets 282 | 283 | @property 284 | def tab_data_shape(self): 285 | return len(self.cli_data_cols) 286 | 287 | def __len__(self): 288 | return len(self.pids) 289 | 290 | def load_feat_and_aug(self, bag_fp) -> np.ndarray: 291 | """ 292 | Load WSI feature bag 293 | Args: 294 | bag_fp: 295 | 296 | Returns: 297 | """ 298 | with open(bag_fp, 'rb') as infile: 299 | bag_feat_list_obj = pickle.load(infile) 300 | 301 | bag_feat = [] 302 | feat_names = [] 303 | for aug_feat_dict in bag_feat_list_obj: 304 | if self.is_train: 305 | aug_feat = aug_feat_dict['tr'] 306 | aug_feat = np.vstack([aug_feat, np.expand_dims(aug_feat_dict['val'], 0)]) 307 | random_row = np.random.randint(0, aug_feat.shape[0]) 308 | choice_feat = aug_feat[random_row] 309 | bag_feat.append(choice_feat) 310 | else: 311 | aug_feat = aug_feat_dict['val'] 312 | bag_feat.append(aug_feat) 313 | 314 | feat_names.append(aug_feat_dict['feat_name']) 315 | 316 | del bag_feat_list_obj 317 | bag_feat = np.vstack(bag_feat) 318 | 319 | 320 | 321 | if self.is_train: 322 | if np.random.rand() < 0.5: 323 | num_of_drop_columns = np.random.randint(0, 100) 324 | for _ in range(num_of_drop_columns): 325 | random_drop_column = np.random.randint(0, bag_feat.shape[1]) 326 | bag_feat[:, random_drop_column] = 0 327 | if np.random.rand() < 0.5: 328 | noise = np.random.normal(loc=0, scale=0.01, size=bag_feat.shape) 329 | bag_feat += noise 330 | if self.shuffle_bag: 331 | instance_size = bag_feat.shape[0] 332 | shuffle_idx = np.random.permutation(instance_size) 333 | bag_feat = bag_feat[shuffle_idx] 334 | 335 | return bag_feat, feat_names 336 | 337 | def __getitem__(self, idx) -> Dict: 338 | c_pid = self.pids[idx] 339 | label = self.targets[idx] 340 | ret = {} 341 | 342 | if self.select_scale == 0: 343 | for idx, feat_root in enumerate([self.scale1_feat_root, self.scale2_feat_root, self.scale3_feat_root]): 344 | bag_fp = osp.join(feat_root, f'{c_pid}.pkl') 345 | if osp.exists(bag_fp): 346 | bag_feat, feat_name = self.load_feat_and_aug(bag_fp) 347 | else: 348 | bag_feat = np.zeros((1, 1280)) 349 | feat_name = [] 350 | k = f'wsi_feat_scale{idx+1}' 351 | ret[k] = torch.from_numpy(bag_feat).float() 352 | ret[k+'_feat_name'] = feat_name 353 | else: 354 | if self.select_scale == 1: 355 | feat_root = self.scale1_feat_root 356 | elif self.select_scale == 2: 357 | feat_root = self.scale2_feat_root 358 | elif self.select_scale == 3: 359 | feat_root = self.scale3_feat_root 360 | 361 | bag_fp = osp.join(feat_root, f'{c_pid}.pkl') 362 | if osp.exists(bag_fp): 363 | bag_feat, feat_name = self.load_feat_and_aug(bag_fp) 364 | else: 365 | bag_feat = np.zeros((1, 1280)) 366 | feat_name = [] 367 | ret['wsi_feat_scale1'] = torch.from_numpy(bag_feat).float() 368 | ret['wsi_feat_scale1_feat_name'] = feat_name 369 | 370 | tab_feat = self.cli_feat_df[self.cli_feat_df.pid == c_pid][self.cli_feat_cols].values[0] 371 | tab_data = self.cli_data_df[self.cli_data_df.pid == c_pid][self.cli_data_cols].values[0] 372 | ret['name'] = c_pid 373 | ret['tab_feat'] = torch.from_numpy(tab_feat).float() 374 | ret['tab_data'] = torch.from_numpy(tab_data).float() 375 | ret['label'] = torch.tensor(label).float() 376 | ret['pid'] = c_pid 377 | return ret 378 | 379 | 380 | def mixup_data(x, alpha=1.0, use_cuda=False): 381 | '''Returns mixed inputs, pairs of targets, and lambda''' 382 | if alpha > 0: 383 | lam = np.random.beta(alpha, alpha) 384 | else: 385 | lam = 1 386 | 387 | batch_size = x.shape[0] 388 | if use_cuda: 389 | index = torch.randperm(batch_size).cuda() 390 | else: 391 | index = torch.randperm(batch_size) 392 | 393 | mixed_x = lam * x + (1 - lam) * x[index, :] 394 | return mixed_x 395 | 396 | class EMPOBJ: 397 | def __init__(self): 398 | self.local_rank = 0 399 | 400 | if __name__ == '__main__': 401 | import pandas as pd 402 | 403 | df = pd.read_csv('/path/to/your/table') 404 | bag_feat_root = "path/to/your/bag" 405 | cfg = EMPOBJ() 406 | from rich.progress import track 407 | cfg.local_rank = 0 408 | ds = ModalFusionDataset( 409 | cli_feat=df, 410 | scale1_feat_root='path/to/your/scale1/features', 411 | scale2_feat_root='path/to/your/scale2/features', 412 | scale3_feat_root='path/to/your/scale3/features', 413 | select_scale=0, 414 | cfg=cfg, 415 | shuffle_bag=True, 416 | is_train=True 417 | ) 418 | local_rank = 0 419 | dl = data_utils.DataLoader(ds, num_workers=4) 420 | for data in track(dl): 421 | tab_feat = data['tab_feat'].cuda(local_rank) 422 | wsi_feat_scale1 = data['wsi_feat_scale1'].cuda(local_rank) 423 | -------------------------------------------------------------------------------- /dist_train_modal_fusion.py: -------------------------------------------------------------------------------- 1 | # System libs 2 | import datetime 3 | 4 | import os 5 | import os.path as osp 6 | 7 | import random 8 | import pickle 9 | import argparse 10 | 11 | # Numerical libs 12 | import torch 13 | import torch.nn as nn 14 | 15 | import torch.utils.data as data_utils 16 | import pandas as pd 17 | import numpy as np 18 | 19 | # Our libs 20 | from configs.defaults import _C as train_config 21 | from utils import setup_logger 22 | 23 | import torch.distributed as dist 24 | from models.mil_net import MILFusion 25 | from dataloader.feat_bag_dataset import ModalFusionDataset 26 | 27 | from metrics import ROC_AUC 28 | 29 | 30 | import matplotlib 31 | from typing import Tuple 32 | 33 | old_print = print 34 | from rich import print 35 | 36 | 37 | matplotlib.use("Agg") 38 | 39 | def print_in_main_thread(msg: str,): 40 | if local_rank == 0: 41 | print(msg) 42 | 43 | def log_in_main_thread(msg: str): 44 | if local_rank == 0: 45 | logger.info(msg) 46 | 47 | 48 | def evaluate(model: nn.Module, val_loader, epoch, local_rank, final_test=False, dump_dir=None) -> Tuple[float, float, float]: 49 | """ 50 | distributed method for model inference, meter will automatically deal with the sync of multiple gpus 51 | Parameters 52 | ---------- 53 | model 54 | val_loader 55 | epoch 56 | local_rank 57 | 58 | Returns 59 | ------- 60 | 61 | """ 62 | 63 | auc_meter = ROC_AUC() 64 | model.eval() 65 | 66 | start_time = datetime.datetime.now() 67 | print(f'Start test at {start_time} at {local_rank}') 68 | 69 | with torch.no_grad(): 70 | for batch_nb, batch_data in enumerate(val_loader): 71 | if local_rank == 0: 72 | old_print(f'\r {batch_nb} / {len(val_loader)} ', end='') 73 | 74 | label = batch_data['label'].cuda(local_rank) 75 | label = label.view(label.size(0), 1).float() 76 | output, loss, __ = model(batch_data) 77 | 78 | auc_meter.update([torch.sigmoid(output.detach()).cpu().view(-1), label.view(-1).detach().cpu()]) 79 | 80 | print() 81 | 82 | end_time = datetime.datetime.now() 83 | print(f'End test at {end_time} at {local_rank}') 84 | all_pred = auc_meter.predictions 85 | print(all_pred) 86 | 87 | 88 | 89 | def main(cfg, local_rank): 90 | """ 91 | build 92 | prepare model training 93 | :param cfg: 94 | :param local_rank: 95 | :return: 96 | """ 97 | if local_rank == 0: 98 | logger.info(f'Build model') 99 | 100 | with open(cfg.dataset.tab_data_path, 'rb') as infile: 101 | tab_data = pickle.load(infile) 102 | cat_dims = tab_data['cat_dims'] 103 | cat_idxs = tab_data['cat_idxs'] 104 | 105 | tab_data_df = pd.read_csv(cfg.dataset.tab_data_path.rsplit('.', 1)[0] + '.csv') 106 | 107 | 108 | df_path = cfg.dataset.df_path 109 | if local_rank == 0: 110 | print(f'Load df from {os.path.abspath(df_path)}') 111 | df = pd.read_csv(df_path) 112 | test_df = df[df.split == 'test'] 113 | 114 | test_data_df = tab_data_df[tab_data_df.split == 'test'] 115 | 116 | if local_rank == 0: 117 | logger.info(f'Build dataset') 118 | 119 | """build dataset""" 120 | 121 | test_dataset = ModalFusionDataset( 122 | cli_feat=test_df, 123 | cli_data=test_data_df, 124 | scale1_feat_root=cfg.dataset.scale1_feat_root, 125 | scale2_feat_root=cfg.dataset.scale2_feat_root, 126 | scale3_feat_root=cfg.dataset.scale3_feat_root, 127 | select_scale=cfg.dataset.select_scale, 128 | cfg=cfg, 129 | shuffle_bag=False, 130 | is_train=False 131 | ) 132 | 133 | log_in_main_thread('Dataset load finish') 134 | test_sampler = data_utils.distributed.DistributedSampler(test_dataset, rank=local_rank) 135 | 136 | num_workers = cfg.train.workers 137 | 138 | test_loader = data_utils.DataLoader( 139 | test_dataset, 140 | batch_size=1, 141 | num_workers=num_workers, 142 | drop_last=False, 143 | shuffle=False, 144 | pin_memory=False, 145 | sampler=test_sampler 146 | ) 147 | 148 | """build model""" 149 | log_in_main_thread('Build model') 150 | if hasattr(cfg.model, 'fusion_method'): 151 | fusion = cfg.model.fusion_method 152 | else: 153 | fusion = 'mmtm' 154 | if hasattr(cfg.model, 'use_k_agg'): 155 | use_k_agg = cfg.model.use_k_agg 156 | k_agg = cfg.model.k_agg 157 | else: 158 | use_k_agg = False 159 | k_agg = 10 160 | 161 | if cfg.model.arch == 'm3d': 162 | logger.info(f'Adapt m3d') 163 | from models.mil_net import M3D 164 | model = M3D(img_feat_input_dim=1280, 165 | tab_feat_input_dim=32, 166 | img_feat_rep_layers=4, 167 | num_modal=cfg.model.num_modal, 168 | fusion=fusion, 169 | use_tabnet=cfg.model.use_tabnet, 170 | use_k_agg=use_k_agg, 171 | k_agg=k_agg, 172 | tab_indim=test_dataset.tab_data_shape, 173 | cat_dims=cat_dims, 174 | cat_idxs=cat_idxs, 175 | local_rank=local_rank) 176 | elif cfg.model.arch == 'attention_refine': 177 | logger.info(f'attention_refine') 178 | from models.mil_net import MILFusionAppend 179 | model = MILFusionAppend(img_feat_input_dim=1280, 180 | tab_feat_input_dim=32, 181 | img_feat_rep_layers=4, 182 | num_modal=cfg.model.num_modal, 183 | fusion=fusion, 184 | use_tabnet=cfg.model.use_tabnet, 185 | use_k_agg=use_k_agg, 186 | k_agg=k_agg, 187 | tab_indim=test_dataset.tab_data_shape, 188 | cat_dims=cat_dims, 189 | cat_idxs=cat_idxs, 190 | local_rank=local_rank) 191 | elif cfg.model.arch == 'attention_add': 192 | logger.info(f'attention_add') 193 | from models.mil_net import MILFusionAdd 194 | model = MILFusionAdd(img_feat_input_dim=1280, 195 | tab_feat_input_dim=32, 196 | img_feat_rep_layers=4, 197 | num_modal=cfg.model.num_modal, 198 | fusion=fusion, 199 | use_tabnet=cfg.model.use_tabnet, 200 | use_k_agg=use_k_agg, 201 | k_agg=k_agg, 202 | tab_indim=test_dataset.tab_data_shape, 203 | cat_dims=cat_dims, 204 | cat_idxs=cat_idxs, 205 | local_rank=local_rank) 206 | else: 207 | model = MILFusion(img_feat_input_dim=1280, 208 | tab_feat_input_dim=32, 209 | img_feat_rep_layers=4, 210 | num_modal=cfg.model.num_modal, 211 | fusion=fusion, 212 | use_tabnet=cfg.model.use_tabnet, 213 | use_k_agg=use_k_agg, 214 | k_agg=k_agg, 215 | tab_indim=test_dataset.tab_data_shape, 216 | cat_dims=cat_dims, 217 | cat_idxs=cat_idxs, 218 | local_rank=local_rank) 219 | 220 | 221 | model = model.cuda(local_rank) 222 | model = model.to(local_rank) 223 | 224 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) 225 | 226 | 227 | if local_rank == 0: 228 | logger.info(f'Start training') 229 | 230 | """ 231 | load best ckpt 232 | """ 233 | if hasattr(cfg.test, 'checkpoint'): 234 | ckpt_path = cfg.test.checkpoint 235 | if osp.exists(ckpt_path): 236 | bst_val_model_path = ckpt_path 237 | 238 | log_in_main_thread(f'Load model from {bst_val_model_path}') 239 | 240 | map_location = {'cuda:%d' % 0: 'cuda:%d' % local_rank} 241 | model.load_state_dict(torch.load(bst_val_model_path, map_location=map_location)) 242 | 243 | model.eval() 244 | 245 | evaluate(model, test_loader, cfg.train.num_epoch, local_rank, final_test=True, 246 | dump_dir=cfg.save_dir) 247 | 248 | def seed_everything(seed_value): 249 | random.seed(seed_value) 250 | np.random.seed(seed_value) 251 | torch.manual_seed(seed_value) 252 | os.environ['PYTHONHASHSEED'] = str(seed_value) 253 | 254 | if torch.cuda.is_available(): 255 | torch.cuda.manual_seed(seed_value) 256 | torch.cuda.manual_seed_all(seed_value) 257 | torch.backends.cudnn.deterministic = True 258 | torch.backends.cudnn.benchmark = False 259 | 260 | if __name__ == '__main__': 261 | 262 | parser = argparse.ArgumentParser( 263 | description="PyTorch WSI Multi modal training" 264 | ) 265 | parser.add_argument( 266 | "--cfg", 267 | metavar="FILE", 268 | help="path to config file", 269 | type=str, 270 | ) 271 | parser.add_argument( 272 | "opts", 273 | help="Modify config options using the command-line", 274 | default=None, 275 | nargs=argparse.REMAINDER, 276 | ) 277 | parser.add_argument('--local_rank', default=-1, type=int, 278 | help='node rank for distributed training') 279 | 280 | args = parser.parse_args() 281 | 282 | cfg = train_config 283 | cfg.merge_from_file(args.cfg) 284 | cfg.merge_from_list(args.opts) 285 | 286 | local_rank = args.local_rank 287 | # set dist 288 | torch.cuda.set_device(args.local_rank) 289 | dist.init_process_group(backend='nccl', rank=local_rank) 290 | 291 | print(f'local rank: {args.local_rank}') 292 | 293 | time_now = datetime.datetime.now() 294 | cfg.save_dir = osp.join(cfg.save_dir, 295 | f'{time_now.year}_{time_now.month}_{time_now.day}_{time_now.hour}_{time_now.minute}') 296 | 297 | 298 | if not os.path.isdir(cfg.save_dir): 299 | os.makedirs(cfg.save_dir, exist_ok=True) 300 | logger = setup_logger(distributed_rank=args.local_rank, filename=osp.join(cfg.save_dir, 'train_log.txt')) # TODO 301 | log_in_main_thread(f'Save result to : {cfg.save_dir}') 302 | 303 | 304 | if args.local_rank == 0: 305 | logger.info("Loaded configuration file {}".format(args.cfg)) 306 | logger.info("Running with config:\n{}".format(cfg)) 307 | logger.info("Outputing checkpoints to: {}".format(cfg.save_dir)) 308 | with open(os.path.join(cfg.save_dir, 'config.yaml'), 'w') as f: 309 | f.write("{}".format(cfg)) 310 | 311 | num_gpus = 1 312 | 313 | random.seed(cfg.train.seed) 314 | torch.manual_seed(cfg.train.seed) 315 | 316 | main(cfg, args.local_rank) 317 | 318 | 319 | -------------------------------------------------------------------------------- /figs/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfzon/Multi-modal-Multi-instance-Learning/c30051bc87c8d40ec93d29ca216ae088816ac6b5/figs/arch.png -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .roc_auc import ROC_AUC -------------------------------------------------------------------------------- /metrics/epoch_metric.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Sequence 2 | from functools import wraps 3 | import torch.distributed as dist 4 | import torch 5 | import pickle 6 | 7 | def all_gather(data): 8 | """ 9 | Run all_gather on arbitrary picklable data (not necessarily tensors) 10 | Args: 11 | data: any picklable object 12 | Returns: 13 | list[data]: list of data gathered from each rank 14 | """ 15 | world_size = dist.get_world_size() 16 | if world_size == 1: 17 | return [data] 18 | 19 | rank = dist.get_rank() 20 | device = torch.device('cuda', rank) 21 | 22 | 23 | # serialized to a Tensor 24 | buffer = pickle.dumps(data) 25 | storage = torch.ByteStorage.from_buffer(buffer) 26 | tensor = torch.ByteTensor(storage).to(device) 27 | 28 | # obtain Tensor size of each rank 29 | local_size = torch.LongTensor([tensor.numel()]).to(device) 30 | size_list = [torch.LongTensor([0]).to(device) for _ in range(world_size)] 31 | dist.all_gather(size_list, local_size) 32 | size_list = [int(size.item()) for size in size_list] 33 | max_size = max(size_list) 34 | 35 | tensor_list = [] 36 | for _ in size_list: 37 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to(device)) 38 | if local_size != max_size: 39 | padding = torch.ByteTensor(size=(max_size - local_size,)).to(device) 40 | tensor = torch.cat((tensor, padding), dim=0) 41 | dist.all_gather(tensor_list, tensor) 42 | 43 | data_list = [] 44 | for size, tensor in zip(size_list, tensor_list): 45 | buffer = tensor.cpu().numpy().tobytes()[:size] 46 | data_list.append(pickle.loads(buffer)) 47 | 48 | ret = [] 49 | for d in data_list: 50 | ret.append(d.clone().to(device)) 51 | return ret 52 | 53 | 54 | class NotComputableError(RuntimeError): 55 | """ 56 | Exception class to raise if Metric cannot be computed. 57 | """ 58 | 59 | 60 | def reinit__is_reduced(func: Callable) -> Callable: 61 | """Helper decorator for distributed configuration. 62 | See :doc:`metrics` on how to use it. 63 | """ 64 | 65 | @wraps(func) 66 | def wrapper(self, *args, **kwargs): 67 | func(self, *args, **kwargs) 68 | self._is_reduced = False 69 | 70 | wrapper._decorated = True 71 | return wrapper 72 | 73 | 74 | class EpochMetric: 75 | """Class for metrics that should be computed on the entire output history of a model. 76 | Model's output and targets are restricted to be of shape `(batch_size, n_classes)`. Output 77 | datatype should be `float32`. Target datatype should be `long`. 78 | .. warning:: 79 | Current implementation stores all input data (output and target) in as tensors before computing a metric. 80 | This can potentially lead to a memory error if the input data is larger than available RAM. 81 | .. warning:: 82 | Current implementation does not work with distributed computations. Results are not gather across all devices 83 | and computed results are valid for a single device only. 84 | - `update` must receive output of the form `(y_pred, y)` or `{'y_pred': y_pred, 'y': y}`. 85 | If target shape is `(batch_size, n_classes)` and `n_classes > 1` than it should be binary: e.g. `[[0, 1, 0, 1], ]`. 86 | Args: 87 | compute_fn (callable): a callable with the signature (`torch.tensor`, `torch.tensor`) takes as the input 88 | `predictions` and `targets` and returns a scalar. 89 | output_transform (callable, optional): a callable that is used to transform the 90 | :class:`~ignite.engine.Engine`'s `process_function`'s output into the 91 | form expected by the metric. This can be useful if, for example, you have a multi-output model and 92 | you want to compute the metric with respect to one of the outputs. 93 | """ 94 | 95 | def __init__(self, compute_fn: Callable, output_transform: Callable = lambda x: x): 96 | 97 | if not callable(compute_fn): 98 | raise TypeError("Argument compute_fn should be callable.") 99 | 100 | self._is_reduced = False 101 | self.compute_fn = compute_fn 102 | self.reset() 103 | 104 | @reinit__is_reduced 105 | def reset(self) -> None: 106 | self._predictions = [] 107 | self._targets = [] 108 | 109 | def _check_shape(self, output): 110 | y_pred, y = output 111 | if y_pred.ndimension() not in (1, 2): 112 | raise ValueError("Predictions should be of shape (batch_size, n_classes) or (batch_size, ).") 113 | 114 | if y.ndimension() not in (1, 2): 115 | raise ValueError("Targets should be of shape (batch_size, n_classes) or (batch_size, ).") 116 | 117 | if y.ndimension() == 2: 118 | if not torch.equal(y ** 2, y): 119 | raise ValueError("Targets should be binary (0 or 1).") 120 | 121 | def _check_type(self, output): 122 | y_pred, y = output 123 | if len(self._predictions) < 1: 124 | return 125 | dtype_preds = self._predictions[-1].type() 126 | if dtype_preds != y_pred.type(): 127 | raise ValueError( 128 | "Incoherent types between input y_pred and stored predictions: " 129 | "{} vs {}".format(dtype_preds, y_pred.type()) 130 | ) 131 | 132 | dtype_targets = self._targets[-1].type() 133 | if dtype_targets != y.type(): 134 | raise ValueError( 135 | "Incoherent types between input y and stored targets: " "{} vs {}".format(dtype_targets, y.type()) 136 | ) 137 | 138 | @reinit__is_reduced 139 | def update(self, output: Sequence[torch.Tensor]) -> None: 140 | self._check_shape(output) 141 | y_pred, y = output[0].detach(), output[1].detach() 142 | 143 | if y_pred.ndimension() == 2 and y_pred.shape[1] == 1: 144 | y_pred = y_pred.squeeze(dim=-1) 145 | 146 | if y.ndimension() == 2 and y.shape[1] == 1: 147 | y = y.squeeze(dim=-1) 148 | 149 | y_pred = y_pred.clone().to(y_pred.device) 150 | y = y.clone().to(y_pred.device) 151 | 152 | self._check_type((y_pred, y)) 153 | self._predictions.append(y_pred) 154 | self._targets.append(y) 155 | 156 | def compute(self) -> Sequence[torch.Tensor]: 157 | 158 | if len(self._predictions) < 1 or len(self._targets) < 1: 159 | raise NotComputableError("EpochMetric must have at least one example before it can be computed.") 160 | 161 | rank = dist.get_rank() 162 | device = torch.device('cuda', rank) 163 | 164 | _prediction_tensor = torch.cat(self._predictions, dim=0).to(device).view(-1) 165 | _target_tensor = torch.cat(self._targets, dim=0).to(device).view(-1) 166 | 167 | ws = dist.get_world_size() 168 | 169 | dist.barrier() 170 | if ws > 1 and not self._is_reduced: 171 | _prediction_output = all_gather(_prediction_tensor) 172 | _target_output = all_gather(_target_tensor) 173 | 174 | _prediction_tensor = torch.cat(_prediction_output, dim=0) 175 | _target_tensor = torch.cat(_target_output, dim=0) 176 | 177 | self._is_reduced = True 178 | _prediction_tensor = _prediction_tensor.cpu() 179 | _target_tensor = _target_tensor.cpu() 180 | 181 | result = torch.zeros(1).to(device) 182 | if dist.get_rank() == 0: 183 | # Run compute_fn on zero rank only 184 | result = self.compute_fn(_prediction_tensor, _target_tensor) 185 | 186 | result = torch.tensor(result.item()).to(device) 187 | if ws > 1: 188 | dist.broadcast(result, src=0) 189 | 190 | _prediction_tensor = _prediction_tensor.numpy() 191 | _target_tensor = _target_tensor.numpy() 192 | return result.item(), _prediction_tensor, _target_tensor 193 | 194 | @property 195 | def predictions(self): 196 | return self._predictions 197 | 198 | @property 199 | def targets(self): 200 | return self._targets 201 | -------------------------------------------------------------------------------- /metrics/roc_auc.py: -------------------------------------------------------------------------------- 1 | from .epoch_metric import EpochMetric 2 | 3 | def roc_auc_compute_fn(y_preds, y_targets): 4 | try: 5 | from sklearn.metrics import roc_auc_score 6 | except ImportError: 7 | raise RuntimeError("This contrib module requires sklearn to be installed.") 8 | 9 | y_true = y_targets.numpy() 10 | y_pred = y_preds.numpy() 11 | return roc_auc_score(y_true, y_pred) 12 | 13 | 14 | class ROC_AUC(EpochMetric): 15 | """Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC) 16 | accumulating predictions and the ground-truth during an epoch and applying 17 | `sklearn.metrics.roc_auc_score `_ . 19 | Args: 20 | output_transform (callable, optional): a callable that is used to transform the 21 | :class:`~ignite.engine.Engine`'s `process_function`'s output into the 22 | form expected by the metric. This can be useful if, for example, you have a multi-output model and 23 | you want to compute the metric with respect to one of the outputs. 24 | ROC_AUC expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or confidence 25 | values. To apply an activation to y_pred, use output_transform as shown below: 26 | .. code-block:: python 27 | def activated_output_transform(output): 28 | y_pred, y = output 29 | y_pred = torch.sigmoid(y_pred) 30 | return y_pred, y 31 | roc_auc = ROC_AUC(activated_output_transform) 32 | """ 33 | 34 | def __init__(self, output_transform=lambda x: x): 35 | super(ROC_AUC, self).__init__(roc_auc_compute_fn, output_transform=output_transform) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfzon/Multi-modal-Multi-instance-Learning/c30051bc87c8d40ec93d29ca216ae088816ac6b5/models/__init__.py -------------------------------------------------------------------------------- /models/effnet.py: -------------------------------------------------------------------------------- 1 | from efficientnet_pytorch import EfficientNet 2 | from torch import nn 3 | import torch 4 | 5 | efn_pretrained = { 6 | 'efficientnet-b0': '../pretrained/efficientnet-b0-355c32eb.pth', 7 | 3: '../pretrained/efficientnet-b3_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5', 8 | 4: '../pretrained/efficientnet-b4_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5', 9 | 5: '../pretrained/efficientnet-b5_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5', 10 | 6: '../pretrained/efficientnet-b6_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5', 11 | 7: '../pretrained/efficientnet-b7_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5' 12 | } 13 | 14 | class EffNet(nn.Module): 15 | def __init__(self, efname='efficientnet-b0'): 16 | super(EffNet, self).__init__() 17 | self.model = EfficientNet.from_name(efname) 18 | pretrain_model_fp = efn_pretrained[efname] 19 | print(f'Load pretrain model from {pretrain_model_fp}') 20 | self.model.load_state_dict(torch.load(pretrain_model_fp)) 21 | 22 | def forward(self, data): 23 | bs = data.shape[0] 24 | feat = self.model.extract_features(data) 25 | feat = nn.functional.adaptive_avg_pool2d(feat, output_size=(1)) 26 | feat = feat.view(bs, -1) 27 | return feat 28 | -------------------------------------------------------------------------------- /models/fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BilinearFusion(nn.Module): 6 | """ 7 | 8 | """ 9 | def __init__(self, skip=1, use_bilinear=1, gate1=1, gate2=1, dim1=32, dim2=32, scale_dim1=1, scale_dim2=1, mmhid=64, dropout_rate=0.25): 10 | super(BilinearFusion, self).__init__() 11 | self.skip = skip 12 | self.use_bilinear = use_bilinear 13 | self.gate1 = gate1 14 | self.gate2 = gate2 15 | 16 | dim1_og, dim2_og, dim1, dim2 = dim1, dim2, dim1//scale_dim1, dim2//scale_dim2 17 | skip_dim = dim1+dim2+2 if skip else 0 18 | 19 | self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU()) 20 | self.linear_z1 = nn.Bilinear(dim1_og, dim2_og, dim1) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim1)) 21 | self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate)) 22 | 23 | self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU()) 24 | self.linear_z2 = nn.Bilinear(dim1_og, dim2_og, dim2) if use_bilinear else nn.Sequential(nn.Linear(dim1_og+dim2_og, dim2)) 25 | self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate)) 26 | 27 | self.post_fusion_dropout = nn.Dropout(p=dropout_rate) 28 | self.encoder1 = nn.Sequential(nn.Linear((dim1+1)*(dim2+1), mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate)) 29 | self.encoder2 = nn.Sequential(nn.Linear(mmhid+skip_dim, mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate)) 30 | 31 | def forward(self, vec1, vec2): 32 | ### Gated Multimodal Units 33 | if self.gate1: 34 | h1 = self.linear_h1(vec1) 35 | z1 = self.linear_z1(vec1, vec2) if self.use_bilinear else self.linear_z1(torch.cat((vec1, vec2), dim=1)) 36 | o1 = self.linear_o1(nn.Sigmoid()(z1)*h1) 37 | else: 38 | o1 = self.linear_o1(vec1) 39 | 40 | if self.gate2: 41 | h2 = self.linear_h2(vec2) 42 | z2 = self.linear_z2(vec1, vec2) if self.use_bilinear else self.linear_z2(torch.cat((vec1, vec2), dim=1)) 43 | o2 = self.linear_o2(nn.Sigmoid()(z2)*h2) 44 | else: 45 | o2 = self.linear_o2(vec2) 46 | 47 | ### Fusion 48 | o1 = torch.cat((o1, torch.cuda.FloatTensor(o1.shape[0], 1).fill_(1)), 1) 49 | o2 = torch.cat((o2, torch.cuda.FloatTensor(o2.shape[0], 1).fill_(1)), 1) 50 | o12 = torch.bmm(o1.unsqueeze(2), o2.unsqueeze(1)).flatten(start_dim=1) # BATCH_SIZE X 1024 51 | out = self.post_fusion_dropout(o12) 52 | out = self.encoder1(out) 53 | if self.skip: out = torch.cat((out, o1, o2), 1) 54 | out = self.encoder2(out) 55 | return out 56 | 57 | 58 | class TrilinearFusion_B(nn.Module): 59 | def __init__(self, skip=1, use_bilinear=1, gate1=1, gate2=1, gate3=1, dim1=32, dim2=32, dim3=32, scale_dim1=1, 60 | scale_dim2=1, scale_dim3=1, mmhid=96, dropout_rate=0.25): 61 | super(TrilinearFusion_B, self).__init__() 62 | self.skip = skip 63 | self.use_bilinear = use_bilinear 64 | self.gate1 = gate1 65 | self.gate2 = gate2 66 | self.gate3 = gate3 67 | 68 | dim1_og, dim2_og, dim3_og, dim1, dim2, dim3 = dim1, dim2, dim3, dim1 // scale_dim1, dim2 // scale_dim2, dim3 // scale_dim3 69 | skip_dim = dim1 + dim2 + dim3 + 3 if skip else 0 70 | 71 | ### Path 72 | self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU()) 73 | self.linear_z1 = nn.Bilinear(dim1_og, dim3_og, dim1) if use_bilinear else nn.Sequential( 74 | nn.Linear(dim1_og + dim3_og, dim1)) 75 | self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate)) 76 | 77 | ### Graph 78 | self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU()) 79 | self.linear_z2 = nn.Bilinear(dim2_og, dim1_og, dim2) if use_bilinear else nn.Sequential( 80 | nn.Linear(dim2_og + dim1_og, dim2)) 81 | self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate)) 82 | 83 | ### Omic 84 | self.linear_h3 = nn.Sequential(nn.Linear(dim3_og, dim3), nn.ReLU()) 85 | self.linear_z3 = nn.Bilinear(dim1_og, dim3_og, dim3) if use_bilinear else nn.Sequential( 86 | nn.Linear(dim1_og + dim3_og, dim3)) 87 | self.linear_o3 = nn.Sequential(nn.Linear(dim3, dim3), nn.ReLU(), nn.Dropout(p=dropout_rate)) 88 | 89 | self.post_fusion_dropout = nn.Dropout(p=0.25) 90 | self.encoder1 = nn.Sequential(nn.Linear((dim1 + 1) * (dim2 + 1) * (dim3 + 1), mmhid), nn.ReLU(), 91 | nn.Dropout(p=dropout_rate)) 92 | self.encoder2 = nn.Sequential(nn.Linear(mmhid + skip_dim, mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate)) 93 | 94 | 95 | def forward(self, vec1, vec2, vec3): 96 | ### Gated Multimodal Units 97 | if self.gate1: 98 | h1 = self.linear_h1(vec1) 99 | z1 = self.linear_z1(vec1, vec3) if self.use_bilinear else self.linear_z1( 100 | torch.cat((vec1, vec3), dim=1)) # Gate Path with Omic 101 | o1 = self.linear_o1(nn.Sigmoid()(z1) * h1) 102 | else: 103 | o1 = self.linear_o1(vec1) 104 | 105 | if self.gate2: 106 | h2 = self.linear_h2(vec2) 107 | z2 = self.linear_z2(vec2, vec1) if self.use_bilinear else self.linear_z2( 108 | torch.cat((vec2, vec1), dim=1)) # Gate Graph with Omic 109 | o2 = self.linear_o2(nn.Sigmoid()(z2) * h2) 110 | else: 111 | o2 = self.linear_o2(vec2) 112 | 113 | if self.gate3: 114 | h3 = self.linear_h3(vec3) 115 | z3 = self.linear_z3(vec1, vec3) if self.use_bilinear else self.linear_z3( 116 | torch.cat((vec1, vec3), dim=1)) # Gate Omic With Path 117 | o3 = self.linear_o3(nn.Sigmoid()(z3) * h3) 118 | else: 119 | o3 = self.linear_o3(vec3) 120 | 121 | ### Fusion 122 | o1 = torch.cat((o1, torch.cuda.FloatTensor(o1.shape[0], 1).fill_(1)), 1) 123 | o2 = torch.cat((o2, torch.cuda.FloatTensor(o2.shape[0], 1).fill_(1)), 1) 124 | o3 = torch.cat((o3, torch.cuda.FloatTensor(o3.shape[0], 1).fill_(1)), 1) 125 | o12 = torch.bmm(o1.unsqueeze(2), o2.unsqueeze(1)).flatten(start_dim=1) 126 | o123 = torch.bmm(o12.unsqueeze(2), o3.unsqueeze(1)).flatten(start_dim=1) 127 | out = self.post_fusion_dropout(o123) 128 | out = self.encoder1(out) 129 | if self.skip: out = torch.cat((out, o1, o2, o3), 1) 130 | out = self.encoder2(out) 131 | return out 132 | 133 | 134 | 135 | class QualinearFusion(nn.Module): 136 | def __init__(self, skip=1, use_bilinear=1, 137 | gate1=1, gate2=1, gate3=1, gate4=1, 138 | dim1=32, dim2=32, dim3=32, dim4=32, 139 | scale_dim1=1,scale_dim2=1, scale_dim3=1, scale_dim4=1, 140 | mmhid=96, dropout_rate=0.25): 141 | super(QualinearFusion, self).__init__() 142 | self.skip = skip 143 | self.use_bilinear = use_bilinear 144 | self.gate1 = gate1 145 | self.gate2 = gate2 146 | self.gate3 = gate3 147 | self.gate4 = gate4 148 | 149 | dim1_og, dim2_og, dim3_og, dim4_og, dim1, dim2, dim3, dim4 = dim1, dim2, dim3, dim4, dim1 // scale_dim1, dim2 // scale_dim2, dim3 // scale_dim3, dim4 // scale_dim4 150 | # skip_dim = dim1 + dim2 + dim3 + 3 if skip else 0 151 | skip_dim = dim1 + dim2 + dim3 + dim4 + 4 if skip else 0 152 | 153 | ### tab 154 | self.linear_h1 = nn.Sequential(nn.Linear(dim1_og, dim1), nn.ReLU()) 155 | self.linear_z1 = nn.Bilinear(dim1_og, dim3_og, dim1) if use_bilinear else nn.Sequential( 156 | nn.Linear(dim1_og + dim3_og, dim1)) 157 | self.linear_o1 = nn.Sequential(nn.Linear(dim1, dim1), nn.ReLU(), nn.Dropout(p=dropout_rate)) 158 | 159 | ### scale1 160 | self.linear_h2 = nn.Sequential(nn.Linear(dim2_og, dim2), nn.ReLU()) 161 | self.linear_z2 = nn.Bilinear(dim2_og, dim1_og, dim2) if use_bilinear else nn.Sequential( 162 | nn.Linear(dim2_og + dim1_og, dim2)) 163 | self.linear_o2 = nn.Sequential(nn.Linear(dim2, dim2), nn.ReLU(), nn.Dropout(p=dropout_rate)) 164 | 165 | ### scale2 166 | self.linear_h3 = nn.Sequential(nn.Linear(dim3_og, dim3), nn.ReLU()) 167 | self.linear_z3 = nn.Bilinear(dim1_og, dim3_og, dim3) if use_bilinear else nn.Sequential( 168 | nn.Linear(dim1_og + dim3_og, dim3)) 169 | self.linear_o3 = nn.Sequential(nn.Linear(dim3, dim3), nn.ReLU(), nn.Dropout(p=dropout_rate)) 170 | 171 | ### scale3 172 | self.linear_h4 = nn.Sequential(nn.Linear(dim4_og, dim4), nn.ReLU()) 173 | self.linear_z4 = nn.Bilinear(dim1_og, dim3_og, dim4) if use_bilinear else nn.Sequential( 174 | nn.Linear(dim1_og + dim3_og, dim4)) 175 | self.linear_o4 = nn.Sequential(nn.Linear(dim4, dim4), nn.ReLU(), nn.Dropout(p=dropout_rate)) 176 | 177 | 178 | self.post_fusion_dropout = nn.Dropout(p=0.25) 179 | self.encoder1 = nn.Sequential(nn.Linear((dim1 + 1) * (dim2 + 1) * (dim3 + 1) * (dim4 + 1), mmhid), nn.ReLU(), 180 | nn.Dropout(p=dropout_rate)) 181 | self.encoder2 = nn.Sequential(nn.Linear(mmhid + skip_dim, mmhid), nn.ReLU(), nn.Dropout(p=dropout_rate)) 182 | 183 | 184 | def forward(self, vec1, vec2, vec3, vec4): 185 | ### Gated Multimodal Units 186 | print(vec1.shape, vec2.shape, vec3.shape, vec4.shape) 187 | if self.gate1: 188 | h1 = self.linear_h1(vec1) 189 | z1 = self.linear_z1(vec1, vec3) if self.use_bilinear else self.linear_z1( 190 | torch.cat((vec1, vec3), dim=1)) # Gate Path with Omic 191 | o1 = self.linear_o1(nn.Sigmoid()(z1) * h1) 192 | else: 193 | o1 = self.linear_o1(vec1) 194 | 195 | if self.gate2: 196 | h2 = self.linear_h2(vec2) 197 | z2 = self.linear_z2(vec2, vec1) if self.use_bilinear else self.linear_z2( 198 | torch.cat((vec2, vec1), dim=1)) # Gate Graph with Omic 199 | o2 = self.linear_o2(nn.Sigmoid()(z2) * h2) 200 | else: 201 | o2 = self.linear_o2(vec2) 202 | 203 | if self.gate3: 204 | h3 = self.linear_h3(vec3) 205 | z3 = self.linear_z3(vec1, vec3) if self.use_bilinear else self.linear_z3( 206 | torch.cat((vec1, vec3), dim=1)) # Gate Omic With Path 207 | o3 = self.linear_o3(nn.Sigmoid()(z3) * h3) 208 | else: 209 | o3 = self.linear_o3(vec3) 210 | 211 | if self.gate4: 212 | h4 = self.linear_h4(vec4) 213 | z4 = self.linear_z4(vec1, vec4) if self.use_bilinear else self.linear_z4( 214 | torch.cat((vec1, vec4), dim=1)) # Gate Omic With Path 215 | o4 = self.linear_o4(nn.Sigmoid()(z4) * h4) 216 | else: 217 | o4 = self.linear_o4(vec4) 218 | 219 | ### Fusion 220 | o1 = torch.cat((o1, torch.cuda.FloatTensor(o1.shape[0], 1).fill_(1)), 1) 221 | o2 = torch.cat((o2, torch.cuda.FloatTensor(o2.shape[0], 1).fill_(1)), 1) 222 | o3 = torch.cat((o3, torch.cuda.FloatTensor(o3.shape[0], 1).fill_(1)), 1) 223 | o4 = torch.cat((o4, torch.cuda.FloatTensor(o4.shape[0], 1).fill_(1)), 1) 224 | 225 | o12 = torch.bmm(o1.unsqueeze(2), o2.unsqueeze(1)).flatten(start_dim=1) 226 | o123 = torch.bmm(o12.unsqueeze(2), o3.unsqueeze(1)).flatten(start_dim=1) 227 | o1234 = torch.bmm(o123.unsqueeze(2), o4.unsqueeze(1)).flatten(start_dim=1) 228 | 229 | out = self.post_fusion_dropout(o1234) 230 | out = self.encoder1(out) 231 | if self.skip: out = torch.cat((out, o1, o2, o3, o4), 1) 232 | out = self.encoder2(out) 233 | return out -------------------------------------------------------------------------------- /models/mil_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from typing import Sequence 6 | from models.tabnet.tab_network import TabNet 7 | 8 | from .fusion import BilinearFusion 9 | from copy import deepcopy 10 | 11 | class Swish(nn.Module): 12 | def __init__(self): 13 | super(Swish, self).__init__() 14 | 15 | def forward(self, x): 16 | return x * torch.sigmoid(x) 17 | 18 | 19 | class MILNet(nn.Module): 20 | def __init__(self, input_dim=512, attention_dim=128, attention_out_dim=1, if_freeze_param=False, dropout_p=0.5, 21 | instance_attention_layers=1, instance_attention_dim=128, feature_attention_layers=1, 22 | feature_represent_dim=512, 23 | feature_represent_layers=1, 24 | feature_attention_dim=256): 25 | """ 26 | 27 | Args: 28 | input_dim: number of instance feature 29 | attention_dim: 30 | attention_out_dim: number of features after attention 31 | if_freeze_param: 32 | """ 33 | super(MILNet, self).__init__() # super(subClass, instance).method(args), python 2.x 34 | self.L = input_dim 35 | self.D = attention_dim 36 | self.K = attention_out_dim 37 | self.if_freeze_param = if_freeze_param 38 | 39 | """ 40 | instance attention layers 41 | """ 42 | attention_list = [ 43 | nn.Linear(feature_represent_dim, instance_attention_dim), 44 | nn.LeakyReLU(), 45 | ] 46 | for _ in range(instance_attention_layers): 47 | attention_list.extend([ 48 | nn.Dropout(dropout_p), 49 | nn.Linear(instance_attention_dim, instance_attention_dim), 50 | nn.LeakyReLU(), 51 | 52 | ]) 53 | attention_list.extend([ 54 | nn.Linear(instance_attention_dim, self.K) 55 | ]) 56 | self.attention = nn.Sequential(*attention_list) 57 | 58 | """ 59 | feature represent layers 60 | """ 61 | feature_represent_layer_list = [ 62 | nn.Linear(self.L, feature_represent_dim), 63 | nn.LeakyReLU(), 64 | ] 65 | for _ in range(feature_represent_layers): 66 | feature_represent_layer_list.extend([ 67 | nn.Linear(feature_represent_dim, feature_represent_dim), 68 | nn.LeakyReLU(), 69 | nn.Dropout(dropout_p), 70 | ]) 71 | self.feature_represent = nn.Sequential(*feature_represent_layer_list) 72 | 73 | """ 74 | feature attention layers 75 | """ 76 | feature_attention_list = [ 77 | nn.Linear(feature_represent_dim, feature_attention_dim), 78 | Swish(), 79 | ] 80 | for _ in range(feature_attention_layers): 81 | feature_attention_list.extend([ 82 | nn.Linear(feature_attention_dim, feature_attention_dim), 83 | nn.LeakyReLU(), 84 | ]) 85 | feature_attention_list.extend([ 86 | nn.Linear(feature_attention_dim, feature_represent_dim) 87 | ]) 88 | self.feature_attention = nn.Sequential(*feature_attention_list) 89 | 90 | """ 91 | final classifier 92 | """ 93 | self.classifier = nn.Sequential( 94 | nn.Linear(feature_represent_dim * self.K, feature_represent_dim), 95 | nn.LeakyReLU(), 96 | nn.Dropout(dropout_p), 97 | nn.Linear(feature_represent_dim, 1) 98 | ) 99 | 100 | def get_instance_attention_parameters(self): 101 | params = [] 102 | for param in self.attention.parameters(): 103 | params.append(param) 104 | return params 105 | 106 | def get_feature_attention_parameters(self): 107 | params = [] 108 | for param in self.feature_attention.parameters(): 109 | params.append(param) 110 | return params 111 | 112 | def get_classifier_parameters(self): 113 | params = [] 114 | for param in self.classifier.parameters(): 115 | params.append(param) 116 | return params 117 | 118 | def forward(self, batch_data: torch.Tensor): 119 | # H: NxL embedding 120 | # if len(batch_data.size()) == 5: 121 | # 1 num_instance num_channel, w, h 122 | if len(batch_data.size()) == 3: 123 | # 1 #instance #feat 124 | batch_data = batch_data.squeeze(0) 125 | 126 | bag = batch_data 127 | 128 | bag = self.feature_represent(bag) 129 | # instance attention 130 | A = self.attention(bag) # NxK attentions 131 | A = torch.transpose(A, 1, 0) # KxN 132 | A = F.softmax(A, dim=1) # softmax over N 133 | 134 | # feature attention 135 | feature_attention = self.feature_attention(bag) 136 | feature_attention = torch.sigmoid(feature_attention) 137 | bag = bag * feature_attention 138 | M = torch.mm(A, bag) # KxL 139 | 140 | Y_prob = self.classifier(M) 141 | return Y_prob, A 142 | 143 | 144 | class MILNetWithTabFeat(nn.Module): 145 | def __init__(self, input_dim=512, attention_dim=128, attention_out_dim=1, if_freeze_param=False, dropout_p=0.5, 146 | instance_attention_layers=1, instance_attention_dim=128, feature_attention_layers=1, 147 | feature_represent_dim=512, 148 | feature_represent_layers=1, 149 | feature_attention_dim=256, 150 | tabfeat_dim=32): 151 | """ 152 | 153 | Args: 154 | input_dim: number of instance feature 155 | attention_dim: 156 | attention_out_dim: number of features after attention 157 | if_freeze_param: 158 | """ 159 | super(MILNetWithTabFeat, self).__init__() # super(subClass, instance).method(args), python 2.x 160 | self.L = input_dim 161 | self.D = attention_dim 162 | self.K = attention_out_dim 163 | self.if_freeze_param = if_freeze_param 164 | 165 | """ 166 | instance attention layers 167 | """ 168 | attention_list = [ 169 | nn.Linear(feature_represent_dim, instance_attention_dim), 170 | nn.LeakyReLU(), 171 | ] 172 | for _ in range(instance_attention_layers): 173 | attention_list.extend([ 174 | nn.Dropout(dropout_p), 175 | nn.Linear(instance_attention_dim, instance_attention_dim), 176 | nn.LeakyReLU(), 177 | ]) 178 | attention_list.extend([ 179 | nn.Linear(instance_attention_dim, self.K) 180 | ]) 181 | self.attention = nn.Sequential(*attention_list) 182 | 183 | """ 184 | feature represent layers 185 | """ 186 | feature_represent_layer_list = [ 187 | nn.Linear(self.L, feature_represent_dim), 188 | nn.LeakyReLU(), 189 | ] 190 | for _ in range(feature_represent_layers): 191 | feature_represent_layer_list.extend([ 192 | nn.Linear(feature_represent_dim, feature_represent_dim), 193 | nn.LeakyReLU(), 194 | nn.Dropout(dropout_p), 195 | ]) 196 | self.feature_represent = nn.Sequential(*feature_represent_layer_list) 197 | 198 | """ 199 | feature attention layers 200 | """ 201 | feature_attention_list = [ 202 | nn.Linear(feature_represent_dim, feature_attention_dim), 203 | Swish(), 204 | ] 205 | for _ in range(feature_attention_layers): 206 | feature_attention_list.extend([ 207 | nn.Linear(feature_attention_dim, feature_attention_dim), 208 | nn.LeakyReLU(), 209 | ]) 210 | feature_attention_list.extend([ 211 | nn.Linear(feature_attention_dim, feature_represent_dim) 212 | ]) 213 | self.feature_attention = nn.Sequential(*feature_attention_list) 214 | 215 | """ 216 | final classifier 217 | """ 218 | self.classifier = nn.Sequential( 219 | nn.Linear(feature_represent_dim * self.K + tabfeat_dim, feature_represent_dim), 220 | nn.LeakyReLU(), 221 | nn.Dropout(dropout_p), 222 | nn.Linear(feature_represent_dim, 1) 223 | ) 224 | 225 | def get_instance_attention_parameters(self): 226 | params = [] 227 | for param in self.attention.parameters(): 228 | params.append(param) 229 | return params 230 | 231 | def get_feature_attention_parameters(self): 232 | params = [] 233 | for param in self.feature_attention.parameters(): 234 | params.append(param) 235 | return params 236 | 237 | def get_classifier_parameters(self): 238 | params = [] 239 | for param in self.classifier.parameters(): 240 | params.append(param) 241 | return params 242 | 243 | def forward(self, batch_data: torch.Tensor, tab_feat: torch.Tensor): 244 | # H: NxL embedding 245 | # if len(batch_data.size()) == 5: 246 | # 1 num_instance num_channel, w, h 247 | if len(batch_data.size()) == 3: 248 | # 1 #instance #feat 249 | batch_data = batch_data.squeeze(0) 250 | 251 | bag = batch_data 252 | 253 | bag = self.feature_represent(bag) 254 | # instance attention 255 | A = self.attention(bag) # NxK attentions 256 | A = torch.transpose(A, 1, 0) # KxN 257 | A = F.softmax(A, dim=1) # softmax over N 258 | 259 | # feature attention 260 | feature_attention = self.feature_attention(bag) 261 | feature_attention = torch.sigmoid(feature_attention) 262 | bag = bag * feature_attention 263 | M = torch.mm(A, bag) # KxL 264 | 265 | M = M.view(1, -1) 266 | tab_feat = tab_feat.view(1, -1) 267 | 268 | merge_feat = torch.cat([M, tab_feat], dim=1) 269 | Y_prob = self.classifier(merge_feat) 270 | 271 | return Y_prob 272 | 273 | 274 | class MMTMBi(nn.Module): 275 | """ 276 | bi moludal fusion 277 | """ 278 | 279 | def __init__(self, dim_tab, dim_img, ratio=4): 280 | """ 281 | 282 | Parameters 283 | ---------- 284 | dim_tab: feature dimension of tabular data 285 | dim_img: feature dimension of MIL image modal 286 | ratio 287 | """ 288 | super(MMTMBi, self).__init__() 289 | dim = dim_tab + dim_img 290 | dim_out = int(2 * dim / ratio) 291 | self.fc_squeeze = nn.Linear(dim, dim_out) 292 | 293 | self.fc_tab = nn.Linear(dim_out, dim_tab) 294 | self.fc_img = nn.Linear(dim_out, dim_img) 295 | self.relu = nn.ReLU() 296 | self.sigmoid = nn.Sigmoid() 297 | 298 | def forward(self, tab_feat, img_feat) -> Sequence[torch.Tensor]: 299 | """ 300 | 301 | Parameters 302 | ---------- 303 | tab_feat: b * c 304 | skeleton: b * c 305 | 306 | Returns 307 | 表格数据加权结果 308 | WSI 全局特征加权结果 309 | WSI 全局特征加权权重 310 | ------- 311 | 312 | """ 313 | 314 | squeeze = torch.cat([tab_feat, img_feat], dim=1) 315 | 316 | excitation = self.fc_squeeze(squeeze) 317 | excitation = self.relu(excitation) 318 | 319 | tab_out = self.fc_tab(excitation) 320 | img_out = self.fc_img(excitation) 321 | 322 | tab_out = self.sigmoid(tab_out) 323 | img_out = self.sigmoid(img_out) 324 | 325 | return tab_feat * tab_out, img_feat * img_out, img_out 326 | 327 | class MMTMTri(nn.Module): 328 | """ 329 | tri-modal fusion 330 | """ 331 | 332 | def __init__(self, dim_img, ratio=4): 333 | """ 334 | 335 | Parameters 336 | ---------- 337 | dim_tab: feature dimension of tabular data 338 | dim_img: feature dimension of MIL model 339 | ratio 340 | """ 341 | super(MMTMTri, self).__init__() 342 | dim = dim_img * 3 343 | dim_out = int(2 * dim / ratio) 344 | self.fc_squeeze = nn.Linear(dim, dim_out) 345 | 346 | 347 | self.fc_img_scale1 = nn.Linear(dim_out, dim_img) 348 | self.fc_img_scale2 = nn.Linear(dim_out, dim_img) 349 | self.fc_img_scale3 = nn.Linear(dim_out, dim_img) 350 | 351 | self.relu = nn.ReLU() 352 | self.sigmoid = nn.Sigmoid() 353 | 354 | def forward(self, img_feat_scale1, img_feat_scale2, img_feat_scale3) -> Sequence[torch.Tensor]: 355 | """ 356 | 357 | Parameters 358 | ---------- 359 | tab_feat: b * c 360 | skeleton: b * c 361 | 362 | Returns 363 | ------- 364 | 365 | """ 366 | 367 | squeeze = torch.cat([img_feat_scale1, img_feat_scale2, img_feat_scale3], dim=1) 368 | 369 | excitation = self.fc_squeeze(squeeze) 370 | excitation = self.relu(excitation) 371 | 372 | 373 | img_out_scale1 = self.fc_img_scale1(excitation) 374 | img_out_scale2 = self.fc_img_scale2(excitation) 375 | img_out_scale3 = self.fc_img_scale3(excitation) 376 | 377 | img_out_scale1 = self.sigmoid(img_out_scale1) 378 | img_out_scale2 = self.sigmoid(img_out_scale2) 379 | img_out_scale3 = self.sigmoid(img_out_scale3) 380 | 381 | return img_feat_scale1 * img_out_scale1, img_out_scale1, img_feat_scale2 * img_out_scale2, img_out_scale2, img_feat_scale2 * img_out_scale3, img_out_scale3 382 | 383 | class MMTMQuad(nn.Module): 384 | """ 385 | quad modal fusion 386 | """ 387 | 388 | def __init__(self, dim_tab, dim_img, ratio=4): 389 | """ 390 | 391 | Parameters 392 | ---------- 393 | dim_tab: feature dimension of tabular data 394 | dim_img: feature dimension of MIL model 395 | ratio 396 | """ 397 | super(MMTMQuad, self).__init__() 398 | dim = dim_tab + dim_img * 3 399 | dim_out = int(2 * dim / ratio) 400 | self.fc_squeeze = nn.Linear(dim, dim_out) 401 | 402 | self.fc_tab = nn.Linear(dim_out, dim_tab) 403 | 404 | self.fc_img_scale1 = nn.Linear(dim_out, dim_img) 405 | self.fc_img_scale2 = nn.Linear(dim_out, dim_img) 406 | self.fc_img_scale3 = nn.Linear(dim_out, dim_img) 407 | 408 | self.relu = nn.ReLU() 409 | self.sigmoid = nn.Sigmoid() 410 | 411 | def forward(self, tab_feat, img_feat_scale1, img_feat_scale2, img_feat_scale3) -> Sequence[torch.Tensor]: 412 | """ 413 | 414 | Parameters 415 | ---------- 416 | tab_feat: b * c 417 | skeleton: b * c 418 | 419 | Returns 420 | ------- 421 | 422 | """ 423 | 424 | squeeze = torch.cat([tab_feat, img_feat_scale1, img_feat_scale2, img_feat_scale3], dim=1) 425 | 426 | excitation = self.fc_squeeze(squeeze) 427 | excitation = self.relu(excitation) 428 | 429 | tab_out = self.fc_tab(excitation) 430 | img_out_scale1 = self.fc_img_scale1(excitation) 431 | img_out_scale2 = self.fc_img_scale2(excitation) 432 | img_out_scale3 = self.fc_img_scale3(excitation) 433 | 434 | tab_out = self.sigmoid(tab_out) 435 | img_out_scale1 = self.sigmoid(img_out_scale1) 436 | img_out_scale2 = self.sigmoid(img_out_scale2) 437 | img_out_scale3 = self.sigmoid(img_out_scale3) 438 | 439 | return tab_feat * tab_out, img_feat_scale1 * img_out_scale1, img_out_scale1, img_feat_scale2 * img_out_scale2, img_out_scale2, img_feat_scale2 * img_out_scale3, img_out_scale3 440 | 441 | 442 | class InstanceAttentionGate(nn.Module): 443 | def __init__(self, feat_dim): 444 | super(InstanceAttentionGate, self).__init__() 445 | self.trans = nn.Sequential( 446 | nn.Linear(feat_dim * 2, feat_dim), 447 | nn.LeakyReLU(), 448 | nn.Linear(feat_dim, 1), 449 | ) 450 | 451 | def forward(self, instance_feature, global_feature): 452 | feat = torch.cat([instance_feature, global_feature], dim=1) 453 | attention = self.trans(feat) 454 | return attention 455 | 456 | 457 | class MILFusion(nn.Module): 458 | def __init__(self, img_feat_input_dim=512, tab_feat_input_dim=32, 459 | img_feat_rep_layers=4, 460 | num_modal=2, 461 | use_tabnet=False, 462 | tab_indim=0, 463 | local_rank=0, 464 | cat_idxs=None, 465 | cat_dims=None, 466 | lambda_sparse=1e-3, 467 | fusion='mmtm', 468 | use_k_agg=False, 469 | k_agg=10, 470 | ): 471 | super(MILFusion, self).__init__() 472 | self.num_modal = num_modal 473 | self.local_rank = local_rank 474 | self.use_tabnet = use_tabnet 475 | self.tab_indim = tab_indim 476 | self.lambda_sparse = lambda_sparse 477 | # define K mean agg 478 | self.use_k_agg = use_k_agg 479 | self.k_agg = k_agg 480 | 481 | self.fusion_method = fusion 482 | if self.use_tabnet: 483 | self.tabnet = TabNet(input_dim=tab_indim, output_dim=1, 484 | n_d=32, n_a=32, n_steps=5, 485 | gamma=1.5, n_independent=2, n_shared=2, 486 | momentum=0.3, 487 | cat_idxs=cat_idxs, cat_dims=cat_dims) 488 | else: 489 | self.tabnet = None 490 | 491 | if self.use_tabnet and num_modal == 1: 492 | self.only_tabnet = True 493 | else: 494 | self.only_tabnet = False 495 | 496 | """ 497 | Control tabnet 498 | """ 499 | if self.only_tabnet: 500 | self.feature_fine_tuning = None 501 | else: 502 | """pretrained feature fine tune""" 503 | feature_fine_tuning_layers = [] 504 | for _ in range(img_feat_rep_layers): 505 | feature_fine_tuning_layers.extend([ 506 | nn.Linear(img_feat_input_dim, img_feat_input_dim), 507 | nn.LeakyReLU(), 508 | ]) 509 | self.feature_fine_tuning = nn.Sequential(*feature_fine_tuning_layers) 510 | 511 | # 3 means three scales of images as modals 512 | if self.num_modal == 4 or self.num_modal == 3: 513 | self.feature_fine_tuning2 = nn.Sequential(*feature_fine_tuning_layers) 514 | self.feature_fine_tuning3 = nn.Sequential(*feature_fine_tuning_layers) 515 | else: 516 | self.feature_fine_tuning2 = None 517 | self.feature_fine_tuning3 = None 518 | 519 | if self.only_tabnet or self.num_modal == 3: 520 | self.table_feature_ft = None 521 | else: 522 | """tab feature fine tuning""" 523 | self.table_feature_ft = nn.Sequential( 524 | nn.Linear(tab_feat_input_dim, tab_feat_input_dim) 525 | ) 526 | 527 | # k agg score 528 | self.score_fc = nn.ModuleList() 529 | if self.use_k_agg: 530 | for _ in range(self.num_modal - 1): 531 | self.score_fc.append( 532 | nn.Sequential( 533 | nn.Linear(img_feat_input_dim, img_feat_input_dim), 534 | nn.LeakyReLU(), 535 | nn.Linear(img_feat_input_dim, img_feat_input_dim), 536 | nn.LeakyReLU(), 537 | nn.Linear(img_feat_input_dim, 1), 538 | nn.Sigmoid() 539 | ) 540 | ) 541 | 542 | 543 | """modal fusion""" 544 | self.wsi_select_gate = None 545 | # define different fusion methods and related output feature dimension and fusion module 546 | if self.only_tabnet: 547 | self.mmtm = None 548 | elif self.fusion_method == 'concat': 549 | self.fusion_out_dim = tab_feat_input_dim + img_feat_input_dim 550 | self.wsi_select_gate = nn.Sequential( 551 | nn.Linear(img_feat_input_dim, 1), 552 | nn.Sigmoid() 553 | ) 554 | self.mmtm = nn.Linear(self.fusion_out_dim, self.fusion_out_dim) 555 | elif self.fusion_method == 'bilinear': 556 | self.wsi_select_gate = nn.Sequential( 557 | nn.Linear(img_feat_input_dim, 1), 558 | nn.Sigmoid() 559 | ) 560 | self.fusion_out_dim = tab_feat_input_dim + img_feat_input_dim 561 | self.mmtm = nn.Bilinear(tab_feat_input_dim, img_feat_input_dim, self.fusion_out_dim) 562 | elif self.fusion_method == 'add': 563 | self.wsi_select_gate = nn.Sequential( 564 | nn.Linear(img_feat_input_dim, 1), 565 | nn.Sigmoid() 566 | ) 567 | self.fusion_out_dim = tab_feat_input_dim 568 | self.mmtm = nn.Linear(img_feat_input_dim * (num_modal - 1), tab_feat_input_dim) 569 | elif self.fusion_method == 'gate': 570 | self.wsi_select_gate = nn.Sequential( 571 | nn.Linear(img_feat_input_dim, 1), 572 | nn.Sigmoid() 573 | ) 574 | self.fusion_out_dim = 96 575 | self.mmtm = BilinearFusion(dim1=tab_feat_input_dim, dim2=img_feat_input_dim, mmhid=self.fusion_out_dim) 576 | 577 | elif self.num_modal == 2 and self.fusion_method == 'mmtm': 578 | self.fusion_out_dim = (img_feat_input_dim * 2) * (num_modal - 1) + tab_feat_input_dim 579 | self.mmtm = MMTMBi(dim_tab=tab_feat_input_dim, dim_img=img_feat_input_dim) 580 | elif self.num_modal == 3 and self.fusion_method == 'mmtm': 581 | self.fusion_out_dim = (img_feat_input_dim * 2) * 3 582 | self.mmtm = MMTMTri(dim_img=img_feat_input_dim) 583 | elif self.num_modal == 4 and self.fusion_method == 'mmtm': 584 | self.fusion_out_dim = (img_feat_input_dim * 2) * (num_modal - 1) + tab_feat_input_dim 585 | self.mmtm = MMTMQuad(dim_tab=tab_feat_input_dim, dim_img=img_feat_input_dim) 586 | else: 587 | raise NotImplementedError(f'num_modal {num_modal} not implemented') 588 | 589 | """instance selection""" 590 | if self.only_tabnet or self.fusion_method in ['concat', 'add', 'bilinear', 'gate']: 591 | self.instance_gate1 = None 592 | else: 593 | self.instance_gate1 = InstanceAttentionGate(img_feat_input_dim) 594 | 595 | if (self.num_modal == 4 or self.num_modal == 3)and self.fusion_method == 'mmtm': 596 | self.instance_gate2 = InstanceAttentionGate(img_feat_input_dim) 597 | self.instance_gate3 = InstanceAttentionGate(img_feat_input_dim) 598 | else: 599 | self.instance_gate2 = None 600 | self.instance_gate3 = None 601 | 602 | """classifier layer""" 603 | if self.only_tabnet: 604 | self.classifier = None 605 | else: 606 | self.classifier = nn.Sequential( 607 | nn.Linear(self.fusion_out_dim, self.fusion_out_dim), 608 | nn.Dropout(0.5), 609 | nn.Linear(self.fusion_out_dim, 1) 610 | ) 611 | 612 | def agg_k_cluster_by_score(self, data: torch.Tensor, score_fc: nn.Module): 613 | num_elements = data.shape[0] 614 | score = score_fc(data) 615 | 616 | """ 617 | >>> score = torch.rand(4,1) 618 | >>> top_score, top_idx = torch.topk(score, k=num_elements, dim=0) 619 | >>> top_score, top_idx 620 | (tensor([[0.3963], 621 | [0.0856], 622 | [0.0704], 623 | [0.0247]]), 624 | tensor([[1], 625 | [0], 626 | [3], 627 | [2]])) 628 | """ 629 | top_score, top_idx = torch.topk(score, k=num_elements, dim=0) 630 | """ 631 | >>> data 632 | tensor([[0.0672, 0.9001, 0.5660, 0.0522, 0.1543], 633 | [0.1965, 0.7711, 0.9737, 0.5269, 0.9255], 634 | [0.6761, 0.5801, 0.4687, 0.1683, 0.8136], 635 | [0.2091, 0.9620, 0.8105, 0.8210, 0.3391]]) 636 | >>> top_idx[:, 0] 637 | tensor([1, 0, 3, 2]) 638 | >>> data_sorted 639 | tensor([[0.1965, 0.7711, 0.9737, 0.5269, 0.9255], 640 | [0.0672, 0.9001, 0.5660, 0.0522, 0.1543], 641 | [0.2091, 0.9620, 0.8105, 0.8210, 0.3391], 642 | [0.6761, 0.5801, 0.4687, 0.1683, 0.8136]]) 643 | """ 644 | data_sorted = torch.zeros_like(data) 645 | data_sorted.index_copy_(dim=0, index=top_idx[:, 0], source=data) 646 | 647 | # set Batch with feature dim 648 | data_sorted = torch.transpose(data_sorted, 1, 0) 649 | data_sorted = data_sorted.unsqueeze(1) 650 | 651 | 652 | agg_result = nn.functional.adaptive_max_pool1d(data_sorted, self.k_agg) 653 | 654 | agg_result = agg_result.squeeze(1) 655 | agg_result = torch.transpose(agg_result, 1, 0) 656 | 657 | return agg_result 658 | 659 | def forward(self, data): 660 | 661 | attention_weight_out_list = [] 662 | if self.use_tabnet: 663 | if torch.cuda.is_available(): 664 | tab_data = data['tab_data'].cuda(self.local_rank) 665 | else: 666 | tab_data = data['tab_data'] 667 | if self.only_tabnet: 668 | tab_logit, M_loss = self.tabnet(tab_data) 669 | else: 670 | tab_logit, tab_feat, M_loss = self.tabnet(tab_data) 671 | 672 | tab_loss_weight = 1. 673 | else: 674 | tab_feat = data['tab_feat'].cuda(self.local_rank) 675 | tab_logit = torch.zeros((1, 1)).cuda(self.local_rank) 676 | M_loss = 0. 677 | tab_loss_weight = 0. 678 | 679 | if torch.cuda.is_available(): 680 | y = data['label'].cuda(self.local_rank) 681 | wsi_feat_scale1 = data['wsi_feat_scale1'].cuda(self.local_rank) 682 | else: 683 | y = data['label'] 684 | wsi_feat_scale1 = data['wsi_feat_scale1'] 685 | if len(wsi_feat_scale1.size()) == 3: 686 | # 1 #instance #feat 687 | wsi_feat_scale1 = wsi_feat_scale1.squeeze(0) 688 | scale1_bs = wsi_feat_scale1.shape[0] 689 | 690 | if self.only_tabnet: 691 | out = tab_logit 692 | tab_loss_weight = 0. 693 | elif self.fusion_method in ['concat', 'add', 'bilinear', 'gate']: 694 | tab_feat = self.table_feature_ft(tab_feat) 695 | # fusion first 696 | feat_list = [] 697 | for scale, ft_fc in zip(range(3), [self.feature_fine_tuning, self.feature_fine_tuning2, self.feature_fine_tuning3]): 698 | if ft_fc is None: 699 | break 700 | wsi_feat = data[f'wsi_feat_scale{scale+1}'].cuda(self.local_rank) 701 | wsi_feat = wsi_feat.squeeze(0) 702 | wsi_feat = ft_fc(wsi_feat) 703 | 704 | feat_list.append(wsi_feat) 705 | 706 | wsi_feat_concat = torch.cat(feat_list, dim=0) 707 | 708 | global_wsi_feat_weight = self.wsi_select_gate(wsi_feat_concat) 709 | global_wsi_feat = torch.sum(wsi_feat_concat * global_wsi_feat_weight, dim=0, keepdim=True) 710 | 711 | if self.fusion_method == 'concat': 712 | fusion_feat = self.mmtm(torch.cat([tab_feat, global_wsi_feat], dim=1)) 713 | elif self.fusion_method == 'add': 714 | fusion_feat = self.mmtm(global_wsi_feat) + tab_feat 715 | elif self.fusion_method == 'bilinear': 716 | fusion_feat = self.mmtm(tab_feat, global_wsi_feat) 717 | elif self.fusion_method == 'gate': 718 | fusion_feat = self.mmtm(tab_feat, global_wsi_feat) 719 | 720 | out = self.classifier(fusion_feat) 721 | 722 | 723 | elif self.num_modal == 2: 724 | wsi_feat_scale1 = self.feature_fine_tuning(wsi_feat_scale1) 725 | wsi_feat_scale1_gloabl = torch.mean(wsi_feat_scale1, dim=0, keepdim=True) # instance level mean 726 | 727 | tab_feat_mmtm, wsi_feat1_gloabl, wsi_feat_scale1_gate = self.mmtm(tab_feat, wsi_feat_scale1_gloabl) 728 | 729 | # table feature calculate once more 730 | tab_feat_ft = self.table_feature_ft(tab_feat_mmtm) 731 | 732 | # weight on feature level 733 | wsi_feat_scale1 = wsi_feat_scale1 * wsi_feat_scale1_gate 734 | 735 | wsi_feat1_gloabl_repeat = wsi_feat1_gloabl.detach().repeat(scale1_bs, 1) 736 | 737 | # N * 1 738 | instance_attention_weight = self.instance_gate1(wsi_feat_scale1, wsi_feat1_gloabl_repeat) 739 | # 1 * N 740 | instance_attention_weight = torch.transpose(instance_attention_weight, 1, 0) 741 | 742 | instance_attention_weight = torch.softmax(instance_attention_weight, dim=1) 743 | 744 | attention_weight_out_list.append(instance_attention_weight.detach().clone()) 745 | 746 | # 1 * N 747 | wsi_feat_agg_scale1 = torch.mm(instance_attention_weight, wsi_feat_scale1) 748 | 749 | final_feat = torch.cat([tab_feat_ft, wsi_feat_agg_scale1, wsi_feat1_gloabl], dim=1) 750 | 751 | out = self.classifier(final_feat) 752 | 753 | elif self.num_modal == 3: 754 | """ 755 | Fuse 3 modalities 756 | """ 757 | wsi_feat_scale2 = data['wsi_feat_scale2'].cuda(self.local_rank) 758 | wsi_feat_scale3 = data['wsi_feat_scale3'].cuda(self.local_rank) 759 | if len(wsi_feat_scale2.size()) == 3: 760 | # 1 #instance #feat 761 | wsi_feat_scale2 = wsi_feat_scale2.squeeze(0) 762 | if len(wsi_feat_scale3.size()) == 3: 763 | # 1 #instance #feat 764 | wsi_feat_scale3 = wsi_feat_scale3.squeeze(0) 765 | 766 | """fine-tuning on 3 scales""" 767 | wsi_ft_feat_list = [] 768 | for ft_conv, wsi_feat in zip( 769 | [self.feature_fine_tuning, self.feature_fine_tuning2, self.feature_fine_tuning3], 770 | [wsi_feat_scale1, wsi_feat_scale2, wsi_feat_scale3], 771 | ): 772 | wsi_ft_feat_list.append(ft_conv(wsi_feat)) 773 | 774 | if self.use_k_agg: 775 | agg_feat_list = [] 776 | for data_feat, score_fc in zip(wsi_ft_feat_list, self.score_fc): 777 | agg_feat_list.append(self.agg_k_cluster_by_score(data_feat, score_fc)) 778 | wsi_ft_feat_list = agg_feat_list 779 | 780 | wsi_feat_scale_gloabl_list = [] 781 | for data_feat, score_fc in zip(agg_feat_list, self.score_fc): 782 | feat_score = score_fc(data_feat) 783 | feat_attention = torch.softmax(feat_score, dim=0) 784 | global_feat = torch.sum(data_feat * feat_attention, dim=0, keepdim=True) 785 | wsi_feat_scale_gloabl_list.append(global_feat) 786 | 787 | else: 788 | """global representation of 3 scales features""" 789 | wsi_feat_scale_gloabl_list = [] 790 | for feat in wsi_ft_feat_list: 791 | wsi_feat_scale_gloabl_list.append(torch.mean(feat, dim=0, keepdim=True)) 792 | 793 | """mmtm""" 794 | wsi_feat1_gloabl, wsi_feat_scale1_gate, wsi_feat2_gloabl, wsi_feat_scale2_gate, wsi_feat3_gloabl, wsi_feat_scale3_gate = self.mmtm( 795 | *wsi_feat_scale_gloabl_list) 796 | 797 | """instance selection on 3 scales""" 798 | wsi_feat_agg_list = [] 799 | for wsi_feat_at_scale, wsi_feat_gate_at_scale, wsi_global_rep, instance_gate in zip( 800 | wsi_ft_feat_list, 801 | [wsi_feat_scale1_gate, wsi_feat_scale2_gate, wsi_feat_scale3_gate], 802 | wsi_feat_scale_gloabl_list, 803 | [self.instance_gate1, self.instance_gate2, self.instance_gate3] 804 | ): 805 | # 806 | bs_at_scale = wsi_feat_at_scale.shape[0] 807 | wsi_feat_at_scale = wsi_feat_at_scale * wsi_feat_gate_at_scale 808 | wsi_global_rep_repeat = wsi_feat_gate_at_scale.detach().repeat(bs_at_scale, 1) 809 | 810 | # N * 1 811 | instance_attention_weight = instance_gate(wsi_feat_at_scale, wsi_global_rep_repeat) 812 | # 1 * N 813 | instance_attention_weight = torch.transpose(instance_attention_weight, 1, 0) 814 | instance_attention_weight = torch.softmax(instance_attention_weight, dim=1) 815 | 816 | # instance aggregate 817 | wsi_feat_agg = torch.mm(instance_attention_weight, wsi_feat_at_scale) 818 | wsi_feat_agg_list.append(wsi_feat_agg) 819 | 820 | 821 | final_feat = torch.cat( 822 | [*wsi_feat_agg_list, wsi_feat1_gloabl, wsi_feat2_gloabl, wsi_feat3_gloabl], dim=1) 823 | 824 | out = self.classifier(final_feat) 825 | 826 | elif self.num_modal == 4: 827 | """ 828 | Fuse 4 modalities 829 | """ 830 | if torch.cuda.is_available(): 831 | wsi_feat_scale2 = data['wsi_feat_scale2'].cuda(self.local_rank) 832 | wsi_feat_scale3 = data['wsi_feat_scale3'].cuda(self.local_rank) 833 | else: 834 | wsi_feat_scale2 = data['wsi_feat_scale2'] 835 | wsi_feat_scale3 = data['wsi_feat_scale3'] 836 | if len(wsi_feat_scale2.size()) == 3: 837 | # 1 #instance #feat 838 | wsi_feat_scale2 = wsi_feat_scale2.squeeze(0) 839 | if len(wsi_feat_scale3.size()) == 3: 840 | # 1 #instance #feat 841 | wsi_feat_scale3 = wsi_feat_scale3.squeeze(0) 842 | 843 | if self.use_k_agg: 844 | if wsi_feat_scale1.shape[0] < self.k_agg: 845 | pad_size = self.k_agg - wsi_feat_scale1.shape[0] 846 | zero_size = (pad_size, *wsi_feat_scale1.shape[1:]) 847 | pad_tensor = torch.zeros(zero_size).to(wsi_feat_scale1.device) 848 | wsi_feat_scale1 = torch.cat([wsi_feat_scale1, pad_tensor]) 849 | if wsi_feat_scale2.shape[0] < self.k_agg: 850 | pad_size = self.k_agg - wsi_feat_scale2.shape[0] 851 | zero_size = (pad_size, *wsi_feat_scale2.shape[1:]) 852 | pad_tensor = torch.zeros(zero_size).to(wsi_feat_scale2.device) 853 | wsi_feat_scale2 = torch.cat([wsi_feat_scale2, pad_tensor]) 854 | if wsi_feat_scale3.shape[0] < self.k_agg: 855 | pad_size = self.k_agg - wsi_feat_scale3.shape[0] 856 | zero_size = (pad_size, *wsi_feat_scale3.shape[1:]) 857 | pad_tensor = torch.zeros(zero_size).to(wsi_feat_scale3.device) 858 | wsi_feat_scale3 = torch.cat([wsi_feat_scale3, pad_tensor]) 859 | 860 | """fine-tuning 3 scales features""" 861 | wsi_ft_feat_list = [] 862 | for ft_conv, wsi_feat in zip( 863 | [self.feature_fine_tuning, self.feature_fine_tuning2, self.feature_fine_tuning3], 864 | [wsi_feat_scale1, wsi_feat_scale2, wsi_feat_scale3], 865 | ): 866 | wsi_ft_feat_list.append(ft_conv(wsi_feat)) 867 | 868 | if self.use_k_agg: 869 | agg_feat_list = [] 870 | for data_feat, score_fc in zip(wsi_ft_feat_list, self.score_fc): 871 | agg_feat_list.append(self.agg_k_cluster_by_score(data_feat, score_fc)) 872 | wsi_ft_feat_list = agg_feat_list 873 | 874 | wsi_feat_scale_gloabl_list = [] 875 | for data_feat, score_fc in zip(agg_feat_list, self.score_fc): 876 | feat_score = score_fc(data_feat) 877 | feat_attention = torch.sigmoid(feat_score) 878 | 879 | attention_weight_out_list.append(feat_attention.detach().clone()) 880 | global_feat = torch.sum(data_feat * feat_attention, dim=0, keepdim=True) 881 | wsi_feat_scale_gloabl_list.append(global_feat) 882 | else: 883 | """global representation of 3 scales features""" 884 | wsi_feat_scale_gloabl_list = [] 885 | for feat in wsi_ft_feat_list: 886 | wsi_feat_scale_gloabl_list.append(torch.mean(feat, dim=0, keepdim=True)) 887 | 888 | 889 | """mmtm""" 890 | tab_feat_mmtm, wsi_feat1_gloabl, wsi_feat_scale1_gate, wsi_feat2_gloabl, wsi_feat_scale2_gate, wsi_feat3_gloabl, wsi_feat_scale3_gate = self.mmtm(tab_feat, *wsi_feat_scale_gloabl_list) 891 | 892 | """instance selection of 3 scales""" 893 | wsi_feat_agg_list = [] 894 | for wsi_feat_at_scale, wsi_feat_gate_at_scale, wsi_global_rep, instance_gate in zip( 895 | wsi_ft_feat_list, 896 | [wsi_feat_scale1_gate, wsi_feat_scale2_gate, wsi_feat_scale3_gate], 897 | wsi_feat_scale_gloabl_list, 898 | [self.instance_gate1, self.instance_gate2, self.instance_gate3] 899 | ): 900 | # 901 | bs_at_scale = wsi_feat_at_scale.shape[0] 902 | wsi_feat_at_scale = wsi_feat_at_scale * wsi_feat_gate_at_scale 903 | wsi_global_rep_repeat = wsi_feat_gate_at_scale.detach().repeat(bs_at_scale, 1) 904 | 905 | # N * 1 906 | instance_attention_weight = instance_gate(wsi_feat_at_scale, wsi_global_rep_repeat) 907 | # 1 * N 908 | instance_attention_weight = torch.transpose(instance_attention_weight, 1, 0) 909 | 910 | instance_attention_weight = torch.softmax(instance_attention_weight, dim=1) 911 | 912 | 913 | # instance aggregate 914 | wsi_feat_agg = torch.mm(instance_attention_weight, wsi_feat_at_scale) 915 | 916 | attention_weight_out_list.append(instance_attention_weight.detach().clone()) 917 | wsi_feat_agg_list.append(wsi_feat_agg) 918 | 919 | """tab feat ft""" 920 | tab_feat_ft = self.table_feature_ft(tab_feat_mmtm) 921 | 922 | final_feat = torch.cat([tab_feat_ft, *wsi_feat_agg_list, wsi_feat1_gloabl, wsi_feat2_gloabl, wsi_feat3_gloabl], dim=1) 923 | 924 | out = self.classifier(final_feat) 925 | 926 | 927 | pass 928 | y = y.view(-1, 1).float() 929 | loss = F.binary_cross_entropy_with_logits(out, y) + \ 930 | tab_loss_weight * F.binary_cross_entropy_with_logits(tab_logit, y) - \ 931 | self.lambda_sparse * M_loss 932 | 933 | return out, loss, attention_weight_out_list 934 | 935 | def get_params(self, base_lr): 936 | ret = [] 937 | 938 | if self.tabnet is not None: 939 | tabnet_params = [] 940 | for param in self.tabnet.parameters(): 941 | tabnet_params.append(param) 942 | ret.append({ 943 | 'params': tabnet_params, 944 | 'lr': base_lr 945 | }) 946 | 947 | cls_learning_rate_rate=100 948 | if self.classifier is not None: 949 | classifier_params = [] 950 | for param in self.classifier.parameters(): 951 | classifier_params.append(param) 952 | ret.append({ 953 | 'params': classifier_params, 954 | 'lr': base_lr / cls_learning_rate_rate, 955 | }) 956 | 957 | 958 | tab_learning_rate_rate = 100 959 | if self.table_feature_ft is not None: 960 | misc_params = [] 961 | for param in self.table_feature_ft.parameters(): 962 | misc_params.append(param) 963 | ret.append({ 964 | 'params': misc_params, 965 | 'lr': base_lr / tab_learning_rate_rate, 966 | }) 967 | 968 | mil_learning_rate_rate = 1000 969 | misc_params = [] 970 | for part in [self.feature_fine_tuning, self.feature_fine_tuning2, self.feature_fine_tuning3, 971 | self.instance_gate1, self.instance_gate2, self.instance_gate3, 972 | self.wsi_select_gate, 973 | self.score_fc]: 974 | if part is not None: 975 | for param in part.parameters(): 976 | misc_params.append(param) 977 | ret.append({ 978 | 'params': misc_params, 979 | 'lr': base_lr / mil_learning_rate_rate, 980 | }) 981 | 982 | misc_learning_rate_rate = 100 983 | misc_params = [] 984 | for part in [self.mmtm, ]: 985 | if part is not None: 986 | for param in part.parameters(): 987 | misc_params.append(param) 988 | ret.append({ 989 | 'params': misc_params, 990 | 'lr': base_lr / misc_learning_rate_rate, 991 | }) 992 | 993 | return ret 994 | 995 | 996 | 997 | """ 998 | M3D part models 999 | 1000 | """ 1001 | # textnet inner concat number 1002 | neure_num = [23, 32, 32, 32, 32, 32, 16, 1] 1003 | 1004 | class TextNet(nn.Module): 1005 | def __init__(self, neure_num): 1006 | super(TextNet, self).__init__() 1007 | self.encoder = make_layers(neure_num[:3]) 1008 | self.feature = make_layers(neure_num[2:-1]) 1009 | self.fc = nn.Linear(neure_num[-2], neure_num[-1]) 1010 | self.sig = nn.Sigmoid() 1011 | self._initialize_weights() 1012 | 1013 | def forward(self, x): 1014 | encoder = self.encoder(x) 1015 | fea = self.feature(encoder) 1016 | y = self.fc(fea) 1017 | y = self.sig(y) 1018 | return y 1019 | 1020 | def _initialize_weights(self): 1021 | for m in self.modules(): 1022 | if isinstance(m, nn.Linear): 1023 | m.weight.data.normal_(0, 0.1) 1024 | m.bias.data.zero_() 1025 | 1026 | 1027 | def make_layers(cfg): 1028 | layers = [] 1029 | n = len(cfg) 1030 | input_dim = cfg[0] 1031 | for i in range(1, n): 1032 | output_dim = cfg[i] 1033 | if i < n - 1: 1034 | layers += [nn.Linear(input_dim, output_dim), nn.BatchNorm1d(output_dim), nn.ReLU(inplace=True)] 1035 | else: 1036 | layers += [nn.Linear(input_dim, output_dim), nn.ReLU(inplace=True)] 1037 | input_dim = output_dim 1038 | return nn.Sequential(*layers) 1039 | 1040 | 1041 | 1042 | class M3D(nn.Module): 1043 | def __init__(self, img_feat_input_dim=512, tab_feat_input_dim=32, 1044 | img_feat_rep_layers=4, 1045 | num_modal=2, 1046 | use_tabnet=False, 1047 | tab_indim=0, 1048 | local_rank=0, 1049 | cat_idxs=None, 1050 | cat_dims=None, 1051 | lambda_sparse=1e-3, 1052 | fusion='mmtm', 1053 | use_k_agg=False, 1054 | k_agg=10, 1055 | ): 1056 | super(M3D, self).__init__() 1057 | self.num_modal = num_modal 1058 | self.local_rank = local_rank 1059 | self.use_tabnet = use_tabnet 1060 | self.tab_indim = tab_indim 1061 | self.lambda_sparse = lambda_sparse 1062 | # define mean agg 1063 | self.use_k_agg = use_k_agg 1064 | self.k_agg = k_agg 1065 | 1066 | self.fusion_method = fusion 1067 | 1068 | feature_fine_tuning_layers = [] 1069 | for _ in range(img_feat_rep_layers): 1070 | feature_fine_tuning_layers.extend([ 1071 | nn.Linear(img_feat_input_dim, img_feat_input_dim), 1072 | nn.LeakyReLU(), 1073 | ]) 1074 | 1075 | self.feature_fine_tuning = nn.Sequential(*deepcopy(feature_fine_tuning_layers)) 1076 | self.feature_fine_tuning2 = nn.Sequential(*deepcopy(feature_fine_tuning_layers)) 1077 | self.feature_fine_tuning3 = nn.Sequential(*deepcopy(feature_fine_tuning_layers)) 1078 | 1079 | self.fc1 = nn.Sequential( 1080 | nn.Linear(img_feat_input_dim, 1), 1081 | nn.Sigmoid(), 1082 | ) 1083 | self.fc2 = nn.Sequential( 1084 | nn.Linear(img_feat_input_dim, 1), 1085 | nn.Sigmoid(), 1086 | ) 1087 | self.fc3 = nn.Sequential( 1088 | nn.Linear(img_feat_input_dim, 1), 1089 | nn.Sigmoid(), 1090 | ) 1091 | 1092 | self.text_net = TextNet(neure_num) 1093 | 1094 | def forward(self, data): 1095 | tab_data = data['tab_data'].cuda(self.local_rank) 1096 | wsi_feat_scale1 = data['wsi_feat_scale1'].cuda(self.local_rank) 1097 | wsi_feat_scale2 = data['wsi_feat_scale2'].cuda(self.local_rank) 1098 | wsi_feat_scale3 = data['wsi_feat_scale3'].cuda(self.local_rank) 1099 | y = data['label'].cuda(self.local_rank) 1100 | y = y.view(-1, 1).float() 1101 | 1102 | wsi_feat1 = self.feature_fine_tuning(wsi_feat_scale1) 1103 | wsi_feat2 = self.feature_fine_tuning2(wsi_feat_scale2) 1104 | wsi_feat3 = self.feature_fine_tuning3(wsi_feat_scale3) 1105 | 1106 | wsi_instance_predict1 = self.fc1(wsi_feat1) 1107 | wsi_instance_predict2 = self.fc2(wsi_feat2) 1108 | wsi_instance_predict3 = self.fc3(wsi_feat3) 1109 | 1110 | text_predict = self.text_net(tab_data) 1111 | 1112 | bag_predict = torch.max(wsi_instance_predict1) + torch.max(wsi_instance_predict2) \ 1113 | + torch.max(wsi_instance_predict3) + torch.max(text_predict) 1114 | 1115 | for_debug_predict = (torch.mean(wsi_instance_predict1) + torch.mean(wsi_instance_predict2) + torch.mean(wsi_instance_predict3)) / 3. 1116 | 1117 | bag_predict = bag_predict / 4. 1118 | bag_predict = bag_predict.view(-1, 1) 1119 | debug_loss = F.binary_cross_entropy(for_debug_predict.view(-1, 1), y) 1120 | loss = F.binary_cross_entropy(bag_predict, y) + 1e-10 * debug_loss 1121 | return bag_predict, loss 1122 | 1123 | def get_params(self, base_lr): 1124 | ret = [] 1125 | 1126 | 1127 | misc_params = [] 1128 | for part in [self.feature_fine_tuning, self.feature_fine_tuning2, self.feature_fine_tuning3, 1129 | self.fc1, self.fc2, self.fc3]: 1130 | if part is not None: 1131 | for param in part.parameters(): 1132 | misc_params.append(param) 1133 | ret.append({ 1134 | 'params': misc_params, 1135 | 'lr': base_lr / 100, 1136 | }) 1137 | 1138 | misc_params = [] 1139 | for part in [self.text_net, ]: 1140 | if part is not None: 1141 | for param in part.parameters(): 1142 | misc_params.append(param) 1143 | ret.append({ 1144 | 'params': misc_params, 1145 | 'lr': base_lr / 100, 1146 | }) 1147 | 1148 | return ret 1149 | 1150 | 1151 | 1152 | 1153 | 1154 | class MILFusionAppend(nn.Module): 1155 | def __init__(self, img_feat_input_dim=512, tab_feat_input_dim=32, 1156 | img_feat_rep_layers=4, 1157 | num_modal=2, 1158 | use_tabnet=False, 1159 | tab_indim=0, 1160 | local_rank=0, 1161 | cat_idxs=None, 1162 | cat_dims=None, 1163 | lambda_sparse=1e-3, 1164 | fusion='mmtm', 1165 | use_k_agg=False, 1166 | k_agg=10, 1167 | ): 1168 | super(MILFusionAppend, self).__init__() 1169 | self.num_modal = num_modal 1170 | self.local_rank = local_rank 1171 | self.use_tabnet = use_tabnet 1172 | self.tab_indim = tab_indim 1173 | self.lambda_sparse = lambda_sparse 1174 | # define K mean agg 1175 | self.use_k_agg = use_k_agg 1176 | self.k_agg = k_agg 1177 | 1178 | self.fusion_method = fusion 1179 | if self.use_tabnet: 1180 | self.tabnet = TabNet(input_dim=tab_indim, output_dim=1, 1181 | n_d=32, n_a=32, n_steps=5, 1182 | gamma=1.5, n_independent=2, n_shared=2, 1183 | momentum=0.3, 1184 | cat_idxs=cat_idxs, cat_dims=cat_dims) 1185 | else: 1186 | self.tabnet = None 1187 | 1188 | if self.use_tabnet and num_modal == 1: 1189 | self.only_tabnet = True 1190 | else: 1191 | self.only_tabnet = False 1192 | 1193 | """ 1194 | Control tabnet 1195 | """ 1196 | if self.only_tabnet: 1197 | self.feature_fine_tuning = None 1198 | else: 1199 | """pretrained feature fine tune""" 1200 | feature_fine_tuning_layers = [] 1201 | for _ in range(img_feat_rep_layers): 1202 | feature_fine_tuning_layers.extend([ 1203 | nn.Linear(img_feat_input_dim, img_feat_input_dim), 1204 | nn.LeakyReLU(), 1205 | ]) 1206 | self.feature_fine_tuning = nn.Sequential(*feature_fine_tuning_layers) 1207 | 1208 | # 3 is the 3 scales of image 1209 | if self.num_modal == 4 or self.num_modal == 3: 1210 | self.feature_fine_tuning2 = nn.Sequential(*feature_fine_tuning_layers) 1211 | self.feature_fine_tuning3 = nn.Sequential(*feature_fine_tuning_layers) 1212 | else: 1213 | self.feature_fine_tuning2 = None 1214 | self.feature_fine_tuning3 = None 1215 | 1216 | if self.only_tabnet or self.num_modal == 3: 1217 | self.table_feature_ft = None 1218 | else: 1219 | """tab feature fine tuning""" 1220 | self.table_feature_ft = nn.Sequential( 1221 | nn.Linear(tab_feat_input_dim, tab_feat_input_dim) 1222 | ) 1223 | 1224 | # k agg score 1225 | self.score_fc = nn.ModuleList() 1226 | if self.use_k_agg: 1227 | for _ in range(self.num_modal - 1): 1228 | self.score_fc.append( 1229 | nn.Sequential( 1230 | nn.Linear(img_feat_input_dim, img_feat_input_dim), 1231 | nn.LeakyReLU(), 1232 | nn.Linear(img_feat_input_dim, img_feat_input_dim), 1233 | nn.LeakyReLU(), 1234 | nn.Linear(img_feat_input_dim, 1), 1235 | nn.Sigmoid() 1236 | ) 1237 | ) 1238 | 1239 | 1240 | """modal fusion""" 1241 | self.wsi_select_gate = None 1242 | 1243 | if self.only_tabnet: 1244 | self.mmtm = None 1245 | elif self.fusion_method == 'concat': 1246 | self.fusion_out_dim = tab_feat_input_dim + img_feat_input_dim 1247 | self.wsi_select_gate = nn.Sequential( 1248 | nn.Linear(img_feat_input_dim, 1), 1249 | nn.Sigmoid() 1250 | ) 1251 | self.mmtm = nn.Linear(self.fusion_out_dim, self.fusion_out_dim) 1252 | elif self.fusion_method == 'bilinear': 1253 | self.wsi_select_gate = nn.Sequential( 1254 | nn.Linear(img_feat_input_dim, 1), 1255 | nn.Sigmoid() 1256 | ) 1257 | self.fusion_out_dim = tab_feat_input_dim + img_feat_input_dim 1258 | self.mmtm = nn.Bilinear(tab_feat_input_dim, img_feat_input_dim, self.fusion_out_dim) 1259 | elif self.fusion_method == 'add': 1260 | self.wsi_select_gate = nn.Sequential( 1261 | nn.Linear(img_feat_input_dim, 1), 1262 | nn.Sigmoid() 1263 | ) 1264 | self.fusion_out_dim = tab_feat_input_dim 1265 | self.mmtm = nn.Linear(img_feat_input_dim * (num_modal - 1), tab_feat_input_dim) 1266 | elif self.fusion_method == 'gate': 1267 | self.wsi_select_gate = nn.Sequential( 1268 | nn.Linear(img_feat_input_dim, 1), 1269 | nn.Sigmoid() 1270 | ) 1271 | self.fusion_out_dim = 96 1272 | self.mmtm = BilinearFusion(dim1=tab_feat_input_dim, dim2=img_feat_input_dim, mmhid=self.fusion_out_dim) 1273 | elif self.num_modal == 2 and self.fusion_method == 'mmtm': 1274 | self.fusion_out_dim = (img_feat_input_dim * 2) * (num_modal - 1) + tab_feat_input_dim 1275 | self.mmtm = MMTMBi(dim_tab=tab_feat_input_dim, dim_img=img_feat_input_dim) 1276 | elif self.num_modal == 3 and self.fusion_method == 'mmtm': 1277 | self.fusion_out_dim = (img_feat_input_dim * 2) * 3 1278 | self.mmtm = MMTMTri(dim_img=img_feat_input_dim) 1279 | elif self.num_modal == 4 and self.fusion_method == 'mmtm': 1280 | self.fusion_out_dim = (img_feat_input_dim * 2) * (num_modal - 1) + tab_feat_input_dim 1281 | self.mmtm = MMTMQuad(dim_tab=tab_feat_input_dim, dim_img=img_feat_input_dim) 1282 | else: 1283 | raise NotImplementedError(f'num_modal {num_modal} not implemented') 1284 | 1285 | """instance selection""" 1286 | if self.only_tabnet or self.fusion_method in ['concat', 'add', 'bilinear', 'gate']: 1287 | self.instance_gate1 = None 1288 | else: 1289 | self.instance_gate1 = InstanceAttentionGate(img_feat_input_dim) 1290 | 1291 | if (self.num_modal == 4 or self.num_modal == 3)and self.fusion_method == 'mmtm': 1292 | self.instance_gate2 = InstanceAttentionGate(img_feat_input_dim) 1293 | self.instance_gate3 = InstanceAttentionGate(img_feat_input_dim) 1294 | else: 1295 | self.instance_gate2 = None 1296 | self.instance_gate3 = None 1297 | 1298 | """classifier layer""" 1299 | if self.only_tabnet: 1300 | self.classifier = None 1301 | else: 1302 | self.classifier = nn.Sequential( 1303 | nn.Linear(self.fusion_out_dim, self.fusion_out_dim), 1304 | nn.Dropout(0.5), 1305 | nn.Linear(self.fusion_out_dim, 1) 1306 | ) 1307 | 1308 | def agg_k_cluster_by_score(self, data: torch.Tensor, score_fc: nn.Module): 1309 | num_elements = data.shape[0] 1310 | score = score_fc(data) 1311 | 1312 | """ 1313 | >>> score = torch.rand(4,1) 1314 | >>> top_score, top_idx = torch.topk(score, k=num_elements, dim=0) 1315 | >>> top_score, top_idx 1316 | (tensor([[0.3963], 1317 | [0.0856], 1318 | [0.0704], 1319 | [0.0247]]), 1320 | tensor([[1], 1321 | [0], 1322 | [3], 1323 | [2]])) 1324 | """ 1325 | top_score, top_idx = torch.topk(score, k=num_elements, dim=0) 1326 | """ 1327 | >>> data 1328 | tensor([[0.0672, 0.9001, 0.5660, 0.0522, 0.1543], 1329 | [0.1965, 0.7711, 0.9737, 0.5269, 0.9255], 1330 | [0.6761, 0.5801, 0.4687, 0.1683, 0.8136], 1331 | [0.2091, 0.9620, 0.8105, 0.8210, 0.3391]]) 1332 | >>> top_idx[:, 0] 1333 | tensor([1, 0, 3, 2]) 1334 | >>> data_sorted 1335 | tensor([[0.1965, 0.7711, 0.9737, 0.5269, 0.9255], 1336 | [0.0672, 0.9001, 0.5660, 0.0522, 0.1543], 1337 | [0.2091, 0.9620, 0.8105, 0.8210, 0.3391], 1338 | [0.6761, 0.5801, 0.4687, 0.1683, 0.8136]]) 1339 | """ 1340 | data_sorted = torch.zeros_like(data) 1341 | data_sorted.index_copy_(dim=0, index=top_idx[:, 0], source=data) 1342 | 1343 | # Batch set as feature dim 1344 | data_sorted = torch.transpose(data_sorted, 1, 0) 1345 | data_sorted = data_sorted.unsqueeze(1) 1346 | 1347 | agg_result = nn.functional.adaptive_max_pool1d(data_sorted, self.k_agg) 1348 | 1349 | agg_result = agg_result.squeeze(1) 1350 | agg_result = torch.transpose(agg_result, 1, 0) 1351 | return agg_result 1352 | 1353 | def forward(self, data): 1354 | attention_weight_out_list = [] 1355 | if self.use_tabnet: 1356 | tab_data = data['tab_data'].cuda(self.local_rank) 1357 | if self.only_tabnet: 1358 | tab_logit, M_loss = self.tabnet(tab_data) 1359 | else: 1360 | tab_logit, tab_feat, M_loss = self.tabnet(tab_data) 1361 | 1362 | tab_loss_weight = 1. 1363 | else: 1364 | tab_feat = data['tab_feat'].cuda(self.local_rank) 1365 | tab_logit = torch.zeros((1, 1)).cuda(self.local_rank) 1366 | M_loss = 0. 1367 | tab_loss_weight = 0. 1368 | 1369 | y = data['label'].cuda(self.local_rank) 1370 | wsi_feat_scale1 = data['wsi_feat_scale1'].cuda(self.local_rank) 1371 | if len(wsi_feat_scale1.size()) == 3: 1372 | # 1 #instance #feat 1373 | wsi_feat_scale1 = wsi_feat_scale1.squeeze(0) 1374 | scale1_bs = wsi_feat_scale1.shape[0] 1375 | 1376 | # get the instance weight during the first forward 1377 | # 1378 | """ 1379 | Fuse 4 modalities 1380 | """ 1381 | wsi_feat_scale2 = data['wsi_feat_scale2'].cuda(self.local_rank) 1382 | wsi_feat_scale3 = data['wsi_feat_scale3'].cuda(self.local_rank) 1383 | if len(wsi_feat_scale2.size()) == 3: 1384 | # 1 #instance #feat 1385 | wsi_feat_scale2 = wsi_feat_scale2.squeeze(0) 1386 | if len(wsi_feat_scale3.size()) == 3: 1387 | # 1 #instance #feat 1388 | wsi_feat_scale3 = wsi_feat_scale3.squeeze(0) 1389 | 1390 | with torch.no_grad(): 1391 | if self.use_k_agg: 1392 | if wsi_feat_scale1.shape[0] < self.k_agg: 1393 | pad_size = self.k_agg - wsi_feat_scale1.shape[0] 1394 | zero_size = (pad_size, *wsi_feat_scale1.shape[1:]) 1395 | pad_tensor = torch.zeros(zero_size).to(wsi_feat_scale1.device) 1396 | wsi_feat_scale1 = torch.cat([wsi_feat_scale1, pad_tensor]) 1397 | if wsi_feat_scale2.shape[0] < self.k_agg: 1398 | pad_size = self.k_agg - wsi_feat_scale2.shape[0] 1399 | zero_size = (pad_size, *wsi_feat_scale2.shape[1:]) 1400 | pad_tensor = torch.zeros(zero_size).to(wsi_feat_scale2.device) 1401 | wsi_feat_scale2 = torch.cat([wsi_feat_scale2, pad_tensor]) 1402 | if wsi_feat_scale3.shape[0] < self.k_agg: 1403 | pad_size = self.k_agg - wsi_feat_scale3.shape[0] 1404 | zero_size = (pad_size, *wsi_feat_scale3.shape[1:]) 1405 | pad_tensor = torch.zeros(zero_size).to(wsi_feat_scale3.device) 1406 | wsi_feat_scale3 = torch.cat([wsi_feat_scale3, pad_tensor]) 1407 | 1408 | """fine-tuning 3 scales""" 1409 | wsi_ft_feat_list = [] 1410 | for ft_conv, wsi_feat in zip( 1411 | [self.feature_fine_tuning, self.feature_fine_tuning2, self.feature_fine_tuning3], 1412 | [wsi_feat_scale1, wsi_feat_scale2, wsi_feat_scale3], 1413 | ): 1414 | wsi_ft_feat_list.append(ft_conv(wsi_feat)) 1415 | 1416 | if self.use_k_agg: 1417 | agg_feat_list = [] 1418 | for data_feat, score_fc in zip(wsi_ft_feat_list, self.score_fc): 1419 | agg_feat_list.append(self.agg_k_cluster_by_score(data_feat, score_fc)) 1420 | wsi_ft_feat_list = agg_feat_list 1421 | 1422 | wsi_feat_scale_gloabl_list = [] 1423 | for data_feat, score_fc in zip(agg_feat_list, self.score_fc): 1424 | feat_score = score_fc(data_feat) 1425 | feat_attention = torch.sigmoid(feat_score) 1426 | 1427 | attention_weight_out_list.append(feat_attention.detach().clone()) 1428 | global_feat = torch.sum(data_feat * feat_attention, dim=0, keepdim=True) 1429 | wsi_feat_scale_gloabl_list.append(global_feat) 1430 | else: 1431 | """global representation of 3 scale images""" 1432 | wsi_feat_scale_gloabl_list = [] 1433 | for feat in wsi_ft_feat_list: 1434 | wsi_feat_scale_gloabl_list.append(torch.mean(feat, dim=0, keepdim=True)) 1435 | 1436 | 1437 | """mmtm""" 1438 | tab_feat_mmtm, wsi_feat1_gloabl, wsi_feat_scale1_gate, wsi_feat2_gloabl, wsi_feat_scale2_gate, wsi_feat3_gloabl, wsi_feat_scale3_gate = self.mmtm(tab_feat, *wsi_feat_scale_gloabl_list) 1439 | 1440 | """instance selection of 3 scales""" 1441 | wsi_feat_agg_list = [] 1442 | for wsi_feat_at_scale, wsi_feat_gate_at_scale, wsi_global_rep, instance_gate in zip( 1443 | wsi_ft_feat_list, 1444 | [wsi_feat_scale1_gate, wsi_feat_scale2_gate, wsi_feat_scale3_gate], 1445 | wsi_feat_scale_gloabl_list, 1446 | [self.instance_gate1, self.instance_gate2, self.instance_gate3] 1447 | ): 1448 | # 1449 | bs_at_scale = wsi_feat_at_scale.shape[0] 1450 | wsi_feat_at_scale = wsi_feat_at_scale * wsi_feat_gate_at_scale 1451 | wsi_global_rep_repeat = wsi_feat_gate_at_scale.detach().repeat(bs_at_scale, 1) 1452 | 1453 | # N * 1 1454 | instance_attention_weight = instance_gate(wsi_feat_at_scale, wsi_global_rep_repeat) 1455 | # 1 * N 1456 | instance_attention_weight = torch.transpose(instance_attention_weight, 1, 0) 1457 | 1458 | instance_attention_weight = torch.softmax(instance_attention_weight, dim=1) 1459 | 1460 | 1461 | # instance aggregate 1462 | wsi_feat_agg = torch.mm(instance_attention_weight, wsi_feat_at_scale) 1463 | attention_weight_out_list.append(instance_attention_weight.detach().clone()) 1464 | 1465 | # second time forward 1466 | if self.use_k_agg: 1467 | if wsi_feat_scale1.shape[0] < self.k_agg: 1468 | pad_size = self.k_agg - wsi_feat_scale1.shape[0] 1469 | zero_size = (pad_size, *wsi_feat_scale1.shape[1:]) 1470 | pad_tensor = torch.zeros(zero_size).to(wsi_feat_scale1.device) 1471 | wsi_feat_scale1 = torch.cat([wsi_feat_scale1, pad_tensor]) 1472 | if wsi_feat_scale2.shape[0] < self.k_agg: 1473 | pad_size = self.k_agg - wsi_feat_scale2.shape[0] 1474 | zero_size = (pad_size, *wsi_feat_scale2.shape[1:]) 1475 | pad_tensor = torch.zeros(zero_size).to(wsi_feat_scale2.device) 1476 | wsi_feat_scale2 = torch.cat([wsi_feat_scale2, pad_tensor]) 1477 | if wsi_feat_scale3.shape[0] < self.k_agg: 1478 | pad_size = self.k_agg - wsi_feat_scale3.shape[0] 1479 | zero_size = (pad_size, *wsi_feat_scale3.shape[1:]) 1480 | pad_tensor = torch.zeros(zero_size).to(wsi_feat_scale3.device) 1481 | wsi_feat_scale3 = torch.cat([wsi_feat_scale3, pad_tensor]) 1482 | 1483 | """fine-tuning 3 scales""" 1484 | wsi_ft_feat_list = [] 1485 | for ft_conv, wsi_feat in zip( 1486 | [self.feature_fine_tuning, self.feature_fine_tuning2, self.feature_fine_tuning3], 1487 | [wsi_feat_scale1, wsi_feat_scale2, wsi_feat_scale3], 1488 | ): 1489 | wsi_ft_feat_list.append(ft_conv(wsi_feat)) 1490 | 1491 | if self.use_k_agg: 1492 | agg_feat_list = [] 1493 | for data_feat, score_fc in zip(wsi_ft_feat_list, self.score_fc): 1494 | agg_feat_list.append(self.agg_k_cluster_by_score(data_feat, score_fc)) 1495 | wsi_ft_feat_list = agg_feat_list 1496 | 1497 | wsi_feat_scale_gloabl_list = [] 1498 | for data_feat, score_fc in zip(agg_feat_list, self.score_fc): 1499 | feat_score = score_fc(data_feat) 1500 | feat_attention = torch.sigmoid(feat_score) 1501 | 1502 | attention_weight_out_list.append(feat_attention.detach().clone()) 1503 | global_feat = torch.sum(data_feat * feat_attention, dim=0, keepdim=True) 1504 | wsi_feat_scale_gloabl_list.append(global_feat) 1505 | else: 1506 | """global representation of 3 scales""" 1507 | wsi_feat_scale_gloabl_list = [] 1508 | for idx, feat in enumerate(wsi_ft_feat_list): 1509 | current_global_feature = torch.mm(attention_weight_out_list[idx], feat) 1510 | wsi_feat_scale_gloabl_list.append(current_global_feature) 1511 | 1512 | attention_weight_out_list = [] 1513 | """mmtm""" 1514 | tab_feat_mmtm, wsi_feat1_gloabl, wsi_feat_scale1_gate, wsi_feat2_gloabl, wsi_feat_scale2_gate, wsi_feat3_gloabl, wsi_feat_scale3_gate = self.mmtm( 1515 | tab_feat, *wsi_feat_scale_gloabl_list) 1516 | 1517 | """instance selection of 3 scales""" 1518 | wsi_feat_agg_list = [] 1519 | for wsi_feat_at_scale, wsi_feat_gate_at_scale, wsi_global_rep, instance_gate in zip( 1520 | wsi_ft_feat_list, 1521 | [wsi_feat_scale1_gate, wsi_feat_scale2_gate, wsi_feat_scale3_gate], 1522 | wsi_feat_scale_gloabl_list, 1523 | [self.instance_gate1, self.instance_gate2, self.instance_gate3] 1524 | ): 1525 | # 1526 | bs_at_scale = wsi_feat_at_scale.shape[0] 1527 | wsi_feat_at_scale = wsi_feat_at_scale * wsi_feat_gate_at_scale 1528 | wsi_global_rep_repeat = wsi_feat_gate_at_scale.detach().repeat(bs_at_scale, 1) 1529 | 1530 | # N * 1 1531 | instance_attention_weight = instance_gate(wsi_feat_at_scale, wsi_global_rep_repeat) 1532 | # 1 * N 1533 | instance_attention_weight = torch.transpose(instance_attention_weight, 1, 0) 1534 | 1535 | instance_attention_weight = torch.softmax(instance_attention_weight, dim=1) 1536 | 1537 | # instance aggregate 1538 | wsi_feat_agg = torch.mm(instance_attention_weight, wsi_feat_at_scale) 1539 | 1540 | attention_weight_out_list.append(instance_attention_weight.detach().clone()) 1541 | wsi_feat_agg_list.append(wsi_feat_agg) 1542 | 1543 | """tab feat ft""" 1544 | tab_feat_ft = self.table_feature_ft(tab_feat_mmtm) 1545 | 1546 | final_feat = torch.cat([tab_feat_ft, *wsi_feat_agg_list, wsi_feat1_gloabl, wsi_feat2_gloabl, wsi_feat3_gloabl], dim=1) 1547 | 1548 | out = self.classifier(final_feat) 1549 | 1550 | 1551 | pass 1552 | y = y.view(-1, 1).float() 1553 | loss = F.binary_cross_entropy_with_logits(out, y) + \ 1554 | tab_loss_weight * F.binary_cross_entropy_with_logits(tab_logit, y) - \ 1555 | self.lambda_sparse * M_loss 1556 | 1557 | return out, loss, attention_weight_out_list 1558 | 1559 | def get_params(self, base_lr): 1560 | ret = [] 1561 | 1562 | if self.tabnet is not None: 1563 | tabnet_params = [] 1564 | for param in self.tabnet.parameters(): 1565 | tabnet_params.append(param) 1566 | ret.append({ 1567 | 'params': tabnet_params, 1568 | 'lr': base_lr 1569 | }) 1570 | 1571 | cls_learning_rate_rate=100 1572 | if self.classifier is not None: 1573 | classifier_params = [] 1574 | for param in self.classifier.parameters(): 1575 | classifier_params.append(param) 1576 | ret.append({ 1577 | 'params': classifier_params, 1578 | 'lr': base_lr / cls_learning_rate_rate, 1579 | }) 1580 | 1581 | 1582 | tab_learning_rate_rate = 100 1583 | if self.table_feature_ft is not None: 1584 | misc_params = [] 1585 | for param in self.table_feature_ft.parameters(): 1586 | misc_params.append(param) 1587 | ret.append({ 1588 | 'params': misc_params, 1589 | 'lr': base_lr / tab_learning_rate_rate, 1590 | }) 1591 | 1592 | mil_learning_rate_rate = 1000 1593 | misc_params = [] 1594 | for part in [self.feature_fine_tuning, self.feature_fine_tuning2, self.feature_fine_tuning3, 1595 | self.instance_gate1, self.instance_gate2, self.instance_gate3, 1596 | self.wsi_select_gate, 1597 | self.score_fc]: 1598 | if part is not None: 1599 | for param in part.parameters(): 1600 | misc_params.append(param) 1601 | ret.append({ 1602 | 'params': misc_params, 1603 | 'lr': base_lr / mil_learning_rate_rate, 1604 | }) 1605 | 1606 | misc_learning_rate_rate = 100 1607 | misc_params = [] 1608 | for part in [self.mmtm, ]: 1609 | if part is not None: 1610 | for param in part.parameters(): 1611 | misc_params.append(param) 1612 | ret.append({ 1613 | 'params': misc_params, 1614 | 'lr': base_lr / misc_learning_rate_rate, 1615 | }) 1616 | 1617 | return ret 1618 | 1619 | 1620 | 1621 | class InstanceAttentionGateAdd(nn.Module): 1622 | def __init__(self, feat_dim): 1623 | super(InstanceAttentionGateAdd, self).__init__() 1624 | self.trans = nn.Sequential( 1625 | nn.Linear(feat_dim, feat_dim), 1626 | nn.LeakyReLU(), 1627 | nn.Linear(feat_dim, 1), 1628 | ) 1629 | 1630 | def forward(self, instance_feature, global_feature): 1631 | 1632 | feat = instance_feature + global_feature 1633 | attention = self.trans(feat) 1634 | return attention 1635 | 1636 | 1637 | class MILFusionAdd(nn.Module): 1638 | def __init__(self, img_feat_input_dim=512, tab_feat_input_dim=32, 1639 | img_feat_rep_layers=4, 1640 | num_modal=2, 1641 | use_tabnet=False, 1642 | tab_indim=0, 1643 | local_rank=0, 1644 | cat_idxs=None, 1645 | cat_dims=None, 1646 | lambda_sparse=1e-3, 1647 | fusion='mmtm', 1648 | use_k_agg=False, 1649 | k_agg=10, 1650 | ): 1651 | super(MILFusionAdd, self).__init__() 1652 | self.num_modal = num_modal 1653 | self.local_rank = local_rank 1654 | self.use_tabnet = use_tabnet 1655 | self.tab_indim = tab_indim 1656 | self.lambda_sparse = lambda_sparse 1657 | # define K mean agg 1658 | self.use_k_agg = use_k_agg 1659 | self.k_agg = k_agg 1660 | 1661 | self.fusion_method = fusion 1662 | if self.use_tabnet: 1663 | self.tabnet = TabNet(input_dim=tab_indim, output_dim=1, 1664 | n_d=32, n_a=32, n_steps=5, 1665 | gamma=1.5, n_independent=2, n_shared=2, 1666 | momentum=0.3, 1667 | cat_idxs=cat_idxs, cat_dims=cat_dims) 1668 | else: 1669 | self.tabnet = None 1670 | 1671 | if self.use_tabnet and num_modal == 1: 1672 | self.only_tabnet = True 1673 | else: 1674 | self.only_tabnet = False 1675 | 1676 | """ 1677 | Control tabnet 1678 | """ 1679 | if self.only_tabnet: 1680 | self.feature_fine_tuning = None 1681 | else: 1682 | """pretrained feature fine tune""" 1683 | feature_fine_tuning_layers = [] 1684 | for _ in range(img_feat_rep_layers): 1685 | feature_fine_tuning_layers.extend([ 1686 | nn.Linear(img_feat_input_dim, img_feat_input_dim), 1687 | nn.LeakyReLU(), 1688 | ]) 1689 | self.feature_fine_tuning = nn.Sequential(*feature_fine_tuning_layers) 1690 | 1691 | # 3 为三个图像模态 1692 | if self.num_modal == 4 or self.num_modal == 3: 1693 | self.feature_fine_tuning2 = nn.Sequential(*feature_fine_tuning_layers) 1694 | self.feature_fine_tuning3 = nn.Sequential(*feature_fine_tuning_layers) 1695 | else: 1696 | self.feature_fine_tuning2 = None 1697 | self.feature_fine_tuning3 = None 1698 | 1699 | if self.only_tabnet or self.num_modal == 3: 1700 | self.table_feature_ft = None 1701 | else: 1702 | """tab feature fine tuning""" 1703 | self.table_feature_ft = nn.Sequential( 1704 | nn.Linear(tab_feat_input_dim, tab_feat_input_dim) 1705 | ) 1706 | 1707 | # k agg score 1708 | self.score_fc = nn.ModuleList() 1709 | if self.use_k_agg: 1710 | for _ in range(self.num_modal - 1): 1711 | self.score_fc.append( 1712 | nn.Sequential( 1713 | nn.Linear(img_feat_input_dim, img_feat_input_dim), 1714 | nn.LeakyReLU(), 1715 | nn.Linear(img_feat_input_dim, img_feat_input_dim), 1716 | nn.LeakyReLU(), 1717 | nn.Linear(img_feat_input_dim, 1), 1718 | nn.Sigmoid() 1719 | ) 1720 | ) 1721 | 1722 | 1723 | """modal fusion""" 1724 | self.wsi_select_gate = None 1725 | # define modal fusion related output feature dimension and fusion module 1726 | if self.only_tabnet: 1727 | self.mmtm = None 1728 | elif self.fusion_method == 'concat': 1729 | self.fusion_out_dim = tab_feat_input_dim + img_feat_input_dim 1730 | self.wsi_select_gate = nn.Sequential( 1731 | nn.Linear(img_feat_input_dim, 1), 1732 | nn.Sigmoid() 1733 | ) 1734 | self.mmtm = nn.Linear(self.fusion_out_dim, self.fusion_out_dim) 1735 | elif self.fusion_method == 'bilinear': 1736 | self.wsi_select_gate = nn.Sequential( 1737 | nn.Linear(img_feat_input_dim, 1), 1738 | nn.Sigmoid() 1739 | ) 1740 | self.fusion_out_dim = tab_feat_input_dim + img_feat_input_dim 1741 | self.mmtm = nn.Bilinear(tab_feat_input_dim, img_feat_input_dim, self.fusion_out_dim) 1742 | elif self.fusion_method == 'add': 1743 | self.wsi_select_gate = nn.Sequential( 1744 | nn.Linear(img_feat_input_dim, 1), 1745 | nn.Sigmoid() 1746 | ) 1747 | self.fusion_out_dim = tab_feat_input_dim 1748 | self.mmtm = nn.Linear(img_feat_input_dim * (num_modal - 1), tab_feat_input_dim) 1749 | elif self.fusion_method == 'gate': 1750 | self.wsi_select_gate = nn.Sequential( 1751 | nn.Linear(img_feat_input_dim, 1), 1752 | nn.Sigmoid() 1753 | ) 1754 | self.fusion_out_dim = 96 1755 | self.mmtm = BilinearFusion(dim1=tab_feat_input_dim, dim2=img_feat_input_dim, mmhid=self.fusion_out_dim) 1756 | elif self.num_modal == 2 and self.fusion_method == 'mmtm': 1757 | self.fusion_out_dim = (img_feat_input_dim * 2) * (num_modal - 1) + tab_feat_input_dim 1758 | self.mmtm = MMTMBi(dim_tab=tab_feat_input_dim, dim_img=img_feat_input_dim) 1759 | elif self.num_modal == 3 and self.fusion_method == 'mmtm': 1760 | self.fusion_out_dim = (img_feat_input_dim * 2) * 3 1761 | self.mmtm = MMTMTri(dim_img=img_feat_input_dim) 1762 | elif self.num_modal == 4 and self.fusion_method == 'mmtm': 1763 | self.fusion_out_dim = (img_feat_input_dim * 2) * (num_modal - 1) + tab_feat_input_dim 1764 | self.mmtm = MMTMQuad(dim_tab=tab_feat_input_dim, dim_img=img_feat_input_dim) 1765 | else: 1766 | raise NotImplementedError(f'num_modal {num_modal} not implemented') 1767 | 1768 | """instance selection""" 1769 | if self.only_tabnet or self.fusion_method in ['concat', 'add', 'bilinear', 'gate']: 1770 | self.instance_gate1 = None 1771 | else: 1772 | self.instance_gate1 = InstanceAttentionGateAdd(img_feat_input_dim) 1773 | 1774 | if (self.num_modal == 4 or self.num_modal == 3)and self.fusion_method == 'mmtm': 1775 | self.instance_gate2 = InstanceAttentionGateAdd(img_feat_input_dim) 1776 | self.instance_gate3 = InstanceAttentionGateAdd(img_feat_input_dim) 1777 | else: 1778 | self.instance_gate2 = None 1779 | self.instance_gate3 = None 1780 | 1781 | """classifier layer""" 1782 | if self.only_tabnet: 1783 | self.classifier = None 1784 | else: 1785 | self.classifier = nn.Sequential( 1786 | nn.Linear(self.fusion_out_dim, self.fusion_out_dim), 1787 | nn.Dropout(0.5), 1788 | nn.Linear(self.fusion_out_dim, 1) 1789 | ) 1790 | 1791 | def agg_k_cluster_by_score(self, data: torch.Tensor, score_fc: nn.Module): 1792 | num_elements = data.shape[0] 1793 | score = score_fc(data) 1794 | 1795 | """ 1796 | >>> score = torch.rand(4,1) 1797 | >>> top_score, top_idx = torch.topk(score, k=num_elements, dim=0) 1798 | >>> top_score, top_idx 1799 | (tensor([[0.3963], 1800 | [0.0856], 1801 | [0.0704], 1802 | [0.0247]]), 1803 | tensor([[1], 1804 | [0], 1805 | [3], 1806 | [2]])) 1807 | """ 1808 | top_score, top_idx = torch.topk(score, k=num_elements, dim=0) 1809 | """ 1810 | >>> data 1811 | tensor([[0.0672, 0.9001, 0.5660, 0.0522, 0.1543], 1812 | [0.1965, 0.7711, 0.9737, 0.5269, 0.9255], 1813 | [0.6761, 0.5801, 0.4687, 0.1683, 0.8136], 1814 | [0.2091, 0.9620, 0.8105, 0.8210, 0.3391]]) 1815 | >>> top_idx[:, 0] 1816 | tensor([1, 0, 3, 2]) 1817 | >>> data_sorted 1818 | tensor([[0.1965, 0.7711, 0.9737, 0.5269, 0.9255], 1819 | [0.0672, 0.9001, 0.5660, 0.0522, 0.1543], 1820 | [0.2091, 0.9620, 0.8105, 0.8210, 0.3391], 1821 | [0.6761, 0.5801, 0.4687, 0.1683, 0.8136]]) 1822 | """ 1823 | data_sorted = torch.zeros_like(data) 1824 | data_sorted.index_copy_(dim=0, index=top_idx[:, 0], source=data) 1825 | 1826 | # Batch set as feature dim 1827 | data_sorted = torch.transpose(data_sorted, 1, 0) 1828 | data_sorted = data_sorted.unsqueeze(1) 1829 | 1830 | agg_result = nn.functional.adaptive_max_pool1d(data_sorted, self.k_agg) 1831 | 1832 | agg_result = agg_result.squeeze(1) 1833 | agg_result = torch.transpose(agg_result, 1, 0) 1834 | return agg_result 1835 | 1836 | def forward(self, data): 1837 | 1838 | attention_weight_out_list = [] 1839 | if self.use_tabnet: 1840 | tab_data = data['tab_data'].cuda(self.local_rank) 1841 | if self.only_tabnet: 1842 | tab_logit, M_loss = self.tabnet(tab_data) 1843 | else: 1844 | tab_logit, tab_feat, M_loss = self.tabnet(tab_data) 1845 | 1846 | tab_loss_weight = 1. 1847 | else: 1848 | tab_feat = data['tab_feat'].cuda(self.local_rank) 1849 | tab_logit = torch.zeros((1, 1)).cuda(self.local_rank) 1850 | M_loss = 0. 1851 | tab_loss_weight = 0. 1852 | 1853 | y = data['label'].cuda(self.local_rank) 1854 | wsi_feat_scale1 = data['wsi_feat_scale1'].cuda(self.local_rank) 1855 | if len(wsi_feat_scale1.size()) == 3: 1856 | # 1 #instance #feat 1857 | wsi_feat_scale1 = wsi_feat_scale1.squeeze(0) 1858 | scale1_bs = wsi_feat_scale1.shape[0] 1859 | 1860 | if True: 1861 | """ 1862 | Fusion of 4 modalities 1863 | """ 1864 | wsi_feat_scale2 = data['wsi_feat_scale2'].cuda(self.local_rank) 1865 | wsi_feat_scale3 = data['wsi_feat_scale3'].cuda(self.local_rank) 1866 | if len(wsi_feat_scale2.size()) == 3: 1867 | # 1 #instance #feat 1868 | wsi_feat_scale2 = wsi_feat_scale2.squeeze(0) 1869 | if len(wsi_feat_scale3.size()) == 3: 1870 | # 1 #instance #feat 1871 | wsi_feat_scale3 = wsi_feat_scale3.squeeze(0) 1872 | 1873 | if self.use_k_agg: 1874 | if wsi_feat_scale1.shape[0] < self.k_agg: 1875 | pad_size = self.k_agg - wsi_feat_scale1.shape[0] 1876 | zero_size = (pad_size, *wsi_feat_scale1.shape[1:]) 1877 | pad_tensor = torch.zeros(zero_size).to(wsi_feat_scale1.device) 1878 | wsi_feat_scale1 = torch.cat([wsi_feat_scale1, pad_tensor]) 1879 | if wsi_feat_scale2.shape[0] < self.k_agg: 1880 | pad_size = self.k_agg - wsi_feat_scale2.shape[0] 1881 | zero_size = (pad_size, *wsi_feat_scale2.shape[1:]) 1882 | pad_tensor = torch.zeros(zero_size).to(wsi_feat_scale2.device) 1883 | wsi_feat_scale2 = torch.cat([wsi_feat_scale2, pad_tensor]) 1884 | if wsi_feat_scale3.shape[0] < self.k_agg: 1885 | pad_size = self.k_agg - wsi_feat_scale3.shape[0] 1886 | zero_size = (pad_size, *wsi_feat_scale3.shape[1:]) 1887 | pad_tensor = torch.zeros(zero_size).to(wsi_feat_scale3.device) 1888 | wsi_feat_scale3 = torch.cat([wsi_feat_scale3, pad_tensor]) 1889 | 1890 | """fine-tuning 3 scales features""" 1891 | wsi_ft_feat_list = [] 1892 | for ft_conv, wsi_feat in zip( 1893 | [self.feature_fine_tuning, self.feature_fine_tuning2, self.feature_fine_tuning3], 1894 | [wsi_feat_scale1, wsi_feat_scale2, wsi_feat_scale3], 1895 | ): 1896 | wsi_ft_feat_list.append(ft_conv(wsi_feat)) 1897 | 1898 | if self.use_k_agg: 1899 | agg_feat_list = [] 1900 | for data_feat, score_fc in zip(wsi_ft_feat_list, self.score_fc): 1901 | agg_feat_list.append(self.agg_k_cluster_by_score(data_feat, score_fc)) 1902 | wsi_ft_feat_list = agg_feat_list 1903 | 1904 | wsi_feat_scale_gloabl_list = [] 1905 | for data_feat, score_fc in zip(agg_feat_list, self.score_fc): 1906 | feat_score = score_fc(data_feat) 1907 | # feat_attention = torch.softmax(feat_score, dim=0) 1908 | feat_attention = torch.sigmoid(feat_score) 1909 | 1910 | attention_weight_out_list.append(feat_attention.detach().clone()) 1911 | global_feat = torch.sum(data_feat * feat_attention, dim=0, keepdim=True) 1912 | wsi_feat_scale_gloabl_list.append(global_feat) 1913 | else: 1914 | """global representation of 3 scales""" 1915 | wsi_feat_scale_gloabl_list = [] 1916 | for feat in wsi_ft_feat_list: 1917 | wsi_feat_scale_gloabl_list.append(torch.mean(feat, dim=0, keepdim=True)) 1918 | 1919 | 1920 | """mmtm""" 1921 | tab_feat_mmtm, wsi_feat1_gloabl, wsi_feat_scale1_gate, wsi_feat2_gloabl, wsi_feat_scale2_gate, wsi_feat3_gloabl, wsi_feat_scale3_gate = self.mmtm(tab_feat, *wsi_feat_scale_gloabl_list) 1922 | 1923 | """instance selection of 3 scales""" 1924 | wsi_feat_agg_list = [] 1925 | for wsi_feat_at_scale, wsi_feat_gate_at_scale, wsi_global_rep, instance_gate in zip( 1926 | wsi_ft_feat_list, 1927 | [wsi_feat_scale1_gate, wsi_feat_scale2_gate, wsi_feat_scale3_gate], 1928 | wsi_feat_scale_gloabl_list, 1929 | [self.instance_gate1, self.instance_gate2, self.instance_gate3] 1930 | ): 1931 | # 1932 | bs_at_scale = wsi_feat_at_scale.shape[0] 1933 | wsi_feat_at_scale = wsi_feat_at_scale * wsi_feat_gate_at_scale 1934 | wsi_global_rep_repeat = wsi_feat_gate_at_scale.detach().repeat(bs_at_scale, 1) 1935 | 1936 | # N * 1 1937 | instance_attention_weight = instance_gate(wsi_feat_at_scale, wsi_global_rep_repeat) 1938 | # 1 * N 1939 | instance_attention_weight = torch.transpose(instance_attention_weight, 1, 0) 1940 | 1941 | instance_attention_weight = torch.softmax(instance_attention_weight, dim=1) 1942 | 1943 | 1944 | # instance aggregate 1945 | wsi_feat_agg = torch.mm(instance_attention_weight, wsi_feat_at_scale) 1946 | attention_weight_out_list.append(instance_attention_weight.detach().clone()) 1947 | wsi_feat_agg_list.append(wsi_feat_agg) 1948 | 1949 | """tab feat ft""" 1950 | tab_feat_ft = self.table_feature_ft(tab_feat_mmtm) 1951 | 1952 | final_feat = torch.cat([tab_feat_ft, *wsi_feat_agg_list, wsi_feat1_gloabl, wsi_feat2_gloabl, wsi_feat3_gloabl], dim=1) 1953 | 1954 | out = self.classifier(final_feat) 1955 | 1956 | 1957 | pass 1958 | y = y.view(-1, 1).float() 1959 | loss = F.binary_cross_entropy_with_logits(out, y) + \ 1960 | tab_loss_weight * F.binary_cross_entropy_with_logits(tab_logit, y) - \ 1961 | self.lambda_sparse * M_loss 1962 | 1963 | return out, loss, attention_weight_out_list 1964 | 1965 | def get_params(self, base_lr): 1966 | ret = [] 1967 | 1968 | if self.tabnet is not None: 1969 | tabnet_params = [] 1970 | for param in self.tabnet.parameters(): 1971 | tabnet_params.append(param) 1972 | ret.append({ 1973 | 'params': tabnet_params, 1974 | 'lr': base_lr 1975 | }) 1976 | 1977 | cls_learning_rate_rate=100 1978 | if self.classifier is not None: 1979 | classifier_params = [] 1980 | for param in self.classifier.parameters(): 1981 | classifier_params.append(param) 1982 | ret.append({ 1983 | 'params': classifier_params, 1984 | 'lr': base_lr / cls_learning_rate_rate, 1985 | }) 1986 | 1987 | tab_learning_rate_rate = 100 1988 | if self.table_feature_ft is not None: 1989 | misc_params = [] 1990 | for param in self.table_feature_ft.parameters(): 1991 | misc_params.append(param) 1992 | ret.append({ 1993 | 'params': misc_params, 1994 | 'lr': base_lr / tab_learning_rate_rate, 1995 | }) 1996 | 1997 | mil_learning_rate_rate = 1000 1998 | misc_params = [] 1999 | for part in [self.feature_fine_tuning, self.feature_fine_tuning2, self.feature_fine_tuning3, 2000 | self.instance_gate1, self.instance_gate2, self.instance_gate3, 2001 | self.wsi_select_gate, 2002 | self.score_fc]: 2003 | if part is not None: 2004 | for param in part.parameters(): 2005 | misc_params.append(param) 2006 | ret.append({ 2007 | 'params': misc_params, 2008 | 'lr': base_lr / mil_learning_rate_rate, 2009 | }) 2010 | 2011 | misc_learning_rate_rate = 100 2012 | misc_params = [] 2013 | for part in [self.mmtm, ]: 2014 | if part is not None: 2015 | for param in part.parameters(): 2016 | misc_params.append(param) 2017 | ret.append({ 2018 | 'params': misc_params, 2019 | 'lr': base_lr / misc_learning_rate_rate, 2020 | }) 2021 | 2022 | return ret 2023 | 2024 | 2025 | 2026 | 2027 | 2028 | 2029 | 2030 | 2031 | -------------------------------------------------------------------------------- /models/tabnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfzon/Multi-modal-Multi-instance-Learning/c30051bc87c8d40ec93d29ca216ae088816ac6b5/models/tabnet/__init__.py -------------------------------------------------------------------------------- /models/tabnet/sparsemax.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch.nn.functional as F 4 | 5 | import torch 6 | 7 | def _make_ix_like(input, dim=0): 8 | d = input.size(dim) 9 | rho = torch.arange(1, d + 1, device=input.device, dtype=input.dtype) 10 | view = [1] * input.dim() 11 | view[0] = -1 12 | return rho.view(view).transpose(0, dim) 13 | 14 | 15 | class SparsemaxFunction(Function): 16 | """ 17 | An implementation of sparsemax (Martins & Astudillo, 2016). See 18 | :cite:`DBLP:journals/corr/MartinsA16` for detailed description. 19 | By Ben Peters and Vlad Niculae 20 | """ 21 | 22 | @staticmethod 23 | def forward(ctx, input, dim=-1): 24 | """sparsemax: normalizing sparse transform (a la softmax) 25 | Parameters: 26 | input (Tensor): any shape 27 | dim: dimension along which to apply sparsemax 28 | Returns: 29 | output (Tensor): same shape as input 30 | """ 31 | ctx.dim = dim 32 | max_val, _ = input.max(dim=dim, keepdim=True) 33 | input -= max_val # same numerical stability trick as for softmax 34 | tau, supp_size = SparsemaxFunction._threshold_and_support(input, dim=dim) 35 | output = torch.clamp(input - tau, min=0) 36 | ctx.save_for_backward(supp_size, output) 37 | return output 38 | 39 | @staticmethod 40 | def backward(ctx, grad_output): 41 | supp_size, output = ctx.saved_tensors 42 | dim = ctx.dim 43 | grad_input = grad_output.clone() 44 | grad_input[output == 0] = 0 45 | 46 | v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() 47 | v_hat = v_hat.unsqueeze(dim) 48 | grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) 49 | return grad_input, None 50 | 51 | @staticmethod 52 | def _threshold_and_support(input, dim=-1): 53 | """Sparsemax building block: compute the threshold 54 | Args: 55 | input: any dimension 56 | dim: dimension along which to apply the sparsemax 57 | Returns: 58 | the threshold value 59 | """ 60 | 61 | input_srt, _ = torch.sort(input, descending=True, dim=dim) 62 | input_cumsum = input_srt.cumsum(dim) - 1 63 | rhos = _make_ix_like(input, dim) 64 | support = rhos * input_srt > input_cumsum 65 | 66 | support_size = support.sum(dim=dim).unsqueeze(dim) 67 | tau = input_cumsum.gather(dim, support_size - 1) 68 | tau /= support_size.to(input.dtype) 69 | return tau, support_size 70 | 71 | 72 | sparsemax = SparsemaxFunction.apply 73 | 74 | 75 | class Sparsemax(nn.Module): 76 | 77 | def __init__(self, dim=-1): 78 | self.dim = dim 79 | super(Sparsemax, self).__init__() 80 | 81 | def forward(self, input): 82 | return sparsemax(input, self.dim) 83 | 84 | 85 | class Entmax15Function(Function): 86 | 87 | 88 | @staticmethod 89 | def forward(ctx, input, dim=-1): 90 | ctx.dim = dim 91 | 92 | max_val, _ = input.max(dim=dim, keepdim=True) 93 | input = input - max_val # same numerical stability trick as for softmax 94 | input = input / 2 # divide by 2 to solve actual Entmax 95 | 96 | tau_star, _ = Entmax15Function._threshold_and_support(input, dim) 97 | output = torch.clamp(input - tau_star, min=0) ** 2 98 | ctx.save_for_backward(output) 99 | return output 100 | 101 | @staticmethod 102 | def backward(ctx, grad_output): 103 | Y, = ctx.saved_tensors 104 | gppr = Y.sqrt() 105 | dX = grad_output * gppr 106 | q = dX.sum(ctx.dim) / gppr.sum(ctx.dim) 107 | q = q.unsqueeze(ctx.dim) 108 | dX -= q * gppr 109 | return dX, None 110 | 111 | @staticmethod 112 | def _threshold_and_support(input, dim=-1): 113 | Xsrt, _ = torch.sort(input, descending=True, dim=dim) 114 | 115 | rho = _make_ix_like(input, dim) 116 | mean = Xsrt.cumsum(dim) / rho 117 | mean_sq = (Xsrt ** 2).cumsum(dim) / rho 118 | ss = rho * (mean_sq - mean ** 2) 119 | delta = (1 - ss) / rho 120 | 121 | # NOTE this is not exactly the same as in reference algo 122 | # Fortunately it seems the clamped values never wrongly 123 | # get selected by tau <= sorted_z. Prove this! 124 | delta_nz = torch.clamp(delta, 0) 125 | tau = mean - torch.sqrt(delta_nz) 126 | 127 | support_size = (tau <= Xsrt).sum(dim).unsqueeze(dim) 128 | tau_star = tau.gather(dim, support_size - 1) 129 | return tau_star, support_size 130 | 131 | 132 | class Entmoid15(Function): 133 | """ A highly optimized equivalent of labda x: Entmax15([x, 0]) """ 134 | 135 | @staticmethod 136 | def forward(ctx, input): 137 | output = Entmoid15._forward(input) 138 | ctx.save_for_backward(output) 139 | return output 140 | 141 | @staticmethod 142 | def _forward(input): 143 | input, is_pos = abs(input), input >= 0 144 | tau = (input + torch.sqrt(F.relu(8 - input ** 2))) / 2 145 | tau.masked_fill_(tau <= input, 2.0) 146 | y_neg = 0.25 * F.relu(tau - input, inplace=True) ** 2 147 | return torch.where(is_pos, 1 - y_neg, y_neg) 148 | 149 | @staticmethod 150 | def backward(ctx, grad_output): 151 | return Entmoid15._backward(ctx.saved_tensors[0], grad_output) 152 | 153 | @staticmethod 154 | def _backward(output, grad_output): 155 | gppr0, gppr1 = output.sqrt(), (1 - output).sqrt() 156 | grad_input = grad_output * gppr0 157 | q = grad_input / (gppr0 + gppr1) 158 | grad_input -= q * gppr0 159 | return grad_input 160 | 161 | 162 | entmax15 = Entmax15Function.apply 163 | entmoid15 = Entmoid15.apply 164 | 165 | 166 | class Entmax15(nn.Module): 167 | 168 | def __init__(self, dim=-1): 169 | self.dim = dim 170 | super(Entmax15, self).__init__() 171 | 172 | def forward(self, input): 173 | return entmax15(input, self.dim) 174 | 175 | -------------------------------------------------------------------------------- /models/tabnet/tab_network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Linear, BatchNorm1d, ReLU 3 | import numpy as np 4 | 5 | from models.tabnet import sparsemax 6 | 7 | def initialize_non_glu(module, input_dim, output_dim): 8 | gain_value = np.sqrt((input_dim+output_dim)/np.sqrt(4*input_dim)) 9 | torch.nn.init.xavier_normal_(module.weight, gain=gain_value) 10 | return 11 | 12 | 13 | def initialize_glu(module, input_dim, output_dim): 14 | gain_value = np.sqrt((input_dim+output_dim)/np.sqrt(input_dim)) 15 | torch.nn.init.xavier_normal_(module.weight, gain=gain_value) 16 | # torch.nn.init.zeros_(module.bias) 17 | return 18 | 19 | 20 | class GBN(torch.nn.Module): 21 | """ 22 | Ghost Batch Normalization 23 | """ 24 | 25 | def __init__(self, input_dim, virtual_batch_size=128, momentum=0.01): 26 | super(GBN, self).__init__() 27 | 28 | self.input_dim = input_dim 29 | self.virtual_batch_size = virtual_batch_size 30 | self.bn = BatchNorm1d(self.input_dim, momentum=momentum) 31 | 32 | def forward(self, x): 33 | chunks = x.chunk(int(np.ceil(x.shape[0] / self.virtual_batch_size)), 0) 34 | res = [self.bn(x_) for x_ in chunks] 35 | 36 | return torch.cat(res, dim=0) 37 | 38 | 39 | class TabNetNoEmbeddings(torch.nn.Module): 40 | def __init__(self, input_dim, output_dim, 41 | n_d=8, n_a=8, 42 | n_steps=3, gamma=1.3, 43 | n_independent=2, n_shared=2, epsilon=1e-15, 44 | virtual_batch_size=128, momentum=0.02, 45 | mask_type="sparsemax"): 46 | """ 47 | Defines main part of the TabNet network without the embedding layers. 48 | 49 | Parameters 50 | ---------- 51 | - input_dim : int 52 | Number of features 53 | - output_dim : int or list of int for multi task classification 54 | Dimension of network output 55 | examples : one for regression, 2 for binary classification etc... 56 | - n_d : int 57 | Dimension of the prediction layer (usually between 4 and 64) 58 | - n_a : int 59 | Dimension of the attention layer (usually between 4 and 64) 60 | - n_steps: int 61 | Number of sucessive steps in the newtork (usually betwenn 3 and 10) 62 | - gamma : float 63 | Float above 1, scaling factor for attention updates (usually betwenn 1.0 to 2.0) 64 | - momentum : float 65 | Float value between 0 and 1 which will be used for momentum in all batch norm 66 | - n_independent : int 67 | Number of independent GLU layer in each GLU block (default 2) 68 | - n_shared : int 69 | Number of independent GLU layer in each GLU block (default 2) 70 | - epsilon: float 71 | Avoid log(0), this should be kept very low 72 | - mask_type: str 73 | Either "sparsemax" or "entmax" : this is the masking function to use 74 | """ 75 | super(TabNetNoEmbeddings, self).__init__() 76 | self.input_dim = input_dim 77 | self.output_dim = output_dim 78 | self.is_multi_task = isinstance(output_dim, list) 79 | self.n_d = n_d 80 | self.n_a = n_a 81 | self.n_steps = n_steps 82 | self.gamma = gamma 83 | self.epsilon = epsilon 84 | self.n_independent = n_independent 85 | self.n_shared = n_shared 86 | self.virtual_batch_size = virtual_batch_size 87 | self.mask_type = mask_type 88 | self.initial_bn = BatchNorm1d(self.input_dim, momentum=0.01) 89 | 90 | if self.n_shared > 0: 91 | shared_feat_transform = torch.nn.ModuleList() 92 | for i in range(self.n_shared): 93 | if i == 0: 94 | shared_feat_transform.append(Linear(self.input_dim, 95 | 2*(n_d + n_a), 96 | bias=False)) 97 | else: 98 | shared_feat_transform.append(Linear(n_d + n_a, 2*(n_d + n_a), bias=False)) 99 | 100 | else: 101 | shared_feat_transform = None 102 | 103 | self.initial_splitter = FeatTransformer(self.input_dim, n_d+n_a, shared_feat_transform, 104 | n_glu_independent=self.n_independent, 105 | virtual_batch_size=self.virtual_batch_size, 106 | momentum=momentum) 107 | 108 | self.feat_transformers = torch.nn.ModuleList() 109 | self.att_transformers = torch.nn.ModuleList() 110 | 111 | for step in range(n_steps): 112 | transformer = FeatTransformer(self.input_dim, n_d+n_a, shared_feat_transform, 113 | n_glu_independent=self.n_independent, 114 | virtual_batch_size=self.virtual_batch_size, 115 | momentum=momentum) 116 | attention = AttentiveTransformer(n_a, self.input_dim, 117 | virtual_batch_size=self.virtual_batch_size, 118 | momentum=momentum, 119 | mask_type=self.mask_type) 120 | self.feat_transformers.append(transformer) 121 | self.att_transformers.append(attention) 122 | 123 | if self.is_multi_task: 124 | self.multi_task_mappings = torch.nn.ModuleList() 125 | for task_dim in output_dim: 126 | task_mapping = Linear(n_d, task_dim, bias=False) 127 | initialize_non_glu(task_mapping, n_d, task_dim) 128 | self.multi_task_mappings.append(task_mapping) 129 | else: 130 | self.final_mapping = Linear(n_d, output_dim, bias=False) 131 | initialize_non_glu(self.final_mapping, n_d, output_dim) 132 | 133 | def forward(self, x): 134 | res = 0 135 | x = self.initial_bn(x) 136 | 137 | prior = torch.ones(x.shape).to(x.device) 138 | M_loss = 0 139 | att = self.initial_splitter(x)[:, self.n_d:] 140 | 141 | for step in range(self.n_steps): 142 | M = self.att_transformers[step](prior, att) 143 | M_loss += torch.mean(torch.sum(torch.mul(M, torch.log(M+self.epsilon)), 144 | dim=1)) 145 | # update prior 146 | prior = torch.mul(self.gamma - M, prior) 147 | # output 148 | masked_x = torch.mul(M, x) 149 | out = self.feat_transformers[step](masked_x) 150 | d = ReLU()(out[:, :self.n_d]) 151 | res = torch.add(res, d) 152 | # update attention 153 | att = out[:, self.n_d:] 154 | 155 | M_loss /= self.n_steps 156 | 157 | if self.is_multi_task: 158 | # Result will be in list format 159 | out = [] 160 | for task_mapping in self.multi_task_mappings: 161 | out.append(task_mapping(res)) 162 | else: 163 | out = self.final_mapping(res) 164 | return out, res, M_loss 165 | 166 | def forward_masks(self, x): 167 | x = self.initial_bn(x) 168 | 169 | prior = torch.ones(x.shape).to(x.device) 170 | M_explain = torch.zeros(x.shape).to(x.device) 171 | att = self.initial_splitter(x)[:, self.n_d:] 172 | masks = {} 173 | 174 | for step in range(self.n_steps): 175 | M = self.att_transformers[step](prior, att) 176 | masks[step] = M 177 | # update prior 178 | prior = torch.mul(self.gamma - M, prior) 179 | # output 180 | masked_x = torch.mul(M, x) 181 | out = self.feat_transformers[step](masked_x) 182 | d = ReLU()(out[:, :self.n_d]) 183 | # explain 184 | step_importance = torch.sum(d, dim=1) 185 | M_explain += torch.mul(M, step_importance.unsqueeze(dim=1)) 186 | # update attention 187 | att = out[:, self.n_d:] 188 | 189 | return M_explain, masks 190 | 191 | 192 | class TabNet(torch.nn.Module): 193 | def __init__(self, input_dim, output_dim, n_d=8, n_a=8, 194 | n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[], cat_emb_dim=1, 195 | n_independent=2, n_shared=2, epsilon=1e-15, 196 | virtual_batch_size=128, momentum=0.02, device_name='auto', 197 | mask_type="sparsemax"): 198 | """ 199 | Defines TabNet network 200 | 201 | Parameters 202 | ---------- 203 | - input_dim : int 204 | Initial number of features 205 | - output_dim : int 206 | Dimension of network output 207 | examples : one for regression, 2 for binary classification etc... 208 | - n_d : int 209 | Dimension of the prediction layer (usually between 4 and 64) 210 | - n_a : int 211 | Dimension of the attention layer (usually between 4 and 64) 212 | - n_steps: int 213 | Number of sucessive steps in the newtork (usually betwenn 3 and 10) 214 | - gamma : float 215 | Float above 1, scaling factor for attention updates (usually betwenn 1.0 to 2.0) 216 | - cat_idxs : list of int 217 | Index of each categorical column in the dataset 218 | - cat_dims : list of int 219 | Number of categories in each categorical column 220 | - cat_emb_dim : int or list of int 221 | Size of the embedding of categorical features 222 | if int, all categorical features will have same embedding size 223 | if list of int, every corresponding feature will have specific size 224 | - momentum : float 225 | Float value between 0 and 1 which will be used for momentum in all batch norm 226 | - n_independent : int 227 | Number of independent GLU layer in each GLU block (default 2) 228 | - n_shared : int 229 | Number of independent GLU layer in each GLU block (default 2) 230 | - mask_type: str 231 | Either "sparsemax" or "entmax" : this is the masking function to use 232 | - epsilon: float 233 | Avoid log(0), this should be kept very low 234 | """ 235 | super(TabNet, self).__init__() 236 | self.cat_idxs = cat_idxs or [] 237 | self.cat_dims = cat_dims or [] 238 | self.cat_emb_dim = cat_emb_dim 239 | 240 | self.input_dim = input_dim 241 | self.output_dim = output_dim 242 | self.n_d = n_d 243 | self.n_a = n_a 244 | self.n_steps = n_steps 245 | self.gamma = gamma 246 | self.epsilon = epsilon 247 | self.n_independent = n_independent 248 | self.n_shared = n_shared 249 | self.mask_type = mask_type 250 | 251 | if self.n_steps <= 0: 252 | raise ValueError("n_steps should be a positive integer.") 253 | if self.n_independent == 0 and self.n_shared == 0: 254 | raise ValueError("n_shared and n_independant can't be both zero.") 255 | 256 | self.virtual_batch_size = virtual_batch_size 257 | self.embedder = EmbeddingGenerator(input_dim, cat_dims, cat_idxs, cat_emb_dim) 258 | self.post_embed_dim = self.embedder.post_embed_dim 259 | self.tabnet = TabNetNoEmbeddings(self.post_embed_dim, output_dim, n_d, n_a, n_steps, 260 | gamma, n_independent, n_shared, epsilon, 261 | virtual_batch_size, momentum, mask_type) 262 | 263 | def forward(self, x): 264 | x = self.embedder(x) 265 | return self.tabnet(x) 266 | 267 | def forward_masks(self, x): 268 | x = self.embedder(x) 269 | return self.tabnet.forward_masks(x) 270 | 271 | 272 | class AttentiveTransformer(torch.nn.Module): 273 | def __init__(self, input_dim, output_dim, 274 | virtual_batch_size=128, 275 | momentum=0.02, 276 | mask_type="sparsemax"): 277 | """ 278 | Initialize an attention transformer. 279 | 280 | Parameters 281 | ---------- 282 | - input_dim : int 283 | Input size 284 | - output_dim : int 285 | Outpu_size 286 | - momentum : float 287 | Float value between 0 and 1 which will be used for momentum in batch norm 288 | - mask_type: str 289 | Either "sparsemax" or "entmax" : this is the masking function to use 290 | """ 291 | super(AttentiveTransformer, self).__init__() 292 | self.fc = Linear(input_dim, output_dim, bias=False) 293 | initialize_non_glu(self.fc, input_dim, output_dim) 294 | self.bn = GBN(output_dim, virtual_batch_size=virtual_batch_size, 295 | momentum=momentum) 296 | 297 | if mask_type == "sparsemax": 298 | # Sparsemax 299 | self.selector = sparsemax.Sparsemax(dim=-1) 300 | elif mask_type == "entmax": 301 | # Entmax 302 | self.selector = sparsemax.Entmax15(dim=-1) 303 | else: 304 | raise NotImplementedError("Please choose either sparsemax" + 305 | "or entmax as masktype") 306 | 307 | def forward(self, priors, processed_feat): 308 | x = self.fc(processed_feat) 309 | x = self.bn(x) 310 | x = torch.mul(x, priors) 311 | x = self.selector(x) 312 | return x 313 | 314 | 315 | class FeatTransformer(torch.nn.Module): 316 | def __init__(self, input_dim, output_dim, shared_layers, n_glu_independent, 317 | virtual_batch_size=128, momentum=0.02): 318 | super(FeatTransformer, self).__init__() 319 | """ 320 | Initialize a feature transformer. 321 | 322 | Parameters 323 | ---------- 324 | - input_dim : int 325 | Input size 326 | - output_dim : int 327 | Outpu_size 328 | - n_glu_independant 329 | - shared_blocks : torch.nn.ModuleList 330 | The shared block that should be common to every step 331 | - momentum : float 332 | Float value between 0 and 1 which will be used for momentum in batch norm 333 | """ 334 | 335 | params = { 336 | 'n_glu': n_glu_independent, 337 | 'virtual_batch_size': virtual_batch_size, 338 | 'momentum': momentum 339 | } 340 | 341 | if shared_layers is None: 342 | # no shared layers 343 | self.shared = torch.nn.Identity() 344 | is_first = True 345 | else: 346 | self.shared = GLU_Block(input_dim, output_dim, 347 | first=True, 348 | shared_layers=shared_layers, 349 | n_glu=len(shared_layers), 350 | virtual_batch_size=virtual_batch_size, 351 | momentum=momentum) 352 | is_first = False 353 | 354 | if n_glu_independent == 0: 355 | # no independent layers 356 | self.specifics = torch.nn.Identity() 357 | else: 358 | spec_input_dim = input_dim if is_first else output_dim 359 | self.specifics = GLU_Block(spec_input_dim, output_dim, 360 | first=is_first, 361 | **params) 362 | 363 | def forward(self, x): 364 | x = self.shared(x) 365 | x = self.specifics(x) 366 | return x 367 | 368 | 369 | class GLU_Block(torch.nn.Module): 370 | """ 371 | Independant GLU block, specific to each step 372 | """ 373 | 374 | def __init__(self, input_dim, output_dim, n_glu=2, first=False, shared_layers=None, 375 | virtual_batch_size=128, momentum=0.02): 376 | super(GLU_Block, self).__init__() 377 | self.first = first 378 | self.shared_layers = shared_layers 379 | self.n_glu = n_glu 380 | self.glu_layers = torch.nn.ModuleList() 381 | 382 | params = { 383 | 'virtual_batch_size': virtual_batch_size, 384 | 'momentum': momentum 385 | } 386 | 387 | fc = shared_layers[0] if shared_layers else None 388 | self.glu_layers.append(GLU_Layer(input_dim, output_dim, 389 | fc=fc, 390 | **params)) 391 | for glu_id in range(1, self.n_glu): 392 | fc = shared_layers[glu_id] if shared_layers else None 393 | self.glu_layers.append(GLU_Layer(output_dim, output_dim, 394 | fc=fc, 395 | **params)) 396 | 397 | def forward(self, x): 398 | scale = torch.sqrt(torch.FloatTensor([0.5]).to(x.device)) 399 | if self.first: # the first layer of the block has no scale multiplication 400 | x = self.glu_layers[0](x) 401 | layers_left = range(1, self.n_glu) 402 | else: 403 | layers_left = range(self.n_glu) 404 | 405 | for glu_id in layers_left: 406 | x = torch.add(x, self.glu_layers[glu_id](x)) 407 | x = x*scale 408 | return x 409 | 410 | 411 | class GLU_Layer(torch.nn.Module): 412 | def __init__(self, input_dim, output_dim, fc=None, 413 | virtual_batch_size=128, momentum=0.02): 414 | super(GLU_Layer, self).__init__() 415 | 416 | self.output_dim = output_dim 417 | if fc: 418 | self.fc = fc 419 | else: 420 | self.fc = Linear(input_dim, 2*output_dim, bias=False) 421 | initialize_glu(self.fc, input_dim, 2*output_dim) 422 | 423 | self.bn = GBN(2*output_dim, virtual_batch_size=virtual_batch_size, 424 | momentum=momentum) 425 | 426 | def forward(self, x): 427 | x = self.fc(x) 428 | x = self.bn(x) 429 | out = torch.mul(x[:, :self.output_dim], torch.sigmoid(x[:, self.output_dim:])) 430 | return out 431 | 432 | 433 | class EmbeddingGenerator(torch.nn.Module): 434 | """ 435 | Classical embeddings generator 436 | """ 437 | 438 | def __init__(self, input_dim, cat_dims, cat_idxs, cat_emb_dim): 439 | """ This is an embedding module for an entier set of features 440 | 441 | Parameters 442 | ---------- 443 | input_dim : int 444 | Number of features coming as input (number of columns) 445 | cat_dims : list of int 446 | Number of modalities for each categorial features 447 | If the list is empty, no embeddings will be done 448 | cat_idxs : list of int 449 | Positional index for each categorical features in inputs 450 | cat_emb_dim : int or list of int 451 | Embedding dimension for each categorical features 452 | If int, the same embdeding dimension will be used for all categorical features 453 | """ 454 | super(EmbeddingGenerator, self).__init__() 455 | if cat_dims == [] or cat_idxs == []: 456 | self.skip_embedding = True 457 | self.post_embed_dim = input_dim 458 | return 459 | 460 | self.skip_embedding = False 461 | if isinstance(cat_emb_dim, int): 462 | self.cat_emb_dims = [cat_emb_dim]*len(cat_idxs) 463 | else: 464 | self.cat_emb_dims = cat_emb_dim 465 | 466 | # check that all embeddings are provided 467 | if len(self.cat_emb_dims) != len(cat_dims): 468 | msg = """ cat_emb_dim and cat_dims must be lists of same length, got {len(self.cat_emb_dims)} 469 | and {len(cat_dims)}""" 470 | raise ValueError(msg) 471 | self.post_embed_dim = int(input_dim + np.sum(self.cat_emb_dims) - len(self.cat_emb_dims)) 472 | 473 | self.embeddings = torch.nn.ModuleList() 474 | 475 | # Sort dims by cat_idx 476 | sorted_idxs = np.argsort(cat_idxs) 477 | cat_dims = [cat_dims[i] for i in sorted_idxs] 478 | self.cat_emb_dims = [self.cat_emb_dims[i] for i in sorted_idxs] 479 | 480 | for cat_dim, emb_dim in zip(cat_dims, self.cat_emb_dims): 481 | self.embeddings.append(torch.nn.Embedding(cat_dim, emb_dim)) 482 | 483 | # record continuous indices 484 | self.continuous_idx = torch.ones(input_dim, dtype=torch.bool) 485 | self.continuous_idx[cat_idxs] = 0 486 | 487 | def forward(self, x): 488 | """ 489 | Apply embdeddings to inputs 490 | Inputs should be (batch_size, input_dim) 491 | Outputs will be of size (batch_size, self.post_embed_dim) 492 | """ 493 | if self.skip_embedding: 494 | # no embeddings required 495 | return x 496 | 497 | cols = [] 498 | cat_feat_counter = 0 499 | for feat_init_idx, is_continuous in enumerate(self.continuous_idx): 500 | # Enumerate through continuous idx boolean mask to apply embeddings 501 | if is_continuous: 502 | cols.append(x[:, feat_init_idx].float().view(-1, 1)) 503 | else: 504 | cols.append(self.embeddings[cat_feat_counter](x[:, feat_init_idx].long())) 505 | cat_feat_counter += 1 506 | # concat 507 | post_embeddings = torch.cat(cols, dim=1) 508 | return post_embeddings 509 | -------------------------------------------------------------------------------- /models/tabnet/utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from torch.utils.data import DataLoader, WeightedRandomSampler 3 | import torch 4 | import numpy as np 5 | import scipy 6 | 7 | 8 | class TorchDataset(Dataset): 9 | """ 10 | Format for numpy array 11 | 12 | Parameters 13 | ---------- 14 | X: 2D array 15 | The input matrix 16 | y: 2D array 17 | The one-hot encoded target 18 | """ 19 | 20 | def __init__(self, x, y): 21 | self.x = x 22 | self.y = y 23 | 24 | def __len__(self): 25 | return len(self.x) 26 | 27 | def __getitem__(self, index): 28 | x, y = self.x[index], self.y[index] 29 | return x, y 30 | 31 | 32 | class PredictDataset(Dataset): 33 | """ 34 | Format for numpy array 35 | 36 | Parameters 37 | ---------- 38 | X: 2D array 39 | The input matrix 40 | """ 41 | 42 | def __init__(self, x): 43 | self.x = x 44 | 45 | def __len__(self): 46 | return len(self.x) 47 | 48 | def __getitem__(self, index): 49 | x = self.x[index] 50 | return x 51 | 52 | 53 | def create_dataloaders( 54 | X_train, y_train, eval_set, weights, batch_size, num_workers, drop_last 55 | ): 56 | """ 57 | Create dataloaders with or wihtout subsampling depending on weights and balanced. 58 | 59 | Parameters 60 | ---------- 61 | X_train: np.ndarray 62 | Training data 63 | y_train: np.array 64 | Mapped Training targets 65 | X_valid: np.ndarray 66 | Validation data 67 | y_valid: np.array 68 | Mapped Validation targets 69 | weights : either 0, 1, dict or iterable 70 | if 0 (default) : no weights will be applied 71 | if 1 : classification only, will balanced class with inverse frequency 72 | if dict : keys are corresponding class values are sample weights 73 | if iterable : list or np array must be of length equal to nb elements 74 | in the training set 75 | Returns 76 | ------- 77 | train_dataloader, valid_dataloader : torch.DataLoader, torch.DataLoader 78 | Training and validation dataloaders 79 | """ 80 | 81 | if isinstance(weights, int): 82 | if weights == 0: 83 | need_shuffle = True 84 | sampler = None 85 | elif weights == 1: 86 | need_shuffle = False 87 | class_sample_count = np.array( 88 | [len(np.where(y_train == t)[0]) for t in np.unique(y_train)] 89 | ) 90 | 91 | weights = 1.0 / class_sample_count 92 | 93 | samples_weight = np.array([weights[t] for t in y_train]) 94 | 95 | samples_weight = torch.from_numpy(samples_weight) 96 | samples_weight = samples_weight.double() 97 | sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) 98 | else: 99 | raise ValueError("Weights should be either 0, 1, dictionnary or list.") 100 | elif isinstance(weights, dict): 101 | # custom weights per class 102 | need_shuffle = False 103 | samples_weight = np.array([weights[t] for t in y_train]) 104 | sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) 105 | else: 106 | # custom weights 107 | if len(weights) != len(y_train): 108 | raise ValueError("Custom weights should match number of train samples.") 109 | need_shuffle = False 110 | samples_weight = np.array(weights) 111 | sampler = WeightedRandomSampler(samples_weight, len(samples_weight)) 112 | 113 | train_dataloader = DataLoader( 114 | TorchDataset(X_train, y_train), 115 | batch_size=batch_size, 116 | sampler=sampler, 117 | shuffle=need_shuffle, 118 | num_workers=num_workers, 119 | drop_last=drop_last, 120 | pin_memory=True 121 | ) 122 | 123 | valid_dataloaders = [] 124 | for X, y in eval_set: 125 | valid_dataloaders.append( 126 | DataLoader( 127 | TorchDataset(X, y), 128 | batch_size=batch_size, 129 | shuffle=False, 130 | num_workers=num_workers, 131 | pin_memory=True 132 | ) 133 | ) 134 | 135 | return train_dataloader, valid_dataloaders 136 | 137 | 138 | def create_explain_matrix(input_dim, cat_emb_dim, cat_idxs, post_embed_dim): 139 | """ 140 | This is a computational trick. 141 | In order to rapidly sum importances from same embeddings 142 | to the initial index. 143 | 144 | Parameters 145 | ---------- 146 | input_dim: int 147 | Initial input dim 148 | cat_emb_dim : int or list of int 149 | if int : size of embedding for all categorical feature 150 | if list of int : size of embedding for each categorical feature 151 | cat_idxs : list of int 152 | Initial position of categorical features 153 | post_embed_dim : int 154 | Post embedding inputs dimension 155 | 156 | Returns 157 | ------- 158 | reducing_matrix : np.array 159 | Matrix of dim (post_embed_dim, input_dim) to performe reduce 160 | """ 161 | 162 | if isinstance(cat_emb_dim, int): 163 | all_emb_impact = [cat_emb_dim - 1] * len(cat_idxs) 164 | else: 165 | all_emb_impact = [emb_dim - 1 for emb_dim in cat_emb_dim] 166 | 167 | acc_emb = 0 168 | nb_emb = 0 169 | indices_trick = [] 170 | for i in range(input_dim): 171 | if i not in cat_idxs: 172 | indices_trick.append([i + acc_emb]) 173 | else: 174 | indices_trick.append( 175 | range(i + acc_emb, i + acc_emb + all_emb_impact[nb_emb] + 1) 176 | ) 177 | acc_emb += all_emb_impact[nb_emb] 178 | nb_emb += 1 179 | 180 | reducing_matrix = np.zeros((post_embed_dim, input_dim)) 181 | for i, cols in enumerate(indices_trick): 182 | reducing_matrix[cols, i] = 1 183 | 184 | return scipy.sparse.csc_matrix(reducing_matrix) 185 | 186 | 187 | def filter_weights(weights): 188 | """ 189 | This function makes sure that weights are in correct format for 190 | regression and multitask TabNet 191 | 192 | Parameters 193 | ---------- 194 | weights: int, dict or list 195 | Initial weights parameters given by user 196 | Returns 197 | ------- 198 | None : This function will only throw an error if format is wrong 199 | """ 200 | err_msg = "Please provide a list of weights for regression or multitask : " 201 | if isinstance(weights, int): 202 | if weights == 1: 203 | raise ValueError(err_msg + "1 given.") 204 | if isinstance(weights, dict): 205 | raise ValueError(err_msg + "Dict given.") 206 | return 207 | 208 | 209 | def validate_eval_set(eval_set, eval_name, X_train, y_train): 210 | """Check if the shapes of eval_set are compatible with (X_train, y_train). 211 | 212 | Parameters 213 | ---------- 214 | eval_set: list of tuple 215 | List of eval tuple set (X, y). 216 | The last one is used for early stopping 217 | eval_names: list of str 218 | List of eval set names. 219 | X_train: np.ndarray 220 | Train owned products 221 | y_train : np.array 222 | Train targeted products 223 | 224 | Returns 225 | ------- 226 | eval_names : list of str 227 | Validated list of eval_names. 228 | eval_set : list of tuple 229 | Validated list of eval_set. 230 | 231 | """ 232 | eval_name = eval_name or [f"val_{i}" for i in range(len(eval_set))] 233 | 234 | assert len(eval_set) == len( 235 | eval_name 236 | ), "eval_set and eval_name have not the same length" 237 | if len(eval_set) > 0: 238 | assert all( 239 | len(elem) == 2 for elem in eval_set 240 | ), "Each tuple of eval_set need to have two elements" 241 | for name, (X, y) in zip(eval_name, eval_set): 242 | check_nans(X) 243 | check_nans(y) 244 | msg = ( 245 | f"Number of columns is different between X_{name} " 246 | + f"({X.shape[1]}) and X_train ({X_train.shape[1]})" 247 | ) 248 | assert X.shape[1] == X_train.shape[1], msg 249 | if len(y_train.shape) == 2: 250 | msg = ( 251 | f"Number of columns is different between y_{name} " 252 | + f"({y.shape[1]}) and y_train ({y_train.shape[1]})" 253 | ) 254 | assert y.shape[1] == y_train.shape[1], msg 255 | msg = ( 256 | f"You need the same number of rows between X_{name} " 257 | + f"({X.shape[0]}) and y_{name} ({y.shape[0]})" 258 | ) 259 | assert X.shape[0] == y.shape[0], msg 260 | 261 | return eval_name, eval_set 262 | 263 | 264 | def check_nans(array): 265 | if np.isnan(array).any(): 266 | raise ValueError("NaN were found, TabNet does not allow nans.") 267 | if np.isinf(array).any(): 268 | raise ValueError("Infinite values were found, TabNet does not allow inf.") 269 | 270 | 271 | def define_device(device_name): 272 | """ 273 | Define the device to use during training and inference. 274 | If auto it will detect automatically whether to use cuda or cpu 275 | Parameters 276 | ---------- 277 | - device_name : str 278 | Either "auto", "cpu" or "cuda" 279 | Returns 280 | ------- 281 | - str 282 | Either "cpu" or "cuda" 283 | """ 284 | if device_name == "auto": 285 | if torch.cuda.is_available(): 286 | return "cuda" 287 | else: 288 | return "cpu" 289 | else: 290 | return device_name 291 | -------------------------------------------------------------------------------- /preprocessing/extract_feat_with_tta.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os, sys, inspect 4 | current_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 5 | parent_dir = os.path.dirname(current_dir) 6 | sys.path.insert(0, parent_dir) 7 | 8 | 9 | import os.path as osp 10 | import torch 11 | import pandas as pd 12 | from glob import glob 13 | import torch.utils.data as data_utils 14 | 15 | from tqdm import tqdm 16 | 17 | import pickle 18 | import numpy as np 19 | from multiprocessing.pool import Pool 20 | import cv2 21 | from albumentations.pytorch import ToTensorV2 22 | import albumentations as alb 23 | import random 24 | 25 | import queue 26 | import threading 27 | from models.effnet import EffNet 28 | from concurrent.futures import ThreadPoolExecutor 29 | import argparse 30 | 31 | tr_trans = alb.Compose([ 32 | alb.Resize(512, 512), 33 | alb.RandomRotate90(), 34 | alb.RandomBrightnessContrast(), 35 | alb.HueSaturationValue(), 36 | alb.HorizontalFlip(), 37 | alb.VerticalFlip(), 38 | alb.CoarseDropout(max_holes=4), 39 | alb.Normalize(), 40 | ToTensorV2(), 41 | ]) 42 | 43 | val_trans = alb.Compose([ 44 | alb.Resize(512, 512), 45 | alb.Normalize(), 46 | ToTensorV2(), 47 | ]) 48 | 49 | # Test time augmentation 50 | TTA_TIMES = 1 51 | 52 | 53 | class MILPatchDataset(data_utils.Dataset): 54 | def __init__(self, img_fp_list): 55 | self.img_fp_list = img_fp_list 56 | 57 | def __len__(self): 58 | return len(self.img_fp_list) 59 | 60 | def __getitem__(self, idx): 61 | img_fp = self.img_fp_list[idx] 62 | img = cv2.imread(img_fp)[:, :, ::-1] 63 | pid = osp.basename(osp.dirname(img_fp)) 64 | img_bname = osp.basename(img_fp).rsplit('.', 1)[0] 65 | 66 | val_img = val_trans(image=img)['image'] 67 | 68 | tr_img_compose = [] 69 | for i in range(TTA_TIMES): 70 | aug_img = tr_trans(image=img)['image'] 71 | tr_img_compose.append(aug_img) 72 | 73 | tr_ret = torch.stack(tr_img_compose) 74 | 75 | return pid, img_bname, val_img, tr_ret 76 | 77 | 78 | gt_csv_fp = "./data/dataset1_with_wsi.csv" 79 | 80 | 81 | def merge_feat_to_bag(feat_dir): 82 | """ 83 | all patch feature were saved in their pid dir, 84 | this function extract feature from pid dir and merge them into a single bag file 85 | Args: 86 | feat_dir: instance level feature save path 87 | Returns: 88 | """ 89 | dataset_df = pd.read_csv(gt_csv_fp) 90 | pid_label_mp = dict(zip(dataset_df['pid'], dataset_df['target'])) 91 | pid_dirs = glob(osp.join(feat_dir, '*')) 92 | pid_dirs = [d for d in pid_dirs if osp.isdir(d)] 93 | for pid_d in pid_dirs: 94 | feats = [] 95 | f_names = [] 96 | feat_fps = glob((osp.join(pid_d, '*.pkl'))) 97 | for feat_fp in feat_fps: 98 | with open(feat_fp, 'rb') as infile: 99 | try: 100 | feat_data = pickle.load(infile) 101 | except: 102 | continue 103 | feats.append(feat_data) 104 | f_names.append(osp.basename(feat_fp).rsplit('.', 1)[0]) 105 | 106 | feats = np.stack(feats) 107 | pid = osp.basename(pid_d) 108 | save_fp = osp.join(feat_dir, f'{pid}.pkl') 109 | 110 | print(f'Save features bag to: {save_fp} with size: {feats.shape}') 111 | 112 | if pid not in pid_label_mp.keys(): 113 | continue 114 | with open(save_fp, 'wb') as outfile: 115 | pickle.dump({ 116 | 'feat_bag': feats, 117 | 'feat_name': f_names, 118 | 'bag_label': pid_label_mp[pid] 119 | }, 120 | outfile) 121 | 122 | 123 | def save_feat_in_thread(batch_pid, batch_img_bname, batch_val_feat, tr_feat, batch_tr_ret): 124 | for b_idx, (pid, img_bname, val_feat) in enumerate(zip(batch_pid, batch_img_bname, batch_val_feat)): 125 | feat_save_dir = osp.join(save_dir, pid) 126 | os.makedirs(feat_save_dir, exist_ok=True) 127 | feat_save_name = osp.join(feat_save_dir, f'{img_bname}.pkl') 128 | save_dict = {} 129 | save_dict['val'] = val_feat 130 | tr_aug_feat = [] 131 | for aug_time in batch_tr_ret.shape[1]: 132 | tr_aug_feat.append(tr_feat[aug_time][b_idx]) 133 | save_dict['tr'] = tr_aug_feat 134 | with open(feat_save_name, 'wb') as outfile: 135 | pickle.dump(save_dict, outfile) 136 | 137 | 138 | def pred_and_save_with_dataloader(model, img_fp_list, local_rank): 139 | random.seed(42) 140 | model.cuda(local_rank) 141 | model.eval() 142 | executor = ThreadPoolExecutor(max_workers=16) 143 | dl = torch.utils.data.DataLoader( 144 | MILPatchDataset(img_fp_list), 145 | batch_size=1, 146 | num_workers=0, 147 | shuffle=False 148 | ) 149 | for batch in dl: 150 | 151 | batch_pid, batch_img_bname, batch_val_img, batch_tr_ret = batch 152 | 153 | batch_val_img = batch_val_img.cuda(local_rank) 154 | batch_tr_ret = batch_tr_ret.cuda(local_rank) 155 | 156 | with torch.no_grad(): 157 | batch_val_feat = model(batch_val_img) 158 | tr_feat = [] 159 | for aug_time in batch_tr_ret.shape[1]: 160 | tr_feat.append(model(batch_tr_ret[:, aug_time])) 161 | executor.submit(save_feat_in_thread, batch_pid, batch_img_bname, batch_val_feat, tr_feat, batch_tr_ret) 162 | 163 | 164 | 165 | img_queue = queue.Queue(maxsize=128) 166 | img_fp_queue = queue.Queue() 167 | 168 | 169 | def read_worker(): 170 | while True: 171 | imgfp = img_fp_queue.get() 172 | if imgfp is None: 173 | break 174 | try: 175 | img = cv2.imread(imgfp)[:, :, ::-1] 176 | except: 177 | img = np.zeros((512, 512, 3), dtype='uint8') 178 | aug_img_list = [] 179 | for i in range(TTA_TIMES): 180 | aug_img = tr_trans(image=img)['image'] 181 | aug_img_list.append(aug_img) 182 | img_queue.put((imgfp, img, aug_img_list)) 183 | 184 | 185 | def save_pkl_file(feat_save_name, save_dict): 186 | if osp.exists(feat_save_name): 187 | try: 188 | with open(feat_save_name, 'rb') as infile: 189 | old_save_dict = pickle.load(infile) 190 | except: 191 | old_save_dict = {} 192 | else: 193 | old_save_dict = {} 194 | 195 | tr_aug_feat = save_dict['tr'] 196 | if 'tr' in old_save_dict.keys(): 197 | try: 198 | save_dict['tr'] = np.concatenate([tr_aug_feat, old_save_dict['tr']]) 199 | except: 200 | save_dict['tr'] = tr_aug_feat 201 | with open(feat_save_name, 'wb') as outfile: 202 | pickle.dump(save_dict, outfile) 203 | 204 | 205 | def pred_and_save(model, img_fp_list, local_rank): 206 | total = len(img_fp_list) 207 | read_cnt = 0 208 | for img_fp in img_fp_list: 209 | img_fp_queue.put(img_fp) 210 | 211 | threads = [] 212 | num_worker_threads = 36 213 | for _ in range(num_worker_threads): 214 | t = threading.Thread(target=read_worker) 215 | t.start() 216 | threads.append(t) 217 | 218 | random.seed(42) 219 | model.cuda(local_rank) 220 | model.eval() 221 | 222 | executor = ThreadPoolExecutor(max_workers=16) 223 | 224 | while read_cnt < total: 225 | img_fp, img, aug_img_list = img_queue.get() 226 | read_cnt += 1 227 | pid = osp.basename(osp.dirname(img_fp)) 228 | 229 | img_bname = osp.basename(img_fp).rsplit('.', 1)[0] 230 | 231 | feat_save_dir = osp.join(save_dir, pid) 232 | os.makedirs(feat_save_dir, exist_ok=True) 233 | 234 | feat_save_name = osp.join(feat_save_dir, f'{img_bname}.pkl') 235 | save_dict = {} 236 | with torch.no_grad(): 237 | 238 | val_img = val_trans(image=img)['image'].cuda(local_rank) 239 | 240 | val_feat = model(val_img.unsqueeze(0)).detach().cpu().numpy() 241 | save_dict['val'] = val_feat 242 | 243 | tr_aug_feat = [] 244 | for aug_img in aug_img_list: 245 | aug_img = aug_img.cuda(local_rank) 246 | tr_feat = model(aug_img.unsqueeze(0)).detach().cpu().numpy() 247 | 248 | tr_aug_feat.append(tr_feat) 249 | 250 | tr_aug_feat = np.stack(tr_aug_feat) 251 | if 'tr' in save_dict.keys(): 252 | save_dict['tr'] = np.concatenate([tr_aug_feat, save_dict['tr']]) 253 | else: 254 | save_dict['tr'] = tr_aug_feat 255 | 256 | executor.submit(save_pkl_file, feat_save_name, save_dict) 257 | 258 | for i in range(num_worker_threads): 259 | img_fp_queue.put(None) 260 | 261 | for t in threads: 262 | t.join() 263 | 264 | 265 | def main(): 266 | """ 267 | extract patch feature from WSI 268 | save each patch feature into pid dir 269 | then merge them into a single file 270 | """ 271 | print(f'Load dataset...') 272 | 273 | model = EffNet() 274 | 275 | img_fp_list = [] 276 | print(f'Working on {patch_root_dir}') 277 | bag_fp_list = glob(osp.join(patch_root_dir, '*')) 278 | for bag_fp in bag_fp_list: 279 | img_files = glob(osp.join(bag_fp, '*.png')) 280 | img_fp_list.extend(img_files) 281 | 282 | print(f'Len of img {len(img_fp_list)}') 283 | 284 | 285 | img_fp_list = sorted(img_fp_list) 286 | 287 | np.random.shuffle(img_fp_list) 288 | num_processes = 8 289 | num_train_images = len(img_fp_list) 290 | images_per_process = num_train_images / num_processes 291 | 292 | tasks = [] 293 | for num_process in range(1, num_processes + 1): 294 | start_index = (num_process - 1) * images_per_process + 1 295 | end_index = num_process * images_per_process 296 | start_index = int(start_index) 297 | end_index = int(end_index) 298 | tasks.append((model, img_fp_list[start_index:end_index], (num_process - 1) % 4)) 299 | if start_index == end_index: 300 | print("Task #" + str(num_process) + 301 | ": Process slide " + str(start_index)) 302 | else: 303 | print("Task #" + str(num_process) + ": Process slides " + 304 | str(start_index) + " to " + str(end_index)) 305 | 306 | with Pool(num_processes) as p: 307 | for _ in tqdm(p.starmap(pred_and_save, tasks), total=len(tasks)): 308 | pass 309 | 310 | 311 | 312 | if __name__ == '__main__': 313 | load_and_save_path = { 314 | 'x5': [ 315 | 'path_to_WSI_patch_image_files', 316 | 'path_to_saved_features'], 317 | 'x10': [ 318 | 'path_to_WSI_patch_image_files', 319 | 'path_to_saved_features'], 320 | 'x20': [ 321 | 'path_to_WSI_patch_image_files', 322 | 'path_to_saved_features'] 323 | } 324 | 325 | parser = argparse.ArgumentParser() 326 | parser.add_argument('--level', type=str, default='x5') 327 | arg = parser.parse_args() 328 | 329 | select = load_and_save_path[arg.level] 330 | 331 | patch_root_dir, save_dir = select 332 | main() 333 | -------------------------------------------------------------------------------- /preprocessing/merge_patch_feat.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import os 5 | import os.path as osp 6 | import pickle 7 | from glob import glob 8 | from multiprocessing import Pool 9 | 10 | from rich import print 11 | from rich.progress import track 12 | 13 | """ 14 | extract_patch_with_tta.py Multiple features will be extracted for each Patch via TTA 15 | - WSI id1 16 | - patch1 feat.pkl 17 | - patch2 feat.pkl 18 | - WSI id2 19 | - patch1 feat.pkl 20 | - patch2 feat.pkl 21 | This file combines the features extracted by offline within the same WSI into a single file 22 | - WSI id1.pkl 23 | - WSI id2.pkl 24 | """ 25 | 26 | feat_save_dirx10 = 'path_to_extracted_features_at_x5' 27 | merge_feat_save_dirx10 = 'path_to_combined_features_at_x5' 28 | 29 | feat_save_dirx20 = 'path_to_extracted_features_at_x20' 30 | merge_feat_save_dirx20 = 'path_to_combined_features_at_x20' 31 | 32 | feat_save_dirx5 = 'path_to_extracted_features_at_x5' 33 | merge_feat_save_dirx5 = 'path_to_combined_features_at_x5' 34 | 35 | 36 | def merge_wsi_feat(wsi_feat_dir) -> None: 37 | """ 38 | Args: 39 | wsi_feat_dir: 40 | 41 | Returns: 42 | 43 | """ 44 | 45 | 46 | files = glob(osp.join(wsi_feat_dir, '*.pkl')) 47 | 48 | save_obj = [] 49 | for fp in files: 50 | # 51 | try: 52 | with open(fp, 'rb') as infile: 53 | obj = pickle.load(infile) 54 | 55 | # add patch file name 56 | obj['feat_name'] = osp.basename(fp).rsplit('.', 1)[0] 57 | save_obj.append(obj) 58 | except Exception as e: 59 | print(f'Error in {fp} as {e}') 60 | continue 61 | 62 | bname = osp.basename(wsi_feat_dir).lower() # wsi id 63 | save_fp = osp.join(merge_feat_save_dir, f'{bname}.pkl') 64 | with open(save_fp, 'wb') as outfile: 65 | pickle.dump(save_obj, outfile) 66 | 67 | 68 | if __name__ == '__main__': 69 | for feat_save_dir, merge_feat_save_dir in zip( 70 | [feat_save_dirx20, feat_save_dirx10, feat_save_dirx5, ], 71 | [merge_feat_save_dirx20, merge_feat_save_dirx10, merge_feat_save_dirx5] 72 | ): 73 | print(f'Save to {merge_feat_save_dir}') 74 | os.makedirs(merge_feat_save_dir, exist_ok=True) 75 | wsi_dirs = glob(osp.join(feat_save_dir, '*')) 76 | 77 | with Pool(160) as p: 78 | for _ in track(p.imap_unordered(merge_wsi_feat, wsi_dirs), total=len(wsi_dirs)): 79 | pass 80 | -------------------------------------------------------------------------------- /sampledata.7z: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yfzon/Multi-modal-Multi-instance-Learning/c30051bc87c8d40ec93d29ca216ae088816ac6b5/sampledata.7z -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import fnmatch 4 | from rich.logging import RichHandler 5 | 6 | 7 | def setup_logger(distributed_rank=0, filename="log.txt"): 8 | FORMAT = "%(message)s" 9 | logging.basicConfig(level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]) 10 | 11 | logger = logging.getLogger("Logger") 12 | logger.setLevel(logging.DEBUG) 13 | fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" 14 | # don't log results for the non-master process 15 | if distributed_rank > 0: 16 | return logger 17 | 18 | fh = logging.FileHandler(filename) 19 | fh.setLevel(logging.DEBUG) 20 | fh.setFormatter(logging.Formatter(fmt)) 21 | logger.addHandler(fh) 22 | 23 | return logger 24 | 25 | 26 | def find_recursive(root_dir, ext='.jpg'): 27 | files = [] 28 | for root, dirnames, filenames in os.walk(root_dir): 29 | for filename in fnmatch.filter(filenames, '*' + ext): 30 | files.append(os.path.join(root, filename)) 31 | return files 32 | 33 | 34 | class AverageMeter(object): 35 | """Computes and stores the average and current value""" 36 | 37 | def __init__(self): 38 | self.initialized = False 39 | self.val = None 40 | self.avg = None 41 | self.sum = None 42 | self.count = None 43 | 44 | def initialize(self, val, weight): 45 | self.val = val 46 | self.avg = val 47 | self.sum = val * weight 48 | self.count = weight 49 | self.initialized = True 50 | 51 | def update(self, val, weight=1): 52 | if not self.initialized: 53 | self.initialize(val, weight) 54 | else: 55 | self.add(val, weight) 56 | 57 | def add(self, val, weight): 58 | self.val = val 59 | self.sum += val * weight 60 | self.count += weight 61 | self.avg = self.sum / self.count 62 | 63 | def value(self): 64 | return self.val 65 | 66 | def average(self): 67 | return self.avg 68 | 69 | 70 | --------------------------------------------------------------------------------