├── hemorrhage ├── dataset.py ├── test.zip ├── test.txt ├── README.md └── config.py ├── unet ├── train.sh ├── test.sh ├── README.md ├── args.py ├── unet_model.py └── main.py ├── MICCAI ├── mask │ ├── Poisson2D │ │ ├── PoissonDistributionMask_1.mat │ │ ├── PoissonDistributionMask_10.mat │ │ ├── PoissonDistributionMask_20.mat │ │ ├── PoissonDistributionMask_30.mat │ │ ├── PoissonDistributionMask_40.mat │ │ ├── PoissonDistributionMask_5.mat │ │ └── PoissonDistributionMask_50.mat │ ├── Gaussian1D │ │ ├── GaussianDistribution1DMask_1.mat │ │ ├── GaussianDistribution1DMask_5.mat │ │ ├── GaussianDistribution1DMask_10.mat │ │ ├── GaussianDistribution1DMask_20.mat │ │ ├── GaussianDistribution1DMask_30.mat │ │ ├── GaussianDistribution1DMask_40.mat │ │ └── GaussianDistribution1DMask_50.mat │ └── Gaussian2D │ │ ├── GaussianDistribution2DMask_1.mat │ │ ├── GaussianDistribution2DMask_5.mat │ │ ├── GaussianDistribution2DMask_10.mat │ │ ├── GaussianDistribution2DMask_20.mat │ │ ├── GaussianDistribution2DMask_30.mat │ │ ├── GaussianDistribution2DMask_40.mat │ │ └── GaussianDistribution2DMask_50.mat ├── README.md ├── DAGAN_Testing_datasets.txt ├── config.py ├── dataset.py ├── DAGAN_Training_datasets.txt └── prepare_data.py ├── .gitignore ├── fastMRI ├── README.md ├── config.py ├── mri_data.py ├── subsample.py ├── dataset.py └── transforms.py ├── README.md ├── model.py ├── utils.py ├── env.py ├── pixel_wise_a2c.py ├── test.py └── train.py /hemorrhage/dataset.py: -------------------------------------------------------------------------------- 1 | ../MICCAI/dataset.py -------------------------------------------------------------------------------- /hemorrhage/test.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/hemorrhage/test.zip -------------------------------------------------------------------------------- /unet/train.sh: -------------------------------------------------------------------------------- 1 | #CUDA_VISIBLE_DEVICE=1 2 | python main.py --batch-size=16 --challenge singlecoil --data-path None --data-parallel 3 | -------------------------------------------------------------------------------- /MICCAI/mask/Poisson2D/PoissonDistributionMask_1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Poisson2D/PoissonDistributionMask_1.mat -------------------------------------------------------------------------------- /MICCAI/mask/Poisson2D/PoissonDistributionMask_10.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Poisson2D/PoissonDistributionMask_10.mat -------------------------------------------------------------------------------- /MICCAI/mask/Poisson2D/PoissonDistributionMask_20.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Poisson2D/PoissonDistributionMask_20.mat -------------------------------------------------------------------------------- /MICCAI/mask/Poisson2D/PoissonDistributionMask_30.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Poisson2D/PoissonDistributionMask_30.mat -------------------------------------------------------------------------------- /MICCAI/mask/Poisson2D/PoissonDistributionMask_40.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Poisson2D/PoissonDistributionMask_40.mat -------------------------------------------------------------------------------- /MICCAI/mask/Poisson2D/PoissonDistributionMask_5.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Poisson2D/PoissonDistributionMask_5.mat -------------------------------------------------------------------------------- /MICCAI/mask/Poisson2D/PoissonDistributionMask_50.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Poisson2D/PoissonDistributionMask_50.mat -------------------------------------------------------------------------------- /MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_1.mat -------------------------------------------------------------------------------- /MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_5.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_5.mat -------------------------------------------------------------------------------- /MICCAI/mask/Gaussian2D/GaussianDistribution2DMask_1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Gaussian2D/GaussianDistribution2DMask_1.mat -------------------------------------------------------------------------------- /MICCAI/mask/Gaussian2D/GaussianDistribution2DMask_5.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Gaussian2D/GaussianDistribution2DMask_5.mat -------------------------------------------------------------------------------- /MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_10.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_10.mat -------------------------------------------------------------------------------- /MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_20.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_20.mat -------------------------------------------------------------------------------- /MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_30.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_30.mat -------------------------------------------------------------------------------- /MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_40.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_40.mat -------------------------------------------------------------------------------- /MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_50.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_50.mat -------------------------------------------------------------------------------- /MICCAI/mask/Gaussian2D/GaussianDistribution2DMask_10.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Gaussian2D/GaussianDistribution2DMask_10.mat -------------------------------------------------------------------------------- /MICCAI/mask/Gaussian2D/GaussianDistribution2DMask_20.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Gaussian2D/GaussianDistribution2DMask_20.mat -------------------------------------------------------------------------------- /MICCAI/mask/Gaussian2D/GaussianDistribution2DMask_30.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Gaussian2D/GaussianDistribution2DMask_30.mat -------------------------------------------------------------------------------- /MICCAI/mask/Gaussian2D/GaussianDistribution2DMask_40.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Gaussian2D/GaussianDistribution2DMask_40.mat -------------------------------------------------------------------------------- /MICCAI/mask/Gaussian2D/GaussianDistribution2DMask_50.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wentianli/MRI_RL/HEAD/MICCAI/mask/Gaussian2D/GaussianDistribution2DMask_50.mat -------------------------------------------------------------------------------- /unet/test.sh: -------------------------------------------------------------------------------- 1 | #CUDA_VISIBLE_DEVICE=1 2 | python main.py --batch-size 10 --resume --checkpoint checkpoints/model.pt --challenge singlecoil --data-path None --data-parallel --test 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | */__pycache__/ 3 | MICCAI/data 4 | *.jpg 5 | *.png 6 | *.bmp 7 | *.log 8 | model/ 9 | logs/ 10 | *.swp 11 | unet/checkpoints/ 12 | *.py_ 13 | *.pth 14 | *.pt 15 | -------------------------------------------------------------------------------- /hemorrhage/test.txt: -------------------------------------------------------------------------------- 1 | 1.jpg 2 | 2.jpg 3 | 3.jpg 4 | 4.jpg 5 | 5.jpg 6 | 6.jpg 7 | 7.jpg 8 | 8.jpg 9 | 9.jpg 10 | 10.jpg 11 | 11.jpg 12 | 12.jpg 13 | 13.jpg 14 | 14.jpg 15 | 15.jpg 16 | 16.jpg 17 | 17.jpg 18 | 18.jpg 19 | 19.jpg 20 | 20.jpg 21 | 21.jpg 22 | 22.jpg 23 | 23.jpg 24 | 24.jpg 25 | -------------------------------------------------------------------------------- /MICCAI/README.md: -------------------------------------------------------------------------------- 1 | ### Notes 2 | 3 | The same `.nii.gz` files and the same random seed as in 4 | [DAGAN](https://github.com/nebulaV/DAGAN/blob/master/data/MICCAI13_SegChallenge/dataset_name_list.txt) are used to select training and test images. 5 | 6 | However, I found that, the number of training images, 7 | as well as PSNR and NMSE of the zero-filled test images, 8 | are slightly different from what were stated in [DAGAN paper](https://ieeexplore.ieee.org/abstract/document/8233175). 9 | -------------------------------------------------------------------------------- /fastMRI/README.md: -------------------------------------------------------------------------------- 1 | ### Notes 2 | 3 | Different from [fastMRI official code](https://github.com/facebookresearch/fastMRI/blob/38bbfe2117905f5a246714739e7d6bedbdaba649/models/unet/train_unet.py), 4 | we do down-sampling after cropping the central region of the image, which allows data consistency step. 5 | 6 | The original normalization of the data `x` is 7 | ``` 8 | x = (x - x.mean()) / (x.std() + eps) 9 | x = x.clip(-6, 6) 10 | ``` 11 | After this, we scale the image into the range [0, 1] 12 | ``` 13 | m = min(float(x.min()), 0) 14 | x = (x - m) / (6 - m) 15 | ``` 16 | -------------------------------------------------------------------------------- /unet/README.md: -------------------------------------------------------------------------------- 1 | ### Notes 2 | 3 | This implementation of Unet based on [fastMRI official code](https://github.com/facebookresearch/fastMRI/tree/master/models/unet). 4 | 5 | The `F.interpolate` function was changed to `F.upsample` of older Pytorch version. 6 | The `Dataset` class is modified so that down-sampling happens after cropping. 7 | Other setting is almost identical to the original implementation. 8 | 9 | For training, run 10 | ``` 11 | sh train.sh 12 | ``` 13 | For testing, run 14 | ``` 15 | sh test.sh 16 | ``` 17 | During testing, the output is scaled into the range [0, 255] to make a fair comparison with our proposed model. 18 | 19 | There could be strange `segmentation fault`, especially when I tested the model during training. I haven't figured out why. 20 | -------------------------------------------------------------------------------- /hemorrhage/README.md: -------------------------------------------------------------------------------- 1 | ### Notes 2 | To test the model on a custom dataset, a `.txt` file listing all the images is needed. 3 | 4 | Put all images inside `test/`. 5 | 6 | To test on this dataset of Axial T1 images exhibiting subacute intraparenchymal hemorrhage, please unzip `test.zip`, and run 7 | ``` 8 | python test --dataset hemorrhage --model path_to_the_model 9 | ``` 10 | 11 | 12 | The images are downloaded from 13 | [Radiopaedia.org](https://radiopaedia.org/cases/early-and-late-subacute-intracerebral-haemorrhage-on-mri-and-ct?lang=us) (rID: 55641). 14 | 15 | We use the model trained on MICCAI using 30% mask. 16 | Our model is able to reconstruct (improving PSNR and SSIM) all the images 17 | and the regions exhibiting pathology can be successfully restored, 18 | although the aliasing cannot be fully removed 19 | (partly due to the domain gap between the training data and test images). 20 | 21 | This is an example of testing on out-of-distribution dataset. The output action distribution is very different from MICCAI. 22 | -------------------------------------------------------------------------------- /MICCAI/DAGAN_Testing_datasets.txt: -------------------------------------------------------------------------------- 1 | 1006_3x1104_3Warped.nii.gz 1007_3x1004_3Warped.nii.gz 1007_3x1019_3Warped.nii.gz 1007_3x1119_3Warped.nii.gz 2 | 1006_3x1107_3Warped.nii.gz 1007_3x1005_3Warped.nii.gz 1007_3x1023_3Warped.nii.gz 1007_3x1122_3Warped.nii.gz 3 | 1006_3x1110_3Warped.nii.gz 1007_3x1006_3Warped.nii.gz 1007_3x1024_3Warped.nii.gz 1007_3x1125_3Warped.nii.gz 4 | 1006_3x1113_3Warped.nii.gz 1007_3x1008_3Warped.nii.gz 1007_3x1025_3Warped.nii.gz 1007_3x1128_3Warped.nii.gz 5 | 1006_3x1116_3Warped.nii.gz 1007_3x1009_3Warped.nii.gz 1007_3x1036_3Warped.nii.gz 1008_3x1000_3Warped.nii.gz 6 | 1006_3x1119_3Warped.nii.gz 1007_3x1010_3Warped.nii.gz 1007_3x1038_3Warped.nii.gz 1008_3x1001_3Warped.nii.gz 7 | 1006_3x1122_3Warped.nii.gz 1007_3x1011_3Warped.nii.gz 1007_3x1039_3Warped.nii.gz 1008_3x1002_3Warped.nii.gz 8 | 1006_3x1125_3Warped.nii.gz 1007_3x1012_3Warped.nii.gz 1007_3x1101_3Warped.nii.gz 1008_3x1003_3Warped.nii.gz 9 | 1006_3x1128_3Warped.nii.gz 1007_3x1013_3Warped.nii.gz 1007_3x1104_3Warped.nii.gz 1008_3x1004_3Warped.nii.gz 10 | 1007_3x1000_3Warped.nii.gz 1007_3x1014_3Warped.nii.gz 1007_3x1107_3Warped.nii.gz 1008_3x1005_3Warped.nii.gz 11 | 1007_3x1001_3Warped.nii.gz 1007_3x1015_3Warped.nii.gz 1007_3x1110_3Warped.nii.gz 1008_3x1006_3Warped.nii.gz 12 | 1007_3x1002_3Warped.nii.gz 1007_3x1017_3Warped.nii.gz 1007_3x1113_3Warped.nii.gz 13 | 1007_3x1003_3Warped.nii.gz 1007_3x1018_3Warped.nii.gz 1007_3x1116_3Warped.nii.gz 14 | -------------------------------------------------------------------------------- /MICCAI/config.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class config: 4 | sampling_ratio = 30 5 | #-------------learning_related--------------------# 6 | batch_size = 12 7 | workers = 4 8 | iter_size = 2 9 | num_episodes = 20000 10 | test_episodes = 500 11 | save_episodes = 20000 12 | resume_model = '' #'model/8_28_21_16000.pth' 13 | display = 100 14 | #-------------rl_related--------------------# 15 | pi_loss_coeff = 1.0 16 | v_loss_coeff = 0.25 17 | beta = 0.1 18 | c_loss_coeff = 0.5 # 0.005 19 | switch = 4 20 | warm_up_episodes = 1000 21 | episode_len = 3 22 | gamma = 1 23 | reward_method = 'abs' 24 | noise_scale = 0.2 #0.5 25 | #-------------continuous parameters--------------------# 26 | actions = { 27 | 'box': 1, 28 | 'bilateral': 2, 29 | 'median': 3, 30 | 'Gaussian': 4, 31 | 'Laplace': 5, 32 | 'Sobel_v1': 6, 33 | 'Sobel_v2': 7, 34 | 'Sobel_h1': 8, 35 | 'Sobel_h2': 9, 36 | 'unsharp': 10, 37 | 'subtraction': 11, 38 | } 39 | num_actions = len(actions) + 1 40 | 41 | parameters_scale = { 42 | 'Laplace': 0.2, 43 | 'Sobel_v1': 0.2, 44 | 'Sobel_v2': 0.2, 45 | 'Sobel_h1': 0.2, 46 | 'Sobel_h2': 0.2, 47 | 'unsharp': 1.0, 48 | } 49 | 50 | #-------------lr_policy--------------------# 51 | base_lr = 0.001 52 | # poly 53 | lr_policy = 'poly' 54 | policy_parameter = { 55 | 'power': 1, 56 | 'max_iter' : 40000, 57 | } 58 | 59 | #-------------folder--------------------# 60 | dataset = 'MICCAI' 61 | root = 'MICCAI/data/' 62 | -------------------------------------------------------------------------------- /hemorrhage/config.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class config: 4 | sampling_ratio = 30 5 | #-------------learning_related--------------------# 6 | batch_size = 12 7 | workers = 4 8 | iter_size = 2 9 | num_episodes = 20000 10 | test_episodes = 500 11 | save_episodes = 20000 12 | resume_model = '' #'model/8_28_21_16000.pth' 13 | display = 100 14 | #-------------rl_related--------------------# 15 | pi_loss_coeff = 1.0 16 | v_loss_coeff = 0.25 17 | beta = 0.1 18 | c_loss_coeff = 0.5 # 0.005 19 | switch = 4 20 | warm_up_episodes = 1000 21 | episode_len = 3 22 | gamma = 1 23 | reward_method = 'abs' 24 | noise_scale = 0.2 #0.5 25 | #-------------continuous parameters--------------------# 26 | actions = { 27 | 'box': 1, 28 | 'bilateral': 2, 29 | 'median': 3, 30 | 'Gaussian': 4, 31 | 'Laplace': 5, 32 | 'Sobel_v1': 6, 33 | 'Sobel_v2': 7, 34 | 'Sobel_h1': 8, 35 | 'Sobel_h2': 9, 36 | 'unsharp': 10, 37 | 'subtraction': 11, 38 | } 39 | num_actions = len(actions) + 1 40 | 41 | parameters_scale = { 42 | 'Laplace': 0.2, 43 | 'Sobel_v1': 0.2, 44 | 'Sobel_v2': 0.2, 45 | 'Sobel_h1': 0.2, 46 | 'Sobel_h2': 0.2, 47 | 'unsharp': 1.0, 48 | } 49 | 50 | #-------------lr_policy--------------------# 51 | base_lr = 0.001 52 | # poly 53 | lr_policy = 'poly' 54 | policy_parameter = { 55 | 'power': 1, 56 | 'max_iter' : 40000, 57 | } 58 | 59 | #-------------folder--------------------# 60 | dataset = 'MICCAI' 61 | root = 'hemorrhage/' 62 | -------------------------------------------------------------------------------- /fastMRI/config.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class config: 4 | sampling_scheme = ([0.128], [2.5]) # retain 12.8% low freq, speed-up 2.5 (40% sampling ratio) 5 | resolution = 320 # 320x320 image 6 | #-------------learning_related--------------------# 7 | batch_size = 6 8 | workers = 4 9 | iter_size = 4 10 | num_episodes = 40000 11 | test_episodes = 1000 12 | save_episodes = 40000 13 | resume_model = '' #'model/8_28_21_16000.pth' 14 | display = 100 15 | #-------------rl_related--------------------# 16 | pi_loss_coeff = 1.0 17 | v_loss_coeff = 0.25 18 | beta = 0.1 19 | c_loss_coeff = 0.5 # 0.005 20 | switch = 8 21 | warm_up_episodes = 2000 22 | episode_len = 3 23 | gamma = 1 24 | reward_method = 'abs' 25 | noise_scale = 0.5 26 | #-------------continuous parameters--------------------# 27 | actions = { 28 | 'box': 1, 29 | 'bilateral': 2, 30 | 'median': 3, 31 | 'Gaussian': 4, 32 | 'Laplace': 5, 33 | 'Sobel_v1': 6, 34 | 'Sobel_v2': 7, 35 | 'Sobel_h1': 8, 36 | 'Sobel_h2': 9, 37 | 'unsharp': 10, 38 | 'subtraction': 11, 39 | } 40 | num_actions = len(actions) + 1 41 | 42 | parameters_scale = { 43 | 'Laplace': 0.2, 44 | 'Sobel_v1': 0.2, 45 | 'Sobel_v2': 0.2, 46 | 'Sobel_h1': 0.2, 47 | 'Sobel_h2': 0.2, 48 | 'unsharp': 1.0, 49 | } 50 | 51 | #-------------lr_policy--------------------# 52 | base_lr = 0.001 53 | # poly 54 | lr_policy = 'poly' 55 | policy_parameter = { 56 | 'power': 1, 57 | 'max_iter' : 80000, 58 | } 59 | 60 | #-------------folder--------------------# 61 | dataset = 'fastMRI' 62 | root = '/home/lwt/' 63 | -------------------------------------------------------------------------------- /unet/args.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import argparse 9 | import pathlib 10 | 11 | 12 | class Args(argparse.ArgumentParser): 13 | """ 14 | Defines global default arguments. 15 | """ 16 | 17 | def __init__(self, **overrides): 18 | """ 19 | Args: 20 | **overrides (dict, optional): Keyword arguments used to override default argument values 21 | """ 22 | 23 | super().__init__(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 24 | 25 | self.add_argument('--seed', default=42, type=int, help='Seed for random number generators') 26 | self.add_argument('--resolution', default=320, type=int, help='Resolution of images') 27 | 28 | # Data parameters 29 | self.add_argument('--challenge', choices=['singlecoil', 'multicoil'], required=True, 30 | help='Which challenge') 31 | self.add_argument('--data-path', type=pathlib.Path, required=True, 32 | help='Path to the dataset') 33 | self.add_argument('--sample-rate', type=float, default=1., 34 | help='Fraction of total volumes to include') 35 | 36 | # Mask parameters 37 | self.add_argument('--accelerations', nargs='+', default=[4, 8], type=int, 38 | help='Ratio of k-space columns to be sampled. If multiple values are ' 39 | 'provided, then one of those is chosen uniformly at random for ' 40 | 'each volume.') 41 | self.add_argument('--center-fractions', nargs='+', default=[0.08, 0.04], type=float, 42 | help='Fraction of low-frequency k-space columns to be sampled. Should ' 43 | 'have the same length as accelerations') 44 | 45 | # Override defaults with passed overrides 46 | self.set_defaults(**overrides) 47 | -------------------------------------------------------------------------------- /fastMRI/mri_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import pathlib 9 | import random 10 | 11 | import h5py 12 | from torch.utils.data import Dataset 13 | 14 | 15 | class SliceData(Dataset): 16 | """ 17 | A PyTorch Dataset that provides access to MR image slices. 18 | """ 19 | 20 | def __init__(self, root, transform, challenge, sample_rate=1): 21 | """ 22 | Args: 23 | root (pathlib.Path): Path to the dataset. 24 | transform (callable): A callable object that pre-processes the raw data into 25 | appropriate form. The transform function should take 'kspace', 'target', 26 | 'attributes', 'filename', and 'slice' as inputs. 'target' may be null 27 | for test data. 28 | challenge (str): "singlecoil" or "multicoil" depending on which challenge to use. 29 | sample_rate (float, optional): A float between 0 and 1. This controls what fraction 30 | of the volumes should be loaded. 31 | """ 32 | if challenge not in ('singlecoil', 'multicoil'): 33 | raise ValueError('challenge should be either "singlecoil" or "multicoil"') 34 | 35 | self.transform = transform 36 | self.recons_key = 'reconstruction_esc' if challenge == 'singlecoil' \ 37 | else 'reconstruction_rss' 38 | 39 | self.examples = [] 40 | files = list(pathlib.Path(root).iterdir()) 41 | if sample_rate < 1: 42 | random.shuffle(files) 43 | num_files = round(len(files) * sample_rate) 44 | files = files[:num_files] 45 | for fname in sorted(files): 46 | kspace = h5py.File(fname, 'r')['kspace'] 47 | num_slices = kspace.shape[0] 48 | self.examples += [(fname, slice) for slice in range(num_slices)] 49 | 50 | def __len__(self): 51 | return len(self.examples) 52 | 53 | def __getitem__(self, i): 54 | fname, slice = self.examples[i] 55 | with h5py.File(fname, 'r') as data: 56 | kspace = data['kspace'][slice] 57 | target = data[self.recons_key][slice] if self.recons_key in data else None 58 | return self.transform(kspace, target, data.attrs, fname.name, slice) 59 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MRI_RL 2 | 3 | This is the implementation of our AAAI 2020 [paper](https://ojs.aaai.org/index.php/AAAI/article/view/5423):
4 | MRI Reconstruction with Interpretable Pixel-Wise Operations Using Reinforcement Learning 5 | 6 | ``` 7 | @inproceedings{li2020mri, 8 | title={MRI Reconstruction with Interpretable Pixel-Wise Operations Using Reinforcement Learning}, 9 | author={Li, Wentian and Feng, Xidong and An, Haotian and Ng, Xiang Yao and Zhang, Yu-Jin}, 10 | booktitle={AAAI}, 11 | year={2020} 12 | } 13 | ``` 14 | 15 | Parts of the code are borrowed from other repos, including [pixelRL](https://github.com/rfuruta/pixelRL) (for a2c algorithm), [DAGAN](https://github.com/nebulaV/DAGAN/) (for MICCAI and MRI dataset), [fastMRI](https://github.com/facebookresearch/fastMRI) (for fastMRI dataset and unet), and some others. 16 | 17 | ## Environment 18 | 19 | I used Python 3.6.1, Pytorch 0.3.1.post2, torchvision 0.2.0, numpy 1.14.2, and tensorboardX 1.7. 20 | The code usually works fine on my machine with two GeForce GTX 1080, 21 | but some weird bugs appeared occasionally (ValueError from numpy and segmentation fault from running Unet). 22 | 23 | ## Data Preparation 24 | 25 | For [MICCAI 2013 Grand Challenge](https://my.vanderbilt.edu/masi/workshops/) dataset, please download the data and extract images by running `prepare_data.py` in `MICCAI/`. 26 | 27 | For [fastMRI](https://fastmri.med.nyu.edu/) dataset, the h5 files will directly be read. 28 | 29 | ## Training 30 | 31 | To properly set the data path, you need to modify the variables `dataset` and `root` in the `config.py` file in `MICCAI/` or `fastMRI/` accordingly. The hyper-parameters are also set in `config.py`. 32 | 33 | For MICCAI, run 34 | ``` 35 | python train.py --dataset MICCAI 36 | ``` 37 | 38 | For fastMRI, run 39 | ``` 40 | python train.py --dataset fastMRI 41 | ``` 42 | 43 | To train Unet on fastMRI, go to `unet/` and run 44 | ``` 45 | sh train.sh 46 | ``` 47 | See `unet/` for more details. 48 | 49 | ## Trained models 50 | 51 | We povide our model trained on MICCAI with 30% mask, our model and Unet trained on fastMRI with 40% mask. 52 | 53 | The trained models can be downloaded here: [百度网盘](https://pan.baidu.com/s/1y2OXdERwmeYZEGDsI-r2UQ) or [drive.google](https://drive.google.com/folderview?id=1-5F_qoX25HY1oxpZtsMXv7qTZUCY6Eo_). 54 | 55 | 56 | ## Testing 57 | 58 | Run 59 | ``` 60 | python test.py --dataset MICCAI_or_fastMRI --model path_to_the_model 61 | ``` 62 | 63 | We also provide an example of testing on a custom dataset. Please see `hemorrhage/`. 64 | -------------------------------------------------------------------------------- /MICCAI/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import torch.utils.data as data 5 | import cv2 6 | import numpy as np 7 | import PIL 8 | from torchvision.transforms import Compose, RandomHorizontalFlip, RandomRotation, RandomResizedCrop, ColorJitter, ToPILImage, ToTensor 9 | 10 | from utils import Downsample 11 | 12 | def load_mask(sampling_ratio): 13 | assert sampling_ratio in [10, 20, 30, 40, 50] 14 | from scipy.io import loadmat 15 | mask = loadmat('MICCAI/mask/Gaussian1D/GaussianDistribution1DMask_{}.mat'.format(sampling_ratio)) 16 | mask = mask['maskRS1'] 17 | print('mask:', np.mean(mask)) 18 | return mask 19 | 20 | class MRIDataset(data.Dataset): 21 | def __init__(self, image_set, transform, config): 22 | self.root = config.root 23 | self.image_set = image_set 24 | self.transform = transform 25 | self.ids = [i.strip() for i in open(self.root + self.image_set + '.txt').readlines()] 26 | 27 | self.mask = load_mask(config.sampling_ratio) 28 | #self.DAGAN = 'SegChallenge' in root and image_set == 'test' 29 | 30 | def __getitem__(self, index): 31 | x = cv2.imread(os.path.join(self.root, self.image_set, self.ids[index]), cv2.IMREAD_GRAYSCALE) 32 | if x.shape != (256, 256): 33 | x = cv2.resize(x, (256, 256)) 34 | 35 | # data augmentation 36 | if self.transform: 37 | transformations = Compose( 38 | [ToPILImage(), 39 | RandomRotation(degrees=10, resample=PIL.Image.BICUBIC), 40 | #RandomAffine(degrees=10, translate=(-25, 25), scale=(0.90, 1.10), resample=PIL.Image.BILINEAR), 41 | RandomHorizontalFlip(), 42 | RandomResizedCrop(size=256, scale=(0.90, 1), ratio=(0.95, 1.05), interpolation=PIL.Image.BICUBIC), 43 | #ColorJitter(brightness=0.05), 44 | #CenterCrop(size=(256, 256)), 45 | ToTensor(), 46 | ]) 47 | x = x[..., np.newaxis] 48 | x = transformations(x).float().numpy() * 255 49 | x = x[0] 50 | 51 | image, _, _ = Downsample(x, self.mask) 52 | 53 | x = x / 255. 54 | image = image / 255. 55 | 56 | target = torch.from_numpy(x).float().unsqueeze(0) 57 | image = torch.from_numpy(image).float().unsqueeze(0) 58 | mask = [0] # return something to be compatible with fastMRI dataset 59 | return target, image, mask 60 | 61 | def __len__(self): 62 | return len(self.ids) 63 | -------------------------------------------------------------------------------- /MICCAI/DAGAN_Training_datasets.txt: -------------------------------------------------------------------------------- 1 | 1003_3x1110_3Warped.nii.gz 1004_3x1023_3Warped.nii.gz 1005_3x1010_3Warped.nii.gz 1006_3x1000_3Warped.nii.gz 2 | 1003_3x1113_3Warped.nii.gz 1004_3x1024_3Warped.nii.gz 1005_3x1011_3Warped.nii.gz 1006_3x1001_3Warped.nii.gz 3 | 1003_3x1116_3Warped.nii.gz 1004_3x1025_3Warped.nii.gz 1005_3x1012_3Warped.nii.gz 1006_3x1002_3Warped.nii.gz 4 | 1003_3x1119_3Warped.nii.gz 1004_3x1036_3Warped.nii.gz 1005_3x1013_3Warped.nii.gz 1006_3x1003_3Warped.nii.gz 5 | 1003_3x1122_3Warped.nii.gz 1004_3x1038_3Warped.nii.gz 1005_3x1014_3Warped.nii.gz 1006_3x1004_3Warped.nii.gz 6 | 1003_3x1125_3Warped.nii.gz 1004_3x1039_3Warped.nii.gz 1005_3x1015_3Warped.nii.gz 1006_3x1005_3Warped.nii.gz 7 | 1003_3x1128_3Warped.nii.gz 1004_3x1101_3Warped.nii.gz 1005_3x1017_3Warped.nii.gz 1006_3x1007_3Warped.nii.gz 8 | 1004_3x1000_3Warped.nii.gz 1004_3x1104_3Warped.nii.gz 1005_3x1018_3Warped.nii.gz 1006_3x1008_3Warped.nii.gz 9 | 1004_3x1001_3Warped.nii.gz 1004_3x1107_3Warped.nii.gz 1005_3x1019_3Warped.nii.gz 1006_3x1009_3Warped.nii.gz 10 | 1004_3x1002_3Warped.nii.gz 1004_3x1110_3Warped.nii.gz 1005_3x1023_3Warped.nii.gz 1006_3x1010_3Warped.nii.gz 11 | 1004_3x1003_3Warped.nii.gz 1004_3x1113_3Warped.nii.gz 1005_3x1024_3Warped.nii.gz 1006_3x1011_3Warped.nii.gz 12 | 1004_3x1005_3Warped.nii.gz 1004_3x1116_3Warped.nii.gz 1005_3x1025_3Warped.nii.gz 1006_3x1012_3Warped.nii.gz 13 | 1004_3x1006_3Warped.nii.gz 1004_3x1119_3Warped.nii.gz 1005_3x1036_3Warped.nii.gz 1006_3x1013_3Warped.nii.gz 14 | 1004_3x1007_3Warped.nii.gz 1004_3x1122_3Warped.nii.gz 1005_3x1038_3Warped.nii.gz 1006_3x1014_3Warped.nii.gz 15 | 1004_3x1008_3Warped.nii.gz 1004_3x1125_3Warped.nii.gz 1005_3x1039_3Warped.nii.gz 1006_3x1015_3Warped.nii.gz 16 | 1004_3x1009_3Warped.nii.gz 1004_3x1128_3Warped.nii.gz 1005_3x1101_3Warped.nii.gz 1006_3x1017_3Warped.nii.gz 17 | 1004_3x1010_3Warped.nii.gz 1005_3x1000_3Warped.nii.gz 1005_3x1104_3Warped.nii.gz 1006_3x1018_3Warped.nii.gz 18 | 1004_3x1011_3Warped.nii.gz 1005_3x1001_3Warped.nii.gz 1005_3x1107_3Warped.nii.gz 1006_3x1019_3Warped.nii.gz 19 | 1004_3x1012_3Warped.nii.gz 1005_3x1002_3Warped.nii.gz 1005_3x1110_3Warped.nii.gz 1006_3x1023_3Warped.nii.gz 20 | 1004_3x1013_3Warped.nii.gz 1005_3x1003_3Warped.nii.gz 1005_3x1113_3Warped.nii.gz 1006_3x1024_3Warped.nii.gz 21 | 1004_3x1014_3Warped.nii.gz 1005_3x1004_3Warped.nii.gz 1005_3x1116_3Warped.nii.gz 1006_3x1025_3Warped.nii.gz 22 | 1004_3x1015_3Warped.nii.gz 1005_3x1006_3Warped.nii.gz 1005_3x1119_3Warped.nii.gz 1006_3x1036_3Warped.nii.gz 23 | 1004_3x1017_3Warped.nii.gz 1005_3x1007_3Warped.nii.gz 1005_3x1122_3Warped.nii.gz 1006_3x1038_3Warped.nii.gz 24 | 1004_3x1018_3Warped.nii.gz 1005_3x1008_3Warped.nii.gz 1005_3x1125_3Warped.nii.gz 1006_3x1039_3Warped.nii.gz 25 | 1004_3x1019_3Warped.nii.gz 1005_3x1009_3Warped.nii.gz 1005_3x1128_3Warped.nii.gz 1006_3x1101_3Warped.nii.gz 26 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.random import randn 3 | import torch 4 | from torch.nn import Conv2d 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | class MyFcn(torch.nn.Module): 9 | def __init__(self, config): 10 | super(MyFcn, self).__init__() 11 | 12 | self.noise_scale = config.noise_scale 13 | self.num_parameters = len(config.parameters_scale) 14 | 15 | self.conv1 = Conv2d(1, 64, kernel_size=3, stride=1, padding=1) 16 | self.conv2 = Conv2d(64, 64, kernel_size=3, stride=1, padding=2, dilation=2) 17 | self.conv3 = Conv2d(64, 64, kernel_size=3, stride=1, padding=3, dilation=3) 18 | self.conv4 = Conv2d(64, 64, kernel_size=3, stride=1, padding=4, dilation=4) 19 | 20 | self.conv5_pi = Conv2d(64, 64, kernel_size=3, stride=1, padding=3, dilation=3) 21 | self.conv6_pi = Conv2d(64, 64, kernel_size=3, stride=1, padding=2, dilation=2) 22 | self.conv7_pi = Conv2d(64, config.num_actions, kernel_size=3, stride=1, padding=1) 23 | 24 | self.conv5_V = Conv2d(64, 64, kernel_size=3, stride=1, padding=3, dilation=3) 25 | self.conv6_V = Conv2d(64 + self.num_parameters, 64, kernel_size=3, stride=1, padding=2, dilation=2) 26 | self.conv7_V = Conv2d(64, 1, kernel_size=3, stride=1, padding=1) 27 | 28 | self.conv5_p = Conv2d(64, 64, kernel_size=3, stride=1, padding=3, dilation=3) 29 | self.conv6_p = Conv2d(64, 64, kernel_size=3, stride=1, padding=2, dilation=2) 30 | self.conv7_p = Conv2d(64, self.num_parameters, kernel_size=3, stride=1, padding=1) 31 | 32 | def parse_p(self, u_out): 33 | p = torch.mean(u_out.view(u_out.shape[0], u_out.shape[1], -1), dim=2) 34 | return p 35 | 36 | def forward(self, x, flag_a2c=True, add_noise=False): 37 | h = F.relu(self.conv1(x)) 38 | h = F.relu(self.conv2(h)) 39 | h = F.relu(self.conv3(h)) 40 | h = F.relu(self.conv4(h)) 41 | if not flag_a2c: 42 | h = h.detach() 43 | 44 | # pi branch 45 | h_pi = F.relu(self.conv5_pi(h)) 46 | h_pi = F.relu(self.conv6_pi(h_pi)) 47 | pi_out = F.softmax(self.conv7_pi(h_pi), dim=1) 48 | 49 | # p branch 50 | p_out = F.relu(self.conv5_p(h)) 51 | p_out = F.relu(self.conv6_p(p_out)) 52 | p_out = self.conv7_p(p_out) 53 | if flag_a2c: 54 | if add_noise: 55 | p_out = p_out.data + torch.from_numpy(randn(p_out.shape[0], p_out.shape[1], 1, 1).astype(np.float32)).cuda() * self.noise_scale 56 | p_out = Variable(p_out) 57 | else: 58 | p_out = p_out.detach() 59 | p_out = F.sigmoid(p_out) 60 | 61 | # V branch 62 | h_v = F.relu(self.conv5_V(h)) 63 | h_v = torch.cat((h_v, p_out), dim=1) 64 | h_v = F.relu(self.conv6_V(h_v)) 65 | v_out = self.conv7_V(h_v) 66 | 67 | return pi_out, v_out, self.parse_p(p_out) 68 | -------------------------------------------------------------------------------- /MICCAI/prepare_data.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import tensorlayer as tl 3 | import numpy as np 4 | import os 5 | import nibabel as nib 6 | import cv2 7 | 8 | training_data_path = "../../MICCAI-2013-SATA-Challenge-Data-Std-Reg/diencephalon/training-training/warped-images/" 9 | testing_data_path = "../../MICCAI-2013-SATA-Challenge-Data-Std-Reg/diencephalon/training-training/warped-images/" 10 | val_ratio = 0.3 11 | seed = 100 12 | preserving_ratio = 0.1 # filter out 2d images containing < 10% non-zeros 13 | data_saving_path = 'data/' 14 | tl.files.exists_or_mkdir(data_saving_path) 15 | 16 | # dump training images 17 | f_train_all = [] 18 | for line in open('DAGAN_Training_datasets.txt').readlines(): 19 | for l in line.split(): 20 | print(l) 21 | f_train_all.append(l) 22 | 23 | 24 | train_all_num = len(f_train_all) 25 | val_num = int(train_all_num * val_ratio) 26 | 27 | f_train = [] 28 | f_val = [] 29 | 30 | val_idex = tl.utils.get_random_int(min=0, 31 | max=train_all_num - 1, 32 | number=val_num, 33 | seed=seed) 34 | for i in range(train_all_num): 35 | if i in val_idex: 36 | f_val.append(f_train_all[i]) 37 | else: 38 | f_train.append(f_train_all[i]) 39 | 40 | train_3d_num, val_3d_num = len(f_train), len(f_val) 41 | print('number of training volumes: ', train_3d_num) 42 | 43 | X_train = [] 44 | count = 0 45 | train_image_path = data_saving_path + '/train/' 46 | tl.files.exists_or_mkdir(train_image_path) 47 | fw = open(data_saving_path+'/train.txt', 'w') 48 | for fi, f in enumerate(f_train): 49 | print("processing [{}/{}] 3d image ({}) for training set ...".format(fi + 1, train_3d_num, f)) 50 | img_path = os.path.join(training_data_path, f) 51 | img = nib.load(img_path).get_data() 52 | img_3d_max = np.max(img) 53 | img = img / img_3d_max * 255 54 | for i in range(img.shape[2]): 55 | img_2d = img[:, :, i] 56 | # filter out 2d images containing < 10% non-zeros 57 | if float(np.count_nonzero(img_2d)) / img_2d.size >= preserving_ratio: 58 | img_2d = np.transpose(img_2d, (1, 0)) 59 | img_name = str(fi) + '_' + str(i) + '.bmp' 60 | cv2.imwrite(train_image_path + img_name, img_2d) 61 | fw.write(img_name+'\n') 62 | count += 1 63 | print('number of training images: ', count) 64 | 65 | # dump test images 66 | f_test = [] 67 | for line in open('DAGAN_Testing_datasets.txt').readlines(): 68 | for l in line.split(): 69 | print(l) 70 | f_test.append(l) 71 | test_3d_num = len(f_test) 72 | 73 | X_test = [] 74 | test_image_path = data_saving_path + '/test/' 75 | tl.files.exists_or_mkdir(test_image_path) 76 | fw = open(data_saving_path + '/test.txt', 'w') 77 | for fi, f in enumerate(f_test): 78 | print("processing [{}/{}] 3d image ({}) for test set ...".format(fi + 1, test_3d_num, f)) 79 | img_path = os.path.join(testing_data_path, f) 80 | img = nib.load(img_path).get_data() 81 | img_3d_max = np.max(img) 82 | img = img / img_3d_max * 255 83 | for i in range(img.shape[2]): 84 | img_2d = img[:, :, i] 85 | # filter out 2d images containing < 10% non-zeros 86 | if float(np.count_nonzero(img_2d)) / img_2d.size >= preserving_ratio: 87 | img_2d = np.transpose(img_2d, (1, 0)) 88 | X_test.append(img_2d) 89 | 90 | X_test = np.asarray(X_test) 91 | X_test = X_test[:, :, :, np.newaxis] 92 | idex = tl.utils.get_random_int(min=0, max=len(X_test) - 1, number=50, seed=100) 93 | X = X_test[idex] 94 | for i in range(X.shape[0]): 95 | cv2.imwrite(test_image_path + str(i) + '.bmp', X[i, :, :, 0]) 96 | fw.write(str(i) + '.bmp\n') 97 | -------------------------------------------------------------------------------- /fastMRI/subsample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | class MaskFunc: 13 | """ 14 | MaskFunc creates a sub-sampling mask of a given shape. 15 | 16 | The mask selects a subset of columns from the input k-space data. If the k-space data has N 17 | columns, the mask picks out: 18 | 1. N_low_freqs = (N * center_fraction) columns in the center corresponding to 19 | low-frequencies 20 | 2. The other columns are selected uniformly at random with a probability equal to: 21 | prob = (N / acceleration - N_low_freqs) / (N - N_low_freqs). 22 | This ensures that the expected number of columns selected is equal to (N / acceleration) 23 | 24 | It is possible to use multiple center_fractions and accelerations, in which case one possible 25 | (center_fraction, acceleration) is chosen uniformly at random each time the MaskFunc object is 26 | called. 27 | 28 | For example, if accelerations = [4, 8] and center_fractions = [0.08, 0.04], then there 29 | is a 50% probability that 4-fold acceleration with 8% center fraction is selected and a 50% 30 | probability that 8-fold acceleration with 4% center fraction is selected. 31 | """ 32 | 33 | def __init__(self, center_fractions, accelerations): 34 | """ 35 | Args: 36 | center_fractions (List[float]): Fraction of low-frequency columns to be retained. 37 | If multiple values are provided, then one of these numbers is chosen uniformly 38 | each time. 39 | 40 | accelerations (List[int]): Amount of under-sampling. This should have the same length 41 | as center_fractions. If multiple values are provided, then one of these is chosen 42 | uniformly each time. An acceleration of 4 retains 25% of the columns, but they may 43 | not be spaced evenly. 44 | """ 45 | if len(center_fractions) != len(accelerations): 46 | raise ValueError('Number of center fractions should match number of accelerations') 47 | 48 | self.center_fractions = center_fractions 49 | self.accelerations = accelerations 50 | self.rng = np.random.RandomState() 51 | 52 | def __call__(self, shape, seed=None): 53 | """ 54 | Args: 55 | shape (iterable[int]): The shape of the mask to be created. The shape should have 56 | at least 3 dimensions. Samples are drawn along the second last dimension. 57 | seed (int, optional): Seed for the random number generator. Setting the seed 58 | ensures the same mask is generated each time for the same shape. 59 | Returns: 60 | torch.Tensor: A mask of the specified shape. 61 | """ 62 | if len(shape) < 3: 63 | raise ValueError('Shape should have 3 or more dimensions') 64 | 65 | self.rng.seed(seed) 66 | num_cols = shape[-2] 67 | 68 | choice = self.rng.randint(0, len(self.accelerations)) 69 | center_fraction = self.center_fractions[choice] 70 | acceleration = self.accelerations[choice] 71 | 72 | # Create the mask 73 | num_low_freqs = int(round(num_cols * center_fraction)) 74 | prob = (num_cols / acceleration - num_low_freqs) / (num_cols - num_low_freqs) 75 | mask = self.rng.uniform(size=num_cols) < prob 76 | pad = (num_cols - num_low_freqs + 1) // 2 77 | mask[pad:pad + num_low_freqs] = True 78 | 79 | # Reshape the mask 80 | mask_shape = [1 for _ in shape] 81 | mask_shape[-2] = num_cols 82 | mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) 83 | 84 | return mask 85 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage.measure 3 | import scipy 4 | 5 | 6 | def fft_shift(x): 7 | fft = scipy.fftpack.fft2(x) 8 | fft = scipy.fftpack.fftshift(fft) 9 | return fft 10 | 11 | 12 | def shift_ifft(fft): 13 | fft = scipy.fftpack.ifftshift(fft) 14 | x = scipy.fftpack.ifft2(fft) 15 | return x 16 | 17 | 18 | def Downsample(x, mask): 19 | fft = scipy.fftpack.fft2(x) 20 | fft_good = scipy.fftpack.fftshift(fft) 21 | fft_bad = fft_good * mask 22 | fft = scipy.fftpack.ifftshift(fft_bad) 23 | x = scipy.fftpack.ifft2(fft) 24 | # x = np.abs(x) 25 | x = np.real(x) 26 | return x, fft_good, fft_bad 27 | 28 | 29 | def SSIM(x_good, x_bad): 30 | assert len(x_good.shape) == 2 31 | ssim_res = skimage.measure.compare_ssim(x_good, x_bad) 32 | return ssim_res 33 | 34 | 35 | def PSNR(x_good, x_bad): 36 | assert len(x_good.shape) == 2 37 | psnr_res = skimage.measure.compare_psnr(x_good, x_bad) 38 | return psnr_res 39 | 40 | 41 | def NMSE(x_good, x_bad): 42 | assert len(x_good.shape) == 2 43 | nmse_a_0_1 = np.sum((x_good - x_bad) ** 2) 44 | nmse_b_0_1 = np.sum(x_good ** 2) 45 | # this is DAGAN implementation, which is wrong 46 | nmse_a_0_1, nmse_b_0_1 = np.sqrt(nmse_a_0_1), np.sqrt(nmse_b_0_1) 47 | nmse_0_1 = nmse_a_0_1 / nmse_b_0_1 48 | return nmse_0_1 49 | 50 | 51 | def computePSNR(o_, p_, i_): 52 | return PSNR(o_, p_), PSNR(o_, i_) 53 | 54 | 55 | def computeSSIM(o_, p_, i_): 56 | return SSIM(o_, p_), SSIM(o_, i_) 57 | 58 | 59 | def computeNMSE(o_, p_, i_): 60 | return NMSE(o_, p_), NMSE(o_, i_) 61 | 62 | 63 | def DC(x_good, x_rec, mask): 64 | fft_good = fft_shift(x_good) 65 | fft_rec = fft_shift(x_rec) 66 | fft = fft_good * mask + fft_rec * (1 - mask) 67 | x = shift_ifft(fft) 68 | x = np.real(x) 69 | #x = np.abs(x) 70 | return x 71 | 72 | 73 | def adjust_learning_rate(optimizer, iters, base_lr, policy_parameter, policy='step', multiple=[1]): 74 | ''' 75 | source: https://github.com/last-one/Pytorch_Realtime_Multi-Person_Pose_Estimation/blob/master/utils.py 76 | ''' 77 | if policy == 'fixed': 78 | lr = base_lr 79 | elif policy == 'step': 80 | lr = base_lr * (policy_parameter['gamma'] ** (iters // policy_parameter['step_size'])) 81 | elif policy == 'exp': 82 | lr = base_lr * (policy_parameter['gamma'] ** iters) 83 | elif policy == 'inv': 84 | lr = base_lr * ((1 + policy_parameter['gamma'] * iters) ** (-policy_parameter['power'])) 85 | elif policy == 'multistep': 86 | lr = base_lr 87 | for stepvalue in policy_parameter['stepvalue']: 88 | if iters >= stepvalue: 89 | lr *= policy_parameter['gamma'] 90 | else: 91 | break 92 | elif policy == 'poly': 93 | lr = base_lr * ((1 - iters * 1.0 / policy_parameter['max_iter']) ** policy_parameter['power']) 94 | elif policy == 'sigmoid': 95 | lr = base_lr * (1.0 / (1 + math.exp(-policy_parameter['gamma'] * (iters - policy_parameter['stepsize'])))) 96 | elif policy == 'multistep-poly': 97 | lr = base_lr 98 | stepstart = 0 99 | stepend = policy_parameter['max_iter'] 100 | for stepvalue in policy_parameter['stepvalue']: 101 | if iters >= stepvalue: 102 | lr *= policy_parameter['gamma'] 103 | stepstart = stepvalue 104 | else: 105 | stepend = stepvalue 106 | break 107 | lr = max(lr * policy_parameter['gamma'], lr * (1 - (iters - stepstart) * 1.0 / (stepend - stepstart)) ** policy_parameter['power']) 108 | 109 | for i, param_group in enumerate(optimizer.param_groups): 110 | param_group['lr'] = lr * multiple[i] 111 | return lr 112 | 113 | if __name__ == "__main__": 114 | pass 115 | -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import cv2 4 | import torch 5 | from skimage.measure import compare_ssim 6 | 7 | class Env(): 8 | def __init__(self, config): 9 | self.image = None 10 | self.previous_image = None 11 | 12 | self.num_actions = config.num_actions 13 | self.actions = config.actions 14 | 15 | self.parameters_scale = config.parameters_scale 16 | self.parameters = dict() 17 | self.set_param([0.5] * len(self.parameters_scale)) 18 | 19 | self.reward_method = config.reward_method 20 | 21 | def reset(self, ori_image, image): 22 | self.ori_image = ori_image 23 | self.image = image 24 | self.previous_image = None 25 | return 26 | 27 | def set_param(self, p): 28 | for i, k in enumerate(sorted(self.parameters_scale.keys())): 29 | self.parameters[k] = p[i] * self.parameters_scale[k] 30 | return 31 | 32 | def step(self, act): 33 | self.previous_image = self.image.copy() 34 | 35 | canvas = [np.zeros(self.image.shape, self.image.dtype) for _ in range(self.num_actions + 1)] 36 | b, c, h, w = self.image.shape 37 | for i in range(b): 38 | canvas[0][i, 0] = self.image[i,0] 39 | canvas[self.actions['subtraction']][i, 0] = self.image[i,0] - 3. / 255 40 | 41 | if np.sum(act[i] == self.actions['box']) > 0: 42 | canvas[self.actions['box']][i, 0] = cv2.boxFilter(self.image[i,0], ddepth=-1, ksize=(5,5)) 43 | 44 | if np.sum(act[i] == self.actions['bilateral']) > 0: 45 | canvas[self.actions['bilateral']][i, 0] = cv2.bilateralFilter(self.image[i,0], d=5, sigmaColor=0.1, sigmaSpace=5) 46 | 47 | if True: 48 | canvas[self.actions['Gaussian']][i, 0] = cv2.GaussianBlur(self.image[i,0], ksize=(5,5), sigmaX=0.5) 49 | 50 | if np.sum(act[i] == self.actions['median']) > 0: 51 | canvas[self.actions['median']][i, 0] = cv2.medianBlur(self.image[i,0], ksize=5) 52 | 53 | if np.sum(act[i] == self.actions['Laplace']) > 0: 54 | p = self.parameters['Laplace'][i] 55 | k = np.array([[0, -p, 0], [-p, 1 + 4 * p, -p], [0, -p, 0]]) 56 | canvas[self.actions['Laplace']][i, 0] = cv2.filter2D(self.image[i, 0], -1, kernel=k) 57 | 58 | if np.sum(act[i] == self.actions['unsharp']) > 0: 59 | amount = self.parameters['unsharp'][i] 60 | canvas[self.actions['unsharp']][i, 0] = self.image[i, 0] * (1 + amount) - canvas[self.actions['Gaussian']][i, 0] * amount 61 | 62 | if np.sum(act[i] == self.actions['Sobel_v1']) > 0: 63 | p = self.parameters['Sobel_v1'][i] 64 | k = np.array([[p, 0, -p], [2 * p, 1, -2 * p], [p, 0, -p]]) 65 | canvas[self.actions['Sobel_v1']][i, 0] = cv2.filter2D(self.image[i, 0], -1, kernel=k) 66 | 67 | if np.sum(act[i] == self.actions['Sobel_v2']) > 0: 68 | p = self.parameters['Sobel_v2'][i] 69 | k = np.array([[-p, 0, p], [-2 * p, 1, 2 * p], [-p, 0, p]]) 70 | canvas[self.actions['Sobel_v2']][i, 0] = cv2.filter2D(self.image[i, 0], -1, kernel=k) 71 | 72 | if np.sum(act[i] == self.actions['Sobel_h1']) > 0: 73 | p = self.parameters['Sobel_h1'][i] 74 | k = np.array([[-p,-2 * p,-p], [0, 1, 0], [p, 2 * p, p]]) 75 | canvas[self.actions['Sobel_h1']][i, 0] = cv2.filter2D(self.image[i, 0], -1, kernel=k) 76 | 77 | if np.sum(act[i] == self.actions['Sobel_h2']) > 0: 78 | p = self.parameters['Sobel_h2'][i] 79 | k = np.array([[p, 2 * p, p], [0, 1, 0], [-p, -2 * p, -p]]) 80 | canvas[self.actions['Sobel_h2']][i, 0] = cv2.filter2D(self.image[i, 0], -1, kernel=k) 81 | 82 | for a in range(1, self.num_actions + 1): 83 | self.image = np.where(act[:,np.newaxis,:,:] == a, canvas[a], self.image) 84 | self.image = np.clip(self.image, 0, 1) 85 | 86 | if self.reward_method == 'square': 87 | reward = np.square(self.ori_image - self.previous_image) * 255 - np.square(self.ori_image - self.image) * 255 88 | elif self.reward_method == 'abs': 89 | reward = np.abs(self.ori_image - self.previous_image) * 255 - np.abs(self.ori_image - self.image) * 255 90 | 91 | return self.image, reward 92 | -------------------------------------------------------------------------------- /pixel_wise_a2c.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torch.distributions import Categorical 4 | 5 | 6 | class PixelWiseA2C: 7 | """A2C: Advantage Actor-Critic. 8 | 9 | Args: 10 | model (A3CModel): Model to train 11 | gamma (float): Discount factor [0,1] 12 | beta (float): Weight coefficient for the entropy regularizaiton term. 13 | pi_loss_coeff(float): Weight coefficient for the loss of the policy 14 | v_loss_coeff(float): Weight coefficient for the loss of the value 15 | function 16 | """ 17 | 18 | def __init__(self, config): 19 | 20 | self.gamma = config.gamma 21 | self.beta = config.beta 22 | self.pi_loss_coeff = config.pi_loss_coeff 23 | self.v_loss_coeff = config.v_loss_coeff 24 | 25 | self.t = 0 26 | self.t_start = 0 27 | self.past_action_log_prob = {} 28 | self.past_action_entropy = {} 29 | self.past_rewards = {} 30 | self.past_values = {} 31 | 32 | def reset(self): 33 | self.past_action_log_prob = {} 34 | self.past_action_entropy = {} 35 | self.past_states = {} 36 | self.past_rewards = {} 37 | self.past_values = {} 38 | 39 | self.t_start = 0 40 | self.t = 0 41 | 42 | def compute_loss(self): 43 | assert self.t_start < self.t 44 | R = 0 45 | 46 | pi_loss = 0 47 | v_loss = 0 48 | entropy_loss = 0 49 | for i in reversed(range(self.t_start, self.t)): 50 | R *= self.gamma 51 | R += self.past_rewards[i] 52 | v = self.past_values[i] 53 | advantage = R - v.detach() 54 | selected_log_prob = self.past_action_log_prob[i] 55 | entropy = self.past_action_entropy[i] 56 | 57 | # Log probability is increased proportionally to advantage 58 | pi_loss -= selected_log_prob * advantage 59 | # Entropy is maximized 60 | entropy_loss -= entropy 61 | # Accumulate gradients of value function 62 | v_loss += (v - R) ** 2 63 | 64 | if self.pi_loss_coeff != 1.0: 65 | pi_loss *= self.pi_loss_coeff 66 | 67 | if self.v_loss_coeff != 1.0: 68 | v_loss *= self.v_loss_coeff 69 | 70 | entropy_loss *= self.beta 71 | 72 | losses = dict() 73 | losses['pi_loss'] = pi_loss.mean() 74 | losses['v_loss'] = v_loss.view(pi_loss.shape).mean() 75 | losses['entropy_loss'] = entropy_loss.mean() 76 | return losses 77 | 78 | def act_and_train(self, pi, value, reward): 79 | self.past_rewards[self.t - 1] = reward 80 | 81 | def randomly_choose_actions(pi): 82 | pi = torch.clamp(pi, min=0) 83 | n, num_actions, h, w = pi.shape 84 | pi_reshape = pi.permute(0, 2, 3, 1).contiguous().view(-1, num_actions) 85 | m = Categorical(pi_reshape.data) 86 | actions = m.sample() 87 | 88 | log_pi_reshape = torch.log(torch.clamp(pi_reshape, min=1e-9, max=1-1e-9)) 89 | entropy = -torch.sum(pi_reshape * log_pi_reshape, dim=-1).view(n, 1, h, w) 90 | 91 | selected_log_prob = torch.gather(log_pi_reshape, 1, Variable(actions.unsqueeze(-1))).view(n, 1, h, w) 92 | 93 | actions = actions.view(n, h, w) 94 | 95 | return actions, entropy, selected_log_prob 96 | 97 | actions, entropy, selected_log_prob = randomly_choose_actions(pi) 98 | 99 | self.past_action_log_prob[self.t] = selected_log_prob 100 | self.past_action_entropy[self.t] = entropy 101 | self.past_values[self.t] = value 102 | self.t += 1 103 | return actions.cpu().numpy() 104 | 105 | def act(self, pi, deterministic=True): 106 | if deterministic: 107 | _, actions = torch.max(pi.data, dim=1) 108 | else: 109 | pi = torch.clamp(pi.data, min=0) 110 | n, num_actions, h, w = pi.shape 111 | pi_reshape = pi.permute(0, 2, 3, 1).contiguous().view(-1, num_actions) 112 | m = Categorical(pi_reshape) 113 | actions = m.sample() 114 | actions = actions.view(n, h, w) 115 | 116 | return actions.cpu().numpy() 117 | 118 | def stop_episode_and_compute_loss(self, reward, done=False): 119 | self.past_rewards[self.t - 1] = reward 120 | if done: 121 | losses = self.compute_loss() 122 | else: 123 | raise Exception 124 | self.reset() 125 | return losses 126 | -------------------------------------------------------------------------------- /unet/unet_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | 13 | class ConvBlock(nn.Module): 14 | """ 15 | A Convolutional Block that consists of two convolution layers each followed by 16 | instance normalization, relu activation and dropout. 17 | """ 18 | 19 | def __init__(self, in_chans, out_chans, drop_prob): 20 | """ 21 | Args: 22 | in_chans (int): Number of channels in the input. 23 | out_chans (int): Number of channels in the output. 24 | drop_prob (float): Dropout probability. 25 | """ 26 | super().__init__() 27 | 28 | self.in_chans = in_chans 29 | self.out_chans = out_chans 30 | self.drop_prob = drop_prob 31 | 32 | self.layers = nn.Sequential( 33 | nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1), 34 | nn.InstanceNorm2d(out_chans), 35 | nn.ReLU(), 36 | nn.Dropout2d(drop_prob), 37 | nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1), 38 | nn.InstanceNorm2d(out_chans), 39 | nn.ReLU(), 40 | nn.Dropout2d(drop_prob) 41 | ) 42 | 43 | def forward(self, input): 44 | """ 45 | Args: 46 | input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] 47 | 48 | Returns: 49 | (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] 50 | """ 51 | return self.layers(input) 52 | 53 | def __repr__(self): 54 | return f'ConvBlock(in_chans={self.in_chans}, out_chans={self.out_chans}, ' \ 55 | f'drop_prob={self.drop_prob})' 56 | 57 | 58 | class UnetModel(nn.Module): 59 | """ 60 | PyTorch implementation of a U-Net model. 61 | 62 | This is based on: 63 | Olaf Ronneberger, Philipp Fischer, and Thomas Brox. U-net: Convolutional networks 64 | for biomedical image segmentation. In International Conference on Medical image 65 | computing and computer-assisted intervention, pages 234–241. Springer, 2015. 66 | """ 67 | 68 | def __init__(self, in_chans, out_chans, chans, num_pool_layers, drop_prob): 69 | """ 70 | Args: 71 | in_chans (int): Number of channels in the input to the U-Net model. 72 | out_chans (int): Number of channels in the output to the U-Net model. 73 | chans (int): Number of output channels of the first convolution layer. 74 | num_pool_layers (int): Number of down-sampling and up-sampling layers. 75 | drop_prob (float): Dropout probability. 76 | """ 77 | super().__init__() 78 | 79 | self.in_chans = in_chans 80 | self.out_chans = out_chans 81 | self.chans = chans 82 | self.num_pool_layers = num_pool_layers 83 | self.drop_prob = drop_prob 84 | 85 | self.down_sample_layers = nn.ModuleList([ConvBlock(in_chans, chans, drop_prob)]) 86 | ch = chans 87 | for i in range(num_pool_layers - 1): 88 | self.down_sample_layers += [ConvBlock(ch, ch * 2, drop_prob)] 89 | ch *= 2 90 | self.conv = ConvBlock(ch, ch, drop_prob) 91 | 92 | self.up_sample_layers = nn.ModuleList() 93 | for i in range(num_pool_layers - 1): 94 | self.up_sample_layers += [ConvBlock(ch * 2, ch // 2, drop_prob)] 95 | ch //= 2 96 | self.up_sample_layers += [ConvBlock(ch * 2, ch, drop_prob)] 97 | self.conv2 = nn.Sequential( 98 | nn.Conv2d(ch, ch // 2, kernel_size=1), 99 | nn.Conv2d(ch // 2, out_chans, kernel_size=1), 100 | nn.Conv2d(out_chans, out_chans, kernel_size=1), 101 | ) 102 | 103 | def forward(self, input): 104 | """ 105 | Args: 106 | input (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width] 107 | 108 | Returns: 109 | (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width] 110 | """ 111 | stack = [] 112 | output = input 113 | # Apply down-sampling layers 114 | for layer in self.down_sample_layers: 115 | output = layer(output) 116 | stack.append(output) 117 | output = F.max_pool2d(output, kernel_size=2) 118 | 119 | output = self.conv(output) 120 | 121 | # Apply up-sampling layers 122 | for layer in self.up_sample_layers: 123 | #output = F.interpolate(output, scale_factor=2, mode='bilinear', align_corners=False) 124 | output = F.upsample(output, scale_factor=2, mode='bilinear') 125 | output = torch.cat([output, stack.pop()], dim=1) 126 | output = layer(output) 127 | return self.conv2(output) 128 | -------------------------------------------------------------------------------- /fastMRI/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | from utils import Downsample 7 | from mri_data import SliceData 8 | from subsample import MaskFunc 9 | from transforms import normalize_instance, to_tensor 10 | 11 | class DataTransform: 12 | """ 13 | Data Transformer for training U-Net models. 14 | """ 15 | 16 | def __init__(self, mask_func, resolution, which_challenge, use_seed=True, normalize=False): 17 | """ 18 | Args: 19 | mask_func (common.subsample.MaskFunc): A function that can create a mask of 20 | appropriate shape. 21 | resolution (int): Resolution of the image. 22 | which_challenge (str): Either "singlecoil" or "multicoil" denoting the dataset. 23 | use_seed (bool): If true, this class computes a pseudo random number generator seed 24 | from the filename. This ensures that the same mask is used for all the slices of 25 | a given volume every time. 26 | """ 27 | if which_challenge not in ('singlecoil', 'multicoil'): 28 | raise ValueError(f'Challenge should either be "singlecoil" or "multicoil"') 29 | self.mask_func = mask_func 30 | self.resolution = resolution 31 | self.which_challenge = which_challenge 32 | self.use_seed = use_seed 33 | self.normalize = normalize 34 | 35 | def __call__(self, kspace, target, attrs, fname, slice): 36 | """ 37 | Args: 38 | kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil 39 | data or (rows, cols, 2) for single coil data. 40 | target (numpy.array): Target image 41 | attrs (dict): Acquisition related information stored in the HDF5 object. 42 | fname (str): File name 43 | slice (int): Serial number of the slice. 44 | Returns: 45 | (tuple): tuple containing: 46 | image (torch.Tensor): Zero-filled input image. 47 | target (torch.Tensor): Target image converted to a torch Tensor. 48 | mean (float): Mean value used for normalization. 49 | std (float): Standard deviation value used for normalization. 50 | norm (float): L2 norm of the entire volume. 51 | """ 52 | 53 | # this is the original normalization method from fastMRI official code 54 | def normalize_image(x): 55 | x, mean, std = normalize_instance(x, eps=1e-11) 56 | x = x.clip(-6, 6) 57 | return x 58 | 59 | if target is not None: 60 | target = normalize_image(target) 61 | else: 62 | target = [0] 63 | 64 | # Apply mask 65 | seed = None if not self.use_seed else tuple(map(ord, fname)) 66 | mask = self.mask_func(target.shape + (2,), seed) 67 | mask = mask[:, :, 0].numpy() 68 | 69 | m = min(float(np.min(target)), 0) 70 | target_01 = (target - m) / (6 - m) # normalization into the range [0, 1] 71 | image, _, _ = Downsample(target_01, mask) 72 | if self.normalize: 73 | target = target_01 74 | else: 75 | image = image * (6 - m) + m # for unet, to scale back 76 | #else: 77 | # image, _, _ = Downsample(target - m, mask) # make sure that the data are non-negative before downsampling 78 | # image += m 79 | 80 | target = to_tensor(target) 81 | image = to_tensor(image) 82 | mask = to_tensor(mask) 83 | return target.unsqueeze(0).float(), image.unsqueeze(0).float(), mask.float() 84 | 85 | 86 | def MRIDataset(image_set, transform, config): 87 | ''' 88 | transform: rescale the image into [0, 1] 89 | For our model, set transform True. 90 | For unet, set transform False. 91 | No data augmentation is implemented for fastMRI. 92 | ''' 93 | 94 | train_mask = MaskFunc(*config.sampling_scheme) 95 | 96 | if image_set == 'train': 97 | root = config.root + '/singlecoil_train' 98 | elif image_set == 'test': 99 | root = config.root + '/singlecoil_val' 100 | dataset = SliceData( 101 | root=root, 102 | transform=DataTransform(train_mask, config.resolution, 'singlecoil', normalize=transform), 103 | sample_rate=1., 104 | challenge='singlecoil', 105 | ) 106 | return dataset 107 | 108 | if __name__ == '__main__': 109 | from config import config 110 | test_loader = DataLoader( 111 | dataset=MRIDataset('test', True, config), 112 | batch_size=1, 113 | shuffle=False, 114 | num_workers=1, 115 | pin_memory=True, 116 | ) 117 | 118 | import cv2 119 | for _, data in enumerate(test_loader): 120 | ori_image, image, _ = data 121 | image = image.numpy() 122 | target = ori_image.numpy() 123 | cv2.imshow('test.jpg', image[0, 0] * 255) 124 | cv2.imshow('test_ori.jpg', target[0, 0] * 255) 125 | cv2.waitKey(0) 126 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import numpy as np 6 | import cv2 7 | from collections import defaultdict 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | from torch.autograd import Variable 14 | 15 | from tensorboardX import SummaryWriter 16 | 17 | from env import Env 18 | from model import MyFcn 19 | from pixel_wise_a2c import PixelWiseA2C 20 | from utils import PSNR, SSIM, NMSE, DC, computePSNR, computeSSIM, computeNMSE 21 | 22 | def parse(): 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--dataset', default='MICCAI', type=str, 26 | dest='dataset', help='to use dataset.py and config.py in which directory') 27 | parser.add_argument('--gpu', default=[0, 1], nargs='+', type=int, 28 | dest='gpu', help='the gpu used') 29 | parser.add_argument('--model', type=str, help='file of the trained model') 30 | 31 | return parser.parse_args() 32 | 33 | 34 | def test(model, a2c, config, early_break=True, batch_size=None, verbose=False): 35 | if batch_size is None: 36 | batch_size = config.batch_size 37 | env = Env(config) 38 | if not os.path.exists('results/'): 39 | os.mkdir('results/') 40 | 41 | from dataset import MRIDataset 42 | test_loader = torch.utils.data.DataLoader( 43 | dataset = MRIDataset(image_set='test', transform=(config.dataset=='fastMRI_data'), config=config), 44 | batch_size=batch_size, shuffle=False, 45 | num_workers=1, pin_memory=True) 46 | 47 | reward_sum = 0 48 | p_list = defaultdict(list) 49 | PSNR_dict = defaultdict(list) 50 | SSIM_dict = defaultdict(list) 51 | NMSE_dict = defaultdict(list) 52 | count = 0 53 | actions_prob = np.zeros((config.num_actions, config.episode_len)) 54 | image_history = dict() 55 | 56 | for i, (ori_image, image, mask) in enumerate(test_loader): 57 | count += 1 58 | if early_break and count == 101: # test only part of the dataset 59 | break 60 | if count % 100 == 0: 61 | print('tested: ', count) 62 | 63 | ori_image = ori_image.numpy() 64 | image = image.numpy() 65 | previous_image = image.copy() 66 | env.reset(ori_image=ori_image, image=image) 67 | 68 | for t in range(config.episode_len): 69 | if verbose: 70 | image_history[t] = image 71 | image_input = Variable(torch.from_numpy(image).cuda(), volatile=True) 72 | pi_out, v_out, p = model(image_input, flag_a2c=True) 73 | 74 | p = p.permute(1, 0).cpu().data.numpy() 75 | env.set_param(p) 76 | p_list[t].append(p) 77 | 78 | actions = a2c.act(pi_out, deterministic=True) 79 | last_image = image.copy() 80 | image, reward = env.step(actions) 81 | image = np.clip(image, 0, 1) 82 | 83 | reward_sum += np.mean(reward) 84 | 85 | actions = actions.astype(np.uint8) 86 | prob = pi_out.cpu().data.numpy() 87 | total = actions.size 88 | for n in range(config.num_actions): 89 | actions_prob[n, t] += np.sum(actions==n) / total 90 | 91 | # draw action distribution on pixels 92 | for j in range(ori_image.shape[0]): 93 | if i > 0: 94 | break 95 | for dd in ['results/actions/', 'results/time_steps']: 96 | if not os.path.exists(dd + str(j)): 97 | os.mkdir(dd + str(j)) 98 | a = actions[j].astype(np.uint8) 99 | total = a.size 100 | canvas = last_image[j, 0].copy() 101 | unchanged_mask = np.abs(last_image[j, 0] - image[j, 0]) < (1 / 255) # some pixel values are not changed 102 | for n in range(config.num_actions): 103 | A = np.tile(canvas[..., np.newaxis], (1, 1, 3)) * 255 104 | a_mask = (a==n) & (1 - unchanged_mask).astype(np.bool) 105 | A[a_mask, 2] += 250 106 | cv2.imwrite('results/actions/' + str(j) + '/' + str(n) + '_' + str(t) +'.bmp', A) 107 | cv2.imwrite('results/actions/' + str(t) + '_unchanged.jpg', np.abs(last_image[j, 0] - image[j, 0]) * 255 * 255) 108 | 109 | for j in range(image.shape[0]): 110 | if 'fastMRI' in config.dataset: 111 | mask_j = mask.numpy()[j] 112 | mask_j = np.tile(mask_j, (image.shape[2] ,1)) 113 | else: 114 | mask_j = test_loader.dataset.mask 115 | image_with_DC = DC(ori_image[j, 0], image[j, 0], mask_j) 116 | image_with_DC = np.clip(image_with_DC, 0, 1) 117 | for k in range(2): 118 | key = ['wo', 'DC'][k] 119 | tmp_image = [image[j, 0], image_with_DC][k] 120 | PSNR_dict[key].append(computePSNR(ori_image[j, 0], previous_image[j, 0], tmp_image)) 121 | SSIM_dict[key].append(computeSSIM(ori_image[j, 0], previous_image[j, 0], tmp_image)) 122 | NMSE_dict[key].append(computeNMSE(ori_image[j, 0], previous_image[j, 0], tmp_image)) 123 | if verbose: 124 | print(j, key, PSNR_dict[key][-1], SSIM_dict[key][-1], NMSE_dict[key][-1]) 125 | 126 | # draw input, output and error maps 127 | cv2.imwrite('results/'+str(i)+'_'+str(j)+'.bmp', np.concatenate((ori_image[j, 0], mask_j, previous_image[j, 0], image[j, 0], image_with_DC, np.abs(ori_image[j, 0] - image[j, 0]) * 10), axis=1) * 255) 128 | # draw output of different timesteps 129 | if verbose: 130 | cv2.imwrite('results/time_steps/'+str(i)+'_'+str(j)+'.bmp', np.concatenate([image_history[jj][j, 0] for jj in range(config.episode_len)] + [image[j, 0], image_with_DC, ori_image[j, 0]], axis=1) * 255) 131 | 132 | print('actions_prob', actions_prob / count) 133 | 134 | for key in PSNR_dict.keys(): 135 | PSNR_list, SSIM_list, NMSE_list = map(lambda x: x[key], [PSNR_dict, SSIM_dict, NMSE_dict]) 136 | print('number of test images: ', len(PSNR_list)) 137 | psnr_res = np.mean(np.array(PSNR_list), axis=0) 138 | ssim_res = np.mean(np.array(SSIM_list), axis=0) 139 | nmse_res = np.mean(np.array(NMSE_list), axis=0) 140 | 141 | print('PSNR', psnr_res) 142 | print('SSIM', ssim_res) 143 | print('NMSE', nmse_res) 144 | 145 | for t in range(config.episode_len): 146 | print('parameters at {}: '.format(t), np.mean(np.concatenate(p_list[t], axis=1), axis=1)) 147 | 148 | avg_reward = reward_sum / (i + 1) 149 | print('test finished: reward ', avg_reward) 150 | 151 | return avg_reward, psnr_res, ssim_res 152 | 153 | 154 | if __name__ == "__main__": 155 | args = parse() 156 | sys.path.append(args.dataset) 157 | from config import config 158 | 159 | torch.backends.cudnn.benchmark = True 160 | 161 | env = Env(config) 162 | model = MyFcn(config) 163 | model.load_state_dict(torch.load(args.model)) 164 | model = torch.nn.DataParallel(model, device_ids=args.gpu).cuda() 165 | a2c = PixelWiseA2C(config) 166 | 167 | avg_reward, psnr_res, ssim_res = test(model, a2c, config, early_break=False, batch_size=50, verbose=True) 168 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | import numpy as np 6 | import cv2 7 | from collections import defaultdict 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | from torch.autograd import Variable 14 | 15 | from tensorboardX import SummaryWriter 16 | 17 | from env import Env 18 | from model import MyFcn 19 | from pixel_wise_a2c import PixelWiseA2C 20 | from test import test 21 | 22 | from utils import adjust_learning_rate 23 | from utils import PSNR, SSIM, NMSE, DC, computePSNR, computeSSIM, computeNMSE 24 | 25 | def parse(): 26 | 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--dataset', default='MICCAI', type=str, 29 | dest='dataset', help='to use dataset.py and config.py in which directory') 30 | parser.add_argument('--gpu', default=[0, 1], nargs='+', type=int, 31 | dest='gpu', help='the gpu used') 32 | 33 | return parser.parse_args() 34 | 35 | 36 | def train(): 37 | torch.backends.cudnn.benchmark = False 38 | 39 | # load config 40 | args = parse() 41 | sys.path.append(args.dataset) 42 | from config import config 43 | assert config.switch % config.iter_size == 0 44 | time_tuple = time.localtime(time.time()) 45 | log_dir = './logs/' + '_'.join(map(lambda x: str(x), time_tuple[1:4])) 46 | print('log_dir: ', log_dir) 47 | writer = SummaryWriter(log_dir) 48 | if not os.path.exists('model/'): 49 | os.mkdir('model/') 50 | 51 | # dataset 52 | from dataset import MRIDataset 53 | train_loader = torch.utils.data.DataLoader( 54 | dataset = MRIDataset(image_set='train', transform=True, config=config), 55 | batch_size=config.batch_size, shuffle=True, 56 | num_workers=config.workers, pin_memory=True) 57 | 58 | env = Env(config) 59 | a2c = PixelWiseA2C(config) 60 | 61 | episodes = 0 62 | model = MyFcn(config) 63 | if len(config.resume_model) > 0: # resume training 64 | model.load_state_dict(torch.load(config.resume_model)) 65 | episodes = int(config.resume_model.split('.')[0].split('_')[-1]) 66 | print('resume from episodes {}'.format(episodes)) 67 | model = torch.nn.DataParallel(model, device_ids=args.gpu).cuda() 68 | 69 | # construct optimizers for a2c and ddpg 70 | parameters_wo_p = [value for key, value in dict(model.module.named_parameters()).items() if '_p.' not in key] 71 | optimizer = torch.optim.Adam(parameters_wo_p, config.base_lr) 72 | 73 | parameters_p = [value for key, value in dict(model.module.named_parameters()).items() if '_p.' in key] 74 | #parameters_pi = [value for key, value in dict(model.module.named_parameters()).items() if '_pi.' in key] 75 | params = [ 76 | {'params': parameters_p, 'lr': config.base_lr}, 77 | ] 78 | optimizer_p = torch.optim.SGD(params, config.base_lr) 79 | 80 | # training 81 | flag_a2c = True # if True, use a2c; if False, use ddpg 82 | while episodes < config.num_episodes: 83 | 84 | for i, (ori_image, image, _) in enumerate(train_loader): 85 | # adjust learning rate 86 | learning_rate = adjust_learning_rate(optimizer, episodes, config.base_lr, policy=config.lr_policy, policy_parameter=config.policy_parameter) 87 | _ = adjust_learning_rate(optimizer_p, episodes, config.base_lr, policy=config.lr_policy, policy_parameter=config.policy_parameter) 88 | 89 | ori_image = ori_image.numpy() 90 | image = image.numpy() 91 | env.reset(ori_image=ori_image, image=image) 92 | reward = np.zeros((1)) 93 | 94 | # forward 95 | if not flag_a2c: 96 | v_out_dict = dict() 97 | for t in range(config.episode_len): 98 | image_input = Variable(torch.from_numpy(image).cuda()) 99 | reward_input = Variable(torch.from_numpy(reward).cuda()) 100 | pi_out, v_out, p = model(image_input, flag_a2c, add_noise=flag_a2c) 101 | if flag_a2c: 102 | actions = a2c.act_and_train(pi_out, v_out, reward_input) 103 | else: 104 | v_out_dict[t] = - v_out.mean() 105 | actions = a2c.act(pi_out, deterministic=True) 106 | 107 | p = p.cpu().data.numpy().transpose(1, 0) 108 | env.set_param(p) 109 | previous_image = image 110 | image, reward = env.step(actions) 111 | 112 | if not(episodes % config.display): 113 | print('\na2c: ', flag_a2c) 114 | print('episode {}: reward@{} = {:.4f}'.format(episodes, t, np.mean(reward))) 115 | for k, v in env.parameters.items(): 116 | print(k, ' parameters: ', v.mean()) 117 | print("PSNR: {:.5f} -> {:.5f}".format(*computePSNR(ori_image[0, 0], previous_image[0, 0], image[0, 0]))) 118 | print("SSIM: {:.5f} -> {:.5f}".format(*computeSSIM(ori_image[0, 0], previous_image[0, 0], image[0, 0]))) 119 | 120 | image = np.clip(image, 0, 1) 121 | 122 | 123 | # compute loss and backpropagate 124 | if flag_a2c: 125 | losses = a2c.stop_episode_and_compute_loss(reward=Variable(torch.from_numpy(reward).cuda()), done=True) 126 | loss = sum(losses.values()) / config.iter_size 127 | loss.backward() 128 | else: 129 | loss = sum(v_out_dict.values()) * config.c_loss_coeff / config.iter_size 130 | loss.backward() 131 | 132 | if not(episodes % config.display): 133 | print('\na2c: ', flag_a2c) 134 | print('episode {}: loss = {}'.format(episodes, float(loss.data))) 135 | 136 | # update model and write into tensorboard 137 | if not(episodes % config.iter_size): 138 | if flag_a2c: 139 | optimizer.step() 140 | optimizer.zero_grad() 141 | optimizer_p.zero_grad() 142 | else: 143 | optimizer_p.step() 144 | optimizer_p.zero_grad() 145 | optimizer.zero_grad() 146 | for l in v_out_dict.keys(): 147 | writer.add_scalar('v_out_{}'.format(l), float(v_out_dict[l].cpu().data.numpy()), episodes) 148 | 149 | for l in losses.keys(): 150 | writer.add_scalar(l, float(losses[l].cpu().data.numpy()), episodes) 151 | writer.add_scalar('lr', float(learning_rate), episodes) 152 | for k, v in env.parameters.items(): 153 | writer.add_scalar(k, float(v.mean()), episodes) 154 | 155 | if not(episodes % config.switch): 156 | flag_a2c = not flag_a2c 157 | if episodes < config.warm_up_episodes: 158 | flag_a2c = True 159 | 160 | episodes += 1 161 | 162 | # save model 163 | if not(episodes % config.save_episodes): 164 | torch.save(model.module.state_dict(), 'model/' + '_'.join(map(lambda x: str(x), time_tuple[1:4])) + '_' + str(episodes) + '.pth') 165 | print('model saved') 166 | 167 | # test model 168 | if not(episodes % config.test_episodes): 169 | avg_reward, psnr_res, ssim_res = test(model, a2c, config, batch_size=10) 170 | writer.add_scalar('test reward', avg_reward, episodes) 171 | writer.add_scalar('test psnr', psnr_res[1], episodes) 172 | writer.add_scalar('test ssim', ssim_res[1], episodes) 173 | 174 | if episodes == config.num_episodes: 175 | writer.close() 176 | break 177 | 178 | if __name__ == "__main__": 179 | train() 180 | -------------------------------------------------------------------------------- /fastMRI/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | 12 | def to_tensor(data): 13 | """ 14 | Convert numpy array to PyTorch tensor. For complex arrays, the real and imaginary parts 15 | are stacked along the last dimension. 16 | 17 | Args: 18 | data (np.array): Input numpy array 19 | 20 | Returns: 21 | torch.Tensor: PyTorch version of data 22 | """ 23 | if np.iscomplexobj(data): 24 | data = np.stack((data.real, data.imag), axis=-1) 25 | return torch.from_numpy(data) 26 | 27 | 28 | def apply_mask(data, mask_func, seed=None): 29 | """ 30 | Subsample given k-space by multiplying with a mask. 31 | 32 | Args: 33 | data (torch.Tensor): The input k-space data. This should have at least 3 dimensions, where 34 | dimensions -3 and -2 are the spatial dimensions, and the final dimension has size 35 | 2 (for complex values). 36 | mask_func (callable): A function that takes a shape (tuple of ints) and a random 37 | number seed and returns a mask. 38 | seed (int or 1-d array_like, optional): Seed for the random number generator. 39 | 40 | Returns: 41 | (tuple): tuple containing: 42 | masked data (torch.Tensor): Subsampled k-space data 43 | mask (torch.Tensor): The generated mask 44 | """ 45 | shape = np.array(data.shape) 46 | shape[:-3] = 1 47 | mask = mask_func(shape, seed) 48 | return mask # 0830 49 | #print(mask, data) 50 | return torch.where(mask == 0, torch.Tensor([0]), data), mask 51 | 52 | 53 | def fft2(data): 54 | """ 55 | Apply centered 2 dimensional Fast Fourier Transform. 56 | 57 | Args: 58 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions 59 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are 60 | assumed to be batch dimensions. 61 | 62 | Returns: 63 | torch.Tensor: The FFT of the input. 64 | """ 65 | assert data.size(-1) == 2 66 | data = ifftshift(data, dim=(-3, -2)) 67 | data = torch.fft(data, 2, normalized=True) 68 | data = fftshift(data, dim=(-3, -2)) 69 | return data 70 | 71 | 72 | def ifft2(data): 73 | """ 74 | Apply centered 2-dimensional Inverse Fast Fourier Transform. 75 | 76 | Args: 77 | data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions 78 | -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are 79 | assumed to be batch dimensions. 80 | 81 | Returns: 82 | torch.Tensor: The IFFT of the input. 83 | """ 84 | assert data.size(-1) == 2 85 | data = ifftshift(data, dim=(-3, -2)) 86 | data = torch.ifft(data, 2, normalized=True) 87 | data = fftshift(data, dim=(-3, -2)) 88 | return data 89 | 90 | 91 | def complex_abs(data): 92 | """ 93 | Compute the absolute value of a complex valued input tensor. 94 | 95 | Args: 96 | data (torch.Tensor): A complex valued tensor, where the size of the final dimension 97 | should be 2. 98 | 99 | Returns: 100 | torch.Tensor: Absolute value of data 101 | """ 102 | assert data.size(-1) == 2 103 | return (data ** 2).sum(dim=-1).sqrt() 104 | 105 | 106 | def root_sum_of_squares(data, dim=0): 107 | """ 108 | Compute the Root Sum of Squares (RSS) transform along a given dimension of a tensor. 109 | 110 | Args: 111 | data (torch.Tensor): The input tensor 112 | dim (int): The dimensions along which to apply the RSS transform 113 | 114 | Returns: 115 | torch.Tensor: The RSS value 116 | """ 117 | return torch.sqrt((data ** 2).sum(dim)) 118 | 119 | 120 | def center_crop(data, shape): 121 | """ 122 | Apply a center crop to the input real image or batch of real images. 123 | 124 | Args: 125 | data (torch.Tensor): The input tensor to be center cropped. It should have at 126 | least 2 dimensions and the cropping is applied along the last two dimensions. 127 | shape (int, int): The output shape. The shape should be smaller than the 128 | corresponding dimensions of data. 129 | 130 | Returns: 131 | torch.Tensor: The center cropped image 132 | """ 133 | assert 0 < shape[0] <= data.shape[-2] 134 | assert 0 < shape[1] <= data.shape[-1] 135 | w_from = (data.shape[-2] - shape[0]) // 2 136 | h_from = (data.shape[-1] - shape[1]) // 2 137 | w_to = w_from + shape[0] 138 | h_to = h_from + shape[1] 139 | return data[..., w_from:w_to, h_from:h_to] 140 | 141 | 142 | def complex_center_crop(data, shape): 143 | """ 144 | Apply a center crop to the input image or batch of complex images. 145 | 146 | Args: 147 | data (torch.Tensor): The complex input tensor to be center cropped. It should 148 | have at least 3 dimensions and the cropping is applied along dimensions 149 | -3 and -2 and the last dimensions should have a size of 2. 150 | shape (int, int): The output shape. The shape should be smaller than the 151 | corresponding dimensions of data. 152 | 153 | Returns: 154 | torch.Tensor: The center cropped image 155 | """ 156 | assert 0 < shape[0] <= data.shape[-3] 157 | assert 0 < shape[1] <= data.shape[-2] 158 | w_from = (data.shape[-3] - shape[0]) // 2 159 | h_from = (data.shape[-2] - shape[1]) // 2 160 | w_to = w_from + shape[0] 161 | h_to = h_from + shape[1] 162 | return data[..., w_from:w_to, h_from:h_to, :] 163 | 164 | 165 | def normalize(data, mean, stddev, eps=0.): 166 | """ 167 | Normalize the given tensor using: 168 | (data - mean) / (stddev + eps) 169 | 170 | Args: 171 | data (torch.Tensor): Input data to be normalized 172 | mean (float): Mean value 173 | stddev (float): Standard deviation 174 | eps (float): Added to stddev to prevent dividing by zero 175 | 176 | Returns: 177 | torch.Tensor: Normalized tensor 178 | """ 179 | return (data - mean) / (stddev + eps) 180 | 181 | 182 | def normalize_instance(data, eps=0.): 183 | """ 184 | Normalize the given tensor using: 185 | (data - mean) / (stddev + eps) 186 | where mean and stddev are computed from the data itself. 187 | 188 | Args: 189 | data (torch.Tensor): Input data to be normalized 190 | eps (float): Added to stddev to prevent dividing by zero 191 | 192 | Returns: 193 | torch.Tensor: Normalized tensor 194 | """ 195 | mean = data.mean() 196 | std = data.std() 197 | return normalize(data, mean, std, eps), mean, std 198 | 199 | 200 | # Helper functions 201 | 202 | def roll(x, shift, dim): 203 | """ 204 | Similar to np.roll but applies to PyTorch Tensors 205 | """ 206 | if isinstance(shift, (tuple, list)): 207 | assert len(shift) == len(dim) 208 | for s, d in zip(shift, dim): 209 | x = roll(x, s, d) 210 | return x 211 | shift = shift % x.size(dim) 212 | if shift == 0: 213 | return x 214 | left = x.narrow(dim, 0, x.size(dim) - shift) 215 | right = x.narrow(dim, x.size(dim) - shift, shift) 216 | return torch.cat((right, left), dim=dim) 217 | 218 | 219 | def fftshift(x, dim=None): 220 | """ 221 | Similar to np.fft.fftshift but applies to PyTorch Tensors 222 | """ 223 | if dim is None: 224 | dim = tuple(range(x.dim())) 225 | shift = [dim // 2 for dim in x.shape] 226 | elif isinstance(dim, int): 227 | shift = x.shape[dim] // 2 228 | else: 229 | shift = [x.shape[i] // 2 for i in dim] 230 | return roll(x, shift, dim) 231 | 232 | 233 | def ifftshift(x, dim=None): 234 | """ 235 | Similar to np.fft.ifftshift but applies to PyTorch Tensors 236 | """ 237 | if dim is None: 238 | dim = tuple(range(x.dim())) 239 | shift = [(dim + 1) // 2 for dim in x.shape] 240 | elif isinstance(dim, int): 241 | shift = (x.shape[dim] + 1) // 2 242 | else: 243 | shift = [(x.shape[i] + 1) // 2 for i in dim] 244 | return roll(x, shift, dim) 245 | -------------------------------------------------------------------------------- /unet/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import logging 9 | import pathlib 10 | import random 11 | import shutil 12 | import time 13 | from collections import defaultdict 14 | 15 | import numpy as np 16 | import cv2 17 | import torch 18 | import torchvision 19 | from tensorboardX import SummaryWriter 20 | from torch.nn import functional as F 21 | from torch.utils.data import DataLoader 22 | from torch.autograd import Variable 23 | 24 | import sys 25 | sys.path.append('..') 26 | from utils import PSNR, SSIM, NMSE, DC, computePSNR, computeSSIM, computeNMSE 27 | 28 | from unet_model import UnetModel 29 | from args import Args 30 | sys.path.append('../fastMRI/') 31 | from subsample import MaskFunc 32 | from dataset import MRIDataset 33 | 34 | logging.basicConfig(level=logging.INFO) 35 | logger = logging.getLogger(__name__) 36 | 37 | 38 | def create_datasets(args): 39 | from config import config 40 | train_data = MRIDataset(image_set='train', transform=False, config=config) 41 | dev_data = MRIDataset(image_set='test', transform=False, config=config) 42 | return dev_data, train_data 43 | 44 | 45 | def create_data_loaders(args): 46 | dev_data, train_data = create_datasets(args) 47 | display_data = []#[dev_data[i] for i in range(0, len(dev_data), len(dev_data) // 16)] 48 | 49 | train_loader = DataLoader( 50 | dataset=train_data, 51 | batch_size=args.batch_size, 52 | shuffle=True, 53 | num_workers=0, 54 | pin_memory=True, 55 | ) 56 | dev_loader = DataLoader( 57 | dataset=dev_data, 58 | #batch_size=args.batch_size, 59 | batch_size=10, 60 | num_workers=1, 61 | pin_memory=False, 62 | ) 63 | #display_loader = DataLoader( 64 | # dataset=display_data, 65 | # batch_size=16, 66 | # num_workers=8, 67 | # pin_memory=True, 68 | #) 69 | return train_loader, dev_loader, None#display_loader 70 | 71 | 72 | def train_epoch(args, epoch, model, data_loader, optimizer, writer): 73 | model.train() 74 | avg_loss = 0. 75 | start_epoch = start_iter = time.perf_counter() 76 | global_step = epoch * len(data_loader) 77 | iter = 0 78 | while True: 79 | for _, data in enumerate(data_loader): 80 | iter += 1 81 | target, input, _ = data 82 | input = Variable(input).cuda() 83 | target = Variable(target).cuda() 84 | 85 | output = model(input) 86 | loss = F.l1_loss(output, target) 87 | optimizer.zero_grad() 88 | loss.backward() 89 | optimizer.step() 90 | 91 | avg_loss = 0.99 * avg_loss + 0.01 * float(loss.data) if iter > 0 else loss.item() 92 | writer.add_scalar('TrainLoss', float(loss.data), global_step + iter) 93 | 94 | if iter % args.report_interval == 0: 95 | logging.info( 96 | f'Epoch = [{epoch:3d}/{args.num_epochs:3d}] ' 97 | f'Iter = [{iter:4d}/{len(data_loader):4d}] ' 98 | f'Loss = {float(loss.data):.4g} Avg Loss = {avg_loss:.4g} ' 99 | f'Time = {time.perf_counter() - start_iter:.4f}s', 100 | ) 101 | start_iter = time.perf_counter() 102 | if args.iters_per_epoch and iter == args.iters_per_epoch: 103 | break 104 | 105 | return avg_loss, time.perf_counter() - start_epoch 106 | 107 | 108 | def evaluate(args, epoch, model, data_loader, writer): 109 | model.eval() 110 | losses = [] 111 | start = time.perf_counter() 112 | early_break = True 113 | 114 | 115 | PSNR_dict = defaultdict(list) 116 | SSIM_dict = defaultdict(list) 117 | NMSE_dict = defaultdict(list) 118 | count = 0 119 | 120 | for i, (target, input, mask) in enumerate(data_loader): 121 | ori_image = target.numpy() 122 | previous_image = input.numpy() 123 | mask = mask.numpy() 124 | count += 1 125 | print(count) 126 | if early_break and count == 101: 127 | break 128 | if count % 100 == 0: 129 | print('tested: ', count) 130 | 131 | input = Variable(input, volatile=True).cuda() 132 | target = Variable(target).cuda() 133 | output = model(input)#.squeeze(1) 134 | 135 | loss = F.mse_loss(output, target, size_average=False) 136 | losses.append(float(loss.data)) 137 | 138 | image = output.cpu().data.numpy() 139 | for ii in range(image.shape[0]): 140 | m = min(float(np.min(ori_image[ii, 0])), 0) 141 | def rescale(x): 142 | return (x - m) / (6 - m) 143 | ori_image[ii, 0] = rescale(ori_image[ii, 0]) 144 | previous_image[ii, 0] = rescale(previous_image[ii, 0]) 145 | image[ii, 0] = rescale(image[ii, 0]) 146 | image_with_DC = DC(ori_image[ii, 0], image[ii, 0], mask[ii]) 147 | 148 | for k in range(2): 149 | key = ['wo', 'DC'][k] 150 | tmp_image = [image[ii, 0], image_with_DC][k] 151 | PSNR_dict[key].append(computePSNR(ori_image[ii, 0], previous_image[ii, 0], tmp_image)) 152 | SSIM_dict[key].append(computeSSIM(ori_image[ii, 0], previous_image[ii, 0], tmp_image)) 153 | NMSE_dict[key].append(computeNMSE(ori_image[ii, 0], previous_image[ii, 0], tmp_image)) 154 | 155 | cv2.imwrite('unet_results/'+str(i)+'_'+str(ii)+'.bmp', np.concatenate((ori_image[ii, 0], previous_image[ii, 0], image[ii, 0], np.abs(ori_image[ii, 0] - image[ii, 0]) * 10), axis=1) * 255) 156 | writer.add_scalar('Dev_Loss', np.mean(losses), epoch) 157 | 158 | 159 | for key in PSNR_dict.keys(): 160 | PSNR_list, SSIM_list, NMSE_list = map(lambda x: x[key], [PSNR_dict, SSIM_dict, NMSE_dict]) 161 | print('number of test images: ', len(PSNR_list)) 162 | psnr_res = np.mean(np.array(PSNR_list), axis=0) 163 | ssim_res = np.mean(np.array(SSIM_list), axis=0) 164 | nmse_res = np.mean(np.array(NMSE_list), axis=0) 165 | 166 | print('PSNR', psnr_res) 167 | print('SSIM', ssim_res) 168 | print('NMSE', nmse_res) 169 | 170 | return np.mean(losses), time.perf_counter() - start 171 | 172 | 173 | def save_model(args, exp_dir, epoch, model, optimizer, best_dev_loss, is_new_best): 174 | torch.save( 175 | { 176 | 'epoch': epoch, 177 | 'args': args, 178 | 'model': model.state_dict(), 179 | 'optimizer': optimizer.state_dict(), 180 | 'best_dev_loss': best_dev_loss, 181 | 'exp_dir': exp_dir 182 | }, 183 | f=exp_dir / 'model.pt' 184 | ) 185 | if is_new_best: 186 | shutil.copyfile(exp_dir / 'model.pt', exp_dir / 'best_model.pt') 187 | 188 | 189 | def build_model(args): 190 | model = UnetModel( 191 | in_chans=1, 192 | out_chans=1, 193 | chans=args.num_chans, 194 | num_pool_layers=args.num_pools, 195 | drop_prob=args.drop_prob 196 | ).cuda()#to(args.device) 197 | return model 198 | 199 | 200 | def load_model(checkpoint_file): 201 | checkpoint = torch.load(checkpoint_file) 202 | args = checkpoint['args'] 203 | model = build_model(args) 204 | if args.data_parallel: 205 | model = torch.nn.DataParallel(model) 206 | model.load_state_dict(checkpoint['model']) 207 | 208 | optimizer = build_optim(args, model.parameters()) 209 | optimizer.load_state_dict(checkpoint['optimizer']) 210 | return checkpoint, model, optimizer 211 | 212 | 213 | def build_optim(args, params): 214 | optimizer = torch.optim.RMSprop(params, args.lr, weight_decay=args.weight_decay) 215 | return optimizer 216 | 217 | 218 | def main(args): 219 | args.exp_dir.mkdir(parents=True, exist_ok=True) 220 | writer = SummaryWriter(args.exp_dir / 'summary') 221 | 222 | if args.test: 223 | checkpoint, model, optimizer = load_model(args.checkpoint) 224 | start_epoch = checkpoint['epoch'] 225 | del checkpoint 226 | elif args.resume: 227 | checkpoint, model, optimizer = load_model(args.checkpoint) 228 | args = checkpoint['args'] 229 | best_dev_loss = checkpoint['best_dev_loss'] 230 | start_epoch = checkpoint['epoch'] 231 | del checkpoint 232 | else: 233 | model = build_model(args) 234 | if args.data_parallel: 235 | model = torch.nn.DataParallel(model) 236 | optimizer = build_optim(args, model.parameters()) 237 | best_dev_loss = 1e9 238 | start_epoch = 0 239 | logging.info(args) 240 | logging.info(model) 241 | 242 | train_loader, dev_loader, display_loader = create_data_loaders(args) 243 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_step_size, args.lr_gamma) 244 | 245 | for epoch in range(start_epoch, args.num_epochs): 246 | if args.test: 247 | print('evaluating') 248 | dev_loss, dev_time = evaluate(args, epoch, model, dev_loader, writer) 249 | exit() 250 | 251 | scheduler.step(epoch) 252 | train_loss, train_time = train_epoch(args, epoch, model, train_loader, optimizer, writer) 253 | if (epoch + 1) % 5 == 0: 254 | #dev_loss, dev_time = evaluate(args, epoch, model, dev_loader, writer) 255 | 256 | is_new_best = True #dev_loss < best_dev_loss 257 | best_dev_loss = 0 #min(best_dev_loss, dev_loss) 258 | save_model(args, args.exp_dir, epoch, model, optimizer, best_dev_loss, is_new_best) 259 | logging.info( 260 | 'saved', 261 | #f'Epoch = [{epoch:4d}/{args.num_epochs:4d}] TrainLoss = {train_loss:.4g} ' 262 | #f'DevLoss = {dev_loss:.4g} TrainTime = {train_time:.4f}s DevTime = {dev_time:.4f}s', 263 | ) 264 | writer.close() 265 | 266 | 267 | def create_arg_parser(): 268 | parser = Args() 269 | parser.add_argument('--num-pools', type=int, default=4, help='Number of U-Net pooling layers') 270 | parser.add_argument('--drop-prob', type=float, default=0.0, help='Dropout probability') 271 | parser.add_argument('--num-chans', type=int, default=32, help='Number of U-Net channels') 272 | 273 | parser.add_argument('--batch-size', default=16, type=int, help='Mini batch size') 274 | parser.add_argument('--num-epochs', type=int, default=50, help='Number of training epochs') 275 | parser.add_argument('--iters-per-epoch', type=int, default=0, help='Number of iterations per epoch') 276 | parser.add_argument('--lr', type=float, default=0.001, help='Learning rate') 277 | parser.add_argument('--lr-step-size', type=int, default=40, 278 | help='Period of learning rate decay') 279 | parser.add_argument('--lr-gamma', type=float, default=0.1, 280 | help='Multiplicative factor of learning rate decay') 281 | parser.add_argument('--weight-decay', type=float, default=0., 282 | help='Strength of weight decay regularization') 283 | 284 | parser.add_argument('--report-interval', type=int, default=40, help='Period of loss reporting') 285 | parser.add_argument('--data-parallel', action='store_true', 286 | help='If set, use multiple GPUs using data parallelism') 287 | parser.add_argument('--device', type=str, default='cuda', 288 | help='Which device to train on. Set to "cuda" to use the GPU') 289 | parser.add_argument('--exp-dir', type=pathlib.Path, default='checkpoints', 290 | help='Path where model and results should be saved') 291 | parser.add_argument('--resume', action='store_true', 292 | help='If set, resume the training from a previous model checkpoint. ' 293 | '"--checkpoint" should be set with this') 294 | parser.add_argument('--checkpoint', type=str, 295 | help='Path to an existing checkpoint. Used along with "--resume"') 296 | parser.add_argument('--test', action='store_true', default=False) 297 | return parser 298 | 299 | 300 | if __name__ == '__main__': 301 | args = create_arg_parser().parse_args() 302 | random.seed(args.seed) 303 | np.random.seed(args.seed) 304 | torch.manual_seed(args.seed) 305 | main(args) 306 | --------------------------------------------------------------------------------