├── README.md ├── data ├── system_matrix │ └── Read me ├── testing_set │ └── Read me └── training_set │ └── Read me ├── log └── Read me ├── model └── Read me ├── result └── Read me └── src ├── FBP_A.m ├── FBP_reconstruction.m ├── cell.py ├── combinedGradient.m ├── combinedGradientDescart.m ├── config.py ├── iterative_reconstruction.m ├── layer_xyf.py ├── main.py ├── model.py ├── solver.py └── sysMatrix_2D_all_angle.m /README.md: -------------------------------------------------------------------------------- 1 | # DualReconstruction 2 | 3 | This reposity is organized mainly for an image decomposition algorithm which is proposed to solve the image reconstruction and image-domain decomposition problem in Dual-energy Computed Tomography (DECT).
4 | 5 | The algorithm is designed based on deep learning paradigm. For more theoretical details, please go to [Deep Learning](http://www.deeplearningbook.org/) and [Material Decomposition Using DECT](https://pubs.rsna.org/doi/10.1148/rg.2016150220).
6 | 7 | The code is currently based on python 3.6, [Tensorflow](https://github.com/tensorflow/tensorflow) 1.4.0 and [ODL](https://github.com/odlgroup/odl) in Windows 7 platform.
8 | 9 | * data: contains 3 paths. 10 | * system_matrix: the system matrix used in the reconstruction algorithms. You can generate the matrix by running the 'iterative_reconstruction.m' or 'FBP_reconstruction.m' file in the src path. 11 | * testing_set: The data used for testing. 12 | * training_set: The data used for training the deep model. Both training and testing set can be download from [here](https://pan.baidu.com/s/1VfhTuNenuy2C6HAw1aWbZA)(Extraction number: t4ya).
13 | * log: save the Tensorflow log file in training process. 14 | * model: save the trained model 15 | * result: save the result generated by the reconstruction algorithms. 16 | * src: the codes for the proposed algorithm and two other competing ones: 17 | * Filter back projection (FBP) followed by direct matrix inversion (FBP_reconstruction.m) 18 | * Combined iterative reconstruction and image decomposition (iterative_reconstruction.m). Related paper: [Combined iterative reconstruction and image-domain decomposition for dual energy CT using total-variation regularization.](https://aapm.onlinelibrary.wiley.com/doi/abs/10.1118/1.4870375) 19 | * The proposed deep model (main.py). You can start to train the proposed deep model via the cmd: 20 | ```bash 21 | python main.py --dataset="../data/training_set/" --mode="train" --model_name="your-saved-model-result-name" --lr = 0.0001 --epoch=30 --model_step=1000 --batch_size=1 22 | ``` 23 | After finishing the training process, you can test the trained model via the cmd: 24 | ```bash 25 | python main.py --dataset="../data/testing_set/" --mode="feedforward" --model_name="your-saved-model-result-name" --checkpoint="../model/your-saved-model-result-name" 26 | ``` 27 | 28 | # Contact 29 | Email: vastcyclone@yeah.net 30 | -------------------------------------------------------------------------------- /data/system_matrix/Read me: -------------------------------------------------------------------------------- 1 | This path is used to storage the system matrix applied in the algorithms. 2 | We do not provide download link of the matrix since it is as large as 2GB for 256*256 reconstruction problem. 3 | Please generate the system matrix in your local by running the 'iterative_reconstruction.m' or 'FBP_reconstruction.m' file in the src path. 4 | This should be done before starting to train the deep neural net. 5 | -------------------------------------------------------------------------------- /data/testing_set/Read me: -------------------------------------------------------------------------------- 1 | Please download the testing set from https://pan.baidu.com/s/1VfhTuNenuy2C6HAw1aWbZA (Extraction number: t4ya) and unzip the file to this path. 2 | -------------------------------------------------------------------------------- /data/training_set/Read me: -------------------------------------------------------------------------------- 1 | Please download the training set from https://pan.baidu.com/s/1VfhTuNenuy2C6HAw1aWbZA (Extraction number: t4ya) and unzip the file to this path. 2 | -------------------------------------------------------------------------------- /log/Read me: -------------------------------------------------------------------------------- 1 | This path is used to storage the Tensorflow log files in training process. 2 | -------------------------------------------------------------------------------- /model/Read me: -------------------------------------------------------------------------------- 1 | This path is used to storage the trained model which is tensorflow .ckpt format. 2 | -------------------------------------------------------------------------------- /result/Read me: -------------------------------------------------------------------------------- 1 | This path is used to storage the testing result files (.mat file format) generated by the algorithms. 2 | -------------------------------------------------------------------------------- /src/FBP_A.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XYF-GitHub/DualReconstruction/65e95bfb54d0f555c32d8c92417f6f3a63eb5078/src/FBP_A.m -------------------------------------------------------------------------------- /src/FBP_reconstruction.m: -------------------------------------------------------------------------------- 1 | close all; 2 | clear all; 3 | clc; 4 | 5 | test_data_path = '..\data\testing_set\'; 6 | result_path = '..\result\'; 7 | 8 | N = 256; 9 | prjLen = 1024; 10 | 11 | opts.Nx = N; % Size of the object 12 | opts.Ny = N; % 13 | opts.sod = 1000; % Distance from source to object (mm) 14 | opts.sdd = 1500; % Distance from source to detector 15 | opts.dt = 0.388; % Size of detector voxel 16 | opts.Uy = prjLen; % Number of projections 17 | opts.voxel = 1.0; % opts.voxel = opts.sod/opts.sdd*opts.dt*opts.Uy/opts.Nx; 18 | opts.Nz = 1; % 19 | opts.Vz = 1; % 20 | angcov = 360; 21 | angstp = angcov / (view_num + 1); 22 | theta_vec = 0:angstp:angcov - 1; 23 | 24 | systemMatrix_file = '..\data\A_256.mat'; 25 | if exist(systemMatrix_file, 'file') 26 | disp('Loading sysMatrix...'); 27 | load(systemMatrix_file); 28 | else 29 | tic; 30 | disp('sysMatrix_2D_all_angle...'); 31 | [W_row, W_col, W_val, sumP2R, sumC, sumR, Row, Col, Val_Num] = sysMatrix_2D_all_angle(theta_vec,opts); 32 | disp('A sparse...'); 33 | A = sparse(W_row, W_col, W_val, Row, Col); 34 | clear W_row; 35 | clear W_col; 36 | clear W_val; 37 | toc; 38 | disp('Saving system matrix...'); 39 | save(systemMatrix_file, '-v7.3', 'A'); 40 | end 41 | 42 | %% 43 | x1H = 0.0342; 44 | x1L = 0.0588; 45 | x2H = 0.019; 46 | x2L = 0.0251; 47 | a = x2L/(x1H*x2L - x2H*x1L); 48 | b = -1*x2H/(x1H*x2L - x2H*x1L); 49 | c = -1*x1L/(x1H*x2L - x2H*x1L); 50 | d = x1H/(x1H*x2L - x2H*x1L); 51 | 52 | mat_files = dir(test_data_path); 53 | 54 | fopt.angstp = 1; 55 | fopt.angcov = 360; 56 | fopt.voxel = 1.03; 57 | fopt.filter = 1; % 1- RL;2- SL;3- cos;4- hamming;5- hann 58 | 59 | for f = 3:length(mat_files) 60 | file = [test_data_path, mat_files(f).name]; 61 | load(file); 62 | [view_num, prjLen, img_num] = size(mh); 63 | 64 | rh = zeros(N, N, img_num); 65 | rl = zeros(N, N, img_num); 66 | d1 = zeros(N, N, img_num); 67 | d2 = zeros(N, N, img_num); 68 | 69 | disp(file); 70 | for n = 1:img_num 71 | tic; 72 | disp( sprintf('Image reconstruction %01d / %d ', n, img_num) ); 73 | 74 | xh = FBP_A(A,double(mh(:,:,n)'),fopt); 75 | xl = FBP_A(A,double(ml(:,:,n)'),fopt); 76 | 77 | rh(:,:,n) = reshape(xh,N,N); 78 | rl(:,:,n) = reshape(xl,N,N); 79 | d1(:,:,n) = a*rh(:,:,n) + b*rl(:,:,n); 80 | d2(:,:,n) = c*rh(:,:,n) + d*rl(:,:,n); 81 | toc; 82 | end 83 | savefileName = [result_path, sprintf('result_test_%04d.mat', f - 2) ]; 84 | disp(['Saving file: ', savefileName]); 85 | save(savefileName, 'rh', 'rl', 'd1', 'd2'); 86 | end 87 | 88 | -------------------------------------------------------------------------------- /src/cell.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class ConvLSTMCell(tf.nn.rnn_cell.RNNCell): 4 | """A LSTM cell with convolutions instead of multiplications. 5 | 6 | Reference: 7 | Xingjian, S. H. I., et al. "Convolutional LSTM network: A machine learning approach for precipitation nowcasting." Advances in Neural Information Processing Systems. 2015. 8 | """ 9 | 10 | def __init__(self, shape, filters, kernel, forget_bias=1.0, activation=tf.tanh, normalize=True, peephole=True, data_format='channels_last', reuse=None): 11 | super(ConvLSTMCell, self).__init__(_reuse=reuse) 12 | self._kernel = kernel 13 | self._filters = filters 14 | self._forget_bias = forget_bias 15 | self._activation = activation 16 | self._normalize = normalize 17 | self._peephole = peephole 18 | if data_format == 'channels_last': 19 | self._size = tf.TensorShape(shape + [self._filters]) 20 | self._feature_axis = self._size.ndims 21 | self._data_format = None 22 | elif data_format == 'channels_first': 23 | self._size = tf.TensorShape([self._filters] + shape) 24 | self._feature_axis = 0 25 | self._data_format = 'NC' 26 | else: 27 | raise ValueError('Unknown data_format') 28 | 29 | @property 30 | def state_size(self): 31 | return tf.nn.rnn_cell.LSTMStateTuple(self._size, self._size) 32 | 33 | @property 34 | def output_size(self): 35 | return self._size 36 | 37 | def call(self, x, state): 38 | c, h = state 39 | 40 | x = tf.concat([x, h], axis=self._feature_axis) 41 | n = x.shape[-1].value 42 | m = 4 * self._filters if self._filters > 1 else 4 43 | W = tf.get_variable('kernel', self._kernel + [n, m]) 44 | y = tf.nn.convolution(x, W, 'SAME', data_format=self._data_format) 45 | if not self._normalize: 46 | y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer()) 47 | j, i, f, o = tf.split(y, 4, axis=self._feature_axis) 48 | 49 | if self._peephole: 50 | i += tf.get_variable('W_ci', c.shape[1:]) * c 51 | f += tf.get_variable('W_cf', c.shape[1:]) * c 52 | 53 | if self._normalize: 54 | j = tf.contrib.layers.layer_norm(j) 55 | i = tf.contrib.layers.layer_norm(i) 56 | f = tf.contrib.layers.layer_norm(f) 57 | 58 | f = tf.sigmoid(f + self._forget_bias) 59 | i = tf.sigmoid(i) 60 | c = c * f + i * self._activation(j) 61 | 62 | if self._peephole: 63 | o += tf.get_variable('W_co', c.shape[1:]) * c 64 | 65 | if self._normalize: 66 | o = tf.contrib.layers.layer_norm(o) 67 | c = tf.contrib.layers.layer_norm(c) 68 | 69 | o = tf.sigmoid(o) 70 | h = o * self._activation(c) 71 | 72 | state = tf.nn.rnn_cell.LSTMStateTuple(c, h) 73 | 74 | return h, state 75 | 76 | 77 | class ConvGRUCell(tf.nn.rnn_cell.RNNCell): 78 | """A GRU cell with convolutions instead of multiplications.""" 79 | 80 | def __init__(self, shape, filters, kernel, activation=tf.tanh, normalize=True, data_format='channels_last', reuse=None): 81 | super(ConvGRUCell, self).__init__(_reuse=reuse) 82 | self._filters = filters 83 | self._kernel = kernel 84 | self._activation = activation 85 | self._normalize = normalize 86 | if data_format == 'channels_last': 87 | self._size = tf.TensorShape(shape + [self._filters]) 88 | self._feature_axis = self._size.ndims 89 | self._data_format = None 90 | elif data_format == 'channels_first': 91 | self._size = tf.TensorShape([self._filters] + shape) 92 | self._feature_axis = 0 93 | self._data_format = 'NC' 94 | else: 95 | raise ValueError('Unknown data_format') 96 | 97 | @property 98 | def state_size(self): 99 | return self._size 100 | 101 | @property 102 | def output_size(self): 103 | return self._size 104 | 105 | def call(self, x, h): 106 | channels = x.shape[self._feature_axis].value 107 | 108 | with tf.variable_scope('gates'): 109 | inputs = tf.concat([x, h], axis=self._feature_axis) 110 | n = channels + self._filters 111 | m = 2 * self._filters if self._filters > 1 else 2 112 | W = tf.get_variable('kernel', self._kernel + [n, m]) 113 | y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format) 114 | if self._normalize: 115 | r, u = tf.split(y, 2, axis=self._feature_axis) 116 | r = tf.contrib.layers.layer_norm(r) 117 | u = tf.contrib.layers.layer_norm(u) 118 | else: 119 | y += tf.get_variable('bias', [m], initializer=tf.ones_initializer()) 120 | r, u = tf.split(y, 2, axis=self._feature_axis) 121 | r, u = tf.sigmoid(r), tf.sigmoid(u) 122 | 123 | with tf.variable_scope('candidate'): 124 | inputs = tf.concat([x, r * h], axis=self._feature_axis) 125 | n = channels + self._filters 126 | m = self._filters 127 | W = tf.get_variable('kernel', self._kernel + [n, m]) 128 | y = tf.nn.convolution(inputs, W, 'SAME', data_format=self._data_format) 129 | if self._normalize: 130 | y = tf.contrib.layers.layer_norm(y) 131 | else: 132 | y += tf.get_variable('bias', [m], initializer=tf.zeros_initializer()) 133 | h = u * h + (1 - u) * self._activation(y) 134 | 135 | return h, h 136 | -------------------------------------------------------------------------------- /src/combinedGradient.m: -------------------------------------------------------------------------------- 1 | function [g1,g2] = combinedGradient(f1,f2,a,b) 2 | %% 3 | 4 | %% 5 | [row1,col1] = size(f1); 6 | [row2,col2] = size(f2); 7 | 8 | ff1 = zeros(row1+4,col1+4); 9 | ff2 = zeros(row2+4,col2+4); 10 | 11 | ff1(3:end-2,3:end-2) = f1; 12 | ff2(3:end-2,3:end-2) = f2; 13 | 14 | g1 = zeros(size(f1)); 15 | g2 = zeros(size(f2)); 16 | 17 | if row1~=row2 || col1~=col2 18 | error('Error! Size of f1 and f2 not match!'); 19 | end 20 | 21 | for i=3:(row1+2) 22 | for j=3:(col1+2) 23 | %%% 24 | v1 = a*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i-1,j)-b*ff2(i-1,j) ) + a*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i,j-1)-b*ff2(i,j-1) ); 25 | v2 = a*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i+1,j)-b*ff2(i+1,j) ); 26 | v3 = a*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i,j+1)-b*ff2(i,j+1) ); 27 | %%%% 28 | t1 = b*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i-1,j)-b*ff2(i-1,j) ) + b*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i,j-1)-b*ff2(i,j-1) ); 29 | t2 = b*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i+1,j)-b*ff2(i+1,j) ); 30 | t3 = b*( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i,j+1)-b*ff2(i,j+1) ); 31 | %%% 32 | denom1 = sqrt( eps+( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i-1,j)-b*ff2(i-1,j) )^2+( a*ff1(i,j)+b*ff2(i,j)-a*ff1(i,j-1)-b*ff2(i,j-1) )^2 ); 33 | denom2 = sqrt( eps+( a*ff1(i+1,j)+b*ff2(i+1,j)-a*ff1(i,j)-b*ff2(i,j) )^2+( a*ff1(i+1,j)+b*ff2(i+1,j)-a*ff1(i+1,j-1)-b*ff2(i+1,j-1) )^2 ); 34 | denom3 = sqrt( eps+( a*ff1(i,j+1)+b*ff2(i,j+1)-a*ff1(i-1,j+1)-b*ff2(i-1,j+1) )^2+( a*ff1(i,j+1)+b*ff2(i,j+1)-a*ff1(i,j)-b*ff2(i,j) )^2 ); 35 | 36 | g1(i-2,j-2) = v1/denom1+v2/denom2+v3/denom3; 37 | g2(i-2,j-2) = t1/denom1+t2/denom2+t3/denom3; 38 | end 39 | end 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | end 54 | -------------------------------------------------------------------------------- /src/combinedGradientDescart.m: -------------------------------------------------------------------------------- 1 | function [g] = combinedGradientDescart(f1,f2,beta1,beta2,a,b,c,d) 2 | %% 3 | 4 | %% 5 | 6 | [gh1,gl1] = combinedGradient(f1,f2,a,b); 7 | [gh2,gl2] = combinedGradient(f1,f2,c,d); 8 | g = [beta1*gh1(:)+beta2*gh2(:);beta1*gl1(:)+beta2*gl2(:)]; 9 | 10 | end 11 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | def config_par(config): 5 | config.DEFINE_string("output_model_dir", "../model", "") 6 | config.DEFINE_string("output_data_dir", "../result", "") 7 | config.DEFINE_string("model_name", "lstm", "") 8 | config.DEFINE_integer("sampleNum", 3226, "") 9 | config.DEFINE_integer("iteration", math.ceil( config.FLAGS.sampleNum / config.FLAGS.batch_size * config.FLAGS.epoch ), "") 10 | 11 | config.DEFINE_string("summary_dir", "../log", "") 12 | config.DEFINE_integer("summary_step", 20, "") 13 | 14 | config.DEFINE_float("learning_rate_decay", 0.85, "") 15 | config.DEFINE_float("moving_average_decay", 0.99, "") 16 | config.DEFINE_float("regularazition_rate", 0.0001, "") 17 | 18 | if not os.path.exists(config.FLAGS.output_model_dir): 19 | os.makedirs(config.FLAGS.output_model_dir) 20 | if not os.path.exists(config.FLAGS.output_data_dir): 21 | os.makedirs(config.FLAGS.output_data_dir) 22 | if not os.path.exists(config.FLAGS.summary_dir): 23 | os.makedirs(config.FLAGS.summary_dir) 24 | 25 | return config -------------------------------------------------------------------------------- /src/iterative_reconstruction.m: -------------------------------------------------------------------------------- 1 | close all; 2 | clear all; 3 | clc; 4 | 5 | test_data_path = '..\data\testing_set\'; 6 | result_path = '..\result\'; 7 | 8 | N = 256; 9 | prjLen = 1024; 10 | 11 | opts.Nx = N; % Size of the object 12 | opts.Ny = N; % 13 | opts.sod = 1000; % Distance from source to object (mm) 14 | opts.sdd = 1500; % Distance from source to detector 15 | opts.dt = 0.388; % Size of detector voxel 16 | opts.Uy = prjLen; % Number of projections 17 | opts.voxel = 1.03; % opts.voxel = opts.sod/opts.sdd*opts.dt*opts.Uy/opts.Nx; 18 | opts.Nz = 1; % 19 | opts.Vz = 1; % 20 | angcov = 360; 21 | angstp = angcov / (view_num + 1); 22 | theta_vec = 0:angstp:angcov - 1; 23 | 24 | systemMatrix_file = '..\data\A_256.mat'; 25 | if exist(systemMatrix_file, 'file') 26 | disp('Loading sysMatrix...'); 27 | load(systemMatrix_file); 28 | else 29 | tic; 30 | disp('sysMatrix_2D_all_angle...'); 31 | [W_row, W_col, W_val, sumP2R, sumC, sumR, Row, Col, Val_Num] = sysMatrix_2D_all_angle(theta_vec,opts); 32 | disp('A sparse...'); 33 | A = sparse(W_row, W_col, W_val, Row, Col); 34 | clear W_row; 35 | clear W_col; 36 | clear W_val; 37 | toc; 38 | disp('Saving system matrix...'); 39 | save(systemMatrix_file, '-v7.3', 'A'); 40 | end 41 | 42 | Niter = 100; 43 | x1H = 0.0342; 44 | x1L = 0.0588; 45 | x2H = 0.019; 46 | x2L = 0.0251; 47 | a = x2L/(x1H*x2L - x2H*x1L); 48 | b = -1*x2H/(x1H*x2L - x2H*x1L); 49 | c = -1*x1L/(x1H*x2L - x2H*x1L); 50 | d = x1H/(x1H*x2L - x2H*x1L); 51 | 52 | mat_files = dir(test_data_path); 53 | 54 | for f = 3:length(mat_files) 55 | file = [test_data_path, mat_files(f).name]; 56 | load(file); 57 | [view_num, prjLen, img_num] = size(mh); 58 | 59 | rh = zeros(N, N, img_num); 60 | rl = zeros(N, N, img_num); 61 | d1 = zeros(N, N, img_num); 62 | d2 = zeros(N, N, img_num); 63 | 64 | disp(file); 65 | for n = 1:img_num 66 | tic 67 | disp( sprintf('Image reconstruction %01d / %d ', n, img_num) ); 68 | ml_slice = ml(:,:,n)'; 69 | mh_slice = mh(:,:,n)'; 70 | ml_slice = double(ml_slice(:)); 71 | mh_slice = double(mh_slice(:)); 72 | 73 | xh = zeros(N*N,1); 74 | xl = zeros(N*N,1); 75 | xh_old = zeros(N*N,1); 76 | xl_old = zeros(N*N,1); 77 | x = zeros(2*N*N,1); 78 | px = zeros(2*N*N,1); 79 | px_old = zeros(2*N*N,1); 80 | 81 | beta1 = 0.003; 82 | beta2 = 0.0045; 83 | beta1_red = 0.99; 84 | beta2_red = 0.99; 85 | 86 | kai = 0.3; 87 | tol = eps; 88 | 89 | for i = 1:Niter 90 | x = [xh;xl]; 91 | gtv = combinedGradientDescart(reshape(xh,N,N),reshape(xl,N,N),beta1,beta2,a,b,c,d); 92 | gquadh = A'*(A*double(xh) - mh_slice); 93 | gquadl = A'*(A*double(xl) - ml_slice); 94 | g = single(gtv + [gquadh(:);gquadl(:)]); 95 | 96 | px(:) = 0; 97 | px(g <= 0 | x > 0) = g(g <= 0 | x > 0); 98 | if i == 1 99 | xh_old = single(zeros(N*N,1)); 100 | xl_old = single(zeros(N*N,1)); 101 | x_old = [xh_old;xl_old]; 102 | px_old = single(zeros(2*N*N,1)); 103 | alpha = 10E-8; 104 | else 105 | alpha1 = (x - x_old)'*(x - x_old) / ((x - x_old)'*(px - px_old)); 106 | alpha2 = (x - x_old)'*(px - px_old) / ((px - px_old)'*(px - px_old)); 107 | if alpha2 / alpha1 < kai 108 | alpha = alpha1; 109 | else 110 | alpha = alpha2; 111 | end 112 | x_old = x; 113 | px_old = px; 114 | end 115 | x = x - alpha*px; 116 | x(x < 0) = 0; 117 | xh = x(1:N*N); 118 | xl = x((N*N + 1):end); 119 | 120 | if norm(x - x_old) <= tol 121 | break; 122 | end 123 | beta1 = beta1*beta1_red; 124 | beta2 = beta2*beta2_red; 125 | end 126 | 127 | rh(:,:,n) = reshape(xh,N,N); 128 | rl(:,:,n) = reshape(xl,N,N); 129 | d1(:,:,n) = a*rh(:,:,n) + b*rl(:,:,n); 130 | d2(:,:,n) = c*rh(:,:,n) + d*rl(:,:,n); 131 | % figure(1); imshow(rh(:,:,n),[]), title(['rh image: ', num2str(n)]); 132 | % figure(2); imshow(rl(:,:,n),[]), title(['rl image: ', num2str(n)]); 133 | % figure(3); imshow(d1(:,:,n),[]), title(['d1 image: ', num2str(n)]); 134 | % figure(4); imshow(d2(:,:,n),[]), title(['d2 image: ', num2str(n)]); 135 | toc; 136 | end 137 | savefileName = [result_path, sprintf('result_test_%04d.mat', f - 2) ]; 138 | disp(['Saving file: ', savefileName]); 139 | save(savefileName, 'rh', 'rl', 'd1', 'd2'); 140 | end 141 | 142 | -------------------------------------------------------------------------------- /src/layer_xyf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Dec 1 20:39:12 2017 4 | 5 | @author: XYF 6 | """ 7 | 8 | import tensorflow as tf 9 | 10 | def variable_summaries(var, var_name): 11 | with tf.name_scope(var_name): 12 | mean = tf.reduce_mean(var) 13 | tf.summary.scalar('mean', mean) 14 | with tf.name_scope('stddev'): 15 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 16 | tf.summary.scalar('stddev', stddev) 17 | tf.summary.scalar('max', tf.reduce_max(var)) 18 | tf.summary.scalar('min', tf.reduce_min(var)) 19 | tf.summary.histogram('histogram', var) 20 | 21 | def convo(input_layer, layer_name, filter_shape, stride, padStr): 22 | with tf.variable_scope(layer_name): 23 | input_shape = input_layer.get_shape().as_list() 24 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3]) 25 | 26 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1)) 27 | conv_b = tf.get_variable("bias", filter_shape[3], initializer = tf.constant_initializer(0.0)) 28 | conv = tf.nn.conv2d(input_layer, conv_w, [1, stride, stride, 1], padStr) 29 | relu = tf.nn.relu(tf.nn.bias_add(conv, conv_b)) 30 | return relu 31 | 32 | def convo_sigmoid(input_layer, layer_name, filter_shape, stride, padStr): 33 | with tf.variable_scope(layer_name): 34 | input_shape = input_layer.get_shape().as_list() 35 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3]) 36 | 37 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1)) 38 | conv_b = tf.get_variable("bias", filter_shape[3], initializer = tf.constant_initializer(0.0)) 39 | conv = tf.nn.conv2d(input_layer, conv_w, [1, stride, stride, 1], padStr) 40 | sig = tf.nn.sigmoid(tf.nn.bias_add(conv, conv_b)) 41 | return sig 42 | 43 | def convo_noneRelu(input_layer, layer_name, filter_shape, stride, padStr): 44 | with tf.variable_scope(layer_name): 45 | input_shape = input_layer.get_shape().as_list() 46 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3]) 47 | 48 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1)) 49 | conv_b = tf.get_variable("bias", filter_shape[3], initializer = tf.constant_initializer(0.0)) 50 | conv = tf.nn.conv2d(input_layer, conv_w, [1, stride, stride, 1], padStr) 51 | none_relu = tf.nn.bias_add(conv, conv_b) 52 | 53 | variable_summaries(conv_w, "conv_w") 54 | variable_summaries(conv_b, "conv_b") 55 | variable_summaries(none_relu, "none_relu") 56 | return none_relu 57 | 58 | def deconvo(input_layer, layer_name, filter_shape, out_img_size, stride, padStr): 59 | with tf.variable_scope(layer_name): 60 | input_shape = input_layer.get_shape().as_list() 61 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3]) 62 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1)) 63 | input_shape = tf.shape(input_layer) 64 | output_shape = tf.stack([input_shape[0], out_img_size[0], out_img_size[1], filter_shape[2]]) 65 | conv_b = tf.get_variable("bias", filter_shape[2], initializer = tf.constant_initializer(0.0)) 66 | conv = tf.nn.conv2d_transpose(input_layer, conv_w, output_shape, [1, stride, stride, 1], padStr) 67 | deconv = tf.nn.bias_add(conv, conv_b) 68 | return deconv 69 | 70 | def deconv_withRelu(input_layer, layer_name, filter_shape, out_img_size, stride, padStr): 71 | with tf.variable_scope(layer_name): 72 | input_shape = input_layer.get_shape().as_list() 73 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3]) 74 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1)) 75 | input_shape = tf.shape(input_layer) 76 | output_shape = tf.stack([input_shape[0], out_img_size[0], out_img_size[1], filter_shape[2]]) 77 | conv_b = tf.get_variable("bias", filter_shape[2], initializer = tf.constant_initializer(0.0)) 78 | conv = tf.nn.conv2d_transpose(input_layer, conv_w, output_shape, [1, stride, stride, 1], padStr) 79 | relu = tf.nn.relu(tf.nn.bias_add(conv, conv_b)) 80 | return relu 81 | 82 | def deconv_withSigmoid(input_layer, layer_name, filter_shape, out_img_size, stride, padStr): 83 | with tf.variable_scope(layer_name): 84 | input_shape = input_layer.get_shape().as_list() 85 | print(layer_name, " input shape:", input_shape[0], input_shape[1], input_shape[2], input_shape[3]) 86 | conv_w = tf.get_variable("weight", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.1)) 87 | input_shape = tf.shape(input_layer) 88 | output_shape = tf.stack([input_shape[0], out_img_size[0], out_img_size[1], filter_shape[2]]) 89 | conv_b = tf.get_variable("bias", filter_shape[2], initializer = tf.constant_initializer(0.0)) 90 | conv = tf.nn.conv2d_transpose(input_layer, conv_w, output_shape, [1, stride, stride, 1], padStr) 91 | sig = tf.nn.sigmoid(tf.nn.bias_add(conv, conv_b)) 92 | return sig 93 | 94 | def pooling(input_layer, layer_name, kernal_shape, stride, padStr): 95 | with tf.name_scope(layer_name): 96 | pool = tf.nn.max_pool(input_layer, kernal_shape, [1,stride,stride,1], padStr) 97 | return pool 98 | 99 | def pooling_withmax(input_layer, layer_name, kernal_shape, stride, padStr): 100 | with tf.name_scope(layer_name): 101 | return tf.nn.max_pool_with_argmax(input_layer, kernal_shape, [1,stride,stride,1], padStr) 102 | 103 | def unravel_argmax(argmax, shape): 104 | output_list = [] 105 | output_list.append(argmax // (shape[2] * shape[3])) 106 | output_list.append(argmax % (shape[2] * shape[3]) // shape[3]) 107 | return tf.stack(output_list) 108 | 109 | def unpooling_layer2x2(x, layer_name, raveled_argmax, out_shape): 110 | with tf.name_scope(layer_name): 111 | argmax = unravel_argmax(raveled_argmax, tf.to_int64(out_shape)) 112 | output = tf.zeros([out_shape[1], out_shape[2], out_shape[3]]) 113 | 114 | height = tf.shape(output)[0] 115 | width = tf.shape(output)[1] 116 | channels = tf.shape(output)[2] 117 | 118 | t1 = tf.to_int64(tf.range(channels)) 119 | t1 = tf.tile(t1, [((width + 1) // 2) * ((height + 1) // 2)]) 120 | t1 = tf.reshape(t1, [-1, channels]) 121 | t1 = tf.transpose(t1, perm=[1, 0]) 122 | t1 = tf.reshape(t1, [channels, (height + 1) // 2, (width + 1) // 2, 1]) 123 | 124 | t2 = tf.squeeze(argmax) 125 | t2 = tf.stack((t2[0], t2[1]), axis=0) 126 | t2 = tf.transpose(t2, perm=[3, 1, 2, 0]) 127 | 128 | t = tf.concat([t2, t1], 3) 129 | indices = tf.reshape(t, [((height + 1) // 2) * ((width + 1) // 2) * channels, 3]) 130 | 131 | x1 = tf.squeeze(x) 132 | x1 = tf.reshape(x1, [-1, channels]) 133 | x1 = tf.transpose(x1, perm=[1, 0]) 134 | values = tf.reshape(x1, [-1]) 135 | 136 | delta = tf.SparseTensor(indices, values, tf.to_int64(tf.shape(output))) 137 | return tf.expand_dims(tf.sparse_tensor_to_dense(tf.sparse_reorder(delta)), 0) 138 | 139 | def unpooling_layer2x2_batch(bottom, layer_name, argmax): 140 | with tf.name_scope(layer_name): 141 | bottom_shape = tf.shape(bottom) 142 | top_shape = [bottom_shape[0], bottom_shape[1] * 2, bottom_shape[2] * 2, bottom_shape[3]] 143 | 144 | batch_size = top_shape[0] 145 | height = top_shape[1] 146 | width = top_shape[2] 147 | channels = top_shape[3] 148 | 149 | argmax_shape = tf.to_int64([batch_size, height, width, channels]) 150 | argmax = unravel_argmax(argmax, argmax_shape) 151 | 152 | t1 = tf.to_int64(tf.range(channels)) 153 | t1 = tf.tile(t1, [batch_size * (width // 2) * (height // 2)]) 154 | t1 = tf.reshape(t1, [-1, channels]) 155 | t1 = tf.transpose(t1, perm=[1, 0]) 156 | t1 = tf.reshape(t1, [channels, batch_size, height // 2, width // 2, 1]) 157 | t1 = tf.transpose(t1, perm=[1, 0, 2, 3, 4]) 158 | 159 | t2 = tf.to_int64(tf.range(batch_size)) 160 | t2 = tf.tile(t2, [channels * (width // 2) * (height // 2)]) 161 | t2 = tf.reshape(t2, [-1, batch_size]) 162 | t2 = tf.transpose(t2, perm=[1, 0]) 163 | t2 = tf.reshape(t2, [batch_size, channels, height // 2, width // 2, 1]) 164 | 165 | t3 = tf.transpose(argmax, perm=[1, 4, 2, 3, 0]) 166 | 167 | t = tf.concat([t2, t3, t1], 4) 168 | indices = tf.reshape(t, [(height // 2) * (width // 2) * channels * batch_size, 4]) 169 | 170 | x1 = tf.transpose(bottom, perm=[0, 3, 1, 2]) 171 | values = tf.reshape(x1, [-1]) 172 | return tf.scatter_nd(indices, values, tf.to_int64(top_shape)) 173 | 174 | def fullyconnect(input_layer, layer_name, input_size, output_size, regularizer): 175 | with tf.variable_scope(layer_name): 176 | fc_w = tf.get_variable("weight", [input_size, output_size], 177 | initializer = tf.truncated_normal_initializer(stddev = 0.1)) 178 | if regularizer != 0: 179 | tf.add_to_collection('losses', regularizer(fc_w)) 180 | fc_b = tf.get_variable("bias", [output_size], initializer = tf.constant_initializer(0.1)) 181 | fc = tf.nn.sigmoid(tf.matmul(input_layer, fc_w) + fc_b) 182 | #if tain: fc = tf.nn.dropout(fc, 0.5) 183 | return fc 184 | 185 | def linearlyconnect(input_layer, layer_name, input_size, output_size, regularizer): 186 | with tf.variable_scope(layer_name): 187 | fc_w = tf.get_variable("weight", [input_size, output_size], 188 | initializer = tf.truncated_normal_initializer(stddev = 0.1)) 189 | if regularizer != 0: 190 | tf.add_to_collection('losses', regularizer(fc_w)) 191 | 192 | fc_b = tf.get_variable("bias", [output_size], initializer = tf.constant_initializer(0.1)) 193 | logit = tf.matmul(input_layer, fc_w) + fc_b 194 | return logit -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tensorflow as tf 4 | import pprint 5 | 6 | from model import DualReconstructionModel 7 | from solver import nn_train, nn_test, nn_feedforward 8 | from config import config_par 9 | 10 | flags = tf.app.flags 11 | flags.DEFINE_string("dataset", "data", ".tfRecord training file or .mat testing file [data]") 12 | flags.DEFINE_string("mode", "train", "train, test or feedforward [train]") 13 | flags.DEFINE_integer("epoch", 30, "Epoch to train [500]") 14 | flags.DEFINE_integer("batch_size", 10, "The size of batch images [10]") 15 | flags.DEFINE_integer("model_step", 1000, "The number of iteration to save model [1000]") 16 | flags.DEFINE_float("lr", 0.01, "The base learning rate [0.01]") 17 | flags.DEFINE_string("checkpoint", "checkpoint", "The path of checkpoint") 18 | flags.DEFINE_boolean("goon", False, "Go on training flag [0]") 19 | 20 | def main(_): 21 | os.chdir(sys.path[0]) 22 | print("Current cwd: ", os.getcwd()) 23 | 24 | config = config_par(flags) 25 | pp = pprint.PrettyPrinter() 26 | pp.pprint(config.FLAGS.__flags) 27 | FLAGS = config.FLAGS 28 | 29 | print("Building model ...") 30 | nn = DualReconstructionModel() 31 | if FLAGS.mode == "train": 32 | print("Mode: train ...") 33 | nn_train(nn, FLAGS) 34 | elif FLAGS.mode == "test": 35 | print("Mode: test ...") 36 | nn_test(nn, FLAGS) 37 | elif FLAGS.mode == "feedforward": 38 | print("Mode: feedforward ...") 39 | nn_feedforward(nn, FLAGS) 40 | 41 | exit() 42 | 43 | if __name__ == '__main__': 44 | tf.app.run() -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import numpy as np 4 | import odl 5 | import odl.contrib.tensorflow 6 | 7 | from cell import ConvLSTMCell 8 | 9 | import layer_xyf 10 | 11 | # input size 12 | prjLen = 1024 13 | view_num = 360 14 | channel = 1 15 | 16 | # output size 17 | output_h = 256 18 | output_w = 256 19 | output_depth = 1 20 | 21 | # LSTM hidden size 22 | filters_lstm = 64 23 | kernal_lstm = [3, 3] 24 | 25 | #conv: [filter_height, filter_width, in_channels, out_channels] 26 | filter_shape = [256, 256, filters_lstm, 1] 27 | 28 | sod = 1000 29 | sdd = 1500 30 | 31 | iteration = 10 32 | 33 | # CNN 34 | #layer_1 input_size 128*128 / 256*256 / 512*512 conv: [filter_height, filter_width, in_channels, out_channels] 35 | FILTER_1 = [3, 3, 1, 32] 36 | STRIDE_1 = 2 37 | PAD_1 = "SAME" 38 | 39 | #layer_2 input_size 64*64 / 128*128 / 256*256 40 | FILTER_2 = [3, 3, 32, 32] 41 | STRIDE_2 = 2 42 | PAD_2 = "SAME" 43 | 44 | #layer_3 input_size 32*32 / 64*64 / 128*128 45 | FILTER_3 = [3, 3, 32, 32] 46 | STRIDE_3 = 2 47 | PAD_3 = "SAME" 48 | 49 | #layer_4 input_size 16*16 / 32*32 / 64*64 50 | #FILTER_4 = [9, 9, 256, 512] 51 | FILTER_4 = [3, 3, 32, 32] 52 | STRIDE_4 = 2 53 | PAD_4 = "SAME" 54 | 55 | #layer_5 input_size 8*8 / 16*16 / 32*32 56 | FILTER_5 = [16, 16, 32, 32] 57 | STRIDE_5 = 1 58 | PAD_5 = "VALID" 59 | 60 | #layer_6 input_size 1*1*128 61 | FILTER_6 = [1, 1, 64, 1] 62 | STRIDE_6 = 1 63 | PAD_6 = "VALID" 64 | 65 | space = odl.uniform_discr([-256, -256], [256, 256], [output_h, output_w], dtype='float32') 66 | angle_partition = odl.uniform_partition(0, 2 * np.pi, view_num) 67 | detector_partition = odl.uniform_partition(-360, 360, prjLen) 68 | geometry = odl.tomo.FanFlatGeometry(angle_partition, detector_partition, 69 | src_radius = sod, det_radius = sdd - sod) 70 | operator = odl.tomo.RayTransform(space, geometry, impl='astra_cuda') 71 | 72 | odl_op_layer = odl.contrib.tensorflow.as_tensorflow_layer(operator, 'RayTransform') 73 | odl_op_layer_adjoint = odl.contrib.tensorflow.as_tensorflow_layer(operator.adjoint, 'RayTransformAdjoint') 74 | 75 | 76 | class DualReconstructionModel(object): 77 | def __init__(self): 78 | self.input_width = prjLen 79 | self.input_height = view_num 80 | self.num_channel = channel 81 | self.output_h = output_h 82 | self.output_w = output_w 83 | self.output_depth = output_depth 84 | 85 | def feedforward(self, ml, mh, regularizer = 0): 86 | input_shape = mh.get_shape().as_list() 87 | batch_size = input_shape[0] 88 | 89 | #LSTM 90 | with tf.variable_scope('LSTM_1'): 91 | rh = tf.get_variable("rh", [batch_size, self.output_h, self.output_w, self.output_depth], initializer = tf.constant_initializer(0.0), trainable = False) 92 | conv_w1 = tf.get_variable( "weight_lstm1", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.001) ) 93 | conv_b1 = tf.get_variable( "bias1", 1, initializer = tf.constant_initializer(0.0) ) 94 | cell_g1 = ConvLSTMCell([self.output_h, self.output_w], filters_lstm, kernal_lstm) 95 | init_state_g1 = cell_g1.zero_state(batch_size, dtype = tf.float32) 96 | state1 = init_state_g1 97 | for timestep in range(iteration): 98 | if timestep > 0: 99 | tf.get_variable_scope().reuse_variables() 100 | rh_tr = tf.transpose(rh, perm = [0, 2, 1, 3]) 101 | g1 = odl_op_layer_adjoint( (odl_op_layer(rh_tr) - mh) ) 102 | gt1 = tf.transpose(g1, perm = [0, 2, 1, 3]) 103 | (cell_output1, state1) = cell_g1(gt1, state1) 104 | conv1 = tf.nn.conv2d(cell_output1, conv_w1, [1, 1, 1, 1], "VALID") 105 | s1 = tf.nn.tanh(tf.nn.bias_add(conv1, conv_b1)) 106 | self.variable_summaries(s1, ('s1_%d'%timestep)) 107 | 108 | rh = rh + 0.0001*s1*gt1 109 | rh = tf.clip_by_value(rh, 0, 5) 110 | tf.summary.image('rh_pred_%d'%timestep, rh, 1) 111 | 112 | with tf.variable_scope('LSTM_2'): 113 | rl = tf.get_variable("rl", [batch_size, self.output_h, self.output_w, self.output_depth], initializer = tf.constant_initializer(0.0), trainable = False) 114 | conv_w2 = tf.get_variable("weight_lstm2", filter_shape, initializer = tf.truncated_normal_initializer(stddev = 0.001) ) 115 | conv_b2 = tf.get_variable("bias2", 1, initializer = tf.constant_initializer(0.0) ) 116 | cell_g2 = ConvLSTMCell([self.output_h, self.output_w], filters_lstm, kernal_lstm) 117 | init_state_g2 = cell_g2.zero_state(batch_size, dtype = tf.float32) 118 | state2 = init_state_g2 119 | for timestep in range(iteration): 120 | if timestep > 0: 121 | tf.get_variable_scope().reuse_variables() 122 | rl_tr = tf.transpose(rl, perm = [0, 2, 1, 3]) 123 | g2 = odl_op_layer_adjoint( (odl_op_layer(rl_tr) - ml) ) 124 | gt2 = tf.transpose(g2, perm = [0, 2, 1, 3]) 125 | 126 | (cell_output2, state2) = cell_g2(gt2, state2) 127 | conv2 = tf.nn.conv2d(cell_output2, conv_w2, [1, 1, 1, 1], "VALID") 128 | s2 = tf.nn.tanh(tf.nn.bias_add(conv2, conv_b2)) 129 | self.variable_summaries(s2, ('s2_%d'%timestep)) 130 | 131 | rl = rl + 0.0001*s2*gt2 132 | rl = tf.clip_by_value(rl, 0, 5) 133 | tf.summary.image('rl_pred_%d'%timestep, rl, 1) 134 | 135 | #CNN 136 | layer_L1 = layer_xyf.convo(rl, "conv_L1", FILTER_1, STRIDE_1, PAD_1) 137 | layer_L2 = layer_xyf.convo(layer_L1, "conv_L2", FILTER_2, STRIDE_2, PAD_2) 138 | layer_L3 = layer_xyf.convo(layer_L2, "conv_L3", FILTER_3, STRIDE_3, PAD_3) 139 | layer_L4 = layer_xyf.convo(layer_L3, "conv_L4", FILTER_4, STRIDE_4, PAD_4) 140 | layer_L5 = layer_xyf.convo(layer_L4, "conv_L5", FILTER_5, STRIDE_5, PAD_5) 141 | 142 | layer_H1 = layer_xyf.convo(rh, "conv_H1", FILTER_1, STRIDE_1, PAD_1) 143 | layer_H2 = layer_xyf.convo(layer_H1, "conv_H2", FILTER_2, STRIDE_2, PAD_2) 144 | layer_H3 = layer_xyf.convo(layer_H2, "conv_H3", FILTER_3, STRIDE_3, PAD_3) 145 | layer_H4 = layer_xyf.convo(layer_H3, "conv_H4", FILTER_4, STRIDE_4, PAD_4) 146 | layer_H5 = layer_xyf.convo(layer_H4, "conv_H5", FILTER_5, STRIDE_5, PAD_5) 147 | 148 | combine_LH = tf.concat([layer_L5, layer_H5], 3) 149 | 150 | pa_pred = layer_xyf.convo_noneRelu(combine_LH, "conv_pa", FILTER_6, STRIDE_6, PAD_6) 151 | pb_pred = layer_xyf.convo_noneRelu(combine_LH, "conv_pb", FILTER_6, STRIDE_6, PAD_6) 152 | pc_pred = layer_xyf.convo_noneRelu(combine_LH, "conv_pc", FILTER_6, STRIDE_6, PAD_6) 153 | pd_pred = layer_xyf.convo_noneRelu(combine_LH, "conv_pd", FILTER_6, STRIDE_6, PAD_6) 154 | 155 | d1 = pa_pred*rh + pb_pred*rl 156 | d2 = pc_pred*rh + pd_pred*rl 157 | 158 | d1 = tf.clip_by_value(d1, 0, 5) 159 | d2 = tf.clip_by_value(d2, 0, 5) 160 | 161 | tf.summary.image('d1_pred', d1, 1) 162 | tf.summary.image('d2_pred', d2, 1) 163 | 164 | return d1, d2, rl, rh 165 | 166 | def variable_summaries(self, var, var_name): 167 | with tf.name_scope(var_name): 168 | mean = tf.reduce_mean(var) 169 | tf.summary.scalar('mean', mean) 170 | with tf.name_scope('stddev'): 171 | stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) 172 | tf.summary.scalar('stddev', stddev) 173 | tf.summary.scalar('max', tf.reduce_max(var)) 174 | tf.summary.scalar('min', tf.reduce_min(var)) 175 | tf.summary.histogram('histogram', var) 176 | -------------------------------------------------------------------------------- /src/solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io 3 | import time 4 | import tensorflow as tf 5 | import numpy as np 6 | 7 | from functools import reduce 8 | from operator import mul 9 | 10 | def nn_train(nn, config): 11 | print ("Loading .tfrecords data...", config.dataset) 12 | if not os.path.exists(config.dataset): 13 | raise Exception("training data not find.") 14 | return 15 | 16 | time_str = time.strftime('%Y-%m-%d-%H_%M_%S', time.localtime(time.time())) 17 | train_file_path = os.path.join(config.dataset, "train_sample_batches_*") 18 | train_files = tf.train.match_filenames_once(train_file_path) 19 | filename_queue = tf.train.string_input_producer(train_files, shuffle = True) 20 | 21 | reader = tf.TFRecordReader() 22 | _, serialized_example = reader.read(filename_queue) 23 | features = tf.parse_single_example(serialized_example, 24 | features = {'ml': tf.FixedLenFeature([], tf.string), 25 | 'mh': tf.FixedLenFeature([], tf.string), 26 | 'd_bone': tf.FixedLenFeature([], tf.string), 27 | 'd_tissue': tf.FixedLenFeature([], tf.string), 28 | 'rl': tf.FixedLenFeature([], tf.string), 29 | 'rh': tf.FixedLenFeature([], tf.string)}) 30 | 31 | ml = tf.decode_raw(features['ml'], tf.float32) 32 | mh = tf.decode_raw(features['mh'], tf.float32) 33 | d1 = tf.decode_raw(features['d_bone'], tf.float32) 34 | d2 = tf.decode_raw(features['d_tissue'], tf.float32) 35 | rl = tf.decode_raw(features['rl'], tf.float32) 36 | rh = tf.decode_raw(features['rh'], tf.float32) 37 | 38 | ml = tf.reshape(ml, [nn.input_height, nn.input_width, nn.num_channel]) 39 | mh = tf.reshape(mh, [nn.input_height, nn.input_width, nn.num_channel]) 40 | 41 | d1 = tf.reshape(d1, [512, 512, nn.output_depth]) 42 | d2 = tf.reshape(d2, [512, 512, nn.output_depth]) 43 | rl = tf.reshape(rl, [512, 512, nn.output_depth]) 44 | rh = tf.reshape(rh, [512, 512, nn.output_depth]) 45 | com_resize = 256 46 | d1_resized = tf.image.resize_images(d1, [com_resize, com_resize], method = 0) 47 | d2_resized = tf.image.resize_images(d2, [com_resize, com_resize], method = 0) 48 | rl_resized = tf.image.resize_images(rl, [com_resize, com_resize], method = 0) 49 | rh_resized = tf.image.resize_images(rh, [com_resize, com_resize], method = 0) 50 | 51 | 52 | ml_batch, mh_batch, d1_batch, d2_batch, rl_batch, rh_batch = tf.train.shuffle_batch([ml, mh, d1_resized, d2_resized, rl_resized, rh_resized], 53 | batch_size = config.batch_size, 54 | capacity = config.batch_size*3 + 50, 55 | min_after_dequeue = 30) 56 | 57 | # loss function define 58 | d1_batch_pred, d2_batch_pred, rl_batch_pred, rh_batch_pred = nn.feedforward(ml_batch, mh_batch, rl_batch, rh_batch) 59 | 60 | global_step = tf.Variable(0, trainable = False) 61 | variable_averages = tf.train.ExponentialMovingAverage(config.moving_average_decay, global_step) 62 | variables_averages_op = variable_averages.apply(tf.trainable_variables()) 63 | 64 | mse_loss = tf.reduce_mean(0.005*tf.square(d1_batch_pred - d1_batch) / 2 + 0.005*tf.square(d2_batch_pred - d2_batch) / 2 + 65 | tf.square(rl_batch_pred - rl_batch) / 2 + tf.square(rh_batch_pred - rh_batch) / 2) 66 | 67 | tf.add_to_collection('losses', mse_loss) 68 | tf.summary.scalar('MSE_losses', mse_loss) 69 | 70 | learning_rate = tf.train.exponential_decay(config.lr, global_step, 71 | config.sampleNum / config.batch_size, config.learning_rate_decay) 72 | train_step = tf.train.AdamOptimizer(learning_rate).minimize(mse_loss, global_step = global_step) 73 | 74 | with tf.control_dependencies([train_step, variables_averages_op]): 75 | train_op = tf.no_op(name = 'train') 76 | 77 | merged = tf.summary.merge_all() 78 | summary_path = os.path.join(config.summary_dir, time_str) 79 | os.mkdir(summary_path) 80 | 81 | run_config = tf.ConfigProto(allow_soft_placement = True) 82 | run_config.gpu_options.allow_growth = True 83 | 84 | num_params = 0 85 | for variable in tf.trainable_variables(): 86 | shape = variable.get_shape() 87 | num_params += reduce(mul, [dim.value for dim in shape], 1) 88 | print("Number of trainable parameters: %d"%(num_params)) 89 | 90 | with tf.Session(config = run_config) as sess: 91 | summary_writer = tf.summary.FileWriter(summary_path, sess.graph) 92 | saver = tf.train.Saver() 93 | if config.goon: 94 | if not os.path.exists(config.checkpoint): 95 | raise Exception("checkpoint path not find.") 96 | return 97 | print("Loading trained model... ", config.checkpoint) 98 | ckpt = tf.train.get_checkpoint_state(config.checkpoint) 99 | saver.retore(sess, ckpt) 100 | else: 101 | init = (tf.global_variables_initializer(), tf.local_variables_initializer()) 102 | sess.run(init) 103 | 104 | save_model_path = os.path.join(config.output_model_dir, time_str) 105 | os.mkdir(save_model_path) 106 | 107 | coord = tf.train.Coordinator() 108 | threads = tf.train.start_queue_runners(sess = sess, coord = coord) 109 | 110 | variable_name = [v.name for v in tf.trainable_variables()] 111 | print(variable_name) 112 | 113 | print("start training sess...") 114 | start_time = time.time() 115 | for i in range(config.iteration): 116 | batch_start_time = time.time() 117 | summary, _, loss_value, step, learningRate = sess.run([merged, train_op, mse_loss, global_step, learning_rate]) 118 | batch_end_time = time.time() 119 | sec_per_batch = batch_end_time - batch_start_time 120 | if step % config.summary_step == 0: 121 | summary_writer.add_summary(summary, step//config.summary_step) 122 | if step % config.model_step == 0: 123 | print("Saving model (after %d iteration)... " %(step)) 124 | saver.save(sess, os.path.join(save_model_path, config.model_name + ".ckpt"), global_step = global_step) 125 | print("sec/batch(%d) %gs, global step %d batches, training epoch %d/%d, learningRate %g, loss on training is %g" 126 | % (config.batch_size, sec_per_batch, step, i*config.batch_size / config.sampleNum, 127 | config.epoch, learningRate, loss_value)) 128 | print("Elapsed time: %gs" %(batch_end_time - start_time)) 129 | 130 | coord.request_stop() 131 | coord.join(threads) 132 | summary_writer.close() 133 | print("Train done. ") 134 | print("Saving model... ", save_model_path) 135 | saver.save(sess, os.path.join(save_model_path, config.model_name + ".ckpt"), global_step = global_step) 136 | sess.close() 137 | 138 | print("Sess closed.") 139 | 140 | def nn_test(nn, config): 141 | print ("Loading .tfrecords data...", config.dataset) 142 | if not os.path.exists(config.dataset): 143 | raise Exception(".mat file not find.") 144 | return 145 | 146 | def nn_feedforward(nn, config): 147 | print ("Loading .mat data...", config.dataset) 148 | if not os.path.exists(config.dataset): 149 | print("Testing .mat file not find.") 150 | raise Exception("Testing .mat file not find.") 151 | return 152 | 153 | run_config = tf.ConfigProto() 154 | with tf.Session(config = run_config) as sess: 155 | mat_file_list = [] 156 | filelist = os.listdir(config.dataset) 157 | for line in filelist: 158 | file = os.path.join(config.dataset, line) 159 | if os.path.isfile(file): 160 | mat_file_list.append(file) 161 | print(file) 162 | 163 | ml = tf.placeholder(tf.float32, [1, 360, 1024, 1], name = 'xl-input') 164 | mh = tf.placeholder(tf.float32, [1, 360, 1024, 1], name = 'xh-input') 165 | 166 | d1_pred, d2_pred, rl_pred, rh_pred = nn.feedforward(ml, mh) 167 | 168 | variable_averages = tf.train.ExponentialMovingAverage(config.moving_average_decay) 169 | variables_to_restore = variable_averages.variables_to_restore() 170 | saver = tf.train.Saver(variables_to_restore) 171 | 172 | ckpt = tf.train.get_checkpoint_state(config.checkpoint) 173 | if ckpt and ckpt.model_checkpoint_path: 174 | saver.restore(sess, ckpt.model_checkpoint_path) 175 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-'[-1]) 176 | print('global step: ', global_step) 177 | 178 | file_num = 1 179 | for mat_file in mat_file_list: 180 | print('loading file: ', mat_file) 181 | mat = scipy.io.loadmat(mat_file) 182 | mat_ml = mat['ml'] 183 | mat_mh = mat['mh'] 184 | 185 | dim_h = mat_ml.shape[0] 186 | dim_w = mat_ml.shape[1] 187 | if mat_ml.ndim < 3: 188 | mat_ml = mat_ml.reshape(dim_h, dim_w, 1) 189 | mat_mh = mat_mh.reshape(dim_h, dim_w, 1) 190 | img_num = mat_ml.shape[2] 191 | print('image num: ', img_num) 192 | 193 | d1 = np.zeros((nn.output_h, nn.output_w, img_num), dtype = np.float32) 194 | d2 = np.zeros((nn.output_h, nn.output_w, img_num), dtype = np.float32) 195 | rl = np.zeros((nn.output_h, nn.output_w, img_num), dtype = np.float32) 196 | rh = np.zeros((nn.output_h, nn.output_w, img_num), dtype = np.float32) 197 | 198 | input_ml = np.zeros( (1, dim_h, dim_w, 1) ) 199 | input_mh = np.zeros( (1, dim_h, dim_w, 1) ) 200 | 201 | start_time = time.time() 202 | for i in range( int(img_num) ): 203 | input_ml = mat_ml[:,:,i].reshape([1, dim_h, dim_w, 1]) 204 | input_mh = mat_mh[:,:,i].reshape([1, dim_h, dim_w, 1]) 205 | 206 | output_d1, output_d2, output_rl, output_rh = sess.run([d1_pred, d2_pred, rl_pred, rh_pred], \ 207 | feed_dict={ml: input_ml, mh: input_mh}) 208 | output_d1.resize(1, nn.output_h, nn.output_w, 1) 209 | output_d2.resize(1, nn.output_h, nn.output_w, 1) 210 | output_rl.resize(1, nn.output_h, nn.output_w, 1) 211 | output_rh.resize(1, nn.output_h, nn.output_w, 1) 212 | 213 | d1[:,:,i] = output_d1.reshape([nn.output_h, nn.output_w]) 214 | d2[:,:,i] = output_d2.reshape([nn.output_h, nn.output_w]) 215 | rl[:,:,i] = output_rl.reshape([nn.output_h, nn.output_w]) 216 | rh[:,:,i] = output_rh.reshape([nn.output_h, nn.output_w]) 217 | 218 | end_time = time.time() 219 | print("Elapsed time: %gs" %(end_time - start_time)) 220 | 221 | result_fileName = config.output_data_dir + '/' + 'result_test_%.4d.mat' %(file_num) 222 | scipy.io.savemat(result_fileName, {'d1':d1, 'd2':d2, 'rl':rl, 'rh':rh}) 223 | print('Feed forward done. Result: ', result_fileName) 224 | file_num = file_num + 1 225 | else: 226 | print('No checkpoint file found.') 227 | return 228 | -------------------------------------------------------------------------------- /src/sysMatrix_2D_all_angle.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XYF-GitHub/DualReconstruction/65e95bfb54d0f555c32d8c92417f6f3a63eb5078/src/sysMatrix_2D_all_angle.m --------------------------------------------------------------------------------