├── .gitignore ├── LICENSE ├── README.md ├── config ├── acdc │ ├── attention_unet_2d.yaml │ ├── attention_unet_3d.yaml │ ├── medformer_2d.yaml │ ├── medformer_3d.yaml │ ├── nnformer_3d.yaml │ ├── resunet_2d.yaml │ ├── resunet_3d.yaml │ ├── swinunet_2d.yaml │ ├── transunet_2d.yaml │ ├── unet++_2d.yaml │ ├── unet++_3d.yaml │ ├── unet_2d.yaml │ ├── unet_3d.yaml │ ├── unetr_3d.yaml │ ├── vnet_3d.yaml │ └── vtunet_3d.yaml ├── amos_ct │ ├── attention_unet_3d.yaml │ ├── medformer_3d.yaml │ ├── resunet_3d.yaml │ └── vtunet_3d.yaml ├── amos_mr │ ├── attention_unet_3d.yaml │ ├── medformer_3d.yaml │ ├── resunet_3d.yaml │ └── vtunet_3d.yaml ├── bcv │ ├── attention_unet_3d.yaml │ ├── medformer_3d.yaml │ ├── nnformer_3d.yaml │ ├── resunet_3d.yaml │ ├── swin_unetr_3d.yaml │ ├── unetr_3d.yaml │ └── vtunet_3d.yaml ├── kits │ ├── attention_unet_3d.yaml │ ├── medformer_3d.yaml │ ├── nnformer_3d.yaml │ ├── resunet_3d.yaml │ ├── swin_unetr_3d.yaml │ ├── unetr_3d.yaml │ └── vtunet_3d.yaml └── lits │ ├── attention_unet_3d.yaml │ ├── medformer_3d.yaml │ ├── nnformer_3d.yaml │ ├── resunet_3d.yaml │ ├── swin_unetr_3d.yaml │ ├── unetr_3d.yaml │ └── vtunet_3d.yaml ├── dataset_conversion ├── __init__.py ├── acdc_2d.py ├── acdc_3d.py ├── amos_3d.py ├── bcv_3d.py ├── kits_3d.py ├── lits_3d.py └── utils.py ├── docs ├── change.md └── tutorial.md ├── inference ├── inference2d.py ├── inference3d.py └── utils.py ├── metric ├── lookup_tables.py ├── metrics.py └── utils.py ├── model ├── __init__.py ├── dim2 │ ├── __init__.py │ ├── attention_unet.py │ ├── attention_unet_utils.py │ ├── conv_layers.py │ ├── dual_attention_unet.py │ ├── dual_attention_utils.py │ ├── medformer.py │ ├── medformer_utils.py │ ├── swin_unet.py │ ├── trans_layers.py │ ├── transunet.py │ ├── unet.py │ ├── unet_utils.py │ ├── unetpp.py │ └── utils.py ├── dim3 │ ├── __init__.py │ ├── attention_unet.py │ ├── attention_unet_utils.py │ ├── conv_layers.py │ ├── medformer.py │ ├── medformer_utils.py │ ├── medformer_utils_v2.py │ ├── nnformer.py │ ├── nnformer_utils.py │ ├── swin_unetr.py │ ├── trans_layers.py │ ├── unet.py │ ├── unet_utils.py │ ├── unetpp.py │ ├── unetr.py │ ├── utils.py │ ├── vnet.py │ ├── vtunet.py │ └── vtunet_utils.py └── utils.py ├── prediction.py ├── requirements.txt ├── train.py ├── train_ddp.py ├── training ├── __init__.py ├── augmentation.py ├── dataset │ ├── __init__.py │ ├── dim2 │ │ ├── __init__.py │ │ └── dataset_acdc.py │ ├── dim3 │ │ ├── dataset_acdc.py │ │ ├── dataset_amos_ct.py │ │ ├── dataset_amos_mr.py │ │ ├── dataset_bcv.py │ │ ├── dataset_kits.py │ │ └── dataset_lits.py │ └── utils.py ├── losses.py ├── utils.py └── validation.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | *_bkp.py 4 | flops.py 5 | checkpoint/ 6 | log*/ 7 | exp/ 8 | initmodel/ 9 | result/ 10 | *.swp 11 | test* 12 | -------------------------------------------------------------------------------- /config/acdc/attention_unet_2d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /research/cbim/medical/yg397/ACDC_2d 3 | classes: 4 4 | modality: mri 5 | 6 | 7 | #MODEL 8 | arch: attention_unet 9 | in_chan: 1 10 | base_chan: 32 11 | 12 | 13 | #TRAIN 14 | epochs: 150 15 | training_size: [256, 256] # training crop size 16 | start_epoch: 0 17 | num_workers: 4 # modify this if I/O or augmentation is slow 18 | aug_device: 'cpu' 19 | 20 | split_seed: 0 # random seed for train/test split (shuffle) before setting cross validation fold 21 | k_fold: 5 # number of folds in cross validation 22 | 23 | optimizer: adamw 24 | base_lr: 0.0005 25 | betas: [0.9, 0.999] 26 | weight_decay: 0.05 # weight decay of optimizer 27 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 28 | rlt: 1 # relation between CE and Dice loss 29 | 30 | scale: 0.3 # scale for data augmentation 31 | rotate: 180 # rotation angle for data augmentation 32 | translate: 0 33 | gaussian_noise_std: 0.02 34 | additive_brightness_std: 0.7 35 | gamma_range: [0.5, 1.6] 36 | 37 | print_freq: 5 38 | 39 | #VALIDATION 40 | ema: True 41 | ema_alpha: 0.99 42 | val_freq: 10 43 | 44 | 45 | 46 | #INFERENCE 47 | sliding_window: False 48 | 49 | 50 | # DDP 51 | world_size: 1 52 | proc_idx: 0 53 | rank: 0 54 | port: 10000 55 | dist_url: 'tcp://localhost:10000' # make sure the port here is the same as in port 56 | dist_backend: "nccl" 57 | multiprocessing_distributed: true 58 | reproduce_seed: 2023 # use any seed you want, or use 'null' to disable deterministic behavior 59 | -------------------------------------------------------------------------------- /config/acdc/attention_unet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/acdc/acdc_3d 3 | classes: 4 4 | modality: mri 5 | 6 | 7 | #MODEL 8 | arch: attention_unet 9 | in_chan: 1 10 | base_chan: 32 11 | down_scale: [[1,2,2], [1,2,2], [2,2,2], [2,2,2]] 12 | kernel_size: [[1,3,3], [1,3,3], [3,3,3], [3,3,3], [3,3,3]] 13 | block: BasicBlock 14 | norm: in 15 | 16 | #TRAIN 17 | epochs: 200 18 | training_size: [16, 192, 192] # training crop size 19 | start_epoch: 0 20 | num_workers: 2 21 | aug_device: 'cpu' 22 | 23 | split_seed: 0 24 | k_fold: 5 25 | 26 | optimizer: adamw 27 | base_lr: 0.001 28 | betas: [0.9, 0.999] 29 | weight_decay: 0.05 # weight decay of SGD optimizer 30 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 31 | rlt: 1 # relation between CE and Dice loss 32 | 33 | scale: [0.1, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 34 | rotate: [30, 0, 0] # rotation angle for data augmentation 35 | translate: [0, 0, 0] 36 | gaussian_noise_std: 0.02 37 | additive_brightness_std: 0.7 38 | gamma_range: [0.5, 1.6] 39 | 40 | 41 | print_freq: 5 42 | iter_per_epoch: 200 43 | 44 | 45 | #VALIDATION 46 | ema: True 47 | ema_alpha: 0.99 48 | val_freq: 10 49 | 50 | 51 | 52 | #INFERENCE 53 | sliding_window: True 54 | window_size: [16, 192, 192] 55 | 56 | # DDP 57 | world_size: 1 58 | proc_idx: 0 59 | rank: 0 60 | port: 10000 61 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 62 | dist_backend: "nccl" 63 | multiprocessing_distributed: true 64 | reproduce_seed: null 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /config/acdc/medformer_2d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /research/cbim/medical/yg397/ACDC_2d 3 | classes: 4 4 | modality: mri 5 | 6 | 7 | #MODEL 8 | arch: medformer 9 | in_chan: 1 10 | base_chan: 32 11 | conv_block: 'BasicBlock' 12 | conv_num: [2,0,0,0, 0,0,2,2] 13 | trans_num: [0,2,2,2, 2,2,0,0] 14 | num_heads: [1,4,8,16, 8,4,1,1] 15 | map_size: 3 16 | expansion: 2 17 | fusion_depth: 2 18 | fusion_dim: 512 19 | fusion_heads: 16 20 | proj_type: 'depthwise' 21 | attn_drop: 0. 22 | proj_drop: 0. 23 | 24 | 25 | 26 | #TRAIN 27 | epochs: 200 28 | training_size: [256, 256] # training crop size 29 | start_epoch: 0 30 | aux_loss: True 31 | aux_weight: [0.5, 0.5] 32 | num_workers: 4 # modify this if I/O or augmentation is slow 33 | aug_device: 'cpu' 34 | 35 | split_seed: 0 # random seed for train/test split (shuffle) before setting cross validation fold 36 | k_fold: 5 # number of folds in cross validation 37 | 38 | optimizer: adamw 39 | base_lr: 0.0005 40 | betas: [0.9, 0.999] 41 | weight_decay: 0.05 # weight decay of SGD optimizer 42 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 43 | rlt: 1 # relation between CE and Dice loss 44 | 45 | scale: 0.3 # scale for data augmentation 46 | rotate: 180 # rotation angle for data augmentation 47 | translate: 0 48 | gaussian_noise_std: 0.02 49 | additive_brightness_std: 0.7 50 | gamma_range: [0.5, 1.6] 51 | 52 | print_freq: 5 53 | 54 | 55 | #VALIDATION 56 | ema: True 57 | ema_alpha: 0.99 58 | val_freq: 10 # evaluate every val_freq epochs 59 | 60 | #INFERENCE 61 | sliding_window: False 62 | 63 | 64 | # DDP 65 | world_size: 1 66 | proc_idx: 0 67 | rank: 0 68 | port: 10000 69 | dist_url: 'tcp://localhost:10000' # make sure the port here is the same as in port 70 | dist_backend: "nccl" 71 | multiprocessing_distributed: true 72 | reproduce_seed: 2023 # use any seed you want, or use 'null' to disable deterministic behavior 73 | -------------------------------------------------------------------------------- /config/acdc/medformer_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/acdc/acdc_3d 3 | classes: 4 4 | modality: mri 5 | 6 | 7 | #MODEL 8 | arch: medformer 9 | in_chan: 1 10 | base_chan: 32 11 | conv_block: 'BasicBlock' 12 | 13 | down_scale: [[1,2,2], [1,2,2], [2,2,2], [2,2,2]] # coresponds to down1 down2 down3 down4 14 | kernel_size: [[1,3,3], [1,3,3], [3,3,3], [3,3,3], [3,3,3]] # coresponds to inconv, down1, down2, down3, down4 15 | norm: in 16 | act: relu 17 | map_size: [2, 6, 6] 18 | conv_num: [2,0,0,0, 0,0,2,2] 19 | trans_num: [0,2,2,2, 2,2,0,0] 20 | num_heads: [1,4,4,4, 4,4,1,1] 21 | expansion: 4 22 | fusion_depth: 2 23 | fusion_dim: 256 24 | fusion_heads: 4 25 | attn_drop: 0. 26 | proj_drop: 0. 27 | proj_type: 'depthwise' 28 | rel_pos: False 29 | se: True 30 | 31 | 32 | #TRAIN 33 | epochs: 150 34 | training_size: [16, 192, 192] # training crop size 35 | start_epoch: 0 36 | num_workers: 2 # modify this if I/O or augmentation is slow 37 | aug_device: 'cpu' 38 | 39 | aux_loss: True 40 | aux_weight: [0.5, 0.5] 41 | 42 | split_seed: 0 # random seed for train/test split (shuffle) before setting cross validation fold 43 | k_fold: 5 # number of folds in cross validation 44 | 45 | optimizer: adamw 46 | base_lr: 0.001 47 | betas: [0.9, 0.999] 48 | weight_decay: 0.05 # weight decay of the optimizer 49 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 50 | rlt: 1 # relation between CE and Dice loss 51 | 52 | print_freq: 1 53 | iter_per_epoch: 200 54 | 55 | 56 | scale: [0.1, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 57 | rotate: [30, 0, 0] # rotation angle for data augmentation 58 | translate: [0, 0, 0] 59 | gaussian_noise_std: 0.02 60 | additive_brightness_std: 0.7 61 | gamma_range: [0.5, 1.6] 62 | 63 | 64 | 65 | 66 | 67 | #VALIDATION 68 | ema: True 69 | ema_alpha: 0.99 70 | val_freq: 10 71 | 72 | 73 | 74 | #INFERENCE 75 | sliding_window: True 76 | window_size: [16, 192, 192] 77 | 78 | 79 | # DDP 80 | world_size: 1 81 | proc_idx: 0 82 | rank: 0 83 | port: 10000 84 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 85 | dist_backend: "nccl" 86 | multiprocessing_distributed: true # if use PyTorch DDP for multi-gpu training 87 | reproduce_seed: 2023 88 | -------------------------------------------------------------------------------- /config/acdc/nnformer_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/acdc/acdc_3d 3 | classes: 4 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: nnformer 9 | in_chan: 1 10 | base_chan: 24 11 | 12 | #TRAIN 13 | epochs: 150 14 | training_size: [16, 192, 192] # training crop size 15 | start_epoch: 0 16 | num_workers: 2 17 | aug_device: 'cpu' 18 | 19 | aux_loss: True 20 | aux_weight: [0.2, 0.3, 0.5] 21 | 22 | split_seed: 0 23 | k_fold: 5 24 | 25 | optimizer: adamw 26 | base_lr: 0.0004 27 | betas: [0.9, 0.999] 28 | weight_decay: 0.05 # weight decay of SGD optimizer 29 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 30 | rlt: 1 # relation between CE and Dice loss 31 | 32 | scale: [0., 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 33 | rotate: [30, 0, 0] # rotation angle for data augmentation 34 | translate: [0, 0, 0] 35 | affine_pad_size: [0, 40, 40] 36 | gaussian_noise_std: 0.02 37 | additive_brightness_std: 0.7 38 | gamma_range: [0.5, 1.5] 39 | 40 | print_freq: 5 41 | iter_per_epoch: 200 42 | 43 | 44 | #VALIDATION 45 | ema: True 46 | ema_alpha: 0.99 47 | val_freq: 50 48 | 49 | 50 | 51 | #INFERENCE 52 | sliding_window: True 53 | window_size: [16, 192, 192] 54 | 55 | 56 | 57 | 58 | # DDP 59 | world_size: 1 60 | proc_idx: 0 61 | rank: 0 62 | port: 10000 63 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 64 | dist_backend: "nccl" 65 | multiprocessing_distributed: true 66 | reproduce_seed: null 67 | 68 | 69 | # dataloader 70 | num_threads: 4 71 | aug_device: 'gpu' 72 | 73 | -------------------------------------------------------------------------------- /config/acdc/resunet_2d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /research/cbim/medical/yg397/ACDC_2d 3 | classes: 4 4 | modality: mri 5 | 6 | 7 | #MODEL 8 | arch: resunet 9 | in_chan: 1 10 | base_chan: 32 11 | block: BasicBlock 12 | 13 | 14 | #TRAIN 15 | epochs: 150 16 | training_size: [256, 256] # training crop size 17 | start_epoch: 0 18 | num_workers: 4 19 | aug_device: 'cpu' 20 | 21 | split_seed: 0 # random seed for train/test split (suffule) before setting cross validation fold 22 | k_fold: 5 # number of folds in cross validation 23 | 24 | optimizer: adamw 25 | base_lr: 0.0005 26 | betas: [0.9, 0.999] 27 | weight_decay: 0.05 # weight decay of the optimizer 28 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 29 | rlt: 1 # relation between CE and Dice loss 30 | 31 | scale: 0.3 # scale for data augmentation 32 | rotate: 180 # rotation angle for data augmentation 33 | translate: 0 34 | gaussian_noise_std: 0.02 35 | additive_brightness_std: 0.7 36 | gamma_range: [0.5, 1.6] 37 | 38 | print_freq: 5 39 | 40 | 41 | #VALIDATION 42 | ema: True 43 | ema_alpha: 0.99 44 | val_freq: 10 45 | 46 | #INFERENCE 47 | sliding_window: False 48 | 49 | # DDP 50 | world_size: 1 51 | proc_idx: 0 52 | rank: 0 53 | port: 10000 54 | dist_url: 'tcp://localhost:10000' # make sure the port here is the same as in port 55 | dist_backend: "nccl" 56 | multiprocessing_distributed: true # if use PyTorch DDP for multi-gpu training 57 | reproduce_seed: 2023 # use any seed you want, or use 'null' to disable deterministic behavior 58 | -------------------------------------------------------------------------------- /config/acdc/resunet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/acdc/acdc_3d 3 | classes: 4 4 | modality: mri 5 | 6 | 7 | #MODEL 8 | arch: resunet 9 | in_chan: 1 10 | base_chan: 32 11 | down_scale: [[1,2,2], [1,2,2], [2,2,2], [2,2,2]] 12 | kernel_size: [[1,3,3], [1,3,3], [3,3,3], [3,3,3], [3,3,3]] 13 | block: BasicBlock 14 | norm: in 15 | 16 | #TRAIN 17 | epochs: 150 18 | training_size: [16, 192, 192] # training crop size 19 | start_epoch: 0 20 | num_workers: 2 # modify this if I/O or augmentation is slow 21 | aug_device: 'cpu' 22 | 23 | split_seed: 0 # random seed for train/test/split (shuffle) before setting cross validation fold 24 | k_fold: 5 # number of folds incross validation 25 | 26 | optimizer: adamw 27 | base_lr: 0.001 28 | betas: [0.9, 0.999] 29 | weight_decay: 0.05 # weight decay of the optimizer 30 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 31 | rlt: 1 # relation between CE and Dice loss 32 | 33 | scale: [0.1, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 34 | rotate: [30, 0, 0] # rotation angle for data augmentation 35 | translate: [0, 0, 0] 36 | gaussian_noise_std: 0.02 37 | additive_brightness_std: 0.7 38 | gamma_range: [0.5, 1.6] 39 | 40 | 41 | print_freq: 5 42 | iter_per_epoch: 200 43 | 44 | 45 | #VALIDATION 46 | ema: True 47 | ema_alpha: 0.99 48 | val_freq: 10 49 | 50 | 51 | 52 | #INFERENCE 53 | sliding_window: True 54 | window_size: [16, 192, 192] 55 | 56 | 57 | # DDP 58 | world_size: 1 59 | proc_idx: 0 60 | rank: 0 61 | port: 10000 62 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 63 | dist_backend: "nccl" 64 | multiprocessing_distributed: true # if use PyTorch DDP for multi-gpu training 65 | reproduce_seed: 2023 66 | -------------------------------------------------------------------------------- /config/acdc/swinunet_2d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /research/cbim/medical/yg397/ACDC_2d 3 | classes: 4 4 | modality: mri 5 | 6 | 7 | #MODEL 8 | arch: swinunet 9 | init_model: '/research/cbim/vast/yg397/github/UTNet/initmodel/swin_tiny_patch4_window7_224.pth' 10 | base_chan: 48 11 | 12 | #TRAIN 13 | epochs: 400 14 | training_size: [224, 224] # training crop size 15 | start_epoch: 0 16 | num_workers: 4 # modify this if I/O or augmentation is slow 17 | aug_device: 'cpu' 18 | 19 | split_seed: 0 # random seed for train/test split (shuffle) before setting cross validation fold 20 | k_fold: 5 # number of folds in cross validation 21 | 22 | 23 | optimizer: adamw 24 | base_lr: 0.0005 25 | betas: [0.9, 0.999] 26 | weight_decay: 0.05 # weight decay of SGD optimizer 27 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 28 | rlt: 1 # relation between CE and Dice loss 29 | 30 | scale: 0.3 # scale for data augmentation 31 | rotate: 180 # rotation angle for data augmentation 32 | translate: 0 33 | gaussian_noise_std: 0.02 34 | additive_brightness_std: 0.7 35 | gamma_range: [0.5, 1.6] 36 | 37 | print_freq: 5 38 | 39 | #VALIDATION 40 | ema: True 41 | ema_alpha: 0.99 42 | val_freq: 10 # evaluate every val_freq epochs 43 | 44 | #INFERENCE 45 | sliding_window: False 46 | 47 | 48 | # DDP 49 | world_size: 1 50 | proc_idx: 0 51 | rank: 0 52 | port: 10000 53 | dist_url: 'tcp://localhost:10000' # make sure the port here is the same as in port 54 | dist_backend: "nccl" 55 | multiprocessing_distributed: true # if use PyTorch DDP for multi-gpu training 56 | reproduce_seed: 2023 # use any seed you want, or use 'null' to disable deterministic behavior 57 | -------------------------------------------------------------------------------- /config/acdc/transunet_2d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /research/cbim/medical/yg397/ACDC_2d 3 | classes: 4 4 | modality: mri 5 | 6 | 7 | #MODEL 8 | arch: transunet 9 | init_model: '/research/cbim/vast/yg397/github/UTNet/initmodel/R50+ViT-B_16.npz' 10 | 11 | #TRAIN 12 | epochs: 150 13 | training_size: [256, 256] # training crop size 14 | start_epoch: 0 15 | num_workers: 4 16 | aug_device: 'cpu' 17 | 18 | split_seed: 0 # random seed for train/test split (shuffle) before setting cross validation fold 19 | k_fold: 5 # number of folds in cross validation 20 | 21 | optimizer: adamw 22 | base_lr: 0.0005 23 | betas: [0.9, 0.999] 24 | weight_decay: 0.05 25 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 26 | rlt: 1 # relation between CE and Dice loss 27 | 28 | scale: 0.3 # scale for data augmentation 29 | rotate: 180 # rotation angle for data augmentation 30 | translate: 0 31 | gaussian_noise_std: 0.02 32 | additive_brightness_std: 0.7 33 | gamma_range: [0.5, 1.6] 34 | 35 | print_freq: 5 36 | 37 | 38 | #VALIDATION 39 | ema: True 40 | ema_alpha: 0.99 41 | val_freq: 10 42 | 43 | #INFERENCE 44 | sliding_window: False 45 | 46 | 47 | # DDP 48 | world_size: 1 49 | proc_idx: 0 50 | rank: 0 51 | port: 10000 52 | dist_url: 'tcp://localhost:10000' # make sure the port here is the same as in port 53 | dist_backend: "nccl" 54 | multiprocessing_distributed: true # if use PyTorch DDP for multi-gpu training 55 | reproduce_seed: 2023 # use any seed you want, or use 'null' to disable deterministic behavior 56 | -------------------------------------------------------------------------------- /config/acdc/unet++_2d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /research/cbim/medical/yg397/ACDC_2d 3 | classes: 4 4 | modality: mri 5 | 6 | 7 | #MODEL 8 | arch: unet++ 9 | in_chan: 1 10 | base_chan: 32 11 | 12 | 13 | #TRAIN 14 | epochs: 150 15 | training_size: [256, 256] # training crop size 16 | start_epoch: 0 17 | num_workers: 4 # modify this if I/O or augmentation is slow 18 | aug_device: 'cpu' 19 | 20 | 21 | split_seed: 0 # random seed for train/test split (shuffle) before setting cross validation fold 22 | k_fold: 5 # number of folds in cross validation 23 | 24 | optimizer: adamw 25 | base_lr: 0.0005 26 | betas: [0.9, 0.999] 27 | weight_decay: 0.05 # weight decay of the optimizer 28 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 29 | rlt: 1 # relation between CE and Dice loss 30 | 31 | scale: 0.3 # scale for data augmentation 32 | rotate: 180 # rotation angle for data augmentation 33 | translate: 0 34 | gaussian_noise_std: 0.02 35 | additive_brightness_std: 0.7 36 | gamma_range: [0.5, 1.6] 37 | 38 | print_freq: 5 39 | 40 | #VALIDATION 41 | ema: True 42 | ema_alpha: 0.99 43 | val_freq: 10 44 | 45 | 46 | 47 | #INFERENCE 48 | sliding_window: False 49 | 50 | 51 | # DDP 52 | world_size: 1 53 | proc_idx: 0 54 | rank: 0 55 | port: 10000 56 | dist_url: 'tcp://localhost:10000' # make sure the port here is the same as in port 57 | dist_backend: "nccl" 58 | multiprocessing_distributed: true # if use PyTorch DDP for multi-gpu training 59 | reproduce_seed: 2023 # use any seed you want, or use 'null' to disable deterministic behavior 60 | 61 | -------------------------------------------------------------------------------- /config/acdc/unet++_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/acdc/acdc_3d 3 | classes: 4 4 | modality: mri 5 | 6 | 7 | #MODEL 8 | arch: unet++ 9 | in_chan: 1 10 | base_chan: 32 11 | down_scale: [[1,2,2], [1,2,2], [2,2,2], [2,2,2]] 12 | kernel_size: [[1,3,3], [1,3,3], [3,3,3], [3,3,3], [3,3,3]] 13 | block: BasicBlock 14 | norm: in 15 | 16 | #TRAIN 17 | epochs: 200 18 | training_size: [16, 192, 192] # training crop size 19 | start_epoch: 0 20 | num_workers: 2 21 | aug_device: 'cpu' 22 | 23 | split_seed: 0 24 | k_fold: 5 25 | 26 | optimizer: adamw 27 | base_lr: 0.001 28 | betas: [0.9, 0.999] 29 | weight_decay: 0.05 # weight decay of SGD optimizer 30 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 31 | rlt: 1 # relation between CE and Dice loss 32 | 33 | scale: [0.1, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 34 | rotate: [30, 0, 0] # rotation angle for data augmentation 35 | translate: [0, 0, 0] 36 | gaussian_noise_std: 0.02 37 | additive_brightness_std: 0.7 38 | gamma_range: [0.5, 1.6] 39 | 40 | print_freq: 5 41 | iter_per_epoch: 200 42 | 43 | 44 | #VALIDATION 45 | ema: True 46 | ema_alpha: 0.99 47 | val_freq: 10 48 | 49 | 50 | 51 | #INFERENCE 52 | sliding_window: True 53 | window_size: [16, 192, 192] 54 | 55 | 56 | 57 | # DDP 58 | world_size: 1 59 | proc_idx: 0 60 | rank: 0 61 | port: 10000 62 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 63 | dist_backend: "nccl" 64 | multiprocessing_distributed: true # if use PyTorch DDP for multi-gpu training 65 | reproduce_seed: 2023 66 | 67 | -------------------------------------------------------------------------------- /config/acdc/unet_2d.yaml: -------------------------------------------------------------------------------- 1 | # DATA 2 | data_root: /research/cbim/medical/yg397/ACDC_2d 3 | classes: 4 4 | modality: mri 5 | 6 | 7 | # MODEL 8 | arch: unet 9 | in_chan: 1 10 | base_chan: 32 11 | block: SingleConv 12 | 13 | 14 | # TRAIN 15 | epochs: 150 16 | training_size: [256, 256] # training crop size 17 | start_epoch: 0 18 | num_workers: 4 # modify this if I/O or augmentation is slow 19 | aug_device: 'cpu' 20 | 21 | split_seed: 0 # random seed for train/test split (shuffle) before setting cross validation fold 22 | k_fold: 5 # number of folds in cross validation 23 | 24 | optimizer: adamw 25 | base_lr: 0.0005 26 | betas: [0.9, 0.999] 27 | weight_decay: 0.05 28 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 29 | rlt: 1 # relation between CE and Dice loss 30 | 31 | scale: 0.3 # scale for data augmentation 32 | rotate: 180 # rotation angle for data augmentation 33 | translate: 0 34 | gaussian_noise_std: 0.02 35 | additive_brightness_std: 0.7 36 | gamma_range: [0.5, 1.6] 37 | 38 | print_freq: 5 39 | 40 | 41 | # VALIDATION 42 | ema: True 43 | ema_alpha: 0.99 44 | val_freq: 10 # evaluate every val_freq epochs 45 | 46 | 47 | 48 | # INFERENCE 49 | sliding_window: False 50 | 51 | 52 | # DDP 53 | world_size: 1 54 | proc_idx: 0 55 | rank: 0 56 | port: 10000 57 | dist_url: 'tcp://localhost:10000' # make sure the port here is the same as in port 58 | dist_backend: "nccl" 59 | multiprocessing_distributed: true # if use PyTorch DDP for multi-gpu training 60 | reproduce_seed: 2023 # use any seed you want, or use 'null' to disable deterministic behavior 61 | -------------------------------------------------------------------------------- /config/acdc/unet_3d.yaml: -------------------------------------------------------------------------------- 1 | # DATA 2 | data_root: /filer/tmp1/yg397/dataset/acdc/acdc_3d 3 | classes: 4 4 | modality: mri 5 | 6 | 7 | # MODEL 8 | arch: unet 9 | in_chan: 1 10 | base_chan: 32 11 | down_scale: [[1,2,2], [2,2,2], [2,2,2], [2,2,2]] 12 | kernel_size: [[1,3,3], [2,3,3], [3,3,3], [3,3,3], [3,3,3]] 13 | block: SingleConv 14 | norm: in 15 | 16 | # TRAIN 17 | epochs: 150 18 | training_size: [16, 192, 192] # training crop size 19 | start_epoch: 0 20 | num_workers: 2 # modify this if I/O or augmentation is slow 21 | aug_device: 'cpu' 22 | 23 | split_seed: 0 # random seed for train/test split (shuffle) before setting cross validation fold 24 | k_fold: 5 # number of folds in cross validation 25 | 26 | optimizer: adamw 27 | base_lr: 0.001 28 | betas: [0.9, 0.999] 29 | weight_decay: 0.05 # weight decay of the optimizer 30 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 31 | rlt: 1 # relation between CE and Dice loss 32 | 33 | scale: [0.1, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 34 | rotate: [30, 0, 0] # rotation angle for data augmentation 35 | translate: [0, 0, 0] 36 | gaussian_noise_std: 0.02 37 | additive_brightness_std: 0.7 38 | gamma_range: [0.5, 1.6] 39 | 40 | print_freq: 5 41 | iter_per_epoch: 200 42 | 43 | 44 | # VALIDATION 45 | ema: True 46 | ema_alpha: 0.99 47 | val_freq: 10 # evaluate every val_freq epochs 48 | 49 | 50 | 51 | # INFERENCE 52 | sliding_window: True 53 | window_size: [16, 192, 192] 54 | 55 | 56 | # DDP 57 | world_size: 1 58 | proc_idx: 0 59 | rank: 0 60 | port: 10000 61 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 62 | dist_backend: "nccl" 63 | multiprocessing_distributed: true # if use PyTorch DDP for multi-gpu training 64 | reproduce_seed: 2023 65 | -------------------------------------------------------------------------------- /config/acdc/unetr_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/acdc/acdc_3d 3 | classes: 4 4 | modality: mri 5 | 6 | 7 | #MODEL 8 | arch: unetr 9 | in_chan: 1 10 | norm: in 11 | init_model: /research/cbim/vast/yg397/ConvFormer/ConvFormer/initmodel/UNETR_model_best_acc.pth 12 | 13 | #TRAIN 14 | epochs: 400 15 | training_size: [16, 192, 192] # training crop size 16 | start_epoch: 0 17 | num_workers: 2 # modify this if I/O or augmentation is slow 18 | aug_device: 'cpu' 19 | 20 | 21 | split_seed: 0 22 | k_fold: 5 23 | 24 | optimizer: adamw 25 | base_lr: 0.0001 26 | betas: [0.9, 0.999] 27 | weight_decay: 0.00005 # weight decay of SGD optimizer 28 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 29 | rlt: 1 # relation between CE and Dice loss 30 | 31 | scale: [0.1, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 32 | rotate: [30, 0, 0] # rotation angle for data augmentation 33 | translate: [0, 0, 0] 34 | gaussian_noise_std: 0.02 35 | additive_brightness_std: 0.7 36 | gamma_range: [0.5, 1.6] 37 | 38 | print_freq: 1 39 | iter_per_epoch: 200 40 | 41 | 42 | #VALIDATION 43 | ema: False 44 | ema_alpha: 0.99 45 | val_freq: 10 46 | 47 | 48 | 49 | #INFERENCE 50 | sliding_window: True 51 | window_size: [16, 192, 192] 52 | 53 | 54 | # DDP 55 | world_size: 1 56 | proc_idx: 0 57 | rank: 0 58 | port: 10000 59 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 60 | dist_backend: "nccl" 61 | multiprocessing_distributed: true # if use PyTorch DDP for multi-gpu training 62 | reproduce_seed: null 63 | -------------------------------------------------------------------------------- /config/acdc/vnet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/acdc/acdc_3d 3 | classes: 4 4 | modality: mri 5 | 6 | 7 | #MODEL 8 | arch: vnet 9 | in_chan: 1 10 | base_chan: 16 11 | downsample_scale: [[1,2,2], [2,2,2], [2,2,2], [2,2,2]] 12 | 13 | #TRAIN 14 | epochs: 250 15 | training_size: [16, 192, 192] # training crop size 16 | start_epoch: 0 17 | num_workers: 2 18 | aug_device: 'cpu' 19 | 20 | split_seed: 0 21 | k_fold: 5 22 | 23 | optimizer: adamw 24 | base_lr: 0.001 25 | betas: [0.9, 0.999] 26 | weight_decay: 0.05 # weight decay of the optimizer 27 | weight: [0.5, 1, 1, 1] # weitght of each class in the loss function 28 | rlt: 1 # relation between CE and Dice loss 29 | 30 | scale: [0.1, 0.3, 0.3] # scale for data augmentation 31 | rotate: [30, 0, 0] # rotation angle for data augmentation 32 | translate: [0, 0, 0] 33 | gaussian_noise_std: 0.02 34 | additive_brightness_std: 0.7 35 | gamma_range: [0.5, 1.6] 36 | 37 | print_freq: 1 38 | iter_per_epoch: 200 39 | 40 | 41 | #VALIDATION 42 | ema: True 43 | ema_alpha: 0.99 44 | val_freq: 10 45 | 46 | 47 | 48 | #INFERENCE 49 | sliding_window: True 50 | window_size: [16, 192, 192] 51 | 52 | 53 | # DDP 54 | world_size: 1 55 | proc_idx: 0 56 | rank: 0 57 | port: 10000 58 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 59 | dist_backend: "nccl" 60 | multiprocessing_distributed: true # if use PyTorch DDP for multi-gpu training 61 | reproduce_seed: null 62 | -------------------------------------------------------------------------------- /config/acdc/vtunet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/acdc/acdc_3d 3 | classes: 4 4 | modality: MRI 5 | 6 | 7 | #MODEL 8 | arch: vtunet 9 | init_model: '/research/cbim/vast/yg397/github/UTNet/initmodel/swin_tiny_patch4_window7_224.pth' 10 | in_chan: 1 11 | patch_size: [1, 4, 4] 12 | 13 | #TRAIN 14 | epochs: 400 15 | training_size: [16, 128, 128] # training crop size 16 | start_epoch: 0 17 | aux_loss: False 18 | aux_weight: [1, 0.1, 0.1, 0.1] 19 | num_workers: 2 20 | aug_device: 'cpu' 21 | 22 | split_seed: 0 23 | k_fold: 5 24 | 25 | optimizer: adamw 26 | base_lr: 0.001 27 | betas: [0.9, 0.999] 28 | weight_decay: 0.05 # weight decay of the optimizer 29 | weight: [0.5, 1, 1, 1] 30 | rlt: 1 # relation between CE and Dice loss 31 | 32 | 33 | scale: [0.1, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 34 | rotate: [30, 0, 0] # rotation angle for data augmentation 35 | translate: [0, 0, 0] 36 | gaussian_noise_std: 0.02 37 | additive_brightness_std: 0.7 38 | gamma_range: [0.5, 1.6] 39 | 40 | print_freq: 1 41 | iter_per_epoch: 200 42 | 43 | 44 | 45 | 46 | #VALIDATION 47 | ema: True 48 | ema_alpha: 0.99 49 | val_freq: 10 50 | 51 | 52 | 53 | #INFERENCE 54 | sliding_window: True 55 | window_size: [16, 128, 128] 56 | 57 | 58 | # DDP 59 | world_size: 1 60 | proc_idx: 0 61 | rank: 0 62 | port: 10000 63 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 64 | dist_backend: "nccl" 65 | multiprocessing_distributed: true # if use PyTorch DDP for multi-gpu training 66 | reproduce_seed: null 67 | -------------------------------------------------------------------------------- /config/amos_ct/attention_unet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/amos/cbim/amos_ct_3d 3 | classes: 16 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: attention_unet 9 | in_chan: 1 10 | base_chan: 32 11 | down_scale: [[2,2,2], [2,2,2], [2,2,2], [2,2,2]] 12 | kernel_size: [[3,3,3], [3,3,3], [3,3,3], [3,3,3], [3,3,3]] 13 | block: BasicBlock 14 | norm: in 15 | 16 | #TRAIN 17 | epochs: 400 18 | training_size: [128, 128, 128] # training crop size 19 | start_epoch: 0 20 | num_workers: 2 21 | aug_device: 'cpu' 22 | 23 | split_seed: 0 24 | k_fold: 1 25 | 26 | warmup_epoch: 5 27 | optimizer: adamw 28 | base_lr: 0.0006 29 | betas: [0.9, 0.999] 30 | weight_decay: 0.05 # weight decay of SGD optimizer 31 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] # weitght of each class in the loss function 32 | rlt: 1 # relation between CE and Dice loss 33 | 34 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 35 | rotate: [30, 30, 30] # rotation angle for data augmentation 36 | translate: [0, 0, 0] 37 | affine_pad_size: [40, 40, 40] 38 | gaussian_noise_std: 0.02 39 | 40 | print_freq: 5 41 | iter_per_epoch: 500 42 | 43 | 44 | #VALIDATION 45 | ema: True 46 | ema_alpha: 0.99 47 | val_freq: 50 48 | 49 | 50 | 51 | #INFERENCE 52 | sliding_window: True 53 | window_size: [128, 128, 128] 54 | 55 | 56 | # DDP 57 | world_size: 1 58 | proc_idx: 0 59 | rank: 0 60 | port: 10000 61 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 62 | dist_backend: "nccl" 63 | multiprocessing_distributed: true 64 | reproduce_seed: null 65 | 66 | -------------------------------------------------------------------------------- /config/amos_ct/medformer_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/amos/cbim/amos_ct_3d 3 | classes: 16 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: medformer 9 | in_chan: 1 10 | base_chan: 32 11 | conv_block: 'BasicBlock' 12 | 13 | down_scale: [[2,2,2], [2,2,2], [2,2,2], [2,2,2]] 14 | kernel_size: [[3,3,3], [3,3,3], [3,3,3], [3,3,3], [3,3,3]] 15 | chan_num: [64, 128, 256, 320, 256, 128, 64, 32] 16 | norm: in 17 | act: relu 18 | map_size: [4, 4, 4] 19 | conv_num: [2,1,0,0, 0,1,2,2] 20 | trans_num: [0,1,4,6, 4,1,0,0] 21 | num_heads: [1,4,8,10, 8,4,1,1] 22 | expansion: 4 23 | fusion_depth: 2 24 | fusion_dim: 320 25 | fusion_heads: 10 26 | attn_drop: 0. 27 | proj_drop: 0. 28 | proj_type: 'depthwise' 29 | 30 | 31 | #TRAIN 32 | epochs: 400 33 | training_size: [128, 128, 128] # training crop size 34 | start_epoch: 0 35 | num_workers: 2 36 | aug_device: 'cpu' # 'cpu' or 'gpu'. 'gpu' augmentation will consume more GPU memory, but much faster for 3D inputs 37 | 38 | aux_loss: True 39 | aux_weight: [0.5, 0.5] 40 | 41 | split_seed: 0 42 | k_fold: 1 43 | 44 | optimizer: adamw 45 | base_lr: 0.0006 46 | betas: [0.9, 0.999] 47 | weight_decay: 0.05 # weight decay of the optimizer 48 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 49 | rlt: 1 # relation between CE and Dice loss 50 | 51 | 52 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 53 | rotate: [30, 30, 30] # rotation angle for data augmentation 54 | translate: [0, 0, 0] 55 | affine_pad_size: [40, 40, 40] 56 | gaussian_noise_std: 0.02 57 | 58 | print_freq: 5 59 | iter_per_epoch: 500 60 | 61 | 62 | 63 | 64 | #VALIDATION 65 | ema: True 66 | ema_alpha: 0.99 67 | val_freq: 50 68 | 69 | 70 | 71 | #INFERENCE 72 | sliding_window: True 73 | window_size: [128, 128, 128] 74 | 75 | 76 | # DDP 77 | world_size: 1 78 | proc_idx: 0 79 | rank: 0 80 | port: 10000 81 | dist_url: 'tcp://localhost:10003' # the port number here should be the same as the previous one 82 | dist_backend: "nccl" 83 | multiprocessing_distributed: False 84 | reproduce_seed: null 85 | 86 | 87 | -------------------------------------------------------------------------------- /config/amos_ct/resunet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/amos/cbim/amos_ct_3d 3 | classes: 16 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: resunet 9 | in_chan: 1 10 | base_chan: 32 11 | down_scale: [[2,2,2], [2,2,2], [2,2,2], [2,2,2]] 12 | kernel_size: [[3,3,3], [3,3,3], [3,3,3], [3,3,3], [3,3,3]] 13 | block: BasicBlock 14 | norm: in 15 | 16 | #TRAIN 17 | epochs: 400 18 | training_size: [128, 128, 128] # training crop size 19 | start_epoch: 0 20 | num_workers: 2 21 | aug_device: 'cpu' # 'cpu' or 'gpu' for both pytorch and DALI dataloader. 'gpu' augmentation will consume more GPU memory, but much faster for 3D inputs 22 | 23 | split_seed: 0 24 | k_fold: 1 25 | 26 | warmup_epoch: 5 27 | optimizer: adamw 28 | base_lr: 0.0006 29 | momentum: 0.99 30 | betas: [0.9, 0.999] 31 | weight_decay: 0.05 # weight decay of SGD optimizer 32 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] # weitght of each class in the loss function 33 | rlt: 1 # relation between CE and Dice loss 34 | 35 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 36 | rotate: [30, 30, 30] # rotation angle for data augmentation 37 | affine_pad_size: [40, 40, 40] 38 | translate: [0, 0, 0] 39 | gaussian_noise_std: 0.02 40 | 41 | print_freq: 5 42 | iter_per_epoch: 500 43 | 44 | 45 | #VALIDATION 46 | ema: True 47 | ema_alpha: 0.99 48 | val_freq: 20 49 | 50 | 51 | 52 | #INFERENCE 53 | sliding_window: True 54 | window_size: [128, 128, 128] 55 | 56 | 57 | 58 | 59 | # DDP 60 | world_size: 1 61 | proc_idx: 0 62 | rank: 0 63 | port: 10000 64 | dist_url: 'tcp://localhost:10001' # the port number here should be the same as the previous one 65 | dist_backend: "nccl" 66 | multiprocessing_distributed: true 67 | reproduce_seed: null 68 | 69 | 70 | -------------------------------------------------------------------------------- /config/amos_ct/vtunet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/amos/cbim/amos_ct_3d 3 | classes: 16 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: vtunet 9 | init_model: '/research/cbim/vast/yg397/github/UTNet/initmodel/swin_tiny_patch4_window7_224.pth' 10 | in_chan: 1 11 | patch_size: [4, 4, 4] 12 | 13 | #TRAIN 14 | epochs: 400 15 | training_size: [128, 128, 128] # training crop size 16 | start_epoch: 0 17 | num_workers: 2 18 | aug_device: 'gpu' 19 | 20 | aux_loss: False 21 | aux_weight: [1, 0.1, 0.1, 0.1] 22 | 23 | split_seed: 0 24 | k_fold: 1 25 | 26 | optimizer: adamw 27 | base_lr: 0.0006 28 | betas: [0.9, 0.999] 29 | #momentum: 0.9 # momentum of SGD optimizer 30 | weight_decay: 0.05 # weight decay of SGD optimizer 31 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 32 | rlt: 1 # relation between CE and Dice loss 33 | 34 | 35 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 36 | rotate: [30, 30, 30] # rotation angle for data augmentation 37 | affine_pad_size: [40, 40, 40] 38 | translate: [0, 0, 0] 39 | gaussian_noise_std: 0.02 40 | 41 | print_freq: 5 42 | iter_per_epoch: 500 43 | 44 | 45 | 46 | 47 | #VALIDATION 48 | ema: True 49 | ema_alpha: 0.99 50 | val_freq: 50 51 | 52 | 53 | 54 | #INFERENCE 55 | sliding_window: True 56 | window_size: [128, 128, 128] 57 | 58 | 59 | 60 | # DDP 61 | world_size: 1 62 | proc_idx: 0 63 | rank: 0 64 | port: 10000 65 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 66 | dist_backend: "nccl" 67 | multiprocessing_distributed: true 68 | reproduce_seed: null 69 | 70 | -------------------------------------------------------------------------------- /config/amos_mr/attention_unet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/amos/cbim/amos_mr_3d 3 | classes: 16 4 | modality: MR 5 | 6 | 7 | #MODEL 8 | arch: attention_unet 9 | in_chan: 1 10 | base_chan: 32 11 | down_scale: [[2,2,2], [2,2,2], [2,2,2], [2,2,2]] 12 | kernel_size: [[3,3,3], [3,3,3], [3,3,3], [3,3,3], [3,3,3]] 13 | block: BasicBlock 14 | norm: in 15 | 16 | #TRAIN 17 | epochs: 200 18 | training_size: [128, 128, 128] # training crop size 19 | start_epoch: 0 20 | num_workers: 2 21 | aug_device: 'cpu' 22 | 23 | split_seed: 0 24 | k_fold: 1 25 | 26 | warmup_epoch: 5 27 | optimizer: adamw 28 | base_lr: 0.0006 29 | betas: [0.9, 0.999] 30 | weight_decay: 0.05 # weight decay of SGD optimizer 31 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] # weitght of each class in the loss function 32 | rlt: 1 # relation between CE and Dice loss 33 | 34 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 35 | rotate: [30, 30, 30] # rotation angle for data augmentation 36 | translate: [0, 0, 0] 37 | affine_pad_size: [40, 40, 40] 38 | gaussian_noise_std: 0.02 39 | 40 | print_freq: 5 41 | iter_per_epoch: 80 42 | 43 | 44 | #VALIDATION 45 | ema: True 46 | ema_alpha: 0.99 47 | val_freq: 20 48 | 49 | 50 | 51 | #INFERENCE 52 | sliding_window: True 53 | window_size: [128, 128, 128] 54 | 55 | 56 | # DDP 57 | world_size: 1 58 | proc_idx: 0 59 | rank: 0 60 | port: 10000 61 | dist_url: 'tcp://localhost:10001' # the port number here should be the same as the previous one 62 | dist_backend: "nccl" 63 | multiprocessing_distributed: true 64 | reproduce_seed: null 65 | 66 | -------------------------------------------------------------------------------- /config/amos_mr/medformer_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/amos/cbim/amos_mr_3d 3 | classes: 16 4 | modality: MR 5 | 6 | 7 | #MODEL 8 | arch: medformer 9 | in_chan: 1 10 | base_chan: 32 11 | conv_block: 'BasicBlock' 12 | 13 | down_scale: [[2,2,2], [2,2,2], [2,2,2], [2,2,2]] 14 | kernel_size: [[3,3,3], [3,3,3], [3,3,3], [3,3,3], [3,3,3]] 15 | chan_num: [64, 128, 256, 320, 256, 128, 64, 32] 16 | norm: in 17 | act: relu 18 | map_size: [4, 4, 4] 19 | conv_num: [2,1,0,0, 0,1,2,2] 20 | trans_num: [0,1,4,6, 4,1,0,0] 21 | num_heads: [1,4,8,10, 8,4,1,1] 22 | expansion: 4 23 | fusion_depth: 2 24 | fusion_dim: 320 25 | fusion_heads: 10 26 | attn_drop: 0. 27 | proj_drop: 0. 28 | proj_type: 'depthwise' 29 | 30 | 31 | #TRAIN 32 | epochs: 200 33 | training_size: [48, 160, 224] # training crop size 34 | start_epoch: 0 35 | num_workers: 1 36 | aug_device: 'cpu' 37 | 38 | aux_loss: True 39 | aux_weight: [0.5, 0.5] 40 | 41 | split_seed: 0 42 | k_fold: 1 43 | 44 | optimizer: adamw 45 | base_lr: 0.0006 46 | betas: [0.9, 0.999] 47 | weight_decay: 0.05 # weight decay of the optimizer 48 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 49 | rlt: 1 # relation between CE and Dice loss 50 | 51 | 52 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 53 | rotate: [30, 30, 30] # rotation angle for data augmentation 54 | translate: [0, 0, 0] 55 | affine_pad_size: [40, 40, 40] 56 | gaussian_noise_std: 0.02 57 | 58 | print_freq: 5 59 | iter_per_epoch: 80 60 | 61 | 62 | 63 | 64 | #VALIDATION 65 | ema: True 66 | ema_alpha: 0.99 67 | val_freq: 40 68 | 69 | 70 | 71 | #INFERENCE 72 | sliding_window: True 73 | window_size: [48, 160, 224] 74 | 75 | 76 | # DDP 77 | world_size: 1 78 | proc_idx: 0 79 | rank: 0 80 | port: 10000 81 | dist_url: 'tcp://localhost:10003' # the port number here should be the same as the previous one 82 | dist_backend: "nccl" 83 | multiprocessing_distributed: true 84 | reproduce_seed: null 85 | 86 | 87 | -------------------------------------------------------------------------------- /config/amos_mr/resunet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/amos/cbim/amos_mr_3d 3 | classes: 16 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: resunet 9 | in_chan: 1 10 | base_chan: 32 11 | down_scale: [[2,2,2], [2,2,2], [2,2,2], [2,2,2]] 12 | kernel_size: [[3,3,3], [3,3,3], [3,3,3], [3,3,3], [3,3,3]] 13 | block: BasicBlock 14 | norm: in 15 | 16 | #TRAIN 17 | epochs: 200 18 | training_size: [128, 128, 128] # training crop size 19 | start_epoch: 0 20 | num_workers: 2 21 | aug_device: 'cpu' # 'cpu' or 'gpu'. 'gpu' augmentation will consume more GPU memory, but much faster for 3D inputs 22 | 23 | split_seed: 0 24 | k_fold: 1 25 | 26 | optimizer: adamw 27 | base_lr: 0.0006 28 | warmup_epoch: 5 29 | momentum: 0.99 30 | betas: [0.9, 0.999] 31 | weight_decay: 0.05 # weight decay of the optimizer 32 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] # weitght of each class in the loss function 33 | rlt: 1 # relation between CE and Dice loss 34 | 35 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 36 | rotate: [30, 30, 30] # rotation angle for data augmentation 37 | translate: [0, 0, 0] 38 | affine_pad_size: [40, 40, 40] # crop trick to reduce computation for affine augmentation 39 | gaussian_noise_std: 0.02 40 | 41 | print_freq: 5 42 | iter_per_epoch: 80 43 | 44 | 45 | #VALIDATION 46 | ema: True 47 | ema_alpha: 0.99 48 | val_freq: 40 49 | 50 | 51 | 52 | #INFERENCE 53 | sliding_window: True 54 | window_size: [128, 128, 128] 55 | 56 | 57 | 58 | 59 | # DDP 60 | world_size: 1 61 | proc_idx: 0 62 | rank: 0 63 | port: 10000 64 | dist_url: 'tcp://localhost:10008' # the port number here should be the same as the previous one 65 | dist_backend: "nccl" 66 | multiprocessing_distributed: true 67 | reproduce_seed: null 68 | 69 | -------------------------------------------------------------------------------- /config/amos_mr/vtunet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/amos/cbim/amos_mr_3d 3 | classes: 16 4 | modality: MR 5 | 6 | 7 | #MODEL 8 | arch: vtunet 9 | init_model: '/research/cbim/vast/yg397/github/UTNet/initmodel/swin_tiny_patch4_window7_224.pth' 10 | in_chan: 1 11 | patch_size: [4, 4, 4] 12 | 13 | #TRAIN 14 | epochs: 300 15 | training_size: [128, 128, 128] # training crop size 16 | start_epoch: 0 17 | num_workers: 2 18 | aug_device: 'cpu' 19 | 20 | aux_loss: False 21 | aux_weight: [1, 0.1, 0.1, 0.1] 22 | 23 | split_seed: 0 24 | k_fold: 1 25 | 26 | optimizer: adamw 27 | base_lr: 0.0006 28 | betas: [0.9, 0.999] 29 | #momentum: 0.9 # momentum of SGD optimizer 30 | weight_decay: 0.05 # weight decay of SGD optimizer 31 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 32 | rlt: 1 # relation between CE and Dice loss 33 | 34 | 35 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 36 | rotate: [30, 30, 30] # rotation angle for data augmentation 37 | affine_pad_size: [40, 40, 40] 38 | translate: [0, 0, 0] 39 | gaussian_noise_std: 0.02 40 | 41 | print_freq: 5 42 | iter_per_epoch: 80 43 | 44 | 45 | 46 | 47 | #VALIDATION 48 | ema: True 49 | ema_alpha: 0.99 50 | val_freq: 100 51 | 52 | 53 | 54 | #INFERENCE 55 | sliding_window: True 56 | window_size: [128, 128, 128] 57 | 58 | 59 | 60 | # DDP 61 | world_size: 1 62 | proc_idx: 0 63 | rank: 0 64 | port: 10000 65 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 66 | dist_backend: "nccl" 67 | multiprocessing_distributed: true 68 | reproduce_seed: null 69 | 70 | -------------------------------------------------------------------------------- /config/bcv/attention_unet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/bcv/bcv_3d 3 | classes: 14 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: attention_unet 9 | in_chan: 1 10 | base_chan: 32 11 | down_scale: [[1,2,2], [1,2,2], [2,2,2], [2,2,2]] 12 | kernel_size: [[1,3,3], [1,3,3], [3,3,3], [3,3,3], [3,3,3]] 13 | block: BasicBlock 14 | norm: in 15 | 16 | #TRAIN 17 | epochs: 150 18 | training_size: [32, 128, 128] # training crop size 19 | start_epoch: 0 20 | num_workers: 2 21 | aug_device: 'cpu' 22 | 23 | split_seed: 0 24 | k_fold: 5 25 | 26 | optimizer: adamw 27 | base_lr: 0.0006 28 | betas: [0.9, 0.999] 29 | weight_decay: 0.05 # weight decay of SGD optimizer 30 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] # weitght of each class in the loss function 31 | rlt: 1 # relation between CE and Dice loss 32 | 33 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 34 | rotate: [30, 30, 30] # rotation angle for data augmentation 35 | translate: [0, 0, 0] 36 | gaussian_noise_std: 0.02 37 | 38 | print_freq: 5 39 | iter_per_epoch: 300 40 | 41 | 42 | #VALIDATION 43 | ema: True 44 | ema_alpha: 0.99 45 | val_freq: 10 46 | 47 | 48 | 49 | #INFERENCE 50 | sliding_window: True 51 | window_size: [32, 128, 128] 52 | 53 | 54 | # DDP 55 | world_size: 1 56 | proc_idx: 0 57 | rank: 0 58 | port: 10000 59 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 60 | dist_backend: "nccl" 61 | multiprocessing_distributed: true 62 | reproduce_seed: null 63 | -------------------------------------------------------------------------------- /config/bcv/medformer_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/bcv/bcv_3d 3 | classes: 14 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: medformer 9 | in_chan: 1 10 | base_chan: 32 11 | conv_block: 'BasicBlock' 12 | 13 | down_scale: [[1,2,2], [1,2,2], [2,2,2], [2,2,2]] 14 | kernel_size: [[1,3,3], [1,3,3], [3,3,3], [3,3,3], [3,3,3]] 15 | chan_num: [64, 128, 256, 320, 256, 128, 64, 32] 16 | norm: in 17 | act: relu 18 | map_size: [3,3,3] 19 | conv_num: [2,0,0,0, 0,0,2,2] 20 | trans_num: [0,2,4,6, 4,2,0,0] 21 | num_heads: [1,4,8,10, 8,4,1,1] 22 | expansion: 4 23 | fusion_depth: 2 24 | fusion_dim: 320 25 | fusion_heads: 10 26 | attn_drop: 0 27 | proj_drop: 0 28 | proj_type: 'depthwise' 29 | 30 | 31 | #TRAIN 32 | epochs: 150 33 | training_size: [32, 128, 128] # training crop size 34 | start_epoch: 0 35 | num_workers: 2 36 | aug_device: 'cpu' 37 | 38 | aux_loss: True 39 | aux_weight: [0.5, 0.5] 40 | 41 | split_seed: 0 42 | k_fold: 5 43 | 44 | optimizer: adamw 45 | base_lr: 0.0006 46 | betas: [0.9, 0.999] 47 | #momentum: 0.9 # momentum of SGD optimizer 48 | weight_decay: 0.05 # weight decay of SGD optimizer 49 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 50 | rlt: 1 # relation between CE and Dice loss 51 | 52 | 53 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0 0.3 0.3 54 | rotate: [30, 30, 30] # rotation angle for data augmentation 30 55 | translate: [0, 0, 0] 56 | affine_pad_size: [40, 40, 40] 57 | gaussian_noise_std: 0.02 58 | 59 | print_freq: 5 60 | iter_per_epoch: 300 61 | 62 | 63 | 64 | 65 | #VALIDATION 66 | ema: True 67 | ema_alpha: 0.99 68 | val_freq: 50 69 | 70 | 71 | 72 | #INFERENCE 73 | sliding_window: True 74 | window_size: [32, 128, 128] 75 | 76 | 77 | 78 | # DDP 79 | world_size: 1 80 | proc_idx: 0 81 | rank: 0 82 | port: 10000 83 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 84 | dist_backend: "nccl" 85 | multiprocessing_distributed: true 86 | reproduce_seed: null 87 | 88 | 89 | -------------------------------------------------------------------------------- /config/bcv/nnformer_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/bcv/bcv_3d 3 | classes: 14 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: nnformer 9 | in_chan: 1 10 | base_chan: 24 11 | 12 | #TRAIN 13 | epochs: 150 14 | training_size: [128, 128, 128] # training crop size 15 | start_epoch: 0 16 | num_workers: 2 17 | aug_device: 'cpu' 18 | 19 | aux_loss: True 20 | aux_weight: [0.2,0.3,0.5] 21 | 22 | split_seed: 0 23 | k_fold: 5 24 | 25 | optimizer: adamw 26 | base_lr: 0.0006 27 | betas: [0.9, 0.999] 28 | weight_decay: 0.05 # weight decay of SGD optimizer 29 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] # weitght of each class in the loss function 30 | rlt: 1 # relation between CE and Dice loss 31 | 32 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 33 | rotate: [30, 30, 30] # rotation angle for data augmentation 34 | translate: [0, 0, 0] 35 | gaussian_noise_std: 0.02 36 | 37 | print_freq: 5 38 | iter_per_epoch: 300 39 | 40 | 41 | #VALIDATION 42 | ema: True 43 | ema_alpha: 0.99 44 | val_freq: 10 45 | 46 | 47 | 48 | #INFERENCE 49 | sliding_window: True 50 | window_size: [128, 128, 128] 51 | 52 | 53 | 54 | 55 | # DDP 56 | world_size: 1 57 | proc_idx: 0 58 | rank: 0 59 | port: 10000 60 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 61 | dist_backend: "nccl" 62 | multiprocessing_distributed: true 63 | reproduce_seed: null 64 | -------------------------------------------------------------------------------- /config/bcv/resunet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/bcv/bcv_3d 3 | classes: 14 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: resunet 9 | in_chan: 1 10 | base_chan: 32 11 | down_scale: [[1,2,2], [1,2,2], [2,2,2], [2,2,2]] 12 | kernel_size: [[1,3,3], [1,3,3], [3,3,3], [3,3,3], [3,3,3]] 13 | block: BasicBlock 14 | norm: in 15 | 16 | #TRAIN 17 | epochs: 150 18 | training_size: [32, 128, 128] # training crop size 19 | start_epoch: 0 20 | num_workers: 2 21 | aug_device: 'cpu' 22 | 23 | split_seed: 0 24 | k_fold: 5 25 | 26 | optimizer: adamw 27 | base_lr: 0.0006 28 | betas: [0.9, 0.999] 29 | weight_decay: 0.05 # weight decay of SGD optimizer 30 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] # weitght of each class in the loss function 31 | rlt: 1 # relation between CE and Dice loss 32 | 33 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 34 | rotate: [30, 30, 30] # rotation angle for data augmentation 35 | translate: [0, 0, 0] 36 | gaussian_noise_std: 0.02 37 | 38 | print_freq: 5 39 | iter_per_epoch: 300 40 | 41 | 42 | #VALIDATION 43 | ema: True 44 | ema_alpha: 0.99 45 | val_freq: 10 46 | 47 | 48 | 49 | #INFERENCE 50 | sliding_window: True 51 | window_size: [32, 128, 128] 52 | 53 | 54 | 55 | 56 | # DDP 57 | world_size: 1 58 | proc_idx: 0 59 | rank: 0 60 | port: 10000 61 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 62 | dist_backend: "nccl" 63 | multiprocessing_distributed: true 64 | reproduce_seed: null 65 | 66 | -------------------------------------------------------------------------------- /config/bcv/swin_unetr_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/bcv/bcv_3d 3 | classes: 14 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: swin_unetr 9 | in_chan: 1 10 | base_chan: 48 11 | 12 | #TRAIN 13 | epochs: 150 14 | training_size: [128, 128, 128] # training crop size 15 | start_epoch: 0 16 | num_workers: 2 17 | aug_device: 'cpu' 18 | 19 | split_seed: 0 20 | k_fold: 5 21 | 22 | optimizer: adamw 23 | base_lr: 0.0006 24 | betas: [0.9, 0.999] 25 | weight_decay: 0.05 # weight decay of SGD optimizer 26 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] # weitght of each class in the loss function 27 | rlt: 1 # relation between CE and Dice loss 28 | 29 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 30 | rotate: [30, 30, 30] # rotation angle for data augmentation 31 | translate: [0, 0, 0] 32 | gaussian_noise_std: 0.02 33 | 34 | print_freq: 5 35 | iter_per_epoch: 300 36 | 37 | 38 | #VALIDATION 39 | ema: True 40 | ema_alpha: 0.99 41 | val_freq: 50 42 | 43 | 44 | 45 | #INFERENCE 46 | sliding_window: True 47 | window_size: [128, 128, 128] 48 | 49 | 50 | 51 | 52 | # DDP 53 | world_size: 1 54 | proc_idx: 0 55 | rank: 0 56 | port: 10000 57 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 58 | dist_backend: "nccl" 59 | multiprocessing_distributed: true 60 | reproduce_seed: null 61 | 62 | -------------------------------------------------------------------------------- /config/bcv/unetr_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/bcv/bcv_3d 3 | classes: 14 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: unetr 9 | in_chan: 1 10 | norm: in 11 | init_model: /research/cbim/vast/yg397/ConvFormer/ConvFormer/initmodel/UNETR_model_best_acc.pth 12 | 13 | #TRAIN 14 | epochs: 300 15 | training_size: [96, 96, 96] # training crop size 16 | start_epoch: 0 17 | num_workers: 2 18 | aug_device: 'cpu' 19 | 20 | split_seed: 0 21 | k_fold: 5 22 | 23 | optimizer: adamw 24 | base_lr: 0.0001 25 | betas: [0.9, 0.999] 26 | weight_decay: 0.00005 # weight decay of SGD optimizer 27 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 28 | rlt: 1 # relation between CE and Dice loss 29 | 30 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 31 | rotate: [30, 30, 30] # rotation angle for data augmentation 32 | translate: [0, 0, 0] 33 | gaussian_noise_std: 0.02 34 | 35 | print_freq: 5 36 | iter_per_epoch: 300 37 | 38 | 39 | #VALIDATION 40 | ema: False 41 | ema_alpha: 0.99 42 | val_freq: 50 43 | 44 | 45 | 46 | #INFERENCE 47 | sliding_window: True 48 | window_size: [96, 96, 96] 49 | 50 | 51 | # DDP 52 | world_size: 1 53 | proc_idx: 0 54 | rank: 0 55 | port: 10000 56 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 57 | dist_backend: "nccl" 58 | multiprocessing_distributed: true 59 | reproduce_seed: null 60 | -------------------------------------------------------------------------------- /config/bcv/vtunet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/bcv/bcv_3d 3 | classes: 14 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: vtunet 9 | init_model: '/research/cbim/vast/yg397/github/UTNet/initmodel/swin_tiny_patch4_window7_224.pth' 10 | in_chan: 1 11 | patch_size: [4, 4, 4] 12 | 13 | #TRAIN 14 | epochs: 150 15 | training_size: [64, 128, 128] # training crop size 16 | start_epoch: 0 17 | num_workers: 2 18 | aug_device: 'cpu' 19 | 20 | aux_loss: False 21 | aux_weight: [1, 0.1, 0.1, 0.1] 22 | 23 | split_seed: 0 24 | k_fold: 5 25 | 26 | optimizer: adamw 27 | base_lr: 0.0006 28 | betas: [0.9, 0.999] 29 | #momentum: 0.9 # momentum of SGD optimizer 30 | weight_decay: 0.05 # weight decay of SGD optimizer 31 | weight: [0.5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] 32 | rlt: 1 # relation between CE and Dice loss 33 | 34 | 35 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 36 | rotate: [30, 30, 30] # rotation angle for data augmentation 37 | translate: [0, 0, 0] 38 | affine_pad_size : [40, 40, 40] 39 | gaussian_noise_std: 0.02 40 | 41 | print_freq: 5 42 | iter_per_epoch: 300 43 | 44 | 45 | 46 | 47 | #VALIDATION 48 | ema: True 49 | ema_alpha: 0.99 50 | val_freq: 10 51 | 52 | 53 | 54 | #INFERENCE 55 | sliding_window: True 56 | window_size: [64, 128, 128] 57 | 58 | 59 | 60 | # DDP 61 | world_size: 1 62 | proc_idx: 0 63 | rank: 0 64 | port: 10000 65 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 66 | dist_backend: "nccl" 67 | multiprocessing_distributed: true 68 | reproduce_seed: null 69 | 70 | -------------------------------------------------------------------------------- /config/kits/attention_unet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/kits/kits_3d 3 | classes: 3 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: attention_resunet 9 | in_chan: 1 10 | base_chan: 32 11 | down_scale: [[2,2,2], [2,2,2], [2,2,2], [2,2,2]] 12 | kernel_size: [[3,3,3], [3,3,3], [3,3,3], [3,3,3], [3,3,3]] 13 | block: BasicBlock 14 | norm: in 15 | 16 | #TRAIN 17 | epochs: 300 18 | training_size: [128, 128, 128] # training crop size 19 | start_epoch: 0 20 | num_workers: 2 21 | aug_device: 'cpu' 22 | 23 | split_seed: 0 24 | k_fold: 5 25 | 26 | optimizer: adamw 27 | base_lr: 0.0004 28 | betas: [0.9, 0.999] 29 | weight_decay: 0.05 # weight decay of SGD optimizer 30 | weight: [0.5, 1, 2] # weitght of each class in the loss function 31 | rlt: 1 # relation between CE and Dice loss 32 | 33 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 34 | rotate: [30, 30, 30] # rotation angle for data augmentation 35 | translate: [0, 0, 0] 36 | affine_pad_size: [40, 40, 40] 37 | gaussian_noise_std: 0.02 38 | 39 | print_freq: 5 40 | iter_per_epoch: 500 41 | 42 | 43 | #VALIDATION 44 | ema: True 45 | ema_alpha: 0.99 46 | val_freq: 20 47 | 48 | 49 | 50 | #INFERENCE 51 | sliding_window: True 52 | window_size: [128, 128, 128] 53 | 54 | 55 | 56 | # DDP 57 | world_size: 1 58 | proc_idx: 0 59 | rank: 0 60 | port: 10000 61 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 62 | dist_backend: "nccl" 63 | multiprocessing_distributed: true 64 | reproduce_seed: null 65 | 66 | -------------------------------------------------------------------------------- /config/kits/medformer_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/kits/kits_3d 3 | classes: 3 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: medformer 9 | in_chan: 1 10 | base_chan: 32 11 | conv_block: 'BasicBlock' 12 | 13 | down_scale: [[2,2,2], [2,2,2], [2,2,2], [2,2,2]] 14 | kernel_size: [[3,3,3], [3,3,3], [3,3,3], [3,3,3], [3,3,3]] 15 | chan_num: [64, 128, 256, 320, 256, 128, 64, 32] 16 | norm: in 17 | act: relu 18 | map_size: [4, 4, 4] 19 | conv_num: [2,0,0,0, 0,0,2,2] 20 | trans_num: [0,2,4,6, 4,2,0,0] 21 | num_heads: [1,4,8,10, 8,4,1,1] 22 | expansion: 4 23 | fusion_depth: 2 24 | fusion_dim: 320 25 | fusion_heads: 10 26 | attn_drop: 0. 27 | proj_drop: 0. 28 | proj_type: 'depthwise' 29 | 30 | 31 | #TRAIN 32 | epochs: 300 33 | training_size: [128, 128, 128] # training crop size 34 | start_epoch: 0 35 | num_workers: 2 36 | aug_device: 'cpu' 37 | 38 | aux_loss: True 39 | aux_weight: [0.5, 0.5] 40 | 41 | split_seed: 0 42 | k_fold: 5 43 | 44 | optimizer: adamw 45 | base_lr: 0.0004 46 | betas: [0.9, 0.999] 47 | weight_decay: 0.05 # weight decay of the optimizer 48 | weight: [0.5, 1, 2] 49 | rlt: 1 # relation between CE and Dice loss 50 | 51 | 52 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 53 | rotate: [30, 30, 30] # rotation angle for data augmentation 54 | translate: [0, 0, 0] 55 | affine_pad_size: [40, 40, 40] 56 | gaussian_noise_std: 0.02 57 | 58 | print_freq: 5 59 | iter_per_epoch: 500 60 | 61 | 62 | 63 | 64 | #VALIDATION 65 | ema: True 66 | ema_alpha: 0.99 67 | val_freq: 20 68 | 69 | 70 | 71 | #INFERENCE 72 | sliding_window: True 73 | window_size: [128, 128, 128] 74 | 75 | 76 | # DDP 77 | world_size: 1 78 | proc_idx: 0 79 | rank: 0 80 | port: 10000 81 | dist_url: 'tcp://localhost:10001' # the port number here should be the same as the previous one 82 | dist_backend: "nccl" 83 | multiprocessing_distributed: true 84 | reproduce_seed: null 85 | 86 | -------------------------------------------------------------------------------- /config/kits/nnformer_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/kits/kits_3d 3 | classes: 3 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: nnformer 9 | in_chan: 1 10 | base_chan: 24 11 | 12 | #TRAIN 13 | epochs: 300 14 | training_size: [128, 128, 128] # training crop size 15 | start_epoch: 0 16 | num_workers: 2 17 | aug_device: 'cpu' 18 | 19 | aux_loss: True 20 | aux_weight: [0.2, 0.3, 0.5] 21 | 22 | split_seed: 0 23 | k_fold: 5 24 | 25 | optimizer: adamw 26 | base_lr: 0.0004 27 | betas: [0.9, 0.999] 28 | weight_decay: 0.05 # weight decay of SGD optimizer 29 | weight: [0.5, 1, 2] # weitght of each class in the loss function 30 | rlt: 1 # relation between CE and Dice loss 31 | 32 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 33 | rotate: [30, 30, 30] # rotation angle for data augmentation 34 | translate: [0, 0, 0] 35 | affine_pad_size: [40, 40, 40] 36 | gaussian_noise_std: 0.02 37 | 38 | print_freq: 5 39 | iter_per_epoch: 500 40 | 41 | 42 | #VALIDATION 43 | ema: True 44 | ema_alpha: 0.99 45 | val_freq: 20 46 | 47 | 48 | 49 | #INFERENCE 50 | sliding_window: True 51 | window_size: [128, 128, 128] 52 | 53 | 54 | 55 | 56 | # DDP 57 | world_size: 1 58 | proc_idx: 0 59 | rank: 0 60 | port: 10000 61 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 62 | dist_backend: "nccl" 63 | multiprocessing_distributed: true 64 | reproduce_seed: null 65 | 66 | -------------------------------------------------------------------------------- /config/kits/resunet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/kits/kits_3d 3 | classes: 3 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: resunet 9 | in_chan: 1 10 | base_chan: 32 11 | down_scale: [[2,2,2], [2,2,2], [2,2,2], [2,2,2]] 12 | kernel_size: [[3,3,3], [3,3,3], [3,3,3], [3,3,3], [3,3,3]] 13 | block: BasicBlock 14 | norm: in 15 | 16 | #TRAIN 17 | epochs: 300 18 | training_size: [128, 128, 128] # training crop size 19 | start_epoch: 0 20 | num_workers: 2 21 | aug_device: 'cpu' 22 | 23 | 24 | split_seed: 0 25 | k_fold: 5 26 | 27 | optimizer: adamw 28 | base_lr: 0.0004 29 | betas: [0.9, 0.999] 30 | weight_decay: 0.05 # weight decay of SGD optimizer 31 | weight: [0.5, 1, 2] # weitght of each class in the loss function 32 | rlt: 1 # relation between CE and Dice loss 33 | 34 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 35 | rotate: [30, 30, 30] # rotation angle for data augmentation 36 | translate: [0, 0, 0] 37 | affine_pad_size: [40, 40, 40] 38 | gaussian_noise_std: 0.02 39 | 40 | print_freq: 5 41 | iter_per_epoch: 500 42 | 43 | 44 | #VALIDATION 45 | ema: True 46 | ema_alpha: 0.99 47 | val_freq: 20 48 | 49 | 50 | 51 | #INFERENCE 52 | sliding_window: True 53 | window_size: [128, 128, 128] 54 | 55 | 56 | 57 | 58 | # DDP 59 | world_size: 1 60 | proc_idx: 0 61 | rank: 0 62 | port: 10000 63 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 64 | dist_backend: "nccl" 65 | multiprocessing_distributed: true 66 | reproduce_seed: null 67 | -------------------------------------------------------------------------------- /config/kits/swin_unetr_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/kits/kits_3d 3 | classes: 3 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: swin_unetr 9 | in_chan: 1 10 | base_chan: 48 11 | 12 | #TRAIN 13 | epochs: 300 14 | training_size: [128, 128, 128] # training crop size 15 | start_epoch: 0 16 | num_workers: 2 17 | aug_device: 'cpu' 18 | 19 | split_seed: 0 20 | k_fold: 5 21 | 22 | optimizer: adamw 23 | base_lr: 0.0004 24 | betas: [0.9, 0.999] 25 | weight_decay: 0.05 # weight decay of SGD optimizer 26 | weight: [0.5, 1, 2] # weitght of each class in the loss function 27 | rlt: 1 # relation between CE and Dice loss 28 | 29 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 30 | rotate: [30, 30, 30] # rotation angle for data augmentation 31 | translate: [0, 0, 0] 32 | affine_pad_size: [40, 40, 40] 33 | gaussian_noise_std: 0.02 34 | 35 | print_freq: 5 36 | iter_per_epoch: 500 37 | 38 | 39 | #VALIDATION 40 | ema: True 41 | ema_alpha: 0.99 42 | val_freq: 20 43 | 44 | 45 | 46 | #INFERENCE 47 | sliding_window: True 48 | window_size: [128, 128, 128] 49 | 50 | 51 | 52 | 53 | # DDP 54 | world_size: 1 55 | proc_idx: 0 56 | rank: 0 57 | port: 10000 58 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 59 | dist_backend: "nccl" 60 | multiprocessing_distributed: true 61 | reproduce_seed: null 62 | 63 | 64 | -------------------------------------------------------------------------------- /config/kits/unetr_3d.yaml: -------------------------------------------------------------------------------- 1 | #DAT 2 | data_root: /filer/tmp1/yg397/dataset/kits/kits_3d 3 | classes: 3 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: unetr 9 | in_chan: 1 10 | norm: in 11 | init_model: /research/cbim/vast/yg397/ConvFormer/ConvFormer/initmodel/UNETR_model_best_acc.pth 12 | 13 | #TRAIN 14 | epochs: 400 15 | training_size: [96, 96, 96] # training crop size 16 | start_epoch: 0 17 | num_workers: 2 18 | aug_device: 'cpu' 19 | 20 | split_seed: 0 21 | k_fold: 5 22 | 23 | optimizer: adamw 24 | base_lr: 0.0001 25 | betas: [0.9, 0.999] 26 | weight_decay: 0.00005 # weight decay of SGD optimizer 27 | weight: [0.5, 1, 2] 28 | rlt: 1 # relation between CE and Dice loss 29 | 30 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 31 | rotate: [30, 30, 30] # rotation angle for data augmentation 32 | translate: [0, 0, 0] 33 | affine_pad_size: [40, 40, 40] 34 | gaussian_noise_std: 0.02 35 | 36 | print_freq: 5 37 | iter_per_epoch: 500 38 | 39 | 40 | #VALIDATION 41 | ema: False 42 | ema_alpha: 0.99 43 | val_freq: 20 44 | 45 | 46 | 47 | #INFERENCE 48 | sliding_window: True 49 | window_size: [96, 96, 96] 50 | 51 | 52 | # DDP 53 | world_size: 1 54 | proc_idx: 0 55 | rank: 0 56 | port: 10000 57 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 58 | dist_backend: "nccl" 59 | multiprocessing_distributed: true 60 | reproduce_seed: null 61 | 62 | 63 | -------------------------------------------------------------------------------- /config/kits/vtunet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/kits/kits_3d 3 | classes: 3 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: vtunet 9 | init_model: '/research/cbim/vast/yg397/github/UTNet/initmodel/swin_tiny_patch4_window7_224.pth' 10 | in_chan: 1 11 | patch_size: [4, 4, 4] 12 | 13 | #TRAIN 14 | epochs: 300 15 | training_size: [128, 128, 128] # training crop size 16 | start_epoch: 0 17 | num_workers: 2 18 | aug_device: 'cpu' 19 | 20 | split_seed: 0 21 | k_fold: 5 22 | 23 | optimizer: adamw 24 | base_lr: 0.0004 #0.001 25 | betas: [0.9, 0.999] 26 | #momentum: 0.9 # momentum of SGD optimizer 27 | weight_decay: 0.05 # weight decay of SGD optimizer 28 | weight: [0.5, 1, 2] 29 | rlt: 1 # relation between CE and Dice loss 30 | 31 | 32 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 33 | rotate: [30, 30, 30] # rotation angle for data augmentation 34 | affine_pad_size: [40, 40, 40] 35 | translate: [0, 0, 0] 36 | gaussian_noise_std: 0.02 37 | 38 | print_freq: 5 39 | iter_per_epoch: 500 40 | 41 | 42 | 43 | 44 | #VALIDATION 45 | ema: True 46 | ema_alpha: 0.99 47 | val_freq: 20 48 | 49 | 50 | 51 | #INFERENCE 52 | sliding_window: True 53 | window_size: [128, 128, 128] 54 | 55 | 56 | 57 | 58 | 59 | # DDP 60 | world_size: 1 61 | proc_idx: 0 62 | rank: 0 63 | port: 10000 64 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 65 | dist_backend: "nccl" 66 | multiprocessing_distributed: true 67 | reproduce_seed: null 68 | 69 | -------------------------------------------------------------------------------- /config/lits/attention_unet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/lits/lits_3d 3 | classes: 3 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: attention_resunet 9 | in_chan: 1 10 | base_chan: 32 11 | down_scale: [[2,2,2], [2,2,2], [2,2,2], [2,2,2]] 12 | kernel_size: [[3,3,3], [3,3,3], [3,3,3], [3,3,3], [3,3,3]] 13 | block: BasicBlock 14 | norm: in 15 | 16 | #TRAIN 17 | epochs: 200 18 | training_size: [128, 128, 128] # training crop size 19 | start_epoch: 0 20 | num_workers: 2 21 | aug_device: 'cpu' 22 | 23 | split_seed: 0 24 | k_fold: 5 25 | 26 | optimizer: adamw 27 | base_lr: 0.001 28 | betas: [0.9, 0.999] 29 | weight_decay: 0.05 # weight decay of SGD optimizer 30 | weight: [0.5, 1, 3] # weitght of each class in the loss function 31 | rlt: 1 # relation between CE and Dice loss 32 | 33 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 34 | rotate: [30, 30, 30] # rotation angle for data augmentation 35 | translate: [0, 0, 0] 36 | gaussian_noise_std: 0.02 37 | 38 | print_freq: 5 39 | iter_per_epoch: 500 40 | 41 | 42 | #VALIDATION 43 | ema: True 44 | ema_alpha: 0.99 45 | val_freq: 10 46 | 47 | 48 | 49 | #INFERENCE 50 | sliding_window: True 51 | window_size: [128, 128, 128] 52 | 53 | 54 | 55 | # DDP 56 | world_size: 1 57 | proc_idx: 0 58 | rank: 0 59 | port: 10000 60 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 61 | dist_backend: "nccl" 62 | multiprocessing_distributed: true 63 | reproduce_seed: null 64 | -------------------------------------------------------------------------------- /config/lits/medformer_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/lits/lits_3d 3 | classes: 3 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: medformer 9 | in_chan: 1 10 | base_chan: 32 11 | conv_block: 'BasicBlock' 12 | 13 | down_scale: [[2,2,2], [2,2,2], [2,2,2], [2,2,2]] 14 | kernel_size: [[3,3,3], [3,3,3], [3,3,3], [3,3,3], [3,3,3]] 15 | chan_num: [64, 128, 256, 320, 256, 128, 64, 32] 16 | norm: in 17 | act: relu 18 | map_size: [4, 4, 4] 19 | conv_num: [2,0,0,0, 0,0,2,2] 20 | trans_num: [0,2,4,6, 4,2,0,0] 21 | num_heads: [1,1,1,1, 1,1,1,1] 22 | expansion: 4 23 | fusion_depth: 2 24 | fusion_dim: 320 25 | fusion_heads: 10 26 | attn_drop: 0. 27 | proj_drop: 0. 28 | proj_type: 'depthwise' 29 | 30 | 31 | #TRAIN 32 | epochs: 200 33 | training_size: [128, 128, 128] # training crop size 34 | start_epoch: 0 35 | num_workers: 2 36 | aug_device: 'cpu' 37 | 38 | aux_loss: False 39 | aux_weight: [0.5, 0.5] 40 | 41 | split_seed: 0 42 | k_fold: 5 43 | 44 | optimizer: adamw 45 | base_lr: 0.0004 46 | betas: [0.9, 0.999] 47 | weight_decay: 0.05 # weight decay of the optimizer 48 | weight: [0.5, 1, 3] 49 | rlt: 1 # relation between CE and Dice loss 50 | 51 | 52 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 53 | rotate: [30, 30, 30] # rotation angle for data augmentation 54 | translate: [0, 0, 0] 55 | affine_pad_size: [50, 50, 50] 56 | gaussian_noise_std: 0.02 57 | 58 | print_freq: 5 59 | iter_per_epoch: 500 60 | 61 | 62 | 63 | 64 | #VALIDATION 65 | ema: True 66 | ema_alpha: 0.99 67 | val_freq: 10 68 | 69 | 70 | 71 | #INFERENCE 72 | sliding_window: True 73 | window_size: [128, 128, 128] 74 | 75 | 76 | # DDP 77 | world_size: 1 78 | proc_idx: 0 79 | rank: 0 80 | port: 10000 81 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 82 | dist_backend: "nccl" 83 | multiprocessing_distributed: true 84 | reproduce_seed: null 85 | 86 | 87 | -------------------------------------------------------------------------------- /config/lits/nnformer_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/lits/lits_3d 3 | classes: 3 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: nnformer 9 | in_chan: 1 10 | base_chan: 48 11 | 12 | #TRAIN 13 | epochs: 200 14 | training_size: [128, 128, 128] # training crop size 15 | start_epoch: 0 16 | num_workers: 2 17 | aug_device: 'cpu' 18 | 19 | aux_loss: True 20 | aux_weight: [0.2, 0.3, 0.5] 21 | 22 | split_seed: 0 23 | k_fold: 5 24 | 25 | optimizer: adamw 26 | base_lr: 0.0004 27 | betas: [0.9, 0.999] 28 | weight_decay: 0.05 # weight decay of SGD optimizer 29 | weight: [0.5, 1, 3] # weitght of each class in the loss function 30 | rlt: 1 # relation between CE and Dice loss 31 | 32 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 33 | rotate: [30, 30, 30] # rotation angle for data augmentation 34 | translate: [0, 0, 0] 35 | affine_pad_size: [50, 50, 50] 36 | gaussian_noise_std: 0.02 37 | 38 | print_freq: 5 39 | iter_per_epoch: 500 40 | 41 | 42 | #VALIDATION 43 | ema: True 44 | ema_alpha: 0.99 45 | val_freq: 50 46 | 47 | 48 | 49 | #INFERENCE 50 | sliding_window: True 51 | window_size: [128, 128, 128] 52 | 53 | 54 | 55 | 56 | # DDP 57 | world_size: 1 58 | proc_idx: 0 59 | rank: 0 60 | port: 10000 61 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 62 | dist_backend: "nccl" 63 | multiprocessing_distributed: true 64 | reproduce_seed: null 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /config/lits/resunet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/lits/lits_3d 3 | classes: 3 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: resunet 9 | in_chan: 1 10 | base_chan: 32 11 | down_scale: [[2,2,2], [2,2,2], [2,2,2], [2,2,2]] 12 | kernel_size: [[3,3,3], [3,3,3], [3,3,3], [3,3,3], [3,3,3]] 13 | block: BasicBlock 14 | norm: in 15 | 16 | #TRAIN 17 | epochs: 200 18 | training_size: [128, 128, 128] # training crop size 19 | start_epoch: 0 20 | num_workers: 2 21 | aug_device: 'cpu' 22 | 23 | split_seed: 0 24 | k_fold: 5 25 | 26 | optimizer: adamw 27 | base_lr: 0.0004 28 | betas: [0.9, 0.999] 29 | weight_decay: 0.05 # weight decay of SGD optimizer 30 | weight: [0.5, 1, 3] # weitght of each class in the loss function 31 | rlt: 1 # relation between CE and Dice loss 32 | 33 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 34 | rotate: [30, 30, 30] # rotation angle for data augmentation 35 | translate: [0, 0, 0] 36 | affine_pad_size: [50, 50, 50] 37 | gaussian_noise_std: 0.02 38 | 39 | print_freq: 5 40 | iter_per_epoch: 500 41 | 42 | 43 | #VALIDATION 44 | ema: True 45 | ema_alpha: 0.99 46 | val_freq: 10 47 | 48 | 49 | 50 | #INFERENCE 51 | sliding_window: True 52 | window_size: [128, 128, 128] 53 | 54 | 55 | 56 | 57 | # DDP 58 | world_size: 1 59 | proc_idx: 0 60 | rank: 0 61 | port: 10000 62 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 63 | dist_backend: "nccl" 64 | multiprocessing_distributed: true 65 | reproduce_seed: null 66 | 67 | -------------------------------------------------------------------------------- /config/lits/swin_unetr_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/lits/lits_3d 3 | classes: 3 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: swin_unetr 9 | in_chan: 1 10 | base_chan: 48 11 | 12 | #TRAIN 13 | epochs: 200 14 | training_size: [128, 128, 128] # training crop size 15 | start_epoch: 0 16 | num_workers: 2 17 | aug_device: 'cpu' 18 | 19 | split_seed: 0 20 | k_fold: 5 21 | 22 | optimizer: adamw 23 | base_lr: 0.0004 24 | betas: [0.9, 0.999] 25 | weight_decay: 0.05 # weight decay of SGD optimizer 26 | weight: [0.5, 1, 3] # weitght of each class in the loss function 27 | rlt: 1 # relation between CE and Dice loss 28 | 29 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 30 | rotate: [30, 30, 30] # rotation angle for data augmentation 31 | translate: [0, 0, 0] 32 | affine_pad_size: [50, 50, 50] 33 | gaussian_noise_std: 0.02 34 | 35 | print_freq: 5 36 | iter_per_epoch: 500 37 | 38 | 39 | #VALIDATION 40 | ema: True 41 | ema_alpha: 0.99 42 | val_freq: 50 43 | 44 | 45 | 46 | #INFERENCE 47 | sliding_window: True 48 | window_size: [128, 128, 128] 49 | 50 | 51 | 52 | 53 | # DDP 54 | world_size: 1 55 | proc_idx: 0 56 | rank: 0 57 | port: 10000 58 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 59 | dist_backend: "nccl" 60 | multiprocessing_distributed: true 61 | reproduce_seed: null 62 | 63 | -------------------------------------------------------------------------------- /config/lits/unetr_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/lits/lits_3d 3 | classes: 3 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: unetr 9 | in_chan: 1 10 | norm: in 11 | init_model: /research/cbim/vast/yg397/ConvFormer/ConvFormer/initmodel/UNETR_model_best_acc.pth 12 | 13 | #TRAIN 14 | epochs: 400 15 | training_size: [96, 96, 96] # training crop size 16 | start_epoch: 0 17 | num_workers: 2 18 | aug_device: 'cpu' 19 | 20 | split_seed: 0 21 | k_fold: 5 22 | 23 | optimizer: adamw 24 | base_lr: 0.0001 25 | betas: [0.9, 0.999] 26 | weight_decay: 0.00005 # weight decay of SGD optimizer 27 | weight: [0.5, 1, 3] 28 | rlt: 1 # relation between CE and Dice loss 29 | 30 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 31 | rotate: [30, 30, 30] # rotation angle for data augmentation 32 | translate: [0, 0, 0] 33 | gaussian_noise_std: 0.02 34 | 35 | print_freq: 5 36 | iter_per_epoch: 500 37 | 38 | 39 | #VALIDATION 40 | ema: False 41 | ema_alpha: 0.99 42 | val_freq: 10 43 | 44 | 45 | 46 | #INFERENCE 47 | sliding_window: True 48 | window_size: [96, 96, 96] 49 | 50 | 51 | # DDP 52 | world_size: 1 53 | proc_idx: 0 54 | rank: 0 55 | port: 10000 56 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 57 | dist_backend: "nccl" 58 | multiprocessing_distributed: true 59 | reproduce_seed: null 60 | -------------------------------------------------------------------------------- /config/lits/vtunet_3d.yaml: -------------------------------------------------------------------------------- 1 | #DATA 2 | data_root: /filer/tmp1/yg397/dataset/lits/lits_3d 3 | classes: 3 4 | modality: CT 5 | 6 | 7 | #MODEL 8 | arch: vtunet 9 | init_model: '/research/cbim/vast/yg397/github/UTNet/initmodel/swin_tiny_patch4_window7_224.pth' 10 | in_chan: 1 11 | patch_size: [4, 4, 4] 12 | 13 | #TRAIN 14 | epochs: 400 15 | training_size: [128, 128, 128] # training crop size 16 | start_epoch: 0 17 | num_workers: 2 18 | aug_device: 'cpu' 19 | 20 | split_seed: 0 21 | k_fold: 5 22 | 23 | optimizer: adamw 24 | base_lr: 0.0004 #0.001 25 | betas: [0.9, 0.999] 26 | #momentum: 0.9 # momentum of SGD optimizer 27 | weight_decay: 0.05 # weight decay of SGD optimizer 28 | weight: [0.5, 1, 3] 29 | rlt: 1 # relation between CE and Dice loss 30 | 31 | 32 | scale: [0.3, 0.3, 0.3] # scale for data augmentation 0.1 0.3 0.3 33 | rotate: [30, 30, 30] # rotation angle for data augmentation 34 | translate: [0, 0, 0] 35 | gaussian_noise_std: 0.02 36 | 37 | print_freq: 5 38 | iter_per_epoch: 500 39 | 40 | 41 | 42 | 43 | #VALIDATION 44 | ema: True 45 | ema_alpha: 0.99 46 | val_freq: 10 47 | 48 | 49 | 50 | #INFERENCE 51 | sliding_window: True 52 | window_size: [128, 128, 128] 53 | 54 | 55 | 56 | 57 | 58 | # DDP 59 | world_size: 1 60 | proc_idx: 0 61 | rank: 0 62 | port: 10000 63 | dist_url: 'tcp://localhost:10000' # the port number here should be the same as the previous one 64 | dist_backend: "nccl" 65 | multiprocessing_distributed: true 66 | reproduce_seed: null 67 | -------------------------------------------------------------------------------- /dataset_conversion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhygao/CBIM-Medical-Image-Segmentation/7c26979a96eb9fe057320e1db38680bae33786b8/dataset_conversion/__init__.py -------------------------------------------------------------------------------- /dataset_conversion/acdc_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import SimpleITK as sitk 3 | from utils import ResampleXYZAxis, ResampleLabelToRef 4 | import os 5 | import random 6 | import yaml 7 | import pdb 8 | 9 | def ResampleCMRImage(imImage, imLabel, save_path, patient_name, count, target_spacing=(1., 1., 1.)): 10 | 11 | assert imImage.GetSpacing() == imLabel.GetSpacing() 12 | assert imImage.GetSize() == imLabel.GetSize() 13 | 14 | 15 | spacing = imImage.GetSpacing() 16 | origin = imImage.GetOrigin() 17 | 18 | 19 | npimg = sitk.GetArrayFromImage(imImage) 20 | nplab = sitk.GetArrayFromImage(imLabel) 21 | z, y, x = npimg.shape 22 | 23 | if not os.path.exists('%s'%(save_path)): 24 | os.mkdir('%s'%(save_path)) 25 | 26 | 27 | re_img = ResampleXYZAxis(imImage, space=(target_spacing[0], target_spacing[1], spacing[2]), interp=sitk.sitkBSpline) 28 | re_lab = ResampleLabelToRef(imLabel, re_img, interp=sitk.sitkNearestNeighbor) 29 | 30 | 31 | sitk.WriteImage(re_img, '%s/%s_%d.nii.gz'%(save_path, patient_name, count)) 32 | sitk.WriteImage(re_lab, '%s/%s_%d_gt.nii.gz'%(save_path, patient_name, count)) 33 | 34 | 35 | 36 | if __name__ == '__main__': 37 | 38 | 39 | src_path = '/research/cbim/medical/medical-share/public/ACDC/raw/training/' 40 | tgt_path = '/research/cbim/medical/yg397/tgt_dir/' 41 | 42 | 43 | 44 | 45 | patient_list = list(range(1, 101)) 46 | 47 | name_list = [] 48 | for idx in patient_list: 49 | name_list.append('patient%.3d'%idx) 50 | 51 | if not os.path.exists(tgt_path+'list'): 52 | os.mkdir('%slist'%(tgt_path)) 53 | with open("%slist/dataset.yaml"%tgt_path, "w",encoding="utf-8") as f: 54 | yaml.dump(name_list, f) 55 | 56 | os.chdir(src_path) 57 | for name in os.listdir('.'): 58 | os.chdir(name) 59 | 60 | count = 0 61 | for i in os.listdir('.'): 62 | if 'gt' in i: 63 | tmp = i.split('_') 64 | img_name = tmp[0] + '_' + tmp[1] 65 | patient_name = tmp[0] 66 | 67 | img = sitk.ReadImage('%s.nii.gz'%img_name) 68 | lab = sitk.ReadImage('%s_gt.nii.gz'%img_name) 69 | 70 | ResampleCMRImage(img, lab, tgt_path, patient_name, count, (1.5625, 1.5625)) 71 | count += 1 72 | print(name, '%d'%count, 'done') 73 | 74 | os.chdir('..') 75 | 76 | 77 | -------------------------------------------------------------------------------- /dataset_conversion/acdc_3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import SimpleITK as sitk 3 | from utils import ResampleXYZAxis, ResampleLabelToRef 4 | import os 5 | import random 6 | import yaml 7 | 8 | def ResampleCMRImage(imImage, imLabel, save_path, patient_name, count, target_spacing=(1., 1., 1.)): 9 | 10 | assert imImage.GetSpacing() == imLabel.GetSpacing() 11 | assert imImage.GetSize() == imLabel.GetSize() 12 | 13 | 14 | spacing = imImage.GetSpacing() 15 | origin = imImage.GetOrigin() 16 | 17 | 18 | npimg = sitk.GetArrayFromImage(imImage) 19 | nplab = sitk.GetArrayFromImage(imLabel) 20 | z, y, x = npimg.shape 21 | 22 | if not os.path.exists('%s'%(save_path)): 23 | os.mkdir('%s'%(save_path)) 24 | 25 | re_img_xy = ResampleXYZAxis(imImage, space=(target_spacing[0], target_spacing[1], spacing[2]), interp=sitk.sitkBSpline) 26 | re_lab_xy = ResampleLabelToRef(imLabel, re_img_xy, interp=sitk.sitkNearestNeighbor) 27 | 28 | re_img_xyz = ResampleXYZAxis(re_img_xy, space=(target_spacing[0], target_spacing[1], target_spacing[2]), interp=sitk.sitkNearestNeighbor) 29 | re_lab_xyz = ResampleLabelToRef(re_lab_xy, re_img_xyz, interp=sitk.sitkNearestNeighbor) 30 | 31 | 32 | sitk.WriteImage(re_img_xyz, '%s/%s_%d.nii.gz'%(save_path, patient_name, count)) 33 | sitk.WriteImage(re_lab_xyz, '%s/%s_%d_gt.nii.gz'%(save_path, patient_name, count)) 34 | 35 | 36 | 37 | if __name__ == '__main__': 38 | 39 | 40 | src_path = '/research/cbim/medical/medical-share/public/ACDC/raw/training/' 41 | tgt_path = '/research/cbim/medical/yg397/tgt_dir/' 42 | 43 | 44 | 45 | 46 | patient_list = list(range(1, 101)) 47 | 48 | name_list = [] 49 | 50 | for idx in patient_list: 51 | name_list.append('patient%.3d'%idx) 52 | 53 | 54 | if not os.path.exists(tgt_path+'list'): 55 | os.mkdir('%slist'%(tgt_path)) 56 | with open("%slist/dataset.yaml"%tgt_path, "w",encoding="utf-8") as f: 57 | yaml.dump(name_list, f) 58 | 59 | os.chdir(src_path) 60 | for name in os.listdir('.'): 61 | os.chdir(name) 62 | 63 | count = 0 64 | for i in os.listdir('.'): 65 | if 'gt' in i: 66 | tmp = i.split('_') 67 | img_name = tmp[0] + '_' + tmp[1] 68 | patient_name = tmp[0] 69 | 70 | img = sitk.ReadImage('%s.nii.gz'%img_name) 71 | lab = sitk.ReadImage('%s_gt.nii.gz'%img_name) 72 | 73 | ResampleCMRImage(img, lab, tgt_path, patient_name, count, (1.5625, 1.5625, 5.0)) 74 | count += 1 75 | print(name, 'done') 76 | 77 | os.chdir('..') 78 | 79 | 80 | -------------------------------------------------------------------------------- /dataset_conversion/amos_3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import SimpleITK as sitk 3 | from utils import ResampleXYZAxis, ResampleLabelToRef, CropForeground, ITKReDirection 4 | import os 5 | import random 6 | import yaml 7 | import copy 8 | import pdb 9 | 10 | from matplotlib import image 11 | 12 | def ResampleImage(imImage, imLabel, save_path, name, target_spacing=(1., 1., 1.)): 13 | 14 | assert round(imImage.GetSpacing()[0], 2) == round(imLabel.GetSpacing()[0], 2) 15 | assert round(imImage.GetSpacing()[1], 2) == round(imLabel.GetSpacing()[1], 2) 16 | assert round(imImage.GetSpacing()[2], 2) == round(imLabel.GetSpacing()[2], 2) 17 | 18 | assert imImage.GetSize() == imLabel.GetSize() 19 | 20 | 21 | imLabel.CopyInformation(imImage) 22 | 23 | imImage.SetDirection((1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)) 24 | imLabel.SetDirection((1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)) 25 | 26 | 27 | spacing = imImage.GetSpacing() 28 | origin = imImage.GetOrigin() 29 | 30 | 31 | npimg = sitk.GetArrayFromImage(imImage).astype(np.int32) 32 | nplab = sitk.GetArrayFromImage(imLabel).astype(np.uint8) 33 | z, y, x = npimg.shape 34 | 35 | if not os.path.exists('%s'%(save_path)): 36 | os.mkdir('%s'%(save_path)) 37 | 38 | 39 | re_img_yz = ResampleXYZAxis(imImage, space=(target_spacing[0], target_spacing[1], spacing[2]), interp=sitk.sitkBSpline) 40 | re_lab_yz = ResampleLabelToRef(imLabel, re_img_yz, interp=sitk.sitkNearestNeighbor) 41 | 42 | re_img_xyz = ResampleXYZAxis(re_img_yz, space=(target_spacing[0], target_spacing[1], target_spacing[2]), interp=sitk.sitkNearestNeighbor) 43 | re_lab_xyz = ResampleLabelToRef(re_lab_yz, re_img_xyz, interp=sitk.sitkNearestNeighbor) 44 | 45 | 46 | cropped_img, cropped_lab = CropForeground(re_img_xyz, re_lab_xyz, context_size=[30, 30, 30]) # z, y, x 47 | 48 | sitk.WriteImage(cropped_img, '%s/%s.nii.gz'%(save_path, name)) 49 | sitk.WriteImage(cropped_lab, '%s/%s_gt.nii.gz'%(save_path, name)) 50 | 51 | 52 | if __name__ == '__main__': 53 | 54 | 55 | src_path = '/filer/tmp1/yg397/dataset/amos/amos22/' 56 | ct_tgt_path = '/filer/tmp1/yg397/dataset/amos/cbim/amos_ct_3d/' 57 | mr_tgt_path = '/filer/tmp1/yg397/dataset/amos/cbim/amos_mr_3d/' 58 | 59 | 60 | print('Start processing training set') 61 | ct_name_list = [] 62 | mr_name_list = [] 63 | for name in os.listdir(f"{src_path}imagesTr/"): 64 | if not name.endswith('nii.gz'): 65 | continue 66 | print(name) 67 | idx = name.split('.')[0] 68 | idx = int(idx.split('_')[1]) 69 | if idx < 500: 70 | ct_name_list.append(idx) 71 | else: 72 | mr_name_list.append(idx) 73 | 74 | 75 | if not os.path.exists(ct_tgt_path+'list'): 76 | os.mkdir('%slist'%(ct_tgt_path)) 77 | with open("%slist/dataset.yaml"%ct_tgt_path, "w",encoding="utf-8") as f: 78 | yaml.dump(ct_name_list, f) 79 | 80 | if not os.path.exists(mr_tgt_path+'list'): 81 | os.mkdir('%slist'%(mr_tgt_path)) 82 | with open("%slist/dataset.yaml"%mr_tgt_path, "w",encoding="utf-8") as f: 83 | yaml.dump(mr_name_list, f) 84 | 85 | os.chdir(src_path) 86 | 87 | for name in ct_name_list: 88 | img = sitk.ReadImage(src_path+f"imagesTr/amos_{name:04d}.nii.gz") 89 | lab = sitk.ReadImage(src_path+f"labelsTr/amos_{name:04d}.nii.gz") 90 | 91 | ResampleImage(img, lab, ct_tgt_path, name, (0.68825, 0.68825, 2.0)) 92 | print(name, 'done') 93 | 94 | for name in mr_name_list: 95 | img = sitk.ReadImage(src_path+f"imagesTr/amos_{name:04d}.nii.gz") 96 | lab = sitk.ReadImage(src_path+f"labelsTr/amos_{name:04d}.nii.gz") 97 | 98 | ResampleImage(img, lab, mr_tgt_path, name, (1.1875, 1.1875, 2.0)) 99 | print(name, 'done') 100 | 101 | print('Start processing validation set') 102 | ct_name_list = [] 103 | mr_name_list = [] 104 | for name in os.listdir(f"{src_path}imagesVa/"): 105 | if not name.endswith('nii.gz'): 106 | continue 107 | print(name) 108 | idx = name.split('.')[0] 109 | idx = int(idx.split('_')[1]) 110 | if idx < 500: 111 | ct_name_list.append(idx) 112 | else: 113 | mr_name_list.append(idx) 114 | 115 | for name in ct_name_list: 116 | img = sitk.ReadImage(src_path+f"imagesVa/amos_{name:04d}.nii.gz") 117 | lab = sitk.ReadImage(src_path+f"labelsVa/amos_{name:04d}.nii.gz") 118 | 119 | ResampleImage(img, lab, ct_tgt_path, name, (0.68825, 0.68825, 2.0)) 120 | print(name, 'done') 121 | 122 | for name in mr_name_list: 123 | img = sitk.ReadImage(src_path+f"imagesVa/amos_{name:04d}.nii.gz") 124 | lab = sitk.ReadImage(src_path+f"labelsVa/amos_{name:04d}.nii.gz") 125 | 126 | ResampleImage(img, lab, mr_tgt_path, name, (1.1875, 1.1875, 2.0)) 127 | print(name, 'done') 128 | 129 | -------------------------------------------------------------------------------- /dataset_conversion/bcv_3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import SimpleITK as sitk 3 | from utils import ResampleXYZAxis, ResampleLabelToRef, CropForeground 4 | import os 5 | import random 6 | import yaml 7 | import copy 8 | import pdb 9 | 10 | def ResampleImage(imImage, imLabel, save_path, name, target_spacing=(1., 1., 1.)): 11 | 12 | assert imImage.GetSpacing() == imLabel.GetSpacing() 13 | assert imImage.GetSize() == imLabel.GetSize() 14 | 15 | 16 | spacing = imImage.GetSpacing() 17 | origin = imImage.GetOrigin() 18 | 19 | 20 | npimg = sitk.GetArrayFromImage(imImage).astype(np.int32) 21 | nplab = sitk.GetArrayFromImage(imLabel).astype(np.uint8) 22 | z, y, x = npimg.shape 23 | 24 | if not os.path.exists('%s'%(save_path)): 25 | os.mkdir('%s'%(save_path)) 26 | 27 | imImage.SetDirection((1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)) 28 | imLabel.SetDirection((1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)) 29 | 30 | 31 | re_img_xy = ResampleXYZAxis(imImage, space=(target_spacing[0], target_spacing[1], spacing[2]), interp=sitk.sitkBSpline) 32 | re_lab_xy = ResampleLabelToRef(imLabel, re_img_xy, interp=sitk.sitkNearestNeighbor) 33 | 34 | re_img_xyz = ResampleXYZAxis(re_img_xy, space=(target_spacing[0], target_spacing[1], target_spacing[2]), interp=sitk.sitkNearestNeighbor) 35 | re_lab_xyz = ResampleLabelToRef(re_lab_xy, re_img_xyz, interp=sitk.sitkNearestNeighbor) 36 | 37 | 38 | cropped_img, cropped_lab = CropForeground(re_img_xyz, re_lab_xyz, context_size=[5, 20, 20]) 39 | 40 | sitk.WriteImage(cropped_img, '%s/%s.nii.gz'%(save_path, name)) 41 | sitk.WriteImage(cropped_lab, '%s/%s_gt.nii.gz'%(save_path, name)) 42 | 43 | 44 | if __name__ == '__main__': 45 | 46 | 47 | src_path = '/filer/tmp1/yg397/dataset/bcv/Abdomen/RawData/Training/' 48 | tgt_path = '/filer/tmp1/yg397/dataset/bcv/bcv_3d/' 49 | 50 | 51 | name_list = os.listdir(src_path + 'img') 52 | name_list = [name.split('.')[0] for name in name_list] 53 | 54 | if not os.path.exists(tgt_path+'list'): 55 | os.mkdir('%slist'%(tgt_path)) 56 | with open("%slist/dataset.yaml"%tgt_path, "w",encoding="utf-8") as f: 57 | yaml.dump(name_list, f) 58 | 59 | os.chdir(src_path) 60 | 61 | for name in name_list: 62 | img_name = name + '.nii.gz' 63 | lab_name = img_name.replace('img', 'label') 64 | 65 | img = sitk.ReadImage(src_path+'img/%s'%img_name) 66 | lab = sitk.ReadImage(src_path+'label/%s'%lab_name) 67 | 68 | ResampleImage(img, lab, tgt_path, name, (0.75, 0.75, 3.0)) 69 | print(name, 'done') 70 | 71 | 72 | -------------------------------------------------------------------------------- /dataset_conversion/kits_3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import SimpleITK as sitk 3 | from utils import ResampleXYZAxis, ResampleLabelToRef, CropForeground 4 | import os 5 | import random 6 | import yaml 7 | import copy 8 | import pdb 9 | 10 | def ResampleImage(imImage, imLabel, save_path, name, target_spacing=(1., 1., 1.)): 11 | 12 | assert round(imImage.GetSpacing()[0], 2) == round(imLabel.GetSpacing()[0], 2) 13 | assert round(imImage.GetSpacing()[1], 2) == round(imLabel.GetSpacing()[1], 2) 14 | assert round(imImage.GetSpacing()[2], 2) == round(imLabel.GetSpacing()[2], 2) 15 | 16 | assert imImage.GetSize() == imLabel.GetSize() 17 | 18 | 19 | spacing = imImage.GetSpacing() 20 | origin = imImage.GetOrigin() 21 | 22 | imLabel.CopyInformation(imImage) 23 | 24 | npimg = sitk.GetArrayFromImage(imImage).astype(np.int32) 25 | nplab = sitk.GetArrayFromImage(imLabel).astype(np.uint8) 26 | z, y, x = npimg.shape 27 | 28 | if not os.path.exists('%s'%(save_path)): 29 | os.mkdir('%s'%(save_path)) 30 | 31 | imImage.SetDirection((1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)) 32 | imLabel.SetDirection((1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)) 33 | 34 | 35 | re_img_yz = ResampleXYZAxis(imImage, space=(spacing[0], target_spacing[1], target_spacing[2]), interp=sitk.sitkBSpline) 36 | re_lab_yz = ResampleLabelToRef(imLabel, re_img_yz, interp=sitk.sitkNearestNeighbor) 37 | 38 | re_img_xyz = ResampleXYZAxis(re_img_yz, space=(target_spacing[0], target_spacing[1], target_spacing[2]), interp=sitk.sitkNearestNeighbor) 39 | re_lab_xyz = ResampleLabelToRef(re_lab_yz, re_img_xyz, interp=sitk.sitkNearestNeighbor) 40 | 41 | 42 | 43 | 44 | cropped_img, cropped_lab = CropForeground(re_img_xyz, re_lab_xyz, context_size=[30, 30, 30]) # z, y, x 45 | 46 | sitk.WriteImage(cropped_img, '%s/%s.nii.gz'%(save_path, name)) 47 | sitk.WriteImage(cropped_lab, '%s/%s_gt.nii.gz'%(save_path, name)) 48 | 49 | 50 | if __name__ == '__main__': 51 | 52 | 53 | src_path = '/filer/tmp1/yg397/dataset/kits/kits19/data/' 54 | tgt_path = '/filer/tmp1/yg397/dataset/kits/kits_3d/' 55 | 56 | 57 | name_list = [] 58 | for i in range(0, 210): 59 | name_list.append(i) 60 | 61 | if not os.path.exists(tgt_path+'list'): 62 | os.mkdir('%slist'%(tgt_path)) 63 | with open("%slist/dataset.yaml"%tgt_path, "w",encoding="utf-8") as f: 64 | yaml.dump(name_list, f) 65 | 66 | os.chdir(src_path) 67 | 68 | for name in name_list: 69 | img = sitk.ReadImage(src_path+f"case_00{name:03d}/imaging.nii.gz") 70 | lab = sitk.ReadImage(src_path+f"case_00{name:03d}/segmentation.nii.gz") 71 | 72 | ResampleImage(img, lab, tgt_path, name, (0.781625, 0.781625, 0.781625)) 73 | #ResampleImage(img, lab, tgt_path, name, (1, 1, 1)) 74 | print(name, 'done') 75 | 76 | 77 | -------------------------------------------------------------------------------- /dataset_conversion/lits_3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import SimpleITK as sitk 3 | from utils import ResampleXYZAxis, ResampleLabelToRef, CropForeground 4 | import os 5 | import random 6 | import yaml 7 | import copy 8 | import pdb 9 | 10 | def ResampleImage(imImage, imLabel, save_path, name, target_spacing=(1., 1., 1.)): 11 | 12 | imLabel.CopyInformation(imImage) 13 | 14 | assert imImage.GetSize() == imLabel.GetSize() 15 | 16 | 17 | spacing = imImage.GetSpacing() 18 | origin = imImage.GetOrigin() 19 | 20 | 21 | npimg = sitk.GetArrayFromImage(imImage).astype(np.int32) 22 | nplab = sitk.GetArrayFromImage(imLabel).astype(np.uint8) 23 | z, y, x = npimg.shape 24 | 25 | if not os.path.exists('%s'%(save_path)): 26 | os.mkdir('%s'%(save_path)) 27 | 28 | imImage.SetDirection((1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)) 29 | imLabel.SetDirection((1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)) 30 | 31 | 32 | re_img_xy = ResampleXYZAxis(imImage, space=(target_spacing[0], target_spacing[1], spacing[2]), interp=sitk.sitkBSpline) 33 | re_lab_xy = ResampleLabelToRef(imLabel, re_img_xy, interp=sitk.sitkNearestNeighbor) 34 | 35 | re_img_xyz = ResampleXYZAxis(re_img_xy, space=(target_spacing[0], target_spacing[1], target_spacing[2]), interp=sitk.sitkNearestNeighbor) 36 | re_lab_xyz = ResampleLabelToRef(re_lab_xy, re_img_xyz, interp=sitk.sitkNearestNeighbor) 37 | 38 | 39 | 40 | 41 | cropped_img, cropped_lab = CropForeground(re_img_xyz, re_lab_xyz, context_size=[10, 30, 30]) 42 | 43 | sitk.WriteImage(cropped_img, '%s/%s.nii.gz'%(save_path, name)) 44 | sitk.WriteImage(cropped_lab, '%s/%s_gt.nii.gz'%(save_path, name)) 45 | 46 | 47 | if __name__ == '__main__': 48 | 49 | 50 | src_path = '/filer/tmp1/yg397/dataset/lits/media/nas/01_Datasets/CT/LITS/' 51 | tgt_path = '/filer/tmp1/yg397/dataset/lits/lits_3d/' 52 | 53 | 54 | name_list = [] 55 | for i in range(0, 131): 56 | name_list.append(i) 57 | 58 | if not os.path.exists(tgt_path+'list'): 59 | os.mkdir('%slist'%(tgt_path)) 60 | with open("%slist/dataset.yaml"%tgt_path, "w",encoding="utf-8") as f: 61 | yaml.dump(name_list, f) 62 | 63 | os.chdir(src_path) 64 | 65 | for name in name_list: 66 | img_name = 'volume-%d.nii'%name 67 | lab_name = 'segmentation-%d.nii'%name 68 | 69 | img = sitk.ReadImage(src_path+img_name) 70 | lab = sitk.ReadImage(src_path+lab_name) 71 | 72 | ResampleImage(img, lab, tgt_path, name, (0.767578125, 0.767578125, 1.0)) 73 | #ResampleImage(img, lab, tgt_path, name, (1, 1, 1)) 74 | print(name, 'done') 75 | 76 | 77 | -------------------------------------------------------------------------------- /dataset_conversion/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import SimpleITK as sitk 3 | from skimage.measure import regionprops 4 | 5 | import os 6 | 7 | def ResampleXYZAxis(imImage, space=(1., 1., 1.), interp=sitk.sitkLinear): 8 | identity1 = sitk.Transform(3, sitk.sitkIdentity) 9 | sp1 = imImage.GetSpacing() 10 | sz1 = imImage.GetSize() 11 | 12 | sz2 = (int(round(sz1[0]*sp1[0]*1.0/space[0])), int(round(sz1[1]*sp1[1]*1.0/space[1])), int(round(sz1[2]*sp1[2]*1.0/space[2]))) 13 | 14 | imRefImage = sitk.Image(sz2, imImage.GetPixelIDValue()) 15 | imRefImage.SetSpacing(space) 16 | imRefImage.SetOrigin(imImage.GetOrigin()) 17 | imRefImage.SetDirection(imImage.GetDirection()) 18 | 19 | imOutImage = sitk.Resample(imImage, imRefImage, identity1, interp) 20 | 21 | return imOutImage 22 | 23 | def ResampleLabelToRef(imLabel, imRef, interp=sitk.sitkNearestNeighbor): 24 | identity1 = sitk.Transform(3, sitk.sitkIdentity) 25 | 26 | imRefImage = sitk.Image(imRef.GetSize(), imLabel.GetPixelIDValue()) 27 | imRefImage.SetSpacing(imRef.GetSpacing()) 28 | imRefImage.SetOrigin(imRef.GetOrigin()) 29 | imRefImage.SetDirection(imRef.GetDirection()) 30 | 31 | ResampledLabel = sitk.Resample(imLabel, imRefImage, identity1, interp) 32 | 33 | return ResampledLabel 34 | 35 | 36 | 37 | def ITKReDirection(itkimg, target_direction=(1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)): 38 | # target direction should be orthognal, i.e. (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0) 39 | 40 | # permute axis 41 | tmp_target_direction = np.abs(np.round(np.array(target_direction))).reshape(3,3).T 42 | current_direction = np.abs(np.round(itkimg.GetDirection())).reshape(3,3).T 43 | 44 | permute_order = [] 45 | if not np.array_equal(tmp_target_direction, current_direction): 46 | for i in range(3): 47 | for j in range(3): 48 | if np.array_equal(tmp_target_direction[i], current_direction[j]): 49 | permute_order.append(j) 50 | #print(i, j) 51 | #print(permute_order) 52 | break 53 | redirect_img = sitk.PermuteAxes(itkimg, permute_order) 54 | else: 55 | redirect_img = itkimg 56 | # flip axis 57 | current_direction = np.round(np.array(redirect_img.GetDirection())).reshape(3,3).T 58 | current_direction = np.max(current_direction, axis=1) 59 | 60 | tmp_target_direction = np.array(target_direction).reshape(3,3).T 61 | tmp_target_direction = np.max(tmp_target_direction, axis=1) 62 | flip_order = ((tmp_target_direction * current_direction) != 1) 63 | fliped_img = sitk.Flip(redirect_img, [bool(flip_order[0]), bool(flip_order[1]), bool(flip_order[2])]) 64 | return fliped_img 65 | 66 | 67 | def CropForeground(imImage, imLabel, context_size=[10, 30, 30]): 68 | # the context_size is in numpy indice order: z, y, x 69 | # Note that SimpleITK use the indice order of: x, y, z 70 | 71 | npImg = sitk.GetArrayFromImage(imImage) 72 | npLab = sitk.GetArrayFromImage(imLabel) 73 | 74 | mask = (npLab>0).astype(np.uint8) # foreground mask 75 | 76 | regions = regionprops(mask) 77 | assert len(regions) == 1 78 | 79 | zz, yy, xx = npImg.shape 80 | 81 | z, y, x = regions[0].centroid 82 | 83 | z_min, y_min, x_min, z_max, y_max, x_max = regions[0].bbox 84 | print('forground size:', z_max-z_min, y_max-y_min, x_max-x_min) 85 | 86 | z, y, x = int(z), int(y), int(x) 87 | 88 | z_min = max(0, z_min-context_size[0]) 89 | z_max = min(zz, z_max+context_size[0]) 90 | y_min = max(0, y_min-context_size[2]) 91 | y_max = min(yy, y_max+context_size[2]) 92 | x_min = max(0, x_min-context_size[1]) 93 | x_max = min(xx, x_max+context_size[1]) 94 | 95 | img = npImg[z_min:z_max, y_min:y_max, x_min:x_max] 96 | lab = npLab[z_min:z_max, y_min:y_max, x_min:x_max] 97 | 98 | croppedImage = sitk.GetImageFromArray(img) 99 | croppedLabel = sitk.GetImageFromArray(lab) 100 | 101 | 102 | croppedImage.SetSpacing(imImage.GetSpacing()) 103 | croppedLabel.SetSpacing(imImage.GetSpacing()) 104 | 105 | croppedImage.SetDirection(imImage.GetDirection()) 106 | croppedLabel.SetDirection(imImage.GetDirection()) 107 | 108 | return croppedImage, croppedLabel 109 | 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /docs/change.md: -------------------------------------------------------------------------------- 1 | ## Recent Changes 2 | 3 | #### Mar. 1, 2023 4 | - Support AMOS CT and MR dataset 5 | - Support using GPU for data augmentation 6 | - The affine transformation (rotation, scaling, translation, shearing) for 3D image is computational intensive. Previously, we use multiple CPU workers to perform augmentation, which is slow (5-6 seconds for 160\*160\*160 image) 7 | - Now, we support two ways to use GPU to acesslerate augmentation (0.1-0.3 s for 160\*160\*160 image, with more GPU memory consumption 1-2 G) 8 | - We support use PyTorch cuda operation. You can simply activate this function my setting 'aug\_device' in the config to 'gpu'. 9 | - We support NVIDIA DALI to perform augmentation. You can set the 'dataloader' to 'dali' to use DALI operations for augmentation. We also provide commonly used DALI augmentation functions in the *training/augmentation\_dali.py*. 10 | - In my own experimence, using PyTorch cuda operation already provides huge acceleration. DALI has more advantages in the cpu mode. In the gpu mode, DALI has limited advantages, but needs a lot of time to learn its APIs. 11 | - Add inference code 12 | - We provide *prediction.py* to make prediction on new testing images and save the corresponding predicted labels into .nii.gz files. 13 | - The *prediction.py* includes pre-processing, ensembled prediction, and post-processing, functions. ** You need to modify the target\_spacing and pre-processing function to make sure the testing image is resampled and normalized the same as training. ** 14 | 15 | 16 | #### Feb. 10, 2023 17 | - Sypport KiTS19 kidney and tumor CT segmentation dataset 18 | 19 | 20 | #### Dec. 19, 2022 21 | - Support distributed training with PyTorch DDP 22 | - We provide a new training script: *train\_ddp.py*, which supports PyTorch distributed training 23 | - The original single process training script: *train.py* is still preserved. 24 | - As the debug for multi-processing is hard, you can develop your algorithm with *train.py*, then using *train\_ddp.py* for faster training or larger batch size 25 | 26 | - Support Automatic Mixed Precision (AMP) 27 | - We provide an option for using half precision training in *train\_ddp* 28 | - We have not benchmark the speed with or without AMP yet, but we find AMP can greatly reduce the GPU memory consumption. So if you want to train large 3D models, AMP is an option. 29 | 30 | - We made several improvment on code quality and readablity 31 | - Using Python logging instead of print 32 | - Save all log information to a .txt file 33 | - Save the configuration of each training with the log 34 | - Use better log format 35 | -------------------------------------------------------------------------------- /inference/inference2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .utils import split_idx 5 | import pdb 6 | 7 | 8 | def inference_whole_image(net, img, args=None): 9 | ''' 10 | img: torch tensor, B, C, H, W 11 | return: prob (after softmax), B, classes, H, W 12 | 13 | Use this function to inference if whole image can be put into GPU without memory issue 14 | Better to be consistent with the training window size 15 | ''' 16 | net.eval() 17 | 18 | with torch.no_grad(): 19 | pred = net(img) 20 | 21 | if isinstance(pred, tuple) or isinstance(pred, list): 22 | pred = pred[0] 23 | 24 | return F.softmax(pred, dim=1) 25 | 26 | 27 | def inference_sliding_window(net, img, args): 28 | ''' 29 | img: torch tensor, B, C, H, W 30 | return: prob (after softmax), B, classes, H, W 31 | 32 | The overlap of two windows will be half the window size 33 | Use this function to inference if out-of-memory occurs when whole image inferecing 34 | Better to be consistent with the training window size 35 | ''' 36 | net.eval() 37 | 38 | B, C, H, W = img.shape 39 | 40 | win_h, win_w = args.window_size 41 | 42 | half_win_h = win_h // 2 43 | half_win_w = win_w // 2 44 | 45 | pred_output = torch.zeros((B, args.classes, H, W)).to(img.device) 46 | 47 | counter = torch.zeros((B, 1, H, W)).to(img.device) 48 | one_count = torch.ones((B, 1, win_h, win_w)).to(img.device) 49 | 50 | with torch.no_grad(): 51 | for i in range(H // half_win_h): 52 | for j in range(W // half_win_w): 53 | 54 | h_start_idx, h_end_idx = split_idx(half_win_h, H, i) 55 | w_start_idx, w_end_idx = split_idx(half_win_w, W, j) 56 | 57 | input_tensor = img[:, :, h_start_idx:h_end_idx, w_start_idx:w_end_idx] 58 | 59 | pred = net(input_tensor) 60 | 61 | if isinstance(pred, tuple) or isinstance(pred, list): 62 | pred = pred[0] 63 | 64 | pred = F.softmax(pred, dim=1) 65 | 66 | pred_output[:, :, h_start_idx:h_end_idx, w_start_idx:w_end_idx] += pred 67 | counter[:, :, h_start_idx:h_end_idx, w_start_idx:w_end_idx] += one_count 68 | 69 | pred_output /= counter 70 | 71 | return pred_output 72 | 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /inference/inference3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .utils import split_idx 5 | import pdb 6 | 7 | 8 | def inference_whole_image(net, img, args=None): 9 | ''' 10 | img: torch tensor, B, C, D, H, W 11 | return: prob (after softmax), B, classes, D, H, W 12 | 13 | Use this function to inference if whole image can be put into GPU without memory issue 14 | Better to be consistent with the training window size 15 | ''' 16 | 17 | net.eval() 18 | 19 | with torch.no_grad(): 20 | pred = net(img) 21 | 22 | if isinstance(pred, tuple) or isinstance(pred, list): 23 | pred = pred[0] 24 | 25 | return F.softmax(pred, dim=1) 26 | 27 | 28 | def inference_sliding_window(net, img, args): 29 | ''' 30 | img: torch tensor, B, C, D, H, W 31 | return: prob (after softmax), B, classes, D, H, W 32 | 33 | The overlap of two windows will be half the window size 34 | 35 | Use this function to inference if out-of-memory occurs when whole image inferencing 36 | Better to be consistent with the training window size 37 | ''' 38 | net.eval() 39 | 40 | B, C, D, H, W = img.shape 41 | 42 | win_d, win_h, win_w = args.window_size 43 | 44 | flag = False 45 | if D < win_d or H < win_h or W < win_w: 46 | flag = True 47 | diff_D = max(0, win_d-D) 48 | diff_H = max(0, win_h-H) 49 | diff_W = max(0, win_w-W) 50 | 51 | img = F.pad(img, (0, diff_W, 0, diff_H, 0, diff_D)) 52 | 53 | origin_D, origin_H, origin_W = D, H, W 54 | B, C, D, H, W = img.shape 55 | 56 | 57 | half_win_d = win_d // 2 58 | half_win_h = win_h // 2 59 | half_win_w = win_w // 2 60 | 61 | pred_output = torch.zeros((B, args.classes, D, H, W)).to(img.device) 62 | 63 | counter = torch.zeros((B, 1, D, H, W)).to(img.device) 64 | one_count = torch.ones((B, 1, win_d, win_h, win_w)).to(img.device) 65 | 66 | with torch.no_grad(): 67 | for i in range(D // half_win_d): 68 | for j in range(H // half_win_h): 69 | for k in range(W // half_win_w): 70 | 71 | d_start_idx, d_end_idx = split_idx(half_win_d, D, i) 72 | h_start_idx, h_end_idx = split_idx(half_win_h, H, j) 73 | w_start_idx, w_end_idx = split_idx(half_win_w, W, k) 74 | 75 | input_tensor = img[:, :, d_start_idx:d_end_idx, h_start_idx:h_end_idx, w_start_idx:w_end_idx] 76 | 77 | pred = net(input_tensor) 78 | 79 | if isinstance(pred, tuple) or isinstance(pred, list): 80 | pred = pred[0] 81 | 82 | pred = F.softmax(pred, dim=1) 83 | 84 | pred_output[:, :, d_start_idx:d_end_idx, h_start_idx:h_end_idx, w_start_idx:w_end_idx] += pred 85 | 86 | counter[:, :, d_start_idx:d_end_idx, h_start_idx:h_end_idx, w_start_idx:w_end_idx] += one_count 87 | 88 | pred_output /= counter 89 | if flag: 90 | pred_output = pred_output[:, :, :origin_D, :origin_H, :origin_W] 91 | 92 | return pred_output 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /inference/utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | def get_inference(args): 5 | if args.dimension == '2d': 6 | if args.sliding_window: 7 | from .inference2d import inference_sliding_window 8 | return inference_sliding_window 9 | else: 10 | from .inference2d import inference_whole_image 11 | return inference_whole_image 12 | 13 | elif args.dimension == '3d': 14 | if args.sliding_window: 15 | from .inference3d import inference_sliding_window 16 | return inference_sliding_window 17 | 18 | else: 19 | from .inference3d import inference_whole_image 20 | return inference_whole_image 21 | 22 | 23 | 24 | else: 25 | raise ValueError('Error in image dimension') 26 | 27 | 28 | 29 | def split_idx(half_win, size, i): 30 | ''' 31 | half_win: The size of half window 32 | size: img size along one axis 33 | i: the patch index 34 | ''' 35 | 36 | start_idx = half_win * i 37 | end_idx = start_idx + half_win*2 38 | 39 | if end_idx > size: 40 | start_idx = size - half_win*2 41 | end_idx = size 42 | 43 | return start_idx, end_idx 44 | 45 | -------------------------------------------------------------------------------- /metric/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from . import metrics 5 | import numpy as np 6 | import pdb 7 | 8 | def calculate_distance(label_pred, label_true, spacing, C, percentage=95): 9 | # the input args are torch tensors 10 | if label_pred.is_cuda: 11 | label_pred = label_pred.cpu() 12 | label_true = label_true.cpu() 13 | 14 | label_pred = label_pred.numpy() 15 | label_true = label_true.numpy() 16 | spacing = spacing.numpy() 17 | 18 | ASD_list = np.zeros(C-1) 19 | HD_list = np.zeros(C-1) 20 | 21 | for i in range(C-1): 22 | tmp_surface = metrics.compute_surface_distances(label_true==(i+1), label_pred==(i+1), spacing) 23 | dis_gt_to_pred, dis_pred_to_gt = metrics.compute_average_surface_distance(tmp_surface) 24 | ASD_list[i] = (dis_gt_to_pred + dis_pred_to_gt) / 2 25 | 26 | HD = metrics.compute_robust_hausdorff(tmp_surface, percentage) 27 | HD_list[i] = HD 28 | 29 | return ASD_list, HD_list 30 | 31 | 32 | 33 | def calculate_dice_split(pred, target, C, block_size=64*64*64): 34 | 35 | assert pred.shape[0] == target.shape[0] 36 | N = pred.shape[0] 37 | total_sum = torch.zeros(C).to(pred.device) 38 | total_intersection = torch.zeros(C).to(pred.device) 39 | 40 | split_num = N // block_size 41 | for i in range(split_num): 42 | dice, intersection, summ = calculate_dice(pred[i*block_size:(i+1)*block_size, :], target[i*block_size:(i+1)*block_size, :], C) 43 | total_intersection += intersection 44 | total_sum += summ 45 | if N % block_size != 0: 46 | dice, intersection, summ = calculate_dice(pred[(i+1)*block_size:, :], target[(i+1)*block_size:, :], C) 47 | total_intersection += intersection 48 | total_sum += summ 49 | 50 | dice = 2 * total_intersection / (total_sum + 1e-5) 51 | 52 | return dice, total_intersection, total_sum 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | def calculate_dice(pred, target, C): 63 | # pred and target are torch tensor 64 | target = target.long() 65 | pred = pred.long() 66 | N = pred.shape[0] 67 | target_mask = target.data.new(N, C).fill_(0) 68 | target_mask.scatter_(1, target, 1.) 69 | 70 | pred_mask = pred.data.new(N, C).fill_(0) 71 | pred_mask.scatter_(1, pred, 1.) 72 | 73 | intersection= pred_mask * target_mask 74 | summ = pred_mask + target_mask 75 | 76 | intersection = intersection.sum(0).type(torch.float32) 77 | summ = summ.sum(0).type(torch.float32) 78 | 79 | summ += 1e-5 80 | dice = 2 * intersection / summ 81 | 82 | return dice, intersection, summ 83 | 84 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhygao/CBIM-Medical-Image-Segmentation/7c26979a96eb9fe057320e1db38680bae33786b8/model/__init__.py -------------------------------------------------------------------------------- /model/dim2/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet import UNet 2 | from .unetpp import UNetPlusPlus 3 | from .attention_unet import AttentionUNet 4 | from .dual_attention_unet import DAUNet 5 | from .transunet import VisionTransformer 6 | from .swin_unet import SwinUnet 7 | from .medformer import MedFormer 8 | -------------------------------------------------------------------------------- /model/dim2/attention_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .unet_utils import inconv, down_block 5 | from .utils import get_block, get_norm 6 | from .attention_unet_utils import attention_up_block 7 | 8 | class AttentionUNet(nn.Module): 9 | def __init__(self, in_ch, num_classes, base_ch=32, block='SingleConv', pool=True): 10 | super().__init__() 11 | 12 | num_block = 2 13 | block = get_block(block) 14 | 15 | self.inc = inconv(in_ch, base_ch, block=block) 16 | 17 | self.down1 = down_block(base_ch, 2*base_ch, num_block=num_block, block=block, pool=pool) 18 | self.down2 = down_block(2*base_ch, 4*base_ch, num_block=num_block, block=block, pool=pool) 19 | self.down3 = down_block(4*base_ch, 8*base_ch, num_block=num_block, block=block, pool=pool) 20 | self.down4 = down_block(8*base_ch, 16*base_ch, num_block=num_block, block=block, pool=pool) 21 | 22 | self.up1 = attention_up_block(16*base_ch, 8*base_ch, num_block=num_block, block=block) 23 | self.up2 = attention_up_block(8*base_ch, 4*base_ch, num_block=num_block, block=block) 24 | self.up3 = attention_up_block(4*base_ch, 2*base_ch, num_block=num_block, block=block) 25 | self.up4 = attention_up_block(2*base_ch, base_ch, num_block=num_block, block=block) 26 | 27 | self.outc = nn.Conv2d(base_ch, num_classes, kernel_size=1) 28 | 29 | def forward(self, x): 30 | 31 | x1 = self.inc(x) 32 | x2 = self.down1(x1) 33 | x3 = self.down2(x2) 34 | x4 = self.down3(x3) 35 | x5 = self.down4(x4) 36 | 37 | out = self.up1(x5, x4) 38 | out = self.up2(out, x3) 39 | out = self.up3(out, x2) 40 | out = self.up4(out, x1) 41 | out = self.outc(out) 42 | 43 | return out 44 | 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /model/dim2/attention_unet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .conv_layers import BasicBlock, Bottleneck, SingleConv 5 | 6 | class AttentionBlock(nn.Module): 7 | def __init__(self, g_ch, l_ch, int_ch): 8 | super().__init__() 9 | 10 | self.W_g = nn.Sequential( 11 | nn.Conv2d(g_ch, int_ch, kernel_size=1, stride=1, padding=0, bias=False), 12 | nn.BatchNorm2d(int_ch) 13 | ) 14 | self.W_x = nn.Sequential( 15 | nn.Conv2d(l_ch, int_ch, kernel_size=1, stride=1, padding=0, bias=False), 16 | nn.BatchNorm2d(int_ch) 17 | ) 18 | self.psi = nn.Sequential( 19 | nn.Conv2d(int_ch, 1, kernel_size=1, stride=1, padding=0, bias=False), 20 | nn.BatchNorm2d(1), 21 | nn.Sigmoid() 22 | ) 23 | 24 | self.relu = nn.ReLU(inplace=True) 25 | 26 | def forward(self, g, x): 27 | g1 = self.W_g(g) 28 | x1 = self.W_x(x) 29 | 30 | psi = self.relu(g1 + x1) 31 | psi = self.psi(psi) 32 | 33 | return x * psi 34 | 35 | 36 | class attention_up_block(nn.Module): 37 | def __init__(self, in_ch, out_ch, num_block, block=BasicBlock, norm=nn.BatchNorm2d): 38 | super().__init__() 39 | 40 | self.attn = AttentionBlock(in_ch, out_ch, out_ch//2) 41 | 42 | block_list = [] 43 | block_list.append(block(in_ch+out_ch, out_ch)) 44 | 45 | for i in range(num_block-1): 46 | block_list.append(block(out_ch, out_ch)) 47 | 48 | self.conv = nn.Sequential(*block_list) 49 | 50 | def forward(self, x1, x2): 51 | x1 = F.interpolate(x1, size=x2.shape[2:], mode='bilinear', align_corners=True) 52 | 53 | x2 = self.attn(x1, x2) 54 | 55 | out = torch.cat([x2, x1], dim=1) 56 | 57 | out = self.conv(out) 58 | 59 | return out 60 | 61 | -------------------------------------------------------------------------------- /model/dim2/dual_attention_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .utils import get_block, get_norm 5 | from .unet_utils import inconv, down_block, up_block 6 | from .dual_attention_utils import DAHead 7 | 8 | class DAUNet(nn.Module): 9 | 10 | def __init__(self, in_ch, num_classes, base_ch=32, block='BasicBlock', pool=True): 11 | super().__init__() 12 | 13 | block = get_block(block) 14 | nb = 2 # num_block 15 | 16 | self.inc = inconv(in_ch, base_ch, block=block) 17 | 18 | self.down1 = down_block(base_ch, 2*base_ch, num_block=nb, block=block, pool=pool) 19 | self.down2 = down_block(2*base_ch, 4*base_ch, num_block=nb, block=block, pool=pool) 20 | self.down3 = down_block(4*base_ch, 8*base_ch, num_block=nb, block=block, pool=pool) 21 | self.down4 = down_block(8*base_ch, 16*base_ch, num_block=nb, block=block, pool=pool) 22 | 23 | self.DAModule = DAHead(16*base_ch, num_classes) 24 | 25 | self.up1 = up_block(16*base_ch, 8*base_ch, num_block=nb, block=block) 26 | self.up2 = up_block(8*base_ch, 4*base_ch, num_block=nb, block=block) 27 | self.up3 = up_block(4*base_ch, 2*base_ch, num_block=nb, block=block) 28 | self.up4 = up_block(2*base_ch, base_ch, num_block=nb, block=block) 29 | 30 | self.outc = nn.Conv2d(base_ch, num_classes, kernel_size=1) 31 | 32 | def forward(self, x): 33 | x1 = self.inc(x) 34 | x2 = self.down1(x1) 35 | x3 = self.down2(x2) 36 | x4 = self.down3(x3) 37 | x5 = self.down4(x4) 38 | 39 | feat_fuse, sasc_pred, sa_pred, sc_pred = self.DAModule(x5) 40 | 41 | out = self.up1(feat_fuse, x4) 42 | out = self.up2(out, x3) 43 | out = self.up3(out, x2) 44 | out = self.up4(out, x1) 45 | out = self.outc(out) 46 | 47 | return out 48 | -------------------------------------------------------------------------------- /model/dim2/dual_attention_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | 6 | 7 | class DAHead(nn.Module): 8 | def __init__(self, in_channels, n_classes): 9 | super(DAHead, self).__init__() 10 | inter_channels = in_channels // 4 11 | self.conv_a = nn.Sequential( 12 | nn.BatchNorm2d(in_channels), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False) 15 | ) 16 | self.conv_c = nn.Sequential( 17 | nn.BatchNorm2d(in_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False) 20 | ) 21 | 22 | self.sa = PAM_Module(inter_channels) 23 | self.sc = CAM_Module(inter_channels) 24 | 25 | 26 | self.conv_a_1 = nn.Sequential( 27 | nn.BatchNorm2d(inter_channels), 28 | nn.ReLU(inplace=True), 29 | nn.Conv2d(inter_channels, in_channels, 1) 30 | ) 31 | 32 | 33 | self.conv_c_1 = nn.Sequential( 34 | nn.BatchNorm2d(inter_channels), 35 | nn.ReLU(inplace=True), 36 | nn.Conv2d(inter_channels, in_channels, 1) 37 | ) 38 | 39 | self.conv_a_out = nn.Sequential( 40 | nn.Dropout(0.1), 41 | nn.Conv2d(in_channels, n_classes, 1) 42 | ) 43 | 44 | self.conv_c_out = nn.Sequential( 45 | nn.Dropout(0.1), 46 | nn.Conv2d(in_channels, n_classes, 1) 47 | ) 48 | 49 | self.fuse_out = nn.Sequential( 50 | nn.Dropout(0.1), 51 | nn.Conv2d(in_channels, n_classes, 1) 52 | ) 53 | 54 | 55 | def forward(self, x): 56 | sa_feat = self.conv_a(x) 57 | sa_feat = self.sa(sa_feat) 58 | sa_feat = self.conv_a_1(sa_feat) 59 | 60 | 61 | sc_feat = self.conv_c(x) 62 | sc_feat = self.sc(sc_feat) 63 | sc_feat = self.conv_c_1(sc_feat) 64 | 65 | feat_fusion = sa_feat + sc_feat 66 | 67 | sa_out = self.conv_a_out(sa_feat) 68 | sc_out = self.conv_c_out(sc_feat) 69 | sasc_out = self.fuse_out(feat_fusion) 70 | 71 | 72 | return feat_fusion, sasc_out, sa_out, sc_out 73 | 74 | 75 | class PAM_Module(nn.Module): 76 | """ Position attention module""" 77 | 78 | def __init__(self, in_dim, reduction=8): 79 | super(PAM_Module, self).__init__() 80 | 81 | self.chanel_in = in_dim 82 | self.reduction = reduction 83 | 84 | self.query_conv = nn.Conv2d(in_dim, in_dim//self.reduction, kernel_size=1) 85 | self.key_conv = nn.Conv2d(in_dim, in_dim//self.reduction, kernel_size=1) 86 | self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1) 87 | self.gamma = nn.Parameter(torch.zeros(1)) 88 | 89 | self.softmax = nn.Softmax(dim=-1) 90 | 91 | def forward(self, x): 92 | """ 93 | inputs: 94 | x: input feature maps (B * C * H * W) 95 | returns: 96 | out: attention value + input feature 97 | 98 | """ 99 | 100 | m_batchsize, C, height, width = x.shape 101 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) 102 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) 103 | energy = torch.bmm(proj_query, proj_key) 104 | attention = self.softmax(energy) 105 | proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) 106 | 107 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 108 | out = out.view(m_batchsize, C, height, width) 109 | 110 | out = self.gamma * out + x 111 | 112 | return out 113 | 114 | 115 | class CAM_Module(nn.Module): 116 | """ Channel attention module""" 117 | 118 | def __init__(self, in_dim): 119 | super(CAM_Module, self).__init__() 120 | self.chanel_in = in_dim 121 | 122 | self.gamma = nn.Parameter(torch.zeros(1)) 123 | self.softmax = nn.Softmax(dim=-1) 124 | 125 | def forward(self, x): 126 | """ 127 | inputs: 128 | x: input feature maps (B * C * H * W) 129 | returns: 130 | out: attention value + input feature 131 | """ 132 | m_batchsize, C, height, width = x.shape 133 | proj_query = x.view(m_batchsize, C, -1) 134 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) 135 | energy = torch.bmm(proj_query, proj_key) 136 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy 137 | #energy_new = torch.max(energy, -1, keepdim=True)[0] 138 | #energy_new = energy_new.expand_as(energy) 139 | #energy_new -= energy 140 | attention = self.softmax(energy_new) 141 | proj_value = x.view(m_batchsize, C, -1) 142 | 143 | out = torch.bmm(attention, proj_value) 144 | out = out.view(m_batchsize, C, height, width) 145 | 146 | out = self.gamma * out + x 147 | return out 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | if __name__ == '__main__': 158 | 159 | import pdb 160 | cam = CAM_Module(in_dim=24) 161 | pdb.set_trace() 162 | arr = torch.randn((8, 24, 120, 120)) 163 | out = cam(arr) 164 | -------------------------------------------------------------------------------- /model/dim2/medformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .utils import get_block 6 | from .medformer_utils import down_block, up_block, inconv, SemanticMapFusion 7 | import pdb 8 | 9 | 10 | class MedFormer(nn.Module): 11 | 12 | def __init__(self, in_chan, num_classes, base_chan=32, map_size=8, conv_block='BasicBlock', conv_num=[2,1,0,0, 0,1,2,2], trans_num=[0,1,2,2, 2,1,0,0], num_heads=[1,4,8,16, 8,4,1,1], fusion_depth=2, fusion_dim=512, fusion_heads=16, expansion=4, attn_drop=0., proj_drop=0., proj_type='depthwise', norm=nn.BatchNorm2d, act=nn.ReLU, aux_loss=False): 13 | super().__init__() 14 | 15 | 16 | chan_num = [2*base_chan, 4*base_chan, 8*base_chan, 16*base_chan, 17 | 8*base_chan, 4*base_chan, 2*base_chan, base_chan] 18 | dim_head = [chan_num[i]//num_heads[i] for i in range(8)] 19 | conv_block = get_block(conv_block) 20 | 21 | # self.inc and self.down1 forms the conv stem 22 | self.inc = inconv(in_chan, base_chan, norm=norm, act=act) 23 | self.down1 = down_block(base_chan, chan_num[0], conv_num[0], trans_num[0], conv_block, norm=norm, act=act, map_generate=False) 24 | 25 | # down2 down3 down4 apply the B-MHA blocks 26 | self.down2 = down_block(chan_num[0], chan_num[1], conv_num[1], trans_num[1], conv_block, heads=num_heads[1], dim_head=dim_head[1], expansion=expansion, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, norm=norm, act=act, map_generate=True) 27 | self.down3 = down_block(chan_num[1], chan_num[2], conv_num[2], trans_num[2], conv_block, heads=num_heads[2], dim_head=dim_head[2], expansion=expansion, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, norm=norm, act=act, map_generate=True) 28 | self.down4 = down_block(chan_num[2], chan_num[3], conv_num[3], trans_num[3], conv_block, heads=num_heads[3], dim_head=dim_head[3], expansion=expansion, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, norm=norm, act=act, map_generate=True) 29 | 30 | 31 | self.map_fusion = SemanticMapFusion(chan_num[1:4], fusion_dim, fusion_heads, depth=fusion_depth, norm=norm) 32 | 33 | 34 | self.up1 = up_block(chan_num[3], chan_num[4], conv_num[4], trans_num[4], conv_block, heads=num_heads[4], dim_head=dim_head[4], expansion=expansion, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, norm=norm, act=act, map_shortcut=True) 35 | self.up2 = up_block(chan_num[4], chan_num[5], conv_num[5], trans_num[5], conv_block, heads=num_heads[5], dim_head=dim_head[5], expansion=expansion, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, norm=norm, act=act, map_shortcut=True) 36 | 37 | # up3 up4 form the conv decoder 38 | self.up3 = up_block(chan_num[5], chan_num[6], conv_num[6], trans_num[6], conv_block, norm=norm, act=act, map_shortcut=False) 39 | self.up4 = up_block(chan_num[6], chan_num[7], conv_num[7], trans_num[7], conv_block, norm=norm, act=act, map_shortcut=False) 40 | 41 | 42 | self.outc = nn.Conv2d(chan_num[7], num_classes, kernel_size=1) 43 | 44 | self.aux_loss = aux_loss 45 | if aux_loss: 46 | self.aux_out = nn.Conv2d(chan_num[5], num_classes, kernel_size=1) 47 | 48 | def forward(self, x): 49 | 50 | x0 = self.inc(x) 51 | x1, _ = self.down1(x0) 52 | x2, map2 = self.down2(x1) 53 | x3, map3 = self.down3(x2) 54 | x4, map4 = self.down4(x3) 55 | 56 | map_list = [map2, map3, map4] 57 | map_list = self.map_fusion(map_list) 58 | 59 | out, semantic_map = self.up1(x4, x3, map_list[2], map_list[1]) 60 | out, semantic_map = self.up2(out, x2, semantic_map, map_list[0]) 61 | 62 | if self.aux_loss: 63 | aux_out = self.aux_out(out) 64 | aux_out = F.interpolate(aux_out, size=x.shape[-2:], mode='bilinear', align_corners=True) 65 | 66 | out, semantic_map = self.up3(out, x1, semantic_map, None) 67 | out, semantic_map = self.up4(out, x0, semantic_map, None) 68 | 69 | out = self.outc(out) 70 | 71 | if self.aux_loss: 72 | return [out, aux_out] 73 | else: 74 | return out 75 | 76 | -------------------------------------------------------------------------------- /model/dim2/trans_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | import pdb 6 | 7 | 8 | __all__ = [ 9 | 'Mlp', 10 | 'Attention', 11 | 'TransformerBlock', 12 | ] 13 | 14 | class Mlp(nn.Module): 15 | def __init__(self, in_dim, hid_dim=None, out_dim=None, act=nn.GELU, drop=0.): 16 | super().__init__() 17 | out_dim = out_dim or in_dim 18 | hid_dim = hid_dim or in_dim 19 | self.fc1 = nn.Linear(in_dim, hid_dim) 20 | self.act = act() 21 | self.fc2 = nn.Linear(hid_dim, out_dim) 22 | self.drop = nn.Dropout(drop) 23 | 24 | def forward(self, x): 25 | x = self.fc1(x) 26 | x = self.act(x) 27 | x = self.drop(x) 28 | x = self.fc2(x) 29 | x = self.drop(x) 30 | 31 | return x 32 | 33 | class PreNorm(nn.Module): 34 | def __init__(self, dim, fn): 35 | super().__init__() 36 | self.norm = nn.LayerNorm(dim) 37 | self.fn = fn 38 | def forward(self, x, **kwargs): 39 | return self.fn(self.norm(x), **kwargs) 40 | 41 | 42 | 43 | class Attention(nn.Module): 44 | def __init__(self, dim, heads, dim_head, attn_drop=0., proj_drop=0.): 45 | super().__init__() 46 | 47 | inner_dim = dim_head * heads 48 | 49 | self.heads = heads 50 | self.scale = dim_head ** -0.5 51 | 52 | self.to_qkv = nn.Linear(dim, inner_dim*3, bias=False) 53 | 54 | self.to_out = nn.Linear(inner_dim, dim) 55 | 56 | self.proj_drop = nn.Dropout(proj_drop) 57 | 58 | def forward(self, x): 59 | # x: B, L, C. Batch, sequence length, dim 60 | q, k, v = self.to_qkv(x).chunk(3, dim=-1) 61 | 62 | q, k, v = map(lambda t: rearrange(t, 'b l (heads dim_head) -> b heads l dim_head', heads=self.heads), [q, k, v]) 63 | attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 64 | 65 | attn = F.softmax(attn, dim=-1) 66 | 67 | attned = torch.einsum('bhij,bhjd->bhid', attn, v) 68 | attned = rearrange(attned, 'b heads l dim_head -> b l (dim_head heads)') 69 | 70 | attned = self.to_out(attned) 71 | 72 | return attned 73 | 74 | 75 | class TransformerBlock(nn.Module): 76 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, attn_drop=0., proj_drop=0.): 77 | super().__init__() 78 | 79 | self.layers = nn.ModuleList([]) 80 | 81 | for i in range(depth): 82 | self.layers.append(nn.ModuleList([ 83 | PreNorm(dim, Attention(dim, heads, dim_head, attn_drop, proj_drop)), 84 | PreNorm(dim, Mlp(dim, mlp_dim, dim, drop=proj_drop)) 85 | ])) 86 | def forward(self, x): 87 | 88 | for attn, ffn in self.layers: 89 | x = attn(x) + x 90 | x = ffn(x) + x 91 | 92 | return x 93 | 94 | 95 | -------------------------------------------------------------------------------- /model/dim2/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .utils import get_block, get_norm 5 | from .unet_utils import inconv, down_block, up_block 6 | 7 | class UNet(nn.Module): 8 | 9 | def __init__(self, in_ch, num_classes, base_ch=32, block='SingleConv', pool=True): 10 | super().__init__() 11 | 12 | block = get_block(block) 13 | nb = 2 # num_block 14 | 15 | self.inc = inconv(in_ch, base_ch, block=block) 16 | 17 | self.down1 = down_block(base_ch, 2*base_ch, num_block=nb, block=block, pool=pool) 18 | self.down2 = down_block(2*base_ch, 4*base_ch, num_block=nb, block=block, pool=pool) 19 | self.down3 = down_block(4*base_ch, 8*base_ch, num_block=nb, block=block, pool=pool) 20 | self.down4 = down_block(8*base_ch, 16*base_ch, num_block=nb, block=block, pool=pool) 21 | 22 | self.up1 = up_block(16*base_ch, 8*base_ch, num_block=nb, block=block) 23 | self.up2 = up_block(8*base_ch, 4*base_ch, num_block=nb, block=block) 24 | self.up3 = up_block(4*base_ch, 2*base_ch, num_block=nb, block=block) 25 | self.up4 = up_block(2*base_ch, base_ch, num_block=nb, block=block) 26 | 27 | self.outc = nn.Conv2d(base_ch, num_classes, kernel_size=1) 28 | 29 | def forward(self, x): 30 | x1 = self.inc(x) 31 | x2 = self.down1(x1) 32 | x3 = self.down2(x2) 33 | x4 = self.down3(x3) 34 | x5 = self.down4(x4) 35 | 36 | 37 | out = self.up1(x5, x4) 38 | out = self.up2(out, x3) 39 | out = self.up3(out, x2) 40 | out = self.up4(out, x1) 41 | out = self.outc(out) 42 | 43 | return out 44 | -------------------------------------------------------------------------------- /model/dim2/unet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .conv_layers import BasicBlock, Bottleneck, MBConv, FusedMBConv, ConvNeXtBlock 5 | 6 | class inconv(nn.Module): 7 | def __init__(self, in_ch, out_ch, block=BasicBlock): 8 | super().__init__() 9 | self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False) 10 | 11 | self.conv2 = block(out_ch, out_ch) 12 | 13 | def forward(self, x): 14 | out = self.conv1(x) 15 | out = self.conv2(out) 16 | 17 | return out 18 | 19 | 20 | class down_block(nn.Module): 21 | def __init__(self, in_ch, out_ch, num_block, block=BasicBlock, pool=True): 22 | super().__init__() 23 | 24 | block_list = [] 25 | 26 | 27 | if pool: 28 | block_list.append(nn.MaxPool2d(2)) 29 | block_list.append(block(in_ch, out_ch)) 30 | else: 31 | block_list.append(block(in_ch, out_ch, stride=2)) 32 | 33 | for i in range(num_block-1): 34 | block_list.append(block(out_ch, out_ch, stride=1)) 35 | 36 | self.conv = nn.Sequential(*block_list) 37 | def forward(self, x): 38 | return self.conv(x) 39 | 40 | class up_block(nn.Module): 41 | def __init__(self, in_ch, out_ch, num_block, block=BasicBlock): 42 | super().__init__() 43 | 44 | self.conv_ch = nn.Conv2d(in_ch, out_ch, kernel_size=1) 45 | 46 | block_list = [] 47 | block_list.append(block(2*out_ch, out_ch)) 48 | 49 | 50 | for i in range(num_block-1): 51 | block_list.append(block(out_ch, out_ch)) 52 | 53 | self.conv = nn.Sequential(*block_list) 54 | 55 | def forward(self, x1, x2): 56 | x1 = F.interpolate(x1, scale_factor=2, mode='bilinear', align_corners=True) 57 | x1 = self.conv_ch(x1) 58 | 59 | out = torch.cat([x2, x1], dim=1) 60 | out = self.conv(out) 61 | 62 | return out 63 | 64 | 65 | -------------------------------------------------------------------------------- /model/dim2/unetpp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .utils import get_block, get_norm 5 | 6 | 7 | class UNetPlusPlus(nn.Module): 8 | def __init__(self, in_ch, num_classes, base_ch=32, block='SingleConv'): 9 | super().__init__() 10 | 11 | num_block = 2 12 | block = get_block(block) 13 | 14 | n_ch = [base_ch, base_ch*2, base_ch*4, base_ch*8, base_ch*16] 15 | 16 | self.pool = nn.MaxPool2d(2, 2) 17 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 18 | 19 | 20 | self.conv0_0 = self.make_layer(in_ch, n_ch[0], num_block, block) 21 | self.conv1_0 = self.make_layer(n_ch[0], n_ch[1], num_block, block) 22 | self.conv2_0 = self.make_layer(n_ch[1], n_ch[2], num_block, block) 23 | self.conv3_0 = self.make_layer(n_ch[2], n_ch[3], num_block, block) 24 | self.conv4_0 = self.make_layer(n_ch[3], n_ch[4], num_block, block) 25 | self.conv0_1 = self.make_layer(n_ch[0]+n_ch[1], n_ch[0], num_block, block) 26 | self.conv1_1 = self.make_layer(n_ch[1]+n_ch[2], n_ch[1], num_block, block) 27 | self.conv2_1 = self.make_layer(n_ch[2]+n_ch[3], n_ch[2], num_block, block) 28 | self.conv3_1 = self.make_layer(n_ch[3]+n_ch[4], n_ch[3], num_block, block) 29 | 30 | self.conv0_2 = self.make_layer(n_ch[0]*2+n_ch[1], n_ch[0], num_block, block) 31 | self.conv1_2 = self.make_layer(n_ch[1]*2+n_ch[2], n_ch[1], num_block, block) 32 | self.conv2_2 = self.make_layer(n_ch[2]*2+n_ch[3], n_ch[2], num_block, block) 33 | 34 | self.conv0_3 = self.make_layer(n_ch[0]*3+n_ch[1], n_ch[0], num_block, block) 35 | self.conv1_3 = self.make_layer(n_ch[1]*3+n_ch[2], n_ch[1], num_block, block) 36 | 37 | 38 | self.conv0_4 = self.make_layer(n_ch[0]*4+n_ch[1], n_ch[0], num_block, block) 39 | 40 | 41 | self.output = nn.Conv2d(n_ch[0], num_classes, kernel_size=1) 42 | 43 | 44 | def forward(self, x): 45 | 46 | x0_0 = self.conv0_0(x) 47 | x1_0 = self.conv1_0(self.pool(x0_0)) 48 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1)) 49 | 50 | x2_0 = self.conv2_0(self.pool(x1_0)) 51 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1)) 52 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1)) 53 | 54 | x3_0 = self.conv3_0(self.pool(x2_0)) 55 | x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1)) 56 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1)) 57 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1)) 58 | 59 | x4_0 = self.conv4_0(self.pool(x3_0)) 60 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1)) 61 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1)) 62 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1)) 63 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1)) 64 | 65 | output = self.output(x0_4) 66 | 67 | return output 68 | 69 | 70 | def make_layer(self, in_ch, out_ch, num_block, block): 71 | blocks = [] 72 | blocks.append(block(in_ch, out_ch)) 73 | 74 | for i in range(num_block-1): 75 | blocks.append(block(out_ch, out_ch)) 76 | 77 | return nn.Sequential(*blocks) 78 | 79 | 80 | -------------------------------------------------------------------------------- /model/dim2/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .conv_layers import BasicBlock, Bottleneck, SingleConv, MBConv, FusedMBConv, ConvNeXtBlock 5 | 6 | def get_block(name): 7 | block_map = { 8 | 'SingleConv': SingleConv, 9 | 'BasicBlock': BasicBlock, 10 | 'Bottleneck': Bottleneck, 11 | 'MBConv': MBConv, 12 | 'FusedMBConv': FusedMBConv, 13 | 'ConvNeXtBlock': ConvNeXtBlock 14 | } 15 | return block_map[name] 16 | 17 | def get_norm(name): 18 | norm_map = {'bn': nn.BatchNorm3d, 19 | 'in': nn.InstanceNorm3d 20 | } 21 | 22 | return norm_map[name] 23 | 24 | -------------------------------------------------------------------------------- /model/dim3/__init__.py: -------------------------------------------------------------------------------- 1 | from .vnet import VNet 2 | from .unet import UNet 3 | from .unetpp import UNetPlusPlus 4 | from .attention_unet import AttentionUNet 5 | from .unetr import UNETR 6 | from .vtunet import VTUNet 7 | from .medformer import MedFormer 8 | from .swin_unetr import SwinUNETR 9 | from .nnformer import nnFormer 10 | -------------------------------------------------------------------------------- /model/dim3/attention_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .unet_utils import inconv, down_block 5 | from .utils import get_block, get_norm 6 | from .attention_unet_utils import attention_up_block 7 | 8 | class AttentionUNet(nn.Module): 9 | def __init__(self, in_ch, base_ch, scale, kernel_size, num_classes=1, block='SingleConv', pool=True, norm='bn'): 10 | super().__init__() 11 | 12 | num_block = 2 13 | block = get_block(block) 14 | norm = get_norm(norm) 15 | 16 | self.inc = inconv(in_ch, base_ch, block=block, kernel_size=kernel_size[0], norm=norm) 17 | 18 | self.down1 = down_block(base_ch, 2*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[0], kernel_size=kernel_size[1], norm=norm) 19 | self.down2 = down_block(2*base_ch, 4*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[1], kernel_size=kernel_size[2], norm=norm) 20 | self.down3 = down_block(4*base_ch, 8*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[2], kernel_size=kernel_size[3], norm=norm) 21 | self.down4 = down_block(8*base_ch, 10*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[3], kernel_size=kernel_size[4], norm=norm) 22 | 23 | self.up1 = attention_up_block(10*base_ch, 8*base_ch, num_block=num_block, block=block, up_scale=scale[3], kernel_size=kernel_size[3], norm=norm) 24 | self.up2 = attention_up_block(8*base_ch, 4*base_ch, num_block=num_block, block=block, up_scale=scale[2], kernel_size=kernel_size[2], norm=norm) 25 | self.up3 = attention_up_block(4*base_ch, 2*base_ch, num_block=num_block, block=block, up_scale=scale[1], kernel_size=kernel_size[1], norm=norm) 26 | self.up4 = attention_up_block(2*base_ch, base_ch, num_block=num_block, block=block, up_scale=scale[0], kernel_size=kernel_size[0], norm=norm) 27 | 28 | self.outc = nn.Conv3d(base_ch, num_classes, kernel_size=1) 29 | 30 | def forward(self, x): 31 | 32 | x1 = self.inc(x) 33 | x2 = self.down1(x1) 34 | x3 = self.down2(x2) 35 | x4 = self.down3(x3) 36 | x5 = self.down4(x4) 37 | 38 | out = self.up1(x5, x4) 39 | out = self.up2(out, x3) 40 | out = self.up3(out, x2) 41 | out = self.up4(out, x1) 42 | out = self.outc(out) 43 | 44 | return out 45 | 46 | 47 | 48 | -------------------------------------------------------------------------------- /model/dim3/attention_unet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .conv_layers import BasicBlock, Bottleneck, ConvNormAct 5 | 6 | class AttentionBlock(nn.Module): 7 | def __init__(self, g_ch, l_ch, int_ch): 8 | super().__init__() 9 | 10 | self.W_g = nn.Sequential( 11 | nn.Conv3d(g_ch, int_ch, kernel_size=1, stride=1, padding=0, bias=False), 12 | nn.InstanceNorm3d(int_ch) 13 | ) 14 | self.W_x = nn.Sequential( 15 | nn.Conv3d(l_ch, int_ch, kernel_size=1, stride=1, padding=0, bias=False), 16 | nn.InstanceNorm3d(int_ch) 17 | ) 18 | self.psi = nn.Sequential( 19 | nn.Conv3d(int_ch, 1, kernel_size=1, stride=1, padding=0, bias=False), 20 | nn.InstanceNorm3d(1), 21 | nn.Sigmoid() 22 | ) 23 | 24 | self.relu = nn.ReLU(inplace=True) 25 | 26 | def forward(self, g, x): 27 | # g: input low-res feature 28 | # x: high-res feature from encoder 29 | g1 = self.W_g(g) 30 | x1 = self.W_x(x) 31 | 32 | psi = self.relu(g1 + x1) 33 | psi = self.psi(psi) 34 | 35 | return x * psi 36 | 37 | class attention_up_block(nn.Module): 38 | def __init__(self, in_ch, out_ch, num_block, block=BasicBlock, kernel_size=[3,3,3], up_scale=[2,2,2], norm=nn.BatchNorm3d): 39 | super().__init__() 40 | 41 | self.conv_ch = nn.Conv3d(in_ch, out_ch, kernel_size=1) 42 | 43 | self.up_scale = up_scale 44 | 45 | self.attn = AttentionBlock(in_ch, out_ch, out_ch//2) 46 | 47 | block_list = [] 48 | block_list.append(block(in_ch+out_ch, out_ch, kernel_size=kernel_size, norm=norm)) 49 | 50 | for i in range(num_block-1): 51 | block_list.append(block(out_ch, out_ch, kernel_size=kernel_size, norm=norm)) 52 | 53 | self.conv = nn.Sequential(*block_list) 54 | 55 | def forward(self, x1, x2): 56 | x1 = F.interpolate(x1, size=x2.shape[2:], mode='trilinear', align_corners=True) 57 | 58 | x2 = self.attn(x1, x2) 59 | 60 | 61 | out = torch.cat([x2, x1], dim=1) 62 | out = self.conv(out) 63 | 64 | return out 65 | 66 | 67 | -------------------------------------------------------------------------------- /model/dim3/medformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .utils import get_block, get_norm, get_act 6 | from .medformer_utils import down_block, up_block, inconv, SemanticMapFusion 7 | import pdb 8 | 9 | 10 | 11 | class MedFormer(nn.Module): 12 | 13 | def __init__(self, 14 | in_chan, 15 | num_classes, 16 | base_chan=32, 17 | map_size=[4,8,8], 18 | conv_block='BasicBlock', 19 | conv_num=[2,1,0,0, 0,1,2,2], 20 | trans_num=[0,1,2,2, 2,1,0,0], 21 | chan_num=[64,128,256,320,256,128,64,32], 22 | num_heads=[1,4,8,16, 8,4,1,1], 23 | fusion_depth=2, 24 | fusion_dim=320, 25 | fusion_heads=4, 26 | expansion=4, attn_drop=0., 27 | proj_drop=0., 28 | proj_type='depthwise', 29 | norm='in', 30 | act='gelu', 31 | kernel_size=[3,3,3,3], 32 | scale=[2,2,2,2], 33 | aux_loss=False 34 | ): 35 | super().__init__() 36 | 37 | if conv_block == 'BasicBlock': 38 | dim_head = [chan_num[i]//num_heads[i] for i in range(8)] 39 | 40 | 41 | conv_block = get_block(conv_block) 42 | norm = get_norm(norm) 43 | act = get_act(act) 44 | 45 | # self.inc and self.down1 forms the conv stem 46 | self.inc = inconv(in_chan, base_chan, block=conv_block, kernel_size=kernel_size[0], norm=norm, act=act) 47 | self.down1 = down_block(base_chan, chan_num[0], conv_num[0], trans_num[0], conv_block=conv_block, kernel_size=kernel_size[1], down_scale=scale[0], norm=norm, act=act, map_generate=False) 48 | 49 | # down2 down3 down4 apply the B-MHA blocks 50 | self.down2 = down_block(chan_num[0], chan_num[1], conv_num[1], trans_num[1], conv_block=conv_block, kernel_size=kernel_size[2], down_scale=scale[1], heads=num_heads[1], dim_head=dim_head[1], expansion=expansion, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, norm=norm, act=act, map_generate=True) 51 | 52 | self.down3 = down_block(chan_num[1], chan_num[2], conv_num[2], trans_num[2], conv_block=conv_block, kernel_size=kernel_size[3], down_scale=scale[2], heads=num_heads[2], dim_head=dim_head[2], expansion=expansion, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, norm=norm, act=act, map_generate=True) 53 | 54 | self.down4 = down_block(chan_num[2], chan_num[3], conv_num[3], trans_num[3], conv_block=conv_block, kernel_size=kernel_size[4], down_scale=scale[3], heads=num_heads[3], dim_head=dim_head[3], expansion=expansion, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, norm=norm, act=act, map_generate=True) 55 | 56 | 57 | self.map_fusion = SemanticMapFusion(chan_num[1:4], fusion_dim, fusion_heads, depth=fusion_depth, norm=norm) 58 | 59 | self.up1 = up_block(chan_num[3], chan_num[4], conv_num[4], trans_num[4], conv_block=conv_block, kernel_size=kernel_size[3], up_scale=scale[3], heads=num_heads[4], dim_head=dim_head[4], expansion=expansion, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, norm=norm, act=act, map_shortcut=True) 60 | 61 | self.up2 = up_block(chan_num[4], chan_num[5], conv_num[5], trans_num[5], conv_block=conv_block, kernel_size=kernel_size[2], up_scale=scale[2], heads=num_heads[5], dim_head=dim_head[5], expansion=expansion, attn_drop=attn_drop, proj_drop=proj_drop, map_size=map_size, proj_type=proj_type, norm=norm, act=act, map_shortcut=True, no_map_out=True) 62 | 63 | self.up3 = up_block(chan_num[5], chan_num[6], conv_num[6], trans_num[6], conv_block=conv_block, kernel_size=kernel_size[1], up_scale=scale[1], norm=norm, act=act, map_shortcut=False) 64 | 65 | self.up4 = up_block(chan_num[6], chan_num[7], conv_num[7], trans_num[7], conv_block=conv_block, kernel_size=kernel_size[0], up_scale=scale[0], norm=norm, act=act, map_shortcut=False) 66 | 67 | self.aux_loss = aux_loss 68 | if aux_loss: 69 | self.aux_out = nn.Conv3d(chan_num[5], num_classes, kernel_size=1) 70 | 71 | self.outc = nn.Conv3d(chan_num[7], num_classes, kernel_size=1) 72 | 73 | def forward(self, x): 74 | 75 | x0 = self.inc(x) 76 | x1, _ = self.down1(x0) 77 | x2, map2 = self.down2(x1) 78 | x3, map3 = self.down3(x2) 79 | x4, map4 = self.down4(x3) 80 | 81 | 82 | map_list = [map2, map3, map4] 83 | map_list = self.map_fusion(map_list) 84 | 85 | 86 | out, semantic_map = self.up1(x4, x3, map_list[2], map_list[1]) 87 | out, semantic_map = self.up2(out, x2, semantic_map, map_list[0]) 88 | 89 | if self.aux_loss: 90 | aux_out = self.aux_out(out) 91 | aux_out = F.interpolate(aux_out, size=x.shape[-3:], mode='trilinear', align_corners=True) 92 | 93 | out, semantic_map = self.up3(out, x1, semantic_map, None) 94 | out, semantic_map = self.up4(out, x0, semantic_map, None) 95 | 96 | out = self.outc(out) 97 | 98 | if self.aux_loss: 99 | return [out, aux_out] 100 | else: 101 | return out 102 | 103 | -------------------------------------------------------------------------------- /model/dim3/trans_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | import pdb 6 | 7 | 8 | __all__ = [ 9 | 'Mlp', 10 | 'Attention', 11 | 'TransformerBlock', 12 | 'LayerNorm', 13 | ] 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__(self, in_dim, hid_dim=None, out_dim=None, act=nn.GELU, drop=0.): 18 | super().__init__() 19 | out_dim = out_dim or in_dim 20 | hid_dim = hid_dim or in_dim 21 | self.fc1 = nn.Linear(in_dim, hid_dim) 22 | self.act = act() 23 | self.fc2 = nn.Linear(hid_dim, out_dim) 24 | self.drop = nn.Dropout(drop) 25 | 26 | def forward(self, x): 27 | x = self.fc1(x) 28 | x = self.act(x) 29 | x = self.drop(x) 30 | x = self.fc2(x) 31 | x = self.drop(x) 32 | 33 | return x 34 | 35 | class PreNorm(nn.Module): 36 | def __init__(self, dim, fn): 37 | super().__init__() 38 | self.norm = nn.LayerNorm(dim) 39 | self.fn = fn 40 | def forward(self, x, **kwargs): 41 | return self.fn(self.norm(x), **kwargs) 42 | 43 | 44 | 45 | class Attention(nn.Module): 46 | def __init__(self, dim, heads, dim_head, attn_drop=0., proj_drop=0.): 47 | super().__init__() 48 | 49 | inner_dim = dim_head * heads 50 | 51 | self.heads = heads 52 | self.scale = dim_head ** -0.5 53 | 54 | self.to_qkv = nn.Linear(dim, inner_dim*3, bias=False) 55 | 56 | self.to_out = nn.Linear(inner_dim, dim) 57 | 58 | self.proj_drop = nn.Dropout(proj_drop) 59 | 60 | def rearrange1(self, x, heads): 61 | # rearrange is not supported by pytorch2.0 torch.compile 62 | # 'b l (heads dim_head) -> b heads l dim_head' 63 | b, l, n = x.shape 64 | dim_head = int(n / heads) 65 | x = x.view(b, l, heads, dim_head).contiguous() 66 | x = x.permute(0, 2, 1, 3).contiguous() 67 | 68 | return x 69 | 70 | def rearrange2(self, x): 71 | # 'b heads l dim_head -> b l (dim_head heads)') 72 | b, heads, l, dim_head = x.shape 73 | x = x.permute(0, 2, 1, 3).contiguous() 74 | x = x.view(b, l, -1).contiguous() 75 | 76 | return x 77 | 78 | 79 | def forward(self, x): 80 | # x: B, L, C. Batch, sequence length, dim 81 | # 'b l (heads dim_head) -> b heads l dim_head', 82 | q, k, v = self.to_qkv(x).chunk(3, dim=-1) 83 | 84 | #q, k, v = map(lambda t: rearrange(t, 'b l (heads dim_head) -> b heads l dim_head', heads=self.heads), [q, k, v]) 85 | q, k, v = map(lambda t: self.rearrange1(t, heads=self.heads), [q, k, v]) 86 | 87 | 88 | attn = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale 89 | 90 | attn = F.softmax(attn, dim=-1) 91 | 92 | attned = torch.einsum('bhij,bhjd->bhid', attn, v) 93 | #attned = rearrange(attned, 'b heads l dim_head -> b l (dim_head heads)') 94 | attned = self.rearrange2(attned) 95 | 96 | attned = self.to_out(attned) 97 | 98 | return attned 99 | 100 | 101 | class TransformerBlock(nn.Module): 102 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, attn_drop=0., proj_drop=0.): 103 | super().__init__() 104 | 105 | self.layers = nn.ModuleList([]) 106 | 107 | for i in range(depth): 108 | self.layers.append(nn.ModuleList([ 109 | PreNorm(dim, Attention(dim, heads, dim_head, attn_drop, proj_drop)), 110 | PreNorm(dim, Mlp(dim, mlp_dim, dim, drop=proj_drop)) 111 | ])) 112 | def forward(self, x): 113 | 114 | for attn, ffn in self.layers: 115 | x = attn(x) + x 116 | x = ffn(x) + x 117 | 118 | return x 119 | 120 | class LayerNorm(nn.Module): 121 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 122 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 123 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 124 | with shape (batch_size, channels, height, width). 125 | """ 126 | 127 | def __init__(self, normalized_shape, eps=1e-5, data_format="channels_first"): 128 | super().__init__() 129 | 130 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 131 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 132 | self.eps = eps 133 | self.data_format = data_format 134 | 135 | if self.data_format not in ["channels_last", "channels_first"]: 136 | raise NotImplementedError 137 | self.normalized_shape = (normalized_shape, ) 138 | 139 | def forward(self, x): 140 | 141 | if self.data_format == "channels_last": 142 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 143 | 144 | elif self.data_format == "channels_first": 145 | u = x.mean(1, keepdim=True) 146 | s = (x - u).pow(2).mean(1, keepdim=True) 147 | x = (x - u) / torch.sqrt(s + self.eps) 148 | x = self.weight[None, :, None, None, None] * x + self.bias[None, :, None, None, None] 149 | return x 150 | 151 | 152 | -------------------------------------------------------------------------------- /model/dim3/unet.py: -------------------------------------------------------------------------------- 1 | # original U-Net 2 | # Modified from https://github.com/milesial/Pytorch-UNet 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from .unet_utils import inconv, down_block, up_block 8 | from .utils import get_block, get_norm 9 | import pdb 10 | 11 | 12 | class UNet(nn.Module): 13 | def __init__(self, in_ch, base_ch, scale=[2,2,2,2], kernel_size=[3,3,3,3], num_classes=1, block='ConvNormAct', pool=True, norm='bn'): 14 | super().__init__() 15 | ''' 16 | Args: 17 | in_ch: the num of input channel 18 | base_ch: the num of channels in the entry level 19 | scale: should be a list to indicate the downsample scale along each axis 20 | in each level, e.g. [1, 1, 2, 2] such that all axis use the same scale 21 | or [[1,2,2], [2,2,2], [2,2,2], [2,2,2]] for difference scale on each axis 22 | kernel_size: the 3D kernel size of each level 23 | e.g. [3,3,3,3] or [[1,3,3], [1,3,3], [3,3,3], [3,3,3]] 24 | num_classes: the target class number 25 | block: 'ConvNormAct' for origin UNet, 'BasicBlock' for ResUNet 26 | pool: use maxpool or use strided conv for downsample 27 | norm: the norm layer type, bn or in 28 | 29 | ''' 30 | 31 | num_block = 2 32 | block = get_block(block) 33 | norm = get_norm(norm) 34 | 35 | self.inc = inconv(in_ch, base_ch, block=block, kernel_size=kernel_size[0], norm=norm) 36 | 37 | self.down1 = down_block(base_ch, 2*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[0], kernel_size=kernel_size[1], norm=norm) 38 | self.down2 = down_block(2*base_ch, 4*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[1], kernel_size=kernel_size[2], norm=norm) 39 | self.down3 = down_block(4*base_ch, 8*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[2], kernel_size=kernel_size[3], norm=norm) 40 | self.down4 = down_block(8*base_ch, 10*base_ch, num_block=num_block, block=block, pool=pool, down_scale=scale[3], kernel_size=kernel_size[4], norm=norm) 41 | 42 | self.up1 = up_block(10*base_ch, 8*base_ch, num_block=num_block, block=block, up_scale=scale[3], kernel_size=kernel_size[3], norm=norm) 43 | self.up2 = up_block(8*base_ch, 4*base_ch, num_block=num_block, block=block, up_scale=scale[2], kernel_size=kernel_size[2], norm=norm) 44 | self.up3 = up_block(4*base_ch, 2*base_ch, num_block=num_block, block=block, up_scale=scale[1], kernel_size=kernel_size[1], norm=norm) 45 | self.up4 = up_block(2*base_ch, base_ch, num_block=num_block, block=block, up_scale=scale[0], kernel_size=kernel_size[0], norm=norm) 46 | 47 | self.outc = nn.Conv3d(base_ch, num_classes, kernel_size=1) 48 | 49 | 50 | def forward(self, x): 51 | 52 | x1 = self.inc(x) 53 | x2 = self.down1(x1) 54 | x3 = self.down2(x2) 55 | x4 = self.down3(x3) 56 | x5 = self.down4(x4) 57 | 58 | out = self.up1(x5, x4) 59 | out = self.up2(out, x3) 60 | out = self.up3(out, x2) 61 | out = self.up4(out, x1) 62 | out = self.outc(out) 63 | 64 | return out 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /model/dim3/unet_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .conv_layers import BasicBlock, Bottleneck, ConvNormAct 5 | import pdb 6 | 7 | class inconv(nn.Module): 8 | def __init__(self, in_ch, out_ch, kernel_size=[3,3,3], block=BasicBlock, norm=nn.BatchNorm3d): 9 | super().__init__() 10 | 11 | if isinstance(kernel_size, int): 12 | kernel_size = [kernel_size] * 3 13 | pad_size = [i//2 for i in kernel_size] 14 | self.conv1 = nn.Conv3d(in_ch, out_ch, kernel_size=kernel_size, padding=pad_size, bias=False) 15 | self.conv2 = block(out_ch, out_ch, kernel_size=kernel_size, norm=norm) 16 | 17 | def forward(self, x): 18 | out = self.conv1(x) 19 | out = self.conv2(out) 20 | 21 | return out 22 | 23 | 24 | class down_block(nn.Module): 25 | def __init__(self, in_ch, out_ch, num_block, block=BasicBlock, kernel_size=[3,3,3], down_scale=[2,2,2], pool=True, norm=nn.BatchNorm3d): 26 | super().__init__() 27 | 28 | if isinstance(kernel_size, int): 29 | kernel_size = [kernel_size] * 3 30 | if isinstance(down_scale, int): 31 | down_scale = [down_scale] * 3 32 | 33 | block_list = [] 34 | 35 | if pool: 36 | block_list.append(nn.MaxPool3d(down_scale)) 37 | block_list.append(block(in_ch, out_ch, kernel_size=kernel_size, norm=norm)) 38 | else: 39 | block_list.append(block(in_ch, out_ch, stride=down_scale, kernel_size=kernel_size, norm=norm)) 40 | 41 | for i in range(num_block-1): 42 | block_list.append(block(out_ch, out_ch, stride=1, kernel_size=kernel_size, norm=norm)) 43 | 44 | self.conv = nn.Sequential(*block_list) 45 | def forward(self, x): 46 | return self.conv(x) 47 | 48 | class up_block(nn.Module): 49 | def __init__(self, in_ch, out_ch, num_block, block=BasicBlock, kernel_size=[3,3,3], up_scale=[2,2,2], norm=nn.BatchNorm3d): 50 | super().__init__() 51 | 52 | if isinstance(kernel_size, int): 53 | kernel_size = [kernel_size] * 3 54 | if isinstance(up_scale, int): 55 | up_scale = [up_scale] * 3 56 | 57 | self.up_scale = up_scale 58 | 59 | 60 | block_list = [] 61 | 62 | block_list.append(block(in_ch+out_ch, out_ch, kernel_size=kernel_size, norm=norm)) 63 | for i in range(num_block-1): 64 | block_list.append(block(out_ch, out_ch, kernel_size=kernel_size, norm=norm)) 65 | 66 | self.conv = nn.Sequential(*block_list) 67 | 68 | def forward(self, x1, x2): 69 | x1 = F.interpolate(x1, size=x2.shape[2:], mode='trilinear', align_corners=True) 70 | 71 | out = torch.cat([x2, x1], dim=1) 72 | 73 | out = self.conv(out) 74 | 75 | return out 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /model/dim3/unetpp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .conv_layers import BasicBlock, Bottleneck, ConvNormAct 5 | from .utils import get_block, get_norm 6 | 7 | 8 | class UNetPlusPlus(nn.Module): 9 | def __init__(self, in_ch, base_ch, scale, kernel_size, num_classes=1, block='SingleConv', norm='bn'): 10 | super().__init__() 11 | 12 | num_block = 2 13 | block = get_block(block) 14 | norm = get_norm(norm) 15 | 16 | n_ch = [base_ch, base_ch*2, base_ch*4, base_ch*8, base_ch*10] 17 | 18 | self.pool0 = nn.MaxPool3d(scale[0]) 19 | self.up0 = nn.Upsample(scale_factor=tuple(scale[0]), mode='trilinear', align_corners=True) 20 | self.pool1 = nn.MaxPool3d(scale[1]) 21 | self.up1 = nn.Upsample(scale_factor=tuple(scale[1]), mode='trilinear', align_corners=True) 22 | self.pool2 = nn.MaxPool3d(scale[2]) 23 | self.up2 = nn.Upsample(scale_factor=tuple(scale[2]), mode='trilinear', align_corners=True) 24 | self.pool3 = nn.MaxPool3d(scale[3]) 25 | self.up3 = nn.Upsample(scale_factor=tuple(scale[3]), mode='trilinear', align_corners=True) 26 | 27 | 28 | self.conv0_0 = self.make_layer(in_ch, n_ch[0], num_block, block, kernel_size=kernel_size[0], norm=norm) 29 | self.conv1_0 = self.make_layer(n_ch[0], n_ch[1], num_block, block, kernel_size=kernel_size[1], norm=norm) 30 | self.conv2_0 = self.make_layer(n_ch[1], n_ch[2], num_block, block, kernel_size=kernel_size[2], norm=norm) 31 | self.conv3_0 = self.make_layer(n_ch[2], n_ch[3], num_block, block, kernel_size=kernel_size[3], norm=norm) 32 | self.conv4_0 = self.make_layer(n_ch[3], n_ch[4], num_block, block, kernel_size=kernel_size[4], norm=norm) 33 | 34 | self.conv0_1 = self.make_layer(n_ch[0]+n_ch[1], n_ch[0], num_block, block, kernel_size=kernel_size[0], norm=norm) 35 | self.conv1_1 = self.make_layer(n_ch[1]+n_ch[2], n_ch[1], num_block, block, kernel_size=kernel_size[1], norm=norm) 36 | self.conv2_1 = self.make_layer(n_ch[2]+n_ch[3], n_ch[2], num_block, block, kernel_size=kernel_size[2], norm=norm) 37 | self.conv3_1 = self.make_layer(n_ch[3]+n_ch[4], n_ch[3], num_block, block, kernel_size=kernel_size[3], norm=norm) 38 | 39 | self.conv0_2 = self.make_layer(n_ch[0]*2+n_ch[1], n_ch[0], num_block, block, kernel_size=kernel_size[0], norm=norm) 40 | self.conv1_2 = self.make_layer(n_ch[1]*2+n_ch[2], n_ch[1], num_block, block, kernel_size=kernel_size[1], norm=norm) 41 | self.conv2_2 = self.make_layer(n_ch[2]*2+n_ch[3], n_ch[2], num_block, block, kernel_size=kernel_size[2], norm=norm) 42 | 43 | self.conv0_3 = self.make_layer(n_ch[0]*3+n_ch[1], n_ch[0], num_block, block, kernel_size=kernel_size[0], norm=norm) 44 | self.conv1_3 = self.make_layer(n_ch[1]*3+n_ch[2], n_ch[1], num_block, block, kernel_size=kernel_size[1], norm=norm) 45 | 46 | 47 | self.conv0_4 = self.make_layer(n_ch[0]*4+n_ch[1], n_ch[0], num_block, block, kernel_size=kernel_size[0], norm=norm) 48 | 49 | 50 | self.output = nn.Conv3d(n_ch[0], num_classes, kernel_size=1) 51 | 52 | def forward(self, x): 53 | 54 | x0_0 = self.conv0_0(x) 55 | x1_0 = self.conv1_0(self.pool0(x0_0)) 56 | x0_1 = self.conv0_1(torch.cat([x0_0, self.up0(x1_0)], 1)) 57 | 58 | x2_0 = self.conv2_0(self.pool1(x1_0)) 59 | x1_1 = self.conv1_1(torch.cat([x1_0, self.up1(x2_0)], 1)) 60 | x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up0(x1_1)], 1)) 61 | 62 | x3_0 = self.conv3_0(self.pool2(x2_0)) 63 | x2_1 = self.conv2_1(torch.cat([x2_0, self.up2(x3_0)], 1)) 64 | x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up1(x2_1)], 1)) 65 | x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up0(x1_2)], 1)) 66 | 67 | x4_0 = self.conv4_0(self.pool3(x3_0)) 68 | x3_1 = self.conv3_1(torch.cat([x3_0, self.up3(x4_0)], 1)) 69 | x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up2(x3_1)], 1)) 70 | x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up1(x2_2)], 1)) 71 | x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up0(x1_3)], 1)) 72 | 73 | output = self.output(x0_4) 74 | 75 | return output 76 | 77 | 78 | def make_layer(self, in_ch, out_ch, num_block, block, kernel_size, norm): 79 | blocks = [] 80 | blocks.append(block(in_ch, out_ch, kernel_size=kernel_size, norm=norm)) 81 | 82 | for i in range(num_block-1): 83 | blocks.append(block(out_ch, out_ch, kernel_size=kernel_size, norm=norm)) 84 | 85 | return nn.Sequential(*blocks) 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /model/dim3/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .conv_layers import BasicBlock, Bottleneck, SingleConv 5 | from .trans_layers import LayerNorm 6 | 7 | def get_block(name): 8 | block_map = { 9 | 'SingleConv': SingleConv, 10 | 'BasicBlock': BasicBlock, 11 | 'Bottleneck': Bottleneck, 12 | } 13 | return block_map[name] 14 | 15 | def get_norm(name): 16 | norm_map = {'bn': nn.BatchNorm3d, 17 | 'in': nn.InstanceNorm3d, 18 | 'ln': LayerNorm 19 | } 20 | 21 | return norm_map[name] 22 | 23 | def get_act(name): 24 | act_map = { 25 | 'relu': nn.ReLU, 26 | 'lrelu': nn.LeakyReLU, 27 | 'gelu': nn.GELU, 28 | 'swish': nn.SiLU 29 | } 30 | return act_map[name] 31 | -------------------------------------------------------------------------------- /model/dim3/vnet.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/mattmacy/vnet.pytorch 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def passthrough(x, **kwargs): 9 | 10 | return x 11 | 12 | def ELUCons(elu, nchan): 13 | if elu: 14 | return nn.ELU(inplace=True) 15 | else: 16 | return nn.PReLU(nchan) 17 | 18 | 19 | # normalization between sub-volumes is necessary 20 | # for good performance 21 | 22 | class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm): 23 | def _check_input_dim(self, input): 24 | if input.dim() != 5: 25 | raise ValueError('expected 5D input (got {}D input)' 26 | .format(input.dim())) 27 | def forward(self, input): 28 | self._check_input_dim(input) 29 | 30 | return F.batch_norm( 31 | input, self.running_mean, self.running_var, self.weight, self.bias, 32 | True, self.momentum, self.eps) 33 | 34 | 35 | class LUConv(nn.Module): 36 | def __init__(self, nchan, elu): 37 | super(LUConv, self).__init__() 38 | self.relu1 = ELUCons(elu, nchan) 39 | self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2) 40 | self.bn1 = ContBatchNorm3d(nchan) 41 | 42 | def forward(self, x): 43 | out = self.relu1(self.bn1(self.conv1(x))) 44 | 45 | return out 46 | 47 | 48 | def _make_nConv(nchan, depth, elu): 49 | layers = [] 50 | for _ in range(depth): 51 | layers.append(LUConv(nchan, elu)) 52 | 53 | return nn.Sequential(*layers) 54 | 55 | 56 | 57 | class InputTransition(nn.Module): 58 | def __init__(self, inChans, outChans, elu): 59 | super(InputTransition, self).__init__() 60 | self.conv1 = nn.Conv3d(inChans, outChans, kernel_size=5, padding=2) 61 | self.bn1 = ContBatchNorm3d(outChans) 62 | self.relu1 = ELUCons(elu, outChans) 63 | self.inChans = inChans 64 | self.outChans = outChans 65 | 66 | def forward(self, x): 67 | # do we want a PRELU here as well? 68 | out = self.bn1(self.conv1(x)) 69 | # split input in to 16 channels 70 | num = int(self.outChans / self.inChans) 71 | x16 = x.repeat(1, num, 1, 1, 1) 72 | #x16 = torch.cat((x, x, x, x, x, x, x, x, 73 | # x, x, x, x, x, x, x, x), 0) 74 | 75 | out = self.relu1(torch.add(out, x16)) 76 | 77 | return out 78 | 79 | class DownTransition(nn.Module): 80 | def __init__(self, inChans, nConvs, elu, scale=2, dropout=False): 81 | super(DownTransition, self).__init__() 82 | 83 | outChans = 2*inChans 84 | self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=scale, stride=scale) 85 | self.bn1 = ContBatchNorm3d(outChans) 86 | self.do1 = passthrough 87 | self.relu1 = ELUCons(elu, outChans) 88 | self.relu2 = ELUCons(elu, outChans) 89 | 90 | if dropout: 91 | self.do1 = nn.Dropout3d() 92 | 93 | self.ops = _make_nConv(outChans, nConvs, elu) 94 | 95 | def forward(self, x): 96 | down = self.relu1(self.bn1(self.down_conv(x))) 97 | out = self.do1(down) 98 | out = self.ops(out) 99 | out = self.relu2(torch.add(out, down)) 100 | 101 | return out 102 | 103 | 104 | class UpTransition(nn.Module): 105 | def __init__(self, inChans, outChans, nConvs, elu, scale=2, dropout=False): 106 | super(UpTransition, self).__init__() 107 | 108 | self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=scale, stride=scale) 109 | self.bn1 = ContBatchNorm3d(outChans // 2) 110 | self.do1 = passthrough 111 | self.do2 = nn.Dropout3d() 112 | self.relu1 = ELUCons(elu, outChans // 2) 113 | self.relu2 = ELUCons(elu, outChans) 114 | 115 | if dropout: 116 | self.do1 = nn.Dropout3d() 117 | self.ops = _make_nConv(outChans, nConvs, elu) 118 | 119 | def forward(self, x, skipx): 120 | 121 | out = self.do1(x) 122 | skipxdo = self.do2(skipx) 123 | out = self.relu1(self.bn1(self.up_conv(out))) 124 | xcat = torch.cat((out, skipxdo), 1) 125 | out = self.ops(xcat) 126 | out = self.relu2(torch.add(out, xcat)) 127 | 128 | return out 129 | 130 | 131 | class OutputTransition(nn.Module): 132 | def __init__(self, inChans, outChans, elu, nll): 133 | super(OutputTransition, self).__init__() 134 | 135 | self.conv1 = nn.Conv3d(inChans, outChans, kernel_size=5, padding=2) 136 | self.bn1 = ContBatchNorm3d(outChans) 137 | self.conv2 = nn.Conv3d(outChans, outChans, kernel_size=1) 138 | self.relu1 = ELUCons(elu, outChans) 139 | 140 | def forward(self, x): 141 | # convolve 32 down to 2 channels 142 | out = self.relu1(self.bn1(self.conv1(x))) 143 | out = self.conv2(out) 144 | 145 | return out 146 | 147 | 148 | class VNet(nn.Module): 149 | # the number of convolutions in each layer corresponds 150 | # to what is in the actual prototxt, not the intent 151 | def __init__(self, inChans, outChans, scale, baseChans=16, elu=True, nll=False): 152 | super(VNet, self).__init__() 153 | 154 | self.in_tr = InputTransition(inChans, baseChans, elu) 155 | self.down_tr32 = DownTransition(baseChans, 1, elu, scale=scale[0]) 156 | self.down_tr64 = DownTransition(baseChans*2, 2, elu, scale=scale[1]) 157 | self.down_tr128 = DownTransition(baseChans*4, 3, elu, dropout=True, scale=scale[2]) 158 | self.down_tr256 = DownTransition(baseChans*8, 2, elu, dropout=True, scale=scale[3]) 159 | 160 | self.up_tr256 = UpTransition(baseChans*16, baseChans*16, 2, elu, dropout=True, scale=scale[3]) 161 | self.up_tr128 = UpTransition(baseChans*16, baseChans*8, 2, elu, dropout=True, scale=scale[2]) 162 | self.up_tr64 = UpTransition(baseChans*8, baseChans*4, 1, elu, scale=scale[1]) 163 | self.up_tr32 = UpTransition(baseChans*4, baseChans*2, 1, elu, scale=scale[0]) 164 | self.out_tr = OutputTransition(baseChans*2, outChans, elu, nll) 165 | 166 | 167 | def forward(self, x): 168 | 169 | out16 = self.in_tr(x) 170 | out32 = self.down_tr32(out16) 171 | out64 = self.down_tr64(out32) 172 | out128 = self.down_tr128(out64) 173 | out256 = self.down_tr256(out128) 174 | 175 | out = self.up_tr256(out256, out128) 176 | out = self.up_tr128(out, out64) 177 | out = self.up_tr64(out, out32) 178 | out = self.up_tr32(out, out16) 179 | 180 | out = self.out_tr(out) 181 | 182 | return out 183 | -------------------------------------------------------------------------------- /model/dim3/vtunet.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import absolute_import 4 | 5 | from __future__ import division 6 | 7 | from __future__ import print_function 8 | 9 | 10 | 11 | import copy 12 | 13 | import logging 14 | 15 | 16 | 17 | import torch 18 | 19 | import torch.nn as nn 20 | 21 | 22 | 23 | from .vtunet_utils import SwinTransformerSys3D 24 | 25 | 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | 31 | 32 | class VTUNet(nn.Module): 33 | 34 | def __init__(self, config, num_classes=3, zero_head=False, embed_dim=96, win_size=7): 35 | 36 | super(VTUNet, self).__init__() 37 | 38 | self.num_classes = num_classes 39 | 40 | self.zero_head = zero_head 41 | 42 | self.config = config 43 | 44 | self.embed_dim = embed_dim 45 | 46 | self.win_size = win_size 47 | 48 | self.win_size = (self.win_size,self.win_size,self.win_size) 49 | 50 | 51 | 52 | self.swin_unet = SwinTransformerSys3D(img_size=config.training_size, 53 | 54 | patch_size=config.patch_size, 55 | 56 | in_chans=config.in_chan, 57 | 58 | num_classes=self.num_classes, 59 | 60 | embed_dim=self.embed_dim, 61 | 62 | depths=[2, 2, 2, 1], 63 | 64 | depths_decoder=[1, 2, 2, 2], 65 | 66 | num_heads=[3, 6, 12, 24], 67 | 68 | window_size=self.win_size, 69 | 70 | mlp_ratio=4., 71 | 72 | qkv_bias=True, 73 | 74 | qk_scale=None, 75 | 76 | drop_rate=0., 77 | 78 | attn_drop_rate=0., 79 | 80 | drop_path_rate=0.1, 81 | 82 | norm_layer=nn.LayerNorm, 83 | 84 | patch_norm=True, 85 | 86 | use_checkpoint=False, 87 | 88 | frozen_stages=-1, 89 | 90 | final_upsample="expand_first") 91 | 92 | 93 | 94 | def forward(self, x): 95 | 96 | logits = self.swin_unet(x) 97 | 98 | return logits 99 | 100 | 101 | 102 | def load_from(self, config): 103 | 104 | pretrained_path = config.init_model 105 | 106 | if pretrained_path is not None: 107 | 108 | print("pretrained_path:{}".format(pretrained_path)) 109 | 110 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 111 | 112 | pretrained_dict = torch.load(pretrained_path, map_location=device) 113 | 114 | if "model" not in pretrained_dict: 115 | 116 | print("---start load pretrained modle by splitting---") 117 | 118 | pretrained_dict = {k[17:]: v for k, v in pretrained_dict.items()} 119 | 120 | for k in list(pretrained_dict.keys()): 121 | 122 | if "output" in k: 123 | 124 | print("delete key:{}".format(k)) 125 | 126 | del pretrained_dict[k] 127 | 128 | self.swin_unet.load_state_dict(pretrained_dict, strict=False) 129 | 130 | 131 | 132 | return 133 | 134 | pretrained_dict = pretrained_dict['model'] 135 | 136 | print("---start load pretrained modle of swin encoder---") 137 | 138 | 139 | 140 | model_dict = self.swin_unet.state_dict() 141 | 142 | full_dict = copy.deepcopy(pretrained_dict) 143 | 144 | for k, v in pretrained_dict.items(): 145 | 146 | if "layers." in k: 147 | 148 | current_layer_num = 3 - int(k[7:8]) 149 | 150 | current_k = "layers_up." + str(current_layer_num) + k[8:] 151 | 152 | full_dict.update({current_k: v}) 153 | 154 | for k in list(full_dict.keys()): 155 | 156 | if k in model_dict: 157 | 158 | if full_dict[k].shape != model_dict[k].shape: 159 | 160 | print("delete:{};shape pretrain:{};shape model:{}".format(k, v.shape, model_dict[k].shape)) 161 | 162 | del full_dict[k] 163 | 164 | 165 | 166 | self.swin_unet.load_state_dict(full_dict, strict=False) 167 | 168 | else: 169 | 170 | print("none pretrain") 171 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import pdb 5 | 6 | def get_model(args, pretrain=False): 7 | 8 | if args.dimension == '2d': 9 | if args.model == 'unet': 10 | from .dim2 import UNet 11 | if pretrain: 12 | raise ValueError('No pretrain model available') 13 | return UNet(args.in_chan, args.classes, args.base_chan, block=args.block) 14 | if args.model == 'unet++': 15 | from .dim2 import UNetPlusPlus 16 | if pretrain: 17 | raise ValueError('No pretrain model available') 18 | return UNetPlusPlus(args.in_chan, args.classes, args.base_chan) 19 | if args.model == 'attention_unet': 20 | from .dim2 import AttentionUNet 21 | if pretrain: 22 | raise ValueError('No pretrain model available') 23 | return AttentionUNet(args.in_chan, args.classes, args.base_chan) 24 | 25 | elif args.model == 'resunet': 26 | from .dim2 import UNet 27 | if pretrain: 28 | raise ValueError('No pretrain model available') 29 | return UNet(args.in_chan, args.classes, args.base_chan, block=args.block) 30 | elif args.model == 'daunet': 31 | from .dim2 import DAUNet 32 | if pretrain: 33 | raise ValueError('No pretrain model available') 34 | return DAUNet(args.in_chan, args.classes, args.base_chan, block=args.block) 35 | 36 | elif args.model in ['medformer']: 37 | from .dim2 import MedFormer 38 | if pretrain: 39 | raise ValueError('No pretrain model available') 40 | return MedFormer(args.in_chan, args.classes, args.base_chan, conv_block=args.conv_block, conv_num=args.conv_num, trans_num=args.trans_num, num_heads=args.num_heads, fusion_depth=args.fusion_depth, fusion_dim=args.fusion_dim, fusion_heads=args.fusion_heads, map_size=args.map_size, proj_type=args.proj_type, act=nn.ReLU, expansion=args.expansion, attn_drop=args.attn_drop, proj_drop=args.proj_drop, aux_loss=args.aux_loss) 41 | 42 | 43 | elif args.model == 'transunet': 44 | from .dim2 import VisionTransformer as ViT_seg 45 | from .dim2.transunet import CONFIGS as CONFIGS_ViT_seg 46 | config_vit = CONFIGS_ViT_seg['R50-ViT-B_16'] 47 | config_vit.n_classes = args.classes 48 | config_vit.n_skip = 3 49 | config_vit.patches.grid = (int(args.training_size[0]/16), int(args.training_size[1]/16)) 50 | net = ViT_seg(config_vit, img_size=args.training_size[0], num_classes=args.classes) 51 | 52 | if pretrain: 53 | net.load_from(weights=np.load(args.init_model)) 54 | 55 | return net 56 | 57 | elif args.model == 'swinunet': 58 | from .dim2 import SwinUnet 59 | from .dim2.swin_unet import SwinUnet_config 60 | config = SwinUnet_config() 61 | net = SwinUnet(config, img_size=224, num_classes=args.classes) 62 | 63 | if pretrain: 64 | net.load_from(args.init_model) 65 | 66 | return net 67 | 68 | 69 | 70 | elif args.dimension == '3d': 71 | if args.model == 'vnet': 72 | from .dim3 import VNet 73 | if pretrain: 74 | raise ValueError('No pretrain model available') 75 | return VNet(args.in_chan, args.classes, scale=args.downsample_scale, baseChans=args.base_chan) 76 | elif args.model == 'resunet': 77 | from .dim3 import UNet 78 | if pretrain: 79 | raise ValueError('No pretrain model available') 80 | return UNet(args.in_chan, args.base_chan, num_classes=args.classes, scale=args.down_scale, norm=args.norm, kernel_size=args.kernel_size, block=args.block) 81 | 82 | elif args.model == 'unet': 83 | from .dim3 import UNet 84 | return UNet(args.in_chan, args.base_chan, num_classes=args.classes, scale=args.down_scale, norm=args.norm, kernel_size=args.kernel_size, block=args.block) 85 | elif args.model == 'unet++': 86 | from .dim3 import UNetPlusPlus 87 | return UNetPlusPlus(args.in_chan, args.base_chan, num_classes=args.classes, scale=args.down_scale, norm=args.norm, kernel_size=args.kernel_size, block=args.block) 88 | elif args.model == 'attention_unet': 89 | from .dim3 import AttentionUNet 90 | return AttentionUNet(args.in_chan, args.base_chan, num_classes=args.classes, scale=args.down_scale, norm=args.norm, kernel_size=args.kernel_size, block=args.block) 91 | 92 | elif args.model == 'medformer': 93 | from .dim3 import MedFormer 94 | 95 | return MedFormer(args.in_chan, args.classes, args.base_chan, map_size=args.map_size, conv_block=args.conv_block, conv_num=args.conv_num, trans_num=args.trans_num, num_heads=args.num_heads, fusion_depth=args.fusion_depth, fusion_dim=args.fusion_dim, fusion_heads=args.fusion_heads, expansion=args.expansion, attn_drop=args.attn_drop, proj_drop=args.proj_drop, proj_type=args.proj_type, norm=args.norm, act=args.act, kernel_size=args.kernel_size, scale=args.down_scale, aux_loss=args.aux_loss) 96 | 97 | elif args.model == 'unetr': 98 | from .dim3 import UNETR 99 | model = UNETR(args.in_chan, args.classes, args.training_size, feature_size=16, hidden_size=768, mlp_dim=3072, num_heads=12, pos_embed='perceptron', norm_name='instance', res_block=True) 100 | 101 | return model 102 | elif args.model == 'vtunet': 103 | from .dim3 import VTUNet 104 | model = VTUNet(args, args.classes) 105 | 106 | if pretrain: 107 | model.load_from(args) 108 | return model 109 | elif args.model == 'swin_unetr': 110 | from .dim3 import SwinUNETR 111 | model = SwinUNETR(args.window_size, args.in_chan, args.classes, feature_size=args.base_chan) 112 | 113 | if args.pretrain: 114 | weights = torch.load('/research/cbim/vast/yg397/ConvFormer/ConvFormer/initmodel/model_swinvit.pt') 115 | model.load_from(weights=weights) 116 | 117 | return model 118 | elif args.model == 'nnformer': 119 | from .dim3 import nnFormer 120 | model = nnFormer(args.window_size, input_channels=args.in_chan, num_classes=args.classes, deep_supervision=args.aux_loss) 121 | 122 | return model 123 | else: 124 | raise ValueError('Invalid dimension, should be \'2d\' or \'3d\'') 125 | 126 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | addict==2.4.0 3 | apex==0.1 4 | astunparse==1.6.3 5 | batchgenerators==0.25 6 | cachetools==5.3.0 7 | certifi @ file:///croot/certifi_1671487769961/work/certifi 8 | charset-normalizer==2.1.1 9 | cmake==3.25.0 10 | contextlib2==21.6.0 11 | contourpy==1.0.7 12 | cycler==0.11.0 13 | einops==0.6.0 14 | filelock==3.9.0 15 | flatbuffers==23.3.3 16 | fonttools==4.39.2 17 | future==0.18.3 18 | gast==0.4.0 19 | google-auth==2.16.2 20 | google-auth-oauthlib==0.4.6 21 | google-pasta==0.2.0 22 | grpcio==1.51.3 23 | h5py==3.8.0 24 | huggingface-hub==0.13.2 25 | idna==3.4 26 | imageio==2.26.0 27 | jax==0.4.6 28 | Jinja2==3.1.2 29 | joblib==1.2.0 30 | keras==2.12.0 31 | kiwisolver==1.4.4 32 | lazy_loader==0.1 33 | libclang==16.0.0 34 | linecache2==1.0.0 35 | lit==15.0.7 36 | Markdown==3.4.1 37 | MarkupSafe==2.1.2 38 | matplotlib==3.7.1 39 | ml-collections==0.1.1 40 | mmcv==1.7.1 41 | monai==1.1.0 42 | mpmath==1.2.1 43 | networkx==3.0 44 | ninja==1.11.1 45 | numpy==1.23.5 46 | oauthlib==3.2.2 47 | opencv-python==4.7.0.72 48 | opt-einsum==3.3.0 49 | packaging==23.0 50 | Pillow==9.3.0 51 | protobuf==4.22.1 52 | pyasn1==0.4.8 53 | pyasn1-modules==0.2.8 54 | pyparsing==3.0.9 55 | python-dateutil==2.8.2 56 | PyWavelets==1.4.1 57 | PyYAML==6.0 58 | requests==2.28.1 59 | requests-oauthlib==1.3.1 60 | rsa==4.9 61 | scikit-image==0.20.0 62 | scikit-learn==1.2.2 63 | scipy==1.10.1 64 | SimpleITK==2.2.1 65 | six==1.16.0 66 | sympy==1.11.1 67 | tabulate==0.9.0 68 | tensorboard==2.12.0 69 | tensorboard-data-server==0.7.0 70 | tensorboard-plugin-wit==1.8.1 71 | tensorflow-estimator==2.12.0 72 | tensorflow-io-gcs-filesystem==0.31.0 73 | termcolor==2.2.0 74 | threadpoolctl==3.1.0 75 | tifffile==2023.3.15 76 | timm==0.6.12 77 | torch==2.0.0+cu118 78 | torchaudio==2.0.1+cu118 79 | torchvision==0.15.1+cu118 80 | tqdm==4.65.0 81 | traceback2==1.4.0 82 | triton==2.0.0 83 | typing_extensions==4.4.0 84 | unittest2==1.1.0 85 | urllib3==1.26.13 86 | Werkzeug==2.2.3 87 | wrapt==1.14.1 88 | yapf==0.32.0 89 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhygao/CBIM-Medical-Image-Segmentation/7c26979a96eb9fe057320e1db38680bae33786b8/training/__init__.py -------------------------------------------------------------------------------- /training/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhygao/CBIM-Medical-Image-Segmentation/7c26979a96eb9fe057320e1db38680bae33786b8/training/dataset/__init__.py -------------------------------------------------------------------------------- /training/dataset/dim2/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /training/dataset/dim2/dataset_acdc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | import SimpleITK as sitk 8 | import yaml 9 | import math 10 | import random 11 | import pdb 12 | from training import augmentation 13 | import logging 14 | import copy 15 | 16 | 17 | class CMRDataset(Dataset): 18 | def __init__(self, args, mode='train', k_fold=5, k=0, seed=0): 19 | 20 | self.mode = mode 21 | self.args = args 22 | 23 | assert mode in ['train', 'test'] 24 | 25 | with open(os.path.join(args.data_root, 'list', 'dataset.yaml'), 'r') as f: 26 | img_name_list = yaml.load(f, Loader=yaml.SafeLoader) 27 | 28 | random.Random(seed).shuffle(img_name_list) 29 | 30 | length = len(img_name_list) 31 | test_name_list = img_name_list[k*(length//k_fold):(k+1)*(length//k_fold)] 32 | train_name_list = img_name_list 33 | train_name_list = list(set(img_name_list) - set(test_name_list)) 34 | 35 | if mode == 'train': 36 | img_name_list = train_name_list 37 | else: 38 | img_name_list = test_name_list 39 | 40 | logging.info(f"Start loading {self.mode} data") 41 | 42 | path = args.data_root 43 | 44 | img_list = [] 45 | lab_list = [] 46 | spacing_list = [] 47 | 48 | for name in img_name_list: 49 | for idx in [0, 1]: 50 | 51 | img_name = name + '_%d.nii.gz'%idx 52 | lab_name = name + '_%d_gt.nii.gz'%idx 53 | 54 | itk_img = sitk.ReadImage(os.path.join(path, img_name)) 55 | itk_lab = sitk.ReadImage(os.path.join(path, lab_name)) 56 | 57 | spacing = np.array(itk_lab.GetSpacing()).tolist() 58 | spacing_list.append(spacing[::-1]) 59 | 60 | assert itk_img.GetSize() == itk_lab.GetSize() 61 | 62 | img, lab = self.preprocess(itk_img, itk_lab) 63 | 64 | img_list.append(img) 65 | lab_list.append(lab) 66 | 67 | self.img_slice_list = [] 68 | self.lab_slice_list = [] 69 | if self.mode == 'train': 70 | for i in range(len(img_list)): 71 | 72 | z, x, y = img_list[i].shape 73 | 74 | for j in range(z): 75 | self.img_slice_list.append(copy.deepcopy(img_list[i][j])) 76 | self.lab_slice_list.append(copy.deepcopy(lab_list[i][j])) 77 | del img_list 78 | del lab_list 79 | else: 80 | self.img_slice_list = img_list 81 | self.lab_slice_list = lab_list 82 | self.spacing_list = spacing_list 83 | 84 | 85 | logging.info(f"Load done, length of dataset: {len(self.img_slice_list)}") 86 | 87 | 88 | 89 | def __len__(self): 90 | return len(self.img_slice_list) 91 | 92 | def preprocess(self, itk_img, itk_lab): 93 | 94 | img = sitk.GetArrayFromImage(itk_img) 95 | lab = sitk.GetArrayFromImage(itk_lab) 96 | 97 | max98 = np.percentile(img, 98) 98 | img = np.clip(img, 0, max98) 99 | 100 | z, y, x = img.shape 101 | if x < self.args.training_size[0]: 102 | diff = (self.args.training_size[0] + 10 - x) // 2 103 | img = np.pad(img, ((0,0), (0,0), (diff, diff))) 104 | lab = np.pad(lab, ((0,0), (0,0), (diff,diff))) 105 | if y < self.args.training_size[1]: 106 | diff = (self.args.training_size[1] + 10 -y) // 2 107 | img = np.pad(img, ((0,0), (diff, diff), (0,0))) 108 | lab = np.pad(lab, ((0,0), (diff, diff), (0,0))) 109 | 110 | img = img / max98 111 | 112 | img = img.astype(np.float32) 113 | lab = lab.astype(np.uint8) 114 | 115 | tensor_img = torch.from_numpy(img).float() 116 | tensor_lab = torch.from_numpy(lab).long() 117 | 118 | return tensor_img, tensor_lab 119 | 120 | 121 | def __getitem__(self, idx): 122 | 123 | 124 | tensor_img = self.img_slice_list[idx] 125 | tensor_lab = self.lab_slice_list[idx] 126 | 127 | 128 | if self.mode == 'train': 129 | tensor_img = tensor_img.unsqueeze(0).unsqueeze(0) 130 | tensor_lab = tensor_lab.unsqueeze(0).unsqueeze(0) 131 | 132 | # Gaussian Noise 133 | tensor_img = augmentation.gaussian_noise(tensor_img, std=self.args.gaussian_noise_std) 134 | # Additive brightness 135 | tensor_img = augmentation.brightness_additive(tensor_img, std=self.args.additive_brightness_std) 136 | # gamma 137 | tensor_img = augmentation.gamma(tensor_img, gamma_range=self.args.gamma_range, retain_stats=True) 138 | 139 | tensor_img, tensor_lab = augmentation.random_scale_rotate_translate_2d(tensor_img, tensor_lab, self.args.scale, self.args.rotate, self.args.translate) 140 | tensor_img, tensor_lab = augmentation.crop_2d(tensor_img, tensor_lab, self.args.training_size, mode='random') 141 | 142 | tensor_img, tensor_lab = tensor_img.squeeze(0), tensor_lab.squeeze(0) 143 | else: 144 | tensor_img, tensor_lab = self.center_crop(tensor_img, tensor_lab) 145 | 146 | assert tensor_img.shape == tensor_lab.shape 147 | 148 | if self.mode == 'train': 149 | return tensor_img, tensor_lab 150 | else: 151 | return tensor_img, tensor_lab, np.array(self.spacing_list[idx]) 152 | 153 | 154 | def center_crop(self, img, label): 155 | D, H, W = img.shape 156 | 157 | diff_H = H - self.args.training_size[0] 158 | diff_W = W - self.args.training_size[1] 159 | 160 | rand_x = diff_H // 2 161 | rand_y = diff_W // 2 162 | 163 | croped_img = img[:, rand_x:rand_x+self.args.training_size[0], rand_y:rand_y+self.args.training_size[0]] 164 | croped_lab = label[:, rand_x:rand_x+self.args.training_size[1], rand_y:rand_y+self.args.training_size[1]] 165 | 166 | return croped_img, croped_lab 167 | -------------------------------------------------------------------------------- /training/dataset/dim3/dataset_acdc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | import SimpleITK as sitk 8 | import yaml 9 | import math 10 | import random 11 | import pdb 12 | from training import augmentation 13 | import logging 14 | import copy 15 | 16 | 17 | class CMRDataset(Dataset): 18 | def __init__(self, args, mode='train', k_fold=5, k=0, seed=0): 19 | 20 | self.mode = mode 21 | self.args = args 22 | 23 | assert mode in ['train', 'test'] 24 | 25 | with open(os.path.join(args.data_root, 'list', 'dataset.yaml'), 'r') as f: 26 | img_name_list = yaml.load(f, Loader=yaml.SafeLoader) 27 | 28 | 29 | random.Random(seed).shuffle(img_name_list) 30 | 31 | length = len(img_name_list) 32 | test_name_list = img_name_list[k*(length//k_fold) : (k+1)*(length//k_fold)] 33 | train_name_list = list(set(img_name_list) - set(test_name_list)) 34 | 35 | if mode == 'train': 36 | img_name_list = train_name_list 37 | else: 38 | img_name_list = test_name_list 39 | 40 | 41 | logging.info(f'Start loading {self.mode} data') 42 | 43 | path = args.data_root 44 | 45 | self.img_list = [] 46 | self.lab_list = [] 47 | self.spacing_list = [] 48 | 49 | for name in img_name_list: 50 | for idx in [0, 1]: 51 | 52 | img_name = name + '_%d.nii.gz'%idx 53 | lab_name = name + '_%d_gt.nii.gz'%idx 54 | 55 | itk_img = sitk.ReadImage(os.path.join(path, img_name)) 56 | itk_lab = sitk.ReadImage(os.path.join(path, lab_name)) 57 | 58 | spacing = np.array(itk_lab.GetSpacing()).tolist() 59 | self.spacing_list.append(spacing[::-1]) # itk axis order is inverse of numpy axis order 60 | 61 | assert itk_img.GetSize() == itk_lab.GetSize() 62 | 63 | img, lab = self.preprocess(itk_img, itk_lab) 64 | 65 | self.img_list.append(img) 66 | self.lab_list.append(lab) 67 | 68 | 69 | logging.info(f"Load done, length of dataset: {len(self.img_list)}") 70 | 71 | def __len__(self): 72 | if self.mode == 'train': 73 | return len(self.img_list) * 100000 74 | else: 75 | return len(self.img_list) 76 | 77 | def preprocess(self, itk_img, itk_lab): 78 | 79 | img = sitk.GetArrayFromImage(itk_img).astype(np.float32) 80 | lab = sitk.GetArrayFromImage(itk_lab).astype(np.uint8) 81 | 82 | max98 = np.percentile(img, 98) 83 | img = np.clip(img, 0, max98) 84 | 85 | z, y, x = img.shape 86 | 87 | # pad if the image size is smaller than trainig size 88 | if z < self.args.training_size[0]: 89 | diff = (self.args.training_size[0]+2 - z) // 2 90 | img = np.pad(img, ((diff, diff), (0,0), (0,0))) 91 | lab = np.pad(lab, ((diff, diff), (0,0), (0,0))) 92 | if y < self.args.training_size[1]: 93 | diff = (self.args.training_size[1]+2 - y) // 2 94 | img = np.pad(img, ((0,0), (diff,diff), (0,0))) 95 | lab = np.pad(lab, ((0,0), (diff, diff), (0,0))) 96 | if x < self.args.training_size[2]: 97 | diff = (self.args.training_size[2]+2 - x) // 2 98 | img = np.pad(img, ((0,0), (0,0), (diff, diff))) 99 | lab = np.pad(lab, ((0,0), (0,0), (diff, diff))) 100 | 101 | img = img / max98 102 | def remove_background(img, lab, size=256): 103 | z, y, x = img.shape 104 | if y > size: 105 | img = img[:, y//2-size//2:y//2+size//2, :] 106 | lab = lab[:, y//2-size//2:y//2+size//2, :] 107 | if x > size: 108 | img = img[:, :, x//2-size//2:x//2+size//2] 109 | lab = lab[:, :, x//2-size//2:x//2+size//2] 110 | 111 | return img, lab 112 | img, lab = remove_background(img, lab, size=256) 113 | 114 | 115 | img = img.astype(np.float32) 116 | lab = lab.astype(np.uint8) 117 | 118 | tensor_img = torch.from_numpy(img).float() 119 | tensor_lab = torch.from_numpy(lab).long() 120 | 121 | return tensor_img, tensor_lab 122 | 123 | def __getitem__(self, idx): 124 | 125 | idx = idx % len(self.img_list) 126 | 127 | tensor_img = self.img_list[idx] 128 | tensor_lab = self.lab_list[idx] 129 | 130 | tensor_img = tensor_img.unsqueeze(0).unsqueeze(0) 131 | tensor_lab = tensor_lab.unsqueeze(0).unsqueeze(0) 132 | # 1, C, D, H, W 133 | 134 | 135 | if self.mode == 'train': 136 | 137 | if self.args.aug_device == 'gpu': 138 | tensor_img = tensor_img.cuda(self.args.proc_idx) 139 | tensor_lab = tensor_lab.cuda(self.args.proc_idx) 140 | 141 | # Gaussian Noise 142 | tensor_img = augmentation.gaussian_noise(tensor_img, std=self.args.gaussian_noise_std) 143 | # Additive brightness 144 | tensor_img = augmentation.brightness_additive(tensor_img, std=self.args.additive_brightness_std) 145 | # gamma 146 | tensor_img = augmentation.gamma(tensor_img, gamma_range=self.args.gamma_range, retain_stats=True) 147 | 148 | tensor_img, tensor_lab = augmentation.random_scale_rotate_translate_3d(tensor_img, tensor_lab, self.args.scale, self.args.rotate, self.args.translate) 149 | tensor_img, tensor_lab = augmentation.crop_3d(tensor_img, tensor_lab, self.args.training_size, mode='random') 150 | 151 | #else: 152 | # tensor_img, tensor_lab = augmentation.crop_3d(tensor_img, tensor_lab,self.args.training_size, mode='center') 153 | 154 | tensor_img = tensor_img.squeeze(0) 155 | tensor_lab = tensor_lab.squeeze(0) 156 | 157 | assert tensor_img.shape == tensor_lab.shape 158 | 159 | if self.mode == 'train': 160 | return tensor_img, tensor_lab.to(torch.int8) 161 | else: 162 | return tensor_img, tensor_lab, np.array(self.spacing_list[idx]) 163 | 164 | 165 | -------------------------------------------------------------------------------- /training/dataset/dim3/dataset_bcv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | import SimpleITK as sitk 8 | import yaml 9 | import math 10 | import random 11 | import pdb 12 | from training import augmentation 13 | import os 14 | 15 | class BCVDataset(Dataset): 16 | def __init__(self, args, mode='train', k_fold=5, k=0, seed=0): 17 | 18 | self.mode = mode 19 | self.args = args 20 | 21 | assert mode in ['train', 'test'] 22 | 23 | with open(os.path.join(args.data_root, 'list', 'dataset.yaml'), 'r') as f: 24 | img_name_list = yaml.load(f, Loader=yaml.SafeLoader) 25 | 26 | 27 | random.Random(seed).shuffle(img_name_list) 28 | 29 | length = len(img_name_list) 30 | test_name_list = img_name_list[k*(length//k_fold) : (k+1)*(length//k_fold)] 31 | train_name_list = list(set(img_name_list) - set(test_name_list)) 32 | 33 | if mode == 'train': 34 | img_name_list = train_name_list 35 | else: 36 | img_name_list = test_name_list 37 | 38 | print(img_name_list) 39 | print('Start loading %s data'%self.mode) 40 | 41 | path = args.data_root 42 | 43 | self.img_list = [] 44 | self.lab_list = [] 45 | self.spacing_list = [] 46 | 47 | for name in img_name_list: 48 | 49 | img_name = name + '.nii.gz' 50 | lab_name = name + '_gt.nii.gz' 51 | 52 | itk_img = sitk.ReadImage(os.path.join(path, img_name)) 53 | itk_lab = sitk.ReadImage(os.path.join(path, lab_name)) 54 | 55 | spacing = np.array(itk_lab.GetSpacing()).tolist() 56 | self.spacing_list.append(spacing[::-1]) # itk axis order is inverse of numpy axis order 57 | 58 | assert itk_img.GetSize() == itk_lab.GetSize() 59 | 60 | img, lab = self.preprocess(itk_img, itk_lab) 61 | 62 | self.img_list.append(img) 63 | self.lab_list.append(lab) 64 | 65 | 66 | print('Load done, length of dataset:', len(self.img_list)) 67 | 68 | def __len__(self): 69 | if self.mode == 'train': 70 | return len(self.img_list) * 100000 71 | else: 72 | return len(self.img_list) 73 | 74 | def preprocess(self, itk_img, itk_lab): 75 | 76 | img = sitk.GetArrayFromImage(itk_img).astype(np.float32) 77 | lab = sitk.GetArrayFromImage(itk_lab).astype(np.uint8) 78 | 79 | img = np.clip(img, -958, 327) 80 | img -= 82.92 81 | img /= 136.97 82 | 83 | z, y, x = img.shape 84 | 85 | # pad if the image size is smaller than trainig size 86 | if z < self.args.training_size[0]: 87 | diff = int(math.ceil((self.args.training_size[0] - z) / 2)) 88 | img = np.pad(img, ((diff, diff), (0,0), (0,0))) 89 | lab = np.pad(lab, ((diff, diff), (0,0), (0,0))) 90 | if y < self.args.training_size[1]: 91 | diff = int(math.ceil((self.args.training_size[1]+2 - y) / 2)) 92 | img = np.pad(img, ((0,0), (diff,diff), (0,0))) 93 | lab = np.pad(lab, ((0,0), (diff, diff), (0,0))) 94 | if x < self.args.training_size[2]: 95 | diff = int(math.ceil((self.args.training_size[2]+2 - x) / 2)) 96 | img = np.pad(img, ((0,0), (0,0), (diff, diff))) 97 | lab = np.pad(lab, ((0,0), (0,0), (diff, diff))) 98 | 99 | tensor_img = torch.from_numpy(img).float() 100 | tensor_lab = torch.from_numpy(lab).long() 101 | 102 | assert tensor_img.shape == tensor_lab.shape 103 | 104 | return tensor_img, tensor_lab 105 | 106 | def __getitem__(self, idx): 107 | 108 | idx = idx % len(self.img_list) 109 | 110 | tensor_img = self.img_list[idx] 111 | tensor_lab = self.lab_list[idx] 112 | 113 | tensor_img = tensor_img.unsqueeze(0).unsqueeze(0) 114 | tensor_lab = tensor_lab.unsqueeze(0).unsqueeze(0) 115 | # 1, C, D, H, W 116 | 117 | 118 | if self.mode == 'train': 119 | if self.args.aug_device == 'gpu': 120 | tensor_img = tensor_img.cuda(self.args.proc_idx) 121 | tensor_lab = tensor_lab.cuda(self.args.proc_idx) 122 | 123 | d, h, w = self.args.training_size 124 | 125 | if np.random.random() < 0.5: 126 | 127 | tensor_img, tensor_lab = augmentation.crop_3d(tensor_img, tensor_lab, [d+15, h+65, w+65], mode='random') 128 | tensor_img, tensor_lab = augmentation.random_scale_rotate_translate_3d(tensor_img, tensor_lab, self.args.scale, self.args.rotate, self.args.translate) 129 | 130 | tensor_img, tensor_lab = augmentation.crop_3d(tensor_img, tensor_lab, self.args.training_size, mode='center') 131 | 132 | else: 133 | tensor_img, tensor_lab = augmentation.crop_3d(tensor_img, tensor_lab, self.args.training_size, mode='random') 134 | 135 | tensor_img, tensor_lab = tensor_img.contiguous(), tensor_lab.contiguous() 136 | 137 | if np.random.random() < 0.2: 138 | tensor_img = augmentation.brightness_multiply(tensor_img, multiply_range=[0.7, 1.3]) 139 | if np.random.random() < 0.2: 140 | tensor_img = augmentation.brightness_additive(tensor_img, std=0.1) 141 | if np.random.random() < 0.2: 142 | tensor_img = augmentation.gamma(tensor_img, gamma_range=[0.7, 1.5]) 143 | if np.random.random() < 0.2: 144 | tensor_img = augmentation.contrast(tensor_img, contrast_range=[0.7, 1.3]) 145 | if np.random.random() < 0.2: 146 | tensor_img = augmentation.gaussian_blur(tensor_img, sigma_range=[0.5, 1.0]) 147 | if np.random.random() < 0.2: 148 | std = np.random.random() * 0.1 149 | tensor_img = augmentation.gaussian_noise(tensor_img, std=std) 150 | 151 | 152 | 153 | 154 | tensor_img = tensor_img.squeeze(0) 155 | tensor_lab = tensor_lab.squeeze(0) 156 | 157 | assert tensor_img.shape == tensor_lab.shape 158 | 159 | if self.mode == 'train': 160 | return tensor_img, tensor_lab 161 | else: 162 | return tensor_img, tensor_lab, np.array(self.spacing_list[idx]) 163 | 164 | -------------------------------------------------------------------------------- /training/dataset/dim3/dataset_lits.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | import SimpleITK as sitk 8 | import yaml 9 | import math 10 | import random 11 | import pdb 12 | from training import augmentation 13 | 14 | class LiverDataset(Dataset): 15 | def __init__(self, args, mode='train', k_fold=5, k=0, seed=0): 16 | 17 | self.mode = mode 18 | self.args = args 19 | 20 | assert mode in ['train', 'test'] 21 | 22 | with open(os.path.join(args.data_root, 'list', 'dataset.yaml'), 'r') as f: 23 | img_name_list = yaml.load(f, Loader=yaml.SafeLoader) 24 | 25 | 26 | random.Random(seed).shuffle(img_name_list) 27 | 28 | length = len(img_name_list) 29 | test_name_list = img_name_list[k*(length//k_fold) : (k+1)*(length//k_fold)] 30 | train_name_list = list(set(img_name_list) - set(test_name_list)) 31 | 32 | if mode == 'train': 33 | img_name_list = train_name_list 34 | else: 35 | img_name_list = test_name_list 36 | 37 | 38 | print('Start loading %s data'%self.mode) 39 | print(img_name_list) 40 | path = args.data_root 41 | 42 | self.img_list = [] 43 | self.lab_list = [] 44 | self.spacing_list = [] 45 | 46 | for name in img_name_list: 47 | img_name = '%d.nii.gz'%name 48 | lab_name = '%d_gt.nii.gz'%name 49 | 50 | itk_img = sitk.ReadImage(os.path.join(path, img_name)) 51 | itk_lab = sitk.ReadImage(os.path.join(path, lab_name)) 52 | 53 | spacing = np.array(itk_lab.GetSpacing()).tolist() 54 | self.spacing_list.append(spacing[::-1]) # itk axis order is inverse of numpy axis order 55 | 56 | assert itk_img.GetSize() == itk_lab.GetSize() 57 | 58 | img, lab = self.preprocess(itk_img, itk_lab) 59 | 60 | self.img_list.append(img) 61 | self.lab_list.append(lab) 62 | 63 | print('Load done, length of dataset:', len(self.img_list)) 64 | 65 | def __len__(self): 66 | if self.mode == 'train': 67 | return len(self.img_list) * 100000 68 | else: 69 | return len(self.img_list) 70 | 71 | def preprocess(self, itk_img, itk_lab): 72 | 73 | img = sitk.GetArrayFromImage(itk_img).astype(np.float32) 74 | lab = sitk.GetArrayFromImage(itk_lab).astype(np.uint8) 75 | 76 | img = np.clip(img, -17, 201) 77 | img -= 99.40 78 | img /= 39.39 79 | 80 | z, y, x = img.shape 81 | 82 | # pad if the image size is smaller than trainig size 83 | if z < self.args.training_size[0]: 84 | diff = int(math.ceil((self.args.training_size[0] - z) / 2)) 85 | img = np.pad(img, ((diff, diff), (0,0), (0,0))) 86 | lab = np.pad(lab, ((diff, diff), (0,0), (0,0))) 87 | if y < self.args.training_size[1]: 88 | diff = int(math.ceil((self.args.training_size[1]+2 - y) / 2)) 89 | img = np.pad(img, ((0,0), (diff,diff), (0,0))) 90 | lab = np.pad(lab, ((0,0), (diff, diff), (0,0))) 91 | if x < self.args.training_size[2]: 92 | diff = int(math.ceil((self.args.training_size[2]+2 - x) / 2)) 93 | img = np.pad(img, ((0,0), (0,0), (diff, diff))) 94 | lab = np.pad(lab, ((0,0), (0,0), (diff, diff))) 95 | 96 | tensor_img = torch.from_numpy(img).float() 97 | tensor_lab = torch.from_numpy(lab).to(torch.int8) 98 | 99 | assert tensor_img.shape == tensor_lab.shape 100 | 101 | return tensor_img, tensor_lab 102 | 103 | def __getitem__(self, idx): 104 | 105 | idx = idx % len(self.img_list) 106 | 107 | tensor_img = self.img_list[idx] 108 | tensor_lab = self.lab_list[idx] 109 | 110 | tensor_img = tensor_img.unsqueeze(0).unsqueeze(0) 111 | tensor_lab = tensor_lab.unsqueeze(0).unsqueeze(0) 112 | # 1, C, D, H, W 113 | 114 | 115 | if self.mode == 'train': 116 | if self.args.aug_device == 'gpu': 117 | tensor_img = tensor_img.cuda(self.args.proc_idx) 118 | tensor_lab = tensor_lab.cuda(self.args.proc_idx) 119 | 120 | d, h, w = self.args.training_size 121 | 122 | if np.random.random() < 0.2: 123 | # crop trick for faster augmentation 124 | # crop a sub volume for scaling and rotation 125 | # instead of scaling and rotating the whole image 126 | tensor_img, tensor_lab = augmentation.crop_3d(tensor_img, tensor_lab, [d+70, h+70, w+70], mode='random') 127 | tensor_img, tensor_lab = augmentation.random_scale_rotate_translate_3d(tensor_img, tensor_lab, self.args.scale, self.args.rotate, self.args.translate) 128 | tensor_img, tensor_lab = augmentation.crop_3d(tensor_img, tensor_lab, self.args.training_size, mode='center') 129 | else: 130 | tensor_img, tensor_lab = augmentation.crop_3d(tensor_img, tensor_lab, self.args.training_size, mode='random') 131 | 132 | tensor_img, tensor_lab = tensor_img.contiguous(), tensor_lab.contiguous() 133 | 134 | if np.random.random() < 0.15: 135 | std = np.random.random() * 0.1 136 | tensor_img = augmentation.gaussian_noise(tensor_img, std=std) 137 | 138 | if np.random.random() < 0.15: 139 | tensor_img = augmentation.brightness_multiply(tensor_img, multiply_range=[0.7, 1.3]) 140 | if np.random.random() < 0.15: 141 | tensor_img = augmentation.gamma(tensor_img, gamma_range=[0.7, 1.5]) 142 | if np.random.random() < 0.15: 143 | tensor_img = augmentation.contrast(tensor_img, contrast_range=[0.65, 1.5]) 144 | if np.random.random() < 0.3: 145 | tensor_img = augmentation.mirror(tensor_img, axis=2) 146 | tensor_lab = augmentation.mirror(tensor_lab, axis=2) 147 | if np.random.random() < 0.2: 148 | tensor_img = augmentation.mirror(tensor_img, axis=1) 149 | tensor_lab = augmentation.mirror(tensor_lab, axis=1) 150 | if np.random.random() < 0.05: 151 | tensor_img = augmentation.mirror(tensor_img, axis=1) 152 | tensor_lab = augmentation.mirror(tensor_lab, axis=1) 153 | 154 | 155 | tensor_img = tensor_img.squeeze(0) 156 | tensor_lab = tensor_lab.squeeze(0) 157 | 158 | assert tensor_img.shape == tensor_lab.shape 159 | 160 | if self.mode == 'train': 161 | return tensor_img, tensor_lab 162 | else: 163 | return tensor_img, tensor_lab, np.array(self.spacing_list[idx]) 164 | -------------------------------------------------------------------------------- /training/dataset/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_dataset(args, mode, **kwargs): 4 | 5 | if args.dimension == '2d': 6 | if args.dataset == 'acdc': 7 | from .dim2.dataset_acdc import CMRDataset 8 | 9 | return CMRDataset(args, mode=mode, k_fold=args.k_fold, k=kwargs['fold_idx'], seed=args.split_seed) 10 | 11 | else: 12 | if args.dataset == 'acdc': 13 | from .dim3.dataset_acdc import CMRDataset 14 | 15 | return CMRDataset(args, mode=mode, k_fold=args.k_fold, k=kwargs['fold_idx'], seed=args.split_seed) 16 | elif args.dataset == 'lits': 17 | from .dim3.dataset_lits import LiverDataset 18 | 19 | return LiverDataset(args, mode=mode, k_fold=args.k_fold, k=kwargs['fold_idx'], seed=args.split_seed) 20 | 21 | elif args.dataset == 'bcv': 22 | from .dim3.dataset_bcv import BCVDataset 23 | 24 | return BCVDataset(args, mode=mode, k_fold=args.k_fold, k=kwargs['fold_idx'], seed=args.split_seed) 25 | 26 | elif args.dataset == 'kits': 27 | from .dim3.dataset_kits import KidneyDataset 28 | 29 | return KidneyDataset(args, mode=mode, k_fold=args.k_fold, k=kwargs['fold_idx'], seed=args.split_seed) 30 | 31 | elif args.dataset == 'amos_ct': 32 | from .dim3.dataset_amos_ct import AMOSDataset 33 | 34 | return AMOSDataset(args, mode=mode, k_fold=args.k_fold, k=kwargs['fold_idx'], seed=args.split_seed) 35 | 36 | elif args.dataset == 'amos_mr': 37 | from .dim3.dataset_amos_mr import AMOSDataset 38 | 39 | return AMOSDataset(args, mode=mode, k_fold=args.k_fold, k=kwargs['fold_idx'], seed=args.split_seed) 40 | 41 | elif args.dataset == 'msd_lung': 42 | from .dim3.dataset_msd_lung import LungDataset 43 | 44 | return LungDataset(args, mode=mode, k_fold=args.k_fold, k=kwargs['fold_idx'], seed=args.split_seed) 45 | 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /training/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | import pdb 7 | 8 | class DiceLoss(nn.Module): 9 | 10 | def __init__(self, alpha=0.5, beta=0.5, size_average=True, reduce=True): 11 | super(DiceLoss, self).__init__() 12 | self.alpha = alpha 13 | self.beta = beta 14 | 15 | self.size_average = size_average 16 | self.reduce = reduce 17 | 18 | def forward(self, preds, targets): 19 | N = preds.size(0) 20 | C = preds.size(1) 21 | 22 | 23 | P = F.softmax(preds, dim=1) 24 | smooth = torch.zeros(C, dtype=torch.float32).fill_(0.00001) 25 | 26 | class_mask = torch.zeros(preds.shape).to(preds.device) 27 | class_mask.scatter_(1, targets, 1.) 28 | 29 | ones = torch.ones(preds.shape).to(preds.device) 30 | P_ = ones - P 31 | class_mask_ = ones - class_mask 32 | 33 | TP = P * class_mask 34 | FP = P * class_mask_ 35 | FN = P_ * class_mask 36 | 37 | smooth = smooth.to(preds.device) 38 | self.alpha = FP.transpose(0, 1).reshape(C, -1).sum(dim=(1)) / ((FP.transpose(0, 1).reshape(C, -1).sum(dim=(1)) + FN.transpose(0, 1).reshape(C, -1).sum(dim=(1))) + smooth) 39 | 40 | self.alpha = torch.clamp(self.alpha, min=0.2, max=0.8) 41 | #print('alpha:', self.alpha) 42 | self.beta = 1 - self.alpha 43 | num = torch.sum(TP.transpose(0, 1).reshape(C, -1), dim=(1)).float() 44 | den = num + self.alpha * torch.sum(FP.transpose(0, 1).reshape(C, -1), dim=(1)).float() + self.beta * torch.sum(FN.transpose(0, 1).reshape(C, -1), dim=(1)).float() 45 | 46 | dice = num / (den + smooth) 47 | 48 | if not self.reduce: 49 | loss = torch.ones(C).to(dice.device) - dice 50 | return loss 51 | 52 | loss = 1 - dice 53 | loss = loss.sum() 54 | 55 | if self.size_average: 56 | loss /= C 57 | 58 | return loss 59 | 60 | class FocalLoss(nn.Module): 61 | def __init__(self, class_num, alpha=None, gamma=2, size_average=True): 62 | super(FocalLoss, self).__init__() 63 | 64 | if alpha is None: 65 | self.alpha = torch.ones(class_num) 66 | else: 67 | self.alpha = alpha 68 | 69 | self.gamma = gamma 70 | self.size_average = size_average 71 | 72 | def forward(self, preds, targets): 73 | N = preds.size(0) 74 | C = preds.size(1) 75 | 76 | targets = targets.unsqueeze(1) 77 | P = F.softmax(preds, dim=1) 78 | log_P = F.log_softmax(preds, dim=1) 79 | 80 | class_mask = torch.zeros(preds.shape).to(preds.device) 81 | class_mask.scatter_(1, targets, 1.) 82 | 83 | if targets.size(1) == 1: 84 | # squeeze the chaneel for target 85 | targets = targets.squeeze(1) 86 | alpha = self.alpha[targets.data].to(preds.device) 87 | 88 | probs = (P * class_mask).sum(1) 89 | log_probs = (log_P * class_mask).sum(1) 90 | 91 | batch_loss = -alpha * (1-probs).pow(self.gamma)*log_probs 92 | 93 | if self.size_average: 94 | loss = batch_loss.mean() 95 | else: 96 | loss = batch_loss.sum() 97 | 98 | return loss 99 | 100 | if __name__ == '__main__': 101 | 102 | DL = DiceLoss() 103 | FL = FocalLoss(10) 104 | 105 | pred = torch.randn(2, 10, 128, 128) 106 | target = torch.zeros((2, 1, 128, 128)).long() 107 | 108 | dl_loss = DL(pred, target) 109 | fl_loss = FL(pred, target) 110 | 111 | print('2D:', dl_loss.item(), fl_loss.item()) 112 | 113 | pred = torch.randn(2, 10, 64, 128, 128) 114 | target = torch.zeros(2, 1, 64, 128, 128).long() 115 | 116 | dl_loss = DL(pred, target) 117 | fl_loss = FL(pred, target) 118 | 119 | print('3D:', dl_loss.item(), fl_loss.item()) 120 | 121 | 122 | -------------------------------------------------------------------------------- /training/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributed as dist 5 | from torch import optim 6 | 7 | 8 | def get_optimizer(args, net): 9 | if args.optimizer == 'sgd': 10 | return optim.SGD(net.parameters(), lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay) 11 | elif args.optimizer == 'adam': 12 | return optim.Adam(net.parameters(), lr=args.base_lr, betas=args.betas, weight_decay=args.weight_decay) 13 | elif args.optimizer == 'adamw': 14 | return optim.AdamW(net.parameters(), lr=args.base_lr, betas=args.betas, weight_decay=args.weight_decay, eps=1e-5) # larger eps has better stability during AMP training 15 | 16 | 17 | def log_evaluation_result(writer, dice_list, ASD_list, HD_list, name, epoch, args): 18 | C = dice_list.shape[0] 19 | 20 | writer.add_scalar('Dice/%s_AVG'%name, dice_list.mean(), epoch+1) 21 | for idx in range(C): 22 | writer.add_scalar('Dice/%s_Dice%d'%(name, idx+1), dice_list[idx], epoch+1) 23 | writer.add_scalar('ASD/%s_AVG'%name, ASD_list.mean(), epoch+1) 24 | for idx in range(C): 25 | writer.add_scalar('ASD/%s_ASD%d'%(name, idx+1), ASD_list[idx], epoch+1) 26 | writer.add_scalar('HD/%s_AVG'%name, HD_list.mean(), epoch+1) 27 | for idx in range(C): 28 | writer.add_scalar('HD/%s_HD%d'%(name, idx+1), HD_list[idx], epoch+1) 29 | 30 | def unwrap_model_checkpoint(net, ema_net, args): 31 | net_state_dict = net.module if args.distributed else net 32 | net_state_dict = net_state_dict._orig_mod.state_dict() if args.torch_compile else net_state_dict.state_dict() 33 | if args.ema: 34 | if args.distributed: 35 | ema_net_state_dict = ema_net.module.state_dict() 36 | else: 37 | ema_net_state_dict = ema_net.state_dict() 38 | else: 39 | ema_net_state_dict = None 40 | 41 | return net_state_dict, ema_net_state_dict 42 | 43 | def filter_validation_results(dice_list, ASD_list, HD_list, args): 44 | if args.dataset == 'amos_mr': 45 | # the validation set of amos_mr doesn't have the last two organs, so elimiate them 46 | dice_list, ASD_list, HD_list = dice_list[:-2], ASD_list[:-2], HD_list[:-2] 47 | 48 | return dice_list, ASD_list, HD_list 49 | 50 | def multistep_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup_epoch, lr_decay_epoch, max_epoch, gamma=0.1): 51 | 52 | if epoch >= 0 and epoch <= warmup_epoch: 53 | lr = init_lr * 2.718 ** (10*(float(epoch) / float(warmup_epoch) - 1.)) 54 | if epoch == warmup_epoch: 55 | lr = init_lr 56 | for param_group in optimizer.param_groups: 57 | param_group['lr'] = lr 58 | 59 | return lr 60 | 61 | flag = False 62 | for i in range(len(lr_decay_epoch)): 63 | if epoch == lr_decay_epoch[i]: 64 | flag = True 65 | break 66 | 67 | if flag == True: 68 | lr = init_lr * gamma**(i+1) 69 | for param_group in optimizer.param_groups: 70 | param_group['lr'] = lr 71 | 72 | else: 73 | return optimizer.param_groups[0]['lr'] 74 | 75 | return lr 76 | 77 | def exp_lr_scheduler_with_warmup(optimizer, init_lr, epoch, warmup_epoch, max_epoch): 78 | 79 | if epoch >= 0 and epoch <= warmup_epoch and warmup_epoch != 0: 80 | lr = init_lr * 2.718 ** (10*(float(epoch) / float(warmup_epoch) - 1.)) 81 | if epoch == warmup_epoch: 82 | lr = init_lr 83 | for param_group in optimizer.param_groups: 84 | param_group['lr'] = lr 85 | 86 | return lr 87 | 88 | else: 89 | lr = init_lr * (1 - epoch / max_epoch)**0.9 90 | for param_group in optimizer.param_groups: 91 | param_group['lr'] = lr 92 | 93 | return lr 94 | 95 | 96 | 97 | 98 | def update_ema_variables(model, ema_model, alpha, global_step): 99 | 100 | alpha = min((1 - 1 / (global_step + 1)), alpha) 101 | for ema_param, param in zip(ema_model.parameters(), model.parameters()): 102 | ema_param.data.mul_(alpha).add_(1 - alpha, param.data) 103 | 104 | for ema_buffer, m_buffer in zip(ema_model.buffers(), model.buffers()): 105 | ema_buffer.copy_(m_buffer) 106 | 107 | 108 | 109 | @torch.no_grad() 110 | def concat_all_gather(tensor): 111 | """ 112 | Performs all_gather operation on the provided tensor 113 | *** Warning ***: torch.distributed.all_gather has no gradient. 114 | """ 115 | tensors_gather = [torch.ones_like(tensor) for _ in range(dist.get_world_size())] 116 | dist.all_gather(tensors_gather, tensor, async_op=False) 117 | 118 | output = torch.cat(tensors_gather, dim=0) 119 | return output 120 | 121 | 122 | @torch.no_grad() 123 | def remove_wrap_arounds(tensor, ranks): 124 | """ 125 | Due to the DistributedSampler will pad samples for evenly distribute 126 | samples to gpus, the padded samples need to be removed for right 127 | evaluation. Need to turn shuffle to False for the dataloader. 128 | """ 129 | if ranks == 0: 130 | return tensor 131 | 132 | world_size = dist.get_world_size() 133 | single_length = len(tensor) // world_size 134 | output = [] 135 | 136 | for rank in range(world_size): 137 | sub_tensor = tensor[rank * single_length : (rank+1) * single_length] 138 | if rank >= ranks: 139 | output.append(sub_tensor[:-1]) 140 | else: 141 | output.append(sub_tensor) 142 | 143 | output = torch.cat(output) 144 | 145 | return output 146 | 147 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import torch 4 | import torch.distributed as dist 5 | import pdb 6 | 7 | LOG_FORMAT = "[%(levelname)s] %(asctime)s %(filename)s:%(lineno)s %(message)s" 8 | LOG_DATEFMT = "%Y-%m-%d %H:%M:%S" 9 | 10 | def configure_logger(rank, log_path=None): 11 | if log_path: 12 | log_dir = os.path.dirname(log_path) 13 | os.makedirs(log_dir, exist_ok=True) 14 | 15 | # only master process will print & write 16 | level = logging.INFO if rank in {-1, 0} else logging.WARNING 17 | handlers = [logging.StreamHandler()] 18 | if rank in {0, -1} and log_path: 19 | handlers.append(logging.FileHandler(log_path, "w")) 20 | 21 | logging.basicConfig( 22 | level=level, 23 | format=LOG_FORMAT, 24 | datefmt=LOG_DATEFMT, 25 | handlers=handlers, 26 | force=True, 27 | ) 28 | 29 | 30 | def save_configure(args): 31 | if hasattr(args, "distributed"): 32 | if (args.distributed and is_master(args)) or (not args.distributed): 33 | with open(f"{args.cp_dir}/config.txt", 'w') as f: 34 | for name in args.__dict__: 35 | f.write(f"{name}: {getattr(args, name)}\n") 36 | else: 37 | with open(f"{args.cp_dir}/config.txt", 'w') as f: 38 | for name in args.__dict__: 39 | f.write(f"{name}: {getattr(args, name)}\n") 40 | 41 | def resume_load_optimizer_checkpoint(optimizer, args): 42 | assert args.load != False, "Please specify the load path with --load" 43 | 44 | checkpoint = torch.load(args.load) 45 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 46 | 47 | def resume_load_model_checkpoint(net, ema_net, args): 48 | assert args.load != False, "Please specify the load path with --load" 49 | 50 | checkpoint = torch.load(args.load) 51 | net.load_state_dict(checkpoint['model_state_dict']) 52 | args.start_epoch = checkpoint['epoch'] 53 | 54 | if args.ema: 55 | ema_net.load_state_dict(checkpoint['ema_model_state_dict']) 56 | 57 | 58 | 59 | class AverageMeter(object): 60 | """ Computes and stores the average and current value """ 61 | 62 | def __init__(self, name, fmt=":f"): 63 | self.name = name 64 | self.fmt = fmt 65 | self.reset() 66 | 67 | def reset(self): 68 | self.val = 0 69 | self.avg = 0 70 | self.sum = 0 71 | self.count = 0 72 | 73 | def update(self, val, n=1): 74 | self.val = val 75 | self.sum += val * n 76 | self.count += n 77 | self.avg = self.sum / self.count 78 | 79 | def __str__(self): 80 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 81 | return fmtstr.format(**self.__dict__) 82 | 83 | 84 | 85 | class ProgressMeter(object): 86 | def __init__(self, num_batches, meters, prefix=""): 87 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 88 | self.meters = meters 89 | self.prefix = prefix 90 | 91 | def display(self, batch): 92 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 93 | entries += [str(meter) for meter in self.meters] 94 | logging.info("\t".join(entries)) 95 | 96 | def _get_batch_fmtstr(self, num_batches): 97 | num_digits = len(str(num_batches // 1)) 98 | fmt = "{:" + str(num_digits) + "d}" 99 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 100 | 101 | 102 | def is_master(args): 103 | return args.rank % args.ngpus_per_node == 0 104 | --------------------------------------------------------------------------------