├── .DS_Store ├── AWESOME.md ├── FULLSCRIPT.md ├── INSTALL.md ├── LICENSE ├── README.md ├── TumorGenerated ├── TumorGenerated.py ├── __init__.py └── utils.py ├── datafolds ├── datafold_read.py ├── healthy.json ├── lits.json ├── mix_0.json ├── mix_1.json ├── mix_2.json ├── mix_3.json ├── mix_4.json ├── real_0.json ├── real_1.json ├── real_2.json ├── real_3.json └── real_4.json ├── documents ├── .DS_Store ├── FAQ.md ├── hu2023label.pdf ├── poster.pdf ├── poster_cvpr23.pdf └── slides.pdf ├── external └── surface-distance │ ├── CONTRIBUTING.md │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── setup.py │ ├── surface_distance │ ├── __init__.py │ ├── lookup_tables.py │ └── metrics.py │ └── surface_distance_test.py ├── figures ├── Examples.gif └── VisualTuringTest.png ├── main.py ├── monai_trainer.py ├── networks ├── __init__.py ├── basicunetplusplus.py ├── mlp.py ├── patchembedding.py ├── selfattention.py ├── swin3d_unetr.py ├── swin3d_unetrv2.py ├── swin_transformer_3d.py ├── transformerblock.py ├── unetr.py ├── unetr_block.py └── vit.py ├── networks2 ├── __init__.py ├── mlp.py ├── patchembedding.py ├── selfattention.py ├── transformerblock.py ├── unetr.py ├── unetr_block.py └── vit.py ├── optimizers ├── __init__.py └── lr_scheduler.py ├── requirements.txt ├── resnet_flex.py ├── transfer_label.py └── validation.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrGiovanni/SyntheticTumors/983dbeef8fb6f5d15ec4373a5241fc198b92d765/.DS_Store -------------------------------------------------------------------------------- /AWESOME.md: -------------------------------------------------------------------------------- 1 | # Awesome Synthetic Tumors in Medical Imaging [![Awesome](https://awesome.re/badge.svg)](https://awesome.re) 2 | 3 | [![MIT License](https://img.shields.io/badge/license-MIT-green.svg)](https://opensource.org/licenses/MIT) 4 | 5 | - ❤ We provide a comprehensive list of tumor synthesis in medical imaging. 6 | - 🔥 Welcome to share the paper and code through the [issues](https://github.com/MrGiovanni/SyntheticTumors/issues/1). 7 | 8 | ## Paper 9 | 10 | **Self-improving generative foundation model for synthetic medical image generation and clinical applications** 11 | *Jinzhuo Wang, Kai Wang, Yunfang Yu, Yuxing Lu, et al.* 12 | Nature Medicine | 11 Dec 2024 13 | [paper](https://www.nature.com/articles/s41591-024-03359-y) 14 | [![GitHub stars](https://img.shields.io/github/stars/WithStomach/MINIM.svg?logo=github&label=Stars)](https://github.com/WithStomach/MINIM) 15 | 16 | **GenerateCT: Text-Conditional Generation of 3D Chest CT Volumes** 17 | *Ibrahim Ethem Hamamci, Sezgin Er, et al.* 18 | ECCV | 11 Mar 2024 19 | [paper](https://arxiv.org/pdf/2305.16037v4) 20 | [![GitHub stars](https://img.shields.io/github/stars/ibrahimethemhamamci/GenerateCT.svg?logo=github&label=Stars)](https://github.com/ibrahimethemhamamci/GenerateCT) 21 | 22 | **Anatomically-Controllable Medical Image Generation with Segmentation-Guided Diffusion Models** 23 | *Nicholas Konz, Yuwen Chen, Haoyu Dong, Maciej A. Mazurowski* 24 | MICCAI | 19 Jun 2024 25 | [paper](https://arxiv.org/pdf/2402.05210) 26 | 27 | **FreeTumor: Advance Tumor Segmentation via Large-Scale Tumor Synthesis** 28 | *Linshan Wu, Jiaxin Zhuang, Xuefeng Ni, Hao Chen* 29 | arXiv | 3 Jun 2024 30 | [paper](https://arxiv.org/pdf/2406.01264) 31 | 32 | **Generative Enhancement for 3D Medical Images** 33 | *Lingting Zhu1, Noel Codella, Dongdong Chen, Zhenchao Jin, Lu Yuan, Lequan Yu* 34 | arXiv | 24 May 2024 35 | [paper](https://arxiv.org/pdf/2403.12852) 36 | 37 | **LeFusion: Synthesizing Myocardial Pathology on Cardiac MRI via Lesion-Focus Diffusion Models** 38 | *Hantao Zhang, Jiancheng Yang, Shouhong Wan, Pascal Fua* 39 | arXiv | 21 Mar 2024 40 | [paper](https://arxiv.org/pdf/2403.14066) 41 | [![GitHub stars](https://img.shields.io/github/stars/M3DV/LeFusion.svg?logo=github&label=Stars)](https://github.com/M3DV/LeFusion) 42 | 43 | **From Pixel to Cancer: Cellular Automata in Computed Tomography** 44 | *Yuxiang Lai, Xiaoxi Chen, Angtian Wang, Alan Yuille, Zongwei Zhou* 45 | MICCAI | 13 May 2024 46 | [paper](https://www.cs.jhu.edu/~alanlab/Pubs24/lai2024pixel.pdf) [![GitHub stars](https://img.shields.io/github/stars/MrGiovanni/Pixel2Cancer.svg?logo=github&label=Stars)](https://github.com/MrGiovanni/Pixel2Cancer) 47 | 48 | **Towards Generalizable Tumor Synthesis** 49 | *Qi Chen, Xiaoxi Chen, Haorui Song, Zhiwei Xiong, Alan Yuille, Chen Wei, Zongwei Zhou* 50 | CVPR | 29 Feb 2024 51 | [paper](https://arxiv.org/pdf/2402.19470.pdf) [![GitHub stars](https://img.shields.io/github/stars/MrGiovanni/DiffTumor.svg?logo=github&label=Stars)](https://github.com/MrGiovanni/DiffTumor) 52 | 53 | **Virtual elastography ultrasound via generative adversarial network for breast cancer diagnosis** 54 | *Zhao Yao, Ting Luo, et al.* 55 | Nature Communications | 11 Feb 2023 56 | [paper](https://www.nature.com/articles/s41467-023-36102-1) [![GitHub stars](https://img.shields.io/github/stars/yyyzzzhao/VEUS.svg?logo=github&label=Stars)](https://github.com/yyyzzzhao/VEUS) 57 | 58 | **SynFundus: Generating a synthetic fundus images dataset with millions of samples and multi-disease annotations** 59 | *Fangxin Shang, Jie Fu, Yehui Yang, Lei Ma* 60 | arXiv | 1 Dec 2023 61 | [paper](https://arxiv.org/abs/2312.00377) 62 | 63 | **Human brain responses are modulated when exposed to optimized natural images or synthetically generated images** 64 | *Zijin Gu, Keith Jamison, Mert R. Sabuncu, Amy Kuceyeski* 65 | Communications Biology | 23 Oct 2023 66 | [paper](https://www.nature.com/articles/s42003-023-05440-7) 67 | 68 | **Synthetically Enhanced: Unveiling Synthetic Data's Potential in Medical Imaging Research** 69 | *Bardia Khosravi, Frank Li, Theo Dapamede, Pouria Rouzrokh, et al.* 70 | arXiv | 15 Nov 2023 71 | [paper](https://arxiv.org/abs/2311.09402) 72 | 73 | **Privacy Distillation: Reducing Re-identification Risk of Diffusion Models** 74 | *Fernandez, Virginia and Sanchez, Pedro and Pinaya, Walter Hugo Lopez and Jacenkow, Grzegorz and Tsaftaris, Sotirios A and Cardoso, Jorge* 75 | arXiv | 2 Jun 2023 76 | [paper](https://arxiv.org/abs/2306.01322) 77 | 78 | **SynthSeg: Segmentation of brain MRI scans of any contrast and resolution without retraining** 79 | *Benjamin Billot, Douglas N. Greve, Oula Puonti, Axel Thielscher, Koen Van Leemput, Bruce Fischl, Adrian V. Dalca, Juan Eugenio Iglesias, for the ADNI* 80 | Medical Image Analysis | 25 Feb 2023 81 | [paper](https://www.sciencedirect.com/science/article/pii/S1361841523000506) [![GitHub stars](https://img.shields.io/github/stars/BBillot/SynthSeg.svg?logo=github&label=Stars)](https://github.com/BBillot/SynthSeg) 82 | 83 | **Image Synthesis with Disentangled Attributes for Chest X-Ray Nodule Augmentation and Detection** 84 | *Zhenrong Shen, Xi Ouyang, Bin Xiao, Jie-Zhi Cheng, Dinggang Shen, Qian Wang* 85 | MEDIA | Feb 2023 86 | [paper](https://arxiv.org/abs/2207.09389) 87 | 88 | **Unsupervised Liver Tumor Segmentation with Pseudo Anomaly Synthesis** 89 | *Zhaoxiang Zhang, Hanqiu Deng, Xingyu Li* 90 | SASHIMI Workshop on Simulation and Synthesis in Medical Imaging | 7 Oct 2023 91 | [paper](https://link.springer.com/chapter/10.1007/978-3-031-44689-4_9) [![GitHub stars](https://img.shields.io/github/stars/nono-zz/LiTs-Segmentation.svg?logo=github&label=Stars)](https://github.com/nono-zz/LiTs-Segmentation) 92 | 93 | **Early Detection and Localization of Pancreatic Cancer by Label-Free Tumor Synthesis** 94 | Bowen Li, Yu-Cheng Chou, Shuwen Sun, Hualin Qiao, Alan Yuille, Zongwei Zhou 95 | MICCAI Workshop on Big Task Small Data | 30 Sep 2023 96 | [paper](https://browse.arxiv.org/pdf/2308.03008.pdf) [![GitHub stars](https://img.shields.io/github/stars/MrGiovanni/SyntheticTumors.svg?logo=github&label=Stars)](https://github.com/MrGiovanni/SyntheticTumors) 97 | 98 | **How Good Are Synthetic Medical Images? An Empirical Study with Lung Ultrasound** 99 | *Menghan Yu, Sourabh Kulhare, Courosh Mehanian, Charles B Delahunt, Daniel E Shea, Zohreh Laverriere, Ishan Shah, Matthew P Horning* 100 | [paper](https://browse.arxiv.org/pdf/2310.03608.pdf) [![GitHub stars](https://img.shields.io/github/stars/global-health-labs/us-dcgan.svg?logo=github&label=Stars)](https://github.com/global-health-labs/us-dcgan) 101 | 102 | **You Don't Have to Be Perfect to Be Amazing: Unveil the Utility of Synthetic Images** 103 | *Xiaodan Xiang, Federico Felder, Yang Nan, Giorgos Papanastasiou, Walsh Simon, Guang Yang* 104 | MICCAI | 25 May 2023 105 | [paper](https://arxiv.org/abs/2305.18337) [![GitHub stars](https://img.shields.io/github/stars/ayanglab/MedSynAnalyzer.svg?logo=github&label=Stars)](https://github.com/ayanglab/MedSynAnalyzer) 106 | 107 | **Label-Free Liver Tumor Segmentation** 108 | *Qixin Hu, Yixiong Chen, Junfei Xiao, Shuwen Sun, Jieneng Chen, Alan Yuille, Zongwei Zhou* 109 | CVPR | 27 March 2023 110 | [paper](https://arxiv.org/abs/2303.14869) [![GitHub stars](https://img.shields.io/github/stars/MrGiovanni/SyntheticTumors.svg?logo=github&label=Stars)](https://github.com/MrGiovanni/SyntheticTumors) 111 | 112 | **Synthetic data accelerates the development of generalizable learning-based algorithms for X-ray image analysis** 113 | *Cong Gao, Benjamin D. Killeen, Yicheng Hu, Robert B. Grupp, Russell H. Taylor, Mehran Armand, Mathias Unberath* 114 | Nature Machine Intelligence | 20 March 2023 115 | [paper](https://www.nature.com/articles/s42256-023-00629-1) | [dataset](https://doi.org/10.7281/T1/2PGJQU) [![GitHub stars](https://img.shields.io/github/stars/arcadelab/SyntheX.svg?logo=github&label=Stars)](https://github.com/arcadelab/SyntheX) 116 | 117 | **Pseudo-Label Guided Image Synthesis for Semi-Supervised COVID-19 Pneumonia Infection Segmentation** 118 | *Fei Lyu, Mang Ye, Jonathan Frederik Carlsen, Kenny Erleben, Sune Darkner, Pong C Yuen* 119 | IEEE TMI | 02 March 2023 120 | [paper](https://pubmed.ncbi.nlm.nih.gov/36288236) [![GitHub stars](https://img.shields.io/github/stars/FeiLyu/SASSL.svg?logo=github&label=Stars)](https://github.com/FeiLyu/SASSL) 121 | 122 | **Image Turing test and its applications on synthetic chest radiographs by using the progressive growing generative adversarial network** 123 | *Miso Jang, Hyun-jin Bae, Minjee Kim, Seo Young Park, A-yeon Son, Se Jin Choi, Jooae Choe, Hye Young Choi, Hye Jeon Hwang, Han Na Noh, Joon Beom Seo, Sang Min Lee & Namkug Kim* 124 | Scientific Reports | 09 February 2023 125 | [paper](https://www.nature.com/articles/s41598-023-28175-1) 126 | 127 | **Self-supervised Tumor Segmentation with Sim2Real Adaptation** 128 | *Xiaoman Zhang, Weidi Xie, Chaoqin Huang, Ya Zhang, Xin Chen, Qi Tian, Yanfeng Wang* 129 | IEEE BHI | 31 January 2023 130 | [paper](https://ieeexplore.ieee.org/document/10032792) [![GitHub stars](https://img.shields.io/github/stars/xiaoman-zhang/Layer-Decomposition.svg?logo=github&label=Stars)](https://github.com/xiaoman-zhang/Layer-Decomposition) 131 | 132 | **Synthetic Tumors Make AI Segment Tumors Better** 133 | *Qixin Hu, Junfei Xiao, Yixiong Chen, Shuwen Sun, Jie-Neng Chen, Alan Yuille, Zongwei Zhou* 134 | Medical Imaging Meets NeurIPS | 26 October 2022 135 | [paper](https://arxiv.org/pdf/2210.14845.pdf) [![GitHub stars](https://img.shields.io/github/stars/MrGiovanni/SyntheticTumors.svg?logo=github&label=Stars)](https://github.com/MrGiovanni/SyntheticTumors) 136 | 137 | **Pancreatic Image Augmentation Based on Local Region Texture Synthesis for Tumor Segmentation** 138 | *Zihan Wei, Yizhou Chen, Qiu Guan, et al.* 139 | ICANN | 07 September 2022 140 | [paper](https://drive.google.com/file/d/16GQqAv384QQyJ9YhXAIbDnzcfvqjLbEu) 141 | 142 | **Hierarchical Amortized GAN for 3D High Resolution Medical Image Synthesis** 143 | *Li Sun, Junxiang Chen, Yanwu Xu, Mingming Gong, Ke Yu, Kayhan Batmanghelich* 144 | IEEE JBHI | 9 August 2022 145 | [paper](https://pubmed.ncbi.nlm.nih.gov/35522642/) [![GitHub stars](https://img.shields.io/github/stars/batmanlab/HA-GAN.svg?logo=github&label=Stars)](https://github.com/batmanlab/HA-GAN) 146 | 147 | **Anomaly segmentation in retinal images with poisson-blending data augmentation** 148 | *Hualin Wang, Yuhong Zhou, Jiong Zhang, Jianqin Lei, Dongke Sun, Feng Xu, Xiayu Xu* 149 | Medical Image Analysis | 10 July 2022 150 | [paper](https://www.sciencedirect.com/science/article/pii/S1361841522001815) 151 | 152 | **Learning From Synthetic CT Images via Test-Time Training for Liver Tumor Segmentation** 153 | *Fei Lyu, Mang Ye, Andy J. Ma, Terry Cheuk-Fung Yip, Grace Lai-Hung Wong, Pong C. Yuen* 154 | IEEE TMI | 11 April 2022 155 | [paper](https://ieeexplore.ieee.org/abstract/document/9754550) [![GitHub stars](https://img.shields.io/github/stars/FeiLyu/SR-TTT.svg?logo=github&label=Stars)](https://github.com/FeiLyu/SR-TTT) 156 | 157 | **Virtual reality for synergistic surgical training and data generation** 158 | *Adnan Munawar, Zhaoshuo Li, et al.* 159 | AE-CAI MICCAI | 15 November 2021 160 | [paper](https://arxiv.org/pdf/2111.08097.pdf) [![GitHub stars](https://img.shields.io/github/stars/LCSR-SICKKIDS/volumetric_drilling.svg?logo=github&label=Stars)](https://github.com/LCSR-SICKKIDS/volumetric_drilling) 161 | 162 | **Label-Free Segmentation of COVID-19 Lesions in Lung CT** 163 | *Qingsong Yao, Li Xiao, Peihang Liu, S Kevin Zhou* 164 | IEEE TMI | 24 March 2021 165 | [paper](https://pubmed.ncbi.nlm.nih.gov/33760731) 166 | 167 | **Free-form tumor synthesis in computed tomography images via richer generative adversarial network** 168 | *Qiangguo Jin, Hui Cui, Changming Sun, Zhaopeng Meng, Ran Su* 169 | Knowledge-Based Systems | 5 January 2021 170 | [paper](https://www.sciencedirect.com/science/article/pii/S0950705121000162) [![GitHub stars](https://img.shields.io/github/stars/qgking/FRGAN.svg?logo=github&label=Stars)](https://github.com/qgking/FRGAN) 171 | 172 | **Time-series Generative Adversarial Networks** 173 | *Jinsung Yoon, Daniel Jarrett, Mihaela van der Schaar* 174 | NeurIPS | 8 December 2019 175 | [paper](https://papers.nips.cc/paper_files/paper/2019/file/c9efe5f26cd17ba6216bbe2a7d26d490-Paper.pdf)[![GitHub stars](https://img.shields.io/github/stars/jsyoon0823/TimeGAN.svg?logo=github&label=Stars)](https://github.com/jsyoon0823/TimeGAN) 176 | 177 | **Generating large labeled data sets for laparoscopic image processing tasks using unpaired image-to-image translation** 178 | *Micha Pfeiffer, Isabel Funke, et al.* 179 | MICCAI | 5 July 2019 180 | [paper](https://arxiv.org/pdf/1907.02882.pdf) | [dataset](http://opencas.dkfz.de/image2image/) 181 | 182 | **Abnormal colon polyp image synthesis using conditional adversarial networks for improved detection performance** 183 | *Younghak Shin; Hemin Ali Qadir; Ilangko Balasingham* 184 | IEEE Access | 27 June 2019 185 | [paper](https://ieeexplore.ieee.org/abstract/document/8478237) 186 | 187 | **Synthesizing diverse lung nodules wherever massively: 3d multiconditional gan-based ct image augmentation for object detection** 188 | *Changhee Han, Yoshiro Kitamura, Akira Kudo, Akimichi Ichinose, Leonardo Rundo, Yujiro Furukawa, Kazuki Umemoto, Yuanzhong Li, Hideki Nakayama* 189 | 3DV | 12 June 2019 190 | [paper](https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8886112) 191 | 192 | 193 | **PATE-GAN: Generating synthetic data with differential privacy guarantees** 194 | *James Jordon, Jinsung Yoon, Mihaela van der Schaar* 195 | ICLR | 21 December 2018 196 | [paper](https://openreview.net/pdf?id=S1zk9iRqF7) [![GitHub stars](https://img.shields.io/github/stars/vanderschaarlab/mlforhealthlabpub.svg?logo=github&label=Stars)](https://github.com/vanderschaarlab/mlforhealthlabpub/tree/main/alg/pategan) 197 | -------------------------------------------------------------------------------- /FULLSCRIPT.md: -------------------------------------------------------------------------------- 1 | ```bash 2 | git clone https://github.com/MrGiovanni/SyntheticTumors.git 3 | cd SyntheticTumors 4 | wget https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/model_swinvit.pt 5 | 6 | ##### ccvl25 7 | source /data/zzhou82/environment/syn/bin/activate 8 | cd /mnt/medical_data/Users/zzhou82/project/SyntheticTumors/ 9 | datapath=/mnt/ccvl15/zzhou82/PublicAbdominalData/ 10 | 11 | ##### ccvl26 12 | source /data/zzhou82/environments/syn/bin/activate 13 | datapath=/mnt/zzhou82/PublicAbdominalData/ 14 | cd /medical_backup/Users/zzhou82/project/SyntheticTumors/ 15 | ``` 16 | 17 | ## 1. Train and evaluate segmentation models using synthetic tumors 18 | 19 | #### UNET (no.pretrain) 20 | ```bash 21 | CUDA_VISIBLE_DEVICES=7 python -W ignore main.py --optim_lr=4e-4 --batch_size=8 --lrschedule=warmup_cosine --optim_name=adamw --model_name=unet --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12215 --cache_num=120 --val_overlap=0.75 --syn --logdir="runs/synt.no_pretrain.unet" --train_dir $datapath --val_dir $datapath --json_dir datafolds/healthy.json 22 | 23 | CUDA_VISIBLE_DEVICES=7 python -W ignore validation.py --model=unet --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/synt.no_pretrain.unet --save_dir outs 24 | ``` 25 | 26 | #### Swin-UNETR-Base (pretrain) 27 | ```bash 28 | CUDA_VISIBLE_DEVICES=7 python -W ignore main.py --optim_lr=4e-4 --batch_size=4 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=base --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12231 --cache_num=120 --val_overlap=0.75 --syn --logdir="runs/synt.pretrain.swin_unetrv2_base" --train_dir $datapath --val_dir $datapath --json_dir datafolds/healthy.json --use_pretrained 29 | 30 | CUDA_VISIBLE_DEVICES=7 python -W ignore validation.py --model=swin_unetrv2 --swin_type=base --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/synt.pretrain.swin_unetrv2_base --save_dir outs 31 | ``` 32 | 33 | #### Swin-UNETR-Base (no.pretrain) 34 | ```bash 35 | CUDA_VISIBLE_DEVICES=7 python -W ignore main.py --optim_lr=4e-4 --batch_size=4 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=base --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12231 --cache_num=120 --val_overlap=0.75 --syn --logdir="runs/synt.no_pretrain.swin_unetrv2_base" --train_dir $datapath --val_dir $datapath --json_dir datafolds/healthy.json 36 | 37 | CUDA_VISIBLE_DEVICES=0 python -W ignore validation.py --model=swin_unetrv2 --swin_type=base --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/synt.no_pretrain.swin_unetrv2_base --save_dir outs 38 | ``` 39 | 40 | #### Swin-UNETR-Small (no.pretrain) 41 | ```bash 42 | CUDA_VISIBLE_DEVICES=7 python -W ignore main.py --optim_lr=4e-4 --batch_size=4 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=small --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12233 --cache_num=120 --val_overlap=0.75 --syn --logdir="runs/synt.no_pretrain.swin_unetrv2_small" --train_dir $datapath --val_dir $datapath --json_dir datafolds/healthy.json 43 | 44 | CUDA_VISIBLE_DEVICES=7 python -W ignore validation.py --model=swin_unetrv2 --swin_type=small --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/synt.no_pretrain.swin_unetrv2_small --save_dir outs 45 | ``` 46 | 47 | #### Swin-UNETR-Tiny (no.pretrain) 48 | ```bash 49 | CUDA_VISIBLE_DEVICES=7 python -W ignore main.py --optim_lr=4e-4 --batch_size=4 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=tiny --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12234 --cache_num=120 --val_overlap=0.75 --syn --logdir="runs/synt.no_pretrain.swin_unetrv2_tiny" --train_dir $datapath --val_dir $datapath --json_dir datafolds/healthy.json 50 | 51 | CUDA_VISIBLE_DEVICES=7 python -W ignore validation.py --model=swin_unetrv2 --swin_type=tiny --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/synt.no_pretrain.swin_unetrv2_tiny --save_dir outs 52 | ``` 53 | ## 2. Train and evaluate segmentation models using real tumors (for comparison) 54 | 55 | #### UNET (no.pretrain) 56 | ```bash 57 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python main.py --optim_lr=4e-4 --batch_size=8 --lrschedule=warmup_cosine --optim_name=adamw --model_name=unet --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12231 --cache_num=120 --val_overlap=0.75 --train_dir $datapath --val_dir $datapath --logdir="runs/real_fold$fold.no_pretrain.unet" --json_dir datafolds/real_$fold.json; done 58 | 59 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore validation.py --model=unet --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/real_$fold.json --log_dir="runs/real_fold$fold.no_pretrain.unet" --save_dir outs; done 60 | ``` 61 | 62 | #### Swin-UNETR-Base (pretrain) 63 | ```bash 64 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore main.py --optim_lr=4e-4 --batch_size=4 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=base --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12240 --cache_num=120 --val_overlap=0.75 --train_dir $datapath --val_dir $datapath --logdir="runs/real_fold$fold.pretrain.swin_unetrv2_base" --json_dir datafolds/real_$fold.json --use_pretrained; done 65 | 66 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore validation.py --model=swin_unetrv2 --swin_type=base --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/real_$fold.json --log_dir="runs/real_fold$fold.pretrain.swin_unetrv2_base" --save_dir outs; done 67 | ``` 68 | 69 | #### Swin-UNETR-Base (no.pretrain) 70 | ```bash 71 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore main.py --optim_lr=4e-4 --batch_size=4 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=base --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12240 --cache_num=120 --val_overlap=0.75 --train_dir $datapath --val_dir $datapath --logdir="runs/real_fold$fold.no_pretrain.swin_unetrv2_base" --json_dir datafolds/real_$fold.json; done 72 | 73 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore validation.py --model=swin_unetrv2 --swin_type=base --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/real_$fold.json --log_dir="runs/real_fold$fold.no_pretrain.swin_unetrv2_base" --save_dir outs; done 74 | ``` 75 | 76 | #### Swin-UNETR-Small (no.pretrain) 77 | ```bash 78 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore main.py --optim_lr=4e-4 --batch_size=4 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=small --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12241 --cache_num=120 --val_overlap=0.75 --train_dir $datapath --val_dir $datapath --logdir="runs/real_fold$fold.no_pretrain.swin_unetrv2_small" --json_dir datafolds/real_$fold.json; done 79 | 80 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore validation.py --model=swin_unetrv2 --swin_type=small --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/real_$fold.json --log_dir="runs/real_fold$fold.no_pretrain.swin_unetrv2_small" --save_dir outs; done 81 | ``` 82 | 83 | #### Swin-UNETR-Tiny (no.pretrain) 84 | ```bash 85 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore main.py --optim_lr=4e-4 --batch_size=4 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=tiny --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12242 --cache_num=120 --val_overlap=0.75 --train_dir $datapath --val_dir $datapath --logdir="runs/real_fold$fold.no_pretrain.swin_unetrv2_tiny" --json_dir datafolds/real_$fold.json; done 86 | 87 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore validation.py --model=swin_unetrv2 --swin_type=tiny --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/real_$fold.json --log_dir="runs/real_fold$fold.no_pretrain.swin_unetrv2_tiny" --save_dir outs; done 88 | ``` 89 | 90 | ## 3. Train segmentation models using both real and synthetic tumors (for comparison) 91 | 92 | #### UNET (no.pretrain) 93 | ```bash 94 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python main.py --optim_lr=4e-4 --batch_size=8 --lrschedule=warmup_cosine --optim_name=adamw --model_name=unet --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12222 --cache_num=120 --val_overlap=0.75 --train_dir $datapath --val_dir $datapath --logdir="runs/mix_fold$fold.no_pretrain.unet" --json_dir datafolds/mix_$fold.json; done 95 | 96 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore validation.py --model=unet --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/real_$fold.json --log_dir="runs/mix_fold$fold.no_pretrain.unet" --save_dir outs; done 97 | ``` 98 | #### Swin-UNETR-Base (pretrain) 99 | ```bash 100 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore main.py --optim_lr=4e-4 --batch_size=4 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=base --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12242 --cache_num=120 --val_overlap=0.75 --train_dir $datapath --val_dir $datapath --logdir="runs/mix_fold$fold.pretrain.swin_unetrv2_base" --json_dir datafolds/mix_$fold.json --use_pretrained; done 101 | 102 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore validation.py --model=swin_unetrv2 --swin_type=base --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/real_$fold.json --log_dir="runs/mix_fold$fold.pretrain.swin_unetrv2_base" --save_dir outs; done 103 | ``` 104 | 105 | #### Swin-UNETR-Base (no.pretrain) 106 | ```bash 107 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore main.py --optim_lr=4e-4 --batch_size=4 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=base --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12242 --cache_num=120 --val_overlap=0.75 --train_dir $datapath --val_dir $datapath --logdir="runs/mix_fold$fold.no_pretrain.swin_unetrv2_base" --json_dir datafolds/mix_$fold.json; done 108 | 109 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore validation.py --model=swin_unetrv2 --swin_type=base --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/real_$fold.json --log_dir="runs/mix_fold$fold.no_pretrain.swin_unetrv2_base" --save_dir outs; done 110 | ``` 111 | 112 | #### Swin-UNETR-Small (no.pretrain) 113 | ```bash 114 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore main.py --optim_lr=4e-4 --batch_size=4 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=small --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12241 --cache_num=120 --val_overlap=0.75 --train_dir $datapath --val_dir $datapath --logdir="runs/mix_fold$fold.no_pretrain.swin_unetrv2_small" --json_dir datafolds/mix_$fold.json; done 115 | 116 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore validation.py --model=swin_unetrv2 --swin_type=small --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/real_$fold.json --log_dir="runs/mix_fold$fold.no_pretrain.swin_unetrv2_small" --save_dir outs; done 117 | ``` 118 | 119 | #### Swin-UNETR-Tiny (no.pretrain) 120 | ```bash 121 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore main.py --optim_lr=4e-4 --batch_size=4 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=tin y --val_every=100 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12220 --cache_num=120 --val_overlap=0.75 --train_dir $datapath --val_dir $datapath --logdir="runs/mix_fold$fold.no_pretrain.swin_unetrv2_tiny" --json_dir datafolds/mix_$fold.json; done 122 | 123 | for fold in {0..4}; do CUDA_VISIBLE_DEVICES=7 python -W ignore validation.py --model=swin_unetrv2 --swin_type=tiny --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/real_$fold.json --log_dir="runs/mix_fold$fold.no_pretrain.swin_unetrv2_tiny" --save_dir outs; done 124 | ``` 125 | 126 | 127 | ## Generalizability to different segmentation model backbones 128 | 129 | ##### Training 130 | ```bash 131 | # Train models on real tumors 132 | datapath=/mnt/zzhou82/PublicAbdominalData/04_LiTS 133 | for backbone in unetpp segresnet dints; do CUDA_VISIBLE_DEVICES=0 python main.py --optim_lr=4e-4 --batch_size=2 --lrschedule=warmup_cosine --optim_name=adamw --model_name=$backbone --val_every=200 --max_epochs=4000 --save_checkpoint --workers=24 --noamp --cache_num=200 --val_overlap=0.5 --train_dir $datapath --val_dir $datapath --logdir="runs/real.no_pretrain.$backbone" --json_dir datafolds/lits_split.json; done 134 | 135 | # Train models on synthetic tumors 136 | coming soon 137 | ``` 138 | 139 | ##### Testing 140 | ```bash 141 | datapath=/mnt/zzhou82/PublicAbdominalData/04_LiTS 142 | for backbone in unetpp segresnet dints; do CUDA_VISIBLE_DEVICES=1 python -W ignore validation_model.py --model=$backbone --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits_split.json --log_dir="runs/real.no_pretrain.$backbone" --save_dir outs; done 143 | ``` 144 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | #### Dataset 4 | 5 | please download these datasets and save to `` (user-defined). 6 | 7 | - 01 [Multi-Atlas Labeling Beyond the Cranial Vault - Workshop and Challenge (BTCV)](https://www.synapse.org/#!Synapse:syn3193805/wiki/89480) 8 | - 02 [Pancreas-CT TCIA](https://wiki.cancerimagingarchive.net/display/Public/Pancreas-CT) 9 | - 03 [Combined Healthy Abdominal Organ Segmentation (CHAOS)](https://chaos.grand-challenge.org/) 10 | - 04 [Liver Tumor Segmentation Challenge (LiTS)](https://competitions.codalab.org/competitions/17094) 11 | 12 | ```bash 13 | wget https://www.dropbox.com/s/jnv74utwh99ikus/01_Multi-Atlas_Labeling.tar.gz # 01 Multi-Atlas_Labeling.tar.gz (1.53 GB) 14 | wget https://www.dropbox.com/s/5yzdzb7el9r3o9i/02_TCIA_Pancreas-CT.tar.gz # 02 TCIA_Pancreas-CT.tar.gz (7.51 GB) 15 | wget https://www.dropbox.com/s/lzrhirei2t2vuwg/03_CHAOS.tar.gz # 03 CHAOS.tar.gz (925.3 MB) 16 | wget https://www.dropbox.com/s/2i19kuw7qewzo6q/04_LiTS.tar.gz # 04 LiTS.tar.gz (17.42 GB) 17 | ``` 18 | 19 | #### Dependency 20 | The code is tested on `python 3.8, Pytorch 1.11`. 21 | ```bash 22 | conda create -n syn python=3.8 23 | source activate syn (or conda activate syn) 24 | cd SyntheticTumors 25 | pip install external/surface-distance 26 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | #### Label 31 | 32 | Our synthetic algorithm requires label as ``0: background, 1: liver``, you need to transfer the label before training AI model. 33 | 34 | ```bash 35 | python transfer_label.py --data_path # is user-defined data path to save datasets 36 | ``` 37 | or you can just download the label 38 | ``` 39 | wget https://www.dropbox.com/s/8e3hlza16vor05s/label.zip 40 | ``` 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | © The Johns Hopkins University. This work is openly licensed via CC BY-NC-ND. 2 | Creative Commons Attribution-NonCommercial-NoDerivatives 4.0 3 | International Public License 4 | 5 | By exercising the Licensed Rights (defined below), You accept and agree 6 | to be bound by the terms and conditions of this Creative Commons 7 | Attribution-NonCommercial-NoDerivatives 4.0 International Public 8 | License ("Public License"). To the extent this Public License may be 9 | interpreted as a contract, You are granted the Licensed Rights in 10 | consideration of Your acceptance of these terms and conditions, and the 11 | Licensor grants You such rights in consideration of benefits the 12 | Licensor receives from making the Licensed Material available under 13 | these terms and conditions. 14 | 15 | 16 | Section 1 -- Definitions. 17 | 18 | a. Adapted Material means material subject to Copyright and Similar 19 | Rights that is derived from or based upon the Licensed Material 20 | and in which the Licensed Material is translated, altered, 21 | arranged, transformed, or otherwise modified in a manner requiring 22 | permission under the Copyright and Similar Rights held by the 23 | Licensor. For purposes of this Public License, where the Licensed 24 | Material is a musical work, performance, or sound recording, 25 | Adapted Material is always produced where the Licensed Material is 26 | synched in timed relation with a moving image. 27 | 28 | b. Copyright and Similar Rights means copyright and/or similar rights 29 | closely related to copyright including, without limitation, 30 | performance, broadcast, sound recording, and Sui Generis Database 31 | Rights, without regard to how the rights are labeled or 32 | categorized. For purposes of this Public License, the rights 33 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 34 | Rights. 35 | 36 | c. Effective Technological Measures means those measures that, in the 37 | absence of proper authority, may not be circumvented under laws 38 | fulfilling obligations under Article 11 of the WIPO Copyright 39 | Treaty adopted on December 20, 1996, and/or similar international 40 | agreements. 41 | 42 | d. Exceptions and Limitations means fair use, fair dealing, and/or 43 | any other exception or limitation to Copyright and Similar Rights 44 | that applies to Your use of the Licensed Material. 45 | 46 | e. Licensed Material means the artistic or literary work, database, 47 | or other material to which the Licensor applied this Public 48 | License. 49 | 50 | f. Licensed Rights means the rights granted to You subject to the 51 | terms and conditions of this Public License, which are limited to 52 | all Copyright and Similar Rights that apply to Your use of the 53 | Licensed Material and that the Licensor has authority to license. 54 | 55 | g. Licensor means the individual(s) or entity(ies) granting rights 56 | under this Public License. 57 | 58 | h. NonCommercial means not primarily intended for or directed towards 59 | commercial advantage or monetary compensation. For purposes of 60 | this Public License, the exchange of the Licensed Material for 61 | other material subject to Copyright and Similar Rights by digital 62 | file-sharing or similar means is NonCommercial provided there is 63 | no payment of monetary compensation in connection with the 64 | exchange. 65 | 66 | i. Share means to provide material to the public by any means or 67 | process that requires permission under the Licensed Rights, such 68 | as reproduction, public display, public performance, distribution, 69 | dissemination, communication, or importation, and to make material 70 | available to the public including in ways that members of the 71 | public may access the material from a place and at a time 72 | individually chosen by them. 73 | 74 | j. Sui Generis Database Rights means rights other than copyright 75 | resulting from Directive 96/9/EC of the European Parliament and of 76 | the Council of 11 March 1996 on the legal protection of databases, 77 | as amended and/or succeeded, as well as other essentially 78 | equivalent rights anywhere in the world. 79 | 80 | k. You means the individual or entity exercising the Licensed Rights 81 | under this Public License. Your has a corresponding meaning. 82 | 83 | 84 | Section 2 -- Scope. 85 | 86 | a. License grant. 87 | 88 | 1. Subject to the terms and conditions of this Public License, 89 | the Licensor hereby grants You a worldwide, royalty-free, 90 | non-sublicensable, non-exclusive, irrevocable license to 91 | exercise the Licensed Rights in the Licensed Material to: 92 | 93 | a. reproduce and Share the Licensed Material, in whole or 94 | in part, for NonCommercial purposes only; and 95 | 96 | b. produce and reproduce, but not Share, Adapted Material 97 | for NonCommercial purposes only. 98 | 99 | 2. Exceptions and Limitations. For the avoidance of doubt, where 100 | Exceptions and Limitations apply to Your use, this Public 101 | License does not apply, and You do not need to comply with 102 | its terms and conditions. 103 | 104 | 3. Term. The term of this Public License is specified in Section 105 | 6(a). 106 | 107 | 4. Media and formats; technical modifications allowed. The 108 | Licensor authorizes You to exercise the Licensed Rights in 109 | all media and formats whether now known or hereafter created, 110 | and to make technical modifications necessary to do so. The 111 | Licensor waives and/or agrees not to assert any right or 112 | authority to forbid You from making technical modifications 113 | necessary to exercise the Licensed Rights, including 114 | technical modifications necessary to circumvent Effective 115 | Technological Measures. For purposes of this Public License, 116 | simply making modifications authorized by this Section 2(a) 117 | (4) never produces Adapted Material. 118 | 119 | 5. Downstream recipients. 120 | 121 | a. Offer from the Licensor -- Licensed Material. Every 122 | recipient of the Licensed Material automatically 123 | receives an offer from the Licensor to exercise the 124 | Licensed Rights under the terms and conditions of this 125 | Public License. 126 | 127 | b. No downstream restrictions. You may not offer or impose 128 | any additional or different terms or conditions on, or 129 | apply any Effective Technological Measures to, the 130 | Licensed Material if doing so restricts exercise of the 131 | Licensed Rights by any recipient of the Licensed 132 | Material. 133 | 134 | 6. No endorsement. Nothing in this Public License constitutes or 135 | may be construed as permission to assert or imply that You 136 | are, or that Your use of the Licensed Material is, connected 137 | with, or sponsored, endorsed, or granted official status by, 138 | the Licensor or others designated to receive attribution as 139 | provided in Section 3(a)(1)(A)(i). 140 | 141 | b. Other rights. 142 | 143 | 1. Moral rights, such as the right of integrity, are not 144 | licensed under this Public License, nor are publicity, 145 | privacy, and/or other similar personality rights; however, to 146 | the extent possible, the Licensor waives and/or agrees not to 147 | assert any such rights held by the Licensor to the limited 148 | extent necessary to allow You to exercise the Licensed 149 | Rights, but not otherwise. 150 | 151 | 2. Patent and trademark rights are not licensed under this 152 | Public License. 153 | 154 | 3. To the extent possible, the Licensor waives any right to 155 | collect royalties from You for the exercise of the Licensed 156 | Rights, whether directly or through a collecting society 157 | under any voluntary or waivable statutory or compulsory 158 | licensing scheme. In all other cases the Licensor expressly 159 | reserves any right to collect such royalties, including when 160 | the Licensed Material is used other than for NonCommercial 161 | purposes. 162 | 163 | 164 | Section 3 -- License Conditions. 165 | 166 | Your exercise of the Licensed Rights is expressly made subject to the 167 | following conditions. 168 | 169 | a. Attribution. 170 | 171 | 1. If You Share the Licensed Material, You must: 172 | 173 | a. retain the following if it is supplied by the Licensor 174 | with the Licensed Material: 175 | 176 | i. identification of the creator(s) of the Licensed 177 | Material and any others designated to receive 178 | attribution, in any reasonable manner requested by 179 | the Licensor (including by pseudonym if 180 | designated); 181 | 182 | ii. a copyright notice; 183 | 184 | iii. a notice that refers to this Public License; 185 | 186 | iv. a notice that refers to the disclaimer of 187 | warranties; 188 | 189 | v. a URI or hyperlink to the Licensed Material to the 190 | extent reasonably practicable; 191 | 192 | b. indicate if You modified the Licensed Material and 193 | retain an indication of any previous modifications; and 194 | 195 | c. indicate the Licensed Material is licensed under this 196 | Public License, and include the text of, or the URI or 197 | hyperlink to, this Public License. 198 | 199 | For the avoidance of doubt, You do not have permission under 200 | this Public License to Share Adapted Material. 201 | 202 | 2. You may satisfy the conditions in Section 3(a)(1) in any 203 | reasonable manner based on the medium, means, and context in 204 | which You Share the Licensed Material. For example, it may be 205 | reasonable to satisfy the conditions by providing a URI or 206 | hyperlink to a resource that includes the required 207 | information. 208 | 209 | 3. If requested by the Licensor, You must remove any of the 210 | information required by Section 3(a)(1)(A) to the extent 211 | reasonably practicable. 212 | 213 | 214 | Section 4 -- Sui Generis Database Rights. 215 | 216 | Where the Licensed Rights include Sui Generis Database Rights that 217 | apply to Your use of the Licensed Material: 218 | 219 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 220 | to extract, reuse, reproduce, and Share all or a substantial 221 | portion of the contents of the database for NonCommercial purposes 222 | only and provided You do not Share Adapted Material; 223 | 224 | b. if You include all or a substantial portion of the database 225 | contents in a database in which You have Sui Generis Database 226 | Rights, then the database in which You have Sui Generis Database 227 | Rights (but not its individual contents) is Adapted Material; and 228 | 229 | c. You must comply with the conditions in Section 3(a) if You Share 230 | all or a substantial portion of the contents of the database. 231 | 232 | For the avoidance of doubt, this Section 4 supplements and does not 233 | replace Your obligations under this Public License where the Licensed 234 | Rights include other Copyright and Similar Rights. 235 | 236 | 237 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 238 | 239 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 240 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 241 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 242 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 243 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 244 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 245 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 246 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 247 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 248 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 249 | 250 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 251 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 252 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 253 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 254 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 255 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 256 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 257 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 258 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 259 | 260 | c. The disclaimer of warranties and limitation of liability provided 261 | above shall be interpreted in a manner that, to the extent 262 | possible, most closely approximates an absolute disclaimer and 263 | waiver of all liability. 264 | 265 | 266 | Section 6 -- Term and Termination. 267 | 268 | a. This Public License applies for the term of the Copyright and 269 | Similar Rights licensed here. However, if You fail to comply with 270 | this Public License, then Your rights under this Public License 271 | terminate automatically. 272 | 273 | b. Where Your right to use the Licensed Material has terminated under 274 | Section 6(a), it reinstates: 275 | 276 | 1. automatically as of the date the violation is cured, provided 277 | it is cured within 30 days of Your discovery of the 278 | violation; or 279 | 280 | 2. upon express reinstatement by the Licensor. 281 | 282 | For the avoidance of doubt, this Section 6(b) does not affect any 283 | right the Licensor may have to seek remedies for Your violations 284 | of this Public License. 285 | 286 | c. For the avoidance of doubt, the Licensor may also offer the 287 | Licensed Material under separate terms or conditions or stop 288 | distributing the Licensed Material at any time; however, doing so 289 | will not terminate this Public License. 290 | 291 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 292 | License. 293 | 294 | 295 | Section 7 -- Other Terms and Conditions. 296 | 297 | a. The Licensor shall not be bound by any additional or different 298 | terms or conditions communicated by You unless expressly agreed. 299 | 300 | b. Any arrangements, understandings, or agreements regarding the 301 | Licensed Material not stated herein are separate from and 302 | independent of the terms and conditions of this Public License. 303 | 304 | 305 | Section 8 -- Interpretation. 306 | 307 | a. For the avoidance of doubt, this Public License does not, and 308 | shall not be interpreted to, reduce, limit, restrict, or impose 309 | conditions on any use of the Licensed Material that could lawfully 310 | be made without permission under this Public License. 311 | 312 | b. To the extent possible, if any provision of this Public License is 313 | deemed unenforceable, it shall be automatically reformed to the 314 | minimum extent necessary to make it enforceable. If the provision 315 | cannot be reformed, it shall be severed from this Public License 316 | without affecting the enforceability of the remaining terms and 317 | conditions. 318 | 319 | c. No term or condition of this Public License will be waived and no 320 | failure to comply consented to unless expressly agreed to by the 321 | Licensor. 322 | 323 | d. Nothing in this Public License constitutes or may be interpreted 324 | as a limitation upon, or waiver of, any privileges and immunities 325 | that apply to the Licensor or You, including from the legal 326 | processes of any jurisdiction or authority. 327 | 328 | 329 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

SynTumor

2 | 3 |
4 | 5 | ![visitors](https://visitor-badge.laobi.icu/badge?page_id=MrGiovanni/SyntheticTumors) 6 | [![GitHub Repo stars](https://img.shields.io/github/stars/MrGiovanni/SyntheticTumors?style=social)](https://github.com/MrGiovanni/SyntheticTumors/stargazers) 7 | 8 | Follow on Twitter 9 |
10 | **Subscribe us: https://groups.google.com/u/2/g/bodymaps** 11 | 12 |
13 | 14 | This repository provides extensive examples of synthetic liver tumors generated by our novel strategies. Check to see if you could tell which is real tumor and which is synthetic tumor. More importantly, our synthetic tumors can be used for training AI models, and have proven to achieve a similar (actually, *better*) performance in real tumor segmentation than a model trained on real tumors. 15 | 16 | **Amazing**, right? 17 | 18 |

19 |

20 | 21 | ## Paper 22 | 23 | Label-Free Liver Tumor Segmentation
24 | [Qixin Hu](https://scholar.google.com/citations?user=EqD5GP8AAAAJ&hl=en)1, [Yixiong Chen](https://scholar.google.com/citations?hl=en&user=bVHYVXQAAAAJ)2, [Junfei Xiao](https://lambert-x.github.io/)3, Shuwen Sun4, [Jieneng Chen](https://scholar.google.com/citations?hl=en&user=yLYj88sAAAAJ)3, [Alan L. Yuille](https://www.cs.jhu.edu/~ayuille/)3, and [Zongwei Zhou](https://www.zongweiz.com/)3,*
25 | 1 Huazhong University of Science and Technology,
26 | 2 The Chinese University of Hong Kong -- Shenzhen,
27 | 3 Johns Hopkins University,
28 | 4 The First Affiliated Hospital of Nanjing Medical University
29 | CVPR, 2023
30 | [paper](https://arxiv.org/pdf/2303.14869.pdf) | [code](https://github.com/MrGiovanni/SyntheticTumors) | [talk](https://www.youtube.com/watch?v=Alqr3vQSDro) (by Alan Yuille) | [talk](https://youtu.be/DhZzAp7gxxw) | [slides](https://github.com/MrGiovanni/SyntheticTumors/blob/main/documents/slides_cvpr23.pdf) | [poster](https://github.com/MrGiovanni/SyntheticTumors/blob/main/documents/poster_cvpr23.pdf) 31 | 32 | Synthetic Tumors Make AI Segment Tumors Better
33 | [Qixin Hu](https://scholar.google.com/citations?user=EqD5GP8AAAAJ&hl=en)1, [Junfei Xiao](https://lambert-x.github.io/)2, [Yixiong Chen](https://scholar.google.com/citations?hl=en&user=bVHYVXQAAAAJ)3, Shuwen Sun4, [Jieneng Chen](https://scholar.google.com/citations?hl=en&user=yLYj88sAAAAJ)2, [Alan L. Yuille](https://www.cs.jhu.edu/~ayuille/)2, and [Zongwei Zhou](https://www.zongweiz.com/)2,*
34 | 1 Huazhong University of Science and Technology,
35 | 2 Johns Hopkins University,
36 | 3 The Chinese University of Hong Kong -- Shenzhen,
37 | 4 The First Affiliated Hospital of Nanjing Medical University
38 | Medical Imaging Meets NeurIPS, 2022
39 | [paper](https://arxiv.org/pdf/2210.14845.pdf) | [code](https://github.com/MrGiovanni/SyntheticTumors) | [slides](https://github.com/MrGiovanni/SyntheticTumors/blob/main/documents/slides.pdf) | [poster](https://github.com/MrGiovanni/SyntheticTumors/blob/main/documents/poster.pdf) | demo | [talk](https://www.youtube.com/watch?v=bJpI9tCTsuA) 40 | 41 | Early Detection and Localization of Pancreatic Cancer by Label-Free Tumor Synthesis
42 | [Bowen Li](https://scholar.google.com/citations?user=UfINwO0AAAAJ&hl=en)1, [Yu-Cheng Chou](https://sites.google.com/view/yu-cheng-chou)1, Shuwen Sun2, Hualin Qiao3, [Alan L. Yuille](https://www.cs.jhu.edu/~ayuille/)1, and [Zongwei Zhou](https://www.zongweiz.com/)1,*
43 | 1 Johns Hopkins University,
44 | 2 The First Affiliated Hospital of Nanjing Medical University,
45 | 3 Sepax technologies
46 | Big Task Small Data, 1001-AI, MICCAI Workshop, 2023
47 | [paper](https://arxiv.org/pdf/2308.03008.pdf) | [code](https://github.com/MrGiovanni/SyntheticTumors) 48 | 49 | **We have documented common questions for the paper in [Frequently Asked Questions (FAQ)](documents/FAQ.md).** 50 | 51 | **We have also provided a list of publications related to tumor synthesis in [Awesome Synthetic Tumors](AWESOME.md) [![Awesome](https://awesome.re/badge.svg)](https://awesome.re).** 52 | 53 | ## Model 54 | 55 | | Tumor | Model | Pre-trained? | Download | 56 | | ---- | ---- | ---- | ---- | 57 | | real | unet | no | [link](https://www.dropbox.com/s/8jfu22zgz4a8qjk/model.pt) | 58 | | real | swin_unetrv2_base | yes | [link](https://www.dropbox.com/s/cmh3uvwtfd3jlvp/model.pt) | 59 | | real | swin_unetrv2_base | no | [link](https://www.dropbox.com/s/vf3se1yns18t7qm/model.pt) | 60 | | real | swin_unetrv2_small | no | [link](https://www.dropbox.com/s/337uz6484zyzjty/model.pt) | 61 | | real | swin_unetrv2_tiny | no | [link](https://www.dropbox.com/s/leh04qh5hvfq9i5/model.pt) | 62 | | synt | unet | no | [link](https://www.dropbox.com/s/gb01obtxlbktmrp/model.pt) | 63 | | synt | swin_unetrv2_base | yes | [link](https://www.dropbox.com/s/pxvb6qfmeaha2va/model.pt) | 64 | | synt | swin_unetrv2_base | no | [link](https://www.dropbox.com/s/idxwss85bmx3ejo/model.pt) | 65 | | synt | swin_unetrv2_small | no | [link](https://www.dropbox.com/s/nkb64yo4jmscoy6/model.pt) | 66 | | synt | swin_unetrv2_tiny | no | [link](https://www.dropbox.com/s/ddej5duj03ioh49/model.pt) | 67 | 68 | **Use the following command to download everything.** 69 | ```bash 70 | wget https://www.dropbox.com/s/jys4tt2ttmr7ig1/runs.tar.gz 71 | tar -xzvf runs.tar.gz 72 | ``` 73 | 74 | ## 0. Installation 75 | 76 | ```bash 77 | git clone https://github.com/MrGiovanni/SyntheticTumors.git 78 | cd SyntheticTumors 79 | wget https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/model_swinvit.pt 80 | ``` 81 | 82 | See [installation instructions](https://github.com/MrGiovanni/SyntheticTumors/blob/main/INSTALL.md). 83 | 84 | ## 1. Train segmentation models using synthetic tumors 85 | 86 | ``` 87 | datapath=/mnt/zzhou82/PublicAbdominalData/ 88 | 89 | # UNET (no.pretrain) 90 | CUDA_VISIBLE_DEVICES=0 python -W ignore main.py --optim_lr=4e-4 --batch_size=2 --lrschedule=warmup_cosine --optim_name=adamw --model_name=unet --val_every=200 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12235 --cache_num=200 --val_overlap=0.5 --syn --logdir="runs/synt.no_pretrain.unet" --train_dir $datapath --val_dir $datapath --json_dir datafolds/healthy.json 91 | # Swin-UNETR-Base (pretrain) 92 | CUDA_VISIBLE_DEVICES=0 python -W ignore main.py --optim_lr=4e-4 --batch_size=2 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=base --val_every=200 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12231 --cache_num=200 --val_overlap=0.5 --syn --logdir="runs/synt.pretrain.swin_unetrv2_base" --train_dir $datapath --val_dir $datapath --json_dir datafolds/healthy.json --use_pretrained 93 | # Swin-UNETR-Base (no.pretrain) 94 | CUDA_VISIBLE_DEVICES=0 python -W ignore main.py --optim_lr=4e-4 --batch_size=2 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=base --val_every=200 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12231 --cache_num=200 --val_overlap=0.5 --syn --logdir="runs/synt.no_pretrain.swin_unetrv2_base" --train_dir $datapath --val_dir $datapath --json_dir datafolds/healthy.json 95 | # Swin-UNETR-Small (no.pretrain) 96 | CUDA_VISIBLE_DEVICES=0 python -W ignore main.py --optim_lr=4e-4 --batch_size=2 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=small --val_every=200 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12233 --cache_num=200 --val_overlap=0.5 --syn --logdir="runs/synt.no_pretrain.swin_unetrv2_small" --train_dir $datapath --val_dir $datapath --json_dir datafolds/healthy.json 97 | # Swin-UNETR-Tiny (no.pretrain) 98 | CUDA_VISIBLE_DEVICES=0 python -W ignore main.py --optim_lr=4e-4 --batch_size=2 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=tiny --val_every=200 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12234 --cache_num=200 --val_overlap=0.5 --syn --logdir="runs/synt.no_pretrain.swin_unetrv2_tiny" --train_dir $datapath --val_dir $datapath --json_dir datafolds/healthy.json 99 | ``` 100 | 101 | ## 2. Train segmentation models using real tumors (for comparison) 102 | 103 | ``` 104 | datapath=/mnt/zzhou82/PublicAbdominalData/ 105 | 106 | # UNET (no.pretrain) 107 | CUDA_VISIBLE_DEVICES=0 python -W ignore -W ignore main.py --optim_lr=4e-4 --batch_size=2 --lrschedule=warmup_cosine --optim_name=adamw --model_name=unet --val_every=200 --val_overlap=0.5 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12235 --cache_num=200 --logdir="runs/real.no_pretrain.unet" --train_dir $datapath --val_dir $datapath --json_dir datafolds/lits.json 108 | # Swin-UNETR-Base (pretrain) 109 | CUDA_VISIBLE_DEVICES=0 python -W ignore -W ignore main.py --optim_lr=4e-4 --batch_size=2 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=base --val_every=200 --val_overlap=0.5 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12231 --cache_num=200 --logdir="runs/real.pretrain.swin_unetrv2_base" --train_dir $datapath --val_dir $datapath --json_dir datafolds/lits.json --use_pretrained 110 | # Swin-UNETR-Base (no.pretrain) 111 | CUDA_VISIBLE_DEVICES=0 python -W ignore -W ignore main.py --optim_lr=4e-4 --batch_size=2 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=base --val_every=200 --val_overlap=0.5 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12232 --cache_num=200 --logdir="runs/real.no_pretrain.swin_unetrv2_base" --train_dir $datapath --val_dir $datapath --json_dir datafolds/lits.json 112 | # Swin-UNETR-Small (no.pretrain) 113 | CUDA_VISIBLE_DEVICES=0 python -W ignore -W ignore main.py --optim_lr=4e-4 --batch_size=2 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=small --val_every=200 --val_overlap=0.5 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12233 --cache_num=200 --logdir="runs/real.no_pretrain.swin_unetrv2_small" --train_dir $datapath --val_dir $datapath --json_dir datafolds/lits.json 114 | # Swin-UNETR-Tiny (no.pretrain) 115 | CUDA_VISIBLE_DEVICES=0 python -W ignore -W ignore main.py --optim_lr=4e-4 --batch_size=2 --lrschedule=warmup_cosine --optim_name=adamw --model_name=swin_unetrv2 --swin_type=tiny --val_every=200 --val_overlap=0.5 --max_epochs=4000 --save_checkpoint --workers=2 --noamp --distributed --dist-url=tcp://127.0.0.1:12234 --cache_num=200 --logdir="runs/real.no_pretrain.swin_unetrv2_tiny" --train_dir $datapath --val_dir $datapath --json_dir datafolds/lits.json 116 | ``` 117 | 118 | ## 3. Evaluation 119 | 120 | #### AI model trained by synthetic tumors 121 | 122 | ``` 123 | datapath=/mnt/zzhou82/PublicAbdominalData/ 124 | 125 | # UNET (no.pretrain) 126 | CUDA_VISIBLE_DEVICES=0 python -W ignore validation.py --model=unet --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/synt.no_pretrain.unet --save_dir out 127 | # Swin-UNETR-Base (pretrain) 128 | CUDA_VISIBLE_DEVICES=0 python -W ignore validation.py --model=swin_unetrv2 --swin_type=base --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/synt.pretrain.swin_unetrv2_base --save_dir out 129 | # Swin-UNETR-Base (no.pretrain) 130 | CUDA_VISIBLE_DEVICES=0 python -W ignore validation.py --model=swin_unetrv2 --swin_type=base --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/synt.no_pretrain.swin_unetrv2_base --save_dir out 131 | # Swin-UNETR-Small (no.pretrain) 132 | CUDA_VISIBLE_DEVICES=0 python -W ignore validation.py --model=swin_unetrv2 --swin_type=small --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/synt.no_pretrain.swin_unetrv2_small --save_dir out 133 | # Swin-UNETR-Tiny (no.pretrain) 134 | CUDA_VISIBLE_DEVICES=0 python -W ignore validation.py --model=swin_unetrv2 --swin_type=tiny --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/synt.no_pretrain.swin_unetrv2_tiny --save_dir out 135 | ``` 136 | 137 | #### AI model trained by real tumors 138 | 139 | ``` 140 | datapath=/mnt/zzhou82/PublicAbdominalData/ 141 | 142 | # UNET (no.pretrain) 143 | CUDA_VISIBLE_DEVICES=0 python -W ignore validation.py --model=unet --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/real.no_pretrain.unet --save_dir out 144 | # Swin-UNETR-Base (pretrain) 145 | CUDA_VISIBLE_DEVICES=0 python -W ignore validation.py --model=swin_unetrv2 --swin_type=base --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/real.pretrain.swin_unetrv2_base --save_dir out 146 | # Swin-UNETR-Base (no.pretrain) 147 | CUDA_VISIBLE_DEVICES=0 python -W ignore validation.py --model=swin_unetrv2 --swin_type=base --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/real.no_pretrain.swin_unetrv2_base --save_dir out 148 | # Swin-UNETR-Small (no.pretrain) 149 | CUDA_VISIBLE_DEVICES=0 python -W ignore validation.py --model=swin_unetrv2 --swin_type=small --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/real.no_pretrain.swin_unetrv2_small --save_dir out 150 | # Swin-UNETR-Tiny (no.pretrain) 151 | CUDA_VISIBLE_DEVICES=0 python -W ignore validation.py --model=swin_unetrv2 --swin_type=tiny --val_overlap=0.75 --val_dir $datapath --json_dir datafolds/lits.json --log_dir runs/real.no_pretrain.swin_unetrv2_tiny --save_dir out 152 | ``` 153 | 154 | 155 | ## TODO 156 | 157 | - [x] Upload the paper to arxiv 158 | - [ ] Make a video about Visual Turing Test (will appear in YouTube) 159 | - [ ] Make an online app for Visual Turing Test 160 | - [x] Apply for a US patent 161 | - [ ] Upload the evaluation code for small tumors 162 | - [ ] Upload the evaluation code for the false-positive study 163 | - [ ] Make a Jupyter Notebook for tumor synthesis 164 | 165 | ## Citation 166 | 167 | ``` 168 | @inproceedings{hu2023label, 169 | title={Label-free liver tumor segmentation}, 170 | author={Hu, Qixin and Chen, Yixiong and Xiao, Junfei and Sun, Shuwen and Chen, Jieneng and Yuille, Alan L and Zhou, Zongwei}, 171 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 172 | pages={7422--7432}, 173 | year={2023} 174 | } 175 | 176 | @article{hu2022synthetic, 177 | title={Synthetic Tumors Make AI Segment Tumors Better}, 178 | author={Hu, Qixin and Xiao, Junfei and Chen, Yixiong and Sun, Shuwen and Chen, Jie-Neng and Yuille, Alan and Zhou, Zongwei}, 179 | journal={NeurIPS Workshop on Medical Imaging meets NeurIPS}, 180 | year={2022} 181 | } 182 | 183 | @article{li2023early, 184 | title={Early Detection and Localization of Pancreatic Cancer by Label-Free Tumor Synthesis}, 185 | author={Li, Bowen and Chou, Yu-Cheng and Sun, Shuwen and Qiao, Hualin and Yuille, Alan and Zhou, Zongwei}, 186 | journal={arXiv preprint arXiv:2308.03008}, 187 | year={2023} 188 | } 189 | ``` 190 | 191 | ## Acknowledgement 192 | 193 | This work was supported by the Lustgarten Foundation for Pancreatic Cancer Research and the McGovern Foundation. The segmentation backbone is based on [Swin UNETR](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_btcv_segmentation_3d.ipynb); we appreciate the effort of the [MONAI Team](https://monai.io/) to provide and maintain open-source code to the community. We thank Camille Torrico and Alexa Delaney for improving the writing of this paper. Paper content is covered by patents pending. 194 | -------------------------------------------------------------------------------- /TumorGenerated/TumorGenerated.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Hashable, Mapping, Dict 3 | 4 | from monai.config import KeysCollection 5 | from monai.config.type_definitions import NdarrayOrTensor 6 | from monai.transforms.transform import MapTransform, RandomizableTransform 7 | 8 | from .utils import SynthesisTumor, get_predefined_texture 9 | import numpy as np 10 | 11 | class TumorGenerated(RandomizableTransform, MapTransform): 12 | def __init__(self, 13 | keys: KeysCollection, 14 | prob: float = 0.1, 15 | tumor_prob = [0.2, 0.2, 0.2, 0.2, 0.2], 16 | allow_missing_keys: bool = False 17 | ) -> None: 18 | MapTransform.__init__(self, keys, allow_missing_keys) 19 | RandomizableTransform.__init__(self, prob) 20 | random.seed(0) 21 | np.random.seed(0) 22 | 23 | self.tumor_types = ['tiny', 'small', 'medium', 'large', 'mix'] 24 | 25 | assert len(tumor_prob) == 5 26 | self.tumor_prob = np.array(tumor_prob) 27 | # texture shape: 420, 300, 320 28 | # self.textures = pre_define 10 texture 29 | self.textures = [] 30 | sigma_as = [3, 6, 9, 12, 15] 31 | sigma_bs = [4, 7] 32 | predefined_texture_shape = (420, 300, 320) 33 | for sigma_a in sigma_as: 34 | for sigma_b in sigma_bs: 35 | texture = get_predefined_texture(predefined_texture_shape, sigma_a, sigma_b) 36 | self.textures.append(texture) 37 | print("All predefined texture have generated.") 38 | 39 | 40 | def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: 41 | d = dict(data) 42 | self.randomize(None) 43 | 44 | if self._do_transform and (np.max(d['label']) <= 1): 45 | tumor_type = np.random.choice(self.tumor_types, p=self.tumor_prob.ravel()) 46 | texture = random.choice(self.textures) 47 | d['image'][0], d['label'][0] = SynthesisTumor(d['image'][0], d['label'][0], tumor_type, texture) 48 | # print(tumor_type, d['image'].shape, np.max(d['label'])) 49 | return d 50 | -------------------------------------------------------------------------------- /TumorGenerated/__init__.py: -------------------------------------------------------------------------------- 1 | ### Online Version TumorGeneration ### 2 | 3 | from .TumorGenerated import TumorGenerated 4 | 5 | from .utils import SynthesisTumor -------------------------------------------------------------------------------- /datafolds/datafold_read.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | def datafold_read(datalist, basedir, fold=0, key='training'): 6 | ''' 7 | 8 | :param datalist: json file (filename only) with the list of data 9 | :param basedir: directory of json file 10 | :param fold: which fold to use (0..1 if in training set) 11 | :param key: usually 'training' , but can try 'validation' or 'testing' to get the list data without labels (used in challenges) 12 | :return: our own 2 arrays (training, validation) 13 | ''' 14 | 15 | with open(datalist) as f: 16 | json_data = json.load(f) 17 | 18 | json_data = json_data[key] 19 | 20 | for d in json_data: 21 | for k, v in d.items(): 22 | if isinstance(d[k], list): 23 | d[k] = [os.path.join(basedir, iv) for iv in d[k]] 24 | elif isinstance(d[k], str): 25 | d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k] 26 | 27 | tr=[] 28 | val=[] 29 | for d in json_data: 30 | if 'fold' in d and d['fold'] == fold: 31 | val.append(d) 32 | else: 33 | tr.append(d) 34 | 35 | return tr, val 36 | -------------------------------------------------------------------------------- /documents/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrGiovanni/SyntheticTumors/983dbeef8fb6f5d15ec4373a5241fc198b92d765/documents/.DS_Store -------------------------------------------------------------------------------- /documents/FAQ.md: -------------------------------------------------------------------------------- 1 | # Frequently Asked Questions (FAQ) 2 | 3 | - **The method generalizability?** 4 | 5 | It was validated in two ways. (1) The model was trained on three datasets and tested on LiTS and three other datasets. Our setting is much more comprehensive than the standard AI evaluation protocol ([Bilic et al., MedIA'23](https://www.sciencedirect.com/science/article/pii/S1361841522003085)). (2) The hyper-parameters of the tumor synthesis were determined on the three datasets (i.e., CHAOS, BTCV, Pancreas-CT), and then the synthetic tumors, created by the same set of hyper-parameters, were mixed up with the real tumors in LiTS (no overlapping with the three datasets) for the visual assessment. Therefore, the proposed method should be generalized to other CT scans. 6 | 7 | - **Compare tumor segmentation with LiTS leaderboard performance** 8 | 9 | In Table 5, the baseline results (5-fold cross-validation on real tumors) were provided by the winner of the MSD challenge ([Tang et al., CVPR'22](https://openaccess.thecvf.com/content/CVPR2022/papers/Tang_Self-Supervised_Pre-Training_of_Swin_Transformers_for_3D_Medical_Image_Analysis_CVPR_2022_paper.pdf)), whose Task03-Liver was LiTS. Our synthetic tumors achieved similar performance to Tang et al. It is unfair to directly compare results between cross-validation and leaderboard because (1) the data used for testing are different and (2) the leaderboard performance is heavily overfitted by cherry-picking the prediction based on the case-by-case DSC scores. 10 | 11 | - **The generalization to different architectures?** 12 | 13 | This has been tested on the classic U-Net architecture. U-Net trained by synthetic tumors achieved a DSC of 57.3%; U-Net trained by real tumors obtained a DSC of 56.4%. The conclusion is consistent with our paper (using Swin UNETR). 14 | 15 | - **Clinical knowledge for new tumor analysis tasks** 16 | 17 | It is needed in two aspects. First, the tumor develop process. In our work, the growth pattern of hepatocellular carcinoma (HCC) has a predominantly expansive (rather than invasive) growth pattern. HCC generally spreads by infiltrating adjacent liver parenchyma or vessels, reflecting unclear tumor margins or satellite lesions. Rarely, multicentric HCC might be confused with intrahepatic metastases that with similar size to the main HCC. Besides, the presence of capsule appearance is a specific feature of HCC. Second, the imaging characteristics. For instance, "wash-in" and "wash-out" are evaluated on imaging as a comparison with liver parenchyma, displaying as higher or lower density, so that we can generate HCC with the appropriate HU values. 18 | 19 | - **The accuracy of vessel segmentation** 20 | 21 | Frankly, this performance does not matter in this work. AI achieves a DSC of 52.8% and 52.9% with and without vessel segmentation. Vessel segmentation, collision detection, mass effect, and capsule appearance were not used for model training due to their high computational cost (but used for Visual Turing Test). 22 | 23 | - **Does the ratio of real/synt data matter?** 24 | 25 | It matters to some extent based on the table below (104 CT scans in total). But our main contribution is that with no real tumor scans annotated, AI can achieve comparable performance to the fully supervised counterpart. 26 | 27 | | real/synt | 1/9 | 3/7 | 5/5 | 7/3 | 9/1 | 28 | | ---- | ---- | ---- | ---- | ---- | ---- | 29 | DSC (%) | 51.2 | 55.3 | 56.6 | 51.5 | 53.2 30 | 31 | - **Generalizability to other organs?** 32 | 33 | A tumor synthesis strategy that is universally effective for a variety of organs is certainly an attractive topic and is the Holy Grail of unsupervised tumor segmentation. However, previous syntheses of tumors, reviewed in Related Works, were designed specifically for a single type of abnormality. We are pioneering in our demonstration that purely training on synthetic tumors can achieve performance that is comparable to training on real liver tumors. Adapting our method to other organs requires a deep understanding of the biology and pathology of the specific tumor. We anticipate that the generalizability of our method can be enhanced through the utilization of automated methods such as GANs, Diffusion Models, and NeRF for generating representative imaging characteristics of various types of tumors in multiple organs. 34 | 35 | - **Will increased size of real/synt data further improve the performance?** 36 | 37 | At the moment, all publicly accessible CT scans with annotated liver tumors have been used for training our baseline (ranking \#1 in LiTS/MSD). More annotated data or new architectures are needed to overcome the bottleneck of real-tumor training. In contrast, synthetic data allow us to significantly expand the training data without the need for manual annotation efforts. We are currently training models on 2,000 healthy CT scans (which are easier to collect) incorporating synthetic tumors. 38 | 39 | - **Why not train AI using every transformation?** 40 | 41 | We did not use all the proposed transformations because it took 5s to generate a synthetic tumor if using everything, and the performance is similar. 42 | 43 | - **Why not use offline generator?** 44 | 45 | It is because the synthetic tumors were not diverse enough if pre-generated and saved to the disc. This led to a downgraded performance: DSC = 43.5% (offline) vs. 52.9% (on-the-fly). 46 | 47 | - **Why does the HU value follow Gaussian distribution?** 48 | 49 | This was supported by the statistics reported in [Bilic et al., MedIA'23](https://www.sciencedirect.com/science/article/pii/S1361841522003085). The randomly scattered light quanta in space made the HU intensity in any specific point to follow Gaussian distribution ([Alpert et al., TMI'82](https://ieeexplore.ieee.org/abstract/document/4307561)). 50 | -------------------------------------------------------------------------------- /documents/hu2023label.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrGiovanni/SyntheticTumors/983dbeef8fb6f5d15ec4373a5241fc198b92d765/documents/hu2023label.pdf -------------------------------------------------------------------------------- /documents/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrGiovanni/SyntheticTumors/983dbeef8fb6f5d15ec4373a5241fc198b92d765/documents/poster.pdf -------------------------------------------------------------------------------- /documents/poster_cvpr23.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrGiovanni/SyntheticTumors/983dbeef8fb6f5d15ec4373a5241fc198b92d765/documents/poster_cvpr23.pdf -------------------------------------------------------------------------------- /documents/slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrGiovanni/SyntheticTumors/983dbeef8fb6f5d15ec4373a5241fc198b92d765/documents/slides.pdf -------------------------------------------------------------------------------- /external/surface-distance/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are just a few small guidelines you need to follow. 4 | 5 | ## Contributor License Agreement 6 | 7 | Contributions to this project must be accompanied by a Contributor License Agreement. You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. Head over to to see your current agreements on file or to sign a new one. 8 | 9 | You generally only need to submit a CLA once, so if you've already submitted one (even if it was for a different project), you probably don't need to do it again. 10 | 11 | ## Code reviews 12 | 13 | All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. 14 | 15 | ## Community Guidelines 16 | 17 | This project follows [Google's Open Source Community Guidelines](https://opensource.google.com/conduct/). -------------------------------------------------------------------------------- /external/surface-distance/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /external/surface-distance/README.md: -------------------------------------------------------------------------------- 1 | # Surface distance metrics 2 | 3 | ## Summary 4 | When comparing multiple image segmentations, performance metrics that assess how closely the surfaces align can be a useful difference measure. This group of surface distance based measures computes the closest distances from all surface points on one segmentation to the points on another surface, and returns performance metrics between the two. This distance can be used alongside other metrics to compare segmented regions against a ground truth. 5 | 6 | Surfaces are represented using surface elements with corresponding area, allowing for more consistent approximation of surface measures. 7 | 8 | ## Metrics included 9 | This library computes the following performance metrics for segmentation: 10 | 11 | - Average surface distance (see `compute_average_surface_distance`) 12 | - Hausdorff distance (see `compute_robust_hausdorff`) 13 | - Surface overlap (see `compute_surface_overlap_at_tolerance`) 14 | - Surface dice (see `compute_surface_dice_at_tolerance`) 15 | - Volumetric dice (see `compute_dice_coefficient`) 16 | 17 | ## Installation 18 | First clone the repo, then install the dependencies and `surface-distance` 19 | package via pip: 20 | 21 | ```shell 22 | $ git clone https://github.com/deepmind/surface-distance.git 23 | $ pip install surface-distance/ 24 | ``` 25 | 26 | ## Usage 27 | For simple usage examples, see `surface_distance_test.py`. 28 | -------------------------------------------------------------------------------- /external/surface-distance/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Surface distance module: https://github.com/deepmind/surface-distance .""" 15 | 16 | from .surface_distance import * # pylint: disable=wildcard-import 17 | -------------------------------------------------------------------------------- /external/surface-distance/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """PyPI package definition.""" 15 | 16 | from setuptools import setup 17 | 18 | setup(name="Surface Distance Based Measures", 19 | version="0.1", 20 | description=( 21 | "Library containing utilities to compute performance metrics for " 22 | "segmentation"), 23 | url="https://github.com/deepmind/surface-distance", 24 | author="DeepMind", 25 | license="Apache License, Version 2.0", 26 | packages=["surface_distance"], 27 | install_requires=["numpy", "scipy", "absl-py"]) 28 | -------------------------------------------------------------------------------- /external/surface-distance/surface_distance/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Surface distance module: https://github.com/deepmind/surface-distance .""" 15 | 16 | from .metrics import * # pylint: disable=wildcard-import 17 | __version__ = "0.1" 18 | -------------------------------------------------------------------------------- /external/surface-distance/surface_distance_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS-IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Simple tests for surface metric computations.""" 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import math 21 | import google3 22 | from absl.testing import absltest 23 | from absl.testing import parameterized 24 | import numpy as np 25 | import surface_distance 26 | from surface_distance.surface_distance import metrics 27 | 28 | 29 | class SurfaceDistanceTest(parameterized.TestCase, absltest.TestCase): 30 | 31 | def _assert_almost_equal(self, expected, actual, places): 32 | """Assertion wrapper correctly handling NaN equality.""" 33 | if np.isnan(expected) and np.isnan(actual): 34 | return 35 | self.assertAlmostEqual(expected, actual, places) 36 | 37 | def _assert_metrics(self, 38 | surface_distances, mask_gt, mask_pred, 39 | expected_average_surface_distance, 40 | expected_hausdorff_100, 41 | expected_hausdorff_95, 42 | expected_surface_overlap_at_1mm, 43 | expected_surface_dice_at_1mm, 44 | expected_volumetric_dice, 45 | places=3): 46 | actual_average_surface_distance = ( 47 | surface_distance.compute_average_surface_distance(surface_distances)) 48 | for i in range(2): 49 | self._assert_almost_equal( 50 | expected_average_surface_distance[i], 51 | actual_average_surface_distance[i], 52 | places=places) 53 | 54 | self._assert_almost_equal( 55 | expected_hausdorff_100, 56 | surface_distance.compute_robust_hausdorff(surface_distances, 100), 57 | places=places) 58 | 59 | self._assert_almost_equal( 60 | expected_hausdorff_95, 61 | surface_distance.compute_robust_hausdorff(surface_distances, 95), 62 | places=places) 63 | 64 | actual_surface_overlap_at_1mm = ( 65 | surface_distance.compute_surface_overlap_at_tolerance( 66 | surface_distances, tolerance_mm=1)) 67 | for i in range(2): 68 | self._assert_almost_equal( 69 | expected_surface_overlap_at_1mm[i], 70 | actual_surface_overlap_at_1mm[i], 71 | places=places) 72 | 73 | self._assert_almost_equal( 74 | expected_surface_dice_at_1mm, 75 | surface_distance.compute_surface_dice_at_tolerance( 76 | surface_distances, tolerance_mm=1), 77 | places=places) 78 | 79 | self._assert_almost_equal( 80 | expected_volumetric_dice, 81 | surface_distance.compute_dice_coefficient(mask_gt, mask_pred), 82 | places=places) 83 | 84 | @parameterized.parameters(( 85 | np.zeros([2, 2, 2], dtype=np.bool), 86 | np.zeros([2, 2], dtype=np.bool), 87 | [1, 1], 88 | ), ( 89 | np.zeros([2, 2], dtype=np.bool), 90 | np.zeros([2, 2, 2], dtype=np.bool), 91 | [1, 1], 92 | ), ( 93 | np.zeros([2, 2], dtype=np.bool), 94 | np.zeros([2, 2], dtype=np.bool), 95 | [1, 1, 1], 96 | )) 97 | def test_compute_surface_distances_raises_on_incompatible_shapes( 98 | self, mask_gt, mask_pred, spacing_mm): 99 | with self.assertRaisesRegex(ValueError, 100 | 'The arguments must be of compatible shape'): 101 | surface_distance.compute_surface_distances(mask_gt, mask_pred, spacing_mm) 102 | 103 | @parameterized.parameters(( 104 | np.zeros([2], dtype=np.bool), 105 | np.zeros([2], dtype=np.bool), 106 | [1], 107 | ), ( 108 | np.zeros([2, 2, 2, 2], dtype=np.bool), 109 | np.zeros([2, 2, 2, 2], dtype=np.bool), 110 | [1, 1, 1, 1], 111 | )) 112 | def test_compute_surface_distances_raises_on_invalid_shapes( 113 | self, mask_gt, mask_pred, spacing_mm): 114 | with self.assertRaisesRegex(ValueError, 115 | 'Only 2D and 3D masks are supported'): 116 | surface_distance.compute_surface_distances(mask_gt, mask_pred, spacing_mm) 117 | 118 | 119 | class SurfaceDistance2DTest(SurfaceDistanceTest, parameterized.TestCase): 120 | 121 | def test_on_2_pixels_2mm_away(self): 122 | mask_gt = np.zeros((128, 128), np.bool) 123 | mask_pred = np.zeros((128, 128), np.bool) 124 | mask_gt[50, 70] = 1 125 | mask_pred[50, 72] = 1 126 | surface_distances = surface_distance.compute_surface_distances( 127 | mask_gt, mask_pred, spacing_mm=(2, 1)) 128 | 129 | diag = 0.5 * math.sqrt(2**2 + 1**2) 130 | expected_distances = { 131 | 'surfel_areas_gt': np.asarray([diag, diag, diag, diag]), 132 | 'surfel_areas_pred': np.asarray([diag, diag, diag, diag]), 133 | 'distances_gt_to_pred': np.asarray([1., 1., 2., 2.]), 134 | 'distances_pred_to_gt': np.asarray([1., 1., 2., 2.]), 135 | } 136 | self.assertEqual(len(expected_distances), len(surface_distances)) 137 | for key, expected_value in expected_distances.items(): 138 | np.testing.assert_array_equal(expected_value, surface_distances[key]) 139 | 140 | self._assert_metrics( 141 | surface_distances, 142 | mask_gt, 143 | mask_pred, 144 | expected_average_surface_distance=(1.5, 1.5), 145 | expected_hausdorff_100=2.0, 146 | expected_hausdorff_95=2.0, 147 | expected_surface_overlap_at_1mm=(0.5, 0.5), 148 | expected_surface_dice_at_1mm=0.5, 149 | expected_volumetric_dice=0.0) 150 | 151 | def test_two_squares_shifted_by_one_pixel(self): 152 | # We make sure we do not have active pixels on the border of the image, 153 | # because this will add additional 2D surfaces on the border of the image 154 | # because the image is padded with background. 155 | mask_gt = np.asarray( 156 | [ 157 | [0, 0, 0, 0, 0, 0], 158 | [0, 1, 1, 0, 0, 0], 159 | [0, 1, 1, 0, 0, 0], 160 | [0, 0, 0, 0, 0, 0], 161 | [0, 0, 0, 0, 0, 0], 162 | [0, 0, 0, 0, 0, 0], 163 | ], 164 | dtype=np.bool) 165 | 166 | mask_pred = np.asarray( 167 | [ 168 | [0, 0, 0, 0, 0, 0], 169 | [0, 1, 1, 0, 0, 0], 170 | [0, 1, 1, 0, 0, 0], 171 | [0, 1, 1, 0, 0, 0], 172 | [0, 0, 0, 0, 0, 0], 173 | [0, 0, 0, 0, 0, 0], 174 | ], 175 | dtype=np.bool) 176 | 177 | vertical = 2 178 | horizontal = 1 179 | diag = 0.5 * math.sqrt(horizontal**2 + vertical**2) 180 | surface_distances = surface_distance.compute_surface_distances( 181 | mask_gt, mask_pred, spacing_mm=(vertical, horizontal)) 182 | 183 | # We go from top left corner, clockwise to describe the surfaces and 184 | # distances. The 2 surfaces are: 185 | # 186 | # /-\ /-\ 187 | # | | | | 188 | # \-/ | | 189 | # \-/ 190 | expected_surfel_areas_gt = np.asarray( 191 | [diag, horizontal, diag, vertical, diag, horizontal, diag, vertical]) 192 | expected_surfel_areas_pred = np.asarray([ 193 | diag, horizontal, diag, vertical, vertical, diag, horizontal, diag, 194 | vertical, vertical 195 | ]) 196 | expected_distances_gt_to_pred = np.asarray([0] * 5 + [horizontal] + [0] * 2) 197 | expected_distances_pred_to_gt = np.asarray([0] * 5 + [vertical] * 3 + 198 | [0] * 2) 199 | 200 | # We sort these using the same sorting algorithm 201 | (expected_distances_gt_to_pred, expected_surfel_areas_gt) = ( 202 | metrics._sort_distances_surfels(expected_distances_gt_to_pred, 203 | expected_surfel_areas_gt)) 204 | (expected_distances_pred_to_gt, expected_surfel_areas_pred) = ( 205 | metrics._sort_distances_surfels(expected_distances_pred_to_gt, 206 | expected_surfel_areas_pred)) 207 | 208 | expected_distances = { 209 | 'surfel_areas_gt': expected_surfel_areas_gt, 210 | 'surfel_areas_pred': expected_surfel_areas_pred, 211 | 'distances_gt_to_pred': expected_distances_gt_to_pred, 212 | 'distances_pred_to_gt': expected_distances_pred_to_gt, 213 | } 214 | 215 | self.assertEqual(len(expected_distances), len(surface_distances)) 216 | for key, expected_value in expected_distances.items(): 217 | np.testing.assert_array_equal(expected_value, surface_distances[key]) 218 | 219 | self._assert_metrics( 220 | surface_distances, 221 | mask_gt, 222 | mask_pred, 223 | expected_average_surface_distance=( 224 | surface_distance.compute_average_surface_distance( 225 | expected_distances)), 226 | expected_hausdorff_100=(surface_distance.compute_robust_hausdorff( 227 | expected_distances, 100)), 228 | expected_hausdorff_95=surface_distance.compute_robust_hausdorff( 229 | expected_distances, 95), 230 | expected_surface_overlap_at_1mm=( 231 | surface_distance.compute_surface_overlap_at_tolerance( 232 | expected_distances, tolerance_mm=1)), 233 | expected_surface_dice_at_1mm=( 234 | surface_distance.compute_surface_dice_at_tolerance( 235 | surface_distances, tolerance_mm=1)), 236 | expected_volumetric_dice=(surface_distance.compute_dice_coefficient( 237 | mask_gt, mask_pred))) 238 | 239 | def test_empty_prediction_mask(self): 240 | mask_gt = np.zeros((128, 128), np.bool) 241 | mask_pred = np.zeros((128, 128), np.bool) 242 | mask_gt[50, 60] = 1 243 | surface_distances = surface_distance.compute_surface_distances( 244 | mask_gt, mask_pred, spacing_mm=(3, 2)) 245 | self._assert_metrics( 246 | surface_distances, 247 | mask_gt, 248 | mask_pred, 249 | expected_average_surface_distance=(np.inf, np.nan), 250 | expected_hausdorff_100=np.inf, 251 | expected_hausdorff_95=np.inf, 252 | expected_surface_overlap_at_1mm=(0.0, np.nan), 253 | expected_surface_dice_at_1mm=0.0, 254 | expected_volumetric_dice=0.0) 255 | 256 | def test_empty_ground_truth_mask(self): 257 | mask_gt = np.zeros((128, 128), np.bool) 258 | mask_pred = np.zeros((128, 128), np.bool) 259 | mask_pred[50, 60] = 1 260 | surface_distances = surface_distance.compute_surface_distances( 261 | mask_gt, mask_pred, spacing_mm=(3, 2)) 262 | self._assert_metrics( 263 | surface_distances, 264 | mask_gt, 265 | mask_pred, 266 | expected_average_surface_distance=(np.nan, np.inf), 267 | expected_hausdorff_100=np.inf, 268 | expected_hausdorff_95=np.inf, 269 | expected_surface_overlap_at_1mm=(np.nan, 0.0), 270 | expected_surface_dice_at_1mm=0.0, 271 | expected_volumetric_dice=0.0) 272 | 273 | def test_both_empty_masks(self): 274 | mask_gt = np.zeros((128, 128), np.bool) 275 | mask_pred = np.zeros((128, 128), np.bool) 276 | surface_distances = surface_distance.compute_surface_distances( 277 | mask_gt, mask_pred, spacing_mm=(3, 2)) 278 | self._assert_metrics( 279 | surface_distances, 280 | mask_gt, 281 | mask_pred, 282 | expected_average_surface_distance=(np.nan, np.nan), 283 | expected_hausdorff_100=np.inf, 284 | expected_hausdorff_95=np.inf, 285 | expected_surface_overlap_at_1mm=(np.nan, np.nan), 286 | expected_surface_dice_at_1mm=np.nan, 287 | expected_volumetric_dice=np.nan) 288 | 289 | 290 | class SurfaceDistance3DTest(SurfaceDistanceTest): 291 | 292 | def test_on_2_pixels_2mm_away(self): 293 | mask_gt = np.zeros((128, 128, 128), np.bool) 294 | mask_pred = np.zeros((128, 128, 128), np.bool) 295 | mask_gt[50, 60, 70] = 1 296 | mask_pred[50, 60, 72] = 1 297 | surface_distances = surface_distance.compute_surface_distances( 298 | mask_gt, mask_pred, spacing_mm=(3, 2, 1)) 299 | self._assert_metrics(surface_distances, mask_gt, mask_pred, 300 | expected_average_surface_distance=(1.5, 1.5), 301 | expected_hausdorff_100=2.0, 302 | expected_hausdorff_95=2.0, 303 | expected_surface_overlap_at_1mm=(0.5, 0.5), 304 | expected_surface_dice_at_1mm=0.5, 305 | expected_volumetric_dice=0.0) 306 | 307 | def test_two_cubes_shifted_by_one_pixel(self): 308 | mask_gt = np.zeros((100, 100, 100), np.bool) 309 | mask_pred = np.zeros((100, 100, 100), np.bool) 310 | mask_gt[0:50, :, :] = 1 311 | mask_pred[0:51, :, :] = 1 312 | surface_distances = surface_distance.compute_surface_distances( 313 | mask_gt, mask_pred, spacing_mm=(2, 1, 1)) 314 | self._assert_metrics( 315 | surface_distances, mask_gt, mask_pred, 316 | expected_average_surface_distance=(0.322, 0.339), 317 | expected_hausdorff_100=2.0, 318 | expected_hausdorff_95=2.0, 319 | expected_surface_overlap_at_1mm=(0.842, 0.830), 320 | expected_surface_dice_at_1mm=0.836, 321 | expected_volumetric_dice=0.990) 322 | 323 | def test_empty_prediction_mask(self): 324 | mask_gt = np.zeros((128, 128, 128), np.bool) 325 | mask_pred = np.zeros((128, 128, 128), np.bool) 326 | mask_gt[50, 60, 70] = 1 327 | surface_distances = surface_distance.compute_surface_distances( 328 | mask_gt, mask_pred, spacing_mm=(3, 2, 1)) 329 | self._assert_metrics( 330 | surface_distances, mask_gt, mask_pred, 331 | expected_average_surface_distance=(np.inf, np.nan), 332 | expected_hausdorff_100=np.inf, 333 | expected_hausdorff_95=np.inf, 334 | expected_surface_overlap_at_1mm=(0.0, np.nan), 335 | expected_surface_dice_at_1mm=0.0, 336 | expected_volumetric_dice=0.0) 337 | 338 | def test_empty_ground_truth_mask(self): 339 | mask_gt = np.zeros((128, 128, 128), np.bool) 340 | mask_pred = np.zeros((128, 128, 128), np.bool) 341 | mask_pred[50, 60, 72] = 1 342 | surface_distances = surface_distance.compute_surface_distances( 343 | mask_gt, mask_pred, spacing_mm=(3, 2, 1)) 344 | self._assert_metrics( 345 | surface_distances, mask_gt, mask_pred, 346 | expected_average_surface_distance=(np.nan, np.inf), 347 | expected_hausdorff_100=np.inf, 348 | expected_hausdorff_95=np.inf, 349 | expected_surface_overlap_at_1mm=(np.nan, 0.0), 350 | expected_surface_dice_at_1mm=0.0, 351 | expected_volumetric_dice=0.0) 352 | 353 | def test_both_empty_masks(self): 354 | mask_gt = np.zeros((128, 128, 128), np.bool) 355 | mask_pred = np.zeros((128, 128, 128), np.bool) 356 | surface_distances = surface_distance.compute_surface_distances( 357 | mask_gt, mask_pred, spacing_mm=(3, 2, 1)) 358 | self._assert_metrics( 359 | surface_distances, mask_gt, mask_pred, 360 | expected_average_surface_distance=(np.nan, np.nan), 361 | expected_hausdorff_100=np.inf, 362 | expected_hausdorff_95=np.inf, 363 | expected_surface_overlap_at_1mm=(np.nan, np.nan), 364 | expected_surface_dice_at_1mm=np.nan, 365 | expected_volumetric_dice=np.nan) 366 | 367 | 368 | if __name__ == '__main__': 369 | absltest.main() 370 | -------------------------------------------------------------------------------- /figures/Examples.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrGiovanni/SyntheticTumors/983dbeef8fb6f5d15ec4373a5241fc198b92d765/figures/Examples.gif -------------------------------------------------------------------------------- /figures/VisualTuringTest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrGiovanni/SyntheticTumors/983dbeef8fb6f5d15ec4373a5241fc198b92d765/figures/VisualTuringTest.png -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrGiovanni/SyntheticTumors/983dbeef8fb6f5d15ec4373a5241fc198b92d765/networks/__init__.py -------------------------------------------------------------------------------- /networks/basicunetplusplus.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from monai.networks.layers.factories import Conv 7 | from monai.networks.nets.basic_unet import Down, TwoConv, UpCat 8 | from monai.utils import ensure_tuple_rep 9 | 10 | __all__ = ["BasicUnetPlusPlus", "BasicunetPlusPlus", "basicunetplusplus", "BasicUNetPlusPlus"] 11 | 12 | 13 | class BasicUNetPlusPlus(nn.Module): 14 | def __init__( 15 | self, 16 | spatial_dims: int = 3, 17 | in_channels: int = 1, 18 | out_channels: int = 2, 19 | features: Sequence[int] = (32, 32, 64, 128, 256, 32), 20 | deep_supervision: bool = False, 21 | act: Union[str, tuple] = ("LeakyReLU", {"negative_slope": 0.1, "inplace": True}), 22 | norm: Union[str, tuple] = ("instance", {"affine": True}), 23 | bias: bool = True, 24 | dropout: Union[float, tuple] = 0.0, 25 | upsample: str = "deconv", 26 | ): 27 | """ 28 | A UNet++ implementation with 1D/2D/3D supports. 29 | 30 | Based on: 31 | 32 | Zhou et al. "UNet++: A Nested U-Net Architecture for Medical Image 33 | Segmentation". 4th Deep Learning in Medical Image Analysis (DLMIA) 34 | Workshop, DOI: https://doi.org/10.48550/arXiv.1807.10165 35 | 36 | 37 | Args: 38 | spatial_dims: number of spatial dimensions. Defaults to 3 for spatial 3D inputs. 39 | in_channels: number of input channels. Defaults to 1. 40 | out_channels: number of output channels. Defaults to 2. 41 | features: six integers as numbers of features. 42 | Defaults to ``(32, 32, 64, 128, 256, 32)``, 43 | 44 | - the first five values correspond to the five-level encoder feature sizes. 45 | - the last value corresponds to the feature size after the last upsampling. 46 | 47 | deep_supervision: whether to prune the network at inference time. Defaults to False. If true, returns a list, 48 | whose elements correspond to outputs at different nodes. 49 | act: activation type and arguments. Defaults to LeakyReLU. 50 | norm: feature normalization type and arguments. Defaults to instance norm. 51 | bias: whether to have a bias term in convolution blocks. Defaults to True. 52 | According to `Performance Tuning Guide `_, 53 | if a conv layer is directly followed by a batch norm layer, bias should be False. 54 | dropout: dropout ratio. Defaults to no dropout. 55 | upsample: upsampling mode, available options are 56 | ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. 57 | 58 | Examples:: 59 | 60 | # for spatial 2D 61 | >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128)) 62 | 63 | # for spatial 2D, with deep supervision enabled 64 | >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), deep_supervision=True) 65 | 66 | # for spatial 2D, with group norm 67 | >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4})) 68 | 69 | # for spatial 3D 70 | >>> net = BasicUNetPlusPlus(spatial_dims=3, features=(32, 32, 64, 128, 256, 32)) 71 | 72 | See Also 73 | - :py:class:`monai.networks.nets.BasicUNet` 74 | - :py:class:`monai.networks.nets.DynUNet` 75 | - :py:class:`monai.networks.nets.UNet` 76 | 77 | """ 78 | super().__init__() 79 | 80 | self.deep_supervision = deep_supervision 81 | 82 | fea = ensure_tuple_rep(features, 6) 83 | print(f"BasicUNetPlusPlus features: {fea}.") 84 | 85 | self.conv_0_0 = TwoConv(spatial_dims, in_channels, fea[0], act, norm, bias, dropout) 86 | self.conv_1_0 = Down(spatial_dims, fea[0], fea[1], act, norm, bias, dropout) 87 | self.conv_2_0 = Down(spatial_dims, fea[1], fea[2], act, norm, bias, dropout) 88 | self.conv_3_0 = Down(spatial_dims, fea[2], fea[3], act, norm, bias, dropout) 89 | self.conv_4_0 = Down(spatial_dims, fea[3], fea[4], act, norm, bias, dropout) 90 | 91 | self.upcat_0_1 = UpCat(spatial_dims, fea[1], fea[0], fea[0], act, norm, bias, dropout, upsample, halves=False) 92 | self.upcat_1_1 = UpCat(spatial_dims, fea[2], fea[1], fea[1], act, norm, bias, dropout, upsample) 93 | self.upcat_2_1 = UpCat(spatial_dims, fea[3], fea[2], fea[2], act, norm, bias, dropout, upsample) 94 | self.upcat_3_1 = UpCat(spatial_dims, fea[4], fea[3], fea[3], act, norm, bias, dropout, upsample) 95 | 96 | self.upcat_0_2 = UpCat( 97 | spatial_dims, fea[1], fea[0] * 2, fea[0], act, norm, bias, dropout, upsample, halves=False 98 | ) 99 | self.upcat_1_2 = UpCat(spatial_dims, fea[2], fea[1] * 2, fea[1], act, norm, bias, dropout, upsample) 100 | self.upcat_2_2 = UpCat(spatial_dims, fea[3], fea[2] * 2, fea[2], act, norm, bias, dropout, upsample) 101 | 102 | self.upcat_0_3 = UpCat( 103 | spatial_dims, fea[1], fea[0] * 3, fea[0], act, norm, bias, dropout, upsample, halves=False 104 | ) 105 | self.upcat_1_3 = UpCat(spatial_dims, fea[2], fea[1] * 3, fea[1], act, norm, bias, dropout, upsample) 106 | 107 | self.upcat_0_4 = UpCat( 108 | spatial_dims, fea[1], fea[0] * 4, fea[5], act, norm, bias, dropout, upsample, halves=False 109 | ) 110 | 111 | self.final_conv_0_1 = Conv["conv", spatial_dims](fea[0], out_channels, kernel_size=1) 112 | self.final_conv_0_2 = Conv["conv", spatial_dims](fea[0], out_channels, kernel_size=1) 113 | self.final_conv_0_3 = Conv["conv", spatial_dims](fea[0], out_channels, kernel_size=1) 114 | self.final_conv_0_4 = Conv["conv", spatial_dims](fea[5], out_channels, kernel_size=1) 115 | 116 | 117 | def forward(self, x: torch.Tensor): 118 | """ 119 | Args: 120 | x: input should have spatially N dimensions 121 | ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N-1])``, N is defined by `dimensions`. 122 | It is recommended to have ``dim_n % 16 == 0`` to ensure all maxpooling inputs have 123 | even edge lengths. 124 | 125 | Returns: 126 | A torch Tensor of "raw" predictions in shape 127 | ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N-1])``. 128 | """ 129 | x_0_0 = self.conv_0_0(x) 130 | x_1_0 = self.conv_1_0(x_0_0) 131 | x_0_1 = self.upcat_0_1(x_1_0, x_0_0) 132 | 133 | x_2_0 = self.conv_2_0(x_1_0) 134 | x_1_1 = self.upcat_1_1(x_2_0, x_1_0) 135 | x_0_2 = self.upcat_0_2(x_1_1, torch.cat([x_0_0, x_0_1], dim=1)) 136 | 137 | x_3_0 = self.conv_3_0(x_2_0) 138 | x_2_1 = self.upcat_2_1(x_3_0, x_2_0) 139 | x_1_2 = self.upcat_1_2(x_2_1, torch.cat([x_1_0, x_1_1], dim=1)) 140 | x_0_3 = self.upcat_0_3(x_1_2, torch.cat([x_0_0, x_0_1, x_0_2], dim=1)) 141 | 142 | x_4_0 = self.conv_4_0(x_3_0) 143 | x_3_1 = self.upcat_3_1(x_4_0, x_3_0) 144 | x_2_2 = self.upcat_2_2(x_3_1, torch.cat([x_2_0, x_2_1], dim=1)) 145 | x_1_3 = self.upcat_1_3(x_2_2, torch.cat([x_1_0, x_1_1, x_1_2], dim=1)) 146 | x_0_4 = self.upcat_0_4(x_1_3, torch.cat([x_0_0, x_0_1, x_0_2, x_0_3], dim=1)) 147 | 148 | output_0_1 = self.final_conv_0_1(x_0_1) 149 | output_0_2 = self.final_conv_0_2(x_0_2) 150 | output_0_3 = self.final_conv_0_3(x_0_3) 151 | output_0_4 = self.final_conv_0_4(x_0_4) 152 | 153 | if self.deep_supervision: 154 | output = [output_0_1, output_0_2, output_0_3, output_0_4] 155 | else: 156 | output = output_0_4 157 | 158 | return output 159 | 160 | 161 | 162 | BasicUnetPlusPlus = BasicunetPlusPlus = basicunetplusplus = BasicUNetPlusPlus -------------------------------------------------------------------------------- /networks/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import torch.nn as nn 13 | 14 | 15 | class MLPBlock(nn.Module): 16 | """ 17 | A multi-layer perceptron block, based on: "Dosovitskiy et al., 18 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 19 | """ 20 | 21 | def __init__( 22 | self, 23 | hidden_size: int, 24 | mlp_dim: int, 25 | dropout_rate: float = 0.0, 26 | ) -> None: 27 | """ 28 | Args: 29 | hidden_size: dimension of hidden layer. 30 | mlp_dim: dimension of feedforward layer. 31 | dropout_rate: faction of the input units to drop. 32 | 33 | """ 34 | 35 | super().__init__() 36 | 37 | if not (0 <= dropout_rate <= 1): 38 | raise AssertionError("dropout_rate should be between 0 and 1.") 39 | 40 | self.linear1 = nn.Linear(hidden_size, mlp_dim) 41 | self.linear2 = nn.Linear(mlp_dim, hidden_size) 42 | self.fn = nn.GELU() 43 | self.drop1 = nn.Dropout(dropout_rate) 44 | self.drop2 = nn.Dropout(dropout_rate) 45 | 46 | def forward(self, x): 47 | x = self.fn(self.linear1(x)) 48 | x = self.drop1(x) 49 | x = self.linear2(x) 50 | x = self.drop2(x) 51 | return x 52 | -------------------------------------------------------------------------------- /networks/patchembedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | import math 14 | from typing import Tuple, Union 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | from monai.utils import optional_import 20 | 21 | Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") 22 | 23 | 24 | class PatchEmbeddingBlock(nn.Module): 25 | """ 26 | A patch embedding block, based on: "Dosovitskiy et al., 27 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 28 | """ 29 | 30 | def __init__( 31 | self, 32 | in_channels: int, 33 | img_size: Tuple[int, int, int], 34 | patch_size: Tuple[int, int, int], 35 | hidden_size: int, 36 | num_heads: int, 37 | pos_embed: str, 38 | dropout_rate: float = 0.0, 39 | ) -> None: 40 | """ 41 | Args: 42 | in_channels: dimension of input channels. 43 | img_size: dimension of input image. 44 | patch_size: dimension of patch size. 45 | hidden_size: dimension of hidden layer. 46 | num_heads: number of attention heads. 47 | pos_embed: position embedding layer type. 48 | dropout_rate: faction of the input units to drop. 49 | 50 | """ 51 | 52 | super().__init__() 53 | 54 | if not (0 <= dropout_rate <= 1): 55 | raise AssertionError("dropout_rate should be between 0 and 1.") 56 | 57 | if hidden_size % num_heads != 0: 58 | raise AssertionError("hidden size should be divisible by num_heads.") 59 | 60 | for m, p in zip(img_size, patch_size): 61 | if m < p: 62 | raise AssertionError("patch_size should be smaller than img_size.") 63 | 64 | if pos_embed not in ["conv", "perceptron"]: 65 | raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") 66 | 67 | if pos_embed == "perceptron": 68 | if img_size[0] % patch_size[0] != 0: 69 | raise AssertionError("img_size should be divisible by patch_size for perceptron patch embedding.") 70 | 71 | self.n_patches = ( 72 | (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) * (img_size[2] // patch_size[2]) 73 | ) 74 | self.patch_dim = in_channels * patch_size[0] * patch_size[1] * patch_size[2] 75 | 76 | self.pos_embed = pos_embed 77 | self.patch_embeddings: Union[nn.Conv3d, nn.Sequential] 78 | if self.pos_embed == "conv": 79 | self.patch_embeddings = nn.Conv3d( 80 | in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size 81 | ) 82 | elif self.pos_embed == "perceptron": 83 | self.patch_embeddings = nn.Sequential( 84 | Rearrange( 85 | "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)", 86 | p1=patch_size[0], 87 | p2=patch_size[1], 88 | p3=patch_size[2], 89 | ), 90 | nn.Linear(self.patch_dim, hidden_size), 91 | ) 92 | self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) 93 | self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) 94 | self.dropout = nn.Dropout(dropout_rate) 95 | self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) 96 | self.apply(self._init_weights) 97 | 98 | def _init_weights(self, m): 99 | if isinstance(m, nn.Linear): 100 | self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0) 101 | if isinstance(m, nn.Linear) and m.bias is not None: 102 | nn.init.constant_(m.bias, 0) 103 | elif isinstance(m, nn.LayerNorm): 104 | nn.init.constant_(m.bias, 0) 105 | nn.init.constant_(m.weight, 1.0) 106 | 107 | def trunc_normal_(self, tensor, mean, std, a, b): 108 | # From PyTorch official master until it's in a few official releases - RW 109 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 110 | def norm_cdf(x): 111 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 112 | 113 | with torch.no_grad(): 114 | l = norm_cdf((a - mean) / std) 115 | u = norm_cdf((b - mean) / std) 116 | tensor.uniform_(2 * l - 1, 2 * u - 1) 117 | tensor.erfinv_() 118 | tensor.mul_(std * math.sqrt(2.0)) 119 | tensor.add_(mean) 120 | tensor.clamp_(min=a, max=b) 121 | return tensor 122 | 123 | def forward(self, x): 124 | if self.pos_embed == "conv": 125 | x = self.patch_embeddings(x) 126 | x = x.flatten(2) 127 | x = x.transpose(-1, -2) 128 | elif self.pos_embed == "perceptron": 129 | x = self.patch_embeddings(x) 130 | embeddings = x + self.position_embeddings 131 | embeddings = self.dropout(embeddings) 132 | return embeddings 133 | -------------------------------------------------------------------------------- /networks/selfattention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from monai.utils import optional_import 16 | 17 | einops, has_einops = optional_import("einops") 18 | 19 | 20 | class SABlock(nn.Module): 21 | """ 22 | A self-attention block, based on: "Dosovitskiy et al., 23 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 24 | """ 25 | 26 | def __init__( 27 | self, 28 | hidden_size: int, 29 | num_heads: int, 30 | dropout_rate: float = 0.0, 31 | ) -> None: 32 | """ 33 | Args: 34 | hidden_size: dimension of hidden layer. 35 | num_heads: number of attention heads. 36 | dropout_rate: faction of the input units to drop. 37 | 38 | """ 39 | 40 | super().__init__() 41 | 42 | if not (0 <= dropout_rate <= 1): 43 | raise AssertionError("dropout_rate should be between 0 and 1.") 44 | 45 | if hidden_size % num_heads != 0: 46 | raise AssertionError("hidden size should be divisible by num_heads.") 47 | 48 | self.num_heads = num_heads 49 | self.out_proj = nn.Linear(hidden_size, hidden_size) 50 | self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False) 51 | self.drop_output = nn.Dropout(dropout_rate) 52 | self.drop_weights = nn.Dropout(dropout_rate) 53 | self.head_dim = hidden_size // num_heads 54 | self.scale = self.head_dim ** -0.5 55 | if has_einops: 56 | self.rearrange = einops.rearrange 57 | else: 58 | raise ValueError('"Requires einops.') 59 | 60 | def forward(self, x): 61 | q, k, v = self.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads) 62 | att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) 63 | att_mat = self.drop_weights(att_mat) 64 | x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) 65 | x = self.rearrange(x, "b h l d -> b l (h d)") 66 | x = self.out_proj(x) 67 | x = self.drop_output(x) 68 | return x 69 | -------------------------------------------------------------------------------- /networks/swin3d_unetr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from typing import Tuple, Union 13 | 14 | import torch.nn as nn 15 | 16 | from monai.networks.blocks.dynunet_block import UnetOutBlock 17 | from networks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock 18 | from networks.swin_transformer_3d import SwinTransformer3D 19 | import pdb 20 | 21 | class SwinUNETR(nn.Module): 22 | """ 23 | UNETR based on: "Hatamizadeh et al., 24 | UNETR: Transformers for 3D Medical Image Segmentation " 25 | """ 26 | 27 | def __init__( 28 | self, 29 | in_channels: int, 30 | out_channels: int, 31 | img_size: Tuple[int, int, int], 32 | feature_size: int = 48, 33 | patch_size: int = 2, 34 | depths: Tuple[int, int, int, int] = [2, 2, 2, 2], 35 | num_heads: Tuple[int, int, int, int] = [3, 6, 12, 24], 36 | window_size: Tuple[int, int, int] = [7, 7, 7], 37 | norm_name: Union[Tuple, str] = "instance", 38 | conv_block: bool = False, 39 | res_block: bool = True, 40 | dropout_rate: float = 0.0, 41 | ) -> None: 42 | """ 43 | Args: 44 | in_channels: dimension of input channels. 45 | out_channels: dimension of output channels. 46 | img_size: dimension of input image. 47 | feature_size: dimension of network feature size. 48 | hidden_size: dimension of hidden layer. 49 | mlp_dim: dimension of feedforward layer. 50 | num_heads: number of attention heads. 51 | pos_embed: position embedding layer type. 52 | norm_name: feature normalization type and arguments. 53 | conv_block: bool argument to determine if convolutional block is used. 54 | res_block: bool argument to determine if residual block is used. 55 | dropout_rate: faction of the input units to drop. 56 | 57 | Examples:: 58 | 59 | # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm 60 | >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') 61 | 62 | # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm 63 | >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') 64 | 65 | """ 66 | 67 | super().__init__() 68 | 69 | if not (0 <= dropout_rate <= 1): 70 | raise AssertionError("dropout_rate should be between 0 and 1.") 71 | 72 | self.swinViT = SwinTransformer3D( 73 | pretrained=None, 74 | pretrained2d=False, 75 | patch_size=(patch_size, patch_size, patch_size), 76 | in_chans=in_channels, 77 | embed_dim=feature_size, 78 | depths = depths, 79 | num_heads=num_heads, 80 | window_size=window_size, 81 | mlp_ratio=4., 82 | qkv_bias=True, 83 | qk_scale=None, 84 | drop_rate=0., 85 | attn_drop_rate=0., 86 | drop_path_rate=0.0, 87 | norm_layer=nn.LayerNorm, 88 | ) 89 | self.encoder1 = UnetrBasicBlock( 90 | spatial_dims=3, 91 | in_channels=in_channels, 92 | out_channels=feature_size, 93 | kernel_size=3, 94 | stride=1, 95 | norm_name=norm_name, 96 | res_block=res_block, 97 | ) 98 | 99 | self.encoder10 = UnetrBasicBlock( 100 | spatial_dims=3, 101 | in_channels=16*feature_size, 102 | out_channels=16*feature_size, 103 | kernel_size=3, 104 | stride=1, 105 | norm_name=norm_name, 106 | res_block=res_block) 107 | 108 | self.decoder5 = UnetrUpBlock( 109 | spatial_dims=3, 110 | in_channels=16*feature_size, 111 | out_channels=8*feature_size, 112 | stride=1, 113 | kernel_size=3, 114 | upsample_kernel_size=2, 115 | norm_name=norm_name, 116 | res_block=res_block, 117 | ) 118 | 119 | self.decoder4 = UnetrUpBlock( 120 | spatial_dims=3, 121 | in_channels=feature_size * 8, 122 | out_channels=feature_size * 4, 123 | stride=1, 124 | kernel_size=3, 125 | upsample_kernel_size=2, 126 | norm_name=norm_name, 127 | res_block=res_block, 128 | ) 129 | 130 | self.decoder3 = UnetrUpBlock( 131 | spatial_dims=3, 132 | in_channels=feature_size * 4, 133 | out_channels=feature_size * 2, 134 | stride=1, 135 | kernel_size=3, 136 | upsample_kernel_size=2, 137 | norm_name=norm_name, 138 | res_block=res_block, 139 | ) 140 | self.decoder2 = UnetrUpBlock( 141 | spatial_dims=3, 142 | in_channels=feature_size * 2, 143 | out_channels=feature_size, 144 | stride=1, 145 | kernel_size=3, 146 | upsample_kernel_size=2, 147 | norm_name=norm_name, 148 | res_block=res_block, 149 | ) 150 | 151 | self.decoder1 = UnetrUpBlock( 152 | spatial_dims=3, 153 | in_channels=feature_size, 154 | out_channels=feature_size, 155 | stride=1, 156 | kernel_size=3, 157 | upsample_kernel_size=2, 158 | norm_name=norm_name, 159 | res_block=res_block, 160 | ) 161 | 162 | self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) # type: ignore 163 | 164 | 165 | def proj_feat(self, x, hidden_size, feat_size): 166 | x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) 167 | x = x.permute(0, 4, 1, 2, 3).contiguous() 168 | return x 169 | 170 | def forward(self, x_in): 171 | hidden_states_out = self.swinViT(x_in) 172 | enc0 = self.encoder1(x_in) 173 | x1 = hidden_states_out[0] 174 | enc1 = x1 175 | x2 = hidden_states_out[1] 176 | enc2 = x2 177 | x3 = hidden_states_out[2] 178 | enc3 = x3 179 | x4 = hidden_states_out[3] 180 | enc4 = x4 181 | dec4 = hidden_states_out[4] 182 | dec4 = self.encoder10(dec4) 183 | dec3 = self.decoder5(dec4, enc4) 184 | dec2 = self.decoder4(dec3, enc3) 185 | dec1 = self.decoder3(dec2, enc2) 186 | dec0 = self.decoder2(dec1, enc1) 187 | out = self.decoder1(dec0, enc0) 188 | logits = self.out(out) 189 | return logits 190 | -------------------------------------------------------------------------------- /networks/swin3d_unetrv2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from typing import Tuple, Union 13 | import torch 14 | import torch.nn as nn 15 | 16 | from monai.networks.blocks.dynunet_block import UnetOutBlock 17 | from networks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock 18 | from networks.swin_transformer_3d import SwinTransformer3D 19 | import pdb 20 | 21 | class SwinUNETR(nn.Module): 22 | """ 23 | UNETR based on: "Hatamizadeh et al., 24 | UNETR: Transformers for 3D Medical Image Segmentation " 25 | """ 26 | 27 | def __init__( 28 | self, 29 | in_channels: int, 30 | out_channels: int, 31 | img_size: Tuple[int, int, int], 32 | feature_size: int = 48, 33 | patch_size: int = 2, 34 | depths: Tuple[int, int, int, int] = [2, 2, 2, 2], 35 | num_heads: Tuple[int, int, int, int] = [3, 6, 12, 24], 36 | window_size: Tuple[int, int, int] = [7, 7, 7], 37 | norm_name: Union[Tuple, str] = "instance", 38 | conv_block: bool = False, 39 | res_block: bool = True, 40 | dropout_rate: float = 0.0, 41 | ) -> None: 42 | """ 43 | Args: 44 | in_channels: dimension of input channels. 45 | out_channels: dimension of output channels. 46 | img_size: dimension of input image. 47 | feature_size: dimension of network feature size. 48 | hidden_size: dimension of hidden layer. 49 | mlp_dim: dimension of feedforward layer. 50 | num_heads: number of attention heads. 51 | pos_embed: position embedding layer type. 52 | norm_name: feature normalization type and arguments. 53 | conv_block: bool argument to determine if convolutional block is used. 54 | res_block: bool argument to determine if residual block is used. 55 | dropout_rate: faction of the input units to drop. 56 | 57 | Examples:: 58 | 59 | # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm 60 | >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') 61 | 62 | # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm 63 | >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') 64 | 65 | """ 66 | 67 | super().__init__() 68 | 69 | if not (0 <= dropout_rate <= 1): 70 | raise AssertionError("dropout_rate should be between 0 and 1.") 71 | 72 | self.swinViT = SwinTransformer3D( 73 | pretrained=None, 74 | pretrained2d=False, 75 | patch_size=(patch_size, patch_size, patch_size), 76 | in_chans=in_channels, 77 | embed_dim=feature_size, 78 | depths = depths, 79 | num_heads=num_heads, 80 | window_size=window_size, 81 | mlp_ratio=4., 82 | qkv_bias=True, 83 | qk_scale=None, 84 | drop_rate=0., 85 | attn_drop_rate=0., 86 | drop_path_rate=0.0, 87 | norm_layer=nn.LayerNorm, 88 | ) 89 | self.encoder1 = UnetrBasicBlock( 90 | spatial_dims=3, 91 | in_channels=in_channels, 92 | out_channels=feature_size, 93 | kernel_size=3, 94 | stride=1, 95 | norm_name=norm_name, 96 | res_block=res_block, 97 | ) 98 | 99 | self.encoder2 = UnetrBasicBlock( 100 | spatial_dims=3, 101 | in_channels=feature_size, 102 | out_channels=feature_size, 103 | kernel_size=3, 104 | stride=1, 105 | norm_name=norm_name, 106 | res_block=res_block, 107 | ) 108 | 109 | self.encoder3 = UnetrBasicBlock( 110 | spatial_dims=3, 111 | in_channels=2 * feature_size, 112 | out_channels=2 * feature_size, 113 | kernel_size=3, 114 | stride=1, 115 | norm_name=norm_name, 116 | res_block=res_block, 117 | ) 118 | 119 | self.encoder4 = UnetrBasicBlock( 120 | spatial_dims=3, 121 | in_channels=4 * feature_size, 122 | out_channels=4 * feature_size, 123 | kernel_size=3, 124 | stride=1, 125 | norm_name=norm_name, 126 | res_block=res_block, 127 | ) 128 | 129 | self.encoder10 = UnetrBasicBlock( 130 | spatial_dims=3, 131 | in_channels=16*feature_size, 132 | out_channels=16*feature_size, 133 | kernel_size=3, 134 | stride=1, 135 | norm_name=norm_name, 136 | res_block=res_block) 137 | 138 | self.decoder5 = UnetrUpBlock( 139 | spatial_dims=3, 140 | in_channels=16*feature_size, 141 | out_channels=8*feature_size, 142 | stride=1, 143 | kernel_size=3, 144 | upsample_kernel_size=2, 145 | norm_name=norm_name, 146 | res_block=res_block, 147 | ) 148 | 149 | self.decoder4 = UnetrUpBlock( 150 | spatial_dims=3, 151 | in_channels=feature_size * 8, 152 | out_channels=feature_size * 4, 153 | stride=1, 154 | kernel_size=3, 155 | upsample_kernel_size=2, 156 | norm_name=norm_name, 157 | res_block=res_block, 158 | ) 159 | 160 | self.decoder3 = UnetrUpBlock( 161 | spatial_dims=3, 162 | in_channels=feature_size * 4, 163 | out_channels=feature_size * 2, 164 | stride=1, 165 | kernel_size=3, 166 | upsample_kernel_size=2, 167 | norm_name=norm_name, 168 | res_block=res_block, 169 | ) 170 | self.decoder2 = UnetrUpBlock( 171 | spatial_dims=3, 172 | in_channels=feature_size * 2, 173 | out_channels=feature_size, 174 | stride=1, 175 | kernel_size=3, 176 | upsample_kernel_size=2, 177 | norm_name=norm_name, 178 | res_block=res_block, 179 | ) 180 | 181 | self.decoder1 = UnetrUpBlock( 182 | spatial_dims=3, 183 | in_channels=feature_size, 184 | out_channels=feature_size, 185 | stride=1, 186 | kernel_size=3, 187 | upsample_kernel_size=2, 188 | norm_name=norm_name, 189 | res_block=res_block, 190 | ) 191 | 192 | self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) # type: ignore 193 | 194 | 195 | def proj_feat(self, x, hidden_size, feat_size): 196 | x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) 197 | x = x.permute(0, 4, 1, 2, 3).contiguous() 198 | return x 199 | 200 | def load_from(self, weights): 201 | with torch.no_grad(): 202 | res_weight = weights 203 | # copy weights from patch embedding 204 | for i in weights['state_dict']: 205 | print(i) 206 | self.swinViT.patch_embed.proj.weight.copy_(weights['state_dict']['module.patch_embed.proj.weight']) 207 | self.swinViT.patch_embed.proj.bias.copy_(weights['state_dict']['module.patch_embed.proj.bias']) 208 | 209 | # layer1 210 | for bname, block in self.swinViT.layers1[0].blocks.named_children(): 211 | block.loadFrom(weights, n_block=bname, layer='layers1') 212 | self.swinViT.layers1[0].downsample.reduction.weight.copy_(weights['state_dict']['module.layers1.0.downsample.reduction.weight']) 213 | self.swinViT.layers1[0].downsample.norm.weight.copy_(weights['state_dict']['module.layers1.0.downsample.norm.weight']) 214 | self.swinViT.layers1[0].downsample.norm.bias.copy_(weights['state_dict']['module.layers1.0.downsample.norm.bias']) 215 | # layer2 216 | for bname, block in self.swinViT.layers2[0].blocks.named_children(): 217 | block.loadFrom(weights, n_block=bname, layer='layers2') 218 | self.swinViT.layers2[0].downsample.reduction.weight.copy_(weights['state_dict']['module.layers2.0.downsample.reduction.weight']) 219 | self.swinViT.layers2[0].downsample.norm.weight.copy_(weights['state_dict']['module.layers2.0.downsample.norm.weight']) 220 | self.swinViT.layers2[0].downsample.norm.bias.copy_(weights['state_dict']['module.layers2.0.downsample.norm.bias']) 221 | # layer3 222 | for bname, block in self.swinViT.layers3[0].blocks.named_children(): 223 | block.loadFrom(weights, n_block=bname, layer='layers3') 224 | self.swinViT.layers3[0].downsample.reduction.weight.copy_(weights['state_dict']['module.layers3.0.downsample.reduction.weight']) 225 | self.swinViT.layers3[0].downsample.norm.weight.copy_(weights['state_dict']['module.layers3.0.downsample.norm.weight']) 226 | self.swinViT.layers3[0].downsample.norm.bias.copy_(weights['state_dict']['module.layers3.0.downsample.norm.bias']) 227 | # layer4 228 | for bname, block in self.swinViT.layers4[0].blocks.named_children(): 229 | block.loadFrom(weights, n_block=bname, layer='layers4') 230 | self.swinViT.layers4[0].downsample.reduction.weight.copy_(weights['state_dict']['module.layers4.0.downsample.reduction.weight']) 231 | self.swinViT.layers4[0].downsample.norm.weight.copy_(weights['state_dict']['module.layers4.0.downsample.norm.weight']) 232 | self.swinViT.layers4[0].downsample.norm.bias.copy_(weights['state_dict']['module.layers4.0.downsample.norm.bias']) 233 | 234 | 235 | # last norm layer of transformer 236 | self.swinViT.norm.weight.copy_(weights['state_dict']['module.norm.weight']) 237 | self.swinViT.norm.bias.copy_(weights['state_dict']['module.norm.bias']) 238 | 239 | def forward(self, x_in): 240 | hidden_states_out = self.swinViT(x_in) 241 | enc0 = self.encoder1(x_in) 242 | x1 = hidden_states_out[0] 243 | enc1 = self.encoder2(x1) 244 | x2 = hidden_states_out[1] 245 | enc2 = self.encoder3(x2) 246 | x3 = hidden_states_out[2] 247 | enc3 = self.encoder4(x3) 248 | x4 = hidden_states_out[3] 249 | enc4 = x4 250 | dec4 = hidden_states_out[4] 251 | dec4 = self.encoder10(dec4) 252 | dec3 = self.decoder5(dec4, enc4) 253 | dec2 = self.decoder4(dec3, enc3) 254 | dec1 = self.decoder3(dec2, enc2) 255 | dec0 = self.decoder2(dec1, enc1) 256 | out = self.decoder1(dec0, enc0) 257 | logits = self.out(out) 258 | return logits 259 | -------------------------------------------------------------------------------- /networks/transformerblock.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from networks.mlp import MLPBlock 16 | from networks.selfattention import SABlock 17 | 18 | 19 | class TransformerBlock(nn.Module): 20 | """ 21 | A transformer block, based on: "Dosovitskiy et al., 22 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 23 | """ 24 | 25 | def __init__( 26 | self, 27 | hidden_size: int, 28 | mlp_dim: int, 29 | num_heads: int, 30 | dropout_rate: float = 0.0, 31 | ) -> None: 32 | """ 33 | Args: 34 | hidden_size: dimension of hidden layer. 35 | mlp_dim: dimension of feedforward layer. 36 | num_heads: number of attention heads. 37 | dropout_rate: faction of the input units to drop. 38 | 39 | """ 40 | 41 | super().__init__() 42 | 43 | if not (0 <= dropout_rate <= 1): 44 | raise AssertionError("dropout_rate should be between 0 and 1.") 45 | 46 | if hidden_size % num_heads != 0: 47 | raise AssertionError("hidden size should be divisible by num_heads.") 48 | 49 | self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) 50 | self.norm1 = nn.LayerNorm(hidden_size) 51 | self.attn = SABlock(hidden_size, num_heads, dropout_rate) 52 | self.norm2 = nn.LayerNorm(hidden_size) 53 | 54 | def loadFrom(self, weights, n_block): 55 | ROOT = f"module.transformer.blocks.{n_block}." 56 | block_names = ['mlp.linear1.weight', 'mlp.linear1.bias', 'mlp.linear2.weight', 'mlp.linear2.bias', 'norm1.weight',\ 57 | 'norm1.bias', 'attn.out_proj.weight', 'attn.out_proj.bias', 'attn.qkv.weight', 'norm2.weight',\ 58 | 'norm2.bias'] 59 | with torch.no_grad(): 60 | self.mlp.linear1.weight.copy_(weights['state_dict'][ROOT+block_names[0]]) 61 | self.mlp.linear1.bias.copy_(weights['state_dict'][ROOT+block_names[1]]) 62 | self.mlp.linear2.weight.copy_(weights['state_dict'][ROOT+block_names[2]]) 63 | self.mlp.linear2.bias.copy_(weights['state_dict'][ROOT+block_names[3]]) 64 | self.norm1.weight.copy_(weights['state_dict'][ROOT+block_names[4]]) 65 | self.norm1.bias.copy_(weights['state_dict'][ROOT+block_names[5]]) 66 | self.attn.out_proj.weight.copy_(weights['state_dict'][ROOT+block_names[6]]) 67 | self.attn.out_proj.bias.copy_(weights['state_dict'][ROOT+block_names[7]]) 68 | self.attn.qkv.weight.copy_(weights['state_dict'][ROOT+block_names[8]]) 69 | self.norm2.weight.copy_(weights['state_dict'][ROOT+block_names[9]]) 70 | self.norm2.bias.copy_(weights['state_dict'][ROOT+block_names[10]]) 71 | 72 | 73 | def forward(self, x): 74 | x = x + self.attn(self.norm1(x)) 75 | x = x + self.mlp(self.norm2(x)) 76 | return x 77 | -------------------------------------------------------------------------------- /networks/unetr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from typing import Tuple, Union 13 | import torch 14 | import torch.nn as nn 15 | 16 | from monai.networks.blocks.dynunet_block import UnetOutBlock 17 | from networks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock 18 | from networks.vit import ViT 19 | 20 | 21 | class UNETR(nn.Module): 22 | """ 23 | UNETR based on: "Hatamizadeh et al., 24 | UNETR: Transformers for 3D Medical Image Segmentation " 25 | """ 26 | 27 | def __init__( 28 | self, 29 | in_channels: int, 30 | out_channels: int, 31 | img_size: Tuple[int, int, int], 32 | feature_size: int = 16, 33 | hidden_size: int = 768, 34 | mlp_dim: int = 3072, 35 | num_heads: int = 12, 36 | pos_embed: str = "perceptron", 37 | norm_name: Union[Tuple, str] = "instance", 38 | conv_block: bool = False, 39 | res_block: bool = True, 40 | dropout_rate: float = 0.0, 41 | ) -> None: 42 | """ 43 | Args: 44 | in_channels: dimension of input channels. 45 | out_channels: dimension of output channels. 46 | img_size: dimension of input image. 47 | feature_size: dimension of network feature size. 48 | hidden_size: dimension of hidden layer. 49 | mlp_dim: dimension of feedforward layer. 50 | num_heads: number of attention heads. 51 | pos_embed: position embedding layer type. 52 | norm_name: feature normalization type and arguments. 53 | conv_block: bool argument to determine if convolutional block is used. 54 | res_block: bool argument to determine if residual block is used. 55 | dropout_rate: faction of the input units to drop. 56 | 57 | Examples:: 58 | 59 | # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm 60 | >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') 61 | 62 | # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm 63 | >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') 64 | 65 | """ 66 | 67 | super().__init__() 68 | 69 | if not (0 <= dropout_rate <= 1): 70 | raise AssertionError("dropout_rate should be between 0 and 1.") 71 | 72 | if hidden_size % num_heads != 0: 73 | raise AssertionError("hidden size should be divisible by num_heads.") 74 | 75 | if pos_embed not in ["conv", "perceptron"]: 76 | raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") 77 | 78 | self.num_layers = 12 79 | self.patch_size = (16, 16, 16) 80 | self.feat_size = ( 81 | img_size[0] // self.patch_size[0], 82 | img_size[1] // self.patch_size[1], 83 | img_size[2] // self.patch_size[2], 84 | ) 85 | self.hidden_size = hidden_size 86 | self.classification = False 87 | self.vit = ViT( 88 | in_channels=in_channels, 89 | img_size=img_size, 90 | patch_size=self.patch_size, 91 | hidden_size=hidden_size, 92 | mlp_dim=mlp_dim, 93 | num_layers=self.num_layers, 94 | num_heads=num_heads, 95 | pos_embed=pos_embed, 96 | classification=self.classification, 97 | dropout_rate=dropout_rate, 98 | ) 99 | self.encoder1 = UnetrBasicBlock( 100 | spatial_dims=3, 101 | in_channels=in_channels, 102 | out_channels=feature_size, 103 | kernel_size=3, 104 | stride=1, 105 | norm_name=norm_name, 106 | res_block=res_block, 107 | ) 108 | self.encoder2 = UnetrPrUpBlock( 109 | spatial_dims=3, 110 | in_channels=hidden_size, 111 | out_channels=feature_size * 2, 112 | num_layer=2, 113 | kernel_size=3, 114 | stride=1, 115 | upsample_kernel_size=2, 116 | norm_name=norm_name, 117 | conv_block=conv_block, 118 | res_block=res_block, 119 | ) 120 | self.encoder3 = UnetrPrUpBlock( 121 | spatial_dims=3, 122 | in_channels=hidden_size, 123 | out_channels=feature_size * 4, 124 | num_layer=1, 125 | kernel_size=3, 126 | stride=1, 127 | upsample_kernel_size=2, 128 | norm_name=norm_name, 129 | conv_block=conv_block, 130 | res_block=res_block, 131 | ) 132 | self.encoder4 = UnetrPrUpBlock( 133 | spatial_dims=3, 134 | in_channels=hidden_size, 135 | out_channels=feature_size * 8, 136 | num_layer=0, 137 | kernel_size=3, 138 | stride=1, 139 | upsample_kernel_size=2, 140 | norm_name=norm_name, 141 | conv_block=conv_block, 142 | res_block=res_block, 143 | ) 144 | self.decoder5 = UnetrUpBlock( 145 | spatial_dims=3, 146 | in_channels=hidden_size, 147 | out_channels=feature_size * 8, 148 | stride=1, 149 | kernel_size=3, 150 | upsample_kernel_size=2, 151 | norm_name=norm_name, 152 | res_block=res_block, 153 | ) 154 | self.decoder4 = UnetrUpBlock( 155 | spatial_dims=3, 156 | in_channels=feature_size * 8, 157 | out_channels=feature_size * 4, 158 | stride=1, 159 | kernel_size=3, 160 | upsample_kernel_size=2, 161 | norm_name=norm_name, 162 | res_block=res_block, 163 | ) 164 | self.decoder3 = UnetrUpBlock( 165 | spatial_dims=3, 166 | in_channels=feature_size * 4, 167 | out_channels=feature_size * 2, 168 | stride=1, 169 | kernel_size=3, 170 | upsample_kernel_size=2, 171 | norm_name=norm_name, 172 | res_block=res_block, 173 | ) 174 | self.decoder2 = UnetrUpBlock( 175 | spatial_dims=3, 176 | in_channels=feature_size * 2, 177 | out_channels=feature_size, 178 | stride=1, 179 | kernel_size=3, 180 | upsample_kernel_size=2, 181 | norm_name=norm_name, 182 | res_block=res_block, 183 | ) 184 | self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) # type: ignore 185 | 186 | def proj_feat(self, x, hidden_size, feat_size): 187 | x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) 188 | x = x.permute(0, 4, 1, 2, 3).contiguous() 189 | return x 190 | 191 | def load_from(self, weights): 192 | with torch.no_grad(): 193 | res_weight = weights 194 | # copy weights from patch embedding 195 | for i in weights['state_dict']: 196 | print(i) 197 | self.vit.patch_embedding.position_embeddings.copy_(weights['state_dict']['module.transformer.patch_embedding.position_embeddings_3d']) 198 | self.vit.patch_embedding.cls_token.copy_(weights['state_dict']['module.transformer.patch_embedding.cls_token']) 199 | self.vit.patch_embedding.patch_embeddings[1].weight.copy_(weights['state_dict']['module.transformer.patch_embedding.patch_embeddings.1.weight']) 200 | self.vit.patch_embedding.patch_embeddings[1].bias.copy_(weights['state_dict']['module.transformer.patch_embedding.patch_embeddings.1.bias']) 201 | 202 | # copy weights from encoding blocks (default: num of blocks: 12) 203 | for bname, block in self.vit.blocks.named_children(): 204 | print(block) 205 | block.loadFrom(weights, n_block=bname) 206 | # last norm layer of transformer 207 | self.vit.norm.weight.copy_(weights['state_dict']['module.transformer.norm.weight']) 208 | self.vit.norm.bias.copy_(weights['state_dict']['module.transformer.norm.bias']) 209 | 210 | 211 | def forward(self, x_in): 212 | x, hidden_states_out = self.vit(x_in) 213 | enc1 = self.encoder1(x_in) 214 | x2 = hidden_states_out[3] 215 | enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) 216 | x3 = hidden_states_out[6] 217 | enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) 218 | x4 = hidden_states_out[9] 219 | enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) 220 | dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) 221 | dec3 = self.decoder5(dec4, enc4) 222 | dec2 = self.decoder4(dec3, enc3) 223 | dec1 = self.decoder3(dec2, enc2) 224 | out = self.decoder2(dec1, enc1) 225 | logits = self.out(out) 226 | return logits 227 | -------------------------------------------------------------------------------- /networks/unetr_block.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | from typing import Sequence, Tuple, Union 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer 19 | 20 | 21 | class UnetrUpBlock(nn.Module): 22 | """ 23 | An upsampling module that can be used for UNETR: "Hatamizadeh et al., 24 | UNETR: Transformers for 3D Medical Image Segmentation " 25 | """ 26 | 27 | def __init__( 28 | self, 29 | spatial_dims: int, 30 | in_channels: int, 31 | out_channels: int, # type: ignore 32 | kernel_size: Union[Sequence[int], int], 33 | stride: Union[Sequence[int], int], 34 | upsample_kernel_size: Union[Sequence[int], int], 35 | norm_name: Union[Tuple, str], 36 | res_block: bool = False, 37 | ) -> None: 38 | """ 39 | Args: 40 | spatial_dims: number of spatial dimensions. 41 | in_channels: number of input channels. 42 | out_channels: number of output channels. 43 | kernel_size: convolution kernel size. 44 | stride: convolution stride. 45 | upsample_kernel_size: convolution kernel size for transposed convolution layers. 46 | norm_name: feature normalization type and arguments. 47 | res_block: bool argument to determine if residual block is used. 48 | 49 | """ 50 | 51 | super(UnetrUpBlock, self).__init__() 52 | upsample_stride = upsample_kernel_size 53 | self.transp_conv = get_conv_layer( 54 | spatial_dims, 55 | in_channels, 56 | out_channels, 57 | kernel_size=upsample_kernel_size, 58 | stride=upsample_stride, 59 | conv_only=True, 60 | is_transposed=True, 61 | ) 62 | 63 | if res_block: 64 | self.conv_block = UnetResBlock( 65 | spatial_dims, 66 | out_channels + out_channels, 67 | out_channels, 68 | kernel_size=kernel_size, 69 | stride=1, 70 | norm_name=norm_name, 71 | ) 72 | else: 73 | self.conv_block = UnetBasicBlock( # type: ignore 74 | spatial_dims, 75 | out_channels + out_channels, 76 | out_channels, 77 | kernel_size=kernel_size, 78 | stride=1, 79 | norm_name=norm_name, 80 | ) 81 | 82 | def forward(self, inp, skip): 83 | # number of channels for skip should equals to out_channels 84 | out = self.transp_conv(inp) 85 | out = torch.cat((out, skip), dim=1) 86 | out = self.conv_block(out) 87 | return out 88 | 89 | 90 | class UnetrPrUpBlock(nn.Module): 91 | """ 92 | A projection upsampling module that can be used for UNETR: "Hatamizadeh et al., 93 | UNETR: Transformers for 3D Medical Image Segmentation " 94 | """ 95 | 96 | def __init__( 97 | self, 98 | spatial_dims: int, 99 | in_channels: int, 100 | out_channels: int, 101 | num_layer: int, 102 | kernel_size: Union[Sequence[int], int], 103 | stride: Union[Sequence[int], int], 104 | upsample_kernel_size: Union[Sequence[int], int], 105 | norm_name: Union[Tuple, str], 106 | conv_block: bool = False, 107 | res_block: bool = False, 108 | ) -> None: 109 | """ 110 | Args: 111 | spatial_dims: number of spatial dimensions. 112 | in_channels: number of input channels. 113 | out_channels: number of output channels. 114 | num_layer: number of upsampling blocks. 115 | kernel_size: convolution kernel size. 116 | stride: convolution stride. 117 | upsample_kernel_size: convolution kernel size for transposed convolution layers. 118 | norm_name: feature normalization type and arguments. 119 | conv_block: bool argument to determine if convolutional block is used. 120 | res_block: bool argument to determine if residual block is used. 121 | 122 | """ 123 | 124 | super().__init__() 125 | 126 | upsample_stride = upsample_kernel_size 127 | self.transp_conv_init = get_conv_layer( 128 | spatial_dims, 129 | in_channels, 130 | out_channels, 131 | kernel_size=upsample_kernel_size, 132 | stride=upsample_stride, 133 | conv_only=True, 134 | is_transposed=True, 135 | ) 136 | if conv_block: 137 | if res_block: 138 | self.blocks = nn.ModuleList( 139 | [ 140 | nn.Sequential( 141 | get_conv_layer( 142 | spatial_dims, 143 | out_channels, 144 | out_channels, 145 | kernel_size=upsample_kernel_size, 146 | stride=upsample_stride, 147 | conv_only=True, 148 | is_transposed=True, 149 | ), 150 | UnetResBlock( 151 | spatial_dims=3, 152 | in_channels=out_channels, 153 | out_channels=out_channels, 154 | kernel_size=kernel_size, 155 | stride=stride, 156 | norm_name=norm_name, 157 | ), 158 | ) 159 | for i in range(num_layer) 160 | ] 161 | ) 162 | else: 163 | self.blocks = nn.ModuleList( 164 | [ 165 | nn.Sequential( 166 | get_conv_layer( 167 | spatial_dims, 168 | out_channels, 169 | out_channels, 170 | kernel_size=upsample_kernel_size, 171 | stride=upsample_stride, 172 | conv_only=True, 173 | is_transposed=True, 174 | ), 175 | UnetBasicBlock( 176 | spatial_dims=3, 177 | in_channels=out_channels, 178 | out_channels=out_channels, 179 | kernel_size=kernel_size, 180 | stride=stride, 181 | norm_name=norm_name, 182 | ), 183 | ) 184 | for i in range(num_layer) 185 | ] 186 | ) 187 | else: 188 | self.blocks = nn.ModuleList( 189 | [ 190 | get_conv_layer( 191 | spatial_dims, 192 | out_channels, 193 | out_channels, 194 | kernel_size=upsample_kernel_size, 195 | stride=upsample_stride, 196 | conv_only=True, 197 | is_transposed=True, 198 | ) 199 | for i in range(num_layer) 200 | ] 201 | ) 202 | 203 | def forward(self, x): 204 | x = self.transp_conv_init(x) 205 | for blk in self.blocks: 206 | x = blk(x) 207 | return x 208 | 209 | 210 | class UnetrBasicBlock(nn.Module): 211 | """ 212 | A CNN module that can be used for UNETR, based on: "Hatamizadeh et al., 213 | UNETR: Transformers for 3D Medical Image Segmentation " 214 | """ 215 | 216 | def __init__( 217 | self, 218 | spatial_dims: int, 219 | in_channels: int, 220 | out_channels: int, 221 | kernel_size: Union[Sequence[int], int], 222 | stride: Union[Sequence[int], int], 223 | norm_name: Union[Tuple, str], 224 | res_block: bool = False, 225 | ) -> None: 226 | """ 227 | Args: 228 | spatial_dims: number of spatial dimensions. 229 | in_channels: number of input channels. 230 | out_channels: number of output channels. 231 | kernel_size: convolution kernel size. 232 | stride: convolution stride. 233 | norm_name: feature normalization type and arguments. 234 | res_block: bool argument to determine if residual block is used. 235 | 236 | """ 237 | 238 | super().__init__() 239 | 240 | if res_block: 241 | self.layer = UnetResBlock( 242 | spatial_dims=spatial_dims, 243 | in_channels=in_channels, 244 | out_channels=out_channels, 245 | kernel_size=kernel_size, 246 | stride=stride, 247 | norm_name=norm_name, 248 | ) 249 | else: 250 | self.layer = UnetBasicBlock( # type: ignore 251 | spatial_dims=spatial_dims, 252 | in_channels=in_channels, 253 | out_channels=out_channels, 254 | kernel_size=kernel_size, 255 | stride=stride, 256 | norm_name=norm_name, 257 | ) 258 | 259 | def forward(self, inp): 260 | out = self.layer(inp) 261 | return out 262 | -------------------------------------------------------------------------------- /networks/vit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | from typing import Tuple 14 | 15 | import torch.nn as nn 16 | 17 | from networks.patchembedding import PatchEmbeddingBlock 18 | from networks.transformerblock import TransformerBlock 19 | 20 | 21 | class ViT(nn.Module): 22 | """ 23 | Vision Transformer (ViT), based on: "Dosovitskiy et al., 24 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 25 | """ 26 | 27 | def __init__( 28 | self, 29 | in_channels: int, 30 | img_size: Tuple[int, int, int], 31 | patch_size: Tuple[int, int, int], 32 | hidden_size: int = 768, 33 | mlp_dim: int = 3072, 34 | num_layers: int = 12, 35 | num_heads: int = 12, 36 | pos_embed: str = "perceptron", 37 | classification: bool = False, 38 | num_classes: int = 2, 39 | dropout_rate: float = 0.0, 40 | ) -> None: 41 | """ 42 | Args: 43 | in_channels: dimension of input channels. 44 | img_size: dimension of input image. 45 | patch_size: dimension of patch size. 46 | hidden_size: dimension of hidden layer. 47 | mlp_dim: dimension of feedforward layer. 48 | num_layers: number of transformer blocks. 49 | num_heads: number of attention heads. 50 | pos_embed: position embedding layer type. 51 | classification: bool argument to determine if classification is used. 52 | num_classes: number of classes if classification is used. 53 | dropout_rate: faction of the input units to drop. 54 | 55 | Examples:: 56 | 57 | # for single channel input with patch size of (96,96,96), conv position embedding and segmentation backbone 58 | >>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv') 59 | 60 | # for 3-channel with patch size of (128,128,128), 24 layers and classification backbone 61 | >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification= True) 62 | 63 | """ 64 | 65 | super().__init__() 66 | 67 | if not (0 <= dropout_rate <= 1): 68 | raise AssertionError("dropout_rate should be between 0 and 1.") 69 | 70 | if hidden_size % num_heads != 0: 71 | raise AssertionError("hidden size should be divisible by num_heads.") 72 | 73 | if pos_embed not in ["conv", "perceptron"]: 74 | raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") 75 | 76 | self.classification = classification 77 | self.patch_embedding = PatchEmbeddingBlock( 78 | in_channels, img_size, patch_size, hidden_size, num_heads, pos_embed, dropout_rate 79 | ) 80 | self.blocks = nn.ModuleList( 81 | [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] 82 | ) 83 | self.norm = nn.LayerNorm(hidden_size) 84 | if self.classification: 85 | self.classification_head = nn.Linear(hidden_size, num_classes) 86 | 87 | def forward(self, x): 88 | x = self.patch_embedding(x) 89 | hidden_states_out = [] 90 | for blk in self.blocks: 91 | x = blk(x) 92 | hidden_states_out.append(x) 93 | x = self.norm(x) 94 | if self.classification: 95 | x = self.classification_head(x[:, 0]) 96 | return x, hidden_states_out 97 | -------------------------------------------------------------------------------- /networks2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrGiovanni/SyntheticTumors/983dbeef8fb6f5d15ec4373a5241fc198b92d765/networks2/__init__.py -------------------------------------------------------------------------------- /networks2/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import torch.nn as nn 13 | 14 | 15 | class MLPBlock(nn.Module): 16 | """ 17 | A multi-layer perceptron block, based on: "Dosovitskiy et al., 18 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 19 | """ 20 | 21 | def __init__( 22 | self, 23 | hidden_size: int, 24 | mlp_dim: int, 25 | dropout_rate: float = 0.0, 26 | ) -> None: 27 | """ 28 | Args: 29 | hidden_size: dimension of hidden layer. 30 | mlp_dim: dimension of feedforward layer. 31 | dropout_rate: faction of the input units to drop. 32 | 33 | """ 34 | 35 | super().__init__() 36 | 37 | if not (0 <= dropout_rate <= 1): 38 | raise AssertionError("dropout_rate should be between 0 and 1.") 39 | 40 | self.linear1 = nn.Linear(hidden_size, mlp_dim) 41 | self.linear2 = nn.Linear(mlp_dim, hidden_size) 42 | self.fn = nn.GELU() 43 | self.drop1 = nn.Dropout(dropout_rate) 44 | self.drop2 = nn.Dropout(dropout_rate) 45 | 46 | def forward(self, x): 47 | x = self.fn(self.linear1(x)) 48 | x = self.drop1(x) 49 | x = self.linear2(x) 50 | x = self.drop2(x) 51 | return x 52 | -------------------------------------------------------------------------------- /networks2/patchembedding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | import math 14 | from typing import Tuple, Union 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | from monai.utils import optional_import 20 | 21 | Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") 22 | 23 | 24 | class PatchEmbeddingBlock(nn.Module): 25 | """ 26 | A patch embedding block, based on: "Dosovitskiy et al., 27 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 28 | """ 29 | 30 | def __init__( 31 | self, 32 | in_channels: int, 33 | img_size: Tuple[int, int, int], 34 | patch_size: Tuple[int, int, int], 35 | hidden_size: int, 36 | num_heads: int, 37 | pos_embed: str, 38 | dropout_rate: float = 0.0, 39 | ) -> None: 40 | """ 41 | Args: 42 | in_channels: dimension of input channels. 43 | img_size: dimension of input image. 44 | patch_size: dimension of patch size. 45 | hidden_size: dimension of hidden layer. 46 | num_heads: number of attention heads. 47 | pos_embed: position embedding layer type. 48 | dropout_rate: faction of the input units to drop. 49 | 50 | """ 51 | 52 | super().__init__() 53 | 54 | if not (0 <= dropout_rate <= 1): 55 | raise AssertionError("dropout_rate should be between 0 and 1.") 56 | 57 | if hidden_size % num_heads != 0: 58 | raise AssertionError("hidden size should be divisible by num_heads.") 59 | 60 | for m, p in zip(img_size, patch_size): 61 | if m < p: 62 | raise AssertionError("patch_size should be smaller than img_size.") 63 | 64 | if pos_embed not in ["conv", "perceptron"]: 65 | raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") 66 | 67 | if pos_embed == "perceptron": 68 | if img_size[0] % patch_size[0] != 0: 69 | raise AssertionError("img_size should be divisible by patch_size for perceptron patch embedding.") 70 | 71 | self.n_patches = ( 72 | (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) * (img_size[2] // patch_size[2]) 73 | ) 74 | self.patch_dim = in_channels * patch_size[0] * patch_size[1] * patch_size[2] 75 | 76 | self.pos_embed = pos_embed 77 | self.patch_embeddings: Union[nn.Conv3d, nn.Sequential] 78 | if self.pos_embed == "conv": 79 | self.patch_embeddings = nn.Conv3d( 80 | in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size 81 | ) 82 | elif self.pos_embed == "perceptron": 83 | self.patch_embeddings = nn.Sequential( 84 | Rearrange( 85 | "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)", 86 | p1=patch_size[0], 87 | p2=patch_size[1], 88 | p3=patch_size[2], 89 | ), 90 | nn.Linear(self.patch_dim, hidden_size), 91 | ) 92 | self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) 93 | self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) 94 | self.dropout = nn.Dropout(dropout_rate) 95 | self.trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0) 96 | self.apply(self._init_weights) 97 | 98 | def _init_weights(self, m): 99 | if isinstance(m, nn.Linear): 100 | self.trunc_normal_(m.weight, mean=0.0, std=0.02, a=-2.0, b=2.0) 101 | if isinstance(m, nn.Linear) and m.bias is not None: 102 | nn.init.constant_(m.bias, 0) 103 | elif isinstance(m, nn.LayerNorm): 104 | nn.init.constant_(m.bias, 0) 105 | nn.init.constant_(m.weight, 1.0) 106 | 107 | def trunc_normal_(self, tensor, mean, std, a, b): 108 | # From PyTorch official master until it's in a few official releases - RW 109 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 110 | def norm_cdf(x): 111 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 112 | 113 | with torch.no_grad(): 114 | l = norm_cdf((a - mean) / std) 115 | u = norm_cdf((b - mean) / std) 116 | tensor.uniform_(2 * l - 1, 2 * u - 1) 117 | tensor.erfinv_() 118 | tensor.mul_(std * math.sqrt(2.0)) 119 | tensor.add_(mean) 120 | tensor.clamp_(min=a, max=b) 121 | return tensor 122 | 123 | def forward(self, x): 124 | if self.pos_embed == "conv": 125 | x = self.patch_embeddings(x) 126 | x = x.flatten(2) 127 | x = x.transpose(-1, -2) 128 | elif self.pos_embed == "perceptron": 129 | x = self.patch_embeddings(x) 130 | embeddings = x + self.position_embeddings 131 | embeddings = self.dropout(embeddings) 132 | return embeddings 133 | -------------------------------------------------------------------------------- /networks2/selfattention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from monai.utils import optional_import 16 | 17 | einops, has_einops = optional_import("einops") 18 | 19 | 20 | class SABlock(nn.Module): 21 | """ 22 | A self-attention block, based on: "Dosovitskiy et al., 23 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 24 | """ 25 | 26 | def __init__( 27 | self, 28 | hidden_size: int, 29 | num_heads: int, 30 | dropout_rate: float = 0.0, 31 | ) -> None: 32 | """ 33 | Args: 34 | hidden_size: dimension of hidden layer. 35 | num_heads: number of attention heads. 36 | dropout_rate: faction of the input units to drop. 37 | 38 | """ 39 | 40 | super().__init__() 41 | 42 | if not (0 <= dropout_rate <= 1): 43 | raise AssertionError("dropout_rate should be between 0 and 1.") 44 | 45 | if hidden_size % num_heads != 0: 46 | raise AssertionError("hidden size should be divisible by num_heads.") 47 | 48 | self.num_heads = num_heads 49 | self.out_proj = nn.Linear(hidden_size, hidden_size) 50 | self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False) 51 | self.drop_output = nn.Dropout(dropout_rate) 52 | self.drop_weights = nn.Dropout(dropout_rate) 53 | self.head_dim = hidden_size // num_heads 54 | self.scale = self.head_dim ** -0.5 55 | if has_einops: 56 | self.rearrange = einops.rearrange 57 | else: 58 | raise ValueError('"Requires einops.') 59 | 60 | def forward(self, x): 61 | q, k, v = self.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads) 62 | att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) 63 | att_mat = self.drop_weights(att_mat) 64 | x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) 65 | x = self.rearrange(x, "b h l d -> b l (h d)") 66 | x = self.out_proj(x) 67 | x = self.drop_output(x) 68 | return x 69 | -------------------------------------------------------------------------------- /networks2/transformerblock.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import torch.nn as nn 13 | 14 | from monai.networks.blocks.mlp import MLPBlock 15 | from monai.networks.blocks.selfattention import SABlock 16 | 17 | 18 | class TransformerBlock(nn.Module): 19 | """ 20 | A transformer block, based on: "Dosovitskiy et al., 21 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 22 | """ 23 | 24 | def __init__( 25 | self, 26 | hidden_size: int, 27 | mlp_dim: int, 28 | num_heads: int, 29 | dropout_rate: float = 0.0, 30 | ) -> None: 31 | """ 32 | Args: 33 | hidden_size: dimension of hidden layer. 34 | mlp_dim: dimension of feedforward layer. 35 | num_heads: number of attention heads. 36 | dropout_rate: faction of the input units to drop. 37 | 38 | """ 39 | 40 | super().__init__() 41 | 42 | if not (0 <= dropout_rate <= 1): 43 | raise AssertionError("dropout_rate should be between 0 and 1.") 44 | 45 | if hidden_size % num_heads != 0: 46 | raise AssertionError("hidden size should be divisible by num_heads.") 47 | 48 | self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate) 49 | self.norm1 = nn.LayerNorm(hidden_size) 50 | self.attn = SABlock(hidden_size, num_heads, dropout_rate) 51 | self.norm2 = nn.LayerNorm(hidden_size) 52 | 53 | def forward(self, x): 54 | x = x + self.attn(self.norm1(x)) 55 | x = x + self.mlp(self.norm2(x)) 56 | return x 57 | -------------------------------------------------------------------------------- /networks2/unetr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from typing import Tuple, Union 13 | 14 | import torch.nn as nn 15 | 16 | from monai.networks.blocks.dynunet_block import UnetOutBlock 17 | from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock 18 | from monai.networks.nets.vit import ViT 19 | 20 | 21 | class UNETR(nn.Module): 22 | """ 23 | UNETR based on: "Hatamizadeh et al., 24 | UNETR: Transformers for 3D Medical Image Segmentation " 25 | """ 26 | 27 | def __init__( 28 | self, 29 | in_channels: int, 30 | out_channels: int, 31 | img_size: Tuple[int, int, int], 32 | feature_size: int = 16, 33 | hidden_size: int = 768, 34 | mlp_dim: int = 3072, 35 | num_heads: int = 12, 36 | pos_embed: str = "perceptron", 37 | norm_name: Union[Tuple, str] = "instance", 38 | conv_block: bool = False, 39 | res_block: bool = True, 40 | dropout_rate: float = 0.0, 41 | ) -> None: 42 | """ 43 | Args: 44 | in_channels: dimension of input channels. 45 | out_channels: dimension of output channels. 46 | img_size: dimension of input image. 47 | feature_size: dimension of network feature size. 48 | hidden_size: dimension of hidden layer. 49 | mlp_dim: dimension of feedforward layer. 50 | num_heads: number of attention heads. 51 | pos_embed: position embedding layer type. 52 | norm_name: feature normalization type and arguments. 53 | conv_block: bool argument to determine if convolutional block is used. 54 | res_block: bool argument to determine if residual block is used. 55 | dropout_rate: faction of the input units to drop. 56 | 57 | Examples:: 58 | 59 | # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm 60 | >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') 61 | 62 | # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm 63 | >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') 64 | 65 | """ 66 | 67 | super().__init__() 68 | 69 | if not (0 <= dropout_rate <= 1): 70 | raise AssertionError("dropout_rate should be between 0 and 1.") 71 | 72 | if hidden_size % num_heads != 0: 73 | raise AssertionError("hidden size should be divisible by num_heads.") 74 | 75 | if pos_embed not in ["conv", "perceptron"]: 76 | raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") 77 | 78 | self.num_layers = 12 79 | self.patch_size = (16, 16, 16) 80 | self.feat_size = ( 81 | img_size[0] // self.patch_size[0], 82 | img_size[1] // self.patch_size[1], 83 | img_size[2] // self.patch_size[2], 84 | ) 85 | self.hidden_size = hidden_size 86 | self.classification = False 87 | self.vit = ViT( 88 | in_channels=in_channels, 89 | img_size=img_size, 90 | patch_size=self.patch_size, 91 | hidden_size=hidden_size, 92 | mlp_dim=mlp_dim, 93 | num_layers=self.num_layers, 94 | num_heads=num_heads, 95 | pos_embed=pos_embed, 96 | classification=self.classification, 97 | dropout_rate=dropout_rate, 98 | ) 99 | self.encoder1 = UnetrBasicBlock( 100 | spatial_dims=3, 101 | in_channels=in_channels, 102 | out_channels=feature_size, 103 | kernel_size=3, 104 | stride=1, 105 | norm_name=norm_name, 106 | res_block=res_block, 107 | ) 108 | self.encoder2 = UnetrPrUpBlock( 109 | spatial_dims=3, 110 | in_channels=hidden_size, 111 | out_channels=feature_size * 2, 112 | num_layer=2, 113 | kernel_size=3, 114 | stride=1, 115 | upsample_kernel_size=2, 116 | norm_name=norm_name, 117 | conv_block=conv_block, 118 | res_block=res_block, 119 | ) 120 | self.encoder3 = UnetrPrUpBlock( 121 | spatial_dims=3, 122 | in_channels=hidden_size, 123 | out_channels=feature_size * 4, 124 | num_layer=1, 125 | kernel_size=3, 126 | stride=1, 127 | upsample_kernel_size=2, 128 | norm_name=norm_name, 129 | conv_block=conv_block, 130 | res_block=res_block, 131 | ) 132 | self.encoder4 = UnetrPrUpBlock( 133 | spatial_dims=3, 134 | in_channels=hidden_size, 135 | out_channels=feature_size * 8, 136 | num_layer=0, 137 | kernel_size=3, 138 | stride=1, 139 | upsample_kernel_size=2, 140 | norm_name=norm_name, 141 | conv_block=conv_block, 142 | res_block=res_block, 143 | ) 144 | self.decoder5 = UnetrUpBlock( 145 | spatial_dims=3, 146 | in_channels=hidden_size, 147 | out_channels=feature_size * 8, 148 | stride=1, 149 | kernel_size=3, 150 | upsample_kernel_size=2, 151 | norm_name=norm_name, 152 | res_block=res_block, 153 | ) 154 | self.decoder4 = UnetrUpBlock( 155 | spatial_dims=3, 156 | in_channels=feature_size * 8, 157 | out_channels=feature_size * 4, 158 | stride=1, 159 | kernel_size=3, 160 | upsample_kernel_size=2, 161 | norm_name=norm_name, 162 | res_block=res_block, 163 | ) 164 | self.decoder3 = UnetrUpBlock( 165 | spatial_dims=3, 166 | in_channels=feature_size * 4, 167 | out_channels=feature_size * 2, 168 | stride=1, 169 | kernel_size=3, 170 | upsample_kernel_size=2, 171 | norm_name=norm_name, 172 | res_block=res_block, 173 | ) 174 | self.decoder2 = UnetrUpBlock( 175 | spatial_dims=3, 176 | in_channels=feature_size * 2, 177 | out_channels=feature_size, 178 | stride=1, 179 | kernel_size=3, 180 | upsample_kernel_size=2, 181 | norm_name=norm_name, 182 | res_block=res_block, 183 | ) 184 | self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size, out_channels=out_channels) # type: ignore 185 | 186 | def proj_feat(self, x, hidden_size, feat_size): 187 | x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) 188 | x = x.permute(0, 4, 1, 2, 3).contiguous() 189 | return x 190 | 191 | def forward(self, x_in): 192 | x, hidden_states_out = self.vit(x_in) 193 | enc1 = self.encoder1(x_in) 194 | x2 = hidden_states_out[3] 195 | enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) 196 | x3 = hidden_states_out[6] 197 | enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) 198 | x4 = hidden_states_out[9] 199 | enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) 200 | dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) 201 | dec3 = self.decoder5(dec4, enc4) 202 | dec2 = self.decoder4(dec3, enc3) 203 | dec1 = self.decoder3(dec2, enc2) 204 | out = self.decoder2(dec1, enc1) 205 | logits = self.out(out) 206 | return logits 207 | -------------------------------------------------------------------------------- /networks2/unetr_block.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | from typing import Sequence, Tuple, Union 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer 19 | 20 | 21 | class UnetrUpBlock(nn.Module): 22 | """ 23 | An upsampling module that can be used for UNETR: "Hatamizadeh et al., 24 | UNETR: Transformers for 3D Medical Image Segmentation " 25 | """ 26 | 27 | def __init__( 28 | self, 29 | spatial_dims: int, 30 | in_channels: int, 31 | out_channels: int, # type: ignore 32 | kernel_size: Union[Sequence[int], int], 33 | stride: Union[Sequence[int], int], 34 | upsample_kernel_size: Union[Sequence[int], int], 35 | norm_name: Union[Tuple, str], 36 | res_block: bool = False, 37 | ) -> None: 38 | """ 39 | Args: 40 | spatial_dims: number of spatial dimensions. 41 | in_channels: number of input channels. 42 | out_channels: number of output channels. 43 | kernel_size: convolution kernel size. 44 | stride: convolution stride. 45 | upsample_kernel_size: convolution kernel size for transposed convolution layers. 46 | norm_name: feature normalization type and arguments. 47 | res_block: bool argument to determine if residual block is used. 48 | 49 | """ 50 | 51 | super(UnetrUpBlock, self).__init__() 52 | upsample_stride = upsample_kernel_size 53 | self.transp_conv = get_conv_layer( 54 | spatial_dims, 55 | in_channels, 56 | out_channels, 57 | kernel_size=upsample_kernel_size, 58 | stride=upsample_stride, 59 | conv_only=True, 60 | is_transposed=True, 61 | ) 62 | 63 | if res_block: 64 | self.conv_block = UnetResBlock( 65 | spatial_dims, 66 | out_channels + out_channels, 67 | out_channels, 68 | kernel_size=kernel_size, 69 | stride=1, 70 | norm_name=norm_name, 71 | ) 72 | else: 73 | self.conv_block = UnetBasicBlock( # type: ignore 74 | spatial_dims, 75 | out_channels + out_channels, 76 | out_channels, 77 | kernel_size=kernel_size, 78 | stride=1, 79 | norm_name=norm_name, 80 | ) 81 | 82 | def forward(self, inp, skip): 83 | # number of channels for skip should equals to out_channels 84 | out = self.transp_conv(inp) 85 | out = torch.cat((out, skip), dim=1) 86 | out = self.conv_block(out) 87 | return out 88 | 89 | 90 | class UnetrPrUpBlock(nn.Module): 91 | """ 92 | A projection upsampling module that can be used for UNETR: "Hatamizadeh et al., 93 | UNETR: Transformers for 3D Medical Image Segmentation " 94 | """ 95 | 96 | def __init__( 97 | self, 98 | spatial_dims: int, 99 | in_channels: int, 100 | out_channels: int, 101 | num_layer: int, 102 | kernel_size: Union[Sequence[int], int], 103 | stride: Union[Sequence[int], int], 104 | upsample_kernel_size: Union[Sequence[int], int], 105 | norm_name: Union[Tuple, str], 106 | conv_block: bool = False, 107 | res_block: bool = False, 108 | ) -> None: 109 | """ 110 | Args: 111 | spatial_dims: number of spatial dimensions. 112 | in_channels: number of input channels. 113 | out_channels: number of output channels. 114 | num_layer: number of upsampling blocks. 115 | kernel_size: convolution kernel size. 116 | stride: convolution stride. 117 | upsample_kernel_size: convolution kernel size for transposed convolution layers. 118 | norm_name: feature normalization type and arguments. 119 | conv_block: bool argument to determine if convolutional block is used. 120 | res_block: bool argument to determine if residual block is used. 121 | 122 | """ 123 | 124 | super().__init__() 125 | 126 | upsample_stride = upsample_kernel_size 127 | self.transp_conv_init = get_conv_layer( 128 | spatial_dims, 129 | in_channels, 130 | out_channels, 131 | kernel_size=upsample_kernel_size, 132 | stride=upsample_stride, 133 | conv_only=True, 134 | is_transposed=True, 135 | ) 136 | if conv_block: 137 | if res_block: 138 | self.blocks = nn.ModuleList( 139 | [ 140 | nn.Sequential( 141 | get_conv_layer( 142 | spatial_dims, 143 | out_channels, 144 | out_channels, 145 | kernel_size=upsample_kernel_size, 146 | stride=upsample_stride, 147 | conv_only=True, 148 | is_transposed=True, 149 | ), 150 | UnetResBlock( 151 | spatial_dims=3, 152 | in_channels=out_channels, 153 | out_channels=out_channels, 154 | kernel_size=kernel_size, 155 | stride=stride, 156 | norm_name=norm_name, 157 | ), 158 | ) 159 | for i in range(num_layer) 160 | ] 161 | ) 162 | else: 163 | self.blocks = nn.ModuleList( 164 | [ 165 | nn.Sequential( 166 | get_conv_layer( 167 | spatial_dims, 168 | out_channels, 169 | out_channels, 170 | kernel_size=upsample_kernel_size, 171 | stride=upsample_stride, 172 | conv_only=True, 173 | is_transposed=True, 174 | ), 175 | UnetBasicBlock( 176 | spatial_dims=3, 177 | in_channels=out_channels, 178 | out_channels=out_channels, 179 | kernel_size=kernel_size, 180 | stride=stride, 181 | norm_name=norm_name, 182 | ), 183 | ) 184 | for i in range(num_layer) 185 | ] 186 | ) 187 | else: 188 | self.blocks = nn.ModuleList( 189 | [ 190 | get_conv_layer( 191 | spatial_dims, 192 | out_channels, 193 | out_channels, 194 | kernel_size=upsample_kernel_size, 195 | stride=upsample_stride, 196 | conv_only=True, 197 | is_transposed=True, 198 | ) 199 | for i in range(num_layer) 200 | ] 201 | ) 202 | 203 | def forward(self, x): 204 | x = self.transp_conv_init(x) 205 | for blk in self.blocks: 206 | x = blk(x) 207 | return x 208 | 209 | 210 | class UnetrBasicBlock(nn.Module): 211 | """ 212 | A CNN module that can be used for UNETR, based on: "Hatamizadeh et al., 213 | UNETR: Transformers for 3D Medical Image Segmentation " 214 | """ 215 | 216 | def __init__( 217 | self, 218 | spatial_dims: int, 219 | in_channels: int, 220 | out_channels: int, 221 | kernel_size: Union[Sequence[int], int], 222 | stride: Union[Sequence[int], int], 223 | norm_name: Union[Tuple, str], 224 | res_block: bool = False, 225 | ) -> None: 226 | """ 227 | Args: 228 | spatial_dims: number of spatial dimensions. 229 | in_channels: number of input channels. 230 | out_channels: number of output channels. 231 | kernel_size: convolution kernel size. 232 | stride: convolution stride. 233 | norm_name: feature normalization type and arguments. 234 | res_block: bool argument to determine if residual block is used. 235 | 236 | """ 237 | 238 | super().__init__() 239 | 240 | if res_block: 241 | self.layer = UnetResBlock( 242 | spatial_dims=spatial_dims, 243 | in_channels=in_channels, 244 | out_channels=out_channels, 245 | kernel_size=kernel_size, 246 | stride=stride, 247 | norm_name=norm_name, 248 | ) 249 | else: 250 | self.layer = UnetBasicBlock( # type: ignore 251 | spatial_dims=spatial_dims, 252 | in_channels=in_channels, 253 | out_channels=out_channels, 254 | kernel_size=kernel_size, 255 | stride=stride, 256 | norm_name=norm_name, 257 | ) 258 | 259 | def forward(self, inp): 260 | out = self.layer(inp) 261 | return out 262 | -------------------------------------------------------------------------------- /networks2/vit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | 13 | from typing import Tuple 14 | 15 | import torch.nn as nn 16 | 17 | from monai.networks.blocks.patchembedding import PatchEmbeddingBlock 18 | from monai.networks.blocks.transformerblock import TransformerBlock 19 | 20 | 21 | class ViT(nn.Module): 22 | """ 23 | Vision Transformer (ViT), based on: "Dosovitskiy et al., 24 | An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale " 25 | """ 26 | 27 | def __init__( 28 | self, 29 | in_channels: int, 30 | img_size: Tuple[int, int, int], 31 | patch_size: Tuple[int, int, int], 32 | hidden_size: int = 768, 33 | mlp_dim: int = 3072, 34 | num_layers: int = 12, 35 | num_heads: int = 12, 36 | pos_embed: str = "perceptron", 37 | classification: bool = False, 38 | num_classes: int = 2, 39 | dropout_rate: float = 0.0, 40 | ) -> None: 41 | """ 42 | Args: 43 | in_channels: dimension of input channels. 44 | img_size: dimension of input image. 45 | patch_size: dimension of patch size. 46 | hidden_size: dimension of hidden layer. 47 | mlp_dim: dimension of feedforward layer. 48 | num_layers: number of transformer blocks. 49 | num_heads: number of attention heads. 50 | pos_embed: position embedding layer type. 51 | classification: bool argument to determine if classification is used. 52 | num_classes: number of classes if classification is used. 53 | dropout_rate: faction of the input units to drop. 54 | 55 | Examples:: 56 | 57 | # for single channel input with patch size of (96,96,96), conv position embedding and segmentation backbone 58 | >>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv') 59 | 60 | # for 3-channel with patch size of (128,128,128), 24 layers and classification backbone 61 | >>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification= True) 62 | 63 | """ 64 | 65 | super().__init__() 66 | 67 | if not (0 <= dropout_rate <= 1): 68 | raise AssertionError("dropout_rate should be between 0 and 1.") 69 | 70 | if hidden_size % num_heads != 0: 71 | raise AssertionError("hidden size should be divisible by num_heads.") 72 | 73 | if pos_embed not in ["conv", "perceptron"]: 74 | raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") 75 | 76 | self.classification = classification 77 | self.patch_embedding = PatchEmbeddingBlock( 78 | in_channels, img_size, patch_size, hidden_size, num_heads, pos_embed, dropout_rate 79 | ) 80 | self.blocks = nn.ModuleList( 81 | [TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate) for i in range(num_layers)] 82 | ) 83 | self.norm = nn.LayerNorm(hidden_size) 84 | if self.classification: 85 | self.classification_head = nn.Linear(hidden_size, num_classes) 86 | 87 | def forward(self, x): 88 | x = self.patch_embedding(x) 89 | hidden_states_out = [] 90 | for blk in self.blocks: 91 | x = blk(x) 92 | hidden_states_out.append(x) 93 | x = self.norm(x) 94 | if self.classification: 95 | x = self.classification_head(x[:, 0]) 96 | return x, hidden_states_out 97 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MrGiovanni/SyntheticTumors/983dbeef8fb6f5d15ec4373a5241fc198b92d765/optimizers/__init__.py -------------------------------------------------------------------------------- /optimizers/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | import math 13 | import warnings 14 | from typing import List 15 | 16 | from torch.optim.lr_scheduler import LambdaLR, _LRScheduler 17 | from torch import nn as nn 18 | from torch.optim import Adam, Optimizer 19 | from torch.optim.lr_scheduler import _LRScheduler 20 | 21 | 22 | __all__ = ["LinearLR", "ExponentialLR"] 23 | 24 | 25 | class _LRSchedulerMONAI(_LRScheduler): 26 | """Base class for increasing the learning rate between two boundaries over a number 27 | of iterations""" 28 | 29 | def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1) -> None: 30 | """ 31 | Args: 32 | optimizer: wrapped optimizer. 33 | end_lr: the final learning rate. 34 | num_iter: the number of iterations over which the test occurs. 35 | last_epoch: the index of last epoch. 36 | Returns: 37 | None 38 | """ 39 | self.end_lr = end_lr 40 | self.num_iter = num_iter 41 | super(_LRSchedulerMONAI, self).__init__(optimizer, last_epoch) 42 | 43 | 44 | class LinearLR(_LRSchedulerMONAI): 45 | """Linearly increases the learning rate between two boundaries over a number of 46 | iterations. 47 | """ 48 | 49 | def get_lr(self): 50 | r = self.last_epoch / (self.num_iter - 1) 51 | return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] 52 | 53 | 54 | class ExponentialLR(_LRSchedulerMONAI): 55 | """Exponentially increases the learning rate between two boundaries over a number of 56 | iterations. 57 | """ 58 | 59 | def get_lr(self): 60 | r = self.last_epoch / (self.num_iter - 1) 61 | return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] 62 | 63 | 64 | class WarmupCosineSchedule(LambdaLR): 65 | """Linear warmup and then cosine decay. 66 | Based on https://huggingface.co/ implementation. 67 | """ 68 | 69 | def __init__( 70 | self, optimizer: Optimizer, warmup_steps: int, t_total: int, cycles: float = 0.5, last_epoch: int = -1 71 | ) -> None: 72 | """ 73 | Args: 74 | optimizer: wrapped optimizer. 75 | warmup_steps: number of warmup iterations. 76 | t_total: total number of training iterations. 77 | cycles: cosine cycles parameter. 78 | last_epoch: the index of last epoch. 79 | Returns: 80 | None 81 | """ 82 | self.warmup_steps = warmup_steps 83 | self.t_total = t_total 84 | self.cycles = cycles 85 | super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch) 86 | 87 | def lr_lambda(self, step): 88 | if step < self.warmup_steps: 89 | return float(step) / float(max(1.0, self.warmup_steps)) 90 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 91 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 92 | 93 | class LinearWarmupCosineAnnealingLR(_LRScheduler): 94 | 95 | def __init__( 96 | self, 97 | optimizer: Optimizer, 98 | warmup_epochs: int, 99 | max_epochs: int, 100 | warmup_start_lr: float = 0.0, 101 | eta_min: float = 0.0, 102 | last_epoch: int = -1, 103 | ) -> None: 104 | """ 105 | Args: 106 | optimizer (Optimizer): Wrapped optimizer. 107 | warmup_epochs (int): Maximum number of iterations for linear warmup 108 | max_epochs (int): Maximum number of iterations 109 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 110 | eta_min (float): Minimum learning rate. Default: 0. 111 | last_epoch (int): The index of last epoch. Default: -1. 112 | """ 113 | self.warmup_epochs = warmup_epochs 114 | self.max_epochs = max_epochs 115 | self.warmup_start_lr = warmup_start_lr 116 | self.eta_min = eta_min 117 | 118 | super(LinearWarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch) 119 | 120 | def get_lr(self) -> List[float]: 121 | """ 122 | Compute learning rate using chainable form of the scheduler 123 | """ 124 | if not self._get_lr_called_within_step: 125 | warnings.warn( 126 | "To get the last learning rate computed by the scheduler, " 127 | "please use `get_last_lr()`.", 128 | UserWarning, 129 | ) 130 | 131 | if self.last_epoch == 0: 132 | return [self.warmup_start_lr] * len(self.base_lrs) 133 | elif self.last_epoch < self.warmup_epochs: 134 | return [ 135 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 136 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 137 | ] 138 | elif self.last_epoch == self.warmup_epochs: 139 | return self.base_lrs 140 | elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: 141 | return [ 142 | group["lr"] + (base_lr - self.eta_min) * 143 | (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 144 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 145 | ] 146 | 147 | return [ 148 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) / 149 | ( 150 | 1 + 151 | math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)) 152 | ) * (group["lr"] - self.eta_min) + self.eta_min for group in self.optimizer.param_groups 153 | ] 154 | 155 | def _get_closed_form_lr(self) -> List[float]: 156 | """ 157 | Called when epoch is passed as a param to the `step` function of the scheduler. 158 | """ 159 | if self.last_epoch < self.warmup_epochs: 160 | return [ 161 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 162 | for base_lr in self.base_lrs 163 | ] 164 | 165 | return [ 166 | self.eta_min + 0.5 * (base_lr - self.eta_min) * 167 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 168 | for base_lr in self.base_lrs 169 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | monai==0.9.0 2 | nibabel 3 | tensorboard 4 | tensorboardX 5 | ml-collections 6 | einops 7 | timm 8 | opencv-python 9 | elasticdeform 10 | numpy 11 | nibabel 12 | scipy 13 | tqdm 14 | matplotlib 15 | pandas 16 | glob2 -------------------------------------------------------------------------------- /transfer_label.py: -------------------------------------------------------------------------------- 1 | import os, shutil, argparse 2 | import nibabel as nib 3 | import numpy as np 4 | from glob import glob 5 | 6 | parser = argparse.ArgumentParser(description='main') 7 | parser.add_argument('--data_path', default=None, type=str, required=True, help="Directory to dataset") 8 | 9 | args = parser.parse_args() 10 | def main(data_set): 11 | # Our synthesis code only deal with label: "0: background, 1:liver" 12 | # Different datasets have differnt index for liver (e.g. TCIA dataset liver: 6) 13 | # This code can transfer other label into "0: background, 1:liver" 14 | data_path = args.data_path 15 | 16 | # make label dir 17 | save_dir = os.path.join(data_path, 'label') 18 | if not os.path.exists(save_dir): 19 | os.makedirs(save_dir) 20 | 21 | if data_set == 'btcv': 22 | label_root = os.path.join(data_path, '01_Multi-Atlas_Labeling/label/') 23 | seg_files = sorted(glob(os.path.join(label_root, '*.nii.gz'))) 24 | for seg_file in seg_files: 25 | mask = nib.load(os.path.join(label_root, seg_file)) 26 | mask_scan = mask.get_fdata() 27 | mask_scan = (mask_scan == 6).astype(np.uint8) 28 | new_mask = nib.nifti1.Nifti1Image(mask_scan, affine=mask.affine, header=mask.header) 29 | name = seg_file.split('/')[-1] 30 | nib.save(new_mask, os.path.join(save_dir, f'multi-atlas-{name}')) 31 | print(seg_file, ' done') 32 | 33 | elif data_set == 'tcia': 34 | label_root = os.path.join(data_path, '02_TCIA_Pancreas-CT/multiorgan_label/') 35 | seg_files = sorted(glob(os.path.join(label_root, '*.nii.gz'))) 36 | for seg_file in seg_files: 37 | mask = nib.load(os.path.join(label_root, seg_file)) 38 | mask_scan = mask.get_fdata() 39 | mask_scan = (mask_scan == 6).astype(np.uint8) 40 | new_mask = nib.nifti1.Nifti1Image(mask_scan, affine=mask.affine, header=mask.header) 41 | name = seg_file.split('/')[-1] 42 | nib.save(new_mask, os.path.join(save_dir, f'tcia-{name}')) 43 | print(seg_file, ' done') 44 | 45 | elif data_set == 'chaos': 46 | label_root = os.path.join(data_path, '03_CHAOS/ct/liver_label/') 47 | seg_files = sorted(glob(os.path.join(label_root, '*.nii.gz'))) 48 | for seg_file in seg_files: 49 | name = seg_file.split('/')[-1] 50 | shutil.copy(os.path.join(label_root, seg_file), os.path.join(save_dir, f'chaos-{name}')) 51 | 52 | elif data_set == 'lits': 53 | label_root = os.path.join(data_path, '04_LiTS/label') 54 | healthy_list = [32, 34, 38, 41, 47, 89, 91, 105, 106, 114, 115] 55 | for person_id in healthy_list: 56 | name = f'liver_{person_id}.nii.gz' 57 | shutil.copy(os.path.join(label_root, name), os.path.join(save_dir, f'lits-{name}')) 58 | else: 59 | raise ValueError('Unsupported dataset ' + str(data_set)) 60 | 61 | 62 | if __name__ == "__main__": 63 | datasets = ['btcv', 'tcia', 'chaos', 'lits'] 64 | for data_set in datasets: 65 | main(data_set) -------------------------------------------------------------------------------- /validation.py: -------------------------------------------------------------------------------- 1 | import os, time, csv 2 | import numpy as np 3 | import torch 4 | from sklearn.metrics import confusion_matrix 5 | from scipy import ndimage 6 | from scipy.ndimage import label 7 | from functools import partial 8 | from surface_distance import compute_surface_distances,compute_surface_dice_at_tolerance 9 | import monai 10 | from monai.inferers import sliding_window_inference 11 | from monai.data import load_decathlon_datalist 12 | from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged 13 | from monai import transforms, data 14 | from networks.swin3d_unetrv2 import SwinUNETR as SwinUNETR_v2 15 | import nibabel as nib 16 | 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | import argparse 21 | parser = argparse.ArgumentParser(description='liver tumor validation') 22 | 23 | # file dir 24 | parser.add_argument('--val_dir', default=None, type=str) 25 | parser.add_argument('--json_dir', default=None, type=str) 26 | parser.add_argument('--save_dir', default='out', type=str) 27 | parser.add_argument('--checkpoint', action='store_true') 28 | 29 | parser.add_argument('--log_dir', default=None, type=str) 30 | parser.add_argument('--feature_size', default=16, type=int) 31 | parser.add_argument('--val_overlap', default=0.5, type=float) 32 | parser.add_argument('--num_classes', default=3, type=int) 33 | 34 | parser.add_argument('--model', default='unet', type=str) 35 | parser.add_argument('--swin_type', default='tiny', type=str) 36 | 37 | def denoise_pred(pred: np.ndarray): 38 | """ 39 | # 0: background, 1: liver, 2: tumor. 40 | pred.shape: (3, H, W, D) 41 | """ 42 | denoise_pred = np.zeros_like(pred) 43 | 44 | live_channel = pred[1, ...] 45 | labels, nb = label(live_channel) 46 | max_sum = -1 47 | choice_idx = -1 48 | for idx in range(1, nb+1): 49 | component = (labels == idx) 50 | if np.sum(component) > max_sum: 51 | choice_idx = idx 52 | max_sum = np.sum(component) 53 | component = (labels == choice_idx) 54 | denoise_pred[1, ...] = component 55 | 56 | # 膨胀然后覆盖掉liver以外的tumor 57 | liver_dilation = ndimage.binary_dilation(denoise_pred[1, ...], iterations=30).astype(bool) 58 | denoise_pred[2,...] = pred[2,...].astype(bool) * liver_dilation 59 | 60 | denoise_pred[0,...] = 1 - np.logical_or(denoise_pred[1,...], denoise_pred[2,...]) 61 | 62 | return denoise_pred 63 | 64 | def cal_dice(pred, true): 65 | intersection = np.sum(pred[true==1]) * 2.0 66 | dice = intersection / (np.sum(pred) + np.sum(true)) 67 | return dice 68 | 69 | def cal_dice_nsd(pred, truth, spacing_mm=(1,1,1), tolerance=2): 70 | dice = cal_dice(pred, truth) 71 | # cal nsd 72 | surface_distances = compute_surface_distances(truth.astype(bool), pred.astype(bool), spacing_mm=spacing_mm) 73 | nsd = compute_surface_dice_at_tolerance(surface_distances, tolerance) 74 | 75 | return (dice, nsd) 76 | 77 | 78 | def _get_model(args): 79 | inf_size = [96, 96, 96] 80 | print(args.model) 81 | if args.model == 'swin_unetrv2': 82 | if args.swin_type == 'tiny': 83 | feature_size=12 84 | elif args.swin_type == 'small': 85 | feature_size=24 86 | elif args.swin_type == 'base': 87 | feature_size=48 88 | 89 | model = SwinUNETR_v2(in_channels=1, 90 | out_channels=3, 91 | img_size=(96, 96, 96), 92 | feature_size=feature_size, 93 | patch_size=2, 94 | depths=[2, 2, 2, 2], 95 | num_heads=[3, 6, 12, 24], 96 | window_size=[7, 7, 7]) 97 | 98 | elif args.model == 'unet': 99 | from monai.networks.nets import UNet 100 | model = UNet( 101 | spatial_dims=3, 102 | in_channels=1, 103 | out_channels=3, 104 | channels=(16, 32, 64, 128, 256), 105 | strides=(2, 2, 2, 2), 106 | num_res_units=2, 107 | ) 108 | 109 | else: 110 | raise ValueError('Unsupported model ' + str(args.model)) 111 | 112 | 113 | if args.checkpoint: 114 | checkpoint = torch.load(os.path.join(args.log_dir, 'model.pt'), map_location='cpu') 115 | 116 | from collections import OrderedDict 117 | new_state_dict = OrderedDict() 118 | for k, v in checkpoint['state_dict'].items(): 119 | new_state_dict[k.replace('backbone.','')] = v 120 | # load params 121 | model.load_state_dict(new_state_dict, strict=False) 122 | print('Use logdir weights') 123 | else: 124 | model_dict = torch.load(os.path.join(args.log_dir, 'model.pt')) 125 | model.load_state_dict(model_dict['state_dict']) 126 | print('Use logdir weights') 127 | 128 | model = model.cuda() 129 | model_inferer = partial(sliding_window_inference, roi_size=inf_size, sw_batch_size=1, predictor=model, overlap=args.val_overlap, mode='gaussian') 130 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 131 | print('Total parameters count', pytorch_total_params) 132 | 133 | return model, model_inferer 134 | 135 | def _get_loader(args): 136 | val_data_dir = args.val_dir 137 | datalist_json = args.json_dir 138 | val_org_transform = transforms.Compose( 139 | [ 140 | transforms.LoadImaged(keys=["image", "label"]), 141 | transforms.AddChanneld(keys=["image", "label"]), 142 | transforms.Orientationd(keys=["image"], axcodes="RAS"), 143 | transforms.Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear")), 144 | transforms.ScaleIntensityRanged(keys=["image"], a_min=-21, a_max=189, b_min=0.0, b_max=1.0, clip=True), 145 | transforms.SpatialPadd(keys=["image"], mode="minimum", spatial_size=[96, 96, 96]), 146 | transforms.ToTensord(keys=["image", "label"]), 147 | ] 148 | ) 149 | val_files = load_decathlon_datalist(datalist_json, True, "validation", base_dir=val_data_dir) 150 | val_org_ds = data.Dataset(val_files, transform=val_org_transform) 151 | val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True) 152 | 153 | post_transforms = Compose([ 154 | Invertd( 155 | keys="pred", 156 | transform=val_org_transform, 157 | orig_keys="image", 158 | meta_keys="pred_meta_dict", 159 | orig_meta_keys="image_meta_dict", 160 | meta_key_postfix="meta_dict", 161 | nearest_interp=False, 162 | to_tensor=True, 163 | ), 164 | # AsDiscreted(keys="pred", argmax=True, to_onehot=3), 165 | AsDiscreted(keys="pred", argmax=True, to_onehot=3), 166 | AsDiscreted(keys="label", to_onehot=3), 167 | # SaveImaged(keys="pred", meta_keys="pred_meta_dict", output_dir=output_dir, output_postfix="seg", resample=False,output_dtype=np.uint8,separate_folder=False), 168 | ]) 169 | 170 | return val_org_loader, post_transforms 171 | 172 | def main(): 173 | args = parser.parse_args() 174 | model_name = args.log_dir.split('/')[-1] 175 | args.model_name = model_name 176 | print("MAIN Argument values:") 177 | for k, v in vars(args).items(): 178 | print(k, '=>', v) 179 | print('-----------------') 180 | 181 | torch.cuda.set_device(0) #use this default device (same as args.device if not distributed) 182 | torch.backends.cudnn.benchmark = True 183 | 184 | ## loader and post_transform 185 | val_loader, post_transforms = _get_loader(args) 186 | 187 | ## NETWORK 188 | model, model_inferer = _get_model(args) 189 | 190 | liver_dice = [] 191 | liver_nsd = [] 192 | tumor_dice = [] 193 | tumor_nsd = [] 194 | header = ['name', 'liver_dice', 'liver_nsd', 'tumor_dice', 'tumor_nsd'] 195 | rows = [] 196 | 197 | model.eval() 198 | start_time = time.time() 199 | with torch.no_grad(): 200 | for idx, val_data in enumerate(val_loader): 201 | val_inputs = val_data["image"].cuda() 202 | name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0] 203 | original_affine = val_data["label_meta_dict"]["affine"][0].numpy() 204 | pixdim = val_data['label_meta_dict']['pixdim'].cpu().numpy() 205 | spacing_mm = tuple(pixdim[0][1:4]) 206 | 207 | val_data["pred"] = model_inferer(val_inputs) 208 | val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] 209 | # val_outputs, val_labels = from_engine(["pred", "label"])(val_data) 210 | val_outputs, val_labels = val_data[0]['pred'], val_data[0]['label'] 211 | 212 | # val_outpus.shape == val_labels.shape (3, H, W, Z) 213 | val_outputs, val_labels = val_outputs.detach().cpu().numpy(), val_labels.detach().cpu().numpy() 214 | 215 | # denoise the ouputs 216 | val_outputs = denoise_pred(val_outputs) 217 | 218 | current_liver_dice, current_liver_nsd = cal_dice_nsd(val_outputs[1,...], val_labels[1,...], spacing_mm=spacing_mm) 219 | current_tumor_dice, current_tumor_nsd = cal_dice_nsd(val_outputs[2,...], val_labels[2,...], spacing_mm=spacing_mm) 220 | 221 | 222 | liver_dice.append(current_liver_dice) 223 | liver_nsd.append(current_liver_nsd) 224 | tumor_dice.append(current_tumor_dice) 225 | tumor_nsd.append(current_tumor_nsd) 226 | 227 | row = [name, current_liver_dice, current_liver_nsd, current_tumor_dice, current_tumor_nsd] 228 | rows.append(row) 229 | 230 | print(name, val_outputs[0].shape, \ 231 | 'dice: [{:.3f} {:.3f}]; nsd: [{:.3f} {:.3f}]'.format(current_liver_dice, current_tumor_dice, current_liver_nsd, current_tumor_nsd), \ 232 | 'time {:.2f}s'.format(time.time() - start_time)) 233 | 234 | # save the prediction 235 | output_dir = os.path.join(args.save_dir, args.model_name, str(args.val_overlap), 'pred') 236 | if not os.path.exists(output_dir): 237 | os.makedirs(output_dir) 238 | val_outputs = np.argmax(val_outputs, axis=0) 239 | 240 | nib.save( 241 | nib.Nifti1Image(val_outputs.astype(np.uint8), original_affine), os.path.join(output_dir, f'{name}.nii.gz') 242 | ) 243 | 244 | 245 | print("liver dice:", np.mean(liver_dice)) 246 | print("liver nsd:", np.mean(liver_nsd)) 247 | print("tumor dice:", np.mean(tumor_dice)) 248 | print("tumor nsd",np.mean(tumor_nsd)) 249 | 250 | # save metrics to cvs file 251 | csv_save = os.path.join(args.save_dir, args.model_name, str(args.val_overlap)) 252 | if not os.path.exists(csv_save): 253 | os.makedirs(csv_save) 254 | csv_name = os.path.join(csv_save, 'metrics.csv') 255 | with open(csv_name, 'w', encoding='UTF8', newline='') as f: 256 | writer = csv.writer(f) 257 | writer.writerow(header) 258 | writer.writerows(rows) 259 | 260 | # save path: save_dir/log_dir_name/str(args.val_overlap)/pred/ 261 | if __name__ == "__main__": 262 | main() 263 | --------------------------------------------------------------------------------