├── AUTHOR.txt ├── License └── Apache License_ver2.txt ├── README.md ├── main.py └── src ├── __init__.py ├── function ├── functions.py └── preprocessing.py ├── layer └── layers.py ├── models └── BEGAN.py └── operator ├── op_BEGAN.py └── op_base.py /AUTHOR.txt: -------------------------------------------------------------------------------- 1 | Copyright 2018 (Institution) under XAI Project supported by Ministry of Science and ICT, Korea 2 | 3 | # This is the list of (Institution) for copyright purposes. 4 | # This does not necessarily list everyone who has contributed code, since in 5 | # some cases, their employer may be the copyright holder. To see the full list 6 | # of contributors, see the revision history in source control 7 | -------------------------------------------------------------------------------- /License/Apache License_ver2.txt: -------------------------------------------------------------------------------- 1 | Copyright [yyyy] [name of copyright owner] 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Deep Generative Model 4 | 5 | ### **CONTENT** 6 | > Boundary Equilibrium Generative Adversarial Networks based MRI generative model 7 | 8 | ### **Dataset** 9 | > Human Connectome Project 10 | > https://www.humanconnectome.org/study/hcp-young-adult/data-releases 11 | 12 | ### **Reference** 13 | > BEGAN 14 | > https://arxiv.org/abs/1703.10717 15 | 16 | # XAI Project 17 | 18 | ### **Project Name** 19 | > A machine learning and statistical inference framework for explainable artificial intelligence(의사결정 이유를 설명할 수 있는 인간 수준의 학습·추론 프레임워크 개발) 20 | ### **Managed by** 21 | > Ministry of Science and ICT/XAIC 22 | ### **Participated Affiliation** 23 | > UNIST, Korea Univ., Yonsei Univ., KAIST., AItrics 24 | ### **Web Site** 25 | > 26 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 UNIST under XAI Project supported by Ministry of Science and ICT, Korea 2 | 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | #Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 10 | 11 | import argparse 12 | import distutils.util 13 | import os 14 | import tensorflow as tf 15 | import src.models.BEGAN as began 16 | 17 | 18 | def main(): 19 | parser = argparse.ArgumentParser() 20 | 21 | parser.add_argument("-f", "--flag", type=distutils.util.strtobool, default='0') 22 | parser.add_argument("-g", "--gpu_number", type=str, default="1") 23 | parser.add_argument("-p", "--project", type=str, default="MRIGAN_2D_g0.3_d3") 24 | 25 | # Train Data 26 | parser.add_argument("-d", "--data_dir", type=str, default="./Data/MRI") 27 | parser.add_argument("-trd", "--dataset", type=str, default="HCP_MRI") 28 | parser.add_argument("-tro", "--data_opt", type=str, default="crop") 29 | parser.add_argument("-trs", "--data_size", type=int, default=256) 30 | parser.add_argument("-ndp", "--num_depth", type=int, default=3) 31 | 32 | # Train Iteration 33 | parser.add_argument("-n" , "--niter", type=int, default=200) 34 | parser.add_argument("-ns", "--nsnapshot", type=int, default=5000) 35 | parser.add_argument("-mx", "--max_to_keep", type=int, default=5) 36 | 37 | # Train Parameter 38 | parser.add_argument("-b" , "--batch_size", type=int, default=1) 39 | parser.add_argument("-lr", "--learning_rate", type=float, default=1e-4) 40 | parser.add_argument("-m" , "--momentum", type=float, default=0.5) 41 | parser.add_argument("-m2", "--momentum2", type=float, default=0.999) 42 | parser.add_argument("-gm", "--gamma", type=float, default=0.3) 43 | parser.add_argument("-lm", "--lamda", type=float, default=0.001) 44 | parser.add_argument("-fn", "--filter_number", type=int, default=64) 45 | parser.add_argument("-z", "--input_size", type=int, default=256) 46 | parser.add_argument("-em", "--embedding", type=int, default=256) 47 | 48 | args = parser.parse_args() 49 | 50 | gpu_number = args.gpu_number 51 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_number 52 | 53 | with tf.device('/gpu:{0}'.format(gpu_number)): 54 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.90) 55 | config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options) 56 | 57 | with tf.Session(config=config) as sess: 58 | model = began.BEGAN(args, sess) 59 | 60 | # TRAIN / TEST 61 | if args.flag: 62 | model.train(args.flag) 63 | else: 64 | model.test(args.flag) 65 | 66 | if __name__ == '__main__': 67 | main() 68 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | ## __init__.py -------------------------------------------------------------------------------- /src/function/functions.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 UNIST under XAI Project supported by Ministry of Science and ICT, Korea 2 | 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | #Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 10 | 11 | import os 12 | import scipy.misc as scm 13 | import nibabel as nib 14 | 15 | def make_project_dir(project_dir): 16 | if not os.path.exists(project_dir): 17 | os.makedirs(project_dir) 18 | os.makedirs(os.path.join(project_dir, 'models')) 19 | os.makedirs(os.path.join(project_dir, 'result')) 20 | os.makedirs(os.path.join(project_dir, 'result_test')) 21 | 22 | def get_image(img_path): 23 | img = scm.imread(img_path)/255. - 0.5 24 | img = img[..., ::-1] # rgb to bgr 25 | return img 26 | 27 | 28 | def inverse_image(img): 29 | img = (img + 0.5) * 255. 30 | img[img > 255] = 255 31 | img[img < 0] = 0 32 | img = img[..., ::-1] # bgr to rgb 33 | return img 34 | 35 | def save_as_nii(vol, aff, save_dir): 36 | for i in range(len(vol)): 37 | img = nib.Nifti1Image(dataobj=vol[i,...], affine=aff) 38 | nib.save(img,'{}_{}.nii'.format(save_dir,i)) 39 | print("MRI file saved..!") 40 | return 41 | -------------------------------------------------------------------------------- /src/function/preprocessing.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 UNIST under XAI Project supported by Ministry of Science and ICT, Korea 2 | 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | #Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 10 | 11 | import os 12 | import numpy as np 13 | import nibabel as nib 14 | from scipy.ndimage.interpolation import zoom 15 | 16 | TARGET_DIR = '../../Data/HCP_MRI' 17 | TARGET_FNAME = 'T1w_restore_brain.nii.gz' 18 | # SAVE_DIR = '../../Data/MRI/HCP_MRI_256.npy' 19 | SAVE_DIR = '/DATA_1/HCP_MRI_256.npy' 20 | COUNT = 0 21 | MIN = 0 22 | 23 | def make_path_list(dir, filename): 24 | 25 | pathlist = [] 26 | for root, _, fnames in sorted(os.walk(dir)): 27 | for fname in sorted(fnames): 28 | if fname == filename: 29 | path = os.path.join(root,fname) 30 | pathlist.append(path) 31 | pathlist = np.asarray(pathlist) 32 | return pathlist 33 | 34 | def normalize(img): 35 | max = np.max(img) 36 | min = np.min(img) 37 | normalized_img = (img-min)/(max-min) 38 | 39 | return normalized_img 40 | 41 | def rescale(vol,scale): 42 | # The MRI dataset shape is w > h = d, so make rescaled mri isotropic 43 | h,w,d = vol.shape 44 | vol_rs = zoom(vol,zoom=(scale,scale*float(h)/w,scale),mode='nearest') 45 | return vol_rs 46 | 47 | def get_mri(data_path,): 48 | first_flag = True 49 | proxy_img = nib.load(data_path) 50 | data_array = np.asarray(proxy_img.dataobj).astype(np.float32) 51 | data_array = data_array[2:-2,:,2:-2] 52 | global COUNT 53 | for s in range(data_array.shape[1]): 54 | _slice = data_array[:,s,:] 55 | if (np.count_nonzero(_slice)==0): 56 | continue 57 | _slice = _slice.T[::-1,:] 58 | _slice = normalize(_slice) 59 | if first_flag: 60 | concat = _slice[...,None] 61 | first_flag = False 62 | else: 63 | concat = np.concatenate((concat,_slice[...,None]),axis=2) 64 | if concat.shape[2] == 257: 65 | print(COUNT + 1) 66 | COUNT += 1 67 | return concat 68 | return concat 69 | 70 | def print_aff(data_path): 71 | proxy_img = nib.load(data_path) 72 | print(proxy_img.affine) 73 | 74 | return 75 | 76 | def get_aff(dir = TARGET_DIR, fname= TARGET_FNAME): 77 | f_list = make_path_list(dir, fname) 78 | proxy_img = nib.load(f_list[0]) 79 | return proxy_img.affine 80 | 81 | def preprocessing(data_dir=TARGET_DIR,save_dir=SAVE_DIR,fname=TARGET_FNAME): 82 | 83 | print("Preprocessing Start") 84 | f_list = make_path_list(data_dir,fname) 85 | concat_mri = [get_mri(path) for path in f_list] 86 | concat_mri = np.asarray(concat_mri).astype(np.float32) 87 | np.save(save_dir,concat_mri) 88 | print("Concatenation Done") 89 | return 90 | 91 | def get_min_nonzero_slice(data_dir=TARGET_DIR,fname=TARGET_FNAME): 92 | f_list = make_path_list(data_dir, fname) 93 | m_count = 10000 94 | i = 1 95 | for path in f_list: 96 | data = np.asarray(nib.load(path).dataobj) 97 | data = np.transpose(data,(0,2,1)) 98 | data = data.reshape((-1,data.shape[-1])) 99 | max = np.max(data,axis=0) 100 | count = np.count_nonzero(max) 101 | m_count = min(m_count,count) 102 | print(i,":") 103 | print(m_count) 104 | 105 | i += 1 106 | print(m_count) 107 | return m_count 108 | 109 | if __name__=="__main__": 110 | preprocessing() 111 | # get_min_nonzero_slice() 112 | -------------------------------------------------------------------------------- /src/layer/layers.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 UNIST under XAI Project supported by Ministry of Science and ICT, Korea 2 | 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | #Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 10 | 11 | import tensorflow as tf 12 | import numpy as np 13 | 14 | 15 | def conv2d(x, filter_shape, bias=True, stride=1, padding="SAME", name="conv2d"): 16 | kw, kh, nin, nout = filter_shape 17 | pad_size = (kw - 1) / 2 18 | 19 | if padding == "VALID": 20 | x = tf.pad(x, [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]], "SYMMETRIC") 21 | 22 | initializer = tf.random_normal_initializer(0., 0.02) 23 | with tf.variable_scope(name): 24 | weight = tf.get_variable("weight", shape=filter_shape, initializer=initializer) 25 | x = tf.nn.conv2d(x, weight, [1, stride, stride, 1], padding=padding) 26 | 27 | if bias: 28 | b = tf.get_variable("bias", shape=filter_shape[-1], initializer=tf.constant_initializer(0.)) 29 | x = tf.nn.bias_add(x, b) 30 | return x 31 | 32 | 33 | def fc(x, output_shape, bias=True, name='fc'): 34 | shape = x.get_shape().as_list() 35 | dim = np.prod(shape[1:]) 36 | x = tf.reshape(x, [-1, dim]) 37 | input_shape = dim 38 | 39 | initializer = tf.random_normal_initializer(0., 0.02) 40 | with tf.variable_scope(name): 41 | weight = tf.get_variable("weight", shape=[input_shape, output_shape], initializer=initializer) 42 | x = tf.matmul(x, weight) 43 | 44 | if bias: 45 | b = tf.get_variable("bias", shape=[output_shape], initializer=tf.constant_initializer(0.)) 46 | x = tf.nn.bias_add(x, b) 47 | return x 48 | 49 | 50 | def pool(x, r=2, s=1): 51 | return tf.nn.avg_pool(x, ksize=[1, r, r, 1], strides=[1, s, s, 1], padding="SAME") 52 | 53 | 54 | def l1_loss(x, y): 55 | return tf.reduce_mean(tf.abs(x - y)) 56 | 57 | 58 | def resize_nn(x, size): 59 | return tf.image.resize_nearest_neighbor(x, size=(int(size), int(size))) 60 | -------------------------------------------------------------------------------- /src/models/BEGAN.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 UNIST under XAI Project supported by Ministry of Science and ICT, Korea 2 | 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | #Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 10 | 11 | from src.layer.layers import * 12 | from src.operator.op_BEGAN import Operator 13 | 14 | 15 | class BEGAN(Operator): 16 | def __init__(self, args, sess): 17 | Operator.__init__(self, args, sess) 18 | 19 | def generator(self, x, reuse=None): 20 | with tf.variable_scope('gen_') as scope: 21 | if reuse: 22 | scope.reuse_variables() 23 | 24 | w = self.data_size 25 | f = self.filter_number 26 | v = self.num_depth 27 | p = "SAME" 28 | 29 | x = fc(x, 8 * 8 * f, name='fc') 30 | x = tf.reshape(x, [-1, 8, 8, f]) 31 | 32 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_a') 33 | x = tf.nn.elu(x) 34 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_b') 35 | x = tf.nn.elu(x) 36 | 37 | if self.data_size == 256: 38 | x = resize_nn(x, w/16) 39 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_a') 40 | x = tf.nn.elu(x) 41 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_b') 42 | x = tf.nn.elu(x) 43 | 44 | if (self.data_size == 128) or (self.data_size == 256): 45 | x = resize_nn(x, w / 8) 46 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_a') 47 | x = tf.nn.elu(x) 48 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_b') 49 | x = tf.nn.elu(x) 50 | 51 | x = resize_nn(x, w / 4) 52 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_a') 53 | x = tf.nn.elu(x) 54 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_b') 55 | x = tf.nn.elu(x) 56 | 57 | x = resize_nn(x, w / 2) 58 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv5_a') 59 | x = tf.nn.elu(x) 60 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv5_b') 61 | x = tf.nn.elu(x) 62 | 63 | x = resize_nn(x, w) 64 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv6_a') 65 | x = tf.nn.elu(x) 66 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv6_b') 67 | x = tf.nn.elu(x) 68 | 69 | x = conv2d(x, [3, 3, f, v], stride=1, padding=p,name='conv7_a') 70 | return x 71 | 72 | def encoder(self, x, reuse=None): 73 | with tf.variable_scope('disc_') as scope: 74 | if reuse: 75 | scope.reuse_variables() 76 | 77 | f = self.filter_number 78 | h = self.embedding 79 | v = self.num_depth 80 | p = "SAME" 81 | 82 | x = conv2d(x, [3, 3, v, f], stride=1, padding=p,name='conv1_enc_a') 83 | x = tf.nn.elu(x) 84 | 85 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv2_enc_a') 86 | x = tf.nn.elu(x) 87 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p,name='conv2_enc_b') 88 | x = tf.nn.elu(x) 89 | 90 | x = conv2d(x, [1, 1, f, 2 * f], stride=1, padding=p,name='conv3_enc_0') 91 | 92 | x = pool(x, r=2, s=2) 93 | 94 | x = conv2d(x, [3, 3, 2 * f, 2 * f], stride=1, padding=p,name='conv3_enc_a') 95 | x = tf.nn.elu(x) 96 | x = conv2d(x, [3, 3, 2 * f, 2 * f], stride=1, padding=p,name='conv3_enc_b') 97 | x = tf.nn.elu(x) 98 | 99 | x = conv2d(x, [1, 1, 2 * f, 3 * f], stride=1, padding=p,name='conv4_enc_0') 100 | 101 | x = pool(x, r=2, s=2) 102 | 103 | x = conv2d(x, [3, 3, 3 * f, 3 * f], stride=1, padding=p,name='conv4_enc_a') 104 | x = tf.nn.elu(x) 105 | x = conv2d(x, [3, 3, 3 * f, 3 * f], stride=1, padding=p,name='conv4_enc_b') 106 | x = tf.nn.elu(x) 107 | 108 | x = conv2d(x, [1, 1, 3 * f, 4 * f], stride=1, padding=p,name='conv5_enc_0') 109 | 110 | x = pool(x, r=2, s=2) 111 | 112 | x = conv2d(x, [3, 3, 4 * f, 4 * f], stride=1, padding=p,name='conv5_enc_a') 113 | x = tf.nn.elu(x) 114 | x = conv2d(x, [3, 3, 4 * f, 4 * f], stride=1, padding=p,name='conv5_enc_b') 115 | x = tf.nn.elu(x) 116 | 117 | if (self.data_size == 128) or (self.data_size == 256): 118 | x = conv2d(x, [1, 1, 4 * f, 5 * f], stride=1, padding=p,name='conv6_enc_0') 119 | x = pool(x, r=2, s=2) 120 | x = conv2d(x, [3, 3, 5 * f, 5 * f], stride=1, padding=p,name='conv6_enc_a') 121 | x = tf.nn.elu(x) 122 | x = conv2d(x, [3, 3, 5 * f, 5 * f], stride=1, padding=p,name='conv6_enc_b') 123 | x = tf.nn.elu(x) 124 | 125 | if self.data_size == 256: 126 | x = conv2d(x, [1, 1, 5 * f, 6 * f], stride=1, padding=p,name='conv7_enc_0') 127 | x = pool(x, r=2, s=2) 128 | x = conv2d(x, [3, 3, 6 * f, 6 * f], stride=1, padding=p,name='conv7_enc_a') 129 | x = tf.nn.elu(x) 130 | x = conv2d(x, [3, 3, 6 * f, 6 * f], stride=1, padding=p,name='conv7_enc_b') 131 | x = tf.nn.elu(x) 132 | 133 | x = fc(x, h, name='enc_fc') 134 | return x 135 | 136 | def decoder(self, x, reuse=None): 137 | with tf.variable_scope('disc_') as scope: 138 | if reuse: 139 | scope.reuse_variables() 140 | 141 | w = self.data_size 142 | f = self.filter_number 143 | v = self.num_depth 144 | p = "SAME" 145 | 146 | x = fc(x, 8 * 8 * f, name='fc') 147 | x = tf.reshape(x, [-1, 8, 8, f]) 148 | 149 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_a') 150 | x = tf.nn.elu(x) 151 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv1_b') 152 | x = tf.nn.elu(x) 153 | 154 | if self.data_size == 256: 155 | x = resize_nn(x, w/16) 156 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_a') 157 | x = tf.nn.elu(x) 158 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv2_b') 159 | x = tf.nn.elu(x) 160 | 161 | if (self.data_size == 128) or (self.data_size == 256): 162 | x = resize_nn(x, w / 8) 163 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_a') 164 | x = tf.nn.elu(x) 165 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv3_b') 166 | x = tf.nn.elu(x) 167 | 168 | x = resize_nn(x, w / 4) 169 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_a') 170 | x = tf.nn.elu(x) 171 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv4_b') 172 | x = tf.nn.elu(x) 173 | 174 | x = resize_nn(x, w / 2) 175 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv5_a') 176 | x = tf.nn.elu(x) 177 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv5_b') 178 | x = tf.nn.elu(x) 179 | 180 | x = resize_nn(x, w) 181 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv6_a') 182 | x = tf.nn.elu(x) 183 | x = conv2d(x, [3, 3, f, f], stride=1, padding=p, name='conv6_b') 184 | x = tf.nn.elu(x) 185 | 186 | x = conv2d(x, [3, 3, f, v], stride=1, padding=p, name='conv7_a') 187 | return x 188 | -------------------------------------------------------------------------------- /src/operator/op_BEGAN.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 UNIST under XAI Project supported by Ministry of Science and ICT, Korea 2 | 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | #Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 10 | 11 | import time 12 | import datetime 13 | from src.layer.layers import * 14 | from src.function.functions import * 15 | from src.function.preprocessing import * 16 | from src.operator.op_base import op_base 17 | 18 | 19 | class Operator(op_base): 20 | def __init__(self, args, sess): 21 | op_base.__init__(self, args, sess) 22 | self.build_model() 23 | 24 | def build_model(self): 25 | # Input placeholder 26 | self.x = tf.placeholder(tf.float32, shape=[self.batch_size, self.input_size], name='x') 27 | self.y = tf.placeholder(tf.float32, shape=[self.batch_size, self.data_size, self.data_size, self.num_depth], name='y') 28 | self.kt = tf.placeholder(tf.float32, name='kt') 29 | self.lr = tf.placeholder(tf.float32, name='lr') 30 | 31 | # # latent 32 | 33 | 34 | # Generator 35 | self.recon_gen = self.generator(self.x) 36 | 37 | # Discriminator (Critic) 38 | self.aaaaaa = self.encoder(self.y) 39 | d_real = self.decoder(self.aaaaaa) 40 | d_fake = self.decoder(self.encoder(self.recon_gen, reuse=True), reuse=True) 41 | self.recon_dec = self.decoder(self.x, reuse=True) 42 | 43 | # Loss 44 | self.d_real_loss = l1_loss(self.y, d_real) 45 | self.d_fake_loss = l1_loss(self.recon_gen, d_fake) 46 | self.d_loss = self.d_real_loss - self.kt * self.d_fake_loss 47 | self.g_loss = self.d_fake_loss 48 | self.m_global = self.d_real_loss + tf.abs(self.gamma * self.d_real_loss - self.d_fake_loss) 49 | 50 | # Variables 51 | g_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "gen_") 52 | d_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, "disc_") 53 | 54 | # Optimizer 55 | self.opt_g = tf.train.AdamOptimizer(self.lr, self.mm).minimize(self.g_loss, var_list=g_vars) 56 | self.opt_d = tf.train.AdamOptimizer(self.lr, self.mm).minimize(self.d_loss, var_list=d_vars) 57 | 58 | 59 | # initializer 60 | self.sess.run(tf.global_variables_initializer()) 61 | 62 | # tf saver 63 | self.saver = tf.train.Saver(max_to_keep=(self.max_to_keep)) 64 | 65 | try: 66 | self.load(self.sess, self.saver, self.ckpt_dir) 67 | except: 68 | # save full graph 69 | self.saver.save(self.sess, self.ckpt_model_name, write_meta_graph=True) 70 | 71 | # Summary 72 | if self.flag: 73 | tf.summary.scalar('loss/loss', self.d_loss + self.g_loss) 74 | tf.summary.scalar('loss/g_loss', self.g_loss) 75 | tf.summary.scalar('loss/d_loss', self.d_loss) 76 | tf.summary.scalar('loss/d_real_loss', self.d_real_loss) 77 | tf.summary.scalar('loss/d_fake_loss', self.d_fake_loss) 78 | tf.summary.scalar('misc/kt', self.kt) 79 | tf.summary.scalar('misc/m_global', self.m_global) 80 | self.merged = tf.summary.merge_all() 81 | self.writer = tf.summary.FileWriter(self.project_dir, self.sess.graph) 82 | 83 | def train(self, train_flag): 84 | # load data 85 | train_data = self.train_data 86 | print('Shuffle ....') 87 | num_vol = train_data.shape[0] 88 | num_sli = train_data.shape[3] 89 | vv = np.arange(num_vol) 90 | ss = np.arange(num_sli) 91 | v_c, s_c = np.meshgrid(vv,ss) 92 | vs = np.column_stack([v_c.flat,s_c.flat]) 93 | data_length = len(vs) 94 | random_order = np.random.permutation(data_length) 95 | print('Shuffle Done') 96 | 97 | # initial parameter 98 | start_time = time.time() 99 | kt = np.float32(0.) 100 | lr = np.float32(self.learning_rate) 101 | self.count = 0 102 | 103 | for epoch in range(self.niter): 104 | batch_idxs = len(vs) // self.batch_size 105 | 106 | for idx in range(0, batch_idxs): 107 | self.count += 1 108 | 109 | batch_x = np.random.uniform(-1., 1., size=[self.batch_size, self.input_size]) 110 | side_depth = int((self.num_depth-1)/2) 111 | batch_data = [] 112 | 113 | for i in range(idx * self.batch_size, (idx + 1) * self.batch_size): 114 | if vs[random_order[i]][1] - side_depth < 0 : 115 | batch_data += [train_data[vs[random_order[i]][0], :, :, vs[random_order[i]][1]:vs[random_order[i]][1]+self.num_depth]] 116 | elif vs[random_order[i]][1] + side_depth > 256: 117 | batch_data += [train_data[vs[random_order[i]][0], :, :, vs[random_order[i]][1]-self.num_depth:vs[random_order[i]][1]]] 118 | else: 119 | batch_data += [train_data[vs[random_order[i]][0], :, :, vs[random_order[i]][1]-side_depth:vs[random_order[i]][1]+side_depth+1]] 120 | 121 | # opt & feed list (different with paper) 122 | g_opt = [self.opt_g, self.g_loss, self.d_real_loss, self.d_fake_loss] 123 | d_opt = [self.opt_d, self.d_loss, self.merged] 124 | feed_dict = {self.x: batch_x, self.y: batch_data, self.kt: kt, self.lr: lr} 125 | 126 | # run tensorflow 127 | _, loss_g, d_real_loss, d_fake_loss = self.sess.run(g_opt, feed_dict=feed_dict) 128 | _, loss_d, summary = self.sess.run(d_opt, feed_dict=feed_dict) 129 | 130 | # update kt, m_global 131 | kt = np.maximum(np.minimum(1., kt + self.lamda * (self.gamma * d_real_loss - d_fake_loss)), 0.) 132 | m_global = d_real_loss + np.abs(self.gamma * d_real_loss - d_fake_loss) 133 | loss = loss_g + loss_d 134 | 135 | print("Epoch: [%2d] [%4d/%4d] time: %4.4f, " 136 | "loss: %.4f, loss_g: %.4f, loss_d: %.4f, d_real: %.4f, d_fake: %.4f, kt: %.8f, M: %.8f" 137 | % (epoch, idx, batch_idxs, time.time() - start_time, 138 | loss, loss_g, loss_d, d_real_loss, d_fake_loss, kt, m_global)) 139 | 140 | # write train summary 141 | self.writer.add_summary(summary, self.count) 142 | 143 | # Test during Training 144 | if (self.count % self.niter_snapshot == (self.niter_snapshot - 1)) or (self.count==1): 145 | # update learning rate 146 | lr *= 0.95 147 | # save & test 148 | self.saver.save(self.sess, self.ckpt_model_name, global_step=self.count, write_meta_graph=False) 149 | self.test(train_flag) 150 | 151 | def test(self, train_flag=True): 152 | # generate output 153 | print("tesing..") 154 | img_num = self.batch_size 155 | img_size = self.data_size 156 | 157 | output_f = int(np.sqrt(img_num)) 158 | im_output_gen = np.zeros([img_size * output_f, img_size * output_f]) 159 | im_output_dec = np.zeros([img_size * output_f, img_size * output_f]) 160 | 161 | test_data = np.random.uniform(-1., 1., size=[img_num, self.input_size]) 162 | output_gen = (self.sess.run(self.recon_gen, feed_dict={self.x: test_data})) # generator output 163 | output_dec = (self.sess.run(self.recon_dec, feed_dict={self.x: test_data})) # decoder output 164 | 165 | ## 166 | # output_gen = output_gen*256. 167 | # output_dec = output_dec*256 168 | ## 169 | 170 | output_gen_slice = output_gen[:,:,:,int(self.num_depth/2)] 171 | output_dec_slice = output_dec[:,:,:,int(self.num_depth/2)] 172 | 173 | for i in range(output_f): 174 | for j in range(output_f): 175 | im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size] \ 176 | = output_gen_slice[j + (i * output_f)] 177 | im_output_dec[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size] \ 178 | = output_dec_slice[j + (i * output_f)] 179 | 180 | 181 | # output save 182 | if train_flag: 183 | scm.imsave(self.project_dir + '/result/' + str(self.count) + '_output.bmp', im_output_gen) 184 | else: 185 | now = datetime.datetime.now() 186 | nowDatetime = now.strftime('%Y-%m-%d_%H:%M:%S') 187 | scm.imsave(self.project_dir + '/result_test/gen_{}_output.bmp'.format(nowDatetime), im_output_gen) 188 | scm.imsave(self.project_dir + '/result_test/dec_{}_output.bmp'.format(nowDatetime), im_output_dec) 189 | 190 | def get_latent(self, train_flag=True): 191 | # generate output 192 | print("latent_tesing..") 193 | img_num = self.batch_size 194 | img_size = self.data_size 195 | 196 | test_data_path = self.data_dir + '/Test_data/test.nii.gz' 197 | test_data = get_mri(test_data_path) 198 | 199 | output_f = int(np.sqrt(img_num)) 200 | im_output_gen = np.zeros([img_size * output_f, img_size * output_f]) 201 | im_output_dec = np.zeros([img_size * output_f, img_size * output_f]) 202 | 203 | 204 | test_data = np.random.uniform(-1., 1., size=[img_num, self.input_size]) 205 | output_gen = (self.sess.run(self.recon_gen, feed_dict={self.x: test_data})) # generator output 206 | output_dec = (self.sess.run(self.recon_dec, feed_dict={self.x: test_data})) # decoder output 207 | 208 | ## 209 | # output_gen = output_gen*256. 210 | # output_dec = output_dec*256 211 | ## 212 | 213 | output_gen_slice = output_gen[:,:,:,int(self.num_depth/2)] 214 | output_dec_slice = output_dec[:,:,:,int(self.num_depth/2)] 215 | 216 | for i in range(output_f): 217 | for j in range(output_f): 218 | im_output_gen[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size] \ 219 | = output_gen_slice[j + (i * output_f)] 220 | im_output_dec[i * img_size:(i + 1) * img_size, j * img_size:(j + 1) * img_size] \ 221 | = output_dec_slice[j + (i * output_f)] 222 | 223 | 224 | # output save 225 | if train_flag: 226 | scm.imsave(self.project_dir + '/result/' + str(self.count) + '_output.bmp', im_output_gen) 227 | else: 228 | now = datetime.datetime.now() 229 | nowDatetime = now.strftime('%Y-%m-%d_%H:%M:%S') 230 | scm.imsave(self.project_dir + '/result_test/gen_{}_output.bmp'.format(nowDatetime), im_output_gen) 231 | scm.imsave(self.project_dir + '/result_test/dec_{}_output.bmp'.format(nowDatetime), im_output_dec) 232 | -------------------------------------------------------------------------------- /src/operator/op_base.py: -------------------------------------------------------------------------------- 1 | #Copyright 2018 UNIST under XAI Project supported by Ministry of Science and ICT, Korea 2 | 3 | #Licensed under the Apache License, Version 2.0 (the "License"); 4 | #you may not use this file except in compliance with the License. 5 | #You may obtain a copy of the License at 6 | 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | 9 | #Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 10 | 11 | import glob 12 | import time 13 | import numpy as np 14 | import tensorflow as tf 15 | from src.function.functions import * 16 | 17 | class op_base: 18 | def __init__(self, args, sess): 19 | self.sess = sess 20 | 21 | # Train 22 | self.flag = args.flag 23 | self.gpu_number = args.gpu_number 24 | self.project = args.project 25 | 26 | # Train Data 27 | self.data_dir = args.data_dir #./Data 28 | self.dataset = args.dataset # HCP_MRI 29 | self.data_size = args.data_size #256 30 | self.data_opt = args.data_opt # raw or crop 31 | # self.train_data_path = '{0}/{1}_{2}'.format(self.data_dir, self.dataset, self.scale_factor) 32 | self.train_data = np.load('{0}/{1}_{2}.npy'.format(self.data_dir,self.dataset,self.data_size),mmap_mode='r') 33 | self.num_depth = args.num_depth # 3 34 | # Train Iteration 35 | self.niter = args.niter 36 | self.niter_snapshot = args.nsnapshot 37 | self.max_to_keep = args.max_to_keep 38 | 39 | # Train Parameter 40 | self.batch_size = args.batch_size 41 | self.learning_rate = args.learning_rate 42 | self.mm = args.momentum 43 | self.mm2 = args.momentum2 44 | self.lamda = args.lamda 45 | self.gamma = args.gamma 46 | self.filter_number = args.filter_number 47 | self.input_size = args.input_size 48 | self.embedding = args.embedding 49 | 50 | # Result Dir & File 51 | self.project_dir = 'assets/{0}_{1}_{2}_{3}/'.format(self.project, self.dataset, self.data_opt, self.data_size) 52 | self.ckpt_dir = os.path.join(self.project_dir, 'models') 53 | self.model_name = "{0}.model".format(self.project) 54 | self.ckpt_model_name = os.path.join(self.ckpt_dir, self.model_name) 55 | 56 | # etc. 57 | if not os.path.exists('assets'): 58 | os.makedirs('assets') 59 | make_project_dir(self.project_dir) 60 | 61 | def load(self, sess, saver, ckpt_dir): 62 | ckpt = tf.train.get_checkpoint_state(ckpt_dir) 63 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 64 | saver.restore(sess, os.path.join(ckpt_dir, ckpt_name)) 65 | --------------------------------------------------------------------------------