├── AUTHOR.txt
├── License
└── Apache License_ver2.txt
├── README.md
├── main.py
└── src
├── __init__.py
├── function
├── functions.py
└── preprocessing.py
├── layer
└── layers.py
├── models
└── BEGAN.py
└── operator
├── op_BEGAN.py
└── op_base.py
/AUTHOR.txt:
--------------------------------------------------------------------------------
1 | Copyright 2018 (Institution) under XAI Project supported by Ministry of Science and ICT, Korea
2 |
3 | # This is the list of (Institution) for copyright purposes.
4 | # This does not necessarily list everyone who has contributed code, since in
5 | # some cases, their employer may be the copyright holder. To see the full list
6 | # of contributors, see the revision history in source control
7 |
--------------------------------------------------------------------------------
/License/Apache License_ver2.txt:
--------------------------------------------------------------------------------
1 | Copyright [yyyy] [name of copyright owner]
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Deep Generative Model
4 |
5 | ### **CONTENT**
6 | > Boundary Equilibrium Generative Adversarial Networks based MRI generative model
7 |
8 | ### **Dataset**
9 | > Human Connectome Project
10 | > https://www.humanconnectome.org/study/hcp-young-adult/data-releases
11 |
12 | ### **Reference**
13 | > BEGAN
14 | > https://arxiv.org/abs/1703.10717
15 |
16 | # XAI Project
17 |
18 | ### **Project Name**
19 | > A machine learning and statistical inference framework for explainable artificial intelligence(의사결정 이유를 설명할 수 있는 인간 수준의 학습·추론 프레임워크 개발)
20 | ### **Managed by**
21 | > Ministry of Science and ICT/XAIC
22 | ### **Participated Affiliation**
23 | > UNIST, Korea Univ., Yonsei Univ., KAIST., AItrics
24 | ### **Web Site**
25 | >
26 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | #Copyright 2018 UNIST under XAI Project supported by Ministry of Science and ICT, Korea
2 |
3 | #Licensed under the Apache License, Version 2.0 (the "License");
4 | #you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | #Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
10 |
11 | import argparse
12 | import distutils.util
13 | import os
14 | import tensorflow as tf
15 | import src.models.BEGAN as began
16 |
17 |
18 | def main():
19 | parser = argparse.ArgumentParser()
20 |
21 | parser.add_argument("-f", "--flag", type=distutils.util.strtobool, default='0')
22 | parser.add_argument("-g", "--gpu_number", type=str, default="1")
23 | parser.add_argument("-p", "--project", type=str, default="MRIGAN_2D_g0.3_d3")
24 |
25 | # Train Data
26 | parser.add_argument("-d", "--data_dir", type=str, default="./Data/MRI")
27 | parser.add_argument("-trd", "--dataset", type=str, default="HCP_MRI")
28 | parser.add_argument("-tro", "--data_opt", type=str, default="crop")
29 | parser.add_argument("-trs", "--data_size", type=int, default=256)
30 | parser.add_argument("-ndp", "--num_depth", type=int, default=3)
31 |
32 | # Train Iteration
33 | parser.add_argument("-n" , "--niter", type=int, default=200)
34 | parser.add_argument("-ns", "--nsnapshot", type=int, default=5000)
35 | parser.add_argument("-mx", "--max_to_keep", type=int, default=5)
36 |
37 | # Train Parameter
38 | parser.add_argument("-b" , "--batch_size", type=int, default=1)
39 | parser.add_argument("-lr", "--learning_rate", type=float, default=1e-4)
40 | parser.add_argument("-m" , "--momentum", type=float, default=0.5)
41 | parser.add_argument("-m2", "--momentum2", type=float, default=0.999)
42 | parser.add_argument("-gm", "--gamma", type=float, default=0.3)
43 | parser.add_argument("-lm", "--lamda", type=float, default=0.001)
44 | parser.add_argument("-fn", "--filter_number", type=int, default=64)
45 | parser.add_argument("-z", "--input_size", type=int, default=256)
46 | parser.add_argument("-em", "--embedding", type=int, default=256)
47 |
48 | args = parser.parse_args()
49 |
50 | gpu_number = args.gpu_number
51 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_number
52 |
53 | with tf.device('/gpu:{0}'.format(gpu_number)):
54 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.90)
55 | config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)
56 |
57 | with tf.Session(config=config) as sess:
58 | model = began.BEGAN(args, sess)
59 |
60 | # TRAIN / TEST
61 | if args.flag:
62 | model.train(args.flag)
63 | else:
64 | model.test(args.flag)
65 |
66 | if __name__ == '__main__':
67 | main()
68 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | ## __init__.py
--------------------------------------------------------------------------------
/src/function/functions.py:
--------------------------------------------------------------------------------
1 | #Copyright 2018 UNIST under XAI Project supported by Ministry of Science and ICT, Korea
2 |
3 | #Licensed under the Apache License, Version 2.0 (the "License");
4 | #you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | #Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
10 |
11 | import os
12 | import scipy.misc as scm
13 | import nibabel as nib
14 |
15 | def make_project_dir(project_dir):
16 | if not os.path.exists(project_dir):
17 | os.makedirs(project_dir)
18 | os.makedirs(os.path.join(project_dir, 'models'))
19 | os.makedirs(os.path.join(project_dir, 'result'))
20 | os.makedirs(os.path.join(project_dir, 'result_test'))
21 |
22 | def get_image(img_path):
23 | img = scm.imread(img_path)/255. - 0.5
24 | img = img[..., ::-1] # rgb to bgr
25 | return img
26 |
27 |
28 | def inverse_image(img):
29 | img = (img + 0.5) * 255.
30 | img[img > 255] = 255
31 | img[img < 0] = 0
32 | img = img[..., ::-1] # bgr to rgb
33 | return img
34 |
35 | def save_as_nii(vol, aff, save_dir):
36 | for i in range(len(vol)):
37 | img = nib.Nifti1Image(dataobj=vol[i,...], affine=aff)
38 | nib.save(img,'{}_{}.nii'.format(save_dir,i))
39 | print("MRI file saved..!")
40 | return
41 |
--------------------------------------------------------------------------------
/src/function/preprocessing.py:
--------------------------------------------------------------------------------
1 | #Copyright 2018 UNIST under XAI Project supported by Ministry of Science and ICT, Korea
2 |
3 | #Licensed under the Apache License, Version 2.0 (the "License");
4 | #you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | #Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
10 |
11 | import os
12 | import numpy as np
13 | import nibabel as nib
14 | from scipy.ndimage.interpolation import zoom
15 |
16 | TARGET_DIR = '../../Data/HCP_MRI'
17 | TARGET_FNAME = 'T1w_restore_brain.nii.gz'
18 | # SAVE_DIR = '../../Data/MRI/HCP_MRI_256.npy'
19 | SAVE_DIR = '/DATA_1/HCP_MRI_256.npy'
20 | COUNT = 0
21 | MIN = 0
22 |
23 | def make_path_list(dir, filename):
24 |
25 | pathlist = []
26 | for root, _, fnames in sorted(os.walk(dir)):
27 | for fname in sorted(fnames):
28 | if fname == filename:
29 | path = os.path.join(root,fname)
30 | pathlist.append(path)
31 | pathlist = np.asarray(pathlist)
32 | return pathlist
33 |
34 | def normalize(img):
35 | max = np.max(img)
36 | min = np.min(img)
37 | normalized_img = (img-min)/(max-min)
38 |
39 | return normalized_img
40 |
41 | def rescale(vol,scale):
42 | # The MRI dataset shape is w > h = d, so make rescaled mri isotropic
43 | h,w,d = vol.shape
44 | vol_rs = zoom(vol,zoom=(scale,scale*float(h)/w,scale),mode='nearest')
45 | return vol_rs
46 |
47 | def get_mri(data_path,):
48 | first_flag = True
49 | proxy_img = nib.load(data_path)
50 | data_array = np.asarray(proxy_img.dataobj).astype(np.float32)
51 | data_array = data_array[2:-2,:,2:-2]
52 | global COUNT
53 | for s in range(data_array.shape[1]):
54 | _slice = data_array[:,s,:]
55 | if (np.count_nonzero(_slice)==0):
56 | continue
57 | _slice = _slice.T[::-1,:]
58 | _slice = normalize(_slice)
59 | if first_flag:
60 | concat = _slice[...,None]
61 | first_flag = False
62 | else:
63 | concat = np.concatenate((concat,_slice[...,None]),axis=2)
64 | if concat.shape[2] == 257:
65 | print(COUNT + 1)
66 | COUNT += 1
67 | return concat
68 | return concat
69 |
70 | def print_aff(data_path):
71 | proxy_img = nib.load(data_path)
72 | print(proxy_img.affine)
73 |
74 | return
75 |
76 | def get_aff(dir = TARGET_DIR, fname= TARGET_FNAME):
77 | f_list = make_path_list(dir, fname)
78 | proxy_img = nib.load(f_list[0])
79 | return proxy_img.affine
80 |
81 | def preprocessing(data_dir=TARGET_DIR,save_dir=SAVE_DIR,fname=TARGET_FNAME):
82 |
83 | print("Preprocessing Start")
84 | f_list = make_path_list(data_dir,fname)
85 | concat_mri = [get_mri(path) for path in f_list]
86 | concat_mri = np.asarray(concat_mri).astype(np.float32)
87 | np.save(save_dir,concat_mri)
88 | print("Concatenation Done")
89 | return
90 |
91 | def get_min_nonzero_slice(data_dir=TARGET_DIR,fname=TARGET_FNAME):
92 | f_list = make_path_list(data_dir, fname)
93 | m_count = 10000
94 | i = 1
95 | for path in f_list:
96 | data = np.asarray(nib.load(path).dataobj)
97 | data = np.transpose(data,(0,2,1))
98 | data = data.reshape((-1,data.shape[-1]))
99 | max = np.max(data,axis=0)
100 | count = np.count_nonzero(max)
101 | m_count = min(m_count,count)
102 | print(i,":")
103 | print(m_count)
104 |
105 | i += 1
106 | print(m_count)
107 | return m_count
108 |
109 | if __name__=="__main__":
110 | preprocessing()
111 | # get_min_nonzero_slice()
112 |
--------------------------------------------------------------------------------
/src/layer/layers.py:
--------------------------------------------------------------------------------
1 | #Copyright 2018 UNIST under XAI Project supported by Ministry of Science and ICT, Korea
2 |
3 | #Licensed under the Apache License, Version 2.0 (the "License");
4 | #you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | #Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
10 |
11 | import tensorflow as tf
12 | import numpy as np
13 |
14 |
15 | def conv2d(x, filter_shape, bias=True, stride=1, padding="SAME", name="conv2d"):
16 | kw, kh, nin, nout = filter_shape
17 | pad_size = (kw - 1) / 2
18 |
19 | if padding == "VALID":
20 | x = tf.pad(x, [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]], "SYMMETRIC")
21 |
22 | initializer = tf.random_normal_initializer(0., 0.02)
23 | with tf.variable_scope(name):
24 | weight = tf.get_variable("weight", shape=filter_shape, initializer=initializer)
25 | x = tf.nn.conv2d(x, weight, [1, stride, stride, 1], padding=padding)
26 |
27 | if bias:
28 | b = tf.get_variable("bias", shape=filter_shape[-1], initializer=tf.constant_initializer(0.))
29 | x = tf.nn.bias_add(x, b)
30 | return x
31 |
32 |
33 | def fc(x, output_shape, bias=True, name='fc'):
34 | shape = x.get_shape().as_list()
35 | dim = np.prod(shape[1:])
36 | x = tf.reshape(x, [-1, dim])
37 | input_shape = dim
38 |
39 | initializer = tf.random_normal_initializer(0., 0.02)
40 | with tf.variable_scope(name):
41 | weight = tf.get_variable("weight", shape=[input_shape, output_shape], initializer=initializer)
42 | x = tf.matmul(x, weight)
43 |
44 | if bias:
45 | b = tf.get_variable("bias", shape=[output_shape], initializer=tf.constant_initializer(0.))
46 | x = tf.nn.bias_add(x, b)
47 | return x
48 |
49 |
50 | def pool(x, r=2, s=1):
51 | return tf.nn.avg_pool(x, ksize=[1, r, r, 1], strides=[1, s, s, 1], padding="SAME")
52 |
53 |
54 | def l1_loss(x, y):
55 | return tf.reduce_mean(tf.abs(x - y))
56 |
57 |
58 | def resize_nn(x, size):
59 | return tf.image.resize_nearest_neighbor(x, size=(int(size), int(size)))
60 |
--------------------------------------------------------------------------------
/src/models/BEGAN.py:
--------------------------------------------------------------------------------
1 | #Copyright 2018 UNIST under XAI Project supported by Ministry of Science and ICT, Korea
2 |
3 | #Licensed under the Apache License, Version 2.0 (the "License");
4 | #you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | #Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
10 |
11 | from src.layer.layers import *
12 | from src.operator.op_BEGAN import Operator
13 |
14 |
15 | class BEGAN(Operator):
16 | def __init__(self, args, sess):
17 | Operator.__init__(self, args, sess)
18 |
19 | def generator(self, x, reuse=None):
20 | with tf.variable_scope('gen_') as scope:
21 | if reuse:
22 | scope.reuse_variables()
23 |
24 | w = self.data_size
25 | f = self.filter_number
26 | v = self.num_depth
27 | p = "SAME"
28 |
29 | x = fc(x, 8 * 8 * f, name='fc')
30 | x = tf.reshape(x, [-1, 8, 8, f])
31 |
32 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_a')
33 | x = tf.nn.elu(x)
34 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_b')
35 | x = tf.nn.elu(x)
36 |
37 | if self.data_size == 256:
38 | x = resize_nn(x, w/16)
39 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_a')
40 | x = tf.nn.elu(x)
41 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_b')
42 | x = tf.nn.elu(x)
43 |
44 | if (self.data_size == 128) or (self.data_size == 256):
45 | x = resize_nn(x, w / 8)
46 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_a')
47 | x = tf.nn.elu(x)
48 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_b')
49 | x = tf.nn.elu(x)
50 |
51 | x = resize_nn(x, w / 4)
52 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_a')
53 | x = tf.nn.elu(x)
54 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_b')
55 | x = tf.nn.elu(x)
56 |
57 | x = resize_nn(x, w / 2)
58 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv5_a')
59 | x = tf.nn.elu(x)
60 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv5_b')
61 | x = tf.nn.elu(x)
62 |
63 | x = resize_nn(x, w)
64 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv6_a')
65 | x = tf.nn.elu(x)
66 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv6_b')
67 | x = tf.nn.elu(x)
68 |
69 | x = conv2d(x, [3, 3, f, v], stride=1, padding=p,name='conv7_a')
70 | return x
71 |
72 | def encoder(self, x, reuse=None):
73 | with tf.variable_scope('disc_') as scope:
74 | if reuse:
75 | scope.reuse_variables()
76 |
77 | f = self.filter_number
78 | h = self.embedding
79 | v = self.num_depth
80 | p = "SAME"
81 |
82 | x = conv2d(x, [3, 3, v, f], stride=1, padding=p,name='conv1_enc_a')
83 | x = tf.nn.elu(x)
84 |
85 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv2_enc_a')
86 | x = tf.nn.elu(x)
87 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv2_enc_b')
88 | x = tf.nn.elu(x)
89 |
90 | x = conv2d(x, [1, 1, f, 2 * f], stride=1, padding=p,name='conv3_enc_0')
91 |
92 | x = pool(x, r=2, s=2)
93 |
94 | x = conv2d(x, [3, 3, 2 * f, 2 * f], stride=1, padding=p,name='conv3_enc_a')
95 | x = tf.nn.elu(x)
96 | x = conv2d(x, [3, 3, 2 * f, 2 * f], stride=1, padding=p,name='conv3_enc_b')
97 | x = tf.nn.elu(x)
98 |
99 | x = conv2d(x, [1, 1, 2 * f, 3 * f], stride=1, padding=p,name='conv4_enc_0')
100 |
101 | x = pool(x, r=2, s=2)
102 |
103 | x = conv2d(x, [3, 3, 3 * f, 3 * f], stride=1, padding=p,name='conv4_enc_a')
104 | x = tf.nn.elu(x)
105 | x = conv2d(x, [3, 3, 3 * f, 3 * f], stride=1, padding=p,name='conv4_enc_b')
106 | x = tf.nn.elu(x)
107 |
108 | x = conv2d(x, [1, 1, 3 * f, 4 * f], stride=1, padding=p,name='conv5_enc_0')
109 |
110 | x = pool(x, r=2, s=2)
111 |
112 | x = conv2d(x, [3, 3, 4 * f, 4 * f], stride=1, padding=p,name='conv5_enc_a')
113 | x = tf.nn.elu(x)
114 | x = conv2d(x, [3, 3, 4 * f, 4 * f], stride=1, padding=p,name='conv5_enc_b')
115 | x = tf.nn.elu(x)
116 |
117 | if (self.data_size == 128) or (self.data_size == 256):
118 | x = conv2d(x, [1, 1, 4 * f, 5 * f], stride=1, padding=p,name='conv6_enc_0')
119 | x = pool(x, r=2, s=2)
120 | x = conv2d(x, [3, 3, 5 * f, 5 * f], stride=1, padding=p,name='conv6_enc_a')
121 | x = tf.nn.elu(x)
122 | x = conv2d(x, [3, 3, 5 * f, 5 * f], stride=1, padding=p,name='conv6_enc_b')
123 | x = tf.nn.elu(x)
124 |
125 | if self.data_size == 256:
126 | x = conv2d(x, [1, 1, 5 * f, 6 * f], stride=1, padding=p,name='conv7_enc_0')
127 | x = pool(x, r=2, s=2)
128 | x = conv2d(x, [3, 3, 6 * f, 6 * f], stride=1, padding=p,name='conv7_enc_a')
129 | x = tf.nn.elu(x)
130 | x = conv2d(x, [3, 3, 6 * f, 6 * f], stride=1, padding=p,name='conv7_enc_b')
131 | x = tf.nn.elu(x)
132 |
133 | x = fc(x, h, name='enc_fc')
134 | return x
135 |
136 | def decoder(self, x, reuse=None):
137 | with tf.variable_scope('disc_') as scope:
138 | if reuse:
139 | scope.reuse_variables()
140 |
141 | w = self.data_size
142 | f = self.filter_number
143 | v = self.num_depth
144 | p = "SAME"
145 |
146 | x = fc(x, 8 * 8 * f, name='fc')
147 | x = tf.reshape(x, [-1, 8, 8, f])
148 |
149 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_a')
150 | x = tf.nn.elu(x)
151 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_b')
152 | x = tf.nn.elu(x)
153 |
154 | if self.data_size == 256:
155 | x = resize_nn(x, w/16)
156 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_a')
157 | x = tf.nn.elu(x)
158 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_b')
159 | x = tf.nn.elu(x)
160 |
161 | if (self.data_size == 128) or (self.data_size == 256):
162 | x = resize_nn(x, w / 8)
163 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_a')
164 | x = tf.nn.elu(x)
165 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_b')
166 | x = tf.nn.elu(x)
167 |
168 | x = resize_nn(x, w / 4)
169 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_a')
170 | x = tf.nn.elu(x)
171 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_b')
172 | x = tf.nn.elu(x)
173 |
174 | x = resize_nn(x, w / 2)
175 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv5_a')
176 | x = tf.nn.elu(x)
177 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv5_b')
178 | x = tf.nn.elu(x)
179 |
180 | x = resize_nn(x, w)
181 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv6_a')
182 | x = tf.nn.elu(x)
183 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv6_b')
184 | x = tf.nn.elu(x)
185 |
186 | x = conv2d(x, [3, 3, f, v], stride=1, padding=p, name='conv7_a')
187 | return x
188 |
--------------------------------------------------------------------------------
/src/operator/op_BEGAN.py:
--------------------------------------------------------------------------------
1 | #Copyright 2018 UNIST under XAI Project supported by Ministry of Science and ICT, Korea
2 |
3 | #Licensed under the Apache License, Version 2.0 (the "License");
4 | #you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | #Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
10 |
11 | import time
12 | import datetime
13 | from src.layer.layers import *
14 | from src.function.functions import *
15 | from src.function.preprocessing import *
16 | from src.operator.op_base import op_base
17 |
18 |
19 | class Operator(op_base):
20 | def __init__(self, args, sess):
21 | op_base.__init__(self, args, sess)
22 | self.build_model()
23 |
24 | def build_model(self):
25 | # Input placeholder
26 | self.x = tf.placeholder(tf.float32, shape=[self.batch_size, self.input_size], name='x')
27 | self.y = tf.placeholder(tf.float32, shape=[self.batch_size, self.data_size, self.data_size, self.num_depth], name='y')
28 | self.kt = tf.placeholder(tf.float32, name='kt')
29 | self.lr = tf.placeholder(tf.float32, name='lr')
30 |
31 | # # latent
32 |
33 |
34 | # Generator
35 | self.recon_gen = self.generator(self.x)
36 |
37 | # Discriminator (Critic)
38 | self.aaaaaa = self.encoder(self.y)
39 | d_real = self.decoder(self.aaaaaa)
40 | d_fake = self.decoder(self.encoder(self.recon_gen, reuse=True), reuse=True)
41 | self.recon_dec = self.decoder(self.x, reuse=True)
42 |
43 | # Loss
44 | self.d_real_loss = l1_loss(self.y, d_real)
45 | self.d_fake_loss = l1_loss(self.recon_gen, d_fake)
46 | self.d_loss = self.d_real_loss - self.kt * self.d_fake_loss
47 | self.g_loss = self.d_fake_loss
48 | self.m_global = self.d_real_loss + tf.abs(self.gamma * self.d_real_loss - self.d_fake_loss)
49 |
50 | # Variables
51 | g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "gen_")
52 | d_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "disc_")
53 |
54 | # Optimizer
55 | self.opt_g = tf.train.AdamOptimizer(self.lr, self.mm).minimize(self.g_loss, var_list=g_vars)
56 | self.opt_d = tf.train.AdamOptimizer(self.lr, self.mm).minimize(self.d_loss, var_list=d_vars)
57 |
58 |
59 | # initializer
60 | self.sess.run(tf.global_variables_initializer())
61 |
62 | # tf saver
63 | self.saver = tf.train.Saver(max_to_keep=(self.max_to_keep))
64 |
65 | try:
66 | self.load(self.sess, self.saver, self.ckpt_dir)
67 | except:
68 | # save full graph
69 | self.saver.save(self.sess, self.ckpt_model_name, write_meta_graph=True)
70 |
71 | # Summary
72 | if self.flag:
73 | tf.summary.scalar('loss/loss', self.d_loss + self.g_loss)
74 | tf.summary.scalar('loss/g_loss', self.g_loss)
75 | tf.summary.scalar('loss/d_loss', self.d_loss)
76 | tf.summary.scalar('loss/d_real_loss', self.d_real_loss)
77 | tf.summary.scalar('loss/d_fake_loss', self.d_fake_loss)
78 | tf.summary.scalar('misc/kt', self.kt)
79 | tf.summary.scalar('misc/m_global', self.m_global)
80 | self.merged = tf.summary.merge_all()
81 | self.writer = tf.summary.FileWriter(self.project_dir, self.sess.graph)
82 |
83 | def train(self, train_flag):
84 | # load data
85 | train_data = self.train_data
86 | print('Shuffle ....')
87 | num_vol = train_data.shape[0]
88 | num_sli = train_data.shape[3]
89 | vv = np.arange(num_vol)
90 | ss = np.arange(num_sli)
91 | v_c, s_c = np.meshgrid(vv,ss)
92 | vs = np.column_stack([v_c.flat,s_c.flat])
93 | data_length = len(vs)
94 | random_order = np.random.permutation(data_length)
95 | print('Shuffle Done')
96 |
97 | # initial parameter
98 | start_time = time.time()
99 | kt = np.float32(0.)
100 | lr = np.float32(self.learning_rate)
101 | self.count = 0
102 |
103 | for epoch in range(self.niter):
104 | batch_idxs = len(vs) // self.batch_size
105 |
106 | for idx in range(0, batch_idxs):
107 | self.count += 1
108 |
109 | batch_x = np.random.uniform(-1., 1., size=[self.batch_size, self.input_size])
110 | side_depth = int((self.num_depth-1)/2)
111 | batch_data = []
112 |
113 | for i in range(idx * self.batch_size, (idx + 1) * self.batch_size):
114 | if vs[random_order[i]][1] - side_depth < 0 :
115 | batch_data += [train_data[vs[random_order[i]][0], :, :, vs[random_order[i]][1]:vs[random_order[i]][1]+self.num_depth]]
116 | elif vs[random_order[i]][1] + side_depth > 256:
117 | batch_data += [train_data[vs[random_order[i]][0], :, :, vs[random_order[i]][1]-self.num_depth:vs[random_order[i]][1]]]
118 | else:
119 | batch_data += [train_data[vs[random_order[i]][0], :, :, vs[random_order[i]][1]-side_depth:vs[random_order[i]][1]+side_depth+1]]
120 |
121 | # opt & feed list (different with paper)
122 | g_opt = [self.opt_g, self.g_loss, self.d_real_loss, self.d_fake_loss]
123 | d_opt = [self.opt_d, self.d_loss, self.merged]
124 | feed_dict = {self.x: batch_x, self.y: batch_data, self.kt: kt, self.lr: lr}
125 |
126 | # run tensorflow
127 | _, loss_g, d_real_loss, d_fake_loss = self.sess.run(g_opt, feed_dict=feed_dict)
128 | _, loss_d, summary = self.sess.run(d_opt, feed_dict=feed_dict)
129 |
130 | # update kt, m_global
131 | kt = np.maximum(np.minimum(1., kt + self.lamda * (self.gamma * d_real_loss - d_fake_loss)), 0.)
132 | m_global = d_real_loss + np.abs(self.gamma * d_real_loss - d_fake_loss)
133 | loss = loss_g + loss_d
134 |
135 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, "
136 | "loss: %.4f, loss_g: %.4f, loss_d: %.4f, d_real: %.4f, d_fake: %.4f, kt: %.8f, M: %.8f"
137 | % (epoch, idx, batch_idxs, time.time() - start_time,
138 | loss, loss_g, loss_d, d_real_loss, d_fake_loss, kt, m_global))
139 |
140 | # write train summary
141 | self.writer.add_summary(summary, self.count)
142 |
143 | # Test during Training
144 | if (self.count % self.niter_snapshot == (self.niter_snapshot - 1)) or (self.count==1):
145 | # update learning rate
146 | lr *= 0.95
147 | # save & test
148 | self.saver.save(self.sess, self.ckpt_model_name, global_step=self.count, write_meta_graph=False)
149 | self.test(train_flag)
150 |
151 | def test(self, train_flag=True):
152 | # generate output
153 | print("tesing..")
154 | img_num = self.batch_size
155 | img_size = self.data_size
156 |
157 | output_f = int(np.sqrt(img_num))
158 | im_output_gen = np.zeros([img_size * output_f, img_size * output_f])
159 | im_output_dec = np.zeros([img_size * output_f, img_size * output_f])
160 |
161 | test_data = np.random.uniform(-1., 1., size=[img_num, self.input_size])
162 | output_gen = (self.sess.run(self.recon_gen, feed_dict={self.x: test_data})) # generator output
163 | output_dec = (self.sess.run(self.recon_dec, feed_dict={self.x: test_data})) # decoder output
164 |
165 | ##
166 | # output_gen = output_gen*256.
167 | # output_dec = output_dec*256
168 | ##
169 |
170 | output_gen_slice = output_gen[:,:,:,int(self.num_depth/2)]
171 | output_dec_slice = output_dec[:,:,:,int(self.num_depth/2)]
172 |
173 | for i in range(output_f):
174 | for j in range(output_f):
175 | im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size] \
176 | = output_gen_slice[j + (i * output_f)]
177 | im_output_dec[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size] \
178 | = output_dec_slice[j + (i * output_f)]
179 |
180 |
181 | # output save
182 | if train_flag:
183 | scm.imsave(self.project_dir + '/result/' + str(self.count) + '_output.bmp', im_output_gen)
184 | else:
185 | now = datetime.datetime.now()
186 | nowDatetime = now.strftime('%Y-%m-%d_%H:%M:%S')
187 | scm.imsave(self.project_dir + '/result_test/gen_{}_output.bmp'.format(nowDatetime), im_output_gen)
188 | scm.imsave(self.project_dir + '/result_test/dec_{}_output.bmp'.format(nowDatetime), im_output_dec)
189 |
190 | def get_latent(self, train_flag=True):
191 | # generate output
192 | print("latent_tesing..")
193 | img_num = self.batch_size
194 | img_size = self.data_size
195 |
196 | test_data_path = self.data_dir + '/Test_data/test.nii.gz'
197 | test_data = get_mri(test_data_path)
198 |
199 | output_f = int(np.sqrt(img_num))
200 | im_output_gen = np.zeros([img_size * output_f, img_size * output_f])
201 | im_output_dec = np.zeros([img_size * output_f, img_size * output_f])
202 |
203 |
204 | test_data = np.random.uniform(-1., 1., size=[img_num, self.input_size])
205 | output_gen = (self.sess.run(self.recon_gen, feed_dict={self.x: test_data})) # generator output
206 | output_dec = (self.sess.run(self.recon_dec, feed_dict={self.x: test_data})) # decoder output
207 |
208 | ##
209 | # output_gen = output_gen*256.
210 | # output_dec = output_dec*256
211 | ##
212 |
213 | output_gen_slice = output_gen[:,:,:,int(self.num_depth/2)]
214 | output_dec_slice = output_dec[:,:,:,int(self.num_depth/2)]
215 |
216 | for i in range(output_f):
217 | for j in range(output_f):
218 | im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size] \
219 | = output_gen_slice[j + (i * output_f)]
220 | im_output_dec[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size] \
221 | = output_dec_slice[j + (i * output_f)]
222 |
223 |
224 | # output save
225 | if train_flag:
226 | scm.imsave(self.project_dir + '/result/' + str(self.count) + '_output.bmp', im_output_gen)
227 | else:
228 | now = datetime.datetime.now()
229 | nowDatetime = now.strftime('%Y-%m-%d_%H:%M:%S')
230 | scm.imsave(self.project_dir + '/result_test/gen_{}_output.bmp'.format(nowDatetime), im_output_gen)
231 | scm.imsave(self.project_dir + '/result_test/dec_{}_output.bmp'.format(nowDatetime), im_output_dec)
232 |
--------------------------------------------------------------------------------
/src/operator/op_base.py:
--------------------------------------------------------------------------------
1 | #Copyright 2018 UNIST under XAI Project supported by Ministry of Science and ICT, Korea
2 |
3 | #Licensed under the Apache License, Version 2.0 (the "License");
4 | #you may not use this file except in compliance with the License.
5 | #You may obtain a copy of the License at
6 |
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 |
9 | #Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
10 |
11 | import glob
12 | import time
13 | import numpy as np
14 | import tensorflow as tf
15 | from src.function.functions import *
16 |
17 | class op_base:
18 | def __init__(self, args, sess):
19 | self.sess = sess
20 |
21 | # Train
22 | self.flag = args.flag
23 | self.gpu_number = args.gpu_number
24 | self.project = args.project
25 |
26 | # Train Data
27 | self.data_dir = args.data_dir #./Data
28 | self.dataset = args.dataset # HCP_MRI
29 | self.data_size = args.data_size #256
30 | self.data_opt = args.data_opt # raw or crop
31 | # self.train_data_path = '{0}/{1}_{2}'.format(self.data_dir, self.dataset, self.scale_factor)
32 | self.train_data = np.load('{0}/{1}_{2}.npy'.format(self.data_dir,self.dataset,self.data_size),mmap_mode='r')
33 | self.num_depth = args.num_depth # 3
34 | # Train Iteration
35 | self.niter = args.niter
36 | self.niter_snapshot = args.nsnapshot
37 | self.max_to_keep = args.max_to_keep
38 |
39 | # Train Parameter
40 | self.batch_size = args.batch_size
41 | self.learning_rate = args.learning_rate
42 | self.mm = args.momentum
43 | self.mm2 = args.momentum2
44 | self.lamda = args.lamda
45 | self.gamma = args.gamma
46 | self.filter_number = args.filter_number
47 | self.input_size = args.input_size
48 | self.embedding = args.embedding
49 |
50 | # Result Dir & File
51 | self.project_dir = 'assets/{0}_{1}_{2}_{3}/'.format(self.project, self.dataset, self.data_opt, self.data_size)
52 | self.ckpt_dir = os.path.join(self.project_dir, 'models')
53 | self.model_name = "{0}.model".format(self.project)
54 | self.ckpt_model_name = os.path.join(self.ckpt_dir, self.model_name)
55 |
56 | # etc.
57 | if not os.path.exists('assets'):
58 | os.makedirs('assets')
59 | make_project_dir(self.project_dir)
60 |
61 | def load(self, sess, saver, ckpt_dir):
62 | ckpt = tf.train.get_checkpoint_state(ckpt_dir)
63 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
64 | saver.restore(sess, os.path.join(ckpt_dir, ckpt_name))
65 |
--------------------------------------------------------------------------------