├── README.md
├── data
├── system_matrix
│ └── Read me
├── testing_set
│ └── Read me
└── training_set
│ └── Read me
├── log
└── Read me
├── model
└── Read me
├── result
└── Read me
└── src
├── FBP_A.m
├── FBP_reconstruction.m
├── cell.py
├── combinedGradient.m
├── combinedGradientDescart.m
├── config.py
├── iterative_reconstruction.m
├── layer_xyf.py
├── main.py
├── model.py
├── solver.py
└── sysMatrix_2D_all_angle.m
/README.md:
--------------------------------------------------------------------------------
1 | # DualReconstruction
2 |
3 | This reposity is organized mainly for an image decomposition algorithm which is proposed to solve the image reconstruction and image-domain decomposition problem in Dual-energy Computed Tomography (DECT).
4 |
5 | The algorithm is designed based on deep learning paradigm. For more theoretical details, please go to [Deep Learning](http://www.deeplearningbook.org/) and [Material Decomposition Using DECT](https://pubs.rsna.org/doi/10.1148/rg.2016150220).
6 |
7 | The code is currently based on python 3.6, [Tensorflow](https://github.com/tensorflow/tensorflow) 1.4.0 and [ODL](https://github.com/odlgroup/odl) in Windows 7 platform.
8 |
9 | * data: contains 3 paths.
10 | * system_matrix: the system matrix used in the reconstruction algorithms. You can generate the matrix by running the 'iterative_reconstruction.m' or 'FBP_reconstruction.m' file in the src path.
11 | * testing_set: The data used for testing.
12 | * training_set: The data used for training the deep model. Both training and testing set can be download from [here](https://pan.baidu.com/s/1VfhTuNenuy2C6HAw1aWbZA)(Extraction number: t4ya).
13 | * log: save the Tensorflow log file in training process.
14 | * model: save the trained model
15 | * result: save the result generated by the reconstruction algorithms.
16 | * src: the codes for the proposed algorithm and two other competing ones:
17 | * Filter back projection (FBP) followed by direct matrix inversion (FBP_reconstruction.m)
18 | * Combined iterative reconstruction and image decomposition (iterative_reconstruction.m). Related paper: [Combined iterative reconstruction and image-domain decomposition for dual energy CT using total-variation regularization.](https://aapm.onlinelibrary.wiley.com/doi/abs/10.1118/1.4870375)
19 | * The proposed deep model (main.py). You can start to train the proposed deep model via the cmd:
20 | ```bash
21 | python main.py --dataset="../data/training_set/" --mode="train" --model_name="your-saved-model-result-name" --lr = 0.0001 --epoch=30 --model_step=1000 --batch_size=1
22 | ```
23 | After finishing the training process, you can test the trained model via the cmd:
24 | ```bash
25 | python main.py --dataset="../data/testing_set/" --mode="feedforward" --model_name="your-saved-model-result-name" --checkpoint="../model/your-saved-model-result-name"
26 | ```
27 |
28 | # Contact
29 | Email: vastcyclone@yeah.net
30 |
--------------------------------------------------------------------------------
/data/system_matrix/Read me:
--------------------------------------------------------------------------------
1 | This path is used to storage the system matrix applied in the algorithms.
2 | We do not provide download link of the matrix since it is as large as 2GB for 256*256 reconstruction problem.
3 | Please generate the system matrix in your local by running the 'iterative_reconstruction.m' or 'FBP_reconstruction.m' file in the src path.
4 | This should be done before starting to train the deep neural net.
5 |
--------------------------------------------------------------------------------
/data/testing_set/Read me:
--------------------------------------------------------------------------------
1 | Please download the testing set from https://pan.baidu.com/s/1VfhTuNenuy2C6HAw1aWbZA (Extraction number: t4ya) and unzip the file to this path.
2 |
--------------------------------------------------------------------------------
/data/training_set/Read me:
--------------------------------------------------------------------------------
1 | Please download the training set from https://pan.baidu.com/s/1VfhTuNenuy2C6HAw1aWbZA (Extraction number: t4ya) and unzip the file to this path.
2 |
--------------------------------------------------------------------------------
/log/Read me:
--------------------------------------------------------------------------------
1 | This path is used to storage the Tensorflow log files in training process.
2 |
--------------------------------------------------------------------------------
/model/Read me:
--------------------------------------------------------------------------------
1 | This path is used to storage the trained model which is tensorflow .ckpt format.
2 |
--------------------------------------------------------------------------------
/result/Read me:
--------------------------------------------------------------------------------
1 | This path is used to storage the testing result files (.mat file format) generated by the algorithms.
2 |
--------------------------------------------------------------------------------
/src/FBP_A.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XYF-GitHub/DualReconstruction/65e95bfb54d0f555c32d8c92417f6f3a63eb5078/src/FBP_A.m
--------------------------------------------------------------------------------
/src/FBP_reconstruction.m:
--------------------------------------------------------------------------------
1 | close all;
2 | clear all;
3 | clc;
4 |
5 | test_data_path = '..\data\testing_set\';
6 | result_path = '..\result\';
7 |
8 | N = 256;
9 | prjLen = 1024;
10 |
11 | opts.Nx = N; % Size of the object
12 | opts.Ny = N; %
13 | opts.sod = 1000; % Distance from source to object (mm)
14 | opts.sdd = 1500; % Distance from source to detector
15 | opts.dt = 0.388; % Size of detector voxel
16 | opts.Uy = prjLen; % Number of projections
17 | opts.voxel = 1.0; % opts.voxel = opts.sod/opts.sdd*opts.dt*opts.Uy/opts.Nx;
18 | opts.Nz = 1; %
19 | opts.Vz = 1; %
20 | angcov = 360;
21 | angstp = angcov / (view_num + 1);
22 | theta_vec = 0:angstp:angcov - 1;
23 |
24 | systemMatrix_file = '..\data\A_256.mat';
25 | if exist(systemMatrix_file, 'file')
26 | disp('Loading sysMatrix...');
27 | load(systemMatrix_file);
28 | else
29 | tic;
30 | disp('sysMatrix_2D_all_angle...');
31 | [W_row, W_col, W_val, sumP2R, sumC, sumR, Row, Col, Val_Num] = sysMatrix_2D_all_angle(theta_vec,opts);
32 | disp('A sparse...');
33 | A = sparse(W_row, W_col, W_val, Row, Col);
34 | clear W_row;
35 | clear W_col;
36 | clear W_val;
37 | toc;
38 | disp('Saving system matrix...');
39 | save(systemMatrix_file, '-v7.3', 'A');
40 | end
41 |
42 | %%
43 | x1H = 0.0342;
44 | x1L = 0.0588;
45 | x2H = 0.019;
46 | x2L = 0.0251;
47 | a = x2L/(x1H*x2L - x2H*x1L);
48 | b = -1*x2H/(x1H*x2L - x2H*x1L);
49 | c = -1*x1L/(x1H*x2L - x2H*x1L);
50 | d = x1H/(x1H*x2L - x2H*x1L);
51 |
52 | mat_files = dir(test_data_path);
53 |
54 | fopt.angstp = 1;
55 | fopt.angcov = 360;
56 | fopt.voxel = 1.03;
57 | fopt.filter = 1; % 1- RL;2- SL;3- cos;4- hamming;5- hann
58 |
59 | for f = 3:length(mat_files)
60 | file = [test_data_path, mat_files(f).name];
61 | load(file);
62 | [view_num, prjLen, img_num] = size(mh);
63 |
64 | rh = zeros(N, N, img_num);
65 | rl = zeros(N, N, img_num);
66 | d1 = zeros(N, N, img_num);
67 | d2 = zeros(N, N, img_num);
68 |
69 | disp(file);
70 | for n = 1:img_num
71 | tic;
72 | disp( sprintf('Image reconstruction %01d / %d ', n, img_num) );
73 |
74 | xh = FBP_A(A,double(mh(:,:,n)'),fopt);
75 | xl = FBP_A(A,double(ml(:,:,n)'),fopt);
76 |
77 | rh(:,:,n) = reshape(xh,N,N);
78 | rl(:,:,n) = reshape(xl,N,N);
79 | d1(:,:,n) = a*rh(:,:,n) + b*rl(:,:,n);
80 | d2(:,:,n) = c*rh(:,:,n) + d*rl(:,:,n);
81 | toc;
82 | end
83 | savefileName = [result_path, sprintf('result_test_%04d.mat', f - 2) ];
84 | disp(['Saving file: ', savefileName]);
85 | save(savefileName, 'rh', 'rl', 'd1', 'd2');
86 | end
87 |
88 |
--------------------------------------------------------------------------------
/src/cell.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | class ConvLSTMCell(tf.nn.rnn_cell.RNNCell):
4 | """A LSTM cell with convolutions instead of multiplications.
5 |
6 | Reference:
7 | Xingjian, S. H. I., et al. "Convolutional LSTM network: A machine learning approach for precipitation nowcasting." Advances in Neural Information Processing Systems. 2015.
8 | """
9 |
10 | def __init__(self, shape, filters, kernel, forget_bias=1.0, activation=tf.tanh, normalize=True, peephole=True, data_format='channels_last', reuse=None):
11 | super(ConvLSTMCell, self).__init__(_reuse=reuse)
12 | self._kernel = kernel
13 | self._filters = filters
14 | self._forget_bias = forget_bias
15 | self._activation = activation
16 | self._normalize = normalize
17 | self._peephole = peephole
18 | if data_format == 'channels_last':
19 | self._size = tf.TensorShape(shape + [self._filters])
20 | self._feature_axis = self._size.ndims
21 | self._data_format = None
22 | elif data_format == 'channels_first':
23 | self._size = tf.TensorShape([self._filters] + shape)
24 | self._feature_axis = 0
25 | self._data_format = 'NC'
26 | else:
27 | raise ValueError('Unknown data_format')
28 |
29 | @property
30 | def state_size(self):
31 | return tf.nn.rnn_cell.LSTMStateTuple(self._size, self._size)
32 |
33 | @property
34 | def output_size(self):
35 | return self._size
36 |
37 | def call(self, x, state):
38 | c, h = state
39 |
40 | x = tf.concat([x, h], axis=self._feature_axis)
41 | n = x.shape[-1].value
42 | m = 4 * self._filters if self._filters > 1 else 4
43 | W = tf.get_variable('kernel', self._kernel + [n, m])
44 | y = tf.nn.convolution(x, W, 'SAME', data_format=self._data_format)
45 | if not self._normalize:
46 | y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())
47 | j, i, f, o = tf.split(y, 4, axis=self._feature_axis)
48 |
49 | if self._peephole:
50 | i += tf.get_variable('W_ci', c.shape[1:]) * c
51 | f += tf.get_variable('W_cf', c.shape[1:]) * c
52 |
53 | if self._normalize:
54 | j = tf.contrib.layers.layer_norm(j)
55 | i = tf.contrib.layers.layer_norm(i)
56 | f = tf.contrib.layers.layer_norm(f)
57 |
58 | f = tf.sigmoid(f + self._forget_bias)
59 | i = tf.sigmoid(i)
60 | c = c * f + i * self._activation(j)
61 |
62 | if self._peephole:
63 | o += tf.get_variable('W_co', c.shape[1:]) * c
64 |
65 | if self._normalize:
66 | o = tf.contrib.layers.layer_norm(o)
67 | c = tf.contrib.layers.layer_norm(c)
68 |
69 | o = tf.sigmoid(o)
70 | h = o * self._activation(c)
71 |
72 | state = tf.nn.rnn_cell.LSTMStateTuple(c, h)
73 |
74 | return h, state
75 |
76 |
77 | class ConvGRUCell(tf.nn.rnn_cell.RNNCell):
78 | """A GRU cell with convolutions instead of multiplications."""
79 |
80 | def __init__(self, shape, filters, kernel, activation=tf.tanh, normalize=True, data_format='channels_last', reuse=None):
81 | super(ConvGRUCell, self).__init__(_reuse=reuse)
82 | self._filters = filters
83 | self._kernel = kernel
84 | self._activation = activation
85 | self._normalize = normalize
86 | if data_format == 'channels_last':
87 | self._size = tf.TensorShape(shape + [self._filters])
88 | self._feature_axis = self._size.ndims
89 | self._data_format = None
90 | elif data_format == 'channels_first':
91 | self._size = tf.TensorShape([self._filters] + shape)
92 | self._feature_axis = 0
93 | self._data_format = 'NC'
94 | else:
95 | raise ValueError('Unknown data_format')
96 |
97 | @property
98 | def state_size(self):
99 | return self._size
100 |
101 | @property
102 | def output_size(self):
103 | return self._size
104 |
105 | def call(self, x, h):
106 | channels = x.shape[self._feature_axis].value
107 |
108 | with tf.variable_scope('gates'):
109 | inputs = tf.concat([x, h], axis=self._feature_axis)
110 | n = channels + self._filters
111 | m = 2 * self._filters if self._filters > 1 else 2
112 | W = tf.get_variable('kernel', self._kernel + [n, m])
113 | y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format)
114 | if self._normalize:
115 | r, u = tf.split(y, 2, axis=self._feature_axis)
116 | r = tf.contrib.layers.layer_norm(r)
117 | u = tf.contrib.layers.layer_norm(u)
118 | else:
119 | y += tf.get_variable('bias', [m], initializer=tf.ones_initializer())
120 | r, u = tf.split(y, 2, axis=self._feature_axis)
121 | r, u = tf.sigmoid(r), tf.sigmoid(u)
122 |
123 | with tf.variable_scope('candidate'):
124 | inputs = tf.concat([x, r * h], axis=self._feature_axis)
125 | n = channels + self._filters
126 | m = self._filters
127 | W = tf.get_variable('kernel', self._kernel + [n, m])
128 | y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format)
129 | if self._normalize:
130 | y = tf.contrib.layers.layer_norm(y)
131 | else:
132 | y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer())
133 | h = u * h + (1 - u) * self._activation(y)
134 |
135 | return h, h
136 |
--------------------------------------------------------------------------------
/src/combinedGradient.m:
--------------------------------------------------------------------------------
1 | function [g1,g2] = combinedGradient(f1,f2,a,b)
2 | %%
3 |
4 | %%
5 | [row1,col1] = size(f1);
6 | [row2,col2] = size(f2);
7 |
8 | ff1 = zeros(row1+4,col1+4);
9 | ff2 = zeros(row2+4,col2+4);
10 |
11 | ff1(3:end-2,3:end-2) = f1;
12 | ff2(3:end-2,3:end-2) = f2;
13 |
14 | g1 = zeros(size(f1));
15 | g2 = zeros(size(f2));
16 |
17 | if row1~=row2 || col1~=col2
18 | error('Error! Size of f1 and f2 not match!');
19 | end
20 |
21 | for i=3:(row1+2)
22 | for j=3:(col1+2)
23 | %%%
24 | v1 = a*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i-1,j)-b*ff2(i-1,j) ) + a*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i,j-1)-b*ff2(i,j-1) );
25 | v2 = a*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i+1,j)-b*ff2(i+1,j) );
26 | v3 = a*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i,j+1)-b*ff2(i,j+1) );
27 | %%%%
28 | t1 = b*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i-1,j)-b*ff2(i-1,j) ) + b*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i,j-1)-b*ff2(i,j-1) );
29 | t2 = b*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i+1,j)-b*ff2(i+1,j) );
30 | t3 = b*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i,j+1)-b*ff2(i,j+1) );
31 | %%%
32 | denom1 = sqrt( eps+( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i-1,j)-b*ff2(i-1,j) )^2+( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i,j-1)-b*ff2(i,j-1) )^2 );
33 | denom2 = sqrt( eps+( a*ff1(i+1,j)+b*ff2(i+1,j)-a*ff1(i,j)-b*ff2(i,j) )^2+( a*ff1(i+1,j)+b*ff2(i+1,j)-a*ff1(i+1,j-1)-b*ff2(i+1,j-1) )^2 );
34 | denom3 = sqrt( eps+( a*ff1(i,j+1)+b*ff2(i,j+1)-a*ff1(i-1,j+1)-b*ff2(i-1,j+1) )^2+( a*ff1(i,j+1)+b*ff2(i,j+1)-a*ff1(i,j)-b*ff2(i,j) )^2 );
35 |
36 | g1(i-2,j-2) = v1/denom1+v2/denom2+v3/denom3;
37 | g2(i-2,j-2) = t1/denom1+t2/denom2+t3/denom3;
38 | end
39 | end
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 | end
54 |
--------------------------------------------------------------------------------
/src/combinedGradientDescart.m:
--------------------------------------------------------------------------------
1 | function [g] = combinedGradientDescart(f1,f2,beta1,beta2,a,b,c,d)
2 | %%
3 |
4 | %%
5 |
6 | [gh1,gl1] = combinedGradient(f1,f2,a,b);
7 | [gh2,gl2] = combinedGradient(f1,f2,c,d);
8 | g = [beta1*gh1(:)+beta2*gh2(:);beta1*gl1(:)+beta2*gl2(:)];
9 |
10 | end
11 |
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 |
4 | def config_par(config):
5 | config.DEFINE_string("output_model_dir", "../model", "")
6 | config.DEFINE_string("output_data_dir", "../result", "")
7 | config.DEFINE_string("model_name", "lstm", "")
8 | config.DEFINE_integer("sampleNum", 3226, "")
9 | config.DEFINE_integer("iteration", math.ceil( config.FLAGS.sampleNum / config.FLAGS.batch_size * config.FLAGS.epoch ), "")
10 |
11 | config.DEFINE_string("summary_dir", "../log", "")
12 | config.DEFINE_integer("summary_step", 20, "")
13 |
14 | config.DEFINE_float("learning_rate_decay", 0.85, "")
15 | config.DEFINE_float("moving_average_decay", 0.99, "")
16 | config.DEFINE_float("regularazition_rate", 0.0001, "")
17 |
18 | if not os.path.exists(config.FLAGS.output_model_dir):
19 | os.makedirs(config.FLAGS.output_model_dir)
20 | if not os.path.exists(config.FLAGS.output_data_dir):
21 | os.makedirs(config.FLAGS.output_data_dir)
22 | if not os.path.exists(config.FLAGS.summary_dir):
23 | os.makedirs(config.FLAGS.summary_dir)
24 |
25 | return config
--------------------------------------------------------------------------------
/src/iterative_reconstruction.m:
--------------------------------------------------------------------------------
1 | close all;
2 | clear all;
3 | clc;
4 |
5 | test_data_path = '..\data\testing_set\';
6 | result_path = '..\result\';
7 |
8 | N = 256;
9 | prjLen = 1024;
10 |
11 | opts.Nx = N; % Size of the object
12 | opts.Ny = N; %
13 | opts.sod = 1000; % Distance from source to object (mm)
14 | opts.sdd = 1500; % Distance from source to detector
15 | opts.dt = 0.388; % Size of detector voxel
16 | opts.Uy = prjLen; % Number of projections
17 | opts.voxel = 1.03; % opts.voxel = opts.sod/opts.sdd*opts.dt*opts.Uy/opts.Nx;
18 | opts.Nz = 1; %
19 | opts.Vz = 1; %
20 | angcov = 360;
21 | angstp = angcov / (view_num + 1);
22 | theta_vec = 0:angstp:angcov - 1;
23 |
24 | systemMatrix_file = '..\data\A_256.mat';
25 | if exist(systemMatrix_file, 'file')
26 | disp('Loading sysMatrix...');
27 | load(systemMatrix_file);
28 | else
29 | tic;
30 | disp('sysMatrix_2D_all_angle...');
31 | [W_row, W_col, W_val, sumP2R, sumC, sumR, Row, Col, Val_Num] = sysMatrix_2D_all_angle(theta_vec,opts);
32 | disp('A sparse...');
33 | A = sparse(W_row, W_col, W_val, Row, Col);
34 | clear W_row;
35 | clear W_col;
36 | clear W_val;
37 | toc;
38 | disp('Saving system matrix...');
39 | save(systemMatrix_file, '-v7.3', 'A');
40 | end
41 |
42 | Niter = 100;
43 | x1H = 0.0342;
44 | x1L = 0.0588;
45 | x2H = 0.019;
46 | x2L = 0.0251;
47 | a = x2L/(x1H*x2L - x2H*x1L);
48 | b = -1*x2H/(x1H*x2L - x2H*x1L);
49 | c = -1*x1L/(x1H*x2L - x2H*x1L);
50 | d = x1H/(x1H*x2L - x2H*x1L);
51 |
52 | mat_files = dir(test_data_path);
53 |
54 | for f = 3:length(mat_files)
55 | file = [test_data_path, mat_files(f).name];
56 | load(file);
57 | [view_num, prjLen, img_num] = size(mh);
58 |
59 | rh = zeros(N, N, img_num);
60 | rl = zeros(N, N, img_num);
61 | d1 = zeros(N, N, img_num);
62 | d2 = zeros(N, N, img_num);
63 |
64 | disp(file);
65 | for n = 1:img_num
66 | tic
67 | disp( sprintf('Image reconstruction %01d / %d ', n, img_num) );
68 | ml_slice = ml(:,:,n)';
69 | mh_slice = mh(:,:,n)';
70 | ml_slice = double(ml_slice(:));
71 | mh_slice = double(mh_slice(:));
72 |
73 | xh = zeros(N*N,1);
74 | xl = zeros(N*N,1);
75 | xh_old = zeros(N*N,1);
76 | xl_old = zeros(N*N,1);
77 | x = zeros(2*N*N,1);
78 | px = zeros(2*N*N,1);
79 | px_old = zeros(2*N*N,1);
80 |
81 | beta1 = 0.003;
82 | beta2 = 0.0045;
83 | beta1_red = 0.99;
84 | beta2_red = 0.99;
85 |
86 | kai = 0.3;
87 | tol = eps;
88 |
89 | for i = 1:Niter
90 | x = [xh;xl];
91 | gtv = combinedGradientDescart(reshape(xh,N,N),reshape(xl,N,N),beta1,beta2,a,b,c,d);
92 | gquadh = A'*(A*double(xh) - mh_slice);
93 | gquadl = A'*(A*double(xl) - ml_slice);
94 | g = single(gtv + [gquadh(:);gquadl(:)]);
95 |
96 | px(:) = 0;
97 | px(g <= 0 | x > 0) = g(g <= 0 | x > 0);
98 | if i == 1
99 | xh_old = single(zeros(N*N,1));
100 | xl_old = single(zeros(N*N,1));
101 | x_old = [xh_old;xl_old];
102 | px_old = single(zeros(2*N*N,1));
103 | alpha = 10E-8;
104 | else
105 | alpha1 = (x - x_old)'*(x - x_old) / ((x - x_old)'*(px - px_old));
106 | alpha2 = (x - x_old)'*(px - px_old) / ((px - px_old)'*(px - px_old));
107 | if alpha2 / alpha1 < kai
108 | alpha = alpha1;
109 | else
110 | alpha = alpha2;
111 | end
112 | x_old = x;
113 | px_old = px;
114 | end
115 | x = x - alpha*px;
116 | x(x < 0) = 0;
117 | xh = x(1:N*N);
118 | xl = x((N*N + 1):end);
119 |
120 | if norm(x - x_old) <= tol
121 | break;
122 | end
123 | beta1 = beta1*beta1_red;
124 | beta2 = beta2*beta2_red;
125 | end
126 |
127 | rh(:,:,n) = reshape(xh,N,N);
128 | rl(:,:,n) = reshape(xl,N,N);
129 | d1(:,:,n) = a*rh(:,:,n) + b*rl(:,:,n);
130 | d2(:,:,n) = c*rh(:,:,n) + d*rl(:,:,n);
131 | % figure(1); imshow(rh(:,:,n),[]), title(['rh image: ', num2str(n)]);
132 | % figure(2); imshow(rl(:,:,n),[]), title(['rl image: ', num2str(n)]);
133 | % figure(3); imshow(d1(:,:,n),[]), title(['d1 image: ', num2str(n)]);
134 | % figure(4); imshow(d2(:,:,n),[]), title(['d2 image: ', num2str(n)]);
135 | toc;
136 | end
137 | savefileName = [result_path, sprintf('result_test_%04d.mat', f - 2) ];
138 | disp(['Saving file: ', savefileName]);
139 | save(savefileName, 'rh', 'rl', 'd1', 'd2');
140 | end
141 |
142 |
--------------------------------------------------------------------------------
/src/layer_xyf.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Fri Dec 1 20:39:12 2017
4 |
5 | @author: XYF
6 | """
7 |
8 | import tensorflow as tf
9 |
10 | def variable_summaries(var, var_name):
11 | with tf.name_scope(var_name):
12 | mean = tf.reduce_mean(var)
13 | tf.summary.scalar('mean', mean)
14 | with tf.name_scope('stddev'):
15 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
16 | tf.summary.scalar('stddev', stddev)
17 | tf.summary.scalar('max', tf.reduce_max(var))
18 | tf.summary.scalar('min', tf.reduce_min(var))
19 | tf.summary.histogram('histogram', var)
20 |
21 | def convo(input_layer, layer_name, filter_shape, stride, padStr):
22 | with tf.variable_scope(layer_name):
23 | input_shape = input_layer.get_shape().as_list()
24 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3])
25 |
26 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1))
27 | conv_b = tf.get_variable("bias", filter_shape[3], initializer = tf.constant_initializer(0.0))
28 | conv = tf.nn.conv2d(input_layer, conv_w, [1, stride, stride, 1], padStr)
29 | relu = tf.nn.relu(tf.nn.bias_add(conv, conv_b))
30 | return relu
31 |
32 | def convo_sigmoid(input_layer, layer_name, filter_shape, stride, padStr):
33 | with tf.variable_scope(layer_name):
34 | input_shape = input_layer.get_shape().as_list()
35 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3])
36 |
37 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1))
38 | conv_b = tf.get_variable("bias", filter_shape[3], initializer = tf.constant_initializer(0.0))
39 | conv = tf.nn.conv2d(input_layer, conv_w, [1, stride, stride, 1], padStr)
40 | sig = tf.nn.sigmoid(tf.nn.bias_add(conv, conv_b))
41 | return sig
42 |
43 | def convo_noneRelu(input_layer, layer_name, filter_shape, stride, padStr):
44 | with tf.variable_scope(layer_name):
45 | input_shape = input_layer.get_shape().as_list()
46 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3])
47 |
48 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1))
49 | conv_b = tf.get_variable("bias", filter_shape[3], initializer = tf.constant_initializer(0.0))
50 | conv = tf.nn.conv2d(input_layer, conv_w, [1, stride, stride, 1], padStr)
51 | none_relu = tf.nn.bias_add(conv, conv_b)
52 |
53 | variable_summaries(conv_w, "conv_w")
54 | variable_summaries(conv_b, "conv_b")
55 | variable_summaries(none_relu, "none_relu")
56 | return none_relu
57 |
58 | def deconvo(input_layer, layer_name, filter_shape, out_img_size, stride, padStr):
59 | with tf.variable_scope(layer_name):
60 | input_shape = input_layer.get_shape().as_list()
61 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3])
62 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1))
63 | input_shape = tf.shape(input_layer)
64 | output_shape = tf.stack([input_shape[0], out_img_size[0], out_img_size[1], filter_shape[2]])
65 | conv_b = tf.get_variable("bias", filter_shape[2], initializer = tf.constant_initializer(0.0))
66 | conv = tf.nn.conv2d_transpose(input_layer, conv_w, output_shape, [1, stride, stride, 1], padStr)
67 | deconv = tf.nn.bias_add(conv, conv_b)
68 | return deconv
69 |
70 | def deconv_withRelu(input_layer, layer_name, filter_shape, out_img_size, stride, padStr):
71 | with tf.variable_scope(layer_name):
72 | input_shape = input_layer.get_shape().as_list()
73 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3])
74 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1))
75 | input_shape = tf.shape(input_layer)
76 | output_shape = tf.stack([input_shape[0], out_img_size[0], out_img_size[1], filter_shape[2]])
77 | conv_b = tf.get_variable("bias", filter_shape[2], initializer = tf.constant_initializer(0.0))
78 | conv = tf.nn.conv2d_transpose(input_layer, conv_w, output_shape, [1, stride, stride, 1], padStr)
79 | relu = tf.nn.relu(tf.nn.bias_add(conv, conv_b))
80 | return relu
81 |
82 | def deconv_withSigmoid(input_layer, layer_name, filter_shape, out_img_size, stride, padStr):
83 | with tf.variable_scope(layer_name):
84 | input_shape = input_layer.get_shape().as_list()
85 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3])
86 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1))
87 | input_shape = tf.shape(input_layer)
88 | output_shape = tf.stack([input_shape[0], out_img_size[0], out_img_size[1], filter_shape[2]])
89 | conv_b = tf.get_variable("bias", filter_shape[2], initializer = tf.constant_initializer(0.0))
90 | conv = tf.nn.conv2d_transpose(input_layer, conv_w, output_shape, [1, stride, stride, 1], padStr)
91 | sig = tf.nn.sigmoid(tf.nn.bias_add(conv, conv_b))
92 | return sig
93 |
94 | def pooling(input_layer, layer_name, kernal_shape, stride, padStr):
95 | with tf.name_scope(layer_name):
96 | pool = tf.nn.max_pool(input_layer, kernal_shape, [1,stride,stride,1], padStr)
97 | return pool
98 |
99 | def pooling_withmax(input_layer, layer_name, kernal_shape, stride, padStr):
100 | with tf.name_scope(layer_name):
101 | return tf.nn.max_pool_with_argmax(input_layer, kernal_shape, [1,stride,stride,1], padStr)
102 |
103 | def unravel_argmax(argmax, shape):
104 | output_list = []
105 | output_list.append(argmax // (shape[2] * shape[3]))
106 | output_list.append(argmax % (shape[2] * shape[3]) // shape[3])
107 | return tf.stack(output_list)
108 |
109 | def unpooling_layer2x2(x, layer_name, raveled_argmax, out_shape):
110 | with tf.name_scope(layer_name):
111 | argmax = unravel_argmax(raveled_argmax, tf.to_int64(out_shape))
112 | output = tf.zeros([out_shape[1], out_shape[2], out_shape[3]])
113 |
114 | height = tf.shape(output)[0]
115 | width = tf.shape(output)[1]
116 | channels = tf.shape(output)[2]
117 |
118 | t1 = tf.to_int64(tf.range(channels))
119 | t1 = tf.tile(t1, [((width + 1) // 2) * ((height + 1) // 2)])
120 | t1 = tf.reshape(t1, [-1, channels])
121 | t1 = tf.transpose(t1, perm=[1, 0])
122 | t1 = tf.reshape(t1, [channels, (height + 1) // 2, (width + 1) // 2, 1])
123 |
124 | t2 = tf.squeeze(argmax)
125 | t2 = tf.stack((t2[0], t2[1]), axis=0)
126 | t2 = tf.transpose(t2, perm=[3, 1, 2, 0])
127 |
128 | t = tf.concat([t2, t1], 3)
129 | indices = tf.reshape(t, [((height + 1) // 2) * ((width + 1) // 2) * channels, 3])
130 |
131 | x1 = tf.squeeze(x)
132 | x1 = tf.reshape(x1, [-1, channels])
133 | x1 = tf.transpose(x1, perm=[1, 0])
134 | values = tf.reshape(x1, [-1])
135 |
136 | delta = tf.SparseTensor(indices, values, tf.to_int64(tf.shape(output)))
137 | return tf.expand_dims(tf.sparse_tensor_to_dense(tf.sparse_reorder(delta)), 0)
138 |
139 | def unpooling_layer2x2_batch(bottom, layer_name, argmax):
140 | with tf.name_scope(layer_name):
141 | bottom_shape = tf.shape(bottom)
142 | top_shape = [bottom_shape[0], bottom_shape[1] * 2, bottom_shape[2] * 2, bottom_shape[3]]
143 |
144 | batch_size = top_shape[0]
145 | height = top_shape[1]
146 | width = top_shape[2]
147 | channels = top_shape[3]
148 |
149 | argmax_shape = tf.to_int64([batch_size, height, width, channels])
150 | argmax = unravel_argmax(argmax, argmax_shape)
151 |
152 | t1 = tf.to_int64(tf.range(channels))
153 | t1 = tf.tile(t1, [batch_size * (width // 2) * (height // 2)])
154 | t1 = tf.reshape(t1, [-1, channels])
155 | t1 = tf.transpose(t1, perm=[1, 0])
156 | t1 = tf.reshape(t1, [channels, batch_size, height // 2, width // 2, 1])
157 | t1 = tf.transpose(t1, perm=[1, 0, 2, 3, 4])
158 |
159 | t2 = tf.to_int64(tf.range(batch_size))
160 | t2 = tf.tile(t2, [channels * (width // 2) * (height // 2)])
161 | t2 = tf.reshape(t2, [-1, batch_size])
162 | t2 = tf.transpose(t2, perm=[1, 0])
163 | t2 = tf.reshape(t2, [batch_size, channels, height // 2, width // 2, 1])
164 |
165 | t3 = tf.transpose(argmax, perm=[1, 4, 2, 3, 0])
166 |
167 | t = tf.concat([t2, t3, t1], 4)
168 | indices = tf.reshape(t, [(height // 2) * (width // 2) * channels * batch_size, 4])
169 |
170 | x1 = tf.transpose(bottom, perm=[0, 3, 1, 2])
171 | values = tf.reshape(x1, [-1])
172 | return tf.scatter_nd(indices, values, tf.to_int64(top_shape))
173 |
174 | def fullyconnect(input_layer, layer_name, input_size, output_size, regularizer):
175 | with tf.variable_scope(layer_name):
176 | fc_w = tf.get_variable("weight", [input_size, output_size],
177 | initializer = tf.truncated_normal_initializer(stddev = 0.1))
178 | if regularizer != 0:
179 | tf.add_to_collection('losses', regularizer(fc_w))
180 | fc_b = tf.get_variable("bias", [output_size], initializer = tf.constant_initializer(0.1))
181 | fc = tf.nn.sigmoid(tf.matmul(input_layer, fc_w) + fc_b)
182 | #if tain: fc = tf.nn.dropout(fc, 0.5)
183 | return fc
184 |
185 | def linearlyconnect(input_layer, layer_name, input_size, output_size, regularizer):
186 | with tf.variable_scope(layer_name):
187 | fc_w = tf.get_variable("weight", [input_size, output_size],
188 | initializer = tf.truncated_normal_initializer(stddev = 0.1))
189 | if regularizer != 0:
190 | tf.add_to_collection('losses', regularizer(fc_w))
191 |
192 | fc_b = tf.get_variable("bias", [output_size], initializer = tf.constant_initializer(0.1))
193 | logit = tf.matmul(input_layer, fc_w) + fc_b
194 | return logit
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import tensorflow as tf
4 | import pprint
5 |
6 | from model import DualReconstructionModel
7 | from solver import nn_train, nn_test, nn_feedforward
8 | from config import config_par
9 |
10 | flags = tf.app.flags
11 | flags.DEFINE_string("dataset", "data", ".tfRecord training file or .mat testing file [data]")
12 | flags.DEFINE_string("mode", "train", "train, test or feedforward [train]")
13 | flags.DEFINE_integer("epoch", 30, "Epoch to train [500]")
14 | flags.DEFINE_integer("batch_size", 10, "The size of batch images [10]")
15 | flags.DEFINE_integer("model_step", 1000, "The number of iteration to save model [1000]")
16 | flags.DEFINE_float("lr", 0.01, "The base learning rate [0.01]")
17 | flags.DEFINE_string("checkpoint", "checkpoint", "The path of checkpoint")
18 | flags.DEFINE_boolean("goon", False, "Go on training flag [0]")
19 |
20 | def main(_):
21 | os.chdir(sys.path[0])
22 | print("Current cwd: ", os.getcwd())
23 |
24 | config = config_par(flags)
25 | pp = pprint.PrettyPrinter()
26 | pp.pprint(config.FLAGS.__flags)
27 | FLAGS = config.FLAGS
28 |
29 | print("Building model ...")
30 | nn = DualReconstructionModel()
31 | if FLAGS.mode == "train":
32 | print("Mode: train ...")
33 | nn_train(nn, FLAGS)
34 | elif FLAGS.mode == "test":
35 | print("Mode: test ...")
36 | nn_test(nn, FLAGS)
37 | elif FLAGS.mode == "feedforward":
38 | print("Mode: feedforward ...")
39 | nn_feedforward(nn, FLAGS)
40 |
41 | exit()
42 |
43 | if __name__ == '__main__':
44 | tf.app.run()
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | import numpy as np
4 | import odl
5 | import odl.contrib.tensorflow
6 |
7 | from cell import ConvLSTMCell
8 |
9 | import layer_xyf
10 |
11 | # input size
12 | prjLen = 1024
13 | view_num = 360
14 | channel = 1
15 |
16 | # output size
17 | output_h = 256
18 | output_w = 256
19 | output_depth = 1
20 |
21 | # LSTM hidden size
22 | filters_lstm = 64
23 | kernal_lstm = [3, 3]
24 |
25 | #conv: [filter_height, filter_width, in_channels, out_channels]
26 | filter_shape = [256, 256, filters_lstm, 1]
27 |
28 | sod = 1000
29 | sdd = 1500
30 |
31 | iteration = 10
32 |
33 | # CNN
34 | #layer_1 input_size 128*128 / 256*256 / 512*512 conv: [filter_height, filter_width, in_channels, out_channels]
35 | FILTER_1 = [3, 3, 1, 32]
36 | STRIDE_1 = 2
37 | PAD_1 = "SAME"
38 |
39 | #layer_2 input_size 64*64 / 128*128 / 256*256
40 | FILTER_2 = [3, 3, 32, 32]
41 | STRIDE_2 = 2
42 | PAD_2 = "SAME"
43 |
44 | #layer_3 input_size 32*32 / 64*64 / 128*128
45 | FILTER_3 = [3, 3, 32, 32]
46 | STRIDE_3 = 2
47 | PAD_3 = "SAME"
48 |
49 | #layer_4 input_size 16*16 / 32*32 / 64*64
50 | #FILTER_4 = [9, 9, 256, 512]
51 | FILTER_4 = [3, 3, 32, 32]
52 | STRIDE_4 = 2
53 | PAD_4 = "SAME"
54 |
55 | #layer_5 input_size 8*8 / 16*16 / 32*32
56 | FILTER_5 = [16, 16, 32, 32]
57 | STRIDE_5 = 1
58 | PAD_5 = "VALID"
59 |
60 | #layer_6 input_size 1*1*128
61 | FILTER_6 = [1, 1, 64, 1]
62 | STRIDE_6 = 1
63 | PAD_6 = "VALID"
64 |
65 | space = odl.uniform_discr([-256, -256], [256, 256], [output_h, output_w], dtype='float32')
66 | angle_partition = odl.uniform_partition(0, 2 * np.pi, view_num)
67 | detector_partition = odl.uniform_partition(-360, 360, prjLen)
68 | geometry = odl.tomo.FanFlatGeometry(angle_partition, detector_partition,
69 | src_radius = sod, det_radius = sdd - sod)
70 | operator = odl.tomo.RayTransform(space, geometry, impl='astra_cuda')
71 |
72 | odl_op_layer = odl.contrib.tensorflow.as_tensorflow_layer(operator, 'RayTransform')
73 | odl_op_layer_adjoint = odl.contrib.tensorflow.as_tensorflow_layer(operator.adjoint, 'RayTransformAdjoint')
74 |
75 |
76 | class DualReconstructionModel(object):
77 | def __init__(self):
78 | self.input_width = prjLen
79 | self.input_height = view_num
80 | self.num_channel = channel
81 | self.output_h = output_h
82 | self.output_w = output_w
83 | self.output_depth = output_depth
84 |
85 | def feedforward(self, ml, mh, regularizer = 0):
86 | input_shape = mh.get_shape().as_list()
87 | batch_size = input_shape[0]
88 |
89 | #LSTM
90 | with tf.variable_scope('LSTM_1'):
91 | rh = tf.get_variable("rh", [batch_size, self.output_h, self.output_w, self.output_depth], initializer = tf.constant_initializer(0.0), trainable = False)
92 | conv_w1 = tf.get_variable( "weight_lstm1", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.001) )
93 | conv_b1 = tf.get_variable( "bias1", 1, initializer = tf.constant_initializer(0.0) )
94 | cell_g1 = ConvLSTMCell([self.output_h, self.output_w], filters_lstm, kernal_lstm)
95 | init_state_g1 = cell_g1.zero_state(batch_size, dtype = tf.float32)
96 | state1 = init_state_g1
97 | for timestep in range(iteration):
98 | if timestep > 0:
99 | tf.get_variable_scope().reuse_variables()
100 | rh_tr = tf.transpose(rh, perm = [0, 2, 1, 3])
101 | g1 = odl_op_layer_adjoint( (odl_op_layer(rh_tr) - mh) )
102 | gt1 = tf.transpose(g1, perm = [0, 2, 1, 3])
103 | (cell_output1, state1) = cell_g1(gt1, state1)
104 | conv1 = tf.nn.conv2d(cell_output1, conv_w1, [1, 1, 1, 1], "VALID")
105 | s1 = tf.nn.tanh(tf.nn.bias_add(conv1, conv_b1))
106 | self.variable_summaries(s1, ('s1_%d'%timestep))
107 |
108 | rh = rh + 0.0001*s1*gt1
109 | rh = tf.clip_by_value(rh, 0, 5)
110 | tf.summary.image('rh_pred_%d'%timestep, rh, 1)
111 |
112 | with tf.variable_scope('LSTM_2'):
113 | rl = tf.get_variable("rl", [batch_size, self.output_h, self.output_w, self.output_depth], initializer = tf.constant_initializer(0.0), trainable = False)
114 | conv_w2 = tf.get_variable("weight_lstm2", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.001) )
115 | conv_b2 = tf.get_variable("bias2", 1, initializer = tf.constant_initializer(0.0) )
116 | cell_g2 = ConvLSTMCell([self.output_h, self.output_w], filters_lstm, kernal_lstm)
117 | init_state_g2 = cell_g2.zero_state(batch_size, dtype = tf.float32)
118 | state2 = init_state_g2
119 | for timestep in range(iteration):
120 | if timestep > 0:
121 | tf.get_variable_scope().reuse_variables()
122 | rl_tr = tf.transpose(rl, perm = [0, 2, 1, 3])
123 | g2 = odl_op_layer_adjoint( (odl_op_layer(rl_tr) - ml) )
124 | gt2 = tf.transpose(g2, perm = [0, 2, 1, 3])
125 |
126 | (cell_output2, state2) = cell_g2(gt2, state2)
127 | conv2 = tf.nn.conv2d(cell_output2, conv_w2, [1, 1, 1, 1], "VALID")
128 | s2 = tf.nn.tanh(tf.nn.bias_add(conv2, conv_b2))
129 | self.variable_summaries(s2, ('s2_%d'%timestep))
130 |
131 | rl = rl + 0.0001*s2*gt2
132 | rl = tf.clip_by_value(rl, 0, 5)
133 | tf.summary.image('rl_pred_%d'%timestep, rl, 1)
134 |
135 | #CNN
136 | layer_L1 = layer_xyf.convo(rl, "conv_L1", FILTER_1, STRIDE_1, PAD_1)
137 | layer_L2 = layer_xyf.convo(layer_L1, "conv_L2", FILTER_2, STRIDE_2, PAD_2)
138 | layer_L3 = layer_xyf.convo(layer_L2, "conv_L3", FILTER_3, STRIDE_3, PAD_3)
139 | layer_L4 = layer_xyf.convo(layer_L3, "conv_L4", FILTER_4, STRIDE_4, PAD_4)
140 | layer_L5 = layer_xyf.convo(layer_L4, "conv_L5", FILTER_5, STRIDE_5, PAD_5)
141 |
142 | layer_H1 = layer_xyf.convo(rh, "conv_H1", FILTER_1, STRIDE_1, PAD_1)
143 | layer_H2 = layer_xyf.convo(layer_H1, "conv_H2", FILTER_2, STRIDE_2, PAD_2)
144 | layer_H3 = layer_xyf.convo(layer_H2, "conv_H3", FILTER_3, STRIDE_3, PAD_3)
145 | layer_H4 = layer_xyf.convo(layer_H3, "conv_H4", FILTER_4, STRIDE_4, PAD_4)
146 | layer_H5 = layer_xyf.convo(layer_H4, "conv_H5", FILTER_5, STRIDE_5, PAD_5)
147 |
148 | combine_LH = tf.concat([layer_L5, layer_H5], 3)
149 |
150 | pa_pred = layer_xyf.convo_noneRelu(combine_LH, "conv_pa", FILTER_6, STRIDE_6, PAD_6)
151 | pb_pred = layer_xyf.convo_noneRelu(combine_LH, "conv_pb", FILTER_6, STRIDE_6, PAD_6)
152 | pc_pred = layer_xyf.convo_noneRelu(combine_LH, "conv_pc", FILTER_6, STRIDE_6, PAD_6)
153 | pd_pred = layer_xyf.convo_noneRelu(combine_LH, "conv_pd", FILTER_6, STRIDE_6, PAD_6)
154 |
155 | d1 = pa_pred*rh + pb_pred*rl
156 | d2 = pc_pred*rh + pd_pred*rl
157 |
158 | d1 = tf.clip_by_value(d1, 0, 5)
159 | d2 = tf.clip_by_value(d2, 0, 5)
160 |
161 | tf.summary.image('d1_pred', d1, 1)
162 | tf.summary.image('d2_pred', d2, 1)
163 |
164 | return d1, d2, rl, rh
165 |
166 | def variable_summaries(self, var, var_name):
167 | with tf.name_scope(var_name):
168 | mean = tf.reduce_mean(var)
169 | tf.summary.scalar('mean', mean)
170 | with tf.name_scope('stddev'):
171 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
172 | tf.summary.scalar('stddev', stddev)
173 | tf.summary.scalar('max', tf.reduce_max(var))
174 | tf.summary.scalar('min', tf.reduce_min(var))
175 | tf.summary.histogram('histogram', var)
176 |
--------------------------------------------------------------------------------
/src/solver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import scipy.io
3 | import time
4 | import tensorflow as tf
5 | import numpy as np
6 |
7 | from functools import reduce
8 | from operator import mul
9 |
10 | def nn_train(nn, config):
11 | print ("Loading .tfrecords data...", config.dataset)
12 | if not os.path.exists(config.dataset):
13 | raise Exception("training data not find.")
14 | return
15 |
16 | time_str = time.strftime('%Y-%m-%d-%H_%M_%S', time.localtime(time.time()))
17 | train_file_path = os.path.join(config.dataset, "train_sample_batches_*")
18 | train_files = tf.train.match_filenames_once(train_file_path)
19 | filename_queue = tf.train.string_input_producer(train_files, shuffle = True)
20 |
21 | reader = tf.TFRecordReader()
22 | _, serialized_example = reader.read(filename_queue)
23 | features = tf.parse_single_example(serialized_example,
24 | features = {'ml': tf.FixedLenFeature([], tf.string),
25 | 'mh': tf.FixedLenFeature([], tf.string),
26 | 'd_bone': tf.FixedLenFeature([], tf.string),
27 | 'd_tissue': tf.FixedLenFeature([], tf.string),
28 | 'rl': tf.FixedLenFeature([], tf.string),
29 | 'rh': tf.FixedLenFeature([], tf.string)})
30 |
31 | ml = tf.decode_raw(features['ml'], tf.float32)
32 | mh = tf.decode_raw(features['mh'], tf.float32)
33 | d1 = tf.decode_raw(features['d_bone'], tf.float32)
34 | d2 = tf.decode_raw(features['d_tissue'], tf.float32)
35 | rl = tf.decode_raw(features['rl'], tf.float32)
36 | rh = tf.decode_raw(features['rh'], tf.float32)
37 |
38 | ml = tf.reshape(ml, [nn.input_height, nn.input_width, nn.num_channel])
39 | mh = tf.reshape(mh, [nn.input_height, nn.input_width, nn.num_channel])
40 |
41 | d1 = tf.reshape(d1, [512, 512, nn.output_depth])
42 | d2 = tf.reshape(d2, [512, 512, nn.output_depth])
43 | rl = tf.reshape(rl, [512, 512, nn.output_depth])
44 | rh = tf.reshape(rh, [512, 512, nn.output_depth])
45 | com_resize = 256
46 | d1_resized = tf.image.resize_images(d1, [com_resize, com_resize], method = 0)
47 | d2_resized = tf.image.resize_images(d2, [com_resize, com_resize], method = 0)
48 | rl_resized = tf.image.resize_images(rl, [com_resize, com_resize], method = 0)
49 | rh_resized = tf.image.resize_images(rh, [com_resize, com_resize], method = 0)
50 |
51 |
52 | ml_batch, mh_batch, d1_batch, d2_batch, rl_batch, rh_batch = tf.train.shuffle_batch([ml, mh, d1_resized, d2_resized, rl_resized, rh_resized],
53 | batch_size = config.batch_size,
54 | capacity = config.batch_size*3 + 50,
55 | min_after_dequeue = 30)
56 |
57 | # loss function define
58 | d1_batch_pred, d2_batch_pred, rl_batch_pred, rh_batch_pred = nn.feedforward(ml_batch, mh_batch, rl_batch, rh_batch)
59 |
60 | global_step = tf.Variable(0, trainable = False)
61 | variable_averages = tf.train.ExponentialMovingAverage(config.moving_average_decay, global_step)
62 | variables_averages_op = variable_averages.apply(tf.trainable_variables())
63 |
64 | mse_loss = tf.reduce_mean(0.005*tf.square(d1_batch_pred - d1_batch) / 2 + 0.005*tf.square(d2_batch_pred - d2_batch) / 2 +
65 | tf.square(rl_batch_pred - rl_batch) / 2 + tf.square(rh_batch_pred - rh_batch) / 2)
66 |
67 | tf.add_to_collection('losses', mse_loss)
68 | tf.summary.scalar('MSE_losses', mse_loss)
69 |
70 | learning_rate = tf.train.exponential_decay(config.lr, global_step,
71 | config.sampleNum / config.batch_size, config.learning_rate_decay)
72 | train_step = tf.train.AdamOptimizer(learning_rate).minimize(mse_loss, global_step = global_step)
73 |
74 | with tf.control_dependencies([train_step, variables_averages_op]):
75 | train_op = tf.no_op(name = 'train')
76 |
77 | merged = tf.summary.merge_all()
78 | summary_path = os.path.join(config.summary_dir, time_str)
79 | os.mkdir(summary_path)
80 |
81 | run_config = tf.ConfigProto(allow_soft_placement = True)
82 | run_config.gpu_options.allow_growth = True
83 |
84 | num_params = 0
85 | for variable in tf.trainable_variables():
86 | shape = variable.get_shape()
87 | num_params += reduce(mul, [dim.value for dim in shape], 1)
88 | print("Number of trainable parameters: %d"%(num_params))
89 |
90 | with tf.Session(config = run_config) as sess:
91 | summary_writer = tf.summary.FileWriter(summary_path, sess.graph)
92 | saver = tf.train.Saver()
93 | if config.goon:
94 | if not os.path.exists(config.checkpoint):
95 | raise Exception("checkpoint path not find.")
96 | return
97 | print("Loading trained model... ", config.checkpoint)
98 | ckpt = tf.train.get_checkpoint_state(config.checkpoint)
99 | saver.retore(sess, ckpt)
100 | else:
101 | init = (tf.global_variables_initializer(), tf.local_variables_initializer())
102 | sess.run(init)
103 |
104 | save_model_path = os.path.join(config.output_model_dir, time_str)
105 | os.mkdir(save_model_path)
106 |
107 | coord = tf.train.Coordinator()
108 | threads = tf.train.start_queue_runners(sess = sess, coord = coord)
109 |
110 | variable_name = [v.name for v in tf.trainable_variables()]
111 | print(variable_name)
112 |
113 | print("start training sess...")
114 | start_time = time.time()
115 | for i in range(config.iteration):
116 | batch_start_time = time.time()
117 | summary, _, loss_value, step, learningRate = sess.run([merged, train_op, mse_loss, global_step, learning_rate])
118 | batch_end_time = time.time()
119 | sec_per_batch = batch_end_time - batch_start_time
120 | if step % config.summary_step == 0:
121 | summary_writer.add_summary(summary, step//config.summary_step)
122 | if step % config.model_step == 0:
123 | print("Saving model (after %d iteration)... " %(step))
124 | saver.save(sess, os.path.join(save_model_path, config.model_name + ".ckpt"), global_step = global_step)
125 | print("sec/batch(%d) %gs, global step %d batches, training epoch %d/%d, learningRate %g, loss on training is %g"
126 | % (config.batch_size, sec_per_batch, step, i*config.batch_size / config.sampleNum,
127 | config.epoch, learningRate, loss_value))
128 | print("Elapsed time: %gs" %(batch_end_time - start_time))
129 |
130 | coord.request_stop()
131 | coord.join(threads)
132 | summary_writer.close()
133 | print("Train done. ")
134 | print("Saving model... ", save_model_path)
135 | saver.save(sess, os.path.join(save_model_path, config.model_name + ".ckpt"), global_step = global_step)
136 | sess.close()
137 |
138 | print("Sess closed.")
139 |
140 | def nn_test(nn, config):
141 | print ("Loading .tfrecords data...", config.dataset)
142 | if not os.path.exists(config.dataset):
143 | raise Exception(".mat file not find.")
144 | return
145 |
146 | def nn_feedforward(nn, config):
147 | print ("Loading .mat data...", config.dataset)
148 | if not os.path.exists(config.dataset):
149 | print("Testing .mat file not find.")
150 | raise Exception("Testing .mat file not find.")
151 | return
152 |
153 | run_config = tf.ConfigProto()
154 | with tf.Session(config = run_config) as sess:
155 | mat_file_list = []
156 | filelist = os.listdir(config.dataset)
157 | for line in filelist:
158 | file = os.path.join(config.dataset, line)
159 | if os.path.isfile(file):
160 | mat_file_list.append(file)
161 | print(file)
162 |
163 | ml = tf.placeholder(tf.float32, [1, 360, 1024, 1], name = 'xl-input')
164 | mh = tf.placeholder(tf.float32, [1, 360, 1024, 1], name = 'xh-input')
165 |
166 | d1_pred, d2_pred, rl_pred, rh_pred = nn.feedforward(ml, mh)
167 |
168 | variable_averages = tf.train.ExponentialMovingAverage(config.moving_average_decay)
169 | variables_to_restore = variable_averages.variables_to_restore()
170 | saver = tf.train.Saver(variables_to_restore)
171 |
172 | ckpt = tf.train.get_checkpoint_state(config.checkpoint)
173 | if ckpt and ckpt.model_checkpoint_path:
174 | saver.restore(sess, ckpt.model_checkpoint_path)
175 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-'[-1])
176 | print('global step: ', global_step)
177 |
178 | file_num = 1
179 | for mat_file in mat_file_list:
180 | print('loading file: ', mat_file)
181 | mat = scipy.io.loadmat(mat_file)
182 | mat_ml = mat['ml']
183 | mat_mh = mat['mh']
184 |
185 | dim_h = mat_ml.shape[0]
186 | dim_w = mat_ml.shape[1]
187 | if mat_ml.ndim < 3:
188 | mat_ml = mat_ml.reshape(dim_h, dim_w, 1)
189 | mat_mh = mat_mh.reshape(dim_h, dim_w, 1)
190 | img_num = mat_ml.shape[2]
191 | print('image num: ', img_num)
192 |
193 | d1 = np.zeros((nn.output_h, nn.output_w, img_num), dtype = np.float32)
194 | d2 = np.zeros((nn.output_h, nn.output_w, img_num), dtype = np.float32)
195 | rl = np.zeros((nn.output_h, nn.output_w, img_num), dtype = np.float32)
196 | rh = np.zeros((nn.output_h, nn.output_w, img_num), dtype = np.float32)
197 |
198 | input_ml = np.zeros( (1, dim_h, dim_w, 1) )
199 | input_mh = np.zeros( (1, dim_h, dim_w, 1) )
200 |
201 | start_time = time.time()
202 | for i in range( int(img_num) ):
203 | input_ml = mat_ml[:,:,i].reshape([1, dim_h, dim_w, 1])
204 | input_mh = mat_mh[:,:,i].reshape([1, dim_h, dim_w, 1])
205 |
206 | output_d1, output_d2, output_rl, output_rh = sess.run([d1_pred, d2_pred, rl_pred, rh_pred], \
207 | feed_dict={ml: input_ml, mh: input_mh})
208 | output_d1.resize(1, nn.output_h, nn.output_w, 1)
209 | output_d2.resize(1, nn.output_h, nn.output_w, 1)
210 | output_rl.resize(1, nn.output_h, nn.output_w, 1)
211 | output_rh.resize(1, nn.output_h, nn.output_w, 1)
212 |
213 | d1[:,:,i] = output_d1.reshape([nn.output_h, nn.output_w])
214 | d2[:,:,i] = output_d2.reshape([nn.output_h, nn.output_w])
215 | rl[:,:,i] = output_rl.reshape([nn.output_h, nn.output_w])
216 | rh[:,:,i] = output_rh.reshape([nn.output_h, nn.output_w])
217 |
218 | end_time = time.time()
219 | print("Elapsed time: %gs" %(end_time - start_time))
220 |
221 | result_fileName = config.output_data_dir + '/' + 'result_test_%.4d.mat' %(file_num)
222 | scipy.io.savemat(result_fileName, {'d1':d1, 'd2':d2, 'rl':rl, 'rh':rh})
223 | print('Feed forward done. Result: ', result_fileName)
224 | file_num = file_num + 1
225 | else:
226 | print('No checkpoint file found.')
227 | return
228 |
--------------------------------------------------------------------------------
/src/sysMatrix_2D_all_angle.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XYF-GitHub/DualReconstruction/65e95bfb54d0f555c32d8c92417f6f3a63eb5078/src/sysMatrix_2D_all_angle.m
--------------------------------------------------------------------------------