├── LICENSE
├── README.md
├── data
├── NMI
│ ├── test
│ │ ├── FL_gt
│ │ │ └── FL_gt_0001.png
│ │ ├── FL_motion
│ │ │ └── FL_motion_0001.png
│ │ ├── T1_gt
│ │ │ └── T1_gt_0001.png
│ │ ├── T1_motion
│ │ │ └── T1_motion_0001.png
│ │ ├── T2_gt
│ │ │ └── T2_gt_0001.png
│ │ └── T2_motion
│ │ │ └── T2_motion_0001.png
│ ├── train
│ │ ├── FL_gt
│ │ │ └── FL_gt_0001.png
│ │ ├── FL_motion
│ │ │ └── FL_motion_0001.png
│ │ ├── T1_gt
│ │ │ └── T1_gt_0001.png
│ │ ├── T1_motion
│ │ │ └── T1_motion_0001.png
│ │ ├── T2_gt
│ │ │ └── T2_gt_0001.png
│ │ └── T2_motion
│ │ │ └── T2_motion_0001.png
│ └── valid
│ │ ├── FL_gt
│ │ └── FL_gt_0001.png
│ │ ├── FL_motion
│ │ └── FL_motion_0001.png
│ │ ├── T1_gt
│ │ └── T1_gt_0001.png
│ │ ├── T1_motion
│ │ └── T1_motion_0001.png
│ │ ├── T2_gt
│ │ └── T2_gt_0001.png
│ │ └── T2_motion
│ │ └── T2_motion_0001.png
├── Network
│ └── test
│ │ ├── FL_gt
│ │ └── FL_gt_0001.png
│ │ ├── FL_motion
│ │ └── FL_motion_0001.png
│ │ ├── T1_gt
│ │ └── T1_gt_0001.png
│ │ ├── T1_motion
│ │ └── T1_motion_0001.png
│ │ ├── T2_gt
│ │ └── T2_gt_0001.png
│ │ └── T2_motion
│ │ └── T2_motion_0001.png
└── No
│ ├── test
│ ├── FL_gt
│ │ └── FL_gt_0001.png
│ ├── FL_motion
│ │ └── FL_motion_0001.png
│ ├── T1_gt
│ │ └── T1_gt_0001.png
│ ├── T1_motion
│ │ └── T1_motion_0001.png
│ ├── T2_gt
│ │ └── T2_gt_0001.png
│ └── T2_motion
│ │ └── T2_motion_0001.png
│ ├── train
│ ├── FL_gt
│ │ └── FL_gt_0001.png
│ ├── FL_motion
│ │ └── FL_motion_0001.png
│ ├── T1_gt
│ │ └── T1_gt_0001.png
│ ├── T1_motion
│ │ └── T1_motion_0001.png
│ ├── T2_gt
│ │ └── T2_gt_0001.png
│ └── T2_motion
│ │ └── T2_motion_0001.png
│ └── valid
│ ├── FL_gt
│ └── FL_gt_0001.png
│ ├── FL_motion
│ └── FL_motion_0001.png
│ ├── T1_gt
│ └── T1_gt_0001.png
│ ├── T1_motion
│ └── T1_motion_0001.png
│ ├── T2_gt
│ └── T2_gt_0001.png
│ └── T2_motion
│ └── T2_motion_0001.png
├── dataset.py
├── figure
├── figure1.png
├── figure2.png
└── figure3.png
├── logs
├── log.INFO
├── log.user-SYS-4029GP-TRT2.jylee.log.INFO.20200708-145230.46000
└── log.user-SYS-4029GP-TRT2.jylee.log.INFO.20200708-150247.1749
├── model.py
├── requirements.txt
├── test.py
├── train.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 OpenXAIProject
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MC2-Net: Motion Correction Network for Multi-Contrast Brain MRI
2 |
3 | Tensorflow implementation of "MC2-Net: Motion Correction Network for Multi-Contrast Brain MRI"
4 |
5 | 
6 |
7 | Contact: Jongyeon Lee, KAIST (jyl4you@kaist.ac.kr)
8 |
9 | ## Paper
10 | [MC2-Net: Motion Correction Network for Multi-Contrast Brain MRI]()
11 | Jongyeon Lee, Byungjai Kim, and [Hyunwook Park](http://athena.kaist.ac.kr)
12 | (Under review)
13 |
14 | Please cite our paper if you find it useful for your research. (Currently under review for MRM)
15 |
16 | ```
17 | @inproceedings{,
18 | author = {},
19 | booktitle = {},
20 | title = {MC2-Net: Motion Correction Network for Multi-Contrast Brain MRI},
21 | year = {2020}
22 | }
23 | ```
24 |
25 | ## Example Results
26 |
27 | 
28 |
29 | ## Quantitative Reuslts
30 |
31 | 
32 |
33 | ## Installation
34 | * Install Tensorflow-gpu 2.2.0 with Python 3.6 and CUDA 10.2
35 |
36 | ```
37 | pip install -r requirements.txt
38 | ```
39 | or
40 | ```
41 | conda create -r [ENV_NAME] python=3.6
42 | conda activate [ENV_NAME]
43 | conda install tensorflow-gpu=2.2.0 scikit-image=0.16.2
44 | ```
45 |
46 | * Clone this repo
47 | ```
48 | git clone https://github.com/OpenXAIProject/mc2-net.git
49 | cd mc2-net
50 | ```
51 |
52 | ## Dataset
53 | * Dataset used for this study is not the public data. The sample images are in data directory for demo.
54 | * You may use [BraTS 2020](https://ipp.cbica.upenn.edu/#BraTS20_registration) as an alternative dataset. To generate the data for training, please refer to the paper and Figure 4 for motion simulation. This step is not implmented in Python code due to its dependency on pulse sequences.
55 |
56 | ## Testing
57 | * Use CUDA_VISIBLE_DEVICES for GPU selection
58 | * Evaluate the predicted images and save them as png files
59 |
60 | ```
61 | CUDA_VISIBLE_DEVICES=0 python test.py --load_weight_name ./weight/[weight filename].h5
62 | ```
63 |
64 | ## Training Examples
65 | * Use CUDA_VISIBLE_DEVICES for GPU selection
66 | * Train the sample images
67 |
68 | ```
69 | CUDA_VISIBLE_DEVICES=0 python train.py --num_epoch 1000
70 | ```
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/data/NMI/test/FL_gt/FL_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/test/FL_gt/FL_gt_0001.png
--------------------------------------------------------------------------------
/data/NMI/test/FL_motion/FL_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/test/FL_motion/FL_motion_0001.png
--------------------------------------------------------------------------------
/data/NMI/test/T1_gt/T1_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/test/T1_gt/T1_gt_0001.png
--------------------------------------------------------------------------------
/data/NMI/test/T1_motion/T1_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/test/T1_motion/T1_motion_0001.png
--------------------------------------------------------------------------------
/data/NMI/test/T2_gt/T2_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/test/T2_gt/T2_gt_0001.png
--------------------------------------------------------------------------------
/data/NMI/test/T2_motion/T2_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/test/T2_motion/T2_motion_0001.png
--------------------------------------------------------------------------------
/data/NMI/train/FL_gt/FL_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/train/FL_gt/FL_gt_0001.png
--------------------------------------------------------------------------------
/data/NMI/train/FL_motion/FL_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/train/FL_motion/FL_motion_0001.png
--------------------------------------------------------------------------------
/data/NMI/train/T1_gt/T1_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/train/T1_gt/T1_gt_0001.png
--------------------------------------------------------------------------------
/data/NMI/train/T1_motion/T1_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/train/T1_motion/T1_motion_0001.png
--------------------------------------------------------------------------------
/data/NMI/train/T2_gt/T2_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/train/T2_gt/T2_gt_0001.png
--------------------------------------------------------------------------------
/data/NMI/train/T2_motion/T2_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/train/T2_motion/T2_motion_0001.png
--------------------------------------------------------------------------------
/data/NMI/valid/FL_gt/FL_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/valid/FL_gt/FL_gt_0001.png
--------------------------------------------------------------------------------
/data/NMI/valid/FL_motion/FL_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/valid/FL_motion/FL_motion_0001.png
--------------------------------------------------------------------------------
/data/NMI/valid/T1_gt/T1_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/valid/T1_gt/T1_gt_0001.png
--------------------------------------------------------------------------------
/data/NMI/valid/T1_motion/T1_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/valid/T1_motion/T1_motion_0001.png
--------------------------------------------------------------------------------
/data/NMI/valid/T2_gt/T2_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/valid/T2_gt/T2_gt_0001.png
--------------------------------------------------------------------------------
/data/NMI/valid/T2_motion/T2_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/NMI/valid/T2_motion/T2_motion_0001.png
--------------------------------------------------------------------------------
/data/Network/test/FL_gt/FL_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/Network/test/FL_gt/FL_gt_0001.png
--------------------------------------------------------------------------------
/data/Network/test/FL_motion/FL_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/Network/test/FL_motion/FL_motion_0001.png
--------------------------------------------------------------------------------
/data/Network/test/T1_gt/T1_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/Network/test/T1_gt/T1_gt_0001.png
--------------------------------------------------------------------------------
/data/Network/test/T1_motion/T1_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/Network/test/T1_motion/T1_motion_0001.png
--------------------------------------------------------------------------------
/data/Network/test/T2_gt/T2_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/Network/test/T2_gt/T2_gt_0001.png
--------------------------------------------------------------------------------
/data/Network/test/T2_motion/T2_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/Network/test/T2_motion/T2_motion_0001.png
--------------------------------------------------------------------------------
/data/No/test/FL_gt/FL_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/test/FL_gt/FL_gt_0001.png
--------------------------------------------------------------------------------
/data/No/test/FL_motion/FL_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/test/FL_motion/FL_motion_0001.png
--------------------------------------------------------------------------------
/data/No/test/T1_gt/T1_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/test/T1_gt/T1_gt_0001.png
--------------------------------------------------------------------------------
/data/No/test/T1_motion/T1_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/test/T1_motion/T1_motion_0001.png
--------------------------------------------------------------------------------
/data/No/test/T2_gt/T2_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/test/T2_gt/T2_gt_0001.png
--------------------------------------------------------------------------------
/data/No/test/T2_motion/T2_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/test/T2_motion/T2_motion_0001.png
--------------------------------------------------------------------------------
/data/No/train/FL_gt/FL_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/train/FL_gt/FL_gt_0001.png
--------------------------------------------------------------------------------
/data/No/train/FL_motion/FL_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/train/FL_motion/FL_motion_0001.png
--------------------------------------------------------------------------------
/data/No/train/T1_gt/T1_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/train/T1_gt/T1_gt_0001.png
--------------------------------------------------------------------------------
/data/No/train/T1_motion/T1_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/train/T1_motion/T1_motion_0001.png
--------------------------------------------------------------------------------
/data/No/train/T2_gt/T2_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/train/T2_gt/T2_gt_0001.png
--------------------------------------------------------------------------------
/data/No/train/T2_motion/T2_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/train/T2_motion/T2_motion_0001.png
--------------------------------------------------------------------------------
/data/No/valid/FL_gt/FL_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/valid/FL_gt/FL_gt_0001.png
--------------------------------------------------------------------------------
/data/No/valid/FL_motion/FL_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/valid/FL_motion/FL_motion_0001.png
--------------------------------------------------------------------------------
/data/No/valid/T1_gt/T1_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/valid/T1_gt/T1_gt_0001.png
--------------------------------------------------------------------------------
/data/No/valid/T1_motion/T1_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/valid/T1_motion/T1_motion_0001.png
--------------------------------------------------------------------------------
/data/No/valid/T2_gt/T2_gt_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/valid/T2_gt/T2_gt_0001.png
--------------------------------------------------------------------------------
/data/No/valid/T2_motion/T2_motion_0001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/data/No/valid/T2_motion/T2_motion_0001.png
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import glob
5 | import random
6 |
7 |
8 | def datalist_loader(path, reg_type=None, data_type='train'):
9 | data_path = os.path.join(path, reg_type, data_type)
10 |
11 | contrasts = ['T1', 'T2', 'FL']
12 | gt_motion = ['gt', 'motion']
13 |
14 | datalist = []
15 | for x_or_y in gt_motion:
16 | for contrast in contrasts:
17 | datalist.append(sorted(glob.glob(data_path+'/'+contrast+'_'+x_or_y+'/*.png')))
18 |
19 | return datalist[0:3], datalist[3:6]
20 |
21 |
22 | def load_batch(fname_list):
23 | out = []
24 | for fname in fname_list:
25 | img = cv2.imread(fname, cv2.IMREAD_GRAYSCALE)
26 | out.append(img)
27 | out = np.expand_dims(np.array(out).astype('float32'), axis=-1)
28 | out = out/255
29 | return out
30 |
31 |
32 | def train_batch_data_loader(datalist, num_contrast, shuffle=True):
33 | datalist_t = list(zip(*datalist))
34 |
35 | train_size = len(datalist_t)
36 | if shuffle:
37 | random.shuffle(datalist_t)
38 | datalist = list(zip(*datalist_t))
39 |
40 | datalist_out_y = []
41 | datalist_out_x = []
42 | for i in range(train_size):
43 | datalist_single_y = []
44 | datalist_single_x = []
45 | for j in range(num_contrast):
46 | gt_or_motion = random.randint(0, 1)*num_contrast # 0 for gt input, 3 for motion input
47 | datalist_single_y.append(datalist[j][i])
48 | datalist_single_x.append(datalist[j+gt_or_motion][i])
49 |
50 | datalist_out_y.append(datalist_single_y)
51 | datalist_out_x.append(datalist_single_x)
52 |
53 | datalist_out_y = list(map(list, zip(*datalist_out_y)))
54 | datalist_out_x = list(map(list, zip(*datalist_out_x)))
55 |
56 | return datalist_out_y, datalist_out_x
57 |
58 |
59 |
60 | def batch_data_loader(batch_datalist, num_contrast):
61 | '''
62 | Arguments:
63 | datalist = list
64 | num_contrast = The number of contrast
65 | '''
66 |
67 | batch = []
68 | for i in range(num_contrast):
69 | batch.append(load_batch(batch_datalist[i]))
70 |
71 | return batch
--------------------------------------------------------------------------------
/figure/figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/figure/figure1.png
--------------------------------------------------------------------------------
/figure/figure2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/figure/figure2.png
--------------------------------------------------------------------------------
/figure/figure3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenXAIProject/mc2-net/b3a28547ee2b5631e05654fd27b1c3af1f586023/figure/figure3.png
--------------------------------------------------------------------------------
/logs/log.user-SYS-4029GP-TRT2.jylee.log.INFO.20200708-145230.46000:
--------------------------------------------------------------------------------
1 | Epoch [ 1/1000] | Iter [ 0/ 193] 5.35s.. train loss for T1: 4.5392, T2: 1.6255, FL: 1.7207
2 | Epoch [ 1/1000] | Iter [ 10/ 193] 26.13s.. train loss for T1: 2.7811, T2: 0.9277, FL: 0.9363
3 | Epoch [ 1/1000] | Iter [ 20/ 193] 46.85s.. train loss for T1: 2.7792, T2: 0.9316, FL: 0.9271
4 | Epoch [ 1/1000] | Iter [ 30/ 193] 67.59s.. train loss for T1: 2.7379, T2: 0.9186, FL: 0.9023
5 | Epoch [ 1/1000] | Iter [ 40/ 193] 88.29s.. train loss for T1: 2.6266, T2: 0.8918, FL: 0.8410
6 | Epoch [ 1/1000] | Iter [ 50/ 193] 109.22s.. train loss for T1: 2.4726, T2: 0.8600, FL: 0.7427
7 | Epoch [ 1/1000] | Iter [ 60/ 193] 130.27s.. train loss for T1: 2.3297, T2: 0.8144, FL: 0.6976
8 | Epoch [ 1/1000] | Iter [ 70/ 193] 151.45s.. train loss for T1: 2.1408, T2: 0.7545, FL: 0.6696
9 | Epoch [ 1/1000] | Iter [ 80/ 193] 172.51s.. train loss for T1: 1.9576, T2: 0.6982, FL: 0.6087
10 | Epoch [ 1/1000] | Iter [ 90/ 193] 193.78s.. train loss for T1: 1.7520, T2: 0.6326, FL: 0.5330
11 | Epoch [ 1/1000] | Iter [ 100/ 193] 215.08s.. train loss for T1: 1.4899, T2: 0.5489, FL: 0.4644
12 | Epoch [ 1/1000] | Iter [ 110/ 193] 236.10s.. train loss for T1: 1.1551, T2: 0.3951, FL: 0.4216
13 | Epoch [ 1/1000] | Iter [ 120/ 193] 257.25s.. train loss for T1: 0.8505, T2: 0.2494, FL: 0.3962
14 | Epoch [ 1/1000] | Iter [ 130/ 193] 278.38s.. train loss for T1: 0.5245, T2: 0.1725, FL: 0.2284
15 | Epoch [ 1/1000] | Iter [ 140/ 193] 299.52s.. train loss for T1: 0.3533, T2: 0.1238, FL: 0.1439
16 | Epoch [ 1/1000] | Iter [ 150/ 193] 320.51s.. train loss for T1: 0.2639, T2: 0.0932, FL: 0.1028
17 | Epoch [ 1/1000] | Iter [ 160/ 193] 341.63s.. train loss for T1: 0.2052, T2: 0.0712, FL: 0.0700
18 | Epoch [ 1/1000] | Iter [ 170/ 193] 362.74s.. train loss for T1: 0.1873, T2: 0.0683, FL: 0.0640
19 | Epoch [ 1/1000] | Iter [ 180/ 193] 383.87s.. train loss for T1: 0.1631, T2: 0.0555, FL: 0.0532
20 | Epoch [ 1/1000] | Iter [ 190/ 193] 405.23s.. train loss for T1: 0.1547, T2: 0.0469, FL: 0.0540
21 | Epoch [ 1/1000] 409.42s.. train loss for T1: 0.0745, T2: 0.0256, FL: 0.0247
22 | Epoch [ 2/1000] | Iter [ 0/ 193] 2.12s.. train loss for T1: 0.1274, T2: 0.0408, FL: 0.0424
23 | Epoch [ 2/1000] | Iter [ 10/ 193] 23.34s.. train loss for T1: 0.1252, T2: 0.0380, FL: 0.0426
24 | Epoch [ 2/1000] | Iter [ 20/ 193] 44.53s.. train loss for T1: 0.1145, T2: 0.0309, FL: 0.0419
25 | Epoch [ 2/1000] | Iter [ 30/ 193] 65.86s.. train loss for T1: 0.1120, T2: 0.0334, FL: 0.0377
26 | Epoch [ 2/1000] | Iter [ 40/ 193] 87.25s.. train loss for T1: 0.1031, T2: 0.0245, FL: 0.0353
27 | Epoch [ 2/1000] | Iter [ 50/ 193] 108.66s.. train loss for T1: 0.1035, T2: 0.0273, FL: 0.0349
28 | Epoch [ 2/1000] | Iter [ 60/ 193] 129.98s.. train loss for T1: 0.0999, T2: 0.0230, FL: 0.0430
29 | Epoch [ 2/1000] | Iter [ 70/ 193] 151.49s.. train loss for T1: 0.0969, T2: 0.0282, FL: 0.0348
30 | Epoch [ 2/1000] | Iter [ 80/ 193] 173.11s.. train loss for T1: 0.0911, T2: 0.0232, FL: 0.0351
31 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | import tensorflow.keras.backend as K
4 | from tensorflow import keras
5 |
6 | tf.random.set_seed(22)
7 | np.random.seed(22)
8 |
9 | class InstanceNormalization(keras.layers.Layer):
10 | def __init__(self,
11 | axis=None,
12 | epsilon=1e-3,
13 | center=True,
14 | scale=True,
15 | beta_initializer='zeros',
16 | gamma_initializer='ones',
17 | beta_regularizer=None,
18 | gamma_regularizer=None,
19 | beta_constraint=None,
20 | gamma_constraint=None,
21 | **kwargs):
22 | super(InstanceNormalization, self).__init__(**kwargs)
23 | self.supports_masking = True
24 | self.axis = axis
25 | self.epsilon = epsilon
26 | self.center = center
27 | self.scale = scale
28 | self.beta_initializer = keras.initializers.get(beta_initializer)
29 | self.gamma_initializer = keras.initializers.get(gamma_initializer)
30 | self.beta_regularizer = keras.regularizers.get(beta_regularizer)
31 | self.gamma_regularizer = keras.regularizers.get(gamma_regularizer)
32 | self.beta_constraint = keras.constraints.get(beta_constraint)
33 | self.gamma_constraint = keras.constraints.get(gamma_constraint)
34 |
35 | def build(self, input_shape):
36 | ndim = len(input_shape)
37 | if self.axis == 0:
38 | raise ValueError('Axis cannot be zero')
39 |
40 | if (self.axis is not None) and (ndim == 2):
41 | raise ValueError('Cannot specify axis for rank 1 tensor')
42 |
43 | self.input_spec = keras.layers.InputSpec(ndim=ndim)
44 |
45 | if self.axis is None:
46 | shape = (1,)
47 | else:
48 | shape = (input_shape[self.axis],)
49 |
50 | if self.scale:
51 | self.gamma = self.add_weight(shape=shape,
52 | name='gamma',
53 | initializer=self.gamma_initializer,
54 | regularizer=self.gamma_regularizer,
55 | constraint=self.gamma_constraint)
56 | else:
57 | self.gamma = None
58 | if self.center:
59 | self.beta = self.add_weight(shape=shape,
60 | name='beta',
61 | initializer=self.beta_initializer,
62 | regularizer=self.beta_regularizer,
63 | constraint=self.beta_constraint)
64 | else:
65 | self.beta = None
66 | self.built = True
67 |
68 | def call(self, inputs, training=None):
69 | input_shape = K.int_shape(inputs)
70 | reduction_axes = list(range(0, len(input_shape)))
71 |
72 | if self.axis is not None:
73 | del reduction_axes[self.axis]
74 |
75 | del reduction_axes[0]
76 |
77 | mean = K.mean(inputs, reduction_axes, keepdims=True)
78 | stddev = K.std(inputs, reduction_axes, keepdims=True) + self.epsilon
79 | normed = (inputs - mean) / stddev
80 |
81 | broadcast_shape = [1] * len(input_shape)
82 | if self.axis is not None:
83 | broadcast_shape[self.axis] = input_shape[self.axis]
84 |
85 | if self.scale:
86 | broadcast_gamma = K.reshape(self.gamma, broadcast_shape)
87 | normed = normed * broadcast_gamma
88 | if self.center:
89 | broadcast_beta = K.reshape(self.beta, broadcast_shape)
90 | normed = normed + broadcast_beta
91 | return normed
92 |
93 | def get_config(self):
94 | config = {
95 | 'axis': self.axis,
96 | 'epsilon': self.epsilon,
97 | 'center': self.center,
98 | 'scale': self.scale,
99 | 'beta_initializer': keras.initializers.serialize(self.beta_initializer),
100 | 'gamma_initializer': keras.initializers.serialize(self.gamma_initializer),
101 | 'beta_regularizer': keras.regularizers.serialize(self.beta_regularizer),
102 | 'gamma_regularizer': keras.regularizers.serialize(self.gamma_regularizer),
103 | 'beta_constraint': keras.constraints.serialize(self.beta_constraint),
104 | 'gamma_constraint': keras.constraints.serialize(self.gamma_constraint)
105 | }
106 | base_config = super(InstanceNormalization, self).get_config()
107 | return dict(list(base_config.items()) + list(config.items()))
108 |
109 |
110 | """ Motion correction network for MC2-Net """
111 |
112 |
113 | class Encoder(keras.Model):
114 | def __init__(self, initial_filters=64):
115 | super(Encoder, self).__init__()
116 |
117 | self.filters = initial_filters
118 |
119 | self.conv1 = keras.layers.Conv2D(self.filters, kernel_size=7, strides=1, padding='same',
120 | kernel_initializer=tf.random_normal_initializer(stddev=0.02))
121 | self.conv2 = keras.layers.Conv2D(self.filters*2, kernel_size=3, strides=2, padding='same',
122 | kernel_initializer=tf.random_normal_initializer(stddev=0.02))
123 | self.conv3 = keras.layers.Conv2D(self.filters*4, kernel_size=3, strides=2, padding='same',
124 | kernel_initializer=tf.random_normal_initializer(stddev=0.02))
125 |
126 | self.n1 = InstanceNormalization()
127 | self.n2 = InstanceNormalization()
128 | self.n3 = InstanceNormalization()
129 |
130 |
131 | def call(self, x, training=True):
132 | x = self.conv1(x)
133 | x = self.n1(x, training=training)
134 | x = tf.nn.relu(x)
135 |
136 | x = self.conv2(x)
137 | x = self.n2(x, training=training)
138 | x = tf.nn.relu(x)
139 |
140 | x = self.conv3(x)
141 | x = self.n3(x, training=training)
142 | x = tf.nn.relu(x)
143 |
144 | return x
145 |
146 |
147 | class Residual(keras.Model):
148 | def __init__(self, initial_filters=256):
149 | super(Residual, self).__init__()
150 |
151 | self.filters = initial_filters
152 |
153 | self.conv1 = keras.layers.Conv2D(self.filters, kernel_size=3, strides=1, padding='same',
154 | kernel_initializer=tf.random_normal_initializer(stddev=0.02))
155 | self.conv2 = keras.layers.Conv2D(self.filters, kernel_size=3, strides=1, padding='same',
156 | kernel_initializer=tf.random_normal_initializer(stddev=0.02))
157 |
158 | self.in1 = InstanceNormalization()
159 | self.in2 = InstanceNormalization()
160 |
161 | def call(self, x, training=True):
162 | inputs = x
163 |
164 | x = self.conv1(x)
165 | # x = self.in1(x, training=training)
166 | x = tf.nn.relu(x)
167 |
168 | x = self.conv2(x)
169 | # x = self.in2(x, training=training)
170 | x = tf.nn.relu(x)
171 |
172 | x = tf.add(x, inputs)
173 |
174 | return x
175 |
176 |
177 | class Decoder(keras.Model):
178 | def __init__(self, initial_filters=128):
179 | super(Decoder, self).__init__()
180 |
181 | self.filters = initial_filters
182 |
183 | self.conv1 = keras.layers.Conv2DTranspose(self.filters, kernel_size=3, strides=2, padding='same',
184 | kernel_initializer=tf.random_normal_initializer(stddev=0.02))
185 | self.conv2 = keras.layers.Conv2DTranspose(self.filters//2, kernel_size=3, strides=2, padding='same',
186 | kernel_initializer=tf.random_normal_initializer(stddev=0.02))
187 | self.conv3 = keras.layers.Conv2D(1, kernel_size=7, strides=1, padding='same',
188 | kernel_initializer=tf.random_normal_initializer(stddev=0.02))
189 |
190 | self.in1 = InstanceNormalization()
191 | self.in2 = InstanceNormalization()
192 | self.in3 = InstanceNormalization()
193 |
194 | def call(self, x, training=True):
195 | x = self.conv1(x)
196 | x = self.in1(x, training=training)
197 | x = tf.nn.relu(x)
198 |
199 | x = self.conv2(x)
200 | x = self.in2(x, training=training)
201 | x = tf.nn.relu(x)
202 |
203 | x = self.conv3(x)
204 | x = self.in3(x, training=training)
205 | x = tf.nn.relu(x)
206 |
207 | return x
208 |
209 |
210 | class MC_Net(keras.Model):
211 | def __init__(self,
212 | img_size=256,
213 | num_filter=16,
214 | num_contrast=3,
215 | num_res_block=9):
216 | super(MC_Net, self).__init__()
217 |
218 | self.img_size = img_size
219 | self.filters = num_filter
220 | self.num_contrast = num_contrast
221 | self.num_res_block = num_res_block
222 |
223 | self.encoder_list = []
224 | for _ in range(num_contrast):
225 | self.encoder_list.append(Encoder(initial_filters=self.filters))
226 |
227 | self.res_block_list = []
228 | for _ in range(num_res_block):
229 | self.res_block_list.append(Residual(initial_filters=self.filters*4*num_contrast))
230 |
231 | self.decoder_list = []
232 | for _ in range(num_contrast):
233 | self.decoder_list.append(Decoder(initial_filters=self.filters*2))
234 |
235 | def build(self, input_shape):
236 | assert isinstance(input_shape, list)
237 | super(MC_Net, self).build(input_shape)
238 |
239 | def call(self, x, training=True):
240 | x_list = []
241 | for i in range(self.num_contrast):
242 | x_list.append(self.encoder_list[i](x[i], training))
243 | x = tf.concat(x_list, axis=-1)
244 |
245 | for i in range(self.num_res_block):
246 | x = (self.res_block_list[i](x, training))
247 |
248 | y = tf.split(x, num_or_size_splits=self.num_contrast, axis=-1)
249 |
250 | y_list = []
251 | for i in range(self.num_contrast):
252 | y_list.append(self.decoder_list[i](y[i], training))
253 |
254 | return y_list
255 |
256 |
257 | def ssim_loss(img1, img2):
258 | return -tf.math.log((tf.image.ssim(img1, img2, max_val=1.0)+1)/2)
259 |
260 |
261 | def vgg_layers(layer_names):
262 | vgg = tf.keras.applications.vgg16.VGG16(include_top=False, weights='imagenet', input_shape=(256, 256, 3))
263 | vgg.trainable = False
264 |
265 | outputs = [vgg.get_layer(name).output for name in layer_names]
266 | model = tf.keras.Model([vgg.input], outputs)
267 | return model
268 |
269 |
270 | def vgg_loss(img1, img2, loss_model):
271 | img1 = tf.repeat(img1, 3, -1)
272 | img2 = tf.repeat(img2, 3, -1)
273 |
274 | return tf.reduce_mean(tf.square(loss_model(img1) - loss_model(img2)))
275 |
276 |
277 | def make_custom_loss(l1, l2, loss_model):
278 | def custom_loss(y_true, y_pred):
279 | return l1*ssim_loss(y_true, y_pred) + l2*vgg_loss(y_true, y_pred, loss_model)
280 |
281 | return custom_loss
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.9.0
2 | astunparse==1.6.3
3 | blinker==1.4
4 | brotlipy==0.7.0
5 | cachetools==4.1.0
6 | certifi==2020.6.20
7 | cffi==1.14.0
8 | chardet==3.0.4
9 | click==7.1.2
10 | cloudpickle==1.4.1
11 | cryptography==2.9.2
12 | cycler==0.10.0
13 | cytoolz==0.10.1
14 | dask @ file:///tmp/build/80754af9/dask-core_1592842333140/work
15 | decorator==4.4.2
16 | gast==0.3.3
17 | google-auth==1.14.1
18 | google-auth-oauthlib==0.4.1
19 | google-pasta==0.2.0
20 | grpcio==1.27.2
21 | h5py @ file:///tmp/build/80754af9/h5py_1593454121459/work
22 | idna @ file:///tmp/build/80754af9/idna_1593446292537/work
23 | imageio==2.8.0
24 | Keras-Preprocessing==1.1.0
25 | kiwisolver==1.2.0
26 | Markdown==3.1.1
27 | matplotlib @ file:///tmp/build/80754af9/matplotlib-base_1592846044287/work
28 | mkl-fft==1.1.0
29 | mkl-random==1.1.1
30 | mkl-service==2.3.0
31 | networkx==2.4
32 | nibabel==2.5.1
33 | numpy==1.17.4
34 | oauthlib==3.1.0
35 | olefile==0.46
36 | opencv-python==4.1.1.26
37 | opt-einsum==3.1.0
38 | Pillow==7.1.2
39 | protobuf==3.12.3
40 | pyasn1==0.4.8
41 | pyasn1-modules==0.2.7
42 | pycparser==2.20
43 | PyJWT==1.7.1
44 | pyOpenSSL==19.1.0
45 | pyparsing==2.4.7
46 | PySocks==1.7.1
47 | python-dateutil==2.8.1
48 | PyWavelets==1.1.1
49 | PyYAML==5.3.1
50 | requests @ file:///tmp/build/80754af9/requests_1592841827918/work
51 | requests-oauthlib==1.3.0
52 | rsa==4.0
53 | scikit-image==0.16.2
54 | scipy @ file:///tmp/build/80754af9/scipy_1592930540905/work
55 | six==1.13.0
56 | tensorboard==2.2.1
57 | tensorboard-plugin-wit==1.6.0
58 | tensorflow==2.2.0
59 | tensorflow-estimator==2.2.0
60 | termcolor==1.1.0
61 | toolz==0.10.0
62 | tornado==6.0.4
63 | urllib3==1.25.9
64 | Werkzeug==1.0.1
65 | wrapt==1.12.1
66 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 | tf.config.experimental_run_functions_eagerly(True)
7 |
8 | from model import MC_Net
9 | from dataset import datalist_loader, batch_data_loader
10 | from utils import test_ssim, test_nmi, test_nrmse, save_image
11 |
12 | tf.random.set_seed(22)
13 | np.random.seed(22)
14 | physical_devices = tf.config.experimental.list_physical_devices('GPU')
15 | tf.config.experimental.set_memory_growth(physical_devices[0], True)
16 |
17 | FLAGS = tf.compat.v1.flags.FLAGS
18 |
19 | tf.compat.v1.flags.DEFINE_integer('batch_size', 1,
20 | 'Batch size (Default: 1)')
21 | tf.compat.v1.flags.DEFINE_integer('image_size', 256,
22 | 'Image size (size x size) (Default: 256)')
23 | tf.compat.v1.flags.DEFINE_string('load_weight_name', 'weight_min_val_loss.h5',
24 | 'Load weight of given name (Default: weight_min_val_loss.h5)')
25 | tf.compat.v1.flags.DEFINE_integer('num_contrast', 3,
26 | 'Number of contrasts of MR images (Default: 3)')
27 | tf.compat.v1.flags.DEFINE_integer('num_filter', 16,
28 | 'Number of filters in the first layer of the encoder (Default: 16)')
29 | tf.compat.v1.flags.DEFINE_integer('num_res_block', 9,
30 | 'Number of residual blocks (Default: 9)')
31 | tf.compat.v1.flags.DEFINE_string('path_data', './data/',
32 | 'Data load path (Default: ./data/')
33 | tf.compat.v1.flags.DEFINE_string('path_save', './test/',
34 | 'Image save path (Default: ./test/')
35 | tf.compat.v1.flags.DEFINE_string('path_weight', './weight/',
36 | 'Weight load path (Default: ./weight/')
37 | tf.compat.v1.flags.DEFINE_string('reg_type', 'NMI',
38 | 'Registration type of input images (No, Network, or NMI) (Default: NMI)')
39 |
40 | assert tf.__version__.startswith('2.')
41 | print('Tensorflow version: ', tf.__version__)
42 | tf.random.set_seed(22)
43 | np.random.seed(22)
44 |
45 | physical_devices = tf.config.experimental.list_physical_devices('GPU')
46 | tf.config.experimental.set_memory_growth(physical_devices[0], True)
47 |
48 | def test():
49 | model = MC_Net(img_size=FLAGS.image_size,
50 | num_filter=FLAGS.num_filter,
51 | num_contrast=FLAGS.num_contrast,
52 | num_res_block=FLAGS.num_res_block)
53 |
54 | input_shape = [(None, FLAGS.image_size, FLAGS.image_size, 1)]
55 | model.build(input_shape=input_shape * FLAGS.num_contrast)
56 |
57 | model.load_weights(FLAGS.path_weight + FLAGS.load_weight_name)
58 | print('Model building completed!')
59 |
60 | y_test_datalist, x_test_datalist = datalist_loader(FLAGS.path_data, FLAGS.reg_type, 'test')
61 | x_test = batch_data_loader(x_test_datalist, FLAGS.num_contrast)
62 | y_test = batch_data_loader(y_test_datalist, FLAGS.num_contrast)
63 | print('Data loading completed!')
64 |
65 | p_test = model.predict(x_test, batch_size=FLAGS.batch_size)
66 | print('Prediction completed!')
67 |
68 | x_ssim_T1, p_ssim_T1 = test_ssim(x_test[0], y_test[0], p_test[0])
69 | x_ssim_T2, p_ssim_T2 = test_ssim(x_test[1], y_test[1], p_test[1])
70 | x_ssim_FL, p_ssim_FL = test_ssim(x_test[2], y_test[2], p_test[2])
71 |
72 | print(' c | x_ssim | p_ssim')
73 | print(f'T1 | {x_ssim_T1:.4f} | {p_ssim_T1:.4f}')
74 | print(f'T2 | {x_ssim_T2:.4f} | {p_ssim_T2:.4f}')
75 | print(f'FL | {x_ssim_FL:.4f} | {p_ssim_FL:.4f}')
76 |
77 | x_nmi_T1, p_nmi_T1 = test_nmi(x_test[0], y_test[0], p_test[0])
78 | x_nmi_T2, p_nmi_T2 = test_nmi(x_test[1], y_test[1], p_test[1])
79 | x_nmi_FL, p_nmi_FL = test_nmi(x_test[2], y_test[2], p_test[2])
80 |
81 | print(' c | x_nmi | p_nmi ')
82 | print(f'T1 | {x_nmi_T1:.4f} | {p_nmi_T1:.4f}')
83 | print(f'T2 | {x_nmi_T2:.4f} | {p_nmi_T2:.4f}')
84 | print(f'FL | {x_nmi_FL:.4f} | {p_nmi_FL:.4f}')
85 |
86 | x_nrmse_T1, p_nrmse_T1 = test_nrmse(x_test[0], y_test[0], p_test[0])
87 | x_nrmse_T2, p_nrmse_T2 = test_nrmse(x_test[1], y_test[1], p_test[1])
88 | x_nrmse_FL, p_nrmse_FL = test_nrmse(x_test[2], y_test[2], p_test[2])
89 |
90 | print(' c | x_nrmse | p_nrmse')
91 | print(f'T1 | {x_nrmse_T1:.4f} | {p_nrmse_T1:.4f}')
92 | print(f'T2 | {x_nrmse_T2:.4f} | {p_nrmse_T2:.4f}')
93 | print(f'FL | {x_nrmse_FL:.4f} | {p_nrmse_FL:.4f}')
94 |
95 | os.makedirs(FLAGS.path_save, exist_ok=True)
96 | for i in range(p_test[0].shape[0]):
97 | save_image(f'{FLAGS.path_save}/T1_pred_{i+1:04d}.png', p_test[0][i])
98 | save_image(f'{FLAGS.path_save}/T2_pred_{i+1:04d}.png', p_test[1][i])
99 | save_image(f'{FLAGS.path_save}/FL_pred_{i+1:04d}.png', p_test[2][i])
100 | print('Image saving completed!')
101 |
102 |
103 | if __name__ == '__main__':
104 | test()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
3 |
4 | from absl import app, flags, logging
5 | from absl.flags import FLAGS
6 | import numpy as np
7 | import tensorflow as tf
8 | from tensorflow import keras
9 | tf.config.experimental_run_functions_eagerly(True)
10 | import time
11 | from tqdm import tqdm
12 |
13 | from model import MC_Net, vgg_layers, make_custom_loss
14 | from dataset import datalist_loader, train_batch_data_loader, batch_data_loader
15 | from utils import rot_tra_argumentation
16 |
17 | tf.random.set_seed(22)
18 | np.random.seed(22)
19 | physical_devices = tf.config.experimental.list_physical_devices('GPU')
20 | tf.config.experimental.set_memory_growth(physical_devices[0], True)
21 |
22 | # FLAGS = flags.FLAGS
23 |
24 | flags.DEFINE_integer('batch_size', 1,
25 | 'Batch size (Default: 1)')
26 | flags.DEFINE_integer('image_size', 256,
27 | 'Image size (size x size) (Default: 256)')
28 | flags.DEFINE_integer('iter_interval', 1,
29 | 'Iteration interval for logging (Default: 1)')
30 | flags.DEFINE_float('lambda_ssim', 1,
31 | 'Weight for SSIM loss (Default: 1)')
32 | flags.DEFINE_float('lambda_vgg', 1e-2,
33 | 'Weight for VGG loss (Default: 0.01)')
34 | flags.DEFINE_float('learning_rate', 1e-4,
35 | 'Initial learning rate for Adam (Default: 0.0001)')
36 | flags.DEFINE_string('load_weight_name', None,
37 | 'Load weight of given name (Default: None)')
38 | flags.DEFINE_integer('num_contrast', 3,
39 | 'Number of contrasts of MR images (Default: 3)')
40 | flags.DEFINE_integer('num_epoch', 1,
41 | 'Number of epochs for training (Default: 1)')
42 | flags.DEFINE_integer('num_filter', 16,
43 | 'Number of filters in the first layer of the encoder (Default: 16)')
44 | flags.DEFINE_integer('num_res_block', 9,
45 | 'Number of residual blocks (Default: 9)')
46 | flags.DEFINE_string('path_data', './data/',
47 | 'Data load path (Default: ./data/')
48 | flags.DEFINE_string('path_weight', './weight/',
49 | 'Weight save path (Default: ./weight/')
50 | flags.DEFINE_string('reg_type', 'NMI',
51 | 'Registration type of input images (No, Network, or NMI) (Default: NMI)')
52 | flags.DEFINE_integer('save_epoch', 10,
53 | 'Save weights by every given number of epochs (Default: 10)')
54 |
55 |
56 | def train(_argv):
57 | os.makedirs('./logs', exist_ok=True)
58 | logging.get_absl_handler().use_absl_log_file('log', "./logs")
59 | logging.get_absl_handler().setFormatter(None)
60 |
61 | os.makedirs(FLAGS.path_weight, exist_ok=True)
62 |
63 | model = MC_Net(img_size=FLAGS.image_size,
64 | num_filter=FLAGS.num_filter,
65 | num_contrast=FLAGS.num_contrast,
66 | num_res_block=FLAGS.num_res_block)
67 |
68 | loss_model = vgg_layers(['block3_conv1'])
69 | final_loss = make_custom_loss(FLAGS.lambda_ssim, FLAGS.lambda_vgg, loss_model)
70 | model.compile(optimizer=keras.optimizers.Adam(FLAGS.learning_rate),
71 | loss=final_loss)
72 | input_shape = [(None, FLAGS.image_size, FLAGS.image_size, 1)]
73 | model.build(input_shape=input_shape * FLAGS.num_contrast)
74 | model.summary()
75 |
76 | # Data loading assumes that the number of contrasts is 3 and contrasts are T1, T2, and FL.
77 | # If you have different datasets, please modify dataset.datalist_loader.
78 | y_train_datalist, x_train_datalist = datalist_loader(FLAGS.path_data, FLAGS.reg_type, 'train')
79 | y_valid_datalist, x_valid_datalist = datalist_loader(FLAGS.path_data, FLAGS.reg_type, 'valid')
80 |
81 | batch_size = FLAGS.batch_size
82 | epochs = FLAGS.num_epoch
83 | train_size = len(y_train_datalist[0])
84 | batch_number = int(np.ceil(train_size//batch_size))
85 |
86 | min_val_loss = 100000
87 |
88 | if FLAGS.load_weight_name is not None:
89 | weight_path = FLAGS.path_weight + '/' + FLAGS.load_weight_name
90 | model.load_weights(weight_path)
91 |
92 | for epoch in range(epochs):
93 | start_time = time.time()
94 | train_loss = [0, 0, 0]
95 | y_train_datalist_shuffle, x_train_datalist_shuffle =\
96 | train_batch_data_loader(y_train_datalist+x_train_datalist, FLAGS.num_contrast)
97 | for batch_index in tqdm(range(batch_number), ncols=100):
98 | start = batch_index*batch_size
99 |
100 | y_train_datalist_batch = []
101 | x_train_datalist_batch = []
102 | for i in range(FLAGS.num_contrast):
103 | y_train_datalist_batch.append(y_train_datalist_shuffle[i][start:start+batch_size])
104 | x_train_datalist_batch.append(x_train_datalist_shuffle[i][start:start+batch_size])
105 |
106 | y_train_batch = batch_data_loader(y_train_datalist_batch, FLAGS.num_contrast)
107 | x_train_batch = batch_data_loader(x_train_datalist_batch, FLAGS.num_contrast)
108 | y_train_batch, x_train_batch = rot_tra_argumentation(y_train_batch, x_train_batch, FLAGS.num_contrast)
109 |
110 | batch_size_tmp = x_train_batch[0].shape[0]
111 | tmp_loss = model.train_on_batch(x_train_batch, y_train_batch)
112 |
113 | if batch_index % FLAGS.iter_interval == 0:
114 | logging.info(f'Epoch [{epoch+1:4d}/{epochs:4d}] | Iter [{batch_index:4d}/{batch_number:4d}] '
115 | f'{time.time() - start_time:.2f}s.. '
116 | f'train loss for T1: {tmp_loss[0]:.4f}, T2: {tmp_loss[1]:.4f}, FL: {tmp_loss[2]:.4f}')
117 |
118 | train_loss = [(x + y*batch_size_tmp) for (x, y) in zip(train_loss, tmp_loss)]
119 |
120 | train_loss = [x / train_size for x in train_loss]
121 |
122 | print(f'Epoch [{epoch+1:4d}/{epochs:4d}] {time.time() - start_time:.2f}s.. '
123 | f'train loss for T1: {train_loss[0]:.4f}, T2: {train_loss[1]:.4f}, FL: {train_loss[2]:.4f}')
124 | logging.info(f'Epoch [{epoch+1:4d}/{epochs:4d}] {time.time() - start_time:.2f}s.. '
125 | f'train loss for T1: {train_loss[0]:.4f}, T2: {train_loss[1]:.4f}, FL: {train_loss[2]:.4f}')
126 |
127 | if ((epoch+1) % FLAGS.save_epoch) == 0:
128 | x_valid = batch_data_loader(x_valid_datalist, FLAGS.num_contrast)
129 | y_valid = batch_data_loader(y_valid_datalist, FLAGS.num_contrast)
130 | val_loss = model.evaluate(x_valid, y_valid, verbose=0)
131 | model.save_weights(f'{FLAGS.path_weight}weight_e{epoch+1:04d}.h5', overwrite=True)
132 |
133 | if np.mean(val_loss) < min_val_loss:
134 | model.save_weights(f'{FLAGS.path_weight}weight_min_val_loss.h5', overwrite=True)
135 | min_val_loss = np.mean(val_loss)
136 |
137 | print(f'Weight saved! val loss T1: {val_loss[0]:.4f}, T2: {val_loss[1]:.4f}, FL: {val_loss[2]:.4f}')
138 | del x_valid, y_valid
139 |
140 | if epoch+1 == epochs:
141 | model.save_weights(f'{FLAGS.path_weight}weight_final.h5',
142 | overwrite=True)
143 | print(f'Weight saved! Training finished.')
144 |
145 | if __name__ == '__main__':
146 | try:
147 | app.run(train)
148 | except SystemExit:
149 | pass
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.ndimage import rotate, shift
3 | import cv2
4 | from skimage.metrics import structural_similarity
5 |
6 |
7 | def transpose(img, r, tx, ty, channel_last=True):
8 | if img.ndim == 2:
9 | s = shift(rotate(img, r, [0, 1], order=1, reshape=False), [tx, ty], order=1)
10 | elif img.ndim == 3:
11 | if channel_last:
12 | s = shift(rotate(img, r, [0, 1], order=1, reshape=False), [tx, ty, 0], order=1)
13 | else:
14 | s = shift(rotate(img, r, [1, 2], order=1, reshape=False), [0, tx, ty], order=1)
15 | elif img.ndim == 4:
16 | s = shift(rotate(img, r, [1, 2], order=1, reshape=False), [0, tx, ty, 0], order=1)
17 | else:
18 | raise TypeError
19 |
20 | return s
21 |
22 |
23 | def rand_range(range):
24 | return (np.random.rand()-0.5)*2*range
25 |
26 |
27 | def generate_rand_par(num):
28 | x = (np.random.rand(num)-.5)*num/(num-1)
29 | return x
30 |
31 |
32 | def rot_tra_argumentation(y, x, num_contrast):
33 | size = y[0].shape[0]
34 |
35 | x_copy = x
36 | y_copy = y
37 | for i in range(size):
38 | r = rand_range(10)
39 | tx = rand_range(5)
40 | ty = rand_range(5)
41 | for j in range(num_contrast):
42 | y_copy[j][i,:,:,:] = transpose(y[j][i,:,:,:], r, tx, ty)
43 | x_copy[j][i,:,:,:] = transpose(x[j][i,:,:,:], r, tx, ty)
44 |
45 | return y_copy, x_copy
46 |
47 |
48 | def save_image(file_name, image):
49 | image = np.squeeze(image)*255
50 | cv2.imwrite(file_name, image)
51 |
52 |
53 | def merge_images(x, y, p):
54 | size = x[0].shape[0]
55 | h = x[0].shape[1]
56 | w = x[0].shape[2]
57 | out = np.zeros([size, h*3, w*3])
58 | for i in range(size):
59 | out[i, 0:h, 0:w] = x[0][i,:,:,0]
60 | out[i, 0:h, w:w*2] = x[1][i,:,:,0]
61 | out[i, 0:h, w*2:w*3] = x[2][i,:,:,0]
62 | out[i, h:h*2, 0:w] = y[0][i,:,:,0]
63 | out[i, h:h*2, w:w*2] = y[1][i,:,:,0]
64 | out[i, h:h*2, w*2:w*3] = y[2][i,:,:,0]
65 | out[i, h*2:h*3, 0:w] = p[0][i,:,:,0]
66 | out[i, h*2:h*3, w:w*2] = p[1][i,:,:,0]
67 | out[i, h*2:h*3, w*2:w*3] = p[2][i,:,:,0]
68 | return out
69 |
70 |
71 | def ssim(img, ref):
72 | return structural_similarity(img, ref, data_range=1.)
73 |
74 |
75 | def test_ssim(x_img, y_img, p_img):
76 | x_ssim = []
77 | p_ssim = []
78 | for i in range(x_img.shape[0]):
79 | x_ssim.append(ssim(x_img[i,:,:,0], y_img[i,:,:,0]))
80 | p_ssim.append(ssim(p_img[i,:,:,0], y_img[i,:,:,0]))
81 | return np.mean(x_ssim), np.mean(p_ssim)
82 |
83 |
84 | def nmi(img, ref, bins=16):
85 | eps = 1e-10
86 | hist = np.histogram2d(img.flatten(), ref.flatten(), bins=bins)[0]
87 | pxy = hist / np.sum(hist)
88 | px = np.sum(pxy, axis=1)
89 | py = np.sum(pxy, axis=0)
90 | px_py = px[:, None] * py[None, :]
91 | nzs = pxy > 0
92 | mi = np.sum(pxy[nzs] * np.log2((pxy[nzs] + eps) / (px_py[nzs] + eps)))
93 | entx = - np.sum(px * np.log2(px + eps))
94 | enty = - np.sum(py * np.log2(py + eps))
95 |
96 | return (2 * mi + eps) / (entx + enty + eps)
97 |
98 |
99 | def test_nmi(x_img, y_img, p_img):
100 | x_nmi = []
101 | p_nmi = []
102 | for i in range(x_img.shape[0]):
103 | x_nmi.append(nmi(x_img[i,:,:,0], y_img[i,:,:,0]))
104 | p_nmi.append(nmi(p_img[i,:,:,0], y_img[i,:,:,0]))
105 | return np.mean(x_nmi), np.mean(p_nmi)
106 |
107 |
108 | def nrmse(img, ref):
109 | rmse = np.sqrt(np.mean((img-ref)**2))
110 | return rmse/np.mean(ref)
111 |
112 |
113 | def test_nrmse(x_img, y_img, p_img):
114 | x_nrmse = []
115 | p_nrmse = []
116 | for i in range(x_img.shape[0]):
117 | x_nrmse.append(nrmse(x_img[i,:,:,0], y_img[i,:,:,0]))
118 | p_nrmse.append(nrmse(p_img[i,:,:,0], y_img[i,:,:,0]))
119 | return np.mean(x_nrmse), np.mean(p_nrmse)
--------------------------------------------------------------------------------