├── LICENSE ├── README.md ├── data ├── NMI │ ├── test │ │ ├── FL_gt │ │ │ └── FL_gt_0001.png │ │ ├── FL_motion │ │ │ └── FL_motion_0001.png │ │ ├── T1_gt │ │ │ └── T1_gt_0001.png │ │ ├── T1_motion │ │ │ └── T1_motion_0001.png │ │ ├── T2_gt │ │ │ └── T2_gt_0001.png │ │ └── T2_motion │ │ │ └── T2_motion_0001.png │ ├── train │ │ ├── FL_gt │ │ │ └── FL_gt_0001.png │ │ ├── FL_motion │ │ │ └── FL_motion_0001.png │ │ ├── T1_gt │ │ │ └── T1_gt_0001.png │ │ ├── T1_motion │ │ │ └── T1_motion_0001.png │ │ ├── T2_gt │ │ │ └── T2_gt_0001.png │ │ └── T2_motion │ │ │ └── T2_motion_0001.png │ └── valid │ │ ├── FL_gt │ │ └── FL_gt_0001.png │ │ ├── FL_motion │ │ └── FL_motion_0001.png │ │ ├── T1_gt │ │ └── T1_gt_0001.png │ │ ├── T1_motion │ │ └── T1_motion_0001.png │ │ ├── T2_gt │ │ └── T2_gt_0001.png │ │ └── T2_motion │ │ └── T2_motion_0001.png ├── Network │ └── test │ │ ├── FL_gt │ │ └── FL_gt_0001.png │ │ ├── FL_motion │ │ └── FL_motion_0001.png │ │ ├── T1_gt │ │ └── T1_gt_0001.png │ │ ├── T1_motion │ │ └── T1_motion_0001.png │ │ ├── T2_gt │ │ └── T2_gt_0001.png │ │ └── T2_motion │ │ └── T2_motion_0001.png └── No │ ├── test │ ├── FL_gt │ │ └── FL_gt_0001.png │ ├── FL_motion │ │ └── FL_motion_0001.png │ ├── T1_gt │ │ └── T1_gt_0001.png │ ├── T1_motion │ │ └── T1_motion_0001.png │ ├── T2_gt │ │ └── T2_gt_0001.png │ └── T2_motion │ │ └── T2_motion_0001.png │ ├── train │ ├── FL_gt │ │ └── FL_gt_0001.png │ ├── FL_motion │ │ └── FL_motion_0001.png │ ├── T1_gt │ │ └── T1_gt_0001.png │ ├── T1_motion │ │ └── T1_motion_0001.png │ ├── T2_gt │ │ └── T2_gt_0001.png │ └── T2_motion │ │ └── T2_motion_0001.png │ └── valid │ ├── FL_gt │ └── FL_gt_0001.png │ ├── FL_motion │ └── FL_motion_0001.png │ ├── T1_gt │ └── T1_gt_0001.png │ ├── T1_motion │ └── T1_motion_0001.png │ ├── T2_gt │ └── T2_gt_0001.png │ └── T2_motion │ └── T2_motion_0001.png ├── dataset.py ├── figure ├── figure1.png ├── figure2.png └── figure3.png ├── logs ├── log.INFO ├── log.user-SYS-4029GP-TRT2.jylee.log.INFO.20200708-145230.46000 └── log.user-SYS-4029GP-TRT2.jylee.log.INFO.20200708-150247.1749 ├── model.py ├── requirements.txt ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 OpenXAIProject 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MC2-Net: Motion Correction Network for Multi-Contrast Brain MRI 2 | 3 | Tensorflow implementation of "MC2-Net: Motion Correction Network for Multi-Contrast Brain MRI" 4 | 5 | ![](figure/figure1.png) 6 | 7 | Contact: Jongyeon Lee, KAIST (jyl4you@kaist.ac.kr) 8 | 9 | ## Paper 10 | [MC2-Net: Motion Correction Network for Multi-Contrast Brain MRI]()
11 | Jongyeon Lee, Byungjai Kim, and [Hyunwook Park](http://athena.kaist.ac.kr)
12 | (Under review) 13 | 14 | Please cite our paper if you find it useful for your research. (Currently under review for MRM) 15 | 16 | ``` 17 | @inproceedings{, 18 | author = {}, 19 | booktitle = {}, 20 | title = {MC2-Net: Motion Correction Network for Multi-Contrast Brain MRI}, 21 | year = {2020} 22 | } 23 | ``` 24 | 25 | ## Example Results 26 | 27 | ![](figure/figure2.png) 28 | 29 | ## Quantitative Reuslts 30 | 31 | ![](figure/figure3.png) 32 | 33 | ## Installation 34 | * Install Tensorflow-gpu 2.2.0 with Python 3.6 and CUDA 10.2 35 | 36 | ``` 37 | pip install -r requirements.txt 38 | ``` 39 | or 40 | ``` 41 | conda create -r [ENV_NAME] python=3.6 42 | conda activate [ENV_NAME] 43 | conda install tensorflow-gpu=2.2.0 scikit-image=0.16.2 44 | ``` 45 | 46 | * Clone this repo 47 | ``` 48 | git clone https://github.com/OpenXAIProject/mc2-net.git 49 | cd mc2-net 50 | ``` 51 | 52 | ## Dataset 53 | * Dataset used for this study is not the public data. The sample images are in data directory for demo. 54 | * You may use [BraTS 2020](https://ipp.cbica.upenn.edu/#BraTS20_registration) as an alternative dataset. To generate the data for training, please refer to the paper and Figure 4 for motion simulation. This step is not implmented in Python code due to its dependency on pulse sequences. 55 | 56 | ## Testing 57 | * Use CUDA_VISIBLE_DEVICES for GPU selection 58 | * Evaluate the predicted images and save them as png files 59 | 60 | ``` 61 | CUDA_VISIBLE_DEVICES=0 python test.py --load_weight_name ./weight/[weight filename].h5 62 | ``` 63 | 64 | ## Training Examples 65 | * Use CUDA_VISIBLE_DEVICES for GPU selection 66 | * Train the sample images 67 | 68 | ``` 69 | CUDA_VISIBLE_DEVICES=0 python train.py --num_epoch 1000 70 | ``` 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /data/NMI/test/FL_gt/FL_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/test/FL_gt/FL_gt_0001.png -------------------------------------------------------------------------------- /data/NMI/test/FL_motion/FL_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/test/FL_motion/FL_motion_0001.png -------------------------------------------------------------------------------- /data/NMI/test/T1_gt/T1_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/test/T1_gt/T1_gt_0001.png -------------------------------------------------------------------------------- /data/NMI/test/T1_motion/T1_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/test/T1_motion/T1_motion_0001.png -------------------------------------------------------------------------------- /data/NMI/test/T2_gt/T2_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/test/T2_gt/T2_gt_0001.png -------------------------------------------------------------------------------- /data/NMI/test/T2_motion/T2_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/test/T2_motion/T2_motion_0001.png -------------------------------------------------------------------------------- /data/NMI/train/FL_gt/FL_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/train/FL_gt/FL_gt_0001.png -------------------------------------------------------------------------------- /data/NMI/train/FL_motion/FL_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/train/FL_motion/FL_motion_0001.png -------------------------------------------------------------------------------- /data/NMI/train/T1_gt/T1_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/train/T1_gt/T1_gt_0001.png -------------------------------------------------------------------------------- /data/NMI/train/T1_motion/T1_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/train/T1_motion/T1_motion_0001.png -------------------------------------------------------------------------------- /data/NMI/train/T2_gt/T2_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/train/T2_gt/T2_gt_0001.png -------------------------------------------------------------------------------- /data/NMI/train/T2_motion/T2_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/train/T2_motion/T2_motion_0001.png -------------------------------------------------------------------------------- /data/NMI/valid/FL_gt/FL_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/valid/FL_gt/FL_gt_0001.png -------------------------------------------------------------------------------- /data/NMI/valid/FL_motion/FL_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/valid/FL_motion/FL_motion_0001.png -------------------------------------------------------------------------------- /data/NMI/valid/T1_gt/T1_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/valid/T1_gt/T1_gt_0001.png -------------------------------------------------------------------------------- /data/NMI/valid/T1_motion/T1_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/valid/T1_motion/T1_motion_0001.png -------------------------------------------------------------------------------- /data/NMI/valid/T2_gt/T2_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/valid/T2_gt/T2_gt_0001.png -------------------------------------------------------------------------------- /data/NMI/valid/T2_motion/T2_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/valid/T2_motion/T2_motion_0001.png -------------------------------------------------------------------------------- /data/Network/test/FL_gt/FL_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/Network/test/FL_gt/FL_gt_0001.png -------------------------------------------------------------------------------- /data/Network/test/FL_motion/FL_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/Network/test/FL_motion/FL_motion_0001.png -------------------------------------------------------------------------------- /data/Network/test/T1_gt/T1_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/Network/test/T1_gt/T1_gt_0001.png -------------------------------------------------------------------------------- /data/Network/test/T1_motion/T1_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/Network/test/T1_motion/T1_motion_0001.png -------------------------------------------------------------------------------- /data/Network/test/T2_gt/T2_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/Network/test/T2_gt/T2_gt_0001.png -------------------------------------------------------------------------------- /data/Network/test/T2_motion/T2_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/Network/test/T2_motion/T2_motion_0001.png -------------------------------------------------------------------------------- /data/No/test/FL_gt/FL_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/test/FL_gt/FL_gt_0001.png -------------------------------------------------------------------------------- /data/No/test/FL_motion/FL_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/test/FL_motion/FL_motion_0001.png -------------------------------------------------------------------------------- /data/No/test/T1_gt/T1_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/test/T1_gt/T1_gt_0001.png -------------------------------------------------------------------------------- /data/No/test/T1_motion/T1_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/test/T1_motion/T1_motion_0001.png -------------------------------------------------------------------------------- /data/No/test/T2_gt/T2_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/test/T2_gt/T2_gt_0001.png -------------------------------------------------------------------------------- /data/No/test/T2_motion/T2_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/test/T2_motion/T2_motion_0001.png -------------------------------------------------------------------------------- /data/No/train/FL_gt/FL_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/train/FL_gt/FL_gt_0001.png -------------------------------------------------------------------------------- /data/No/train/FL_motion/FL_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/train/FL_motion/FL_motion_0001.png -------------------------------------------------------------------------------- /data/No/train/T1_gt/T1_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/train/T1_gt/T1_gt_0001.png -------------------------------------------------------------------------------- /data/No/train/T1_motion/T1_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/train/T1_motion/T1_motion_0001.png -------------------------------------------------------------------------------- /data/No/train/T2_gt/T2_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/train/T2_gt/T2_gt_0001.png -------------------------------------------------------------------------------- /data/No/train/T2_motion/T2_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/train/T2_motion/T2_motion_0001.png -------------------------------------------------------------------------------- /data/No/valid/FL_gt/FL_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/valid/FL_gt/FL_gt_0001.png -------------------------------------------------------------------------------- /data/No/valid/FL_motion/FL_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/valid/FL_motion/FL_motion_0001.png -------------------------------------------------------------------------------- /data/No/valid/T1_gt/T1_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/valid/T1_gt/T1_gt_0001.png -------------------------------------------------------------------------------- /data/No/valid/T1_motion/T1_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/valid/T1_motion/T1_motion_0001.png -------------------------------------------------------------------------------- /data/No/valid/T2_gt/T2_gt_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/valid/T2_gt/T2_gt_0001.png -------------------------------------------------------------------------------- /data/No/valid/T2_motion/T2_motion_0001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/valid/T2_motion/T2_motion_0001.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import glob 5 | import random 6 | 7 | 8 | def datalist_loader(path, reg_type=None, data_type='train'): 9 | data_path = os.path.join(path, reg_type, data_type) 10 | 11 | contrasts = ['T1', 'T2', 'FL'] 12 | gt_motion = ['gt', 'motion'] 13 | 14 | datalist = [] 15 | for x_or_y in gt_motion: 16 | for contrast in contrasts: 17 | datalist.append(sorted(glob.glob(data_path+'/'+contrast+'_'+x_or_y+'/*.png'))) 18 | 19 | return datalist[0:3], datalist[3:6] 20 | 21 | 22 | def load_batch(fname_list): 23 | out = [] 24 | for fname in fname_list: 25 | img = cv2.imread(fname, cv2.IMREAD_GRAYSCALE) 26 | out.append(img) 27 | out = np.expand_dims(np.array(out).astype('float32'), axis=-1) 28 | out = out/255 29 | return out 30 | 31 | 32 | def train_batch_data_loader(datalist, num_contrast, shuffle=True): 33 | datalist_t = list(zip(*datalist)) 34 | 35 | train_size = len(datalist_t) 36 | if shuffle: 37 | random.shuffle(datalist_t) 38 | datalist = list(zip(*datalist_t)) 39 | 40 | datalist_out_y = [] 41 | datalist_out_x = [] 42 | for i in range(train_size): 43 | datalist_single_y = [] 44 | datalist_single_x = [] 45 | for j in range(num_contrast): 46 | gt_or_motion = random.randint(0, 1)*num_contrast # 0 for gt input, 3 for motion input 47 | datalist_single_y.append(datalist[j][i]) 48 | datalist_single_x.append(datalist[j+gt_or_motion][i]) 49 | 50 | datalist_out_y.append(datalist_single_y) 51 | datalist_out_x.append(datalist_single_x) 52 | 53 | datalist_out_y = list(map(list, zip(*datalist_out_y))) 54 | datalist_out_x = list(map(list, zip(*datalist_out_x))) 55 | 56 | return datalist_out_y, datalist_out_x 57 | 58 | 59 | 60 | def batch_data_loader(batch_datalist, num_contrast): 61 | ''' 62 | Arguments: 63 | datalist = list 64 | num_contrast = The number of contrast 65 | ''' 66 | 67 | batch = [] 68 | for i in range(num_contrast): 69 | batch.append(load_batch(batch_datalist[i])) 70 | 71 | return batch -------------------------------------------------------------------------------- /figure/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/figure/figure1.png -------------------------------------------------------------------------------- /figure/figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/figure/figure2.png -------------------------------------------------------------------------------- /figure/figure3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/figure/figure3.png -------------------------------------------------------------------------------- /logs/log.user-SYS-4029GP-TRT2.jylee.log.INFO.20200708-145230.46000: -------------------------------------------------------------------------------- 1 | Epoch [ 1/1000] | Iter [ 0/ 193] 5.35s.. train loss for T1: 4.5392, T2: 1.6255, FL: 1.7207 2 | Epoch [ 1/1000] | Iter [ 10/ 193] 26.13s.. train loss for T1: 2.7811, T2: 0.9277, FL: 0.9363 3 | Epoch [ 1/1000] | Iter [ 20/ 193] 46.85s.. train loss for T1: 2.7792, T2: 0.9316, FL: 0.9271 4 | Epoch [ 1/1000] | Iter [ 30/ 193] 67.59s.. train loss for T1: 2.7379, T2: 0.9186, FL: 0.9023 5 | Epoch [ 1/1000] | Iter [ 40/ 193] 88.29s.. train loss for T1: 2.6266, T2: 0.8918, FL: 0.8410 6 | Epoch [ 1/1000] | Iter [ 50/ 193] 109.22s.. train loss for T1: 2.4726, T2: 0.8600, FL: 0.7427 7 | Epoch [ 1/1000] | Iter [ 60/ 193] 130.27s.. train loss for T1: 2.3297, T2: 0.8144, FL: 0.6976 8 | Epoch [ 1/1000] | Iter [ 70/ 193] 151.45s.. train loss for T1: 2.1408, T2: 0.7545, FL: 0.6696 9 | Epoch [ 1/1000] | Iter [ 80/ 193] 172.51s.. train loss for T1: 1.9576, T2: 0.6982, FL: 0.6087 10 | Epoch [ 1/1000] | Iter [ 90/ 193] 193.78s.. train loss for T1: 1.7520, T2: 0.6326, FL: 0.5330 11 | Epoch [ 1/1000] | Iter [ 100/ 193] 215.08s.. train loss for T1: 1.4899, T2: 0.5489, FL: 0.4644 12 | Epoch [ 1/1000] | Iter [ 110/ 193] 236.10s.. train loss for T1: 1.1551, T2: 0.3951, FL: 0.4216 13 | Epoch [ 1/1000] | Iter [ 120/ 193] 257.25s.. train loss for T1: 0.8505, T2: 0.2494, FL: 0.3962 14 | Epoch [ 1/1000] | Iter [ 130/ 193] 278.38s.. train loss for T1: 0.5245, T2: 0.1725, FL: 0.2284 15 | Epoch [ 1/1000] | Iter [ 140/ 193] 299.52s.. train loss for T1: 0.3533, T2: 0.1238, FL: 0.1439 16 | Epoch [ 1/1000] | Iter [ 150/ 193] 320.51s.. train loss for T1: 0.2639, T2: 0.0932, FL: 0.1028 17 | Epoch [ 1/1000] | Iter [ 160/ 193] 341.63s.. train loss for T1: 0.2052, T2: 0.0712, FL: 0.0700 18 | Epoch [ 1/1000] | Iter [ 170/ 193] 362.74s.. train loss for T1: 0.1873, T2: 0.0683, FL: 0.0640 19 | Epoch [ 1/1000] | Iter [ 180/ 193] 383.87s.. train loss for T1: 0.1631, T2: 0.0555, FL: 0.0532 20 | Epoch [ 1/1000] | Iter [ 190/ 193] 405.23s.. train loss for T1: 0.1547, T2: 0.0469, FL: 0.0540 21 | Epoch [ 1/1000] 409.42s.. train loss for T1: 0.0745, T2: 0.0256, FL: 0.0247 22 | Epoch [ 2/1000] | Iter [ 0/ 193] 2.12s.. train loss for T1: 0.1274, T2: 0.0408, FL: 0.0424 23 | Epoch [ 2/1000] | Iter [ 10/ 193] 23.34s.. train loss for T1: 0.1252, T2: 0.0380, FL: 0.0426 24 | Epoch [ 2/1000] | Iter [ 20/ 193] 44.53s.. train loss for T1: 0.1145, T2: 0.0309, FL: 0.0419 25 | Epoch [ 2/1000] | Iter [ 30/ 193] 65.86s.. train loss for T1: 0.1120, T2: 0.0334, FL: 0.0377 26 | Epoch [ 2/1000] | Iter [ 40/ 193] 87.25s.. train loss for T1: 0.1031, T2: 0.0245, FL: 0.0353 27 | Epoch [ 2/1000] | Iter [ 50/ 193] 108.66s.. train loss for T1: 0.1035, T2: 0.0273, FL: 0.0349 28 | Epoch [ 2/1000] | Iter [ 60/ 193] 129.98s.. train loss for T1: 0.0999, T2: 0.0230, FL: 0.0430 29 | Epoch [ 2/1000] | Iter [ 70/ 193] 151.49s.. train loss for T1: 0.0969, T2: 0.0282, FL: 0.0348 30 | Epoch [ 2/1000] | Iter [ 80/ 193] 173.11s.. train loss for T1: 0.0911, T2: 0.0232, FL: 0.0351 31 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import tensorflow.keras.backend as K 4 | from tensorflow import keras 5 | 6 | tf.random.set_seed(22) 7 | np.random.seed(22) 8 | 9 | class InstanceNormalization(keras.layers.Layer): 10 | def __init__(self, 11 | axis=None, 12 | epsilon=1e-3, 13 | center=True, 14 | scale=True, 15 | beta_initializer='zeros', 16 | gamma_initializer='ones', 17 | beta_regularizer=None, 18 | gamma_regularizer=None, 19 | beta_constraint=None, 20 | gamma_constraint=None, 21 | **kwargs): 22 | super(InstanceNormalization, self).__init__(**kwargs) 23 | self.supports_masking = True 24 | self.axis = axis 25 | self.epsilon = epsilon 26 | self.center = center 27 | self.scale = scale 28 | self.beta_initializer = keras.initializers.get(beta_initializer) 29 | self.gamma_initializer = keras.initializers.get(gamma_initializer) 30 | self.beta_regularizer = keras.regularizers.get(beta_regularizer) 31 | self.gamma_regularizer = keras.regularizers.get(gamma_regularizer) 32 | self.beta_constraint = keras.constraints.get(beta_constraint) 33 | self.gamma_constraint = keras.constraints.get(gamma_constraint) 34 | 35 | def build(self, input_shape): 36 | ndim = len(input_shape) 37 | if self.axis == 0: 38 | raise ValueError('Axis cannot be zero') 39 | 40 | if (self.axis is not None) and (ndim == 2): 41 | raise ValueError('Cannot specify axis for rank 1 tensor') 42 | 43 | self.input_spec = keras.layers.InputSpec(ndim=ndim) 44 | 45 | if self.axis is None: 46 | shape = (1,) 47 | else: 48 | shape = (input_shape[self.axis],) 49 | 50 | if self.scale: 51 | self.gamma = self.add_weight(shape=shape, 52 | name='gamma', 53 | initializer=self.gamma_initializer, 54 | regularizer=self.gamma_regularizer, 55 | constraint=self.gamma_constraint) 56 | else: 57 | self.gamma = None 58 | if self.center: 59 | self.beta = self.add_weight(shape=shape, 60 | name='beta', 61 | initializer=self.beta_initializer, 62 | regularizer=self.beta_regularizer, 63 | constraint=self.beta_constraint) 64 | else: 65 | self.beta = None 66 | self.built = True 67 | 68 | def call(self, inputs, training=None): 69 | input_shape = K.int_shape(inputs) 70 | reduction_axes = list(range(0, len(input_shape))) 71 | 72 | if self.axis is not None: 73 | del reduction_axes[self.axis] 74 | 75 | del reduction_axes[0] 76 | 77 | mean = K.mean(inputs, reduction_axes, keepdims=True) 78 | stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon 79 | normed = (inputs - mean) / stddev 80 | 81 | broadcast_shape = [1] * len(input_shape) 82 | if self.axis is not None: 83 | broadcast_shape[self.axis] = input_shape[self.axis] 84 | 85 | if self.scale: 86 | broadcast_gamma = K.reshape(self.gamma, broadcast_shape) 87 | normed = normed * broadcast_gamma 88 | if self.center: 89 | broadcast_beta = K.reshape(self.beta, broadcast_shape) 90 | normed = normed + broadcast_beta 91 | return normed 92 | 93 | def get_config(self): 94 | config = { 95 | 'axis': self.axis, 96 | 'epsilon': self.epsilon, 97 | 'center': self.center, 98 | 'scale': self.scale, 99 | 'beta_initializer': keras.initializers.serialize(self.beta_initializer), 100 | 'gamma_initializer': keras.initializers.serialize(self.gamma_initializer), 101 | 'beta_regularizer': keras.regularizers.serialize(self.beta_regularizer), 102 | 'gamma_regularizer': keras.regularizers.serialize(self.gamma_regularizer), 103 | 'beta_constraint': keras.constraints.serialize(self.beta_constraint), 104 | 'gamma_constraint': keras.constraints.serialize(self.gamma_constraint) 105 | } 106 | base_config = super(InstanceNormalization, self).get_config() 107 | return dict(list(base_config.items()) + list(config.items())) 108 | 109 | 110 | """ Motion correction network for MC2-Net """ 111 | 112 | 113 | class Encoder(keras.Model): 114 | def __init__(self, initial_filters=64): 115 | super(Encoder, self).__init__() 116 | 117 | self.filters = initial_filters 118 | 119 | self.conv1 = keras.layers.Conv2D(self.filters, kernel_size=7, strides=1, padding='same', 120 | kernel_initializer=tf.random_normal_initializer(stddev=0.02)) 121 | self.conv2 = keras.layers.Conv2D(self.filters*2, kernel_size=3, strides=2, padding='same', 122 | kernel_initializer=tf.random_normal_initializer(stddev=0.02)) 123 | self.conv3 = keras.layers.Conv2D(self.filters*4, kernel_size=3, strides=2, padding='same', 124 | kernel_initializer=tf.random_normal_initializer(stddev=0.02)) 125 | 126 | self.n1 = InstanceNormalization() 127 | self.n2 = InstanceNormalization() 128 | self.n3 = InstanceNormalization() 129 | 130 | 131 | def call(self, x, training=True): 132 | x = self.conv1(x) 133 | x = self.n1(x, training=training) 134 | x = tf.nn.relu(x) 135 | 136 | x = self.conv2(x) 137 | x = self.n2(x, training=training) 138 | x = tf.nn.relu(x) 139 | 140 | x = self.conv3(x) 141 | x = self.n3(x, training=training) 142 | x = tf.nn.relu(x) 143 | 144 | return x 145 | 146 | 147 | class Residual(keras.Model): 148 | def __init__(self, initial_filters=256): 149 | super(Residual, self).__init__() 150 | 151 | self.filters = initial_filters 152 | 153 | self.conv1 = keras.layers.Conv2D(self.filters, kernel_size=3, strides=1, padding='same', 154 | kernel_initializer=tf.random_normal_initializer(stddev=0.02)) 155 | self.conv2 = keras.layers.Conv2D(self.filters, kernel_size=3, strides=1, padding='same', 156 | kernel_initializer=tf.random_normal_initializer(stddev=0.02)) 157 | 158 | self.in1 = InstanceNormalization() 159 | self.in2 = InstanceNormalization() 160 | 161 | def call(self, x, training=True): 162 | inputs = x 163 | 164 | x = self.conv1(x) 165 | # x = self.in1(x, training=training) 166 | x = tf.nn.relu(x) 167 | 168 | x = self.conv2(x) 169 | # x = self.in2(x, training=training) 170 | x = tf.nn.relu(x) 171 | 172 | x = tf.add(x, inputs) 173 | 174 | return x 175 | 176 | 177 | class Decoder(keras.Model): 178 | def __init__(self, initial_filters=128): 179 | super(Decoder, self).__init__() 180 | 181 | self.filters = initial_filters 182 | 183 | self.conv1 = keras.layers.Conv2DTranspose(self.filters, kernel_size=3, strides=2, padding='same', 184 | kernel_initializer=tf.random_normal_initializer(stddev=0.02)) 185 | self.conv2 = keras.layers.Conv2DTranspose(self.filters//2, kernel_size=3, strides=2, padding='same', 186 | kernel_initializer=tf.random_normal_initializer(stddev=0.02)) 187 | self.conv3 = keras.layers.Conv2D(1, kernel_size=7, strides=1, padding='same', 188 | kernel_initializer=tf.random_normal_initializer(stddev=0.02)) 189 | 190 | self.in1 = InstanceNormalization() 191 | self.in2 = InstanceNormalization() 192 | self.in3 = InstanceNormalization() 193 | 194 | def call(self, x, training=True): 195 | x = self.conv1(x) 196 | x = self.in1(x, training=training) 197 | x = tf.nn.relu(x) 198 | 199 | x = self.conv2(x) 200 | x = self.in2(x, training=training) 201 | x = tf.nn.relu(x) 202 | 203 | x = self.conv3(x) 204 | x = self.in3(x, training=training) 205 | x = tf.nn.relu(x) 206 | 207 | return x 208 | 209 | 210 | class MC_Net(keras.Model): 211 | def __init__(self, 212 | img_size=256, 213 | num_filter=16, 214 | num_contrast=3, 215 | num_res_block=9): 216 | super(MC_Net, self).__init__() 217 | 218 | self.img_size = img_size 219 | self.filters = num_filter 220 | self.num_contrast = num_contrast 221 | self.num_res_block = num_res_block 222 | 223 | self.encoder_list = [] 224 | for _ in range(num_contrast): 225 | self.encoder_list.append(Encoder(initial_filters=self.filters)) 226 | 227 | self.res_block_list = [] 228 | for _ in range(num_res_block): 229 | self.res_block_list.append(Residual(initial_filters=self.filters*4*num_contrast)) 230 | 231 | self.decoder_list = [] 232 | for _ in range(num_contrast): 233 | self.decoder_list.append(Decoder(initial_filters=self.filters*2)) 234 | 235 | def build(self, input_shape): 236 | assert isinstance(input_shape, list) 237 | super(MC_Net, self).build(input_shape) 238 | 239 | def call(self, x, training=True): 240 | x_list = [] 241 | for i in range(self.num_contrast): 242 | x_list.append(self.encoder_list[i](x[i], training)) 243 | x = tf.concat(x_list, axis=-1) 244 | 245 | for i in range(self.num_res_block): 246 | x = (self.res_block_list[i](x, training)) 247 | 248 | y = tf.split(x, num_or_size_splits=self.num_contrast, axis=-1) 249 | 250 | y_list = [] 251 | for i in range(self.num_contrast): 252 | y_list.append(self.decoder_list[i](y[i], training)) 253 | 254 | return y_list 255 | 256 | 257 | def ssim_loss(img1, img2): 258 | return -tf.math.log((tf.image.ssim(img1, img2, max_val=1.0)+1)/2) 259 | 260 | 261 | def vgg_layers(layer_names): 262 | vgg = tf.keras.applications.vgg16.VGG16(include_top=False, weights='imagenet', input_shape=(256, 256, 3)) 263 | vgg.trainable = False 264 | 265 | outputs = [vgg.get_layer(name).output for name in layer_names] 266 | model = tf.keras.Model([vgg.input], outputs) 267 | return model 268 | 269 | 270 | def vgg_loss(img1, img2, loss_model): 271 | img1 = tf.repeat(img1, 3, -1) 272 | img2 = tf.repeat(img2, 3, -1) 273 | 274 | return tf.reduce_mean(tf.square(loss_model(img1) - loss_model(img2))) 275 | 276 | 277 | def make_custom_loss(l1, l2, loss_model): 278 | def custom_loss(y_true, y_pred): 279 | return l1*ssim_loss(y_true, y_pred) + l2*vgg_loss(y_true, y_pred, loss_model) 280 | 281 | return custom_loss -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | astunparse==1.6.3 3 | blinker==1.4 4 | brotlipy==0.7.0 5 | cachetools==4.1.0 6 | certifi==2020.6.20 7 | cffi==1.14.0 8 | chardet==3.0.4 9 | click==7.1.2 10 | cloudpickle==1.4.1 11 | cryptography==2.9.2 12 | cycler==0.10.0 13 | cytoolz==0.10.1 14 | dask @ file:///tmp/build/80754af9/dask-core_1592842333140/work 15 | decorator==4.4.2 16 | gast==0.3.3 17 | google-auth==1.14.1 18 | google-auth-oauthlib==0.4.1 19 | google-pasta==0.2.0 20 | grpcio==1.27.2 21 | h5py @ file:///tmp/build/80754af9/h5py_1593454121459/work 22 | idna @ file:///tmp/build/80754af9/idna_1593446292537/work 23 | imageio==2.8.0 24 | Keras-Preprocessing==1.1.0 25 | kiwisolver==1.2.0 26 | Markdown==3.1.1 27 | matplotlib @ file:///tmp/build/80754af9/matplotlib-base_1592846044287/work 28 | mkl-fft==1.1.0 29 | mkl-random==1.1.1 30 | mkl-service==2.3.0 31 | networkx==2.4 32 | nibabel==2.5.1 33 | numpy==1.17.4 34 | oauthlib==3.1.0 35 | olefile==0.46 36 | opencv-python==4.1.1.26 37 | opt-einsum==3.1.0 38 | Pillow==7.1.2 39 | protobuf==3.12.3 40 | pyasn1==0.4.8 41 | pyasn1-modules==0.2.7 42 | pycparser==2.20 43 | PyJWT==1.7.1 44 | pyOpenSSL==19.1.0 45 | pyparsing==2.4.7 46 | PySocks==1.7.1 47 | python-dateutil==2.8.1 48 | PyWavelets==1.1.1 49 | PyYAML==5.3.1 50 | requests @ file:///tmp/build/80754af9/requests_1592841827918/work 51 | requests-oauthlib==1.3.0 52 | rsa==4.0 53 | scikit-image==0.16.2 54 | scipy @ file:///tmp/build/80754af9/scipy_1592930540905/work 55 | six==1.13.0 56 | tensorboard==2.2.1 57 | tensorboard-plugin-wit==1.6.0 58 | tensorflow==2.2.0 59 | tensorflow-estimator==2.2.0 60 | termcolor==1.1.0 61 | toolz==0.10.0 62 | tornado==6.0.4 63 | urllib3==1.25.9 64 | Werkzeug==1.0.1 65 | wrapt==1.12.1 66 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | tf.config.experimental_run_functions_eagerly(True) 7 | 8 | from model import MC_Net 9 | from dataset import datalist_loader, batch_data_loader 10 | from utils import test_ssim, test_nmi, test_nrmse, save_image 11 | 12 | tf.random.set_seed(22) 13 | np.random.seed(22) 14 | physical_devices = tf.config.experimental.list_physical_devices('GPU') 15 | tf.config.experimental.set_memory_growth(physical_devices[0], True) 16 | 17 | FLAGS = tf.compat.v1.flags.FLAGS 18 | 19 | tf.compat.v1.flags.DEFINE_integer('batch_size', 1, 20 | 'Batch size (Default: 1)') 21 | tf.compat.v1.flags.DEFINE_integer('image_size', 256, 22 | 'Image size (size x size) (Default: 256)') 23 | tf.compat.v1.flags.DEFINE_string('load_weight_name', 'weight_min_val_loss.h5', 24 | 'Load weight of given name (Default: weight_min_val_loss.h5)') 25 | tf.compat.v1.flags.DEFINE_integer('num_contrast', 3, 26 | 'Number of contrasts of MR images (Default: 3)') 27 | tf.compat.v1.flags.DEFINE_integer('num_filter', 16, 28 | 'Number of filters in the first layer of the encoder (Default: 16)') 29 | tf.compat.v1.flags.DEFINE_integer('num_res_block', 9, 30 | 'Number of residual blocks (Default: 9)') 31 | tf.compat.v1.flags.DEFINE_string('path_data', './data/', 32 | 'Data load path (Default: ./data/') 33 | tf.compat.v1.flags.DEFINE_string('path_save', './test/', 34 | 'Image save path (Default: ./test/') 35 | tf.compat.v1.flags.DEFINE_string('path_weight', './weight/', 36 | 'Weight load path (Default: ./weight/') 37 | tf.compat.v1.flags.DEFINE_string('reg_type', 'NMI', 38 | 'Registration type of input images (No, Network, or NMI) (Default: NMI)') 39 | 40 | assert tf.__version__.startswith('2.') 41 | print('Tensorflow version: ', tf.__version__) 42 | tf.random.set_seed(22) 43 | np.random.seed(22) 44 | 45 | physical_devices = tf.config.experimental.list_physical_devices('GPU') 46 | tf.config.experimental.set_memory_growth(physical_devices[0], True) 47 | 48 | def test(): 49 | model = MC_Net(img_size=FLAGS.image_size, 50 | num_filter=FLAGS.num_filter, 51 | num_contrast=FLAGS.num_contrast, 52 | num_res_block=FLAGS.num_res_block) 53 | 54 | input_shape = [(None, FLAGS.image_size, FLAGS.image_size, 1)] 55 | model.build(input_shape=input_shape * FLAGS.num_contrast) 56 | 57 | model.load_weights(FLAGS.path_weight + FLAGS.load_weight_name) 58 | print('Model building completed!') 59 | 60 | y_test_datalist, x_test_datalist = datalist_loader(FLAGS.path_data, FLAGS.reg_type, 'test') 61 | x_test = batch_data_loader(x_test_datalist, FLAGS.num_contrast) 62 | y_test = batch_data_loader(y_test_datalist, FLAGS.num_contrast) 63 | print('Data loading completed!') 64 | 65 | p_test = model.predict(x_test, batch_size=FLAGS.batch_size) 66 | print('Prediction completed!') 67 | 68 | x_ssim_T1, p_ssim_T1 = test_ssim(x_test[0], y_test[0], p_test[0]) 69 | x_ssim_T2, p_ssim_T2 = test_ssim(x_test[1], y_test[1], p_test[1]) 70 | x_ssim_FL, p_ssim_FL = test_ssim(x_test[2], y_test[2], p_test[2]) 71 | 72 | print(' c | x_ssim | p_ssim') 73 | print(f'T1 | {x_ssim_T1:.4f} | {p_ssim_T1:.4f}') 74 | print(f'T2 | {x_ssim_T2:.4f} | {p_ssim_T2:.4f}') 75 | print(f'FL | {x_ssim_FL:.4f} | {p_ssim_FL:.4f}') 76 | 77 | x_nmi_T1, p_nmi_T1 = test_nmi(x_test[0], y_test[0], p_test[0]) 78 | x_nmi_T2, p_nmi_T2 = test_nmi(x_test[1], y_test[1], p_test[1]) 79 | x_nmi_FL, p_nmi_FL = test_nmi(x_test[2], y_test[2], p_test[2]) 80 | 81 | print(' c | x_nmi | p_nmi ') 82 | print(f'T1 | {x_nmi_T1:.4f} | {p_nmi_T1:.4f}') 83 | print(f'T2 | {x_nmi_T2:.4f} | {p_nmi_T2:.4f}') 84 | print(f'FL | {x_nmi_FL:.4f} | {p_nmi_FL:.4f}') 85 | 86 | x_nrmse_T1, p_nrmse_T1 = test_nrmse(x_test[0], y_test[0], p_test[0]) 87 | x_nrmse_T2, p_nrmse_T2 = test_nrmse(x_test[1], y_test[1], p_test[1]) 88 | x_nrmse_FL, p_nrmse_FL = test_nrmse(x_test[2], y_test[2], p_test[2]) 89 | 90 | print(' c | x_nrmse | p_nrmse') 91 | print(f'T1 | {x_nrmse_T1:.4f} | {p_nrmse_T1:.4f}') 92 | print(f'T2 | {x_nrmse_T2:.4f} | {p_nrmse_T2:.4f}') 93 | print(f'FL | {x_nrmse_FL:.4f} | {p_nrmse_FL:.4f}') 94 | 95 | os.makedirs(FLAGS.path_save, exist_ok=True) 96 | for i in range(p_test[0].shape[0]): 97 | save_image(f'{FLAGS.path_save}/T1_pred_{i+1:04d}.png', p_test[0][i]) 98 | save_image(f'{FLAGS.path_save}/T2_pred_{i+1:04d}.png', p_test[1][i]) 99 | save_image(f'{FLAGS.path_save}/FL_pred_{i+1:04d}.png', p_test[2][i]) 100 | print('Image saving completed!') 101 | 102 | 103 | if __name__ == '__main__': 104 | test() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 3 | 4 | from absl import app, flags, logging 5 | from absl.flags import FLAGS 6 | import numpy as np 7 | import tensorflow as tf 8 | from tensorflow import keras 9 | tf.config.experimental_run_functions_eagerly(True) 10 | import time 11 | from tqdm import tqdm 12 | 13 | from model import MC_Net, vgg_layers, make_custom_loss 14 | from dataset import datalist_loader, train_batch_data_loader, batch_data_loader 15 | from utils import rot_tra_argumentation 16 | 17 | tf.random.set_seed(22) 18 | np.random.seed(22) 19 | physical_devices = tf.config.experimental.list_physical_devices('GPU') 20 | tf.config.experimental.set_memory_growth(physical_devices[0], True) 21 | 22 | # FLAGS = flags.FLAGS 23 | 24 | flags.DEFINE_integer('batch_size', 1, 25 | 'Batch size (Default: 1)') 26 | flags.DEFINE_integer('image_size', 256, 27 | 'Image size (size x size) (Default: 256)') 28 | flags.DEFINE_integer('iter_interval', 1, 29 | 'Iteration interval for logging (Default: 1)') 30 | flags.DEFINE_float('lambda_ssim', 1, 31 | 'Weight for SSIM loss (Default: 1)') 32 | flags.DEFINE_float('lambda_vgg', 1e-2, 33 | 'Weight for VGG loss (Default: 0.01)') 34 | flags.DEFINE_float('learning_rate', 1e-4, 35 | 'Initial learning rate for Adam (Default: 0.0001)') 36 | flags.DEFINE_string('load_weight_name', None, 37 | 'Load weight of given name (Default: None)') 38 | flags.DEFINE_integer('num_contrast', 3, 39 | 'Number of contrasts of MR images (Default: 3)') 40 | flags.DEFINE_integer('num_epoch', 1, 41 | 'Number of epochs for training (Default: 1)') 42 | flags.DEFINE_integer('num_filter', 16, 43 | 'Number of filters in the first layer of the encoder (Default: 16)') 44 | flags.DEFINE_integer('num_res_block', 9, 45 | 'Number of residual blocks (Default: 9)') 46 | flags.DEFINE_string('path_data', './data/', 47 | 'Data load path (Default: ./data/') 48 | flags.DEFINE_string('path_weight', './weight/', 49 | 'Weight save path (Default: ./weight/') 50 | flags.DEFINE_string('reg_type', 'NMI', 51 | 'Registration type of input images (No, Network, or NMI) (Default: NMI)') 52 | flags.DEFINE_integer('save_epoch', 10, 53 | 'Save weights by every given number of epochs (Default: 10)') 54 | 55 | 56 | def train(_argv): 57 | os.makedirs('./logs', exist_ok=True) 58 | logging.get_absl_handler().use_absl_log_file('log', "./logs") 59 | logging.get_absl_handler().setFormatter(None) 60 | 61 | os.makedirs(FLAGS.path_weight, exist_ok=True) 62 | 63 | model = MC_Net(img_size=FLAGS.image_size, 64 | num_filter=FLAGS.num_filter, 65 | num_contrast=FLAGS.num_contrast, 66 | num_res_block=FLAGS.num_res_block) 67 | 68 | loss_model = vgg_layers(['block3_conv1']) 69 | final_loss = make_custom_loss(FLAGS.lambda_ssim, FLAGS.lambda_vgg, loss_model) 70 | model.compile(optimizer=keras.optimizers.Adam(FLAGS.learning_rate), 71 | loss=final_loss) 72 | input_shape = [(None, FLAGS.image_size, FLAGS.image_size, 1)] 73 | model.build(input_shape=input_shape * FLAGS.num_contrast) 74 | model.summary() 75 | 76 | # Data loading assumes that the number of contrasts is 3 and contrasts are T1, T2, and FL. 77 | # If you have different datasets, please modify dataset.datalist_loader. 78 | y_train_datalist, x_train_datalist = datalist_loader(FLAGS.path_data, FLAGS.reg_type, 'train') 79 | y_valid_datalist, x_valid_datalist = datalist_loader(FLAGS.path_data, FLAGS.reg_type, 'valid') 80 | 81 | batch_size = FLAGS.batch_size 82 | epochs = FLAGS.num_epoch 83 | train_size = len(y_train_datalist[0]) 84 | batch_number = int(np.ceil(train_size//batch_size)) 85 | 86 | min_val_loss = 100000 87 | 88 | if FLAGS.load_weight_name is not None: 89 | weight_path = FLAGS.path_weight + '/' + FLAGS.load_weight_name 90 | model.load_weights(weight_path) 91 | 92 | for epoch in range(epochs): 93 | start_time = time.time() 94 | train_loss = [0, 0, 0] 95 | y_train_datalist_shuffle, x_train_datalist_shuffle =\ 96 | train_batch_data_loader(y_train_datalist+x_train_datalist, FLAGS.num_contrast) 97 | for batch_index in tqdm(range(batch_number), ncols=100): 98 | start = batch_index*batch_size 99 | 100 | y_train_datalist_batch = [] 101 | x_train_datalist_batch = [] 102 | for i in range(FLAGS.num_contrast): 103 | y_train_datalist_batch.append(y_train_datalist_shuffle[i][start:start+batch_size]) 104 | x_train_datalist_batch.append(x_train_datalist_shuffle[i][start:start+batch_size]) 105 | 106 | y_train_batch = batch_data_loader(y_train_datalist_batch, FLAGS.num_contrast) 107 | x_train_batch = batch_data_loader(x_train_datalist_batch, FLAGS.num_contrast) 108 | y_train_batch, x_train_batch = rot_tra_argumentation(y_train_batch, x_train_batch, FLAGS.num_contrast) 109 | 110 | batch_size_tmp = x_train_batch[0].shape[0] 111 | tmp_loss = model.train_on_batch(x_train_batch, y_train_batch) 112 | 113 | if batch_index % FLAGS.iter_interval == 0: 114 | logging.info(f'Epoch [{epoch+1:4d}/{epochs:4d}] | Iter [{batch_index:4d}/{batch_number:4d}] ' 115 | f'{time.time() - start_time:.2f}s.. ' 116 | f'train loss for T1: {tmp_loss[0]:.4f}, T2: {tmp_loss[1]:.4f}, FL: {tmp_loss[2]:.4f}') 117 | 118 | train_loss = [(x + y*batch_size_tmp) for (x, y) in zip(train_loss, tmp_loss)] 119 | 120 | train_loss = [x / train_size for x in train_loss] 121 | 122 | print(f'Epoch [{epoch+1:4d}/{epochs:4d}] {time.time() - start_time:.2f}s.. ' 123 | f'train loss for T1: {train_loss[0]:.4f}, T2: {train_loss[1]:.4f}, FL: {train_loss[2]:.4f}') 124 | logging.info(f'Epoch [{epoch+1:4d}/{epochs:4d}] {time.time() - start_time:.2f}s.. ' 125 | f'train loss for T1: {train_loss[0]:.4f}, T2: {train_loss[1]:.4f}, FL: {train_loss[2]:.4f}') 126 | 127 | if ((epoch+1) % FLAGS.save_epoch) == 0: 128 | x_valid = batch_data_loader(x_valid_datalist, FLAGS.num_contrast) 129 | y_valid = batch_data_loader(y_valid_datalist, FLAGS.num_contrast) 130 | val_loss = model.evaluate(x_valid, y_valid, verbose=0) 131 | model.save_weights(f'{FLAGS.path_weight}weight_e{epoch+1:04d}.h5', overwrite=True) 132 | 133 | if np.mean(val_loss) < min_val_loss: 134 | model.save_weights(f'{FLAGS.path_weight}weight_min_val_loss.h5', overwrite=True) 135 | min_val_loss = np.mean(val_loss) 136 | 137 | print(f'Weight saved! val loss T1: {val_loss[0]:.4f}, T2: {val_loss[1]:.4f}, FL: {val_loss[2]:.4f}') 138 | del x_valid, y_valid 139 | 140 | if epoch+1 == epochs: 141 | model.save_weights(f'{FLAGS.path_weight}weight_final.h5', 142 | overwrite=True) 143 | print(f'Weight saved! Training finished.') 144 | 145 | if __name__ == '__main__': 146 | try: 147 | app.run(train) 148 | except SystemExit: 149 | pass -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import rotate, shift 3 | import cv2 4 | from skimage.metrics import structural_similarity 5 | 6 | 7 | def transpose(img, r, tx, ty, channel_last=True): 8 | if img.ndim == 2: 9 | s = shift(rotate(img, r, [0, 1], order=1, reshape=False), [tx, ty], order=1) 10 | elif img.ndim == 3: 11 | if channel_last: 12 | s = shift(rotate(img, r, [0, 1], order=1, reshape=False), [tx, ty, 0], order=1) 13 | else: 14 | s = shift(rotate(img, r, [1, 2], order=1, reshape=False), [0, tx, ty], order=1) 15 | elif img.ndim == 4: 16 | s = shift(rotate(img, r, [1, 2], order=1, reshape=False), [0, tx, ty, 0], order=1) 17 | else: 18 | raise TypeError 19 | 20 | return s 21 | 22 | 23 | def rand_range(range): 24 | return (np.random.rand()-0.5)*2*range 25 | 26 | 27 | def generate_rand_par(num): 28 | x = (np.random.rand(num)-.5)*num/(num-1) 29 | return x 30 | 31 | 32 | def rot_tra_argumentation(y, x, num_contrast): 33 | size = y[0].shape[0] 34 | 35 | x_copy = x 36 | y_copy = y 37 | for i in range(size): 38 | r = rand_range(10) 39 | tx = rand_range(5) 40 | ty = rand_range(5) 41 | for j in range(num_contrast): 42 | y_copy[j][i,:,:,:] = transpose(y[j][i,:,:,:], r, tx, ty) 43 | x_copy[j][i,:,:,:] = transpose(x[j][i,:,:,:], r, tx, ty) 44 | 45 | return y_copy, x_copy 46 | 47 | 48 | def save_image(file_name, image): 49 | image = np.squeeze(image)*255 50 | cv2.imwrite(file_name, image) 51 | 52 | 53 | def merge_images(x, y, p): 54 | size = x[0].shape[0] 55 | h = x[0].shape[1] 56 | w = x[0].shape[2] 57 | out = np.zeros([size, h*3, w*3]) 58 | for i in range(size): 59 | out[i, 0:h, 0:w] = x[0][i,:,:,0] 60 | out[i, 0:h, w:w*2] = x[1][i,:,:,0] 61 | out[i, 0:h, w*2:w*3] = x[2][i,:,:,0] 62 | out[i, h:h*2, 0:w] = y[0][i,:,:,0] 63 | out[i, h:h*2, w:w*2] = y[1][i,:,:,0] 64 | out[i, h:h*2, w*2:w*3] = y[2][i,:,:,0] 65 | out[i, h*2:h*3, 0:w] = p[0][i,:,:,0] 66 | out[i, h*2:h*3, w:w*2] = p[1][i,:,:,0] 67 | out[i, h*2:h*3, w*2:w*3] = p[2][i,:,:,0] 68 | return out 69 | 70 | 71 | def ssim(img, ref): 72 | return structural_similarity(img, ref, data_range=1.) 73 | 74 | 75 | def test_ssim(x_img, y_img, p_img): 76 | x_ssim = [] 77 | p_ssim = [] 78 | for i in range(x_img.shape[0]): 79 | x_ssim.append(ssim(x_img[i,:,:,0], y_img[i,:,:,0])) 80 | p_ssim.append(ssim(p_img[i,:,:,0], y_img[i,:,:,0])) 81 | return np.mean(x_ssim), np.mean(p_ssim) 82 | 83 | 84 | def nmi(img, ref, bins=16): 85 | eps = 1e-10 86 | hist = np.histogram2d(img.flatten(), ref.flatten(), bins=bins)[0] 87 | pxy = hist / np.sum(hist) 88 | px = np.sum(pxy, axis=1) 89 | py = np.sum(pxy, axis=0) 90 | px_py = px[:, None] * py[None, :] 91 | nzs = pxy > 0 92 | mi = np.sum(pxy[nzs] * np.log2((pxy[nzs] + eps) / (px_py[nzs] + eps))) 93 | entx = - np.sum(px * np.log2(px + eps)) 94 | enty = - np.sum(py * np.log2(py + eps)) 95 | 96 | return (2 * mi + eps) / (entx + enty + eps) 97 | 98 | 99 | def test_nmi(x_img, y_img, p_img): 100 | x_nmi = [] 101 | p_nmi = [] 102 | for i in range(x_img.shape[0]): 103 | x_nmi.append(nmi(x_img[i,:,:,0], y_img[i,:,:,0])) 104 | p_nmi.append(nmi(p_img[i,:,:,0], y_img[i,:,:,0])) 105 | return np.mean(x_nmi), np.mean(p_nmi) 106 | 107 | 108 | def nrmse(img, ref): 109 | rmse = np.sqrt(np.mean((img-ref)**2)) 110 | return rmse/np.mean(ref) 111 | 112 | 113 | def test_nrmse(x_img, y_img, p_img): 114 | x_nrmse = [] 115 | p_nrmse = [] 116 | for i in range(x_img.shape[0]): 117 | x_nrmse.append(nrmse(x_img[i,:,:,0], y_img[i,:,:,0])) 118 | p_nrmse.append(nrmse(p_img[i,:,:,0], y_img[i,:,:,0])) 119 | return np.mean(x_nrmse), np.mean(p_nrmse) --------------------------------------------------------------------------------