├── .gitignore
├── README.md
├── checkpoint
└── checkpoint path for Tensorflow
├── data
└── test
│ ├── test_data_cranial.mat
│ └── test_data_pleural.mat
├── result
├── result_FCN.mat
├── result_itertive_decomposition.mat
└── result_matrix_inversion.mat
└── src
├── AXfunc_pwls.m
├── get_weight.m
├── gradient_LS.m
├── iterative_decomposition.m
├── layer_xyf.py
├── main.py
├── matrix_inversion.m
├── model.py
└── solver.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ImageDecomposition-DECT
2 |
3 | This reposity is organized mainly for an image decomposition algorithm which is proposed to solve the material decomposition problem in Dual-energy Computed Tomography (DECT).
4 |
5 | The algorithm is designed based on deep learning paradigm. For more theoretical details, please go to [Deep Learning](http://www.deeplearningbook.org/) and [Material Decomposition Using DECT](https://pubs.rsna.org/doi/10.1148/rg.2016150220).
6 |
7 | The algorithm is related to the paper ["Image Decomposition Algorithm for Dual-Energy Computed Tomography via Fully Convolutional Network".](https://www.hindawi.com/journals/cmmm/2018/2527516/cta/) (DOI: 10.1155/2018/2527516)
8 |
9 | All have been tested with python 3.6 and tensorflow 1.4.0 in Linux.
10 | * checkpoint: the checkpoint path for the model trained with tensorflow. The [pre-trained model](https://pan.baidu.com/s/1r1OTjid2muWWZfxURB8Pjw) was trained on a dataset which contained totally 2,454,300 samples. Each sample is a 65*65 image patch extracted from 5987 image slices.
11 | * data: contains 2 path.
12 | * test: two test data files, 'test_data_cranial.mat' and 'test_data_pleural.mat'.
13 | * train: we only provide a sub-set (90,000 training samples) in the 'training_samples_90000.rar' file which can be download from [here](https://pan.baidu.com/s/1r1OTjid2muWWZfxURB8Pjw).
14 | * result: save decomposition result.
15 | * src: the codes for three decomposition algorithms:
16 | * Direct matrix inversion (matrix_inversion.m)
17 | * Iterative decomposition (iterative_decomposition.m). Related paper: [Iterative image-domain decomposition for dual-energy CT](https://aapm.onlinelibrary.wiley.com/doi/abs/10.1118/1.4889338)
18 | * The proposed deep model (main.py). After download the pre-trained mode, you can use the following command to run the algorithm.
19 | >> " python main.py --dataset="../data/test/test_data_crainal.mat" --model="feedforward" --model_name="your-saved-result-name" --checkpoint="../checkpoint/FCN_trained_model "
20 |
21 | # Contact
22 | Email: vastcyclone@yeah.net
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/checkpoint/checkpoint path for Tensorflow:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/test/test_data_cranial.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XYF-GitHub/ImageDecomposition-DECT/d8c1c6ba078e24f54c723d703927823de08b1712/data/test/test_data_cranial.mat
--------------------------------------------------------------------------------
/data/test/test_data_pleural.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XYF-GitHub/ImageDecomposition-DECT/d8c1c6ba078e24f54c723d703927823de08b1712/data/test/test_data_pleural.mat
--------------------------------------------------------------------------------
/result/result_FCN.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XYF-GitHub/ImageDecomposition-DECT/d8c1c6ba078e24f54c723d703927823de08b1712/result/result_FCN.mat
--------------------------------------------------------------------------------
/result/result_itertive_decomposition.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XYF-GitHub/ImageDecomposition-DECT/d8c1c6ba078e24f54c723d703927823de08b1712/result/result_itertive_decomposition.mat
--------------------------------------------------------------------------------
/result/result_matrix_inversion.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XYF-GitHub/ImageDecomposition-DECT/d8c1c6ba078e24f54c723d703927823de08b1712/result/result_matrix_inversion.mat
--------------------------------------------------------------------------------
/src/AXfunc_pwls.m:
--------------------------------------------------------------------------------
1 | function [y] = AXfunc_pwls(x,A,At,weight,beta1,beta2)
2 | %%% description:
3 |
4 | % qCGMRF weighting
5 | % r = [beta1*gradient_qCGMRF(x(1:end/2), weight) ...
6 | % beta2*gradient_qCGMRF(x(end/2+1:end), weight)];
7 |
8 | % quadratic weighting
9 | r = [beta1*gradient_LS(x(1:end/2), weight) ...
10 | beta2*gradient_LS(x(end/2+1:end), weight)];
11 |
12 |
13 | r = r(:);
14 | y = (At*(A*x)) + r;
15 |
--------------------------------------------------------------------------------
/src/get_weight.m:
--------------------------------------------------------------------------------
1 | function w = get_weight(img, h, w)
2 | w = ones([[h, w] 4]);
3 | std_1 = std(img(1:end/2));
4 | std_2 = std(img(end/2 + 1:end));
5 | std_ratio = std_1/std_2;
6 |
7 | BW = edge(img(:,1:end/2) + img(:,end/2 + 1:end)*std_ratio,'canny',0.1,1.5);
8 | BW_2 = edge(img(:,1:end/2) + img(:,end/2 + 1:end)*std_ratio,'prewitt',0.1,'both');
9 | BW = BW + BW_2;
10 |
11 | dummy = BW;
12 | g1_d2 = dummy - [dummy(2:end,:); zeros(1, size(dummy,1))];
13 | g1_d1 = dummy - [zeros(1, size(dummy,1)); dummy(1:end - 1,:)];
14 | g1_d0 = dummy - [zeros(size(dummy,1),1) dummy(:,1:end - 1)];
15 | g1_d3 = dummy - [dummy(:,2:end) zeros(size(dummy,1),1)];
16 |
17 | thresh = 0.0001;
18 | Lvalue = 0.1;
19 | dummy = ones([512 512]);
20 | dummy(abs(g1_d0) > thresh) = Lvalue;
21 | w(:,:,1) = dummy;
22 | dummy = ones([512 512]);
23 | dummy(abs(g1_d1) > thresh) = Lvalue;
24 | w(:,:,2) = dummy;
25 | dummy = ones([512 512]);
26 | dummy(abs(g1_d2) > thresh) = Lvalue;
27 | w(:,:,3) = dummy;
28 | dummy = ones([512 512]);
29 | dummy(abs(g1_d3) > thresh) = Lvalue;
30 | w(:,:,4) = dummy;
31 | end
--------------------------------------------------------------------------------
/src/gradient_LS.m:
--------------------------------------------------------------------------------
1 | function gf = gradient_LS(x,weight)
2 |
3 |
4 | %%
5 | % g1_d2 = dummy-[dummy(2:end,:); zeros(1, size(dummy,1))];
6 | % g1_d1 = dummy-[zeros(1, size(dummy,1)); dummy(1:end-1,:)];
7 | % g1_d0 = dummy-[zeros(size(dummy,1),1) dummy(:,1:end-1)];
8 | % g1_d3 = dummy-[dummy(:,2:end) zeros(size(dummy,1),1)];
9 | %%
10 | [row,col,~]=size(weight);
11 | xf = reshape(x,row,col);
12 |
13 | xf_d0 = xf-[zeros(size(xf,1),1) xf(:,1:end-1)];
14 | xf_d1 = xf-[zeros(1, size(xf,1)); xf(1:end-1,:)];
15 | xf_d2 = xf-[xf(2:end,:); zeros(1, size(xf,1))];
16 | xf_d3 = xf-[xf(:,2:end) zeros(size(xf,1),1)];
17 |
18 | gf = 4*(xf_d0.*weight(:,:,1)+xf_d1.*weight(:,:,2)+xf_d2.*weight(:,:,3)+xf_d3.*weight(:,:,4));
--------------------------------------------------------------------------------
/src/iterative_decomposition.m:
--------------------------------------------------------------------------------
1 | close all;
2 | clear all;
3 | clc;
4 |
5 | test_data_name = '../../data/test/test_data_cranial.mat';
6 |
7 | resultPath = '../../result/';
8 | result_name = [resultPath, 'result_itertive_decomposition.mat'];
9 |
10 | disp('Loading test data...');
11 | load(test_data_name);
12 |
13 | %%
14 | mu_ = mu_bone_high*mu_tissue_low - mu_tissue_high*mu_bone_low;
15 | a = mu_tissue_low / mu_;
16 | b = -mu_tissue_high / mu_;
17 | c = -mu_bone_low / mu_;
18 | d = mu_bone_high / mu_;
19 |
20 | %%
21 | pcgmaxi = 500;
22 | pcgtol = 1e-12;
23 | beta1 = 2e-6; %1;
24 | beta2 = beta1*7;
25 |
26 | A = [a b; c d];
27 | A = inv(A);
28 | A = kron(A,speye(512*512));
29 | At = A';
30 |
31 | [h, w, slice] = size(I_L);
32 |
33 | I_bone = zeros([h, w, slice], 'single');
34 | I_tissue = zeros([h, w, slice], 'single');
35 |
36 | img = zeros(h, 2*w);
37 | img_d = zeros(h, 2*w);
38 | for i = 1:slice
39 | disp(['Decomposing image ', num2str(i), '/', num2str(slice)]);
40 | img(:, 1:end/2) = I_H(:,:,i);
41 | img(:, (end/2 + 1):end) = I_L(:,:,i);
42 | img_d = [a*img(:,1:end/2) + b*img(:,end/2+1:end) c*img(:,1:end/2) + d*img(:,end/2+1:end)];
43 | x = img_d(:);
44 | data = A'*img(:);
45 | ratio = max(data(:));
46 | data = data / ratio;
47 | weight = get_weight(img, h, w);
48 | [x_1, flag, relres, iter, rv] = pcg(@AXfunc_pwls,data,pcgtol,pcgmaxi,[], ...
49 | [],x,A,At,weight,beta1,beta2);
50 | x_1 = reshape(x_1, [h, 2*w])*ratio;
51 |
52 | I_bone(:,:,i) = x_1(:, 1:end/2);
53 | I_tissue(:,:,i) = x_1(:, (end/2 + 1):end);
54 | end
55 |
56 | I_bone(I_bone < 0.00001) = 0;
57 | I_tissue(I_tissue < 0.00001) = 0;
58 |
59 | figure(1), imshow(I_bone, []);
60 | figure(2), imshow(I_tissue, []);
61 |
62 | disp(['Saving result data... ', result_name]);
63 | save(result_name, 'I_bone', 'I_tissue');
64 |
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/src/layer_xyf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import tensorflow as tf
4 |
5 | def variable_summaries(var, var_name):
6 | with tf.name_scope(var_name):
7 | mean = tf.reduce_mean(var)
8 | tf.summary.scalar('mean', mean)
9 | with tf.name_scope('stddev'):
10 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
11 | tf.summary.scalar('stddev', stddev)
12 | tf.summary.scalar('max', tf.reduce_max(var))
13 | tf.summary.scalar('min', tf.reduce_min(var))
14 | tf.summary.histogram('histogram', var)
15 |
16 | def convo(input_layer, layer_name, filter_shape, stride, padStr):
17 | with tf.variable_scope(layer_name):
18 | input_shape = input_layer.get_shape().as_list()
19 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3])
20 |
21 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1))
22 | conv_b = tf.get_variable("bias", filter_shape[3], initializer = tf.constant_initializer(0.0))
23 | conv = tf.nn.conv2d(input_layer, conv_w, [1, stride, stride, 1], padStr)
24 | relu = tf.nn.relu(tf.nn.bias_add(conv, conv_b))
25 |
26 | variable_summaries(conv_w, "conv_w")
27 | variable_summaries(conv_b, "conv_b")
28 | variable_summaries(relu, "relu")
29 | return relu
30 |
31 | def convo_noneRelu(input_layer, layer_name, filter_shape, stride, padStr):
32 | with tf.variable_scope(layer_name):
33 | input_shape = input_layer.get_shape().as_list()
34 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3])
35 |
36 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1))
37 | conv_b = tf.get_variable("bias", filter_shape[3], initializer = tf.constant_initializer(0.0))
38 | conv = tf.nn.conv2d(input_layer, conv_w, [1, stride, stride, 1], padStr)
39 | none_relu = tf.nn.bias_add(conv, conv_b)
40 |
41 | variable_summaries(conv_w, "conv_w")
42 | variable_summaries(conv_b, "conv_b")
43 | variable_summaries(none_relu, "none_relu")
44 | return none_relu
45 |
46 | def deconvo(input_layer, layer_name, filter_shape, out_img_size, stride, padStr):
47 | with tf.variable_scope(layer_name):
48 | input_shape = input_layer.get_shape().as_list()
49 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3])
50 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1))
51 | input_shape = tf.shape(input_layer)
52 | output_shape = tf.stack([input_shape[0], out_img_size[0], out_img_size[1], filter_shape[2]])
53 | conv_b = tf.get_variable("bias", filter_shape[2], initializer = tf.constant_initializer(0.0))
54 | conv = tf.nn.conv2d_transpose(input_layer, conv_w, output_shape, [1, stride, stride, 1], padStr)
55 | deconv = tf.nn.bias_add(conv, conv_b)
56 |
57 | variable_summaries(conv_w, "conv_w")
58 | variable_summaries(conv_b, "conv_b")
59 | variable_summaries(deconv, "deconv")
60 | return deconv
61 |
62 | def deconv_withRelu(input_layer, layer_name, filter_shape, out_img_size, stride, padStr):
63 | with tf.variable_scope(layer_name):
64 | input_shape = input_layer.get_shape().as_list()
65 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3])
66 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1))
67 | input_shape = tf.shape(input_layer)
68 | output_shape = tf.stack([input_shape[0], out_img_size[0], out_img_size[1], filter_shape[2]])
69 | conv_b = tf.get_variable("bias", filter_shape[2], initializer = tf.constant_initializer(0.0))
70 | conv = tf.nn.conv2d_transpose(input_layer, conv_w, output_shape, [1, stride, stride, 1], padStr)
71 | relu = tf.nn.relu(tf.nn.bias_add(conv, conv_b))
72 |
73 | variable_summaries(conv_w, "conv_w")
74 | variable_summaries(conv_b, "conv_b")
75 | variable_summaries(relu, "relu")
76 | return relu
77 |
78 | def pooling(input_layer, layer_name, kernal_shape, stride, padStr):
79 | with tf.name_scope(layer_name):
80 | pool = tf.nn.max_pool(input_layer, kernal_shape, [1,stride,stride,1], padStr)
81 | return pool
82 |
83 | def pooling_withmax(input_layer, layer_name, kernal_shape, stride, padStr):
84 | with tf.name_scope(layer_name):
85 | return tf.nn.max_pool_with_argmax(input_layer, kernal_shape, [1,stride,stride,1], padStr)
86 |
87 | def unravel_argmax(argmax, shape):
88 | output_list = []
89 | output_list.append(argmax // (shape[2] * shape[3]))
90 | output_list.append(argmax % (shape[2] * shape[3]) // shape[3])
91 | return tf.stack(output_list)
92 |
93 | def unpooling_layer2x2(x, layer_name, raveled_argmax, out_shape):
94 | with tf.name_scope(layer_name):
95 | argmax = unravel_argmax(raveled_argmax, tf.to_int64(out_shape))
96 | output = tf.zeros([out_shape[1], out_shape[2], out_shape[3]])
97 |
98 | height = tf.shape(output)[0]
99 | width = tf.shape(output)[1]
100 | channels = tf.shape(output)[2]
101 |
102 | t1 = tf.to_int64(tf.range(channels))
103 | t1 = tf.tile(t1, [((width + 1) // 2) * ((height + 1) // 2)])
104 | t1 = tf.reshape(t1, [-1, channels])
105 | t1 = tf.transpose(t1, perm=[1, 0])
106 | t1 = tf.reshape(t1, [channels, (height + 1) // 2, (width + 1) // 2, 1])
107 |
108 | t2 = tf.squeeze(argmax)
109 | t2 = tf.stack((t2[0], t2[1]), axis=0)
110 | t2 = tf.transpose(t2, perm=[3, 1, 2, 0])
111 |
112 | t = tf.concat([t2, t1], 3)
113 | indices = tf.reshape(t, [((height + 1) // 2) * ((width + 1) // 2) * channels, 3])
114 |
115 | x1 = tf.squeeze(x)
116 | x1 = tf.reshape(x1, [-1, channels])
117 | x1 = tf.transpose(x1, perm=[1, 0])
118 | values = tf.reshape(x1, [-1])
119 |
120 | delta = tf.SparseTensor(indices, values, tf.to_int64(tf.shape(output)))
121 | return tf.expand_dims(tf.sparse_tensor_to_dense(tf.sparse_reorder(delta)), 0)
122 |
123 | def unpooling_layer2x2_batch(bottom, layer_name, argmax):
124 | with tf.name_scope(layer_name):
125 | bottom_shape = tf.shape(bottom)
126 | top_shape = [bottom_shape[0], bottom_shape[1] * 2, bottom_shape[2] * 2, bottom_shape[3]]
127 |
128 | batch_size = top_shape[0]
129 | height = top_shape[1]
130 | width = top_shape[2]
131 | channels = top_shape[3]
132 |
133 | argmax_shape = tf.to_int64([batch_size, height, width, channels])
134 | argmax = unravel_argmax(argmax, argmax_shape)
135 |
136 | t1 = tf.to_int64(tf.range(channels))
137 | t1 = tf.tile(t1, [batch_size * (width // 2) * (height // 2)])
138 | t1 = tf.reshape(t1, [-1, channels])
139 | t1 = tf.transpose(t1, perm=[1, 0])
140 | t1 = tf.reshape(t1, [channels, batch_size, height // 2, width // 2, 1])
141 | t1 = tf.transpose(t1, perm=[1, 0, 2, 3, 4])
142 |
143 | t2 = tf.to_int64(tf.range(batch_size))
144 | t2 = tf.tile(t2, [channels * (width // 2) * (height // 2)])
145 | t2 = tf.reshape(t2, [-1, batch_size])
146 | t2 = tf.transpose(t2, perm=[1, 0])
147 | t2 = tf.reshape(t2, [batch_size, channels, height // 2, width // 2, 1])
148 |
149 | t3 = tf.transpose(argmax, perm=[1, 4, 2, 3, 0])
150 |
151 | t = tf.concat([t2, t3, t1], 4)
152 | indices = tf.reshape(t, [(height // 2) * (width // 2) * channels * batch_size, 4])
153 |
154 | x1 = tf.transpose(bottom, perm=[0, 3, 1, 2])
155 | values = tf.reshape(x1, [-1])
156 | return tf.scatter_nd(indices, values, tf.to_int64(top_shape))
157 |
158 | def fullyconnect(input_layer, layer_name, input_size, output_size, regularizer):
159 | with tf.variable_scope(layer_name):
160 | fc_w = tf.get_variable("weight", [input_size, output_size],
161 | initializer = tf.truncated_normal_initializer(stddev = 0.1))
162 | if regularizer != 0:
163 | tf.add_to_collection('losses', regularizer(fc_w))
164 | fc_b = tf.get_variable("bias", [output_size], initializer = tf.constant_initializer(0.1))
165 | fc = tf.nn.sigmoid(tf.matmul(input_layer, fc_w) + fc_b)
166 | #if tain: fc = tf.nn.dropout(fc, 0.5)
167 | return fc
168 |
169 | def linearlyconnect(input_layer, layer_name, input_size, output_size, regularizer):
170 | with tf.variable_scope(layer_name):
171 | fc_w = tf.get_variable("weight", [input_size, output_size],
172 | initializer = tf.truncated_normal_initializer(stddev = 0.1))
173 | if regularizer != 0:
174 | tf.add_to_collection('losses', regularizer(fc_w))
175 |
176 | fc_b = tf.get_variable("bias", [output_size], initializer = tf.constant_initializer(0.1))
177 | logit = tf.matmul(input_layer, fc_w) + fc_b
178 | return logit
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import tensorflow as tf
4 | import pprint
5 |
6 | from model import DecompositionModel
7 | from solver import nn_train, nn_feedforward
8 |
9 | flags = tf.app.flags
10 | flags.DEFINE_string("dataset", "data", ".tfRecord file [data]")
11 | flags.DEFINE_string("mode", "train", "train, train_goon or feedforward [train]")
12 | flags.DEFINE_string("model_name", "model", "The name of model [model]")
13 | flags.DEFINE_integer("epoch", 30, "Epoch to train [500]")
14 | flags.DEFINE_integer("batch_size", 10, "The size of batch images [10]")
15 | flags.DEFINE_integer("model_step", 1000, "The number of iteration to save model [1000]")
16 | flags.DEFINE_float("lr", 0.01, "The base learning rate [0.01]")
17 | flags.DEFINE_string("checkpoint", "checkpoint", "The path of checkpoint")
18 |
19 | FLAGS = flags.FLAGS
20 |
21 | def main(_):
22 | os.chdir(sys.path[0])
23 | print("Current cwd: ", os.getcwd())
24 |
25 | FLAGS = flags.FLAGS
26 | FLAGS.goon = False
27 | FLAGS.output_model_dir = "../checkpoint"
28 | FLAGS.output_mat_dir = "../result"
29 |
30 | if not os.path.exists(FLAGS.output_model_dir):
31 | os.makedirs(FLAGS.output_model_dir)
32 | if not os.path.exists(FLAGS.output_mat_dir):
33 | os.makedirs(FLAGS.output_mat_dir)
34 |
35 | pp = pprint.PrettyPrinter()
36 | pp.pprint(FLAGS.__flags)
37 |
38 | run_config = tf.ConfigProto()
39 | run_config.gpu_options.allow_growth = True
40 |
41 | with tf.Session(config = run_config) as sess:
42 | print("Building Decomposition model ...")
43 | nn = DecompositionModel(sess, epoch = FLAGS.epoch,
44 | batch_size = FLAGS.batch_size,
45 | model_name = FLAGS.model_name)
46 |
47 | if FLAGS.mode == "train":
48 | print("training ...")
49 | nn_train(nn, sess, FLAGS)
50 | elif FLAGS.mode == "feedforward":
51 | print("caculating ...")
52 | nn_feedforward(nn, sess, FLAGS)
53 | elif FLAGS.mode == "train_goon":
54 | print("go on training ...")
55 | FLAGS.goon = True
56 | nn_train(sess, FLAGS)
57 |
58 | pp.pprint(FLAGS.__flags)
59 |
60 | sess.close()
61 | print("Sess closed.")
62 | exit()
63 |
64 | if __name__ == '__main__':
65 | tf.app.run()
--------------------------------------------------------------------------------
/src/matrix_inversion.m:
--------------------------------------------------------------------------------
1 | clear all;
2 | close all;
3 | clc;
4 |
5 | test_data_name = '../../data/test/test_data_cranial.mat';
6 |
7 | resultPath = '../../result/';
8 | result_name = [resultPath, 'result_matrix_inversion.mat'];
9 |
10 | disp('Loading test data...');
11 | load(test_data_name);
12 |
13 | %%
14 | mu_ = mu_bone_high*mu_tissue_low - mu_tissue_high*mu_bone_low;
15 | yita_11 = mu_tissue_low / mu_;
16 | yita_12 = -mu_tissue_high / mu_;
17 | yita_21 = -mu_bone_low / mu_;
18 | yita_22 = mu_bone_high / mu_;
19 |
20 | %%
21 | [h, w, slice] = size(I_L);
22 | I_bone = zeros(h, w, 'single');
23 | I_tissue = zeros(h, w, 'single');
24 |
25 | for i = 1:slice
26 | I_bone(:,:,i) = yita_11*I_H(:,:,i) + yita_12*I_L(:,:,i);
27 | I_tissue(:,:,i) = yita_21*I_H(:,:,i) + yita_22*I_L(:,:,i);
28 | end
29 |
30 | I_bone(I_bone < 0.00001) = 0;
31 | I_tissue(I_tissue < 0.00001) = 0;
32 |
33 | figure(1), imshow(I_bone, []);
34 | figure(2), imshow(I_tissue, []);
35 |
36 | disp(['Saving result data... ', result_name]);
37 | save(result_name, 'I_bone', 'I_tissue');
38 |
39 |
40 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import layer_xyf
3 |
4 | #input size
5 | INPUT_W = 65
6 | INPUT_H = 65
7 | INPUT_DEPTH = 1
8 |
9 | # model definition
10 | #layer_1 input size 65*65 [filter_height, filter_width, in_channels, out_channels]
11 | FILTER_1 = [5, 5, 1, 64]
12 | STRIDE_1 = 2
13 | PAD_1 = "SAME"
14 |
15 | #layer_2 input_size 33*33
16 | FILTER_2 = [5, 5, 64, 128]
17 | STRIDE_2 = 2
18 | PAD_2 = "SAME"
19 |
20 | #layer_3 input_size 17*17
21 | FILTER_3 = [5, 5, 128, 256]
22 | STRIDE_3 = 2
23 | PAD_3 = "SAME"
24 |
25 | #layer_4 input_size 9*9
26 | FILTER_4 = [9, 9, 256, 256]
27 | STRIDE_4 = 1
28 | PAD_4 = "VALID"
29 |
30 | #layer_5 input_size 1*1*2048
31 | FILTER_5 = [1, 1, 512, 1]
32 | STRIDE_5 = 1
33 | PAD_5 = "VALID"
34 |
35 | class DecompositionModel(object):
36 | def __init__(self, sess, epoch, batch_size = 64, model_name = 'model'):
37 | self.sess = sess
38 | self.epoch = epoch
39 | self.input_width = INPUT_W
40 | self.input_height = INPUT_H
41 | self.batch_size = batch_size
42 | self.model_name = model_name
43 |
44 | def feedforward(self, input_L_batch, input_H_batch, regularizer = 0):
45 | layer_L1 = layer_xyf.convo(input_L_batch, "conv_L1", FILTER_1, STRIDE_1, PAD_1)
46 | layer_L2 = layer_xyf.convo(layer_L1, "conv_L2", FILTER_2, STRIDE_2, PAD_2)
47 | layer_L3 = layer_xyf.convo(layer_L2, "conv_L3", FILTER_3, STRIDE_3, PAD_3)
48 | layer_L4 = layer_xyf.convo(layer_L3, "conv_L4", FILTER_4, STRIDE_4, PAD_4)
49 |
50 | layer_H1 = layer_xyf.convo(input_H_batch, "conv_H1", FILTER_1, STRIDE_1, PAD_1)
51 | layer_H2 = layer_xyf.convo(layer_H1, "conv_H2", FILTER_2, STRIDE_2, PAD_2)
52 | layer_H3 = layer_xyf.convo(layer_H2, "conv_H3", FILTER_3, STRIDE_3, PAD_3)
53 | layer_H4 = layer_xyf.convo(layer_H3, "conv_H4", FILTER_4, STRIDE_4, PAD_4)
54 |
55 | combine_LH = tf.concat([layer_L4, layer_H4], 3)
56 |
57 | layer_bone = layer_xyf.convo_noneRelu(combine_LH, "conv_bone", FILTER_5, STRIDE_5, PAD_5)
58 | layer_tissue = layer_xyf.convo_noneRelu(combine_LH, "conv_tissue", FILTER_5, STRIDE_5, PAD_5)
59 | return layer_bone, layer_tissue
60 |
61 | def feedforward_test(self, input_L_batch, input_H_batch, regularizer = 0):
62 | layer_L1 = layer_xyf.convo(input_L_batch, "conv_L1", FILTER_1, 1, PAD_1)
63 | layer_L2 = layer_xyf.convo(layer_L1, "conv_L2", FILTER_2, 1, PAD_2)
64 | layer_L3 = layer_xyf.convo(layer_L2, "conv_L3", FILTER_3, 1, PAD_3)
65 | layer_L4 = layer_xyf.convo(layer_L3, "conv_L4", FILTER_4, STRIDE_4, PAD_4)
66 |
67 | layer_H1 = layer_xyf.convo(input_H_batch, "conv_H1", FILTER_1, 1, PAD_1)
68 | layer_H2 = layer_xyf.convo(layer_H1, "conv_H2", FILTER_2, 1, PAD_2)
69 | layer_H3 = layer_xyf.convo(layer_H2, "conv_H3", FILTER_3, 1, PAD_3)
70 | layer_H4 = layer_xyf.convo(layer_H3, "conv_H4", FILTER_4, STRIDE_4, PAD_4)
71 |
72 | combine_LH = tf.concat([layer_L4, layer_H4], 3)
73 |
74 | layer_bone = layer_xyf.convo_noneRelu(combine_LH, "conv_bone", FILTER_5, STRIDE_5, "SAME")
75 | layer_tissue = layer_xyf.convo_noneRelu(combine_LH, "conv_tissue", FILTER_5, STRIDE_5, "SAME")
76 | return layer_bone, layer_tissue
77 |
78 | def input_layer_width(self):
79 | return self.input_width
80 |
81 | def input_layer_height(self):
82 | return self.input_height
83 |
--------------------------------------------------------------------------------
/src/solver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import scipy.io
3 | import time
4 | import tensorflow as tf
5 | import numpy as np
6 | import math
7 |
8 | LEARNING_RATE_DECAY = 0.85
9 | MOVING_AVERAGE_DECAY = 0.99
10 | REGULARAZTION_RATE = 0.0001
11 |
12 | NUM_SAMPLE = 2454300
13 |
14 | SUMMARY_DIR = "../log"
15 | SUMMARY_STEP = 100
16 |
17 | INPUT_H = 65
18 | INPUT_W = 65
19 | INPUT_CHANNEL = 1
20 |
21 | OUTPUT_H = 1
22 | OUTPUT_W = 1
23 | OUTPUT_CHANNEL = 1
24 |
25 | def nn_train(nn, sess, config):
26 | # read data
27 | print ("Train, loading .tfrecords data...", config.dataset)
28 | if not os.path.exists(config.dataset):
29 | raise Exception("training data not find.")
30 | return
31 |
32 | config.sampleNum = NUM_SAMPLE
33 | config.iteration = math.ceil( config.sampleNum / config.batch_size * config.epoch )
34 |
35 | time_str = time.strftime('%Y-%m-%d-%H_%M_%S', time.localtime(time.time()))
36 |
37 | train_file_path = os.path.join(config.dataset, "train_sample_batches_*")
38 | train_files = tf.train.match_filenames_once(train_file_path)
39 | filename_queue = tf.train.string_input_producer(train_files, shuffle = True)
40 |
41 | reader = tf.TFRecordReader()
42 | _, serialized_example = reader.read(filename_queue)
43 | features = tf.parse_single_example(serialized_example,
44 | features = {'train_x_low': tf.FixedLenFeature([], tf.string),
45 | 'train_x_high': tf.FixedLenFeature([], tf.string),
46 | 'train_y_bone': tf.FixedLenFeature([], tf.string),
47 | 'train_y_tissue': tf.FixedLenFeature([], tf.string)})
48 |
49 | train_x_low = tf.decode_raw(features['train_x_low'], tf.float32)
50 | train_x_high = tf.decode_raw(features['train_x_high'], tf.float32)
51 | train_y_bone = tf.decode_raw(features['train_y_bone'], tf.float32)
52 | train_y_tissue = tf.decode_raw(features['train_y_tissue'], tf.float32)
53 |
54 | train_x_low = tf.reshape(train_x_low, [INPUT_H, INPUT_W, INPUT_CHANNEL])
55 | train_x_high = tf.reshape(train_x_high, [INPUT_H, INPUT_W, INPUT_CHANNEL])
56 | train_y_bone = tf.reshape(train_y_bone, [OUTPUT_H, OUTPUT_W, OUTPUT_CHANNEL])
57 | train_y_tissue = tf.reshape(train_y_tissue, [OUTPUT_H, OUTPUT_W, OUTPUT_CHANNEL])
58 |
59 | train_x_low_batch, train_x_high_batch, train_y_bone_batch, train_y_tissue_batch = tf.train.shuffle_batch([train_x_low, train_x_high, train_y_bone, train_y_tissue],
60 | batch_size = config.batch_size,
61 | capacity = config.batch_size*3 + 1000,
62 | min_after_dequeue = 1000)
63 | # loss function define
64 | ouput_bone_batch, ouput_tissue_batch = nn.feedforward(train_x_low_batch, train_x_high_batch, 0)
65 | global_step = tf.Variable(0, trainable = False)
66 | variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
67 | variables_averages_op = variable_averages.apply(tf.trainable_variables())
68 |
69 | mse_loss = tf.reduce_mean(tf.square(ouput_bone_batch - train_y_bone_batch) + tf.square(ouput_tissue_batch - train_y_tissue_batch))
70 | tf.add_to_collection('losses', mse_loss)
71 | tf.summary.scalar('MSE_losses', mse_loss)
72 |
73 | learning_rate = tf.train.exponential_decay(config.lr, global_step,
74 | config.sampleNum / config.batch_size, LEARNING_RATE_DECAY)
75 | train_step = tf.train.AdamOptimizer(learning_rate).minimize(mse_loss, global_step = global_step)
76 |
77 | with tf.control_dependencies([train_step, variables_averages_op]):
78 | train_op = tf.no_op(name = 'train')
79 |
80 | merged = tf.summary.merge_all()
81 | summary_path = os.path.join(SUMMARY_DIR, time_str)
82 | os.mkdir(summary_path)
83 | summary_writer = tf.summary.FileWriter(summary_path, sess.graph)
84 |
85 | saver = tf.train.Saver()
86 | if config.goon:
87 | if not os.path.exists(config.checkpoint):
88 | raise Exception("checkpoint path not find.")
89 | return
90 | print("Loading trained model... ", config.checkpoint)
91 | ckpt = tf.train.get_checkpoint_state(config.checkpoint)
92 | saver.retore(sess, ckpt)
93 | else:
94 | init = (tf.global_variables_initializer(), tf.local_variables_initializer())
95 | sess.run(init)
96 | #with tf.Session() as sess:
97 | #init = tf.global_variables_initializer()
98 |
99 | save_model_path = os.path.join(config.output_model_dir, time_str)
100 | os.mkdir(save_model_path)
101 |
102 | print("start training sess...")
103 | coord = tf.train.Coordinator()
104 | threads = tf.train.start_queue_runners(sess = sess, coord = coord)
105 |
106 | start_time = time.time()
107 | for i in range(config.iteration):
108 | batch_start_time = time.time()
109 | summary, _, loss_value, step, learningRate = sess.run([merged, train_op, mse_loss, global_step, learning_rate])
110 | batch_end_time = time.time()
111 | sec_per_batch = batch_end_time - batch_start_time
112 | if step % SUMMARY_STEP == 0:
113 | summary_writer.add_summary(summary, step//SUMMARY_STEP)
114 | if step % config.model_step == 0:
115 | print("Saving model (after %d iteration)... " %(step))
116 | saver.save(sess, os.path.join(save_model_path, config.model_name + ".ckpt"), global_step = global_step)
117 | print("sec/batch(%d) %gs, global step %d batches, training epoch %d/%d, learningRate %g, loss on training is %g"
118 | % (config.batch_size, sec_per_batch, step, i*config.batch_size / config.sampleNum,
119 | config.epoch, learningRate, loss_value))
120 | print("Elapsed time: %g" %(batch_end_time - start_time))
121 |
122 | coord.request_stop()
123 | coord.join(threads)
124 | summary_writer.close()
125 | print("Train done. ")
126 | print("Saving model... ", save_model_path)
127 | saver.save(sess, os.path.join(save_model_path, config.model_name + ".ckpt"), global_step = global_step)
128 |
129 | def nn_feedforward(nn, sess, config):
130 | print ("Feed forward, loading .mat data...", config.dataset)
131 | if not os.path.exists(config.dataset):
132 | raise Exception(".mat file not find.")
133 | return
134 |
135 | mat = scipy.io.loadmat(config.dataset)
136 | input_xl = mat['I_L']
137 | input_xh = mat['I_H']
138 |
139 | dim_h = input_xl.shape[0]
140 | dim_w = input_xl.shape[1]
141 | if input_xl.ndim < 3:
142 | input_xl = input_xl.reshape(dim_h, dim_w, 1)
143 | input_xh = input_xh.reshape(dim_h, dim_w, 1)
144 |
145 | batch_size = input_xl.shape[2]
146 | print('image num: ', batch_size)
147 | print('input type: ', input_xl.dtype)
148 |
149 | x_L = tf.placeholder(tf.float32, [batch_size, dim_h, dim_w, 1], name = 'xl-input')
150 | x_H = tf.placeholder(tf.float32, [batch_size, dim_h, dim_w, 1], name = 'xh-input')
151 | y_bone, y_tissue = nn.feedforward_test(x_L, x_H, None)
152 |
153 | output_dim = 0
154 | I_bone = np.array([], dtype = np.float32)
155 | I_tissue = np.array([], dtype = np.float32)
156 |
157 | variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY)
158 | variables_to_restore = variable_averages.variables_to_restore()
159 | saver = tf.train.Saver(variables_to_restore)
160 |
161 | ckpt = tf.train.get_checkpoint_state(config.checkpoint)
162 | if ckpt and ckpt.model_checkpoint_path:
163 | saver.restore(sess, ckpt.model_checkpoint_path)
164 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-'[-1])
165 | print('global step: ', global_step)
166 | for i in range(batch_size):
167 | xl = input_xl[:,:,i].reshape([1, dim_w, dim_h, 1])
168 | xh = input_xh[:,:,i].reshape([1, dim_w, dim_h, 1])
169 | print('xl: ', xl.shape)
170 | output_bone, output_tissue = sess.run([y_bone, y_tissue], feed_dict={x_L: xl, x_H: xh})
171 | print('TF Done')
172 | if i == 0:
173 | output_shape = output_bone.shape
174 | print("output shape: ", output_shape)
175 | output_dim = output_bone.shape[1]
176 | I_bone = np.zeros((output_dim, output_dim, batch_size), dtype = np.float32)
177 | I_tissue = np.zeros((output_dim, output_dim, batch_size), dtype = np.float32)
178 |
179 | output_bone.resize(output_dim, output_dim)
180 | output_tissue.resize(output_dim, output_dim)
181 | I_bone[:,:,i] = output_bone
182 | I_tissue[:,:,i] = output_tissue
183 |
184 | result_fileName = config.output_mat_dir + '/' + 'result_' \
185 | + config.model_name + '.mat'
186 | scipy.io.savemat(result_fileName, {'I_bone':I_bone, 'I_tissue':I_tissue})
187 | print('Feed forward done. Result: ', result_fileName)
188 | else:
189 | print('No checkpoint file found.')
190 | return
191 |
--------------------------------------------------------------------------------