├── .gitignore ├── README.md ├── checkpoint └── checkpoint path for Tensorflow ├── data └── test │ ├── test_data_cranial.mat │ └── test_data_pleural.mat ├── result ├── result_FCN.mat ├── result_itertive_decomposition.mat └── result_matrix_inversion.mat └── src ├── AXfunc_pwls.m ├── get_weight.m ├── gradient_LS.m ├── iterative_decomposition.m ├── layer_xyf.py ├── main.py ├── matrix_inversion.m ├── model.py └── solver.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ImageDecomposition-DECT 2 | 3 | This reposity is organized mainly for an image decomposition algorithm which is proposed to solve the material decomposition problem in Dual-energy Computed Tomography (DECT).
4 | 5 | The algorithm is designed based on deep learning paradigm. For more theoretical details, please go to [Deep Learning](http://www.deeplearningbook.org/) and [Material Decomposition Using DECT](https://pubs.rsna.org/doi/10.1148/rg.2016150220).
6 | 7 | The algorithm is related to the paper ["Image Decomposition Algorithm for Dual-Energy Computed Tomography via Fully Convolutional Network".](https://www.hindawi.com/journals/cmmm/2018/2527516/cta/) (DOI: 10.1155/2018/2527516) 8 | 9 | All have been tested with python 3.6 and tensorflow 1.4.0 in Linux.
10 | * checkpoint: the checkpoint path for the model trained with tensorflow. The [pre-trained model](https://pan.baidu.com/s/1r1OTjid2muWWZfxURB8Pjw) was trained on a dataset which contained totally 2,454,300 samples. Each sample is a 65*65 image patch extracted from 5987 image slices. 11 | * data: contains 2 path. 12 | * test: two test data files, 'test_data_cranial.mat' and 'test_data_pleural.mat'. 13 | * train: we only provide a sub-set (90,000 training samples) in the 'training_samples_90000.rar' file which can be download from [here](https://pan.baidu.com/s/1r1OTjid2muWWZfxURB8Pjw).
14 | * result: save decomposition result. 15 | * src: the codes for three decomposition algorithms: 16 | * Direct matrix inversion (matrix_inversion.m) 17 | * Iterative decomposition (iterative_decomposition.m). Related paper: [Iterative image-domain decomposition for dual-energy CT](https://aapm.onlinelibrary.wiley.com/doi/abs/10.1118/1.4889338) 18 | * The proposed deep model (main.py). After download the pre-trained mode, you can use the following command to run the algorithm.
19 | >> " python main.py --dataset="../data/test/test_data_crainal.mat" --model="feedforward" --model_name="your-saved-result-name" --checkpoint="../checkpoint/FCN_trained_model "
20 | 21 | # Contact 22 | Email: vastcyclone@yeah.net 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /checkpoint/checkpoint path for Tensorflow: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/test/test_data_cranial.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XYF-GitHub/ImageDecomposition-DECT/d8c1c6ba078e24f54c723d703927823de08b1712/data/test/test_data_cranial.mat -------------------------------------------------------------------------------- /data/test/test_data_pleural.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XYF-GitHub/ImageDecomposition-DECT/d8c1c6ba078e24f54c723d703927823de08b1712/data/test/test_data_pleural.mat -------------------------------------------------------------------------------- /result/result_FCN.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XYF-GitHub/ImageDecomposition-DECT/d8c1c6ba078e24f54c723d703927823de08b1712/result/result_FCN.mat -------------------------------------------------------------------------------- /result/result_itertive_decomposition.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XYF-GitHub/ImageDecomposition-DECT/d8c1c6ba078e24f54c723d703927823de08b1712/result/result_itertive_decomposition.mat -------------------------------------------------------------------------------- /result/result_matrix_inversion.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XYF-GitHub/ImageDecomposition-DECT/d8c1c6ba078e24f54c723d703927823de08b1712/result/result_matrix_inversion.mat -------------------------------------------------------------------------------- /src/AXfunc_pwls.m: -------------------------------------------------------------------------------- 1 | function [y] = AXfunc_pwls(x,A,At,weight,beta1,beta2) 2 | %%% description: 3 | 4 | % qCGMRF weighting 5 | % r = [beta1*gradient_qCGMRF(x(1:end/2), weight) ... 6 | % beta2*gradient_qCGMRF(x(end/2+1:end), weight)]; 7 | 8 | % quadratic weighting 9 | r = [beta1*gradient_LS(x(1:end/2), weight) ... 10 | beta2*gradient_LS(x(end/2+1:end), weight)]; 11 | 12 | 13 | r = r(:); 14 | y = (At*(A*x)) + r; 15 | -------------------------------------------------------------------------------- /src/get_weight.m: -------------------------------------------------------------------------------- 1 | function w = get_weight(img, h, w) 2 | w = ones([[h, w] 4]); 3 | std_1 = std(img(1:end/2)); 4 | std_2 = std(img(end/2 + 1:end)); 5 | std_ratio = std_1/std_2; 6 | 7 | BW = edge(img(:,1:end/2) + img(:,end/2 + 1:end)*std_ratio,'canny',0.1,1.5); 8 | BW_2 = edge(img(:,1:end/2) + img(:,end/2 + 1:end)*std_ratio,'prewitt',0.1,'both'); 9 | BW = BW + BW_2; 10 | 11 | dummy = BW; 12 | g1_d2 = dummy - [dummy(2:end,:); zeros(1, size(dummy,1))]; 13 | g1_d1 = dummy - [zeros(1, size(dummy,1)); dummy(1:end - 1,:)]; 14 | g1_d0 = dummy - [zeros(size(dummy,1),1) dummy(:,1:end - 1)]; 15 | g1_d3 = dummy - [dummy(:,2:end) zeros(size(dummy,1),1)]; 16 | 17 | thresh = 0.0001; 18 | Lvalue = 0.1; 19 | dummy = ones([512 512]); 20 | dummy(abs(g1_d0) > thresh) = Lvalue; 21 | w(:,:,1) = dummy; 22 | dummy = ones([512 512]); 23 | dummy(abs(g1_d1) > thresh) = Lvalue; 24 | w(:,:,2) = dummy; 25 | dummy = ones([512 512]); 26 | dummy(abs(g1_d2) > thresh) = Lvalue; 27 | w(:,:,3) = dummy; 28 | dummy = ones([512 512]); 29 | dummy(abs(g1_d3) > thresh) = Lvalue; 30 | w(:,:,4) = dummy; 31 | end -------------------------------------------------------------------------------- /src/gradient_LS.m: -------------------------------------------------------------------------------- 1 | function gf = gradient_LS(x,weight) 2 | 3 | 4 | %% 5 | % g1_d2 = dummy-[dummy(2:end,:); zeros(1, size(dummy,1))]; 6 | % g1_d1 = dummy-[zeros(1, size(dummy,1)); dummy(1:end-1,:)]; 7 | % g1_d0 = dummy-[zeros(size(dummy,1),1) dummy(:,1:end-1)]; 8 | % g1_d3 = dummy-[dummy(:,2:end) zeros(size(dummy,1),1)]; 9 | %% 10 | [row,col,~]=size(weight); 11 | xf = reshape(x,row,col); 12 | 13 | xf_d0 = xf-[zeros(size(xf,1),1) xf(:,1:end-1)]; 14 | xf_d1 = xf-[zeros(1, size(xf,1)); xf(1:end-1,:)]; 15 | xf_d2 = xf-[xf(2:end,:); zeros(1, size(xf,1))]; 16 | xf_d3 = xf-[xf(:,2:end) zeros(size(xf,1),1)]; 17 | 18 | gf = 4*(xf_d0.*weight(:,:,1)+xf_d1.*weight(:,:,2)+xf_d2.*weight(:,:,3)+xf_d3.*weight(:,:,4)); -------------------------------------------------------------------------------- /src/iterative_decomposition.m: -------------------------------------------------------------------------------- 1 | close all; 2 | clear all; 3 | clc; 4 | 5 | test_data_name = '../../data/test/test_data_cranial.mat'; 6 | 7 | resultPath = '../../result/'; 8 | result_name = [resultPath, 'result_itertive_decomposition.mat']; 9 | 10 | disp('Loading test data...'); 11 | load(test_data_name); 12 | 13 | %% 14 | mu_ = mu_bone_high*mu_tissue_low - mu_tissue_high*mu_bone_low; 15 | a = mu_tissue_low / mu_; 16 | b = -mu_tissue_high / mu_; 17 | c = -mu_bone_low / mu_; 18 | d = mu_bone_high / mu_; 19 | 20 | %% 21 | pcgmaxi = 500; 22 | pcgtol = 1e-12; 23 | beta1 = 2e-6; %1; 24 | beta2 = beta1*7; 25 | 26 | A = [a b; c d]; 27 | A = inv(A); 28 | A = kron(A,speye(512*512)); 29 | At = A'; 30 | 31 | [h, w, slice] = size(I_L); 32 | 33 | I_bone = zeros([h, w, slice], 'single'); 34 | I_tissue = zeros([h, w, slice], 'single'); 35 | 36 | img = zeros(h, 2*w); 37 | img_d = zeros(h, 2*w); 38 | for i = 1:slice 39 | disp(['Decomposing image ', num2str(i), '/', num2str(slice)]); 40 | img(:, 1:end/2) = I_H(:,:,i); 41 | img(:, (end/2 + 1):end) = I_L(:,:,i); 42 | img_d = [a*img(:,1:end/2) + b*img(:,end/2+1:end) c*img(:,1:end/2) + d*img(:,end/2+1:end)]; 43 | x = img_d(:); 44 | data = A'*img(:); 45 | ratio = max(data(:)); 46 | data = data / ratio; 47 | weight = get_weight(img, h, w); 48 | [x_1, flag, relres, iter, rv] = pcg(@AXfunc_pwls,data,pcgtol,pcgmaxi,[], ... 49 | [],x,A,At,weight,beta1,beta2); 50 | x_1 = reshape(x_1, [h, 2*w])*ratio; 51 | 52 | I_bone(:,:,i) = x_1(:, 1:end/2); 53 | I_tissue(:,:,i) = x_1(:, (end/2 + 1):end); 54 | end 55 | 56 | I_bone(I_bone < 0.00001) = 0; 57 | I_tissue(I_tissue < 0.00001) = 0; 58 | 59 | figure(1), imshow(I_bone, []); 60 | figure(2), imshow(I_tissue, []); 61 | 62 | disp(['Saving result data... ', result_name]); 63 | save(result_name, 'I_bone', 'I_tissue'); 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /src/layer_xyf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import tensorflow as tf 4 | 5 | def variable_summaries(var, var_name): 6 | with tf.name_scope(var_name): 7 | mean = tf.reduce_mean(var) 8 | tf.summary.scalar('mean', mean) 9 | with tf.name_scope('stddev'): 10 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 11 | tf.summary.scalar('stddev', stddev) 12 | tf.summary.scalar('max', tf.reduce_max(var)) 13 | tf.summary.scalar('min', tf.reduce_min(var)) 14 | tf.summary.histogram('histogram', var) 15 | 16 | def convo(input_layer, layer_name, filter_shape, stride, padStr): 17 | with tf.variable_scope(layer_name): 18 | input_shape = input_layer.get_shape().as_list() 19 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3]) 20 | 21 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1)) 22 | conv_b = tf.get_variable("bias", filter_shape[3], initializer = tf.constant_initializer(0.0)) 23 | conv = tf.nn.conv2d(input_layer, conv_w, [1, stride, stride, 1], padStr) 24 | relu = tf.nn.relu(tf.nn.bias_add(conv, conv_b)) 25 | 26 | variable_summaries(conv_w, "conv_w") 27 | variable_summaries(conv_b, "conv_b") 28 | variable_summaries(relu, "relu") 29 | return relu 30 | 31 | def convo_noneRelu(input_layer, layer_name, filter_shape, stride, padStr): 32 | with tf.variable_scope(layer_name): 33 | input_shape = input_layer.get_shape().as_list() 34 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3]) 35 | 36 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1)) 37 | conv_b = tf.get_variable("bias", filter_shape[3], initializer = tf.constant_initializer(0.0)) 38 | conv = tf.nn.conv2d(input_layer, conv_w, [1, stride, stride, 1], padStr) 39 | none_relu = tf.nn.bias_add(conv, conv_b) 40 | 41 | variable_summaries(conv_w, "conv_w") 42 | variable_summaries(conv_b, "conv_b") 43 | variable_summaries(none_relu, "none_relu") 44 | return none_relu 45 | 46 | def deconvo(input_layer, layer_name, filter_shape, out_img_size, stride, padStr): 47 | with tf.variable_scope(layer_name): 48 | input_shape = input_layer.get_shape().as_list() 49 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3]) 50 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1)) 51 | input_shape = tf.shape(input_layer) 52 | output_shape = tf.stack([input_shape[0], out_img_size[0], out_img_size[1], filter_shape[2]]) 53 | conv_b = tf.get_variable("bias", filter_shape[2], initializer = tf.constant_initializer(0.0)) 54 | conv = tf.nn.conv2d_transpose(input_layer, conv_w, output_shape, [1, stride, stride, 1], padStr) 55 | deconv = tf.nn.bias_add(conv, conv_b) 56 | 57 | variable_summaries(conv_w, "conv_w") 58 | variable_summaries(conv_b, "conv_b") 59 | variable_summaries(deconv, "deconv") 60 | return deconv 61 | 62 | def deconv_withRelu(input_layer, layer_name, filter_shape, out_img_size, stride, padStr): 63 | with tf.variable_scope(layer_name): 64 | input_shape = input_layer.get_shape().as_list() 65 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3]) 66 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1)) 67 | input_shape = tf.shape(input_layer) 68 | output_shape = tf.stack([input_shape[0], out_img_size[0], out_img_size[1], filter_shape[2]]) 69 | conv_b = tf.get_variable("bias", filter_shape[2], initializer = tf.constant_initializer(0.0)) 70 | conv = tf.nn.conv2d_transpose(input_layer, conv_w, output_shape, [1, stride, stride, 1], padStr) 71 | relu = tf.nn.relu(tf.nn.bias_add(conv, conv_b)) 72 | 73 | variable_summaries(conv_w, "conv_w") 74 | variable_summaries(conv_b, "conv_b") 75 | variable_summaries(relu, "relu") 76 | return relu 77 | 78 | def pooling(input_layer, layer_name, kernal_shape, stride, padStr): 79 | with tf.name_scope(layer_name): 80 | pool = tf.nn.max_pool(input_layer, kernal_shape, [1,stride,stride,1], padStr) 81 | return pool 82 | 83 | def pooling_withmax(input_layer, layer_name, kernal_shape, stride, padStr): 84 | with tf.name_scope(layer_name): 85 | return tf.nn.max_pool_with_argmax(input_layer, kernal_shape, [1,stride,stride,1], padStr) 86 | 87 | def unravel_argmax(argmax, shape): 88 | output_list = [] 89 | output_list.append(argmax // (shape[2] * shape[3])) 90 | output_list.append(argmax % (shape[2] * shape[3]) // shape[3]) 91 | return tf.stack(output_list) 92 | 93 | def unpooling_layer2x2(x, layer_name, raveled_argmax, out_shape): 94 | with tf.name_scope(layer_name): 95 | argmax = unravel_argmax(raveled_argmax, tf.to_int64(out_shape)) 96 | output = tf.zeros([out_shape[1], out_shape[2], out_shape[3]]) 97 | 98 | height = tf.shape(output)[0] 99 | width = tf.shape(output)[1] 100 | channels = tf.shape(output)[2] 101 | 102 | t1 = tf.to_int64(tf.range(channels)) 103 | t1 = tf.tile(t1, [((width + 1) // 2) * ((height + 1) // 2)]) 104 | t1 = tf.reshape(t1, [-1, channels]) 105 | t1 = tf.transpose(t1, perm=[1, 0]) 106 | t1 = tf.reshape(t1, [channels, (height + 1) // 2, (width + 1) // 2, 1]) 107 | 108 | t2 = tf.squeeze(argmax) 109 | t2 = tf.stack((t2[0], t2[1]), axis=0) 110 | t2 = tf.transpose(t2, perm=[3, 1, 2, 0]) 111 | 112 | t = tf.concat([t2, t1], 3) 113 | indices = tf.reshape(t, [((height + 1) // 2) * ((width + 1) // 2) * channels, 3]) 114 | 115 | x1 = tf.squeeze(x) 116 | x1 = tf.reshape(x1, [-1, channels]) 117 | x1 = tf.transpose(x1, perm=[1, 0]) 118 | values = tf.reshape(x1, [-1]) 119 | 120 | delta = tf.SparseTensor(indices, values, tf.to_int64(tf.shape(output))) 121 | return tf.expand_dims(tf.sparse_tensor_to_dense(tf.sparse_reorder(delta)), 0) 122 | 123 | def unpooling_layer2x2_batch(bottom, layer_name, argmax): 124 | with tf.name_scope(layer_name): 125 | bottom_shape = tf.shape(bottom) 126 | top_shape = [bottom_shape[0], bottom_shape[1] * 2, bottom_shape[2] * 2, bottom_shape[3]] 127 | 128 | batch_size = top_shape[0] 129 | height = top_shape[1] 130 | width = top_shape[2] 131 | channels = top_shape[3] 132 | 133 | argmax_shape = tf.to_int64([batch_size, height, width, channels]) 134 | argmax = unravel_argmax(argmax, argmax_shape) 135 | 136 | t1 = tf.to_int64(tf.range(channels)) 137 | t1 = tf.tile(t1, [batch_size * (width // 2) * (height // 2)]) 138 | t1 = tf.reshape(t1, [-1, channels]) 139 | t1 = tf.transpose(t1, perm=[1, 0]) 140 | t1 = tf.reshape(t1, [channels, batch_size, height // 2, width // 2, 1]) 141 | t1 = tf.transpose(t1, perm=[1, 0, 2, 3, 4]) 142 | 143 | t2 = tf.to_int64(tf.range(batch_size)) 144 | t2 = tf.tile(t2, [channels * (width // 2) * (height // 2)]) 145 | t2 = tf.reshape(t2, [-1, batch_size]) 146 | t2 = tf.transpose(t2, perm=[1, 0]) 147 | t2 = tf.reshape(t2, [batch_size, channels, height // 2, width // 2, 1]) 148 | 149 | t3 = tf.transpose(argmax, perm=[1, 4, 2, 3, 0]) 150 | 151 | t = tf.concat([t2, t3, t1], 4) 152 | indices = tf.reshape(t, [(height // 2) * (width // 2) * channels * batch_size, 4]) 153 | 154 | x1 = tf.transpose(bottom, perm=[0, 3, 1, 2]) 155 | values = tf.reshape(x1, [-1]) 156 | return tf.scatter_nd(indices, values, tf.to_int64(top_shape)) 157 | 158 | def fullyconnect(input_layer, layer_name, input_size, output_size, regularizer): 159 | with tf.variable_scope(layer_name): 160 | fc_w = tf.get_variable("weight", [input_size, output_size], 161 | initializer = tf.truncated_normal_initializer(stddev = 0.1)) 162 | if regularizer != 0: 163 | tf.add_to_collection('losses', regularizer(fc_w)) 164 | fc_b = tf.get_variable("bias", [output_size], initializer = tf.constant_initializer(0.1)) 165 | fc = tf.nn.sigmoid(tf.matmul(input_layer, fc_w) + fc_b) 166 | #if tain: fc = tf.nn.dropout(fc, 0.5) 167 | return fc 168 | 169 | def linearlyconnect(input_layer, layer_name, input_size, output_size, regularizer): 170 | with tf.variable_scope(layer_name): 171 | fc_w = tf.get_variable("weight", [input_size, output_size], 172 | initializer = tf.truncated_normal_initializer(stddev = 0.1)) 173 | if regularizer != 0: 174 | tf.add_to_collection('losses', regularizer(fc_w)) 175 | 176 | fc_b = tf.get_variable("bias", [output_size], initializer = tf.constant_initializer(0.1)) 177 | logit = tf.matmul(input_layer, fc_w) + fc_b 178 | return logit -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tensorflow as tf 4 | import pprint 5 | 6 | from model import DecompositionModel 7 | from solver import nn_train, nn_feedforward 8 | 9 | flags = tf.app.flags 10 | flags.DEFINE_string("dataset", "data", ".tfRecord file [data]") 11 | flags.DEFINE_string("mode", "train", "train, train_goon or feedforward [train]") 12 | flags.DEFINE_string("model_name", "model", "The name of model [model]") 13 | flags.DEFINE_integer("epoch", 30, "Epoch to train [500]") 14 | flags.DEFINE_integer("batch_size", 10, "The size of batch images [10]") 15 | flags.DEFINE_integer("model_step", 1000, "The number of iteration to save model [1000]") 16 | flags.DEFINE_float("lr", 0.01, "The base learning rate [0.01]") 17 | flags.DEFINE_string("checkpoint", "checkpoint", "The path of checkpoint") 18 | 19 | FLAGS = flags.FLAGS 20 | 21 | def main(_): 22 | os.chdir(sys.path[0]) 23 | print("Current cwd: ", os.getcwd()) 24 | 25 | FLAGS = flags.FLAGS 26 | FLAGS.goon = False 27 | FLAGS.output_model_dir = "../checkpoint" 28 | FLAGS.output_mat_dir = "../result" 29 | 30 | if not os.path.exists(FLAGS.output_model_dir): 31 | os.makedirs(FLAGS.output_model_dir) 32 | if not os.path.exists(FLAGS.output_mat_dir): 33 | os.makedirs(FLAGS.output_mat_dir) 34 | 35 | pp = pprint.PrettyPrinter() 36 | pp.pprint(FLAGS.__flags) 37 | 38 | run_config = tf.ConfigProto() 39 | run_config.gpu_options.allow_growth = True 40 | 41 | with tf.Session(config = run_config) as sess: 42 | print("Building Decomposition model ...") 43 | nn = DecompositionModel(sess, epoch = FLAGS.epoch, 44 | batch_size = FLAGS.batch_size, 45 | model_name = FLAGS.model_name) 46 | 47 | if FLAGS.mode == "train": 48 | print("training ...") 49 | nn_train(nn, sess, FLAGS) 50 | elif FLAGS.mode == "feedforward": 51 | print("caculating ...") 52 | nn_feedforward(nn, sess, FLAGS) 53 | elif FLAGS.mode == "train_goon": 54 | print("go on training ...") 55 | FLAGS.goon = True 56 | nn_train(sess, FLAGS) 57 | 58 | pp.pprint(FLAGS.__flags) 59 | 60 | sess.close() 61 | print("Sess closed.") 62 | exit() 63 | 64 | if __name__ == '__main__': 65 | tf.app.run() -------------------------------------------------------------------------------- /src/matrix_inversion.m: -------------------------------------------------------------------------------- 1 | clear all; 2 | close all; 3 | clc; 4 | 5 | test_data_name = '../../data/test/test_data_cranial.mat'; 6 | 7 | resultPath = '../../result/'; 8 | result_name = [resultPath, 'result_matrix_inversion.mat']; 9 | 10 | disp('Loading test data...'); 11 | load(test_data_name); 12 | 13 | %% 14 | mu_ = mu_bone_high*mu_tissue_low - mu_tissue_high*mu_bone_low; 15 | yita_11 = mu_tissue_low / mu_; 16 | yita_12 = -mu_tissue_high / mu_; 17 | yita_21 = -mu_bone_low / mu_; 18 | yita_22 = mu_bone_high / mu_; 19 | 20 | %% 21 | [h, w, slice] = size(I_L); 22 | I_bone = zeros(h, w, 'single'); 23 | I_tissue = zeros(h, w, 'single'); 24 | 25 | for i = 1:slice 26 | I_bone(:,:,i) = yita_11*I_H(:,:,i) + yita_12*I_L(:,:,i); 27 | I_tissue(:,:,i) = yita_21*I_H(:,:,i) + yita_22*I_L(:,:,i); 28 | end 29 | 30 | I_bone(I_bone < 0.00001) = 0; 31 | I_tissue(I_tissue < 0.00001) = 0; 32 | 33 | figure(1), imshow(I_bone, []); 34 | figure(2), imshow(I_tissue, []); 35 | 36 | disp(['Saving result data... ', result_name]); 37 | save(result_name, 'I_bone', 'I_tissue'); 38 | 39 | 40 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import layer_xyf 3 | 4 | #input size 5 | INPUT_W = 65 6 | INPUT_H = 65 7 | INPUT_DEPTH = 1 8 | 9 | # model definition 10 | #layer_1 input size 65*65 [filter_height, filter_width, in_channels, out_channels] 11 | FILTER_1 = [5, 5, 1, 64] 12 | STRIDE_1 = 2 13 | PAD_1 = "SAME" 14 | 15 | #layer_2 input_size 33*33 16 | FILTER_2 = [5, 5, 64, 128] 17 | STRIDE_2 = 2 18 | PAD_2 = "SAME" 19 | 20 | #layer_3 input_size 17*17 21 | FILTER_3 = [5, 5, 128, 256] 22 | STRIDE_3 = 2 23 | PAD_3 = "SAME" 24 | 25 | #layer_4 input_size 9*9 26 | FILTER_4 = [9, 9, 256, 256] 27 | STRIDE_4 = 1 28 | PAD_4 = "VALID" 29 | 30 | #layer_5 input_size 1*1*2048 31 | FILTER_5 = [1, 1, 512, 1] 32 | STRIDE_5 = 1 33 | PAD_5 = "VALID" 34 | 35 | class DecompositionModel(object): 36 | def __init__(self, sess, epoch, batch_size = 64, model_name = 'model'): 37 | self.sess = sess 38 | self.epoch = epoch 39 | self.input_width = INPUT_W 40 | self.input_height = INPUT_H 41 | self.batch_size = batch_size 42 | self.model_name = model_name 43 | 44 | def feedforward(self, input_L_batch, input_H_batch, regularizer = 0): 45 | layer_L1 = layer_xyf.convo(input_L_batch, "conv_L1", FILTER_1, STRIDE_1, PAD_1) 46 | layer_L2 = layer_xyf.convo(layer_L1, "conv_L2", FILTER_2, STRIDE_2, PAD_2) 47 | layer_L3 = layer_xyf.convo(layer_L2, "conv_L3", FILTER_3, STRIDE_3, PAD_3) 48 | layer_L4 = layer_xyf.convo(layer_L3, "conv_L4", FILTER_4, STRIDE_4, PAD_4) 49 | 50 | layer_H1 = layer_xyf.convo(input_H_batch, "conv_H1", FILTER_1, STRIDE_1, PAD_1) 51 | layer_H2 = layer_xyf.convo(layer_H1, "conv_H2", FILTER_2, STRIDE_2, PAD_2) 52 | layer_H3 = layer_xyf.convo(layer_H2, "conv_H3", FILTER_3, STRIDE_3, PAD_3) 53 | layer_H4 = layer_xyf.convo(layer_H3, "conv_H4", FILTER_4, STRIDE_4, PAD_4) 54 | 55 | combine_LH = tf.concat([layer_L4, layer_H4], 3) 56 | 57 | layer_bone = layer_xyf.convo_noneRelu(combine_LH, "conv_bone", FILTER_5, STRIDE_5, PAD_5) 58 | layer_tissue = layer_xyf.convo_noneRelu(combine_LH, "conv_tissue", FILTER_5, STRIDE_5, PAD_5) 59 | return layer_bone, layer_tissue 60 | 61 | def feedforward_test(self, input_L_batch, input_H_batch, regularizer = 0): 62 | layer_L1 = layer_xyf.convo(input_L_batch, "conv_L1", FILTER_1, 1, PAD_1) 63 | layer_L2 = layer_xyf.convo(layer_L1, "conv_L2", FILTER_2, 1, PAD_2) 64 | layer_L3 = layer_xyf.convo(layer_L2, "conv_L3", FILTER_3, 1, PAD_3) 65 | layer_L4 = layer_xyf.convo(layer_L3, "conv_L4", FILTER_4, STRIDE_4, PAD_4) 66 | 67 | layer_H1 = layer_xyf.convo(input_H_batch, "conv_H1", FILTER_1, 1, PAD_1) 68 | layer_H2 = layer_xyf.convo(layer_H1, "conv_H2", FILTER_2, 1, PAD_2) 69 | layer_H3 = layer_xyf.convo(layer_H2, "conv_H3", FILTER_3, 1, PAD_3) 70 | layer_H4 = layer_xyf.convo(layer_H3, "conv_H4", FILTER_4, STRIDE_4, PAD_4) 71 | 72 | combine_LH = tf.concat([layer_L4, layer_H4], 3) 73 | 74 | layer_bone = layer_xyf.convo_noneRelu(combine_LH, "conv_bone", FILTER_5, STRIDE_5, "SAME") 75 | layer_tissue = layer_xyf.convo_noneRelu(combine_LH, "conv_tissue", FILTER_5, STRIDE_5, "SAME") 76 | return layer_bone, layer_tissue 77 | 78 | def input_layer_width(self): 79 | return self.input_width 80 | 81 | def input_layer_height(self): 82 | return self.input_height 83 | -------------------------------------------------------------------------------- /src/solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io 3 | import time 4 | import tensorflow as tf 5 | import numpy as np 6 | import math 7 | 8 | LEARNING_RATE_DECAY = 0.85 9 | MOVING_AVERAGE_DECAY = 0.99 10 | REGULARAZTION_RATE = 0.0001 11 | 12 | NUM_SAMPLE = 2454300 13 | 14 | SUMMARY_DIR = "../log" 15 | SUMMARY_STEP = 100 16 | 17 | INPUT_H = 65 18 | INPUT_W = 65 19 | INPUT_CHANNEL = 1 20 | 21 | OUTPUT_H = 1 22 | OUTPUT_W = 1 23 | OUTPUT_CHANNEL = 1 24 | 25 | def nn_train(nn, sess, config): 26 | # read data 27 | print ("Train, loading .tfrecords data...", config.dataset) 28 | if not os.path.exists(config.dataset): 29 | raise Exception("training data not find.") 30 | return 31 | 32 | config.sampleNum = NUM_SAMPLE 33 | config.iteration = math.ceil( config.sampleNum / config.batch_size * config.epoch ) 34 | 35 | time_str = time.strftime('%Y-%m-%d-%H_%M_%S', time.localtime(time.time())) 36 | 37 | train_file_path = os.path.join(config.dataset, "train_sample_batches_*") 38 | train_files = tf.train.match_filenames_once(train_file_path) 39 | filename_queue = tf.train.string_input_producer(train_files, shuffle = True) 40 | 41 | reader = tf.TFRecordReader() 42 | _, serialized_example = reader.read(filename_queue) 43 | features = tf.parse_single_example(serialized_example, 44 | features = {'train_x_low': tf.FixedLenFeature([], tf.string), 45 | 'train_x_high': tf.FixedLenFeature([], tf.string), 46 | 'train_y_bone': tf.FixedLenFeature([], tf.string), 47 | 'train_y_tissue': tf.FixedLenFeature([], tf.string)}) 48 | 49 | train_x_low = tf.decode_raw(features['train_x_low'], tf.float32) 50 | train_x_high = tf.decode_raw(features['train_x_high'], tf.float32) 51 | train_y_bone = tf.decode_raw(features['train_y_bone'], tf.float32) 52 | train_y_tissue = tf.decode_raw(features['train_y_tissue'], tf.float32) 53 | 54 | train_x_low = tf.reshape(train_x_low, [INPUT_H, INPUT_W, INPUT_CHANNEL]) 55 | train_x_high = tf.reshape(train_x_high, [INPUT_H, INPUT_W, INPUT_CHANNEL]) 56 | train_y_bone = tf.reshape(train_y_bone, [OUTPUT_H, OUTPUT_W, OUTPUT_CHANNEL]) 57 | train_y_tissue = tf.reshape(train_y_tissue, [OUTPUT_H, OUTPUT_W, OUTPUT_CHANNEL]) 58 | 59 | train_x_low_batch, train_x_high_batch, train_y_bone_batch, train_y_tissue_batch = tf.train.shuffle_batch([train_x_low, train_x_high, train_y_bone, train_y_tissue], 60 | batch_size = config.batch_size, 61 | capacity = config.batch_size*3 + 1000, 62 | min_after_dequeue = 1000) 63 | # loss function define 64 | ouput_bone_batch, ouput_tissue_batch = nn.feedforward(train_x_low_batch, train_x_high_batch, 0) 65 | global_step = tf.Variable(0, trainable = False) 66 | variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step) 67 | variables_averages_op = variable_averages.apply(tf.trainable_variables()) 68 | 69 | mse_loss = tf.reduce_mean(tf.square(ouput_bone_batch - train_y_bone_batch) + tf.square(ouput_tissue_batch - train_y_tissue_batch)) 70 | tf.add_to_collection('losses', mse_loss) 71 | tf.summary.scalar('MSE_losses', mse_loss) 72 | 73 | learning_rate = tf.train.exponential_decay(config.lr, global_step, 74 | config.sampleNum / config.batch_size, LEARNING_RATE_DECAY) 75 | train_step = tf.train.AdamOptimizer(learning_rate).minimize(mse_loss, global_step = global_step) 76 | 77 | with tf.control_dependencies([train_step, variables_averages_op]): 78 | train_op = tf.no_op(name = 'train') 79 | 80 | merged = tf.summary.merge_all() 81 | summary_path = os.path.join(SUMMARY_DIR, time_str) 82 | os.mkdir(summary_path) 83 | summary_writer = tf.summary.FileWriter(summary_path, sess.graph) 84 | 85 | saver = tf.train.Saver() 86 | if config.goon: 87 | if not os.path.exists(config.checkpoint): 88 | raise Exception("checkpoint path not find.") 89 | return 90 | print("Loading trained model... ", config.checkpoint) 91 | ckpt = tf.train.get_checkpoint_state(config.checkpoint) 92 | saver.retore(sess, ckpt) 93 | else: 94 | init = (tf.global_variables_initializer(), tf.local_variables_initializer()) 95 | sess.run(init) 96 | #with tf.Session() as sess: 97 | #init = tf.global_variables_initializer() 98 | 99 | save_model_path = os.path.join(config.output_model_dir, time_str) 100 | os.mkdir(save_model_path) 101 | 102 | print("start training sess...") 103 | coord = tf.train.Coordinator() 104 | threads = tf.train.start_queue_runners(sess = sess, coord = coord) 105 | 106 | start_time = time.time() 107 | for i in range(config.iteration): 108 | batch_start_time = time.time() 109 | summary, _, loss_value, step, learningRate = sess.run([merged, train_op, mse_loss, global_step, learning_rate]) 110 | batch_end_time = time.time() 111 | sec_per_batch = batch_end_time - batch_start_time 112 | if step % SUMMARY_STEP == 0: 113 | summary_writer.add_summary(summary, step//SUMMARY_STEP) 114 | if step % config.model_step == 0: 115 | print("Saving model (after %d iteration)... " %(step)) 116 | saver.save(sess, os.path.join(save_model_path, config.model_name + ".ckpt"), global_step = global_step) 117 | print("sec/batch(%d) %gs, global step %d batches, training epoch %d/%d, learningRate %g, loss on training is %g" 118 | % (config.batch_size, sec_per_batch, step, i*config.batch_size / config.sampleNum, 119 | config.epoch, learningRate, loss_value)) 120 | print("Elapsed time: %g" %(batch_end_time - start_time)) 121 | 122 | coord.request_stop() 123 | coord.join(threads) 124 | summary_writer.close() 125 | print("Train done. ") 126 | print("Saving model... ", save_model_path) 127 | saver.save(sess, os.path.join(save_model_path, config.model_name + ".ckpt"), global_step = global_step) 128 | 129 | def nn_feedforward(nn, sess, config): 130 | print ("Feed forward, loading .mat data...", config.dataset) 131 | if not os.path.exists(config.dataset): 132 | raise Exception(".mat file not find.") 133 | return 134 | 135 | mat = scipy.io.loadmat(config.dataset) 136 | input_xl = mat['I_L'] 137 | input_xh = mat['I_H'] 138 | 139 | dim_h = input_xl.shape[0] 140 | dim_w = input_xl.shape[1] 141 | if input_xl.ndim < 3: 142 | input_xl = input_xl.reshape(dim_h, dim_w, 1) 143 | input_xh = input_xh.reshape(dim_h, dim_w, 1) 144 | 145 | batch_size = input_xl.shape[2] 146 | print('image num: ', batch_size) 147 | print('input type: ', input_xl.dtype) 148 | 149 | x_L = tf.placeholder(tf.float32, [batch_size, dim_h, dim_w, 1], name = 'xl-input') 150 | x_H = tf.placeholder(tf.float32, [batch_size, dim_h, dim_w, 1], name = 'xh-input') 151 | y_bone, y_tissue = nn.feedforward_test(x_L, x_H, None) 152 | 153 | output_dim = 0 154 | I_bone = np.array([], dtype = np.float32) 155 | I_tissue = np.array([], dtype = np.float32) 156 | 157 | variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY) 158 | variables_to_restore = variable_averages.variables_to_restore() 159 | saver = tf.train.Saver(variables_to_restore) 160 | 161 | ckpt = tf.train.get_checkpoint_state(config.checkpoint) 162 | if ckpt and ckpt.model_checkpoint_path: 163 | saver.restore(sess, ckpt.model_checkpoint_path) 164 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-'[-1]) 165 | print('global step: ', global_step) 166 | for i in range(batch_size): 167 | xl = input_xl[:,:,i].reshape([1, dim_w, dim_h, 1]) 168 | xh = input_xh[:,:,i].reshape([1, dim_w, dim_h, 1]) 169 | print('xl: ', xl.shape) 170 | output_bone, output_tissue = sess.run([y_bone, y_tissue], feed_dict={x_L: xl, x_H: xh}) 171 | print('TF Done') 172 | if i == 0: 173 | output_shape = output_bone.shape 174 | print("output shape: ", output_shape) 175 | output_dim = output_bone.shape[1] 176 | I_bone = np.zeros((output_dim, output_dim, batch_size), dtype = np.float32) 177 | I_tissue = np.zeros((output_dim, output_dim, batch_size), dtype = np.float32) 178 | 179 | output_bone.resize(output_dim, output_dim) 180 | output_tissue.resize(output_dim, output_dim) 181 | I_bone[:,:,i] = output_bone 182 | I_tissue[:,:,i] = output_tissue 183 | 184 | result_fileName = config.output_mat_dir + '/' + 'result_' \ 185 | + config.model_name + '.mat' 186 | scipy.io.savemat(result_fileName, {'I_bone':I_bone, 'I_tissue':I_tissue}) 187 | print('Feed forward done. Result: ', result_fileName) 188 | else: 189 | print('No checkpoint file found.') 190 | return 191 | --------------------------------------------------------------------------------