├── .gitignore ├── README.md ├── config.py ├── data └── MICCAI13_SegChallenge │ ├── README.md │ └── dataset_name_list.txt ├── data_loader.py ├── mask ├── Gaussian1D │ ├── GaussianDistribution1DMask_1.mat │ ├── GaussianDistribution1DMask_10.mat │ ├── GaussianDistribution1DMask_20.mat │ ├── GaussianDistribution1DMask_30.mat │ ├── GaussianDistribution1DMask_40.mat │ ├── GaussianDistribution1DMask_5.mat │ └── GaussianDistribution1DMask_50.mat ├── Gaussian2D │ ├── GaussianDistribution2DMask_1.mat │ ├── GaussianDistribution2DMask_10.mat │ ├── GaussianDistribution2DMask_20.mat │ ├── GaussianDistribution2DMask_30.mat │ ├── GaussianDistribution2DMask_40.mat │ ├── GaussianDistribution2DMask_5.mat │ └── GaussianDistribution2DMask_50.mat └── Poisson2D │ ├── PoissonDistributionMask_1.mat │ ├── PoissonDistributionMask_10.mat │ ├── PoissonDistributionMask_20.mat │ ├── PoissonDistributionMask_30.mat │ ├── PoissonDistributionMask_40.mat │ ├── PoissonDistributionMask_5.mat │ └── PoissonDistributionMask_50.mat ├── model.py ├── test.py ├── train.py ├── trained_model └── VGG16 │ └── README.md └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DAGAN 2 | 3 | This is the official implementation code for [DAGAN: Deep De-Aliasing Generative Adversarial Networks for Fast Compressed Sensing MRI Reconstruction](https://ieeexplore.ieee.org/document/8233175/) published in IEEE Transactions on Medical Imaging (2018). 4 | [Guang Yang](https://www.imperial.ac.uk/people/g.yang)\*, [Simiao Yu](https://nebulav.github.io/)\*, et al. 5 | (* equal contributions) 6 | 7 | If you use this code for your research, please cite our paper. 8 | 9 | ``` 10 | @article{yang2018_dagan, 11 | author = {Yang, Guang and Yu, Simiao and Dong, Hao and Slabaugh, Gregory G. and Dragotti, Pier Luigi and Ye, Xujiong and Liu, Fangde and Arridge, Simon R. and Keegan, Jennifer and Guo, Yike and Firmin, David N.}, 12 | journal = {IEEE Trans. Med. Imaging}, 13 | number = 6, 14 | pages = {1310--1321}, 15 | title = {{DAGAN: deep de-aliasing generative adversarial networks for fast compressed sensing MRI reconstruction}}, 16 | volume = 37, 17 | year = 2018 18 | } 19 | ``` 20 | 21 | If you have any questions about this code, please feel free to contact Simiao Yu (simiao.yu13@imperial.ac.uk). 22 | 23 | # Prerequisites 24 | 25 | The original code is in python 3.5 under the following dependencies: 26 | 1. tensorflow (v1.1.0) 27 | 2. tensorlayer (v1.7.2) 28 | 3. easydict (v1.6) 29 | 4. nibabel (v2.1.0) 30 | 5. scikit-image (v0.12.3) 31 | 32 | Code tested in Ubuntu 16.04 with Nvidia GPU + CUDA CuDNN (whose version is compatible to tensorflow v1.1.0). 33 | 34 | # How to use 35 | 36 | 1. Prepare data 37 | 38 | 1) Data used in this work are publicly available from the MICCAI 2013 grand challenge ([link](https://my.vanderbilt.edu/masi/workshops/)). We refer users to register with the grand challenge organisers to be able to download the data. 39 | 2) Download training and test data respectively into data/MICCAI13_SegChallenge/Training_100 and data/MICCAI13_SegChallenge/Testing_100 (We randomly included 100 T1-weighted MRI datasets for training and 50 datasets for testing) 40 | 3) run 'python data_loader.py' 41 | 4) after running the code, training/validation/testing data should be saved to 'data/MICCAI13_SegChallenge/' in pickle format. 42 | 43 | 2. Download pretrained VGG16 model 44 | 45 | 1) Download 'vgg16_weights.npz' from [this link](http://www.cs.toronto.edu/~frossard/post/vgg16/) 46 | 2) Save 'vgg16_weights.npz' into 'trained_model/VGG16' 47 | 48 | 3. Train model 49 | 1) run 'CUDA_VISIBLE_DEVICES=0 python train.py --model MODEL --mask MASK --maskperc MASKPERC' where you should specify MODEL, MASK, MASKPERC respectively: 50 | - MODEL: choose from 'unet' or 'unet_refine' 51 | - MASK: choose from 'gaussian1d', 'gaussian2d', 'poisson2d' 52 | - MASKPERC: choose from '10', '20', '30', '40', '50' (percentage of mask) 53 | 54 | 4. Test trained model 55 | 56 | 1) run 'CUDA_VISIBLE_DEVICES=0 python test.py --model MODEL --mask MASK --maskperc MASKPERC' where you should specify MODEL, MASK, MASKPERC respectively (as above). 57 | 58 | # Results 59 | 60 | Please refer to the paper for the detailed results. 61 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import json 3 | import os 4 | 5 | config = edict() 6 | config.TRAIN = edict() 7 | 8 | config.TRAIN.batch_size = 25 9 | config.TRAIN.early_stopping_num = 10 10 | config.TRAIN.lr = 0.0001 11 | config.TRAIN.lr_decay = 0.5 12 | config.TRAIN.decay_every = 5 13 | config.TRAIN.beta1 = 0.5 # beta1 in Adam optimiser 14 | config.TRAIN.n_epoch = 9999 15 | config.TRAIN.sample_size = 50 16 | config.TRAIN.g_alpha = 15 # weight for pixel loss 17 | config.TRAIN.g_gamma = 0.0025 # weight for perceptual loss 18 | config.TRAIN.g_beta = 0.1 # weight for frequency loss 19 | config.TRAIN.g_adv = 1 # weight for frequency loss 20 | 21 | config.TRAIN.seed = 100 22 | config.TRAIN.epsilon = 0.000001 23 | 24 | 25 | config.TRAIN.VGG16_path = os.path.join('trained_model', 'VGG16', 'vgg16_weights.npz') 26 | config.TRAIN.training_data_path = os.path.join('data', 'MICCAI13_SegChallenge', 'training.pickle') 27 | config.TRAIN.val_data_path = os.path.join('data', 'MICCAI13_SegChallenge', 'validation.pickle') 28 | config.TRAIN.testing_data_path = os.path.join('data', 'MICCAI13_SegChallenge', 'testing.pickle') 29 | config.TRAIN.mask_Gaussian1D_path = os.path.join('mask', 'Gaussian1D') 30 | config.TRAIN.mask_Gaussian2D_path = os.path.join('mask', 'Gaussian2D') 31 | config.TRAIN.mask_Poisson2D_path = os.path.join('mask', 'Poisson2D') 32 | 33 | def log_config(filename, cfg): 34 | with open(filename, 'w') as f: 35 | f.write("================================================\n") 36 | f.write(json.dumps(cfg, indent=4)) 37 | f.write("\n================================================\n") -------------------------------------------------------------------------------- /data/MICCAI13_SegChallenge/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/data/MICCAI13_SegChallenge/README.md -------------------------------------------------------------------------------- /data/MICCAI13_SegChallenge/dataset_name_list.txt: -------------------------------------------------------------------------------- 1 | Training datasets: 2 | 1003_3x1110_3Warped.nii.gz 1004_3x1023_3Warped.nii.gz 1005_3x1010_3Warped.nii.gz 1006_3x1000_3Warped.nii.gz 3 | 1003_3x1113_3Warped.nii.gz 1004_3x1024_3Warped.nii.gz 1005_3x1011_3Warped.nii.gz 1006_3x1001_3Warped.nii.gz 4 | 1003_3x1116_3Warped.nii.gz 1004_3x1025_3Warped.nii.gz 1005_3x1012_3Warped.nii.gz 1006_3x1002_3Warped.nii.gz 5 | 1003_3x1119_3Warped.nii.gz 1004_3x1036_3Warped.nii.gz 1005_3x1013_3Warped.nii.gz 1006_3x1003_3Warped.nii.gz 6 | 1003_3x1122_3Warped.nii.gz 1004_3x1038_3Warped.nii.gz 1005_3x1014_3Warped.nii.gz 1006_3x1004_3Warped.nii.gz 7 | 1003_3x1125_3Warped.nii.gz 1004_3x1039_3Warped.nii.gz 1005_3x1015_3Warped.nii.gz 1006_3x1005_3Warped.nii.gz 8 | 1003_3x1128_3Warped.nii.gz 1004_3x1101_3Warped.nii.gz 1005_3x1017_3Warped.nii.gz 1006_3x1007_3Warped.nii.gz 9 | 1004_3x1000_3Warped.nii.gz 1004_3x1104_3Warped.nii.gz 1005_3x1018_3Warped.nii.gz 1006_3x1008_3Warped.nii.gz 10 | 1004_3x1001_3Warped.nii.gz 1004_3x1107_3Warped.nii.gz 1005_3x1019_3Warped.nii.gz 1006_3x1009_3Warped.nii.gz 11 | 1004_3x1002_3Warped.nii.gz 1004_3x1110_3Warped.nii.gz 1005_3x1023_3Warped.nii.gz 1006_3x1010_3Warped.nii.gz 12 | 1004_3x1003_3Warped.nii.gz 1004_3x1113_3Warped.nii.gz 1005_3x1024_3Warped.nii.gz 1006_3x1011_3Warped.nii.gz 13 | 1004_3x1005_3Warped.nii.gz 1004_3x1116_3Warped.nii.gz 1005_3x1025_3Warped.nii.gz 1006_3x1012_3Warped.nii.gz 14 | 1004_3x1006_3Warped.nii.gz 1004_3x1119_3Warped.nii.gz 1005_3x1036_3Warped.nii.gz 1006_3x1013_3Warped.nii.gz 15 | 1004_3x1007_3Warped.nii.gz 1004_3x1122_3Warped.nii.gz 1005_3x1038_3Warped.nii.gz 1006_3x1014_3Warped.nii.gz 16 | 1004_3x1008_3Warped.nii.gz 1004_3x1125_3Warped.nii.gz 1005_3x1039_3Warped.nii.gz 1006_3x1015_3Warped.nii.gz 17 | 1004_3x1009_3Warped.nii.gz 1004_3x1128_3Warped.nii.gz 1005_3x1101_3Warped.nii.gz 1006_3x1017_3Warped.nii.gz 18 | 1004_3x1010_3Warped.nii.gz 1005_3x1000_3Warped.nii.gz 1005_3x1104_3Warped.nii.gz 1006_3x1018_3Warped.nii.gz 19 | 1004_3x1011_3Warped.nii.gz 1005_3x1001_3Warped.nii.gz 1005_3x1107_3Warped.nii.gz 1006_3x1019_3Warped.nii.gz 20 | 1004_3x1012_3Warped.nii.gz 1005_3x1002_3Warped.nii.gz 1005_3x1110_3Warped.nii.gz 1006_3x1023_3Warped.nii.gz 21 | 1004_3x1013_3Warped.nii.gz 1005_3x1003_3Warped.nii.gz 1005_3x1113_3Warped.nii.gz 1006_3x1024_3Warped.nii.gz 22 | 1004_3x1014_3Warped.nii.gz 1005_3x1004_3Warped.nii.gz 1005_3x1116_3Warped.nii.gz 1006_3x1025_3Warped.nii.gz 23 | 1004_3x1015_3Warped.nii.gz 1005_3x1006_3Warped.nii.gz 1005_3x1119_3Warped.nii.gz 1006_3x1036_3Warped.nii.gz 24 | 1004_3x1017_3Warped.nii.gz 1005_3x1007_3Warped.nii.gz 1005_3x1122_3Warped.nii.gz 1006_3x1038_3Warped.nii.gz 25 | 1004_3x1018_3Warped.nii.gz 1005_3x1008_3Warped.nii.gz 1005_3x1125_3Warped.nii.gz 1006_3x1039_3Warped.nii.gz 26 | 1004_3x1019_3Warped.nii.gz 1005_3x1009_3Warped.nii.gz 1005_3x1128_3Warped.nii.gz 1006_3x1101_3Warped.nii.gz 27 | 28 | Testing datasets: 29 | 1006_3x1104_3Warped.nii.gz 1007_3x1004_3Warped.nii.gz 1007_3x1019_3Warped.nii.gz 1007_3x1119_3Warped.nii.gz 30 | 1006_3x1107_3Warped.nii.gz 1007_3x1005_3Warped.nii.gz 1007_3x1023_3Warped.nii.gz 1007_3x1122_3Warped.nii.gz 31 | 1006_3x1110_3Warped.nii.gz 1007_3x1006_3Warped.nii.gz 1007_3x1024_3Warped.nii.gz 1007_3x1125_3Warped.nii.gz 32 | 1006_3x1113_3Warped.nii.gz 1007_3x1008_3Warped.nii.gz 1007_3x1025_3Warped.nii.gz 1007_3x1128_3Warped.nii.gz 33 | 1006_3x1116_3Warped.nii.gz 1007_3x1009_3Warped.nii.gz 1007_3x1036_3Warped.nii.gz 1008_3x1000_3Warped.nii.gz 34 | 1006_3x1119_3Warped.nii.gz 1007_3x1010_3Warped.nii.gz 1007_3x1038_3Warped.nii.gz 1008_3x1001_3Warped.nii.gz 35 | 1006_3x1122_3Warped.nii.gz 1007_3x1011_3Warped.nii.gz 1007_3x1039_3Warped.nii.gz 1008_3x1002_3Warped.nii.gz 36 | 1006_3x1125_3Warped.nii.gz 1007_3x1012_3Warped.nii.gz 1007_3x1101_3Warped.nii.gz 1008_3x1003_3Warped.nii.gz 37 | 1006_3x1128_3Warped.nii.gz 1007_3x1013_3Warped.nii.gz 1007_3x1104_3Warped.nii.gz 1008_3x1004_3Warped.nii.gz 38 | 1007_3x1000_3Warped.nii.gz 1007_3x1014_3Warped.nii.gz 1007_3x1107_3Warped.nii.gz 1008_3x1005_3Warped.nii.gz 39 | 1007_3x1001_3Warped.nii.gz 1007_3x1015_3Warped.nii.gz 1007_3x1110_3Warped.nii.gz 1008_3x1006_3Warped.nii.gz 40 | 1007_3x1002_3Warped.nii.gz 1007_3x1017_3Warped.nii.gz 1007_3x1113_3Warped.nii.gz 41 | 1007_3x1003_3Warped.nii.gz 1007_3x1018_3Warped.nii.gz 1007_3x1116_3Warped.nii.gz 42 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import tensorlayer as tl 3 | import numpy as np 4 | import os 5 | import nibabel as nib 6 | 7 | training_data_path = "data/MICCAI13_SegChallenge/Training_100" 8 | testing_data_path = "data/MICCAI13_SegChallenge/Testing_50" 9 | val_ratio = 0.3 10 | seed = 100 11 | preserving_ratio = 0.1 # filter out 2d images containing < 10% non-zeros 12 | 13 | 14 | f_train_all = tl.files.load_file_list(path=training_data_path, 15 | regx='.*.gz', 16 | printable=False) 17 | train_all_num = len(f_train_all) 18 | val_num = int(train_all_num * val_ratio) 19 | 20 | f_train = [] 21 | f_val = [] 22 | 23 | val_idex = tl.utils.get_random_int(min=0, 24 | max=train_all_num - 1, 25 | number=val_num, 26 | seed=seed) 27 | for i in range(train_all_num): 28 | if i in val_idex: 29 | f_val.append(f_train_all[i]) 30 | else: 31 | f_train.append(f_train_all[i]) 32 | 33 | f_test = tl.files.load_file_list(path=testing_data_path, 34 | regx='.*.gz', 35 | printable=False) 36 | 37 | train_3d_num, val_3d_num, test_3d_num = len(f_train), len(f_val), len(f_test) 38 | 39 | 40 | X_train = [] 41 | for fi, f in enumerate(f_train): 42 | print("processing [{}/{}] 3d image ({}) for training set ...".format(fi + 1, train_3d_num, f)) 43 | img_path = os.path.join(training_data_path, f) 44 | img = nib.load(img_path).get_data() 45 | img_3d_max = np.max(img) 46 | img = img / img_3d_max * 255 47 | for i in range(img.shape[2]): 48 | img_2d = img[:, :, i] 49 | # filter out 2d images containing < 10% non-zeros 50 | if float(np.count_nonzero(img_2d)) / img_2d.size >= preserving_ratio: 51 | img_2d = img_2d / 127.5 - 1 52 | img_2d = np.transpose(img_2d, (1, 0)) 53 | X_train.append(img_2d) 54 | 55 | X_val = [] 56 | for fi, f in enumerate(f_val): 57 | print("processing [{}/{}] 3d image ({}) for validation set ...".format(fi + 1, val_3d_num, f)) 58 | img_path = os.path.join(training_data_path, f) 59 | img = nib.load(img_path).get_data() 60 | img_3d_max = np.max(img) 61 | img = img / img_3d_max * 255 62 | for i in range(img.shape[2]): 63 | img_2d = img[:, :, i] 64 | # filter out 2d images containing < 10% non-zeros 65 | if float(np.count_nonzero(img_2d)) / img_2d.size >= preserving_ratio: 66 | img_2d = img_2d / 127.5 - 1 67 | img_2d = np.transpose(img_2d, (1, 0)) 68 | X_val.append(img_2d) 69 | 70 | X_test = [] 71 | for fi, f in enumerate(f_test): 72 | print("processing [{}/{}] 3d image ({}) for test set ...".format(fi + 1, test_3d_num, f)) 73 | img_path = os.path.join(testing_data_path, f) 74 | img = nib.load(img_path).get_data() 75 | img_3d_max = np.max(img) 76 | img = img / img_3d_max * 255 77 | for i in range(img.shape[2]): 78 | img_2d = img[:, :, i] 79 | # filter out 2d images containing < 10% non-zeros 80 | if float(np.count_nonzero(img_2d)) / img_2d.size >= preserving_ratio: 81 | img_2d = img_2d / 127.5 - 1 82 | img_2d = np.transpose(img_2d, (1, 0)) 83 | X_test.append(img_2d) 84 | 85 | X_train = np.asarray(X_train) 86 | X_train = X_train[:, :, :, np.newaxis] 87 | X_val = np.asarray(X_val) 88 | X_val = X_val[:, :, :, np.newaxis] 89 | X_test = np.asarray(X_test) 90 | X_test = X_test[:, :, :, np.newaxis] 91 | 92 | # save data into pickle format 93 | data_saving_path = 'data/MICCAI13_SegChallenge/' 94 | tl.files.exists_or_mkdir(data_saving_path) 95 | 96 | print("save training set into pickle format") 97 | with open(os.path.join(data_saving_path, 'training.pickle'), 'wb') as f: 98 | pickle.dump(X_train, f, protocol=4) 99 | 100 | print("save validation set into pickle format") 101 | with open(os.path.join(data_saving_path, 'validation.pickle'), 'wb') as f: 102 | pickle.dump(X_val, f, protocol=4) 103 | 104 | print("save test set into pickle format") 105 | with open(os.path.join(data_saving_path, 'testing.pickle'), 'wb') as f: 106 | pickle.dump(X_test, f, protocol=4) 107 | 108 | print("processing data finished!") 109 | -------------------------------------------------------------------------------- /mask/Gaussian1D/GaussianDistribution1DMask_1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Gaussian1D/GaussianDistribution1DMask_1.mat -------------------------------------------------------------------------------- /mask/Gaussian1D/GaussianDistribution1DMask_10.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Gaussian1D/GaussianDistribution1DMask_10.mat -------------------------------------------------------------------------------- /mask/Gaussian1D/GaussianDistribution1DMask_20.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Gaussian1D/GaussianDistribution1DMask_20.mat -------------------------------------------------------------------------------- /mask/Gaussian1D/GaussianDistribution1DMask_30.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Gaussian1D/GaussianDistribution1DMask_30.mat -------------------------------------------------------------------------------- /mask/Gaussian1D/GaussianDistribution1DMask_40.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Gaussian1D/GaussianDistribution1DMask_40.mat -------------------------------------------------------------------------------- /mask/Gaussian1D/GaussianDistribution1DMask_5.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Gaussian1D/GaussianDistribution1DMask_5.mat -------------------------------------------------------------------------------- /mask/Gaussian1D/GaussianDistribution1DMask_50.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Gaussian1D/GaussianDistribution1DMask_50.mat -------------------------------------------------------------------------------- /mask/Gaussian2D/GaussianDistribution2DMask_1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Gaussian2D/GaussianDistribution2DMask_1.mat -------------------------------------------------------------------------------- /mask/Gaussian2D/GaussianDistribution2DMask_10.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Gaussian2D/GaussianDistribution2DMask_10.mat -------------------------------------------------------------------------------- /mask/Gaussian2D/GaussianDistribution2DMask_20.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Gaussian2D/GaussianDistribution2DMask_20.mat -------------------------------------------------------------------------------- /mask/Gaussian2D/GaussianDistribution2DMask_30.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Gaussian2D/GaussianDistribution2DMask_30.mat -------------------------------------------------------------------------------- /mask/Gaussian2D/GaussianDistribution2DMask_40.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Gaussian2D/GaussianDistribution2DMask_40.mat -------------------------------------------------------------------------------- /mask/Gaussian2D/GaussianDistribution2DMask_5.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Gaussian2D/GaussianDistribution2DMask_5.mat -------------------------------------------------------------------------------- /mask/Gaussian2D/GaussianDistribution2DMask_50.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Gaussian2D/GaussianDistribution2DMask_50.mat -------------------------------------------------------------------------------- /mask/Poisson2D/PoissonDistributionMask_1.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Poisson2D/PoissonDistributionMask_1.mat -------------------------------------------------------------------------------- /mask/Poisson2D/PoissonDistributionMask_10.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Poisson2D/PoissonDistributionMask_10.mat -------------------------------------------------------------------------------- /mask/Poisson2D/PoissonDistributionMask_20.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Poisson2D/PoissonDistributionMask_20.mat -------------------------------------------------------------------------------- /mask/Poisson2D/PoissonDistributionMask_30.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Poisson2D/PoissonDistributionMask_30.mat -------------------------------------------------------------------------------- /mask/Poisson2D/PoissonDistributionMask_40.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Poisson2D/PoissonDistributionMask_40.mat -------------------------------------------------------------------------------- /mask/Poisson2D/PoissonDistributionMask_5.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Poisson2D/PoissonDistributionMask_5.mat -------------------------------------------------------------------------------- /mask/Poisson2D/PoissonDistributionMask_50.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/mask/Poisson2D/PoissonDistributionMask_50.mat -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from tensorlayer.layers import * 2 | from utils import * 3 | 4 | 5 | def discriminator(input_images, is_train=True, reuse=False): 6 | w_init = tf.random_normal_initializer(stddev=0.02) 7 | b_init = None 8 | gamma_init = tf.random_normal_initializer(1., 0.02) 9 | df_dim = 64 10 | 11 | with tf.variable_scope("discriminator", reuse=reuse): 12 | tl.layers.set_name_reuse(reuse) 13 | 14 | net_in = InputLayer(input_images, 15 | name='input') 16 | 17 | net_h0 = Conv2d(net_in, df_dim, (4, 4), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2), 18 | padding='SAME', W_init=w_init, name='h0/conv2d') 19 | 20 | net_h1 = Conv2d(net_h0, df_dim * 2, (4, 4), (2, 2), act=None, 21 | padding='SAME', W_init=w_init, b_init=b_init, name='h1/conv2d') 22 | net_h1 = BatchNormLayer(net_h1, act=lambda x: tl.act.lrelu(x, 0.2), 23 | is_train=is_train, gamma_init=gamma_init, name='h1/batchnorm') 24 | 25 | net_h2 = Conv2d(net_h1, df_dim * 4, (4, 4), (2, 2), act=None, 26 | padding='SAME', W_init=w_init, b_init=b_init, name='h2/conv2d') 27 | net_h2 = BatchNormLayer(net_h2, act=lambda x: tl.act.lrelu(x, 0.2), 28 | is_train=is_train, gamma_init=gamma_init, name='h2/batchnorm') 29 | 30 | net_h3 = Conv2d(net_h2, df_dim * 8, (4, 4), (2, 2), act=None, 31 | padding='SAME', W_init=w_init, b_init=b_init, name='h3/conv2d') 32 | net_h3 = BatchNormLayer(net_h3, act=lambda x: tl.act.lrelu(x, 0.2), 33 | is_train=is_train, gamma_init=gamma_init, name='h3/batchnorm') 34 | 35 | net_h4 = Conv2d(net_h3, df_dim * 16, (4, 4), (2, 2), act=None, 36 | padding='SAME', W_init=w_init, b_init=b_init, name='h4/conv2d') 37 | net_h4 = BatchNormLayer(net_h4, act=lambda x: tl.act.lrelu(x, 0.2), 38 | is_train=is_train, gamma_init=gamma_init, name='h4/batchnorm') 39 | 40 | net_h5 = Conv2d(net_h4, df_dim * 32, (4, 4), (2, 2), act=None, 41 | padding='SAME', W_init=w_init, b_init=b_init, name='h5/conv2d') 42 | net_h5 = BatchNormLayer(net_h5, act=lambda x: tl.act.lrelu(x, 0.2), 43 | is_train=is_train, gamma_init=gamma_init, name='h5/batchnorm') 44 | 45 | net_h6 = Conv2d(net_h5, df_dim * 16, (1, 1), (1, 1), act=None, 46 | padding='SAME', W_init=w_init, b_init=b_init, name='h6/conv2d') 47 | net_h6 = BatchNormLayer(net_h6, act=lambda x: tl.act.lrelu(x, 0.2), 48 | is_train=is_train, gamma_init=gamma_init, name='h6/batchnorm') 49 | 50 | net_h7 = Conv2d(net_h6, df_dim * 8, (1, 1), (1, 1), act=None, 51 | padding='SAME', W_init=w_init, b_init=b_init, name='h7/conv2d') 52 | net_h7 = BatchNormLayer(net_h7, is_train=is_train, gamma_init=gamma_init, name='h7/batchnorm') 53 | 54 | net = Conv2d(net_h7, df_dim * 2, (1, 1), (1, 1), act=None, 55 | padding='SAME', W_init=w_init, b_init=b_init, name='h7_res/conv2d') 56 | net = BatchNormLayer(net, act=lambda x: tl.act.lrelu(x, 0.2), 57 | is_train=is_train, gamma_init=gamma_init, name='h7_res/batchnorm') 58 | net = Conv2d(net, df_dim * 2, (3, 3), (1, 1), act=None, 59 | padding='SAME', W_init=w_init, b_init=b_init, name='h7_res/conv2d2') 60 | net = BatchNormLayer(net, act=lambda x: tl.act.lrelu(x, 0.2), 61 | is_train=is_train, gamma_init=gamma_init, name='h7_res/batchnorm2') 62 | net = Conv2d(net, df_dim * 8, (3, 3), (1, 1), act=None, 63 | padding='SAME', W_init=w_init, b_init=b_init, name='h7_res/conv2d3') 64 | net = BatchNormLayer(net, is_train=is_train, gamma_init=gamma_init, name='h7_res/batchnorm3') 65 | 66 | net_h8 = ElementwiseLayer(layer=[net_h7, net], combine_fn=tf.add, name='h8/add') 67 | net_h8.outputs = tl.act.lrelu(net_h8.outputs, 0.2) 68 | 69 | net_ho = FlattenLayer(net_h8, name='output/flatten') 70 | net_ho = DenseLayer(net_ho, n_units=1, act=tf.identity, W_init=w_init, name='output/dense') 71 | logits = net_ho.outputs 72 | net_ho.outputs = tf.nn.sigmoid(net_ho.outputs) 73 | 74 | return net_ho, logits 75 | 76 | 77 | def u_net_bn(x, is_train=False, reuse=False, is_refine=False): 78 | 79 | w_init = tf.truncated_normal_initializer(stddev=0.01) 80 | b_init = tf.constant_initializer(value=0.0) 81 | gamma_init = tf.random_normal_initializer(1., 0.02) 82 | 83 | with tf.variable_scope("u_net", reuse=reuse): 84 | tl.layers.set_name_reuse(reuse) 85 | inputs = InputLayer(x, name='input') 86 | 87 | conv1 = Conv2d(inputs, 64, (4, 4), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2), padding='SAME', 88 | W_init=w_init, b_init=b_init, name='conv1') 89 | conv2 = Conv2d(conv1, 128, (4, 4), (2, 2), act=None, padding='SAME', 90 | W_init=w_init, b_init=b_init, name='conv2') 91 | conv2 = BatchNormLayer(conv2, act=lambda x: tl.act.lrelu(x, 0.2), 92 | is_train=is_train, gamma_init=gamma_init, name='bn2') 93 | 94 | conv3 = Conv2d(conv2, 256, (4, 4), (2, 2), act=None, padding='SAME', 95 | W_init=w_init, b_init=b_init, name='conv3') 96 | conv3 = BatchNormLayer(conv3, act=lambda x: tl.act.lrelu(x, 0.2), 97 | is_train=is_train, gamma_init=gamma_init, name='bn3') 98 | 99 | conv4 = Conv2d(conv3, 512, (4, 4), (2, 2), act=None, padding='SAME', 100 | W_init=w_init, b_init=b_init, name='conv4') 101 | conv4 = BatchNormLayer(conv4, act=lambda x: tl.act.lrelu(x, 0.2), 102 | is_train=is_train, gamma_init=gamma_init, name='bn4') 103 | 104 | conv5 = Conv2d(conv4, 512, (4, 4), (2, 2), act=None, padding='SAME', 105 | W_init=w_init, b_init=b_init, name='conv5') 106 | conv5 = BatchNormLayer(conv5, act=lambda x: tl.act.lrelu(x, 0.2), 107 | is_train=is_train, gamma_init=gamma_init, name='bn5') 108 | 109 | conv6 = Conv2d(conv5, 512, (4, 4), (2, 2), act=None, padding='SAME', 110 | W_init=w_init, b_init=b_init, name='conv6') 111 | conv6 = BatchNormLayer(conv6, act=lambda x: tl.act.lrelu(x, 0.2), 112 | is_train=is_train, gamma_init=gamma_init, name='bn6') 113 | 114 | conv7 = Conv2d(conv6, 512, (4, 4), (2, 2), act=None, padding='SAME', 115 | W_init=w_init, b_init=b_init, name='conv7') 116 | conv7 = BatchNormLayer(conv7, act=lambda x: tl.act.lrelu(x, 0.2), 117 | is_train=is_train, gamma_init=gamma_init, name='bn7') 118 | 119 | conv8 = Conv2d(conv7, 512, (4, 4), (2, 2), act=lambda x: tl.act.lrelu(x, 0.2), 120 | padding='SAME', W_init=w_init, b_init=b_init, name='conv8') 121 | 122 | up7 = DeConv2d(conv8, 512, (4, 4), out_size=(2, 2), strides=(2, 2), padding='SAME', 123 | act=None, W_init=w_init, b_init=b_init, name='deconv7') 124 | up7 = BatchNormLayer(up7, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn7') 125 | 126 | up6 = ConcatLayer([up7, conv7], concat_dim=3, name='concat6') 127 | up6 = DeConv2d(up6, 1024, (4, 4), out_size=(4, 4), strides=(2, 2), padding='SAME', 128 | act=None, W_init=w_init, b_init=b_init, name='deconv6') 129 | up6 = BatchNormLayer(up6, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn6') 130 | 131 | up5 = ConcatLayer([up6, conv6], concat_dim=3, name='concat5') 132 | up5 = DeConv2d(up5, 1024, (4, 4), out_size=(8, 8), strides=(2, 2), padding='SAME', 133 | act=None, W_init=w_init, b_init=b_init, name='deconv5') 134 | up5 = BatchNormLayer(up5, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn5') 135 | 136 | up4 = ConcatLayer([up5, conv5], concat_dim=3, name='concat4') 137 | up4 = DeConv2d(up4, 1024, (4, 4), out_size=(16, 16), strides=(2, 2), padding='SAME', 138 | act=None, W_init=w_init, b_init=b_init, name='deconv4') 139 | up4 = BatchNormLayer(up4, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn4') 140 | 141 | up3 = ConcatLayer([up4, conv4], concat_dim=3, name='concat3') 142 | up3 = DeConv2d(up3, 256, (4, 4), out_size=(32, 32), strides=(2, 2), padding='SAME', 143 | act=None, W_init=w_init, b_init=b_init, name='deconv3') 144 | up3 = BatchNormLayer(up3, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn3') 145 | 146 | up2 = ConcatLayer([up3, conv3], concat_dim=3, name='concat2') 147 | up2 = DeConv2d(up2, 128, (4, 4), out_size=(64, 64), strides=(2, 2), padding='SAME', 148 | act=None, W_init=w_init, b_init=b_init, name='deconv2') 149 | up2 = BatchNormLayer(up2, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn2') 150 | 151 | up1 = ConcatLayer([up2, conv2], concat_dim=3, name='concat1') 152 | up1 = DeConv2d(up1, 64, (4, 4), out_size=(128, 128), strides=(2, 2), padding='SAME', 153 | act=None, W_init=w_init, b_init=b_init, name='deconv1') 154 | up1 = BatchNormLayer(up1, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn1') 155 | 156 | up0 = ConcatLayer([up1, conv1], concat_dim=3, name='concat0') 157 | up0 = DeConv2d(up0, 64, (4, 4), out_size=(256, 256), strides=(2, 2), padding='SAME', 158 | act=None, W_init=w_init, b_init=b_init, name='deconv0') 159 | up0 = BatchNormLayer(up0, act=tf.nn.relu, is_train=is_train, gamma_init=gamma_init, name='dbn0') 160 | 161 | if is_refine: 162 | out = Conv2d(up0, 1, (1, 1), act=tf.nn.tanh, name='out') 163 | out = ElementwiseLayer([out, inputs], tf.add, 'add_for_refine') 164 | out.outputs = tl.act.ramp(out.outputs, v_min=-1, v_max=1) 165 | else: 166 | out = Conv2d(up0, 1, (1, 1), act=tf.nn.tanh, name='out') 167 | 168 | return out 169 | 170 | 171 | def vgg16_cnn_emb(t_image, reuse=False): 172 | with tf.variable_scope("vgg16_cnn", reuse=reuse) as vs: 173 | tl.layers.set_name_reuse(reuse) 174 | t_image = (t_image + 1) * 127.5 # convert input of [-1, 1] to [0, 255] 175 | 176 | mean = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32, shape=[1, 1, 1, 3], name='img_mean') 177 | net_in = InputLayer(t_image - mean, name='vgg_input_im') 178 | 179 | # conv1 180 | network = tl.layers.Conv2dLayer(net_in, 181 | act=tf.nn.relu, 182 | shape=[3, 3, 3, 64], 183 | strides=[1, 1, 1, 1], 184 | padding='SAME', 185 | name='vgg_conv1_1') 186 | network = tl.layers.Conv2dLayer(network, 187 | act=tf.nn.relu, 188 | shape=[3, 3, 64, 64], 189 | strides=[1, 1, 1, 1], 190 | padding='SAME', 191 | name='vgg_conv1_2') 192 | network = tl.layers.PoolLayer(network, 193 | ksize=[1, 2, 2, 1], 194 | strides=[1, 2, 2, 1], 195 | padding='SAME', 196 | pool=tf.nn.max_pool, 197 | name='vgg_pool1') 198 | 199 | # conv2 200 | network = tl.layers.Conv2dLayer(network, 201 | act=tf.nn.relu, 202 | shape=[3, 3, 64, 128], 203 | strides=[1, 1, 1, 1], 204 | padding='SAME', 205 | name='vgg_conv2_1') 206 | network = tl.layers.Conv2dLayer(network, 207 | act=tf.nn.relu, 208 | shape=[3, 3, 128, 128], 209 | strides=[1, 1, 1, 1], 210 | padding='SAME', 211 | name='vgg_conv2_2') 212 | network = tl.layers.PoolLayer(network, 213 | ksize=[1, 2, 2, 1], 214 | strides=[1, 2, 2, 1], 215 | padding='SAME', 216 | pool=tf.nn.max_pool, 217 | name='vgg_pool2') 218 | 219 | # conv3 220 | network = tl.layers.Conv2dLayer(network, 221 | act=tf.nn.relu, 222 | shape=[3, 3, 128, 256], 223 | strides=[1, 1, 1, 1], 224 | padding='SAME', 225 | name='vgg_conv3_1') 226 | network = tl.layers.Conv2dLayer(network, 227 | act=tf.nn.relu, 228 | shape=[3, 3, 256, 256], 229 | strides=[1, 1, 1, 1], 230 | padding='SAME', 231 | name='vgg_conv3_2') 232 | network = tl.layers.Conv2dLayer(network, 233 | act=tf.nn.relu, 234 | shape=[3, 3, 256, 256], 235 | strides=[1, 1, 1, 1], 236 | padding='SAME', 237 | name='vgg_conv3_3') 238 | network = tl.layers.PoolLayer(network, 239 | ksize=[1, 2, 2, 1], 240 | strides=[1, 2, 2, 1], 241 | padding='SAME', 242 | pool=tf.nn.max_pool, 243 | name='vgg_pool3') 244 | # conv4 245 | network = tl.layers.Conv2dLayer(network, 246 | act=tf.nn.relu, 247 | shape=[3, 3, 256, 512], 248 | strides=[1, 1, 1, 1], 249 | padding='SAME', 250 | name='vgg_conv4_1') 251 | network = tl.layers.Conv2dLayer(network, 252 | act=tf.nn.relu, 253 | shape=[3, 3, 512, 512], 254 | strides=[1, 1, 1, 1], 255 | padding='SAME', 256 | name='vgg_conv4_2') 257 | network = tl.layers.Conv2dLayer(network, 258 | act=tf.nn.relu, 259 | shape=[3, 3, 512, 512], 260 | strides=[1, 1, 1, 1], 261 | padding='SAME', 262 | name='vgg_conv4_3') 263 | 264 | network = tl.layers.PoolLayer(network, 265 | ksize=[1, 2, 2, 1], 266 | strides=[1, 2, 2, 1], 267 | padding='SAME', 268 | pool=tf.nn.max_pool, 269 | name='vgg_pool4') 270 | conv4 = network 271 | 272 | # conv5 273 | network = tl.layers.Conv2dLayer(network, 274 | act=tf.nn.relu, 275 | shape=[3, 3, 512, 512], 276 | strides=[1, 1, 1, 1], 277 | padding='SAME', 278 | name='vgg_conv5_1') 279 | network = tl.layers.Conv2dLayer(network, 280 | act=tf.nn.relu, 281 | shape=[3, 3, 512, 512], 282 | strides=[1, 1, 1, 1], 283 | padding='SAME', 284 | name='vgg_conv5_2') 285 | network = tl.layers.Conv2dLayer(network, 286 | act=tf.nn.relu, 287 | shape=[3, 3, 512, 512], 288 | strides=[1, 1, 1, 1], 289 | padding='SAME', 290 | name='vgg_conv5_3') 291 | network = tl.layers.PoolLayer(network, 292 | ksize=[1, 2, 2, 1], 293 | strides=[1, 2, 2, 1], 294 | padding='SAME', 295 | pool=tf.nn.max_pool, 296 | name='vgg_pool5') 297 | 298 | network = FlattenLayer(network, name='vgg_flatten') 299 | 300 | return conv4, network 301 | 302 | 303 | if __name__ == "__main__": 304 | pass 305 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from model import * 3 | from utils import * 4 | from config import config, log_config 5 | from scipy.io import loadmat, savemat 6 | 7 | 8 | def main_test(): 9 | mask_perc = tl.global_flag['maskperc'] 10 | mask_name = tl.global_flag['mask'] 11 | model_name = tl.global_flag['model'] 12 | 13 | # =================================== BASIC CONFIGS =================================== # 14 | 15 | print('[*] run basic configs ... ') 16 | 17 | log_dir = "log_inference_{}_{}_{}".format(model_name, mask_name, mask_perc) 18 | tl.files.exists_or_mkdir(log_dir) 19 | _, _, log_inference, _, _, log_inference_filename = logging_setup(log_dir) 20 | 21 | checkpoint_dir = "checkpoint_inference_{}_{}_{}".format(model_name, mask_name, mask_perc) 22 | tl.files.exists_or_mkdir(checkpoint_dir) 23 | 24 | save_dir = "samples_inference_{}_{}_{}".format(model_name, mask_name, mask_perc) 25 | tl.files.exists_or_mkdir(save_dir) 26 | 27 | # configs 28 | sample_size = config.TRAIN.sample_size 29 | 30 | # ==================================== PREPARE DATA ==================================== # 31 | 32 | print('[*] load data ... ') 33 | testing_data_path = config.TRAIN.testing_data_path 34 | 35 | with open(testing_data_path, 'rb') as f: 36 | X_test = pickle.load(f) 37 | 38 | print('X_test shape/min/max: ', X_test.shape, X_test.min(), X_test.max()) 39 | 40 | print('[*] loading mask ... ') 41 | if mask_name == "gaussian2d": 42 | mask = \ 43 | loadmat( 44 | os.path.join(config.TRAIN.mask_Gaussian2D_path, "GaussianDistribution2DMask_{}.mat".format(mask_perc)))[ 45 | 'maskRS2'] 46 | elif mask_name == "gaussian1d": 47 | mask = \ 48 | loadmat( 49 | os.path.join(config.TRAIN.mask_Gaussian1D_path, "GaussianDistribution1DMask_{}.mat".format(mask_perc)))[ 50 | 'maskRS1'] 51 | elif mask_name == "poisson2d": 52 | mask = \ 53 | loadmat( 54 | os.path.join(config.TRAIN.mask_Gaussian1D_path, "PoissonDistributionMask_{}.mat".format(mask_perc)))[ 55 | 'population_matrix'] 56 | else: 57 | raise ValueError("no such mask exists: {}".format(mask_name)) 58 | 59 | # ==================================== DEFINE MODEL ==================================== # 60 | 61 | print('[*] define model ... ') 62 | 63 | nw, nh, nz = X_test.shape[1:] 64 | 65 | # define placeholders 66 | t_image_good = tf.placeholder('float32', [sample_size, nw, nh, nz], name='good_image') 67 | t_image_bad = tf.placeholder('float32', [sample_size, nw, nh, nz], name='bad_image') 68 | t_gen = tf.placeholder('float32', [sample_size, nw, nh, nz], name='generated_image') 69 | 70 | # define generator network 71 | if tl.global_flag['model'] == 'unet': 72 | net_test = u_net_bn(t_image_bad, is_train=False, reuse=False, is_refine=False) 73 | elif tl.global_flag['model'] == 'unet_refine': 74 | net_test = u_net_bn(t_image_bad, is_train=False, reuse=False, is_refine=True) 75 | else: 76 | raise Exception("unknown model") 77 | 78 | # nmse metric for testing purpose 79 | nmse_a_0_1 = tf.sqrt(tf.reduce_sum(tf.squared_difference(t_gen, t_image_good), axis=[1, 2, 3])) 80 | nmse_b_0_1 = tf.sqrt(tf.reduce_sum(tf.square(t_image_good), axis=[1, 2, 3])) 81 | nmse_0_1 = nmse_a_0_1 / nmse_b_0_1 82 | 83 | # ==================================== INFERENCE ==================================== # 84 | 85 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 86 | tl.files.load_and_assign_npz(sess=sess, 87 | name=os.path.join(checkpoint_dir, tl.global_flag['model']) + '.npz', 88 | network=net_test) 89 | 90 | idex = tl.utils.get_random_int(min=0, max=len(X_test) - 1, number=sample_size, seed=config.TRAIN.seed) 91 | X_samples_good = X_test[idex] 92 | X_samples_bad = threading_data(X_samples_good, fn=to_bad_img, mask=mask) 93 | 94 | x_good_sample_rescaled = (X_samples_good + 1) / 2 95 | x_bad_sample_rescaled = (X_samples_bad + 1) / 2 96 | 97 | tl.visualize.save_images(X_samples_good, 98 | [5, 10], 99 | os.path.join(save_dir, "sample_image_good.png")) 100 | 101 | tl.visualize.save_images(X_samples_bad, 102 | [5, 10], 103 | os.path.join(save_dir, "sample_image_bad.png")) 104 | 105 | tl.visualize.save_images(np.abs(X_samples_good - X_samples_bad), 106 | [5, 10], 107 | os.path.join(save_dir, "sample_image_diff_abs.png")) 108 | 109 | tl.visualize.save_images(np.sqrt(np.abs(X_samples_good - X_samples_bad) / 2 + config.TRAIN.epsilon), 110 | [5, 10], 111 | os.path.join(save_dir, "sample_image_diff_sqrt_abs.png")) 112 | 113 | tl.visualize.save_images(np.clip(10 * np.abs(X_samples_good - X_samples_bad) / 2, 0, 1), 114 | [5, 10], 115 | os.path.join(save_dir, "sample_image_diff_sqrt_abs_10_clip.png")) 116 | 117 | tl.visualize.save_images(threading_data(X_samples_good, fn=distort_img), 118 | [5, 10], 119 | os.path.join(save_dir, "sample_image_aug.png")) 120 | scipy.misc.imsave(os.path.join(save_dir, "mask.png"), mask * 255) 121 | 122 | print('[*] start testing ... ') 123 | 124 | x_gen = sess.run(net_test.outputs, {t_image_bad: X_samples_bad}) 125 | x_gen_0_1 = (x_gen + 1) / 2 126 | 127 | # evaluation for generated data 128 | 129 | nmse_res = sess.run(nmse_0_1, {t_gen: x_gen_0_1, t_image_good: x_good_sample_rescaled}) 130 | ssim_res = threading_data([_ for _ in zip(x_good_sample_rescaled, x_gen_0_1)], fn=ssim) 131 | psnr_res = threading_data([_ for _ in zip(x_good_sample_rescaled, x_gen_0_1)], fn=psnr) 132 | 133 | log = "NMSE testing: {}\nSSIM testing: {}\nPSNR testing: {}\n\n".format( 134 | nmse_res, 135 | ssim_res, 136 | psnr_res) 137 | 138 | log_inference.debug(log) 139 | 140 | log = "NMSE testing average: {}\nSSIM testing average: {}\nPSNR testing average: {}\n\n".format( 141 | np.mean(nmse_res), 142 | np.mean(ssim_res), 143 | np.mean(psnr_res)) 144 | 145 | log_inference.debug(log) 146 | 147 | log = "NMSE testing std: {}\nSSIM testing std: {}\nPSNR testing std: {}\n\n".format(np.std(nmse_res), 148 | np.std(ssim_res), 149 | np.std(psnr_res)) 150 | 151 | log_inference.debug(log) 152 | 153 | # evaluation for zero-filled (ZF) data 154 | nmse_res_zf = sess.run(nmse_0_1, 155 | {t_gen: x_bad_sample_rescaled, t_image_good: x_good_sample_rescaled}) 156 | ssim_res_zf = threading_data([_ for _ in zip(x_good_sample_rescaled, x_bad_sample_rescaled)], fn=ssim) 157 | psnr_res_zf = threading_data([_ for _ in zip(x_good_sample_rescaled, x_bad_sample_rescaled)], fn=psnr) 158 | 159 | log = "NMSE ZF testing: {}\nSSIM ZF testing: {}\nPSNR ZF testing: {}\n\n".format( 160 | nmse_res_zf, 161 | ssim_res_zf, 162 | psnr_res_zf) 163 | 164 | log_inference.debug(log) 165 | 166 | log = "NMSE ZF average testing: {}\nSSIM ZF average testing: {}\nPSNR ZF average testing: {}\n\n".format( 167 | np.mean(nmse_res_zf), 168 | np.mean(ssim_res_zf), 169 | np.mean(psnr_res_zf)) 170 | 171 | log_inference.debug(log) 172 | 173 | log = "NMSE ZF std testing: {}\nSSIM ZF std testing: {}\nPSNR ZF std testing: {}\n\n".format( 174 | np.std(nmse_res_zf), 175 | np.std(ssim_res_zf), 176 | np.std(psnr_res_zf)) 177 | 178 | log_inference.debug(log) 179 | 180 | # sample testing images 181 | tl.visualize.save_images(x_gen, 182 | [5, 10], 183 | os.path.join(save_dir, "final_generated_image.png")) 184 | 185 | tl.visualize.save_images(np.clip(10 * np.abs(X_samples_good - x_gen) / 2, 0, 1), 186 | [5, 10], 187 | os.path.join(save_dir, "final_generated_image_diff_abs_10_clip.png")) 188 | 189 | tl.visualize.save_images(np.clip(10 * np.abs(X_samples_good - X_samples_bad) / 2, 0, 1), 190 | [5, 10], 191 | os.path.join(save_dir, "final_bad_image_diff_abs_10_clip.png")) 192 | 193 | print("[*] Job finished!") 194 | 195 | if __name__ == "__main__": 196 | import argparse 197 | 198 | parser = argparse.ArgumentParser() 199 | 200 | parser.add_argument('--model', type=str, default='unet', help='unet, unet_refine') 201 | parser.add_argument('--mask', type=str, default='gaussian2d', help='gaussian1d, gaussian2d, poisson2d') 202 | parser.add_argument('--maskperc', type=int, default='30', help='10,20,30,40,50') 203 | 204 | args = parser.parse_args() 205 | 206 | tl.global_flag['model'] = args.model 207 | tl.global_flag['mask'] = args.mask 208 | tl.global_flag['maskperc'] = args.maskperc 209 | 210 | main_test() 211 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from model import * 3 | from utils import * 4 | from config import config, log_config 5 | from scipy.io import loadmat, savemat 6 | 7 | 8 | def main_train(): 9 | mask_perc = tl.global_flag['maskperc'] 10 | mask_name = tl.global_flag['mask'] 11 | model_name = tl.global_flag['model'] 12 | 13 | # =================================== BASIC CONFIGS =================================== # 14 | 15 | print('[*] run basic configs ... ') 16 | 17 | log_dir = "log_{}_{}_{}".format(model_name, mask_name, mask_perc) 18 | tl.files.exists_or_mkdir(log_dir) 19 | log_all, log_eval, log_50, log_all_filename, log_eval_filename, log_50_filename = logging_setup(log_dir) 20 | 21 | checkpoint_dir = "checkpoint_{}_{}_{}".format(model_name, mask_name, mask_perc) 22 | tl.files.exists_or_mkdir(checkpoint_dir) 23 | 24 | save_dir = "samples_{}_{}_{}".format(model_name, mask_name, mask_perc) 25 | tl.files.exists_or_mkdir(save_dir) 26 | 27 | # configs 28 | batch_size = config.TRAIN.batch_size 29 | early_stopping_num = config.TRAIN.early_stopping_num 30 | g_alpha = config.TRAIN.g_alpha 31 | g_beta = config.TRAIN.g_beta 32 | g_gamma = config.TRAIN.g_gamma 33 | g_adv = config.TRAIN.g_adv 34 | lr = config.TRAIN.lr 35 | lr_decay = config.TRAIN.lr_decay 36 | decay_every = config.TRAIN.decay_every 37 | beta1 = config.TRAIN.beta1 38 | n_epoch = config.TRAIN.n_epoch 39 | sample_size = config.TRAIN.sample_size 40 | 41 | log_config(log_all_filename, config) 42 | log_config(log_eval_filename, config) 43 | log_config(log_50_filename, config) 44 | 45 | # ==================================== PREPARE DATA ==================================== # 46 | 47 | print('[*] load data ... ') 48 | training_data_path = config.TRAIN.training_data_path 49 | val_data_path = config.TRAIN.val_data_path 50 | testing_data_path = config.TRAIN.testing_data_path 51 | 52 | with open(training_data_path, 'rb') as f: 53 | X_train = pickle.load(f) 54 | 55 | with open(val_data_path, 'rb') as f: 56 | X_val = pickle.load(f) 57 | 58 | with open(testing_data_path, 'rb') as f: 59 | X_test = pickle.load(f) 60 | 61 | print('X_train shape/min/max: ', X_train.shape, X_train.min(), X_train.max()) 62 | print('X_val shape/min/max: ', X_val.shape, X_val.min(), X_val.max()) 63 | print('X_test shape/min/max: ', X_test.shape, X_test.min(), X_test.max()) 64 | 65 | print('[*] loading mask ... ') 66 | if mask_name == "gaussian2d": 67 | mask = \ 68 | loadmat( 69 | os.path.join(config.TRAIN.mask_Gaussian2D_path, "GaussianDistribution2DMask_{}.mat".format(mask_perc)))[ 70 | 'maskRS2'] 71 | elif mask_name == "gaussian1d": 72 | mask = \ 73 | loadmat( 74 | os.path.join(config.TRAIN.mask_Gaussian1D_path, "GaussianDistribution1DMask_{}.mat".format(mask_perc)))[ 75 | 'maskRS1'] 76 | elif mask_name == "poisson2d": 77 | mask = \ 78 | loadmat( 79 | os.path.join(config.TRAIN.mask_Gaussian1D_path, "PoissonDistributionMask_{}.mat".format(mask_perc)))[ 80 | 'population_matrix'] 81 | else: 82 | raise ValueError("no such mask exists: {}".format(mask_name)) 83 | 84 | # ==================================== DEFINE MODEL ==================================== # 85 | 86 | print('[*] define model ... ') 87 | 88 | nw, nh, nz = X_train.shape[1:] 89 | 90 | # define placeholders 91 | t_image_good = tf.placeholder('float32', [batch_size, nw, nh, nz], name='good_image') 92 | t_image_good_samples = tf.placeholder('float32', [sample_size, nw, nh, nz], name='good_image_samples') 93 | t_image_bad = tf.placeholder('float32', [batch_size, nw, nh, nz], name='bad_image') 94 | t_image_bad_samples = tf.placeholder('float32', [sample_size, nw, nh, nz], name='bad_image_samples') 95 | t_gen = tf.placeholder('float32', [batch_size, nw, nh, nz], name='generated_image_for_test') 96 | t_gen_sample = tf.placeholder('float32', [sample_size, nw, nh, nz], name='generated_sample_image_for_test') 97 | t_image_good_244 = tf.placeholder('float32', [batch_size, 244, 244, 3], name='vgg_good_image') 98 | 99 | # define generator network 100 | if tl.global_flag['model'] == 'unet': 101 | net = u_net_bn(t_image_bad, is_train=True, reuse=False, is_refine=False) 102 | net_test = u_net_bn(t_image_bad, is_train=False, reuse=True, is_refine=False) 103 | net_test_sample = u_net_bn(t_image_bad_samples, is_train=False, reuse=True, is_refine=False) 104 | 105 | elif tl.global_flag['model'] == 'unet_refine': 106 | net = u_net_bn(t_image_bad, is_train=True, reuse=False, is_refine=True) 107 | net_test = u_net_bn(t_image_bad, is_train=False, reuse=True, is_refine=True) 108 | net_test_sample = u_net_bn(t_image_bad_samples, is_train=False, reuse=True, is_refine=True) 109 | else: 110 | raise Exception("unknown model") 111 | 112 | # define discriminator network 113 | net_d, logits_fake = discriminator(net.outputs, is_train=True, reuse=False) 114 | _, logits_real = discriminator(t_image_good, is_train=True, reuse=True) 115 | 116 | # define VGG network 117 | net_vgg_conv4_good, _ = vgg16_cnn_emb(t_image_good_244, reuse=False) 118 | net_vgg_conv4_gen, _ = vgg16_cnn_emb(tf.tile(tf.image.resize_images(net.outputs, [244, 244]), [1, 1, 1, 3]), reuse=True) 119 | 120 | # ==================================== DEFINE LOSS ==================================== # 121 | 122 | print('[*] define loss functions ... ') 123 | 124 | # discriminator loss 125 | d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1') 126 | d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2') 127 | d_loss = d_loss1 + d_loss2 128 | 129 | # generator loss (adversarial) 130 | g_loss = tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g') 131 | 132 | # generator loss (perceptual) 133 | g_perceptual = tf.reduce_mean(tf.reduce_mean(tf.squared_difference( 134 | net_vgg_conv4_good.outputs, 135 | net_vgg_conv4_gen.outputs), 136 | axis=[1, 2, 3])) 137 | 138 | # generator loss (pixel-wise) 139 | g_nmse_a = tf.sqrt(tf.reduce_sum(tf.squared_difference(net.outputs, t_image_good), axis=[1, 2, 3])) 140 | g_nmse_b = tf.sqrt(tf.reduce_sum(tf.square(t_image_good), axis=[1, 2, 3])) 141 | g_nmse = tf.reduce_mean(g_nmse_a / g_nmse_b) 142 | 143 | # generator loss (frequency) 144 | fft_good_abs = tf.map_fn(fft_abs_for_map_fn, t_image_good) 145 | fft_gen_abs = tf.map_fn(fft_abs_for_map_fn, net.outputs) 146 | g_fft = tf.reduce_mean(tf.reduce_mean(tf.squared_difference(fft_good_abs, fft_gen_abs), axis=[1, 2])) 147 | 148 | # generator loss (total) 149 | g_loss = g_adv * g_loss + g_alpha * g_nmse + g_gamma * g_perceptual + g_beta * g_fft 150 | 151 | # nmse metric for testing purpose 152 | nmse_a_0_1 = tf.sqrt(tf.reduce_sum(tf.squared_difference(t_gen, t_image_good), axis=[1, 2, 3])) 153 | nmse_b_0_1 = tf.sqrt(tf.reduce_sum(tf.square(t_image_good), axis=[1, 2, 3])) 154 | nmse_0_1 = nmse_a_0_1 / nmse_b_0_1 155 | 156 | nmse_a_0_1_sample = tf.sqrt(tf.reduce_sum(tf.squared_difference(t_gen_sample, t_image_good_samples), axis=[1, 2, 3])) 157 | nmse_b_0_1_sample = tf.sqrt(tf.reduce_sum(tf.square(t_image_good_samples), axis=[1, 2, 3])) 158 | nmse_0_1_sample = nmse_a_0_1_sample / nmse_b_0_1_sample 159 | 160 | # ==================================== DEFINE TRAIN OPTS ==================================== # 161 | 162 | print('[*] define training options ... ') 163 | 164 | g_vars = tl.layers.get_variables_with_name('u_net', True, True) 165 | d_vars = tl.layers.get_variables_with_name('discriminator', True, True) 166 | 167 | with tf.variable_scope('learning_rate'): 168 | lr_v = tf.Variable(lr, trainable=False) 169 | 170 | g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars) 171 | d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars) 172 | 173 | # ==================================== TRAINING ==================================== # 174 | 175 | sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) 176 | tl.layers.initialize_global_variables(sess) 177 | 178 | # load generator and discriminator weights (for continuous training purpose) 179 | tl.files.load_and_assign_npz(sess=sess, 180 | name=os.path.join(checkpoint_dir, tl.global_flag['model']) + '.npz', 181 | network=net) 182 | tl.files.load_and_assign_npz(sess=sess, 183 | name=os.path.join(checkpoint_dir, tl.global_flag['model']) + '_d.npz', 184 | network=net_d) 185 | 186 | # load vgg weights 187 | net_vgg_conv4_path = config.TRAIN.VGG16_path 188 | npz = np.load(net_vgg_conv4_path) 189 | assign_op = [] 190 | for idx, val in enumerate(sorted(npz.items())[0:20]): 191 | print(" Loading pretrained VGG16, CNN part %s" % str(val[1].shape)) 192 | assign_op.append(net_vgg_conv4_good.all_params[idx].assign(val[1])) 193 | sess.run(assign_op) 194 | net_vgg_conv4_good.print_params(False) 195 | 196 | n_training_examples = len(X_train) 197 | n_step_epoch = round(n_training_examples / batch_size) 198 | 199 | # sample testing images 200 | idex = tl.utils.get_random_int(min=0, max=len(X_test) - 1, number=sample_size, seed=config.TRAIN.seed) 201 | X_samples_good = X_test[idex] 202 | X_samples_bad = threading_data(X_samples_good, fn=to_bad_img, mask=mask) 203 | 204 | x_good_sample_rescaled = (X_samples_good + 1) / 2 205 | x_bad_sample_rescaled = (X_samples_bad + 1) / 2 206 | 207 | tl.visualize.save_images(X_samples_good, 208 | [5, 10], 209 | os.path.join(save_dir, "sample_image_good.png")) 210 | 211 | tl.visualize.save_images(X_samples_bad, 212 | [5, 10], 213 | os.path.join(save_dir, "sample_image_bad.png")) 214 | 215 | tl.visualize.save_images(np.abs(X_samples_good - X_samples_bad), 216 | [5, 10], 217 | os.path.join(save_dir, "sample_image_diff_abs.png")) 218 | 219 | tl.visualize.save_images(np.sqrt(np.abs(X_samples_good - X_samples_bad) / 2 + config.TRAIN.epsilon), 220 | [5, 10], 221 | os.path.join(save_dir, "sample_image_diff_sqrt_abs.png")) 222 | 223 | tl.visualize.save_images(np.clip(10 * np.abs(X_samples_good - X_samples_bad) / 2, 0, 1), 224 | [5, 10], 225 | os.path.join(save_dir, "sample_image_diff_sqrt_abs_10_clip.png")) 226 | 227 | tl.visualize.save_images(threading_data(X_samples_good, fn=distort_img), 228 | [5, 10], 229 | os.path.join(save_dir, "sample_image_aug.png")) 230 | scipy.misc.imsave(os.path.join(save_dir, "mask.png"), mask * 255) 231 | 232 | print('[*] start training ... ') 233 | 234 | best_nmse = np.inf 235 | best_epoch = 1 236 | esn = early_stopping_num 237 | for epoch in range(0, n_epoch): 238 | 239 | # learning rate decay 240 | if epoch != 0 and (epoch % decay_every == 0): 241 | new_lr_decay = lr_decay ** (epoch // decay_every) 242 | sess.run(tf.assign(lr_v, lr * new_lr_decay)) 243 | log = " ** new learning rate: %f" % (lr * new_lr_decay) 244 | print(log) 245 | log_all.debug(log) 246 | elif epoch == 0: 247 | log = " ** init lr: %f decay_every_epoch: %d, lr_decay: %f" % (lr, decay_every, lr_decay) 248 | print(log) 249 | log_all.debug(log) 250 | 251 | for step in range(n_step_epoch): 252 | step_time = time.time() 253 | idex = tl.utils.get_random_int(min=0, max=n_training_examples - 1, number=batch_size) 254 | X_good = X_train[idex] 255 | X_good_aug = threading_data(X_good, fn=distort_img) 256 | X_good_244 = threading_data(X_good_aug, fn=vgg_prepro) 257 | X_bad = threading_data(X_good_aug, fn=to_bad_img, mask=mask) 258 | 259 | errD, _ = sess.run([d_loss, d_optim], {t_image_good: X_good_aug, t_image_bad: X_bad}) 260 | errG, errG_perceptual, errG_nmse, errG_fft, _ = sess.run([g_loss, g_perceptual, g_nmse, g_fft, g_optim], 261 | {t_image_good_244: X_good_244, 262 | t_image_good: X_good_aug, 263 | t_image_bad: X_bad}) 264 | 265 | log = "Epoch[{:3}/{:3}] step={:3} d_loss={:5} g_loss={:5} g_perceptual_loss={:5} g_mse={:5} g_freq={:5} took {:3}s".format( 266 | epoch + 1, 267 | n_epoch, 268 | step, 269 | round(float(errD), 3), 270 | round(float(errG), 3), 271 | round(float(errG_perceptual), 3), 272 | round(float(errG_nmse), 3), 273 | round(float(errG_fft), 3), 274 | round(time.time() - step_time, 2)) 275 | 276 | print(log) 277 | log_all.debug(log) 278 | 279 | # evaluation for training data 280 | total_nmse_training = 0 281 | total_ssim_training = 0 282 | total_psnr_training = 0 283 | num_training_temp = 0 284 | for batch in tl.iterate.minibatches(inputs=X_train, targets=X_train, batch_size=batch_size, shuffle=False): 285 | x_good, _ = batch 286 | # x_bad = threading_data(x_good, fn=to_bad_img, mask=mask) 287 | x_bad = threading_data( 288 | x_good, 289 | fn=to_bad_img, 290 | mask=mask) 291 | 292 | x_gen = sess.run(net_test.outputs, {t_image_bad: x_bad}) 293 | 294 | x_good_0_1 = (x_good + 1) / 2 295 | x_gen_0_1 = (x_gen + 1) / 2 296 | 297 | nmse_res = sess.run(nmse_0_1, {t_gen: x_gen_0_1, t_image_good: x_good_0_1}) 298 | ssim_res = threading_data([_ for _ in zip(x_good_0_1, x_gen_0_1)], fn=ssim) 299 | psnr_res = threading_data([_ for _ in zip(x_good_0_1, x_gen_0_1)], fn=psnr) 300 | total_nmse_training += np.sum(nmse_res) 301 | total_ssim_training += np.sum(ssim_res) 302 | total_psnr_training += np.sum(psnr_res) 303 | num_training_temp += batch_size 304 | 305 | total_nmse_training /= num_training_temp 306 | total_ssim_training /= num_training_temp 307 | total_psnr_training /= num_training_temp 308 | 309 | log = "Epoch: {}\nNMSE training: {:8}, SSIM training: {:8}, PSNR training: {:8}".format( 310 | epoch + 1, 311 | total_nmse_training, 312 | total_ssim_training, 313 | total_psnr_training) 314 | print(log) 315 | log_all.debug(log) 316 | log_eval.info(log) 317 | 318 | # evaluation for validation data 319 | total_nmse_val = 0 320 | total_ssim_val = 0 321 | total_psnr_val = 0 322 | num_val_temp = 0 323 | for batch in tl.iterate.minibatches(inputs=X_val, targets=X_val, batch_size=batch_size, shuffle=False): 324 | x_good, _ = batch 325 | # x_bad = threading_data(x_good, fn=to_bad_img, mask=mask) 326 | x_bad = threading_data( 327 | x_good, 328 | fn=to_bad_img, 329 | mask=mask) 330 | 331 | x_gen = sess.run(net_test.outputs, {t_image_bad: x_bad}) 332 | 333 | x_good_0_1 = (x_good + 1) / 2 334 | x_gen_0_1 = (x_gen + 1) / 2 335 | 336 | nmse_res = sess.run(nmse_0_1, {t_gen: x_gen_0_1, t_image_good: x_good_0_1}) 337 | ssim_res = threading_data([_ for _ in zip(x_good_0_1, x_gen_0_1)], fn=ssim) 338 | psnr_res = threading_data([_ for _ in zip(x_good_0_1, x_gen_0_1)], fn=psnr) 339 | total_nmse_val += np.sum(nmse_res) 340 | total_ssim_val += np.sum(ssim_res) 341 | total_psnr_val += np.sum(psnr_res) 342 | num_val_temp += batch_size 343 | 344 | total_nmse_val /= num_val_temp 345 | total_ssim_val /= num_val_temp 346 | total_psnr_val /= num_val_temp 347 | 348 | log = "Epoch: {}\nNMSE val: {:8}, SSIM val: {:8}, PSNR val: {:8}".format( 349 | epoch + 1, 350 | total_nmse_val, 351 | total_ssim_val, 352 | total_psnr_val) 353 | print(log) 354 | log_all.debug(log) 355 | log_eval.info(log) 356 | 357 | img = sess.run(net_test_sample.outputs, {t_image_bad_samples: X_samples_bad}) 358 | tl.visualize.save_images(img, 359 | [5, 10], 360 | os.path.join(save_dir, "image_{}.png".format(epoch))) 361 | 362 | if total_nmse_val < best_nmse: 363 | esn = early_stopping_num # reset early stopping num 364 | best_nmse = total_nmse_val 365 | best_epoch = epoch + 1 366 | 367 | # save current best model 368 | tl.files.save_npz(net.all_params, 369 | name=os.path.join(checkpoint_dir, tl.global_flag['model']) + '.npz', 370 | sess=sess) 371 | 372 | tl.files.save_npz(net_d.all_params, 373 | name=os.path.join(checkpoint_dir, tl.global_flag['model']) + '_d.npz', 374 | sess=sess) 375 | print("[*] Save checkpoints SUCCESS!") 376 | else: 377 | esn -= 1 378 | 379 | log = "Best NMSE result: {} at {} epoch".format(best_nmse, best_epoch) 380 | log_eval.info(log) 381 | log_all.debug(log) 382 | print(log) 383 | 384 | # early stopping triggered 385 | if esn == 0: 386 | log_eval.info(log) 387 | 388 | tl.files.load_and_assign_npz(sess=sess, 389 | name=os.path.join(checkpoint_dir, tl.global_flag['model']) + '.npz', 390 | network=net) 391 | # evluation for test data 392 | x_gen = sess.run(net_test_sample.outputs, {t_image_bad_samples: X_samples_bad}) 393 | x_gen_0_1 = (x_gen + 1) / 2 394 | savemat(save_dir + '/test_random_50_generated.mat', {'x_gen_0_1': x_gen_0_1}) 395 | 396 | nmse_res = sess.run(nmse_0_1_sample, {t_gen_sample: x_gen_0_1, t_image_good_samples: x_good_sample_rescaled}) 397 | ssim_res = threading_data([_ for _ in zip(x_good_sample_rescaled, x_gen_0_1)], fn=ssim) 398 | psnr_res = threading_data([_ for _ in zip(x_good_sample_rescaled, x_gen_0_1)], fn=psnr) 399 | 400 | log = "NMSE testing: {}\nSSIM testing: {}\nPSNR testing: {}\n\n".format( 401 | nmse_res, 402 | ssim_res, 403 | psnr_res) 404 | 405 | log_50.debug(log) 406 | 407 | log = "NMSE testing average: {}\nSSIM testing average: {}\nPSNR testing average: {}\n\n".format( 408 | np.mean(nmse_res), 409 | np.mean(ssim_res), 410 | np.mean(psnr_res)) 411 | 412 | log_50.debug(log) 413 | 414 | log = "NMSE testing std: {}\nSSIM testing std: {}\nPSNR testing std: {}\n\n".format(np.std(nmse_res), 415 | np.std(ssim_res), 416 | np.std(psnr_res)) 417 | 418 | log_50.debug(log) 419 | 420 | # evaluation for zero-filled (ZF) data 421 | nmse_res_zf = sess.run(nmse_0_1_sample, 422 | {t_gen_sample: x_bad_sample_rescaled, t_image_good_samples: x_good_sample_rescaled}) 423 | ssim_res_zf = threading_data([_ for _ in zip(x_good_sample_rescaled, x_bad_sample_rescaled)], fn=ssim) 424 | psnr_res_zf = threading_data([_ for _ in zip(x_good_sample_rescaled, x_bad_sample_rescaled)], fn=psnr) 425 | 426 | log = "NMSE ZF testing: {}\nSSIM ZF testing: {}\nPSNR ZF testing: {}\n\n".format( 427 | nmse_res_zf, 428 | ssim_res_zf, 429 | psnr_res_zf) 430 | 431 | log_50.debug(log) 432 | 433 | log = "NMSE ZF average testing: {}\nSSIM ZF average testing: {}\nPSNR ZF average testing: {}\n\n".format( 434 | np.mean(nmse_res_zf), 435 | np.mean(ssim_res_zf), 436 | np.mean(psnr_res_zf)) 437 | 438 | log_50.debug(log) 439 | 440 | log = "NMSE ZF std testing: {}\nSSIM ZF std testing: {}\nPSNR ZF std testing: {}\n\n".format( 441 | np.std(nmse_res_zf), 442 | np.std(ssim_res_zf), 443 | np.std(psnr_res_zf)) 444 | 445 | log_50.debug(log) 446 | 447 | # sample testing images 448 | tl.visualize.save_images(x_gen, 449 | [5, 10], 450 | os.path.join(save_dir, "final_generated_image.png")) 451 | 452 | tl.visualize.save_images(np.clip(10 * np.abs(X_samples_good - x_gen) / 2, 0, 1), 453 | [5, 10], 454 | os.path.join(save_dir, "final_generated_image_diff_abs_10_clip.png")) 455 | 456 | tl.visualize.save_images(np.clip(10 * np.abs(X_samples_good - X_samples_bad) / 2, 0, 1), 457 | [5, 10], 458 | os.path.join(save_dir, "final_bad_image_diff_abs_10_clip.png")) 459 | 460 | print("[*] Job finished!") 461 | break 462 | 463 | 464 | if __name__ == "__main__": 465 | import argparse 466 | 467 | parser = argparse.ArgumentParser() 468 | 469 | parser.add_argument('--model', type=str, default='unet', help='unet, unet_refine') 470 | parser.add_argument('--mask', type=str, default='gaussian2d', help='gaussian1d, gaussian2d, poisson2d') 471 | parser.add_argument('--maskperc', type=int, default='30', help='10,20,30,40,50') 472 | 473 | args = parser.parse_args() 474 | 475 | tl.global_flag['model'] = args.model 476 | tl.global_flag['mask'] = args.mask 477 | tl.global_flag['maskperc'] = args.maskperc 478 | 479 | main_train() 480 | -------------------------------------------------------------------------------- /trained_model/VGG16/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/DAGAN/c8fcd0e69efe46d40279bfea8868e85533f0ef91/trained_model/VGG16/README.md -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from tensorlayer.prepro import * 2 | import numpy as np 3 | import skimage.measure 4 | import scipy 5 | from time import localtime, strftime 6 | import logging 7 | 8 | 9 | def distort_img(x): 10 | x = (x + 1.) / 2. 11 | x = flip_axis(x, axis=1, is_random=True) 12 | x = elastic_transform(x, alpha=255 * 3, sigma=255 * 0.10, is_random=True) 13 | x = rotation(x, rg=10, is_random=True, fill_mode='constant') 14 | x = shift(x, wrg=0.10, hrg=0.10, is_random=True, fill_mode='constant') 15 | x = zoom(x, zoom_range=[0.90, 1.10], is_random=True, fill_mode='constant') 16 | x = brightness(x, gamma=0.05, is_random=True) 17 | x = x * 2 - 1 18 | return x 19 | 20 | 21 | def to_bad_img(x, mask): 22 | x = (x + 1.) / 2. 23 | fft = scipy.fftpack.fft2(x[:, :, 0]) 24 | fft = scipy.fftpack.fftshift(fft) 25 | fft = fft * mask 26 | fft = scipy.fftpack.ifftshift(fft) 27 | x = scipy.fftpack.ifft2(fft) 28 | x = np.abs(x) 29 | x = x * 2 - 1 30 | return x[:, :, np.newaxis] 31 | 32 | 33 | def fft_abs_for_map_fn(x): 34 | x = (x + 1.) / 2. 35 | x_complex = tf.complex(x, tf.zeros_like(x))[:, :, 0] 36 | fft = tf.spectral.fft2d(x_complex) 37 | fft_abs = tf.abs(fft) 38 | return fft_abs 39 | 40 | 41 | def ssim(data): 42 | x_good, x_bad = data 43 | x_good = np.squeeze(x_good) 44 | x_bad = np.squeeze(x_bad) 45 | ssim_res = skimage.measure.compare_ssim(x_good, x_bad) 46 | return ssim_res 47 | 48 | 49 | def psnr(data): 50 | x_good, x_bad = data 51 | psnr_res = skimage.measure.compare_psnr(x_good, x_bad) 52 | return psnr_res 53 | 54 | 55 | def vgg_prepro(x): 56 | x = imresize(x, [244, 244], interp='bilinear', mode=None) 57 | x = np.tile(x, 3) 58 | x = x / 127.5 - 1 59 | return x 60 | 61 | 62 | def logging_setup(log_dir): 63 | current_time_str = strftime("%Y_%m_%d_%H_%M_%S", localtime()) 64 | log_all_filename = os.path.join(log_dir, 'log_all_{}.log'.format(current_time_str)) 65 | log_eval_filename = os.path.join(log_dir, 'log_eval_{}.log'.format(current_time_str)) 66 | 67 | log_all = logging.getLogger('log_all') 68 | log_all.setLevel(logging.DEBUG) 69 | log_all.addHandler(logging.FileHandler(log_all_filename)) 70 | 71 | log_eval = logging.getLogger('log_eval') 72 | log_eval.setLevel(logging.INFO) 73 | log_eval.addHandler(logging.FileHandler(log_eval_filename)) 74 | 75 | log_50_filename = os.path.join(log_dir, 'log_50_images_testing_{}.log'.format(current_time_str)) 76 | 77 | log_50 = logging.getLogger('log_50') 78 | log_50.setLevel(logging.DEBUG) 79 | log_50.addHandler(logging.FileHandler(log_50_filename)) 80 | 81 | return log_all, log_eval, log_50, log_all_filename, log_eval_filename, log_50_filename 82 | 83 | 84 | if __name__ == "__main__": 85 | pass 86 | --------------------------------------------------------------------------------