├── model ├── __init__.py ├── DCCNN.py ├── LPDNet.py ├── HQSNet.py ├── ISTANet_plus.py └── BasicModule.py ├── requirements.txt ├── log ├── dc-cnn_acc_5_bs_1_lr_0.001 │ └── events.out.tfevents.1643940254.dionysos.cs.rutgers.edu ├── dc-cnn_acc_10_bs_1_lr_0.001 │ └── events.out.tfevents.1643940425.dionysos.cs.rutgers.edu ├── hqs-net_acc_10_bs_1_lr_0.001 │ └── events.out.tfevents.1643940431.dionysos.cs.rutgers.edu ├── hqs-net_acc_5_bs_1_lr_0.001 │ └── events.out.tfevents.1643940276.dionysos.cs.rutgers.edu ├── lpd-net_acc_10_bs_1_lr_0.001 │ └── events.out.tfevents.1643940447.dionysos.cs.rutgers.edu ├── lpd-net_acc_5_bs_1_lr_0.001 │ └── events.out.tfevents.1643940293.dionysos.cs.rutgers.edu ├── hqs-net-unet_acc_10_bs_1_lr_0.001 │ └── events.out.tfevents.1644031479.dionysos.cs.rutgers.edu ├── hqs-net-unet_acc_5_bs_1_lr_0.001 │ ├── events.out.tfevents.1644031446.dionysos.cs.rutgers.edu │ ├── events.out.tfevents.1644095553.dionysos.cs.rutgers.edu │ ├── events.out.tfevents.1644096279.dionysos.cs.rutgers.edu │ ├── events.out.tfevents.1644104403.dionysos.cs.rutgers.edu │ ├── events.out.tfevents.1644106432.dionysos.cs.rutgers.edu │ ├── events.out.tfevents.1644106637.dionysos.cs.rutgers.edu │ ├── events.out.tfevents.1644112741.dionysos.cs.rutgers.edu │ └── events.out.tfevents.1644126829.dionysos.cs.rutgers.edu ├── ista-net-plus_acc_5_bs_1_lr_0.001 │ └── events.out.tfevents.1643940287.dionysos.cs.rutgers.edu └── ista-net-plus_acc_10_bs_1_lr_0.001 │ └── events.out.tfevents.1643940441.dionysos.cs.rutgers.edu ├── run_sh ├── acc_10 │ ├── test │ │ ├── test_dc_10.sh │ │ ├── test_hqs_10.sh │ │ ├── test_lpd_10.sh │ │ ├── test_ista_10.sh │ │ └── test_hqs_unet_10.sh │ └── train │ │ ├── train_dc_10.sh │ │ ├── train_hqs_10.sh │ │ ├── train_lpd_10.sh │ │ ├── train_ista_10.sh │ │ └── train_hqs_unet_10.sh └── acc_5 │ ├── test │ ├── test_dc_5.sh │ ├── test_hqs_5.sh │ ├── test_lpd_5.sh │ ├── test_ista_5.sh │ └── test_hqs_unet_5.sh │ └── train │ ├── train_dc_5.sh │ ├── train_hqs_5.sh │ ├── train_lpd_5.sh │ ├── train_ista_5.sh │ └── train_hqs_unet_5.sh ├── cal_model_flops.py ├── main.py ├── README.md ├── read_data.py ├── preprocess_ocmr.py ├── utils.py ├── loss.py └── Solver.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | # @author : Bingyu Xin 2 | # @Institute : CS@Rutgers 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ismrmrd 2 | matplotlib 3 | pandas 4 | tqdm 5 | scipy 6 | torchvision 7 | torch >= 1.10 8 | tensorboardX 9 | fvcore -------------------------------------------------------------------------------- /log/dc-cnn_acc_5_bs_1_lr_0.001/events.out.tfevents.1643940254.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/dc-cnn_acc_5_bs_1_lr_0.001/events.out.tfevents.1643940254.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/dc-cnn_acc_10_bs_1_lr_0.001/events.out.tfevents.1643940425.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/dc-cnn_acc_10_bs_1_lr_0.001/events.out.tfevents.1643940425.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/hqs-net_acc_10_bs_1_lr_0.001/events.out.tfevents.1643940431.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/hqs-net_acc_10_bs_1_lr_0.001/events.out.tfevents.1643940431.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/hqs-net_acc_5_bs_1_lr_0.001/events.out.tfevents.1643940276.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/hqs-net_acc_5_bs_1_lr_0.001/events.out.tfevents.1643940276.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/lpd-net_acc_10_bs_1_lr_0.001/events.out.tfevents.1643940447.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/lpd-net_acc_10_bs_1_lr_0.001/events.out.tfevents.1643940447.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/lpd-net_acc_5_bs_1_lr_0.001/events.out.tfevents.1643940293.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/lpd-net_acc_5_bs_1_lr_0.001/events.out.tfevents.1643940293.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/hqs-net-unet_acc_10_bs_1_lr_0.001/events.out.tfevents.1644031479.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/hqs-net-unet_acc_10_bs_1_lr_0.001/events.out.tfevents.1644031479.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644031446.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644031446.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644095553.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644095553.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644096279.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644096279.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644104403.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644104403.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644106432.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644106432.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644106637.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644106637.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644112741.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644112741.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644126829.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/hqs-net-unet_acc_5_bs_1_lr_0.001/events.out.tfevents.1644126829.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/ista-net-plus_acc_5_bs_1_lr_0.001/events.out.tfevents.1643940287.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/ista-net-plus_acc_5_bs_1_lr_0.001/events.out.tfevents.1643940287.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /log/ista-net-plus_acc_10_bs_1_lr_0.001/events.out.tfevents.1643940441.dionysos.cs.rutgers.edu: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hellopipu/HQS-Net/HEAD/log/ista-net-plus_acc_10_bs_1_lr_0.001/events.out.tfevents.1643940441.dionysos.cs.rutgers.edu -------------------------------------------------------------------------------- /run_sh/acc_10/test/test_dc_10.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'test' \ 3 | --model 'dc-cnn' \ 4 | --acc 10 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_5/test/test_dc_5.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'test' \ 3 | --model 'dc-cnn' \ 4 | --acc 5 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_5/test/test_hqs_5.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'test' \ 3 | --model 'hqs-net' \ 4 | --acc 5 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_5/test/test_lpd_5.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'test' \ 3 | --model 'lpd-net' \ 4 | --acc 5 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_5/train/train_dc_5.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'train' \ 3 | --model 'dc-cnn' \ 4 | --acc 5 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_10/test/test_hqs_10.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'test' \ 3 | --model 'hqs-net' \ 4 | --acc 10 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_10/test/test_lpd_10.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'test' \ 3 | --model 'lpd-net' \ 4 | --acc 10 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_10/train/train_dc_10.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'train' \ 3 | --model 'dc-cnn' \ 4 | --acc 10 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_10/train/train_hqs_10.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'train' \ 3 | --model 'hqs-net' \ 4 | --acc 10 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 101 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_10/train/train_lpd_10.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'train' \ 3 | --model 'lpd-net' \ 4 | --acc 10 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_5/test/test_ista_5.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'test' \ 3 | --model 'ista-net-plus' \ 4 | --acc 5 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_5/train/train_hqs_5.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'train' \ 3 | --model 'hqs-net' \ 4 | --acc 5 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_5/train/train_lpd_5.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'train' \ 3 | --model 'lpd-net' \ 4 | --acc 5 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_10/test/test_ista_10.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'test' \ 3 | --model 'ista-net-plus' \ 4 | --acc 10 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_5/test/test_hqs_unet_5.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'test' \ 3 | --model 'hqs-net-unet' \ 4 | --acc 5 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_5/train/train_ista_5.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'train' \ 3 | --model 'ista-net-plus' \ 4 | --acc 5 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_10/test/test_hqs_unet_10.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'test' \ 3 | --model 'hqs-net-unet' \ 4 | --acc 10 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_10/train/train_ista_10.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'train' \ 3 | --model 'ista-net-plus' \ 4 | --acc 10 \ 5 | --batch_size 1 \ 6 | --lr 1e-3 \ 7 | --val_on_epochs 2 \ 8 | --num_epoch 300 \ 9 | --train_path "data/fs_train.npy" \ 10 | --val_path "data/fs_val.npy" \ 11 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_5/train/train_hqs_unet_5.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'train' \ 3 | --model 'hqs-net-unet' \ 4 | --acc 5 \ 5 | --resume 0 \ 6 | --batch_size 1 \ 7 | --lr 1e-3 \ 8 | --val_on_epochs 2 \ 9 | --num_epoch 100 \ 10 | --train_path "data/fs_train.npy" \ 11 | --val_path "data/fs_val.npy" \ 12 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /run_sh/acc_10/train/train_hqs_unet_10.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --mode 'train' \ 3 | --model 'hqs-net-unet' \ 4 | --acc 10 \ 5 | --resume 0 \ 6 | --batch_size 1 \ 7 | --lr 1e-3 \ 8 | --val_on_epochs 2 \ 9 | --num_epoch 100 \ 10 | --train_path "data/fs_train.npy" \ 11 | --val_path "data/fs_val.npy" \ 12 | --test_path "data/fs_test.npy" -------------------------------------------------------------------------------- /cal_model_flops.py: -------------------------------------------------------------------------------- 1 | # @author : Bingyu Xin 2 | # @Institute : CS@Rutgers 3 | 4 | import torch 5 | from fvcore.nn import FlopCountAnalysis 6 | 7 | from model.ISTANet_plus import ISTANetplus 8 | from model.DCCNN import DCCNN 9 | from model.HQSNet import HQSNet 10 | from model.LPDNet import LPDNet 11 | 12 | net1 = DCCNN(n_iter=8) 13 | net2 = ISTANetplus(n_iter=8) 14 | net3 = LPDNet(n_iter=8) 15 | net4 = HQSNet(block_type='cnn', buffer_size=5, n_iter=8) 16 | net5 = HQSNet(block_type='unet', n_iter=10) 17 | net = [net1, net2, net3, net4, net5] 18 | model_name = ['dc-cnn', 'ista-net-plus', 'lpd-net', 'hqs-net', 'hqs-net-unet'] 19 | 20 | im_A_und = torch.randn((1, 2, 192, 160)).cuda() 21 | k_A_und = torch.randn((1, 2, 192, 160)).cuda() 22 | mask = torch.randn((1, 2, 192, 160)).cuda() 23 | for i in range(len(net)): 24 | flops = FlopCountAnalysis(net[i].cuda().eval(), (im_A_und, k_A_und, mask)) 25 | ## ignore the information for unspported operation when calculating flops 26 | flops._enable_warn_unsupported_ops = False 27 | print('--Information of ' + model_name[i] + ': ') 28 | print(' Total # of params: %.5fM' % (sum(p.numel() for p in net[i].parameters()) / 10. ** 6)) 29 | print(' Total # of params: %.5fG' % ((flops.total()) / 10. ** 9)) 30 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # @author : Bingyu Xin 2 | # @Institute : CS@Rutgers 3 | 4 | import argparse 5 | from Solver import Solver 6 | 7 | 8 | def main(args): 9 | print(args) 10 | solver = Solver(args) 11 | if args.mode == 'test': 12 | solver.test() 13 | elif args.mode == 'train': 14 | solver.train() 15 | 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | ############################### experiment settings ########################## 20 | parser.add_argument('--mode', default='train', choices=['train', 'test'], 21 | help='mode for the program') 22 | parser.add_argument('--model', default='hqs-net', 23 | choices=['dc-cnn', 'lpd-net', 'hqs-net', 'hqs-net-unet', 'ista-net-plus'], 24 | help='models to reconstruct') 25 | parser.add_argument('--acc', type=int, default=5, 26 | help='Acceleration factor for k-space sampling') 27 | ############################### dataset setting ############################### 28 | 29 | parser.add_argument('--train_path', default="data/fs_train.npy", 30 | help='train_path') 31 | parser.add_argument('--val_path', default="data/fs_val.npy", 32 | help='val_path') 33 | parser.add_argument('--test_path', default="data/fs_test.npy", 34 | help='test_path') 35 | 36 | ############################### model training settings ######################## 37 | parser.add_argument('--num_epoch', type=int, default=300, 38 | help='num of training epoch') 39 | parser.add_argument('--val_on_epochs', type=int, default=1, 40 | help='validate for each n epochs') 41 | parser.add_argument('--batch_size', type=int, default=1, 42 | help='batch size, 1,4,8,16, ...') 43 | parser.add_argument('--lr', type=float, default=1e-3, 44 | help='learning rate for training') 45 | parser.add_argument('--resume', type=int, default=0, choices=[0, 1], 46 | help='resume training') 47 | 48 | args = parser.parse_args() 49 | 50 | main(args) 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/hellopipu/434860fccdc7cd67278f079c41bbc7e1/demo_hqsnet_single_coil_cardiac_mr_reconstruction.ipynb) 2 | 3 | ## HQS-Net 4 | 5 | pytorch implementation of the paper **Learned Half-Quadratic Splitting Network for Magnetic Resonance Image 6 | Reconstruction** (https://openreview.net/pdf?id=h7rXUbALijU) 7 | 8 | ### Install 9 | 10 | python>=3.7.11 is required with all requirements.txt installed including pytorch>=1.10.0 11 | 12 | ```shell 13 | git clone https://github.com/hellopipu/HQS-Net.git 14 | cd HQS-Net 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ### Prepare dataset 19 | 20 | you can find more information about OCMR dataset at https://ocmr.info/ 21 | 22 | ```shell 23 | ## download dataset 24 | wget -nc https://ocmr.s3.amazonaws.com/data/ocmr_cine.tar.gz -P data/ 25 | ## download dataset attributes csv file 26 | wget -nc https://raw.githubusercontent.com/MRIOSU/OCMR/master/ocmr_data_attributes.csv -P data/ 27 | ## untar dataset 28 | tar -xzvf data/ocmr_cine.tar.gz -C data/ 29 | ## preprocess and split dataset, it takes several hours 30 | python preprocess_ocmr.py 31 | ``` 32 | 33 | Or you can directly download the preprocessed dataset [here](https://github.com/hellopipu/HQS-Net/releases/tag/v0.0), 34 | and then put them to `data/` folder 35 | 36 | ### Training 37 | 38 | Training and testing Scripts for all experiments in the paper can be found in folder `run_sh`. For example, if you want 39 | to train HQS-Net on accleration factor of 5x, you can run: 40 | 41 | ```shell 42 | sh run_sh/acc_5/train/train_hqs_5.sh 43 | ``` 44 | 45 | or if you want to train Unet based HQS-Net on accleration factors 10x, you can run: 46 | 47 | ```shell 48 | sh run_sh/acc_10/train/train_hqs_unet_10.sh 49 | ``` 50 | 51 | ### Testing 52 | 53 | For example, if you want to test HQS-Net on accleration factor of 5x, you can run: 54 | 55 | ```shell 56 | sh run_sh/acc_5/test/test_hqs_5.sh 57 | ``` 58 | 59 | All pretrained models in the paper can be downlowned [here](https://github.com/hellopipu/HQS-Net/releases/tag/v0.0), 60 | then you should put them to `weight/` folder. 61 | 62 | We also provide an Colab 63 | demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/hellopipu/434860fccdc7cd67278f079c41bbc7e1/demo_hqsnet_single_coil_cardiac_mr_reconstruction.ipynb) 64 | 65 | . 66 | 67 | ### Tensorboard 68 | 69 | tensorboard for checking the curves while training 70 | 71 | ```shell 72 | tensorboard --logdir log 73 | ``` -------------------------------------------------------------------------------- /model/DCCNN.py: -------------------------------------------------------------------------------- 1 | # @author : Bingyu Xin 2 | # @Institute : CS@Rutgers 3 | import torch 4 | import torch.nn as nn 5 | from model.BasicModule import conv_block 6 | 7 | 8 | class DCCNN(nn.Module): 9 | def __init__(self, n_iter=8, n_convs=6, n_filters=64, norm='ortho'): 10 | ''' 11 | DC-CNN modified from paper " A Deep Cascade of Convolutional Neural Networks for Dynamic MR Image Reconstruction " 12 | ( https://arxiv.org/pdf/1704.02422.pdf ) ( https://github.com/js3611/Deep-MRI-Reconstruction ) 13 | :param n_iter: num of iterations 14 | :param n_convs: num of convs in each block 15 | :param n_filters: num of feature channels in intermediate features 16 | :param norm: 'ortho' norm for fft 17 | ''' 18 | super(DCCNN, self).__init__() 19 | channel_in = 2 20 | rec_blocks = [] 21 | self.norm = norm 22 | self.mu = nn.Parameter(torch.Tensor([0.5])) 23 | self.n_iter = n_iter 24 | for i in range(n_iter): 25 | rec_blocks.append(conv_block('dc-cnn', channel_in, n_filters=n_filters, n_convs=n_convs)) 26 | 27 | self.rec_blocks = nn.ModuleList(rec_blocks) 28 | 29 | def dc_operation(self, x_rec, k_un, mask): 30 | x_rec = x_rec.permute(0, 2, 3, 1) 31 | mask = mask.permute(0, 2, 3, 1) 32 | k_un = k_un.permute(0, 2, 3, 1) 33 | k_rec = torch.fft.fft2(torch.view_as_complex(x_rec.contiguous()), norm=self.norm) 34 | 35 | k_rec = torch.view_as_real(k_rec) 36 | # noiseless 37 | k_out = k_rec + (k_un - k_rec) * mask 38 | 39 | k_out = torch.view_as_complex(k_out) 40 | x_out = torch.view_as_real(torch.fft.ifft2(k_out, norm=self.norm)) 41 | x_out = x_out.permute(0, 3, 1, 2) 42 | return x_out 43 | 44 | def _forward_operation(self, img, mask): 45 | 46 | k = torch.fft.fft2(torch.view_as_complex(img.permute(0, 2, 3, 1).contiguous()), 47 | norm=self.norm) 48 | k = torch.view_as_real(k).permute(0, 3, 1, 2).contiguous() 49 | k = mask * k 50 | return k 51 | 52 | def _backward_operation(self, k, mask): 53 | 54 | k = mask * k 55 | img = torch.fft.ifft2(torch.view_as_complex(k.permute(0, 2, 3, 1).contiguous()), norm=self.norm) 56 | img = torch.view_as_real(img).permute(0, 3, 1, 2).contiguous() 57 | return img 58 | 59 | def update_opration(self, f_1, k, mask): 60 | h_1 = k - self._forward_operation(f_1, mask) 61 | update = f_1 + self.mu * self._backward_operation(h_1, mask) 62 | return update 63 | 64 | def forward(self, x, k, m): 65 | for i in range(self.n_iter): 66 | # x = self.update_opration(x, k, m) 67 | x_cnn = self.rec_blocks[i](x) 68 | x = x + x_cnn 69 | x = self.update_opration(x, k, m) 70 | return x 71 | -------------------------------------------------------------------------------- /model/LPDNet.py: -------------------------------------------------------------------------------- 1 | # @author : Bingyu Xin 2 | # @Institute : CS@Rutgers 3 | 4 | import torch 5 | from torch import nn 6 | from model.BasicModule import conv_block 7 | 8 | 9 | class LPDNet(nn.Module): 10 | def __init__(self, n_primal=5, n_dual=5, n_iter=8, n_convs=6, n_filters=64, norm='ortho'): 11 | ''' 12 | LPD-Net modified from paper " Learned primal-dual reconstruction " 13 | ( https://arxiv.org/abs/1707.06474 ) ( https://github.com/adler-j/learned_primal_dual ) 14 | :param n_primal: buffer size for primal space ( image space ) 15 | :param n_dual: buffer size for dual space ( k-space ) 16 | :param n_iter: num of iterations 17 | :param n_convs: num of convs in the block 18 | :param n_filters: num of feature channels in intermediate features 19 | :param norm: 'ortho' norm for fft 20 | ''' 21 | super().__init__() 22 | self.norm = norm 23 | self.n_primal = n_primal 24 | self.n_dual = n_dual 25 | self.n_iter = n_iter 26 | image_net_block = [] 27 | kspace_net_block = [] 28 | 29 | for i in range(self.n_iter): 30 | image_net_block.append( 31 | conv_block('prim-net', channel_in=2 * (self.n_primal + 1), n_convs=n_convs, n_filters=n_filters)) 32 | self.primal_net = nn.ModuleList(image_net_block) 33 | 34 | for i in range(self.n_iter): 35 | kspace_net_block.append( 36 | conv_block('dual-net', channel_in=2 * (self.n_dual + 2), n_convs=n_convs, n_filters=n_filters)) 37 | self.dual_net = nn.ModuleList(kspace_net_block) 38 | 39 | def _forward_operation(self, img, mask): 40 | 41 | k = torch.fft.fft2(torch.view_as_complex(img.permute(0, 2, 3, 1).contiguous()), 42 | norm=self.norm) 43 | k = torch.view_as_real(k).permute(0, 3, 1, 2).contiguous() 44 | k = mask * k 45 | 46 | return k 47 | 48 | def _backward_operation(self, k, mask): 49 | 50 | k = mask * k 51 | img = torch.fft.ifft2(torch.view_as_complex(k.permute(0, 2, 3, 1).contiguous()), norm=self.norm) 52 | img = torch.view_as_real(img).permute(0, 3, 1, 2).contiguous() 53 | 54 | return img 55 | 56 | def forward(self, img, k, mask): 57 | 58 | dual_buffer = torch.cat([k] * self.n_dual, 1).to(k.device) 59 | primal_buffer = torch.cat([img] * self.n_primal, 1).to(k.device) 60 | 61 | for i in range(self.n_iter): # 62 | # kspace (dual) 63 | f_2 = primal_buffer[:, 2:4].clone() 64 | dual_buffer = dual_buffer + self.dual_net[i]( 65 | torch.cat([dual_buffer, self._forward_operation(f_2, mask), k], 1) 66 | ) 67 | h_1 = dual_buffer[:, 0:2].clone() 68 | # image space (primal) 69 | primal_buffer = primal_buffer + self.primal_net[i]( 70 | torch.cat([primal_buffer, self._backward_operation(h_1, mask)], 1) 71 | ) 72 | 73 | return primal_buffer[:, 0:2] 74 | -------------------------------------------------------------------------------- /model/HQSNet.py: -------------------------------------------------------------------------------- 1 | # @author : Bingyu Xin 2 | # @Institute : CS@Rutgers 3 | import torch 4 | from torch import nn 5 | from model.BasicModule import conv_block 6 | from model.BasicModule import UNetRes 7 | 8 | 9 | class HQSNet(nn.Module): 10 | def __init__(self, buffer_size=5, n_iter=8, n_convs=6, n_filters=64, block_type='cnn', norm='ortho'): 11 | ''' 12 | HQS-Net from paper " Learned Half-Quadratic Splitting Network for MR Image Reconstruction " 13 | ( https://openreview.net/pdf?id=h7rXUbALijU ) ( https://github.com/hellopipu/HQS-Net ) 14 | :param buffer_size: buffer_size m 15 | :param n_iter: iterations n 16 | :param n_convs: convolutions in each reconstruction block 17 | :param n_filters: output channel for convolutions 18 | :param block_type: 'cnn' or 'unet 19 | :param norm: 'ortho' norm for fft 20 | ''' 21 | 22 | super().__init__() 23 | self.norm = norm 24 | self.m = buffer_size 25 | self.n_iter = n_iter 26 | ## the initialization of mu may influence the final accuracy 27 | self.mu = nn.Parameter(0.5 * torch.ones((1, 1))) # 2 28 | self.block_type = block_type 29 | if self.block_type == 'cnn': 30 | rec_blocks = [] 31 | for i in range(self.n_iter): 32 | rec_blocks.append( 33 | conv_block('hqs-net', channel_in=2 * (self.m + 1), n_convs=n_convs, 34 | n_filters=n_filters)) # self.m + 35 | self.rec_blocks = nn.ModuleList(rec_blocks) 36 | elif self.block_type == 'unet': 37 | self.rec_blocks = UNetRes(in_nc=2 * (self.m + 1), out_nc=2 * self.m, nc=[64, 128, 256, 512], nb=4, 38 | act_mode='R', 39 | downsample_mode="strideconv", upsample_mode="convtranspose") 40 | 41 | def _forward_operation(self, img, mask): 42 | 43 | k = torch.fft.fft2(torch.view_as_complex(img.permute(0, 2, 3, 1).contiguous()), 44 | norm=self.norm) 45 | k = torch.view_as_real(k).permute(0, 3, 1, 2).contiguous() 46 | k = mask * k 47 | return k 48 | 49 | def _backward_operation(self, k, mask): 50 | k = mask * k 51 | img = torch.fft.ifft2(torch.view_as_complex(k.permute(0, 2, 3, 1).contiguous()), norm=self.norm) 52 | img = torch.view_as_real(img).permute(0, 3, 1, 2).contiguous() 53 | return img 54 | 55 | def update_opration(self, f_1, k, mask): 56 | h_1 = k - self._forward_operation(f_1, mask) 57 | update = f_1 + self.mu * self._backward_operation(h_1, mask) 58 | return update 59 | 60 | def forward(self, img, k, mask): 61 | ''' 62 | :param img: zero-filled images, (batch,2,h,w) 63 | :param k: corresponding undersampled k-space data , (batch,2,h,w) 64 | :param mask: uncentered sampling mask , (batch,2,h,w) 65 | :return: reconstructed img 66 | ''' 67 | 68 | ## initialize buffer f : the concatenation of m copies of the complex-valued zero-filled images 69 | f = torch.cat([img] * self.m, 1).to(img.device) 70 | 71 | ## n reconstruction blocks 72 | for i in range(self.n_iter): 73 | f_1 = f[:, 0:2].clone() 74 | updated_f_1 = self.update_opration(f_1, k, mask) 75 | if self.block_type == 'cnn': 76 | f = f + self.rec_blocks[i](torch.cat([f, updated_f_1], 1)) 77 | elif self.block_type == 'unet': 78 | f = f + self.rec_blocks(torch.cat([f, updated_f_1], 1)) 79 | return f[:, 0:2] 80 | -------------------------------------------------------------------------------- /read_data.py: -------------------------------------------------------------------------------- 1 | # @author : Bingyu Xin 2 | # @Institute : CS@Rutgers 3 | 4 | import torch 5 | import numpy as np 6 | from torch.utils.data import Dataset 7 | from torchvision import transforms 8 | from torchvision.transforms import RandomApply, RandomRotation, ToTensor, RandomResizedCrop, \ 9 | Compose, RandomAffine, RandomHorizontalFlip, RandomVerticalFlip, RandomPerspective 10 | from utils import undersample, cartesian_mask 11 | 12 | 13 | class MyData(Dataset): 14 | def __init__(self, imageDir, acc=5, img_size=256, is_training='train'): 15 | super().__init__() 16 | 17 | self.img_size = img_size ## used in transform 18 | self.is_training = is_training 19 | self.acc = acc 20 | 21 | self.images = np.load(imageDir) 22 | self.len = len(self.images) 23 | self.custom_transform = [ToTensor()] 24 | if self.is_training == 'train': 25 | ## random image augmentation when training 26 | self.custom_transform += [ 27 | RandomApply(torch.nn.ModuleList([RandomResizedCrop(self.img_size, scale=(0.9, 1.0), ratio=(0.9, 1.1))]), 28 | p=0.3), 29 | RandomApply(torch.nn.ModuleList([RandomAffine(20, translate=(0.1, 0.1), scale=(0.9, 1.1), 30 | shear=(-5, 5, -5, 5), 31 | interpolation=transforms.InterpolationMode.BILINEAR)]), 32 | p=0.3), 33 | # RandomHorizontalFlip(p=0.3), 34 | # RandomVerticalFlip(p=0.3), 35 | # RandomPerspective(0.05, 0.3), 36 | ] 37 | else: 38 | ## generate a fixed mask for validating and testing 39 | mask = cartesian_mask(self.img_size, acc=self.acc, centred=False, sample_random=False) 40 | self.mask = torch.from_numpy(mask) 41 | 42 | def transform(self, img_A): 43 | ''' 44 | 45 | :param img_A: numpy array, (2,H,W) 46 | :return: torch tensor, complex, (H,W) 47 | ''' 48 | 49 | img_A = img_A.transpose(1, 2, 0) 50 | for t in self.custom_transform: 51 | img_A = t(img_A) 52 | 53 | ## normalize to [0,1] 54 | # img_A = img_A / img_A.max() 55 | ## 2 channel real to complex 56 | img_A = img_A[0] + 1j * img_A[1] 57 | 58 | return img_A 59 | 60 | def get_sample(self, index, mask): 61 | 62 | image_A = self.images[index] 63 | ## data norm 64 | image_A_abs = (image_A[0] ** 2 + image_A[1] ** 2) ** 0.5 65 | image_A = image_A / np.percentile(image_A_abs, 99) 66 | ########################### image preprocessing ########################## 67 | # transform 68 | image_A = self.transform(image_A) 69 | # generate zero-filled image x_und, k_und, k 70 | image_A_und, k_A_und, k_A = undersample(image_A, mask) 71 | 72 | ########################## complex to 2 channel ########################## 73 | im_A = torch.view_as_real(image_A).permute(2, 0, 1).contiguous() 74 | im_A_und = torch.view_as_real(image_A_und).permute(2, 0, 1).contiguous() 75 | k_A_und = torch.view_as_real(k_A_und).permute(2, 0, 1).contiguous() 76 | 77 | return im_A, im_A_und, k_A_und 78 | 79 | def __getitem__(self, i): 80 | if self.is_training == 'train': 81 | ## generate random masks for training 82 | mask = cartesian_mask(self.img_size, acc=self.acc, centred=False, sample_random=True) 83 | mask = torch.from_numpy(mask) 84 | else: 85 | ## use fixed mask for validation and test 86 | mask = self.mask 87 | ## generate samples 88 | im_A, im_A_und, k_A_und = self.get_sample(i, mask) 89 | mask = torch.view_as_real(mask * (1. + 1.j)).permute(2, 0, 1).contiguous() 90 | 91 | return {'im_A': im_A, 'im_A_und': im_A_und, 'k_A_und': k_A_und, 'mask_A': mask} 92 | 93 | def __len__(self): 94 | return self.len 95 | -------------------------------------------------------------------------------- /model/ISTANet_plus.py: -------------------------------------------------------------------------------- 1 | # @author : Bingyu Xin 2 | # @Institute : CS@Rutgers 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn import init 7 | 8 | 9 | # Define ISTA-Net-plus Block 10 | class BasicBlock(torch.nn.Module): 11 | def __init__(self, norm='ortho'): 12 | super(BasicBlock, self).__init__() 13 | self.norm = norm 14 | self.lambda_step = nn.Parameter(torch.Tensor([0.5])) 15 | self.soft_thr = nn.Parameter(torch.Tensor([0.01])) 16 | 17 | num_filter = 64 18 | 19 | self.conv_D = nn.Parameter(init.xavier_normal_(torch.Tensor(num_filter, 2, 3, 3))) 20 | 21 | self.conv1_forward = nn.Parameter(init.xavier_normal_(torch.Tensor(num_filter, num_filter, 3, 3))) 22 | self.conv2_forward = nn.Parameter(init.xavier_normal_(torch.Tensor(num_filter, num_filter, 3, 3))) 23 | self.conv1_backward = nn.Parameter(init.xavier_normal_(torch.Tensor(num_filter, num_filter, 3, 3))) 24 | self.conv2_backward = nn.Parameter(init.xavier_normal_(torch.Tensor(num_filter, num_filter, 3, 3))) 25 | 26 | self.conv_G = nn.Parameter(init.xavier_normal_(torch.Tensor(2, num_filter, 3, 3))) 27 | 28 | def _forward_operation(self, img, mask): 29 | 30 | k = torch.fft.fft2(torch.view_as_complex(img.permute(0, 2, 3, 1).contiguous()), 31 | norm=self.norm) 32 | k = torch.view_as_real(k).permute(0, 3, 1, 2).contiguous() 33 | k = mask * k 34 | return k 35 | 36 | def _backward_operation(self, k, mask): 37 | 38 | k = mask * k 39 | img = torch.fft.ifft2(torch.view_as_complex(k.permute(0, 2, 3, 1).contiguous()), norm=self.norm) 40 | img = torch.view_as_real(img).permute(0, 3, 1, 2).contiguous() 41 | return img 42 | 43 | def update_opration(self, f_1, k, mask): 44 | h_1 = k - self._forward_operation(f_1, mask) 45 | update = f_1 + self.lambda_step * self._backward_operation(h_1, mask) 46 | return update 47 | 48 | def forward(self, x, k, m): 49 | # x = x - self.lambda_step * fft_forback(x, mask) 50 | # x = x + self.lambda_step * PhiTb 51 | x = self.update_opration(x, k, m) 52 | x_input = x 53 | 54 | x_D = F.conv2d(x_input, self.conv_D, padding=1) 55 | 56 | x = F.conv2d(x_D, self.conv1_forward, padding=1) 57 | x = F.relu(x) 58 | x_forward = F.conv2d(x, self.conv2_forward, padding=1) 59 | 60 | x = torch.mul(torch.sign(x_forward), F.relu(torch.abs(x_forward) - self.soft_thr)) 61 | 62 | x = F.conv2d(x, self.conv1_backward, padding=1) 63 | x = F.relu(x) 64 | x_backward = F.conv2d(x, self.conv2_backward, padding=1) 65 | 66 | x_G = F.conv2d(x_backward, self.conv_G, padding=1) 67 | 68 | x_pred = x_input + x_G 69 | 70 | if self.training: 71 | x = F.conv2d(x_forward, self.conv1_backward, padding=1) 72 | x = F.relu(x) 73 | x_D_est = F.conv2d(x, self.conv2_backward, padding=1) 74 | symloss = x_D_est - x_D 75 | return x_pred, symloss 76 | else: 77 | return x_pred, None 78 | 79 | 80 | class ISTANetplus(nn.Module): 81 | def __init__(self, n_iter=8, n_convs=5, n_filters=64, norm='ortho'): 82 | ''' 83 | ISTANetplus modified from paper " ISTA-Net: Interpretable Optimization-Inspired Deep Network for Image 84 | Compressive Sensing " 85 | ( https://arxiv.org/pdf/1706.07929.pdf ) ( https://github.com/jianzhangcs/ISTA-Net-PyTorch ) 86 | :param n_iter: num of iterations 87 | :param n_convs: num of convs in each block 88 | :param n_filters: num of feature channels in intermediate features 89 | :param norm: 'ortho' norm for fft 90 | ''' 91 | super(ISTANetplus, self).__init__() 92 | channel_in = 2 93 | rec_blocks = [] 94 | self.norm = norm 95 | self.n_iter = n_iter 96 | for i in range(n_iter): 97 | rec_blocks.append(BasicBlock(norm=self.norm)) 98 | self.rec_blocks = nn.ModuleList(rec_blocks) 99 | 100 | def forward(self, x, k, m): 101 | layers_sym = [] # for computing symmetric loss 102 | for i in range(self.n_iter): 103 | x, layer_sym = self.rec_blocks[i](x, k, m) 104 | layers_sym.append(layer_sym) 105 | if self.training: 106 | return x, layers_sym 107 | else: 108 | return x 109 | -------------------------------------------------------------------------------- /preprocess_ocmr.py: -------------------------------------------------------------------------------- 1 | # @author : Bingyu Xin 2 | # @Institute : CS@Rutgers 3 | 4 | ''' 5 | File function: 6 | 1. Generate emulated single-coil data from OCMR dataset using the method in 7 | https://arxiv.org/pdf/1811.08026.pdf. This procedure takes several hours. 8 | 2. Split train, val, and test sets to 1874, 544, and 1104, respectively 9 | ''' 10 | 11 | import os 12 | import pandas 13 | from tqdm import tqdm 14 | from utils import read_ocmr 15 | import numpy as np 16 | from numpy.fft import fftshift, ifftshift, ifft2 17 | import math 18 | from scipy.optimize import minimize 19 | from utils import pad_crop 20 | 21 | 22 | def get_data(path_dir, csv_file, scn): 23 | ''' 24 | 25 | :param path_dir: data folder path 26 | :param csv_file: 27 | :param scn: 28 | :return: 29 | ''' 30 | print('------- Emulating single-coil for scn = {} -------'.format(scn)) 31 | array_list = [] 32 | ## read csv files 33 | df = pandas.read_csv(csv_file) 34 | ## Cleanup empty rows and columns 35 | df.dropna(how='all', axis=0, inplace=True) 36 | df.dropna(how='all', axis=1, inplace=True) 37 | ## filter files 38 | selected_df = df.query('`file name`.str.contains("fs_") and fov=="noa" and scn==' + scn, engine='python') 39 | list_name = selected_df['file name'] 40 | 41 | for filename in tqdm(list_name): 42 | data_path = os.path.join(path_dir, filename) 43 | kData, param = read_ocmr(data_path) 44 | dim_kData = kData.shape 45 | CH = dim_kData[3] 46 | 47 | ## average the k-space if average > 1 48 | kData_tmp = np.mean(kData, axis=8); 49 | 50 | ## Coil images are combined using SOS (sum of square.) 51 | im_coil = fftshift(ifft2(ifftshift(kData_tmp, (0, 1)), axes=(0, 1), norm='ortho'), (0, 1)) # IFFT (2D image) 52 | im_sos = np.sqrt(np.sum(np.abs(im_coil) ** 2, 3)) # Sum of Square 53 | 54 | ## Remove ReadOut oversampling 55 | RO = im_sos.shape[0] 56 | image = im_sos[math.floor(RO / 4):math.floor(RO / 4 * 3), :, :] # Remove RO oversampling 57 | im_coil_ = im_coil[math.floor(RO / 4):math.floor(RO / 4 * 3), :, :] 58 | image = image.reshape((image.shape[0], image.shape[1], -1)) 59 | im_coil_ = im_coil_.reshape((image.shape[0], image.shape[1], CH, -1)) 60 | 61 | ## pad or crop to fixed size 62 | image = pad_crop(image, (192, 160, image.shape[2])) 63 | im_coil_ = pad_crop(im_coil_, (192, 160, im_coil_.shape[2], im_coil_.shape[3])) 64 | 65 | ## emulate single-coil img from multi-coil using LBFGS 66 | def error_func(x): 67 | kk = np.matmul(im_coil_.transpose(0, 1, 3, 2), x) 68 | error = np.sum((np.abs(kk) ** 0.5 - np.abs(image) ** 0.5) ** 2) ** 0.5 69 | print('emulating error: ', error, end="\r") 70 | return error 71 | 72 | x0 = np.ones((CH, 1)) / CH 73 | res = minimize(error_func, x0, method='BFGS') 74 | esc = np.matmul(im_coil_.transpose(0, 1, 3, 2), res.x) 75 | 76 | ### save 77 | esc_comp = np.stack([np.real(esc), np.imag(esc)], axis=-1).astype(np.float32) 78 | array_list.append(esc_comp) 79 | 80 | return array_list 81 | 82 | 83 | if __name__ == '__main__': 84 | os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" 85 | 86 | ### 1. simulating the single-coil MRI from multi-coil data using the method in https://arxiv.org/pdf/1811.08026.pdf 87 | ocmr_data_attributes_location = 'data/ocmr_data_attributes.csv' 88 | ocmr_data_location = 'data/OCMR_data' 89 | 90 | im_avan = get_data(ocmr_data_location, ocmr_data_attributes_location, '"15avan"') 91 | im_30pris = get_data(ocmr_data_location, ocmr_data_attributes_location, '"30pris"') 92 | im_15sola = get_data(ocmr_data_location, ocmr_data_attributes_location, '"15sola"') 93 | 94 | ## 2. split dataset 95 | trainset = im_avan[0:6] + im_30pris[0:22] + im_15sola[0:13] 96 | valset = im_avan[6:8] + im_30pris[22:29] + im_15sola[13:17] 97 | testset = im_avan[8::] + im_30pris[29::] + im_15sola[17::] 98 | 99 | trainset = np.concatenate(trainset, axis=2).transpose(2, 3, 0, 1) 100 | valset = np.concatenate(valset, axis=2).transpose(2, 3, 0, 1) 101 | testset = np.concatenate(testset, axis=2).transpose(2, 3, 0, 1) 102 | 103 | ## 3. save to numpy 104 | np.save("data/fs_train.npy", trainset) 105 | np.save("data/fs_val.npy", valset) 106 | np.save("data/fs_test.npy", testset) 107 | 108 | print('slices of trainset: ', len(trainset)) 109 | print('slices of valset: ', len(valset)) 110 | print('slices of testset: ', len(testset)) 111 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # @author : Bingyu Xin 2 | # @Institute : CS@Rutgers 3 | 4 | import os 5 | import time 6 | import operator 7 | import torch 8 | import ismrmrd 9 | import ismrmrd.xsd 10 | import numpy as np 11 | from numpy.lib.stride_tricks import as_strided 12 | from torch.fft import fft2, ifft2 13 | 14 | 15 | def is_num(a): 16 | return isinstance(a, int) or isinstance(a, float) 17 | 18 | 19 | def delta(x1, x2): 20 | delta_ = x2 - x1 21 | return delta_ // 2, delta_ - delta_ // 2 22 | 23 | 24 | def get_padding_width(o_shape, d_shape): 25 | if is_num(o_shape): 26 | o_shape, d_shape = [o_shape], [d_shape] 27 | assert len(o_shape) == len(d_shape), 'Length mismatched!' 28 | borders = [] 29 | for o, d in zip(o_shape, d_shape): 30 | borders.extend(delta(o, d)) 31 | return borders 32 | 33 | 34 | def get_crop_width(o_shape, d_shape): 35 | return get_padding_width(d_shape, o_shape) 36 | 37 | 38 | def get_padding_shape_with_stride(o_shape, stride): 39 | assert isinstance(o_shape, list) or isinstance(o_shape, tuple) or isinstance(o_shape, np.ndarray) 40 | o_shape = np.array(o_shape) 41 | d_shape = np.ceil(o_shape / stride) * stride 42 | return d_shape.astype(np.int32) 43 | 44 | 45 | def pad(arr, d_shape, mode='constant', value=0, strict=True): 46 | """ 47 | pad numpy array, tested! 48 | :param arr: numpy array 49 | :param d_shape: array shape after padding or minimum shape 50 | :param mode: padding mode, 51 | :param value: padding value 52 | :param strict: if True, d_shape must be greater than arr shape and output shape is d_shape. if False, d_shape is minimum shape and output shape is np.maximum(arr.shape, d_shape) 53 | :return: padded arr with expected shape 54 | """ 55 | assert arr.ndim == len(d_shape), 'Dimension mismatched!' 56 | if not strict: 57 | d_shape = np.maximum(arr.shape, d_shape) 58 | else: 59 | assert np.all(np.array(d_shape) >= np.array(arr.shape)), 'Padding shape must be greater than arr shape' 60 | borders = np.array(get_padding_width(arr.shape, d_shape)) 61 | before = borders[list(range(0, len(borders), 2))] 62 | after = borders[list(range(1, len(borders), 2))] 63 | padding_borders = tuple(zip([int(x) for x in before], [int(x) for x in after])) 64 | # print(padding_borders) 65 | if mode == 'constant': 66 | return np.pad(arr, padding_borders, mode=mode, constant_values=value) 67 | else: 68 | return np.pad(arr, padding_borders, mode=mode) 69 | 70 | 71 | def crop(arr, d_shape, strict=True): 72 | """ 73 | central crop numpy array, tested! 74 | :param arr: numpy array 75 | :param d_shape: expected shape 76 | :return: cropped array with expected array 77 | """ 78 | assert arr.ndim == len(d_shape), 'Dimension mismatched!' 79 | if not strict: 80 | d_shape = np.minimum(arr.shape, d_shape) 81 | else: 82 | assert np.all(np.array(d_shape) <= np.array(arr.shape)), 'Crop shape must be smaller than arr shape' 83 | borders = np.array(get_crop_width(arr.shape, d_shape)) 84 | start = borders[list(range(0, len(borders), 2))] 85 | # end = - borders[list(range(1, len(borders), 2))] 86 | end = map(operator.add, start, d_shape) 87 | slices = tuple(map(slice, start, end)) 88 | return arr[slices] 89 | 90 | 91 | def pad_crop(arr, d_shape, mode='constant', value=0): 92 | """ 93 | pad or crop numpy array to expected shape, tested! 94 | :param arr: numpy array 95 | :param d_shape: expected shape 96 | :param mode: padding mode, 97 | :param value: padding value 98 | :return: padded and cropped array 99 | """ 100 | assert arr.ndim == len(d_shape), 'Dimension mismatched!' 101 | arr = pad(arr, d_shape, mode, value, strict=False) 102 | return crop(arr, d_shape) 103 | 104 | 105 | def undersample(image, mask, norm='ortho'): 106 | assert image.shape == mask.shape 107 | # the standard way to get k-space from image 108 | # should be like fftshift(fft2(ifftshift(img,dim=(-1,-2)),norm=norm),dim=(-1,-2)), 109 | # we omit fftshift/ifftshift for simplicity, and now k is uncentered. so we also use uncentered mask 110 | k = fft2(image, norm=norm) 111 | k_und = mask * k 112 | x_und = ifft2(k_und, norm=norm) 113 | 114 | return x_und, k_und, k 115 | 116 | 117 | def output2complex(im_tensor): 118 | ''' 119 | param: im_tensor : [B, 2, W, H] 120 | return : [B,W,H] magnitude of complex value 121 | ''' 122 | ############## revert each channel to [0,1.] range 123 | im_tensor = torch.view_as_complex(im_tensor.permute(0, 2, 3, 1).contiguous()).abs() 124 | 125 | return im_tensor 126 | 127 | 128 | def normal_pdf(length, sensitivity): 129 | return np.exp(-sensitivity * (np.arange(length) - length / 2) ** 2) 130 | 131 | 132 | ''' 133 | modified from https://github.com/js3611/Deep-MRI-Reconstruction/blob/master/utils/compressed_sensing.py 134 | ''' 135 | 136 | 137 | def cartesian_mask(shape: object, acc: object, centred: object = False, 138 | sample_random=True) -> object: 139 | """ 140 | Sampling density estimated from implementation of kt FOCUSS 141 | 142 | shape: tuple - of form (..., nx, ny) 143 | acc: float - doesn't have to be integer 4, 8, etc.. 144 | centered : if False, return uncentered mask 145 | sample_random: if True, generate random mask 146 | """ 147 | shape = shape[:-2] + (shape[-1], shape[-2]) 148 | # now acc only support 5 or 10, you can modify it yourself 149 | if acc == 5: 150 | center_fraction = 0.08 151 | elif acc == 10: 152 | center_fraction = 0.04 153 | 154 | N, Nx, Ny = int(np.prod(shape[:-2])), shape[-2], shape[-1] 155 | # sample_n: num of lines in low frequency to be sampled 156 | sample_n = int(round(Nx * center_fraction)) 157 | pdf_x = normal_pdf(Nx, 0.5 / (Nx / 10.) ** 2) 158 | lmda = Nx / (2. * acc) 159 | n_lines = int(Nx / acc) 160 | 161 | # add uniform distribution 162 | pdf_x += lmda * 1. / Nx 163 | 164 | if sample_n: 165 | pdf_x[Nx // 2 - sample_n // 2:Nx // 2 + sample_n - sample_n // 2] = 0 166 | pdf_x /= np.sum(pdf_x) 167 | n_lines -= sample_n 168 | 169 | mask = np.zeros((N, Nx)) 170 | ##################### modifications to enable random mask and fixed mask ######################### 171 | # set fixed seed 172 | if not sample_random: 173 | np.random.seed(233) 174 | ## set sampling lines 175 | for i in range(N): 176 | idx = np.random.choice(Nx, n_lines, False, pdf_x) 177 | mask[i, idx] = 1 178 | ## cancel seed when finish 179 | if not sample_random: 180 | t = 1000 * time.time() # current time in milliseconds 181 | np.random.seed(int(t) % 2 ** 32) 182 | ################################################################################################## 183 | 184 | if sample_n: 185 | mask[:, Nx // 2 - sample_n // 2:Nx // 2 + sample_n - sample_n // 2] = 1 186 | 187 | size = mask.itemsize 188 | mask = as_strided(mask, (N, Nx, Ny), (size * Nx, size, 0)) 189 | 190 | mask = mask.reshape(shape) 191 | 192 | if not centred: 193 | mask = np.fft.ifftshift(mask, axes=(-1, -2)) 194 | 195 | return mask.transpose((-1, -2)) 196 | 197 | 198 | ''' 199 | borrowed from https://github.com/MRIOSU/OCMR/blob/master/Python/read_ocmr.py 200 | ''' 201 | 202 | 203 | def read_ocmr(filename): 204 | # Before running the code, install ismrmrd-python and ismrmrd-python-tools: 205 | # https://github.com/ismrmrd/ismrmrd-python 206 | # https://github.com/ismrmrd/ismrmrd-python-tools 207 | # Last modified: 06-12-2020 by Chong Chen (Chong.Chen@osumc.edu) 208 | # 209 | # Input: *.h5 file name 210 | # Output: all_data k-space data, orgnazide as {'kx' 'ky' 'kz' 'coil' 'phase' 'set' 'slice' 'rep' 'avg'} 211 | # param some parameters of the scan 212 | # 213 | 214 | # This is a function to read K-space from ISMRMD *.h5 data 215 | # Modifid by Chong Chen (Chong.Chen@osumc.edu) based on the python script 216 | # from https://github.com/ismrmrd/ismrmrd-python-tools/blob/master/recon_ismrmrd_dataset.py 217 | 218 | if not os.path.isfile(filename): 219 | print("%s is not a valid file" % filename) 220 | raise SystemExit 221 | dset = ismrmrd.Dataset(filename, 'dataset', create_if_needed=False) 222 | header = ismrmrd.xsd.CreateFromDocument(dset.read_xml_header()) 223 | enc = header.encoding[0] 224 | 225 | # Matrix size 226 | eNx = enc.encodedSpace.matrixSize.x 227 | # eNy = enc.encodedSpace.matrixSize.y 228 | eNz = enc.encodedSpace.matrixSize.z 229 | eNy = (enc.encodingLimits.kspace_encoding_step_1.maximum + 1); # no zero padding along Ny direction 230 | 231 | # Field of View 232 | eFOVx = enc.encodedSpace.fieldOfView_mm.x 233 | eFOVy = enc.encodedSpace.fieldOfView_mm.y 234 | eFOVz = enc.encodedSpace.fieldOfView_mm.z 235 | 236 | # Save the parameters 237 | param = dict(); 238 | param['TRes'] = str(header.sequenceParameters.TR) 239 | param['FOV'] = [eFOVx, eFOVy, eFOVz] 240 | param['TE'] = str(header.sequenceParameters.TE) 241 | param['TI'] = str(header.sequenceParameters.TI) 242 | param['echo_spacing'] = str(header.sequenceParameters.echo_spacing) 243 | param['flipAngle_deg'] = str(header.sequenceParameters.flipAngle_deg) 244 | param['sequence_type'] = header.sequenceParameters.sequence_type 245 | 246 | # Read number of Slices, Reps, Contrasts, etc. 247 | nCoils = header.acquisitionSystemInformation.receiverChannels 248 | try: 249 | nSlices = enc.encodingLimits.slice.maximum + 1 250 | except: 251 | nSlices = 1 252 | 253 | try: 254 | nReps = enc.encodingLimits.repetition.maximum + 1 255 | except: 256 | nReps = 1 257 | 258 | try: 259 | nPhases = enc.encodingLimits.phase.maximum + 1 260 | except: 261 | nPhases = 1; 262 | 263 | try: 264 | nSets = enc.encodingLimits.set.maximum + 1; 265 | except: 266 | nSets = 1; 267 | 268 | try: 269 | nAverage = enc.encodingLimits.average.maximum + 1; 270 | except: 271 | nAverage = 1; 272 | 273 | # TODO loop through the acquisitions looking for noise scans 274 | firstacq = 0 275 | for acqnum in range(dset.number_of_acquisitions()): 276 | acq = dset.read_acquisition(acqnum) 277 | 278 | # TODO: Currently ignoring noise scans 279 | if acq.isFlagSet(ismrmrd.ACQ_IS_NOISE_MEASUREMENT): 280 | # print("Found noise scan at acq ", acqnum) 281 | continue 282 | else: 283 | firstacq = acqnum 284 | # print("Imaging acquisition starts acq ", acqnum) 285 | break 286 | 287 | # assymetry echo 288 | kx_prezp = 0; 289 | acq_first = dset.read_acquisition(firstacq) 290 | if acq_first.center_sample * 2 < eNx: 291 | kx_prezp = eNx - acq_first.number_of_samples 292 | 293 | # Initialiaze a storage array 294 | param['kspace_dim'] = {'kx ky kz coil phase set slice rep avg'}; 295 | all_data = np.zeros((eNx, eNy, eNz, nCoils, nPhases, nSets, nSlices, nReps, nAverage), dtype=np.complex64) 296 | 297 | # Loop through the rest of the acquisitions and stuff 298 | for acqnum in range(firstacq, dset.number_of_acquisitions()): 299 | acq = dset.read_acquisition(acqnum) 300 | 301 | # Stuff into the buffer 302 | y = acq.idx.kspace_encode_step_1 303 | z = acq.idx.kspace_encode_step_2 304 | phase = acq.idx.phase; 305 | set = acq.idx.set; 306 | slice = acq.idx.slice; 307 | rep = acq.idx.repetition; 308 | avg = acq.idx.average; 309 | all_data[kx_prezp:, y, z, :, phase, set, slice, rep, avg] = np.transpose(acq.data) 310 | 311 | return all_data, param 312 | 313 | 314 | if __name__ == '__main__': 315 | import matplotlib.pyplot as plt 316 | 317 | a = cartesian_mask((192, 160), acc=5, centred=False, sample_random=False) 318 | plt.imshow(a) 319 | plt.show() 320 | print(a.shape, a.dtype) 321 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | # @author : Bingyu Xin 2 | # @Institute : CS@Rutgers 3 | 4 | ## MS-SSIM loss is modified from https://github.com/VainF/pytorch-msssim/blob/master/pytorch_msssim/ssim.py 5 | 6 | import warnings 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | def _fspecial_gauss_1d(size, sigma): 13 | r"""Create 1-D gauss kernel 14 | Args: 15 | size (int): the size of gauss kernel 16 | sigma (float): sigma of normal distribution 17 | 18 | Returns: 19 | torch.Tensor: 1D kernel (1 x 1 x size) 20 | """ 21 | coords = torch.arange(size).to(dtype=torch.float) 22 | coords -= size // 2 23 | 24 | g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) 25 | g /= g.sum() 26 | 27 | return g.unsqueeze(0).unsqueeze(0) 28 | 29 | 30 | def gaussian_filter(input, win): 31 | r""" Blur input with 1-D kernel 32 | Args: 33 | input (torch.Tensor): a batch of tensors to be blurred 34 | window (torch.Tensor): 1-D gauss kernel 35 | 36 | Returns: 37 | torch.Tensor: blurred tensors 38 | """ 39 | assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape 40 | if len(input.shape) == 4: 41 | conv = F.conv2d 42 | elif len(input.shape) == 5: 43 | conv = F.conv3d 44 | else: 45 | raise NotImplementedError(input.shape) 46 | 47 | C = input.shape[1] 48 | out = input 49 | for i, s in enumerate(input.shape[2:]): 50 | if s >= win.shape[-1]: 51 | out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C) 52 | else: 53 | warnings.warn( 54 | f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}" 55 | ) 56 | 57 | return out 58 | 59 | 60 | def _ssim(X, Y, data_range, win, size_average=True, K=(0.01, 0.03)): 61 | r""" Calculate ssim index for X and Y 62 | 63 | Args: 64 | X (torch.Tensor): images 65 | Y (torch.Tensor): images 66 | win (torch.Tensor): 1-D gauss kernel 67 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 68 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 69 | 70 | Returns: 71 | torch.Tensor: ssim results. 72 | """ 73 | K1, K2 = K 74 | # batch, channel, [depth,] height, width = X.shape 75 | compensation = 1.0 76 | 77 | C1 = (K1 * data_range) ** 2 78 | C2 = (K2 * data_range) ** 2 79 | 80 | win = win.to(X.device, dtype=X.dtype) 81 | 82 | mu1 = gaussian_filter(X, win) 83 | mu2 = gaussian_filter(Y, win) 84 | 85 | mu1_sq = mu1.pow(2) 86 | mu2_sq = mu2.pow(2) 87 | mu1_mu2 = mu1 * mu2 88 | 89 | sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq) 90 | sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq) 91 | sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2) 92 | 93 | cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1 94 | ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map 95 | 96 | ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1) 97 | cs = torch.flatten(cs_map, 2).mean(-1) 98 | return ssim_per_channel, cs 99 | 100 | 101 | def ssim( 102 | X, 103 | Y, 104 | data_range=255, 105 | size_average=True, 106 | win_size=11, 107 | win_sigma=1.5, 108 | win=None, 109 | K=(0.01, 0.03), 110 | nonnegative_ssim=False, 111 | ): 112 | r""" interface of ssim 113 | Args: 114 | X (torch.Tensor): a batch of images, (N,C,H,W) 115 | Y (torch.Tensor): a batch of images, (N,C,H,W) 116 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 117 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 118 | win_size: (int, optional): the size of gauss kernel 119 | win_sigma: (float, optional): sigma of normal distribution 120 | win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma 121 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 122 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu 123 | 124 | Returns: 125 | torch.Tensor: ssim results 126 | """ 127 | if not X.shape == Y.shape: 128 | raise ValueError("Input images should have the same dimensions.") 129 | 130 | for d in range(len(X.shape) - 1, 1, -1): 131 | X = X.squeeze(dim=d) 132 | Y = Y.squeeze(dim=d) 133 | 134 | if len(X.shape) not in (4, 5): 135 | raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}") 136 | 137 | if not X.type() == Y.type(): 138 | raise ValueError("Input images should have the same dtype.") 139 | 140 | if win is not None: # set win_size 141 | win_size = win.shape[-1] 142 | 143 | if not (win_size % 2 == 1): 144 | raise ValueError("Window size should be odd.") 145 | 146 | if win is None: 147 | win = _fspecial_gauss_1d(win_size, win_sigma) 148 | win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) 149 | 150 | ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K) 151 | if nonnegative_ssim: 152 | ssim_per_channel = torch.relu(ssim_per_channel) 153 | 154 | if size_average: 155 | return ssim_per_channel.mean() 156 | else: 157 | return ssim_per_channel.mean(1) 158 | 159 | 160 | def ms_ssim( 161 | X, Y, data_range=255, size_average=True, win_size=11, win_sigma=1.5, win=None, weights=None, K=(0.01, 0.03) 162 | ): 163 | r""" interface of ms-ssim 164 | Args: 165 | X (torch.Tensor): a batch of images, (N,C,[T,]H,W) 166 | Y (torch.Tensor): a batch of images, (N,C,[T,]H,W) 167 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 168 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 169 | win_size: (int, optional): the size of gauss kernel 170 | win_sigma: (float, optional): sigma of normal distribution 171 | win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma 172 | weights (list, optional): weights for different levels 173 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 174 | Returns: 175 | torch.Tensor: ms-ssim results 176 | """ 177 | if not X.shape == Y.shape: 178 | raise ValueError("Input images should have the same dimensions.") 179 | 180 | for d in range(len(X.shape) - 1, 1, -1): 181 | X = X.squeeze(dim=d) 182 | Y = Y.squeeze(dim=d) 183 | 184 | if not X.type() == Y.type(): 185 | raise ValueError("Input images should have the same dtype.") 186 | 187 | if len(X.shape) == 4: 188 | avg_pool = F.avg_pool2d 189 | elif len(X.shape) == 5: 190 | avg_pool = F.avg_pool3d 191 | else: 192 | raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}") 193 | 194 | if win is not None: # set win_size 195 | win_size = win.shape[-1] 196 | 197 | if not (win_size % 2 == 1): 198 | raise ValueError("Window size should be odd.") 199 | 200 | smaller_side = min(X.shape[-2:]) 201 | assert smaller_side > (win_size - 1) * ( 202 | 2 ** 4 203 | ), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % ((win_size - 1) * (2 ** 4)) 204 | 205 | if weights is None: 206 | weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] 207 | weights = torch.FloatTensor(weights).to(X.device, dtype=X.dtype) 208 | 209 | if win is None: 210 | win = _fspecial_gauss_1d(win_size, win_sigma) 211 | win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) 212 | 213 | levels = weights.shape[0] 214 | mcs = [] 215 | for i in range(levels): 216 | ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K) 217 | 218 | if i < levels - 1: 219 | mcs.append(torch.relu(cs)) 220 | padding = [s % 2 for s in X.shape[2:]] 221 | X = avg_pool(X, kernel_size=2, padding=padding) 222 | Y = avg_pool(Y, kernel_size=2, padding=padding) 223 | 224 | ssim_per_channel = torch.relu(ssim_per_channel) # (batch, channel) 225 | mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0) # (level, batch, channel) 226 | ms_ssim_val = torch.prod(mcs_and_ssim ** weights.view(-1, 1, 1), dim=0) 227 | 228 | if size_average: 229 | return ms_ssim_val.mean() 230 | else: 231 | return ms_ssim_val.mean(1) 232 | 233 | 234 | class SSIM(torch.nn.Module): 235 | def __init__( 236 | self, 237 | size_average=True, 238 | win_size=11, 239 | win_sigma=1.5, 240 | channel=3, 241 | spatial_dims=2, 242 | K=(0.01, 0.03), 243 | nonnegative_ssim=False, 244 | ): 245 | r""" class for ssim 246 | Args: 247 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 248 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 249 | win_size: (int, optional): the size of gauss kernel 250 | win_sigma: (float, optional): sigma of normal distribution 251 | channel (int, optional): input channels (default: 3) 252 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 253 | nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu. 254 | """ 255 | 256 | super(SSIM, self).__init__() 257 | self.win_size = win_size 258 | self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims) 259 | self.size_average = size_average 260 | self.K = K 261 | self.nonnegative_ssim = nonnegative_ssim 262 | 263 | def forward(self, X, Y, data_range): 264 | return ssim( 265 | X, 266 | Y, 267 | data_range=data_range, 268 | size_average=self.size_average, 269 | win=self.win, 270 | K=self.K, 271 | nonnegative_ssim=self.nonnegative_ssim, 272 | ) 273 | 274 | 275 | class MS_SSIM(torch.nn.Module): 276 | def __init__( 277 | self, 278 | size_average=True, 279 | win_size=11, 280 | win_sigma=1.5, 281 | channel=3, 282 | spatial_dims=2, 283 | weights=None, 284 | K=(0.01, 0.03), 285 | ): 286 | r""" class for ms-ssim 287 | Args: 288 | data_range (float or int, optional): value range of input images. (usually 1.0 or 255) 289 | size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar 290 | win_size: (int, optional): the size of gauss kernel 291 | win_sigma: (float, optional): sigma of normal distribution 292 | channel (int, optional): input channels (default: 3) 293 | weights (list, optional): weights for different levels 294 | K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. 295 | """ 296 | 297 | super(MS_SSIM, self).__init__() 298 | self.win_size = win_size 299 | self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims) 300 | self.size_average = size_average 301 | self.weights = weights 302 | self.K = K 303 | 304 | def forward(self, X, Y, data_range=1.): 305 | return ms_ssim( 306 | X, 307 | Y, 308 | data_range=data_range, 309 | size_average=self.size_average, 310 | win=self.win, 311 | weights=self.weights, 312 | K=self.K, 313 | ) 314 | 315 | 316 | class CompoundLoss(nn.Module): 317 | 318 | def __init__(self, ssim_type='ssim'): 319 | super().__init__() 320 | self.l1loss = nn.L1Loss() 321 | if ssim_type == 'ssim': 322 | self.msssim = SSIM(win_size=7, size_average=True, channel=1, K=(0.01, 0.03)) 323 | elif ssim_type == 'ms-ssim': 324 | self.msssim = MS_SSIM(win_size=7, size_average=True, channel=1, K=(0.01, 0.03)) 325 | self.alpha = 0.84 326 | 327 | def forward(self, pred, target, data_range=1.): 328 | l1_loss = self.l1loss(pred, target) 329 | ssim_loss = 1 - self.msssim(pred.unsqueeze(1), target.unsqueeze(1), data_range) 330 | return (1 - self.alpha) * l1_loss + self.alpha * ssim_loss 331 | -------------------------------------------------------------------------------- /model/BasicModule.py: -------------------------------------------------------------------------------- 1 | # @author : Bingyu Xin 2 | # @Institute : CS@Rutgers 3 | 4 | import torch 5 | import torch.nn as nn 6 | from collections import OrderedDict 7 | 8 | 9 | def conv_block(model_name='hqs-net', channel_in=22, n_convs=3, n_filters=32): 10 | ''' 11 | reconstruction blocks in DC-CNN; 12 | primal(image)-net blocks and dual(k)-space-net blocks in LPD-Net; 13 | regular cnn reconstruction blocks in HQS-Net 14 | :param model_name: 'dc-cnn', 'prim-net', 'dual-net', or 'hqs-net' 15 | :param channel_in: 16 | :param n_filters: 17 | :param n_convs: 18 | :return: 19 | ''' 20 | layers = [] 21 | if model_name == 'dc-cnn': 22 | channel_out = channel_in 23 | elif model_name == 'prim-net' or model_name == 'hqs-net': 24 | channel_out = channel_in - 2 25 | elif model_name == 'dual-net': 26 | channel_out = channel_in - 4 27 | 28 | for i in range(n_convs - 1): 29 | if i == 0: 30 | layers.append(nn.Conv2d(channel_in, n_filters, kernel_size=3, stride=1, padding=1)) 31 | else: 32 | layers.append(nn.Conv2d(n_filters, n_filters, kernel_size=3, stride=1, padding=1)) 33 | 34 | layers.append(nn.LeakyReLU(0.2, inplace=True)) 35 | layers.append(nn.Conv2d(n_filters, channel_out, kernel_size=3, stride=1, padding=1)) 36 | 37 | return nn.Sequential(*layers) 38 | 39 | 40 | ########################################## below are external codes from other repo ####################################### 41 | 42 | ''' 43 | # -------------------------------------------- 44 | # Advanced nn.Sequential 45 | # https://github.com/xinntao/BasicSR 46 | # -------------------------------------------- 47 | ''' 48 | 49 | 50 | def sequential(*args): 51 | """Advanced nn.Sequential. 52 | 53 | Args: 54 | nn.Sequential, nn.Module 55 | 56 | Returns: 57 | nn.Sequential 58 | """ 59 | if len(args) == 1: 60 | if isinstance(args[0], OrderedDict): 61 | raise NotImplementedError('sequential does not support OrderedDict input.') 62 | return args[0] # No sequential is needed. 63 | modules = [] 64 | for module in args: 65 | if isinstance(module, nn.Sequential): 66 | for submodule in module.children(): 67 | modules.append(submodule) 68 | elif isinstance(module, nn.Module): 69 | modules.append(module) 70 | return nn.Sequential(*modules) 71 | 72 | 73 | ''' 74 | # -------------------------------------------- 75 | # Useful blocks 76 | # https://github.com/xinntao/BasicSR 77 | # -------------------------------- 78 | # conv + normaliation + relu (conv) 79 | # resblock (ResBlock) 80 | # -------------------------------------------- 81 | ''' 82 | 83 | 84 | # -------------------------------------------- 85 | # return nn.Sequantial of (Conv + BN + ReLU) 86 | # -------------------------------------------- 87 | def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CBR', 88 | negative_slope=0.2): 89 | L = [] 90 | for t in mode: 91 | if t == 'C': 92 | L.append( 93 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 94 | padding=padding, bias=bias)) 95 | elif t == 'T': 96 | L.append(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, 97 | stride=stride, padding=padding, bias=bias)) 98 | elif t == 'B': 99 | L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True)) 100 | elif t == 'I': 101 | L.append(nn.InstanceNorm2d(out_channels, affine=True)) 102 | elif t == 'R': 103 | L.append(nn.ReLU(inplace=True)) 104 | elif t == 'r': 105 | L.append(nn.ReLU(inplace=False)) 106 | elif t == 'L': 107 | L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True)) 108 | elif t == 'l': 109 | L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False)) 110 | elif t == '2': 111 | L.append(nn.PixelShuffle(upscale_factor=2)) 112 | elif t == '3': 113 | L.append(nn.PixelShuffle(upscale_factor=3)) 114 | elif t == '4': 115 | L.append(nn.PixelShuffle(upscale_factor=4)) 116 | elif t == 'U': 117 | L.append(nn.Upsample(scale_factor=2, mode='nearest')) 118 | elif t == 'u': 119 | L.append(nn.Upsample(scale_factor=3, mode='nearest')) 120 | elif t == 'v': 121 | L.append(nn.Upsample(scale_factor=4, mode='nearest')) 122 | elif t == 'M': 123 | L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0)) 124 | elif t == 'A': 125 | L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0)) 126 | else: 127 | raise NotImplementedError('Undefined type: '.format(t)) 128 | return sequential(*L) 129 | 130 | 131 | # -------------------------------------------- 132 | # Res Block: x + conv(relu(conv(x))) 133 | # -------------------------------------------- 134 | class ResBlock(nn.Module): 135 | def __init__(self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='CRC', 136 | negative_slope=0.2): 137 | super(ResBlock, self).__init__() 138 | 139 | assert in_channels == out_channels, 'Only support in_channels==out_channels.' 140 | if mode[0] in ['R', 'L']: 141 | mode = mode[0].lower() + mode[1:] 142 | 143 | self.res = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) 144 | 145 | def forward(self, x): 146 | # res = self.res(x) 147 | return x + self.res(x) 148 | 149 | 150 | """ 151 | # -------------------------------------------- 152 | # Upsampler 153 | # Kai Zhang, https://github.com/cszn/KAIR 154 | # -------------------------------------------- 155 | # upsample_pixelshuffle 156 | # upsample_upconv 157 | # upsample_convtranspose 158 | # -------------------------------------------- 159 | """ 160 | 161 | 162 | # -------------------------------------------- 163 | # conv + subp (+ relu) 164 | # -------------------------------------------- 165 | def upsample_pixelshuffle(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', 166 | negative_slope=0.2): 167 | assert len(mode) < 4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' 168 | up1 = conv(in_channels, out_channels * (int(mode[0]) ** 2), kernel_size, stride, padding, bias, mode='C' + mode, 169 | negative_slope=negative_slope) 170 | return up1 171 | 172 | 173 | # -------------------------------------------- 174 | # nearest_upsample + conv (+ R) 175 | # -------------------------------------------- 176 | def upsample_upconv(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', 177 | negative_slope=0.2): 178 | assert len(mode) < 4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR' 179 | if mode[0] == '2': 180 | uc = 'UC' 181 | elif mode[0] == '3': 182 | uc = 'uC' 183 | elif mode[0] == '4': 184 | uc = 'vC' 185 | mode = mode.replace(mode[0], uc) 186 | up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope) 187 | return up1 188 | 189 | 190 | # -------------------------------------------- 191 | # convTranspose (+ relu) 192 | # -------------------------------------------- 193 | def upsample_convtranspose(in_channels=64, out_channels=3, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', 194 | negative_slope=0.2): 195 | assert len(mode) < 4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' 196 | kernel_size = int(mode[0]) 197 | stride = int(mode[0]) 198 | mode = mode.replace(mode[0], 'T') 199 | up1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) 200 | return up1 201 | 202 | 203 | ''' 204 | # -------------------------------------------- 205 | # Downsampler 206 | # Kai Zhang, https://github.com/cszn/KAIR 207 | # -------------------------------------------- 208 | # downsample_strideconv 209 | # downsample_maxpool 210 | # downsample_avgpool 211 | # -------------------------------------------- 212 | ''' 213 | 214 | 215 | # -------------------------------------------- 216 | # strideconv (+ relu) 217 | # -------------------------------------------- 218 | def downsample_strideconv(in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True, mode='2R', 219 | negative_slope=0.2): 220 | assert len(mode) < 4 and mode[0] in ['2', '3', '4'], 'mode examples: 2, 2R, 2BR, 3, ..., 4BR.' 221 | kernel_size = int(mode[0]) 222 | stride = int(mode[0]) 223 | mode = mode.replace(mode[0], 'C') 224 | down1 = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope) 225 | return down1 226 | 227 | 228 | # -------------------------------------------- 229 | # maxpooling + conv (+ relu) 230 | # -------------------------------------------- 231 | def downsample_maxpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=True, mode='2R', 232 | negative_slope=0.2): 233 | assert len(mode) < 4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.' 234 | kernel_size_pool = int(mode[0]) 235 | stride_pool = int(mode[0]) 236 | mode = mode.replace(mode[0], 'MC') 237 | pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) 238 | pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], 239 | negative_slope=negative_slope) 240 | return sequential(pool, pool_tail) 241 | 242 | 243 | # -------------------------------------------- 244 | # averagepooling + conv (+ relu) 245 | # -------------------------------------------- 246 | def downsample_avgpool(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode='2R', 247 | negative_slope=0.2): 248 | assert len(mode) < 4 and mode[0] in ['2', '3'], 'mode examples: 2, 2R, 2BR, 3, ..., 3BR.' 249 | kernel_size_pool = int(mode[0]) 250 | stride_pool = int(mode[0]) 251 | mode = mode.replace(mode[0], 'AC') 252 | pool = conv(kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope) 253 | pool_tail = conv(in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], 254 | negative_slope=negative_slope) 255 | return sequential(pool, pool_tail) 256 | 257 | 258 | ## UnetRes is taken from https://github.com/cszn/DPIR/blob/master/models/network_unet.py 259 | ## used as modified Unet reconstruction blocks in HQS-Net 260 | class UNetRes(nn.Module): 261 | def __init__(self, in_nc=1, out_nc=1, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', 262 | upsample_mode='convtranspose'): 263 | super(UNetRes, self).__init__() 264 | 265 | self.m_head = conv(in_nc, nc[0], bias=False, mode='C') 266 | 267 | # downsample 268 | if downsample_mode == 'avgpool': 269 | downsample_block = downsample_avgpool 270 | elif downsample_mode == 'maxpool': 271 | downsample_block = downsample_maxpool 272 | elif downsample_mode == 'strideconv': 273 | downsample_block = downsample_strideconv 274 | else: 275 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) 276 | 277 | self.m_down1 = sequential( 278 | *[ResBlock(nc[0], nc[0], bias=False, mode='C' + act_mode + 'C') for _ in range(nb)], 279 | downsample_block(nc[0], nc[1], bias=False, mode='2')) 280 | self.m_down2 = sequential( 281 | *[ResBlock(nc[1], nc[1], bias=False, mode='C' + act_mode + 'C') for _ in range(nb)], 282 | downsample_block(nc[1], nc[2], bias=False, mode='2')) 283 | self.m_down3 = sequential( 284 | *[ResBlock(nc[2], nc[2], bias=False, mode='C' + act_mode + 'C') for _ in range(nb)], 285 | downsample_block(nc[2], nc[3], bias=False, mode='2')) 286 | 287 | self.m_body = sequential( 288 | *[ResBlock(nc[3], nc[3], bias=False, mode='C' + act_mode + 'C') for _ in range(nb)]) 289 | 290 | # upsample 291 | if upsample_mode == 'upconv': 292 | upsample_block = upsample_upconv 293 | elif upsample_mode == 'pixelshuffle': 294 | upsample_block = upsample_pixelshuffle 295 | elif upsample_mode == 'convtranspose': 296 | upsample_block = upsample_convtranspose 297 | else: 298 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) 299 | 300 | self.m_up3 = sequential(upsample_block(nc[3], nc[2], bias=False, mode='2'), 301 | *[ResBlock(nc[2], nc[2], bias=False, mode='C' + act_mode + 'C') for _ in range(nb)]) 302 | self.m_up2 = sequential(upsample_block(nc[2], nc[1], bias=False, mode='2'), 303 | *[ResBlock(nc[1], nc[1], bias=False, mode='C' + act_mode + 'C') for _ in range(nb)]) 304 | self.m_up1 = sequential(upsample_block(nc[1], nc[0], bias=False, mode='2'), 305 | *[ResBlock(nc[0], nc[0], bias=False, mode='C' + act_mode + 'C') for _ in range(nb)]) 306 | 307 | self.m_tail = conv(nc[0], out_nc, bias=False, mode='C') 308 | 309 | def forward(self, x0): 310 | x1 = self.m_head(x0) 311 | x2 = self.m_down1(x1) 312 | x3 = self.m_down2(x2) 313 | x4 = self.m_down3(x3) 314 | x = self.m_body(x4) 315 | x = self.m_up3(x + x4) 316 | x = self.m_up2(x + x3) 317 | x = self.m_up1(x + x2) 318 | x = self.m_tail(x + x1) 319 | 320 | return x 321 | -------------------------------------------------------------------------------- /Solver.py: -------------------------------------------------------------------------------- 1 | # @author : Bingyu Xin 2 | # @Institute : CS@Rutgers 3 | import os 4 | from os.path import join 5 | import time 6 | from tqdm import tqdm 7 | import torch 8 | import torch.optim as optim 9 | import torch.utils.data as Data 10 | from tensorboardX import SummaryWriter 11 | from skimage.metrics import structural_similarity as cal_ssim 12 | from skimage.metrics import peak_signal_noise_ratio as cal_psnr 13 | from skimage.metrics import normalized_root_mse as cal_nrmse 14 | 15 | from loss import CompoundLoss 16 | from utils import output2complex 17 | from read_data import MyData 18 | from model.DCCNN import DCCNN 19 | from model.LPDNet import LPDNet 20 | from model.HQSNet import HQSNet 21 | from model.ISTANet_plus import ISTANetplus 22 | import numpy as np 23 | 24 | 25 | class Solver(): 26 | def __init__(self, args): 27 | torch.autograd.set_detect_anomaly(True) 28 | self.args = args 29 | ################ experiment settings ################ 30 | self.model_name = self.args.model 31 | self.acc = self.args.acc 32 | self.imageDir_train = self.args.train_path # train path 33 | self.imageDir_val = self.args.val_path # val path while training 34 | self.imageDir_test = self.args.test_path # test path 35 | self.num_epoch = self.args.num_epoch # training epochs 36 | self.batch_size = self.args.batch_size # batch size 37 | self.val_on_epochs = self.args.val_on_epochs # validate on every val_on_epochs; 38 | self.resume = self.args.resume # resume training 39 | ## settings for optimizer 40 | self.lr = self.args.lr 41 | ## settings for data preprocessing 42 | self.img_size = (192, 160) 43 | self.saveDir = 'weight' # model save path while training 44 | if not os.path.isdir(self.saveDir): 45 | os.makedirs(self.saveDir) 46 | 47 | self.task_name = self.model_name + '_acc_' + str(self.acc) + '_bs_' + str(self.batch_size) \ 48 | + '_lr_' + str(self.lr) 49 | print('task_name: ', self.task_name) 50 | self.model_path = 'weight/' + self.task_name + '_best.pth' # model load path for test 51 | 52 | ############################################ Specify network ############################################ 53 | if self.model_name == 'dc-cnn': 54 | self.net = DCCNN(n_iter=8) 55 | elif self.model_name == 'ista-net-plus': 56 | self.net = ISTANetplus(n_iter=8) 57 | elif self.model_name == 'lpd-net': 58 | self.net = LPDNet(n_iter=8) 59 | elif self.model_name == 'hqs-net': 60 | self.net = HQSNet(block_type='cnn', buffer_size=5, n_iter=8) 61 | elif self.model_name == 'hqs-net-unet': 62 | # HQS-Net-Unet is for best reconstruction quality, so we enlarge the model, it is not a fair comparison to other models 63 | self.net = HQSNet(block_type='unet', buffer_size=5, n_iter=10) 64 | else: 65 | assert "wrong model name !" 66 | print('Total # of model params: %.5fM' % (sum(p.numel() for p in self.net.parameters()) / 10. ** 6)) 67 | self.net.cuda() 68 | 69 | def train(self): 70 | 71 | ############################################ Specify loss ############################################ 72 | ## Notice: 73 | ## there is an unknown backward gradient bug when training HQS-Net-Unet, which may interupt the training, 74 | ## you can simply resume the training by setting for --resume 1 in the scripts. 75 | self.criterion = CompoundLoss('ms-ssim') 76 | 77 | ############################################ Specify optimizer ######################################## 78 | 79 | self.optimizer_G = optim.Adam(self.net.parameters(), lr=self.lr, eps=1e-3, weight_decay=1e-10) 80 | 81 | ############################################ load data ############################################ 82 | 83 | dataset_train = MyData(self.imageDir_train, self.acc, self.img_size, is_training='train') 84 | dataset_val = MyData(self.imageDir_val, self.acc, self.img_size, is_training='val') 85 | 86 | num_workers = 4 87 | use_pin_memory = True 88 | loader_train = Data.DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True, drop_last=True, 89 | num_workers=num_workers, pin_memory=use_pin_memory) 90 | loader_val = Data.DataLoader(dataset_val, batch_size=self.batch_size, shuffle=False, drop_last=False, 91 | num_workers=num_workers, pin_memory=use_pin_memory) 92 | self.slices_val = len(dataset_val) 93 | print("slices of 2d train data: ", len(dataset_train)) 94 | print("slices of 2d validation data: ", len(dataset_val)) 95 | 96 | ############################################ setting for tensorboard ################################### 97 | self.writer = SummaryWriter('log/' + self.task_name) 98 | 99 | ############################################ start to run epochs ####################################### 100 | 101 | start_epoch = 0 102 | best_val_psnr = 0 103 | if self.resume: 104 | best_name = self.task_name + '_best.pth' 105 | checkpoint = torch.load(join(self.saveDir, best_name)) 106 | self.net.load_state_dict(checkpoint['net']) 107 | start_epoch = checkpoint['epoch'] + 1 108 | best_val_psnr = checkpoint['val_psnr'] 109 | print('load pretrained model---, start epoch at, ', start_epoch, ', star_psnr_val is: ', best_val_psnr) 110 | for epoch in range(start_epoch, self.num_epoch): 111 | ####################### 1. training ####################### 112 | 113 | loss_g = self._train_cnn(loader_train) 114 | ####################### 2. validate ####################### 115 | if epoch == start_epoch: 116 | base_psnr, base_ssim = self._validate_base(loader_val) 117 | if epoch % self.val_on_epochs == 0: 118 | val_psnr, val_ssim = self._validate(loader_val) 119 | ########################## 3. print and tensorboard ######################## 120 | print("Epoch {}/{}".format(epoch + 1, self.num_epoch)) 121 | print(" base PSNR:\t\t{:.6f}".format(base_psnr)) 122 | print(" test PSNR:\t\t{:.6f}".format(val_psnr)) 123 | print(" base SSIM:\t\t{:.6f}".format(base_ssim)) 124 | print(" test SSIM:\t\t{:.6f}".format(val_ssim)) 125 | ## write to tensorboard 126 | self.writer.add_scalar("loss/train_loss", loss_g, epoch) 127 | self.writer.add_scalar("metric/base_psnr", base_psnr, epoch) 128 | self.writer.add_scalar("metric/val_psnr", val_psnr, epoch) 129 | self.writer.add_scalar("metric/base_ssim", base_ssim, epoch) 130 | self.writer.add_scalar("metric/val_ssim", val_ssim, epoch) 131 | ## save the best model according to validation psnr 132 | if best_val_psnr < val_psnr: 133 | best_val_psnr = val_psnr 134 | best_name = self.task_name + '_best.pth' 135 | state = {'net': self.net.state_dict(), 'epoch': epoch, 'val_psnr': val_psnr, 'val_ssim': val_ssim} 136 | torch.save(state, join(self.saveDir, best_name)) 137 | self.writer.close() 138 | 139 | def test(self): 140 | 141 | ############################################ load data ################################ 142 | 143 | dataset_val = MyData(self.imageDir_test, self.acc, self.img_size, is_training='test') 144 | 145 | loader_val = Data.DataLoader(dataset_val, batch_size=self.batch_size, shuffle=False, drop_last=False, 146 | num_workers=2, pin_memory=False) 147 | len_data = len(dataset_val) 148 | print("slices of 2d test data: ", len_data) 149 | checkpoint = torch.load(self.model_path) 150 | 151 | print("best epoch at : {}, val_psnr: {:.6f}, val_ssim: {:.6f}" \ 152 | .format(checkpoint['epoch'], checkpoint['val_psnr'], checkpoint['val_ssim'])) 153 | 154 | self.net.load_state_dict(checkpoint['net']) 155 | self.net.cuda() 156 | self.net.eval() 157 | 158 | base_psnr = [] 159 | test_psnr = [] 160 | base_ssim = [] 161 | test_ssim = [] 162 | base_nrmse = [] 163 | test_nrmse = [] 164 | with torch.no_grad(): 165 | time_0 = time.time() 166 | for data_dict in tqdm(loader_val): 167 | im_A, im_A_und, k_A_und, mask = data_dict['im_A'].float().cuda(), data_dict['im_A_und'].float().cuda(), \ 168 | data_dict['k_A_und'].float().cuda(), \ 169 | data_dict['mask_A'].float().cuda() 170 | T1 = self.net(im_A_und, k_A_und, mask) 171 | ############## convert model ouput to complex value in original range 172 | 173 | T1 = output2complex(T1) 174 | im_A = output2complex(im_A) 175 | im_A_und = output2complex(im_A_und) 176 | 177 | ########################### calulate metrics ################################### 178 | for T1_i, im_A_i, im_A_und_i in zip(T1.cpu().numpy(), im_A.cpu().numpy(), im_A_und.cpu().numpy()): 179 | ## for skimage.metrics, input is (im_true,im_pred) 180 | base_nrmse.append(cal_nrmse(im_A_i, im_A_und_i)) 181 | test_nrmse.append(cal_nrmse(im_A_i, T1_i)) 182 | base_ssim.append(cal_ssim(im_A_i, im_A_und_i, data_range=im_A_i.max())) 183 | test_ssim.append(cal_ssim(im_A_i, T1_i, data_range=im_A_i.max())) 184 | base_psnr.append(cal_psnr(im_A_i, im_A_und_i, data_range=im_A_i.max())) 185 | test_psnr.append(cal_psnr(im_A_i, T1_i, data_range=im_A_i.max())) 186 | 187 | time_1 = time.time() 188 | ## comment metric calculation code for more precise inference speed 189 | print('inference speed: {:.5f} ms/slice'.format(1000 * (time_1 - time_0) / len_data)) 190 | 191 | print(" base PSNR:\t\t{:.6f}, std: {:.6f}".format(np.mean(base_psnr), np.std(base_psnr))) 192 | print(" test PSNR:\t\t{:.6f}, std: {:.6f}".format(np.mean(test_psnr), np.std(test_psnr))) 193 | print(" base SSIM:\t\t{:.6f}, std: {:.6f}".format(np.mean(base_ssim), np.std(base_ssim))) 194 | print(" test SSIM:\t\t{:.6f}, std: {:.6f}".format(np.mean(test_ssim), np.std(test_ssim))) 195 | print(" base NRMSE:\t\t{:.6f}, std: {:.6f}".format(np.mean(base_nrmse), np.std(base_nrmse))) 196 | print(" test NRMSE:\t\t{:.6f}, std: {:.6f}".format(np.mean(test_nrmse), np.std(test_nrmse))) 197 | 198 | def _train_cnn(self, loader_train): 199 | self.net.train() 200 | for data_dict in tqdm(loader_train): 201 | im_A, im_A_und, k_A_und, mask = data_dict['im_A'].float().cuda(), data_dict[ 202 | 'im_A_und'].float().cuda(), data_dict['k_A_und'].float().cuda(), data_dict['mask_A'].float().cuda() 203 | if self.model_name == 'ista-net-plus': 204 | T1, loss_layers_sym = self.net(im_A_und, k_A_und, mask) 205 | else: 206 | T1 = self.net(im_A_und, k_A_und, mask) 207 | 208 | T1 = output2complex(T1) 209 | im_A = output2complex(im_A) 210 | ############################################# 1.2 update generator ############################################# 211 | 212 | loss_g = self.criterion(T1, im_A, data_range=im_A.max()) 213 | if self.model_name == 'ista-net-plus': 214 | loss_constraint = torch.mean(torch.pow(loss_layers_sym[0], 2)) 215 | for k in range(len(loss_layers_sym) - 1): 216 | loss_constraint += torch.mean(torch.pow(loss_layers_sym[k + 1], 2)) 217 | loss_g = loss_g + 0.01 * loss_constraint 218 | 219 | self.optimizer_G.zero_grad() 220 | loss_g.backward() 221 | self.optimizer_G.step() 222 | 223 | return loss_g 224 | 225 | def _validate_base(self, loader_val): 226 | 227 | base_psnr = 0 228 | base_ssim = 0 229 | 230 | for data_dict in loader_val: 231 | im_A, im_A_und, = data_dict['im_A'].float().cuda(), data_dict['im_A_und'].float().cuda() 232 | ############## convert model ouput to complex value in original range 233 | im_A = output2complex(im_A) 234 | im_A_und = output2complex(im_A_und) 235 | ########################### cal metrics ################################### 236 | for im_A_i, im_A_und_i in zip(im_A.cpu().numpy(), 237 | im_A_und.cpu().numpy()): 238 | ## for skimage.metrics, input is (im_true,im_pred) 239 | base_ssim += cal_ssim(im_A_i, im_A_und_i, data_range=im_A_i.max()) 240 | base_psnr += cal_psnr(im_A_i, im_A_und_i, data_range=im_A_i.max()) 241 | base_psnr /= self.slices_val 242 | base_ssim /= self.slices_val 243 | return base_psnr, base_ssim 244 | 245 | def _validate(self, loader_val): 246 | 247 | test_psnr = 0 248 | test_ssim = 0 249 | 250 | self.net.eval() 251 | with torch.no_grad(): 252 | for data_dict in tqdm(loader_val): 253 | 254 | im_A, im_A_und, k_A_und, mask = data_dict['im_A'].float().cuda(), data_dict[ 255 | 'im_A_und'].float().cuda(), data_dict['k_A_und'].float().cuda(), data_dict[ 256 | 'mask_A'].float().cuda() 257 | T1 = self.net(im_A_und, k_A_und, mask) 258 | ############## convert model ouput to complex value in original range 259 | T1 = output2complex(T1) 260 | im_A = output2complex(im_A) 261 | 262 | ########################### cal metrics ################################### 263 | for T1_i, im_A_i in zip(T1.cpu().numpy(), im_A.cpu().numpy()): 264 | ## for skimage.metrics, input is (im_true,im_pred) 265 | test_ssim += cal_ssim(im_A_i, T1_i, data_range=im_A_i.max()) 266 | test_psnr += cal_psnr(im_A_i, T1_i, data_range=im_A_i.max()) 267 | 268 | test_psnr /= self.slices_val 269 | test_ssim /= self.slices_val 270 | return test_psnr, test_ssim 271 | --------------------------------------------------------------------------------