├── .gitignore ├── README.md ├── complex_utils.py ├── masks ├── mask_1.0x2.0_c1_00.cfl ├── mask_1.0x2.0_c1_00.hdr ├── mask_1.0x2.0_c1_01.cfl ├── mask_1.0x2.0_c1_01.hdr ├── mask_1.0x2.0_c1_02.cfl └── mask_1.0x2.0_c1_02.hdr ├── mri_data.py ├── mri_model.py ├── mri_prep.py ├── mri_util ├── __init__.py ├── bartwrap.py ├── cfl.py ├── coilcomp.py ├── cs_metrics.py ├── fftc.py ├── mask.py ├── metrics.py ├── recon.py ├── tf_util.py └── zReLU.py ├── requirements.txt ├── setup_mri.py ├── test_conv1d.py ├── test_images.py ├── test_loop.py ├── test_script.sh ├── train_loop.py └── train_script.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | #checkpoints and tensorboards 104 | events* 105 | checkpoint* 106 | graph* 107 | 108 | #data 109 | raw* 110 | tmp* 111 | .zip 112 | *.zip -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Implementation related to the paper "Complex-Valued Convolutional Neural Networks for MRI Reconstruction" by Elizabeth K. Cole et. al: https://arxiv.org/abs/2004.01738 2 | 3 | # Complex-Valued MRI Reconstruction - Unrolled Architecture 4 | Image Reconstruction using an Unrolled DL Architecture including Complex-Valued Convolution and Activation Functions 5 | * 2018 Elizabeth Cole, Stanford University (ekcole@stanford.edu) 6 | * 2018 Joseph Y. Cheng, Stanford University 7 | * 2018 Feiyu Chen, Stanford University 8 | * 2018 Chris Sandino, Stanford University 9 | 10 | ## Complex-Valued Utilities 11 | If you are only interested in the complex-valued utilities, they are located in complex_utils.py. This includes complex-valued convolution, complex-valued transposed convolution, and the CReLU, modReLU, zReLU, and cardioid activation functions. 12 | 13 | ## Setup 14 | Make sure the python requirements are installed 15 | 16 | pip3 install -r requirements.txt 17 | 18 | The setup assumes that the latest Berkeley Advanced Reconstruction Toolbox is installed [1]. The scripts have all been tested with v0.4.01. 19 | 20 | ## Data preparation 21 | We will first download data, generate sampling masks, and generate TFRecords for training. The datasets downloaded are fully sampled volumetric knee scans from mridata [2]. The setup script uses the BART binary. In a new folder, run the follwing script: 22 | 23 | python3 setup_mri.py -v 24 | 25 | ## Training 26 | The training can be ran using the following script in the same folder as the prepared data. 27 | 28 | TYPE=complex 29 | ITERATIONS=4 30 | FEAT=256 31 | ACTIVATION=cardioid 32 | LOG_DIR="f"$FEAT"_g"$ITERATIONS 33 | python3 train_loop.py \ 34 | --train_dir $TYPE"_"$ACTIVATION \ 35 | --mask_path masks \ 36 | --dataset_dir data \ 37 | --log_root $LOG_DIR \ 38 | --shape_z 256 --shape_y 320 \ 39 | --num_channels 8 \ 40 | --batch_size 2 \ 41 | --device 0 \ 42 | --max_steps 50000 \ 43 | --feat_map $FEAT \ 44 | --num_grad_steps $ITERATIONS \ 45 | --activation $ACTIVATION \ 46 | --conv $TYPE 47 | 48 | TYPE denotes the type of convolution; options include "real" or "complex". 49 | 50 | ITERATIONS denotes the number of iterations in the unrolled architecture. 51 | 52 | FEAT denotes the number of feature maps in each convolution layer. 53 | 54 | ACTIVATION denotes the activation function used after each convolution layer; options include "relu", "crelu", "zrelu", "modrelu", and "cardioid" [3]. 55 | If running real convolution, the activation must be relu. 56 | 57 | LOG_DIR indicates the directory checkpoints and training logs are saved to. You can view a tensorboard summary for this training run by running: 58 | 59 | tensorboard --logdir=./ 60 | 61 | in this directory. 62 | 63 | Various complex-valued utility functions are in complex_utils.py. This includes complex-valued convolution, complex-valued transposed convolution, and various complex-valeud activation functions such as CReLU, zReLU, modReLU, and cardioid. 64 | 65 | ## Testing 66 | Testing can be run using a similar script to training, found in test_script. One such example is: 67 | 68 | TYPE=complex 69 | ITERATIONS=4 70 | FEAT=256 71 | ACTIVATION=cardioid 72 | LOG_DIR="f"$FEAT"_g"$ITERATIONS 73 | python3 test_images.py \ 74 | --train_dir $TYPE"_"$ACTIVATION \ 75 | --mask_path masks \ 76 | --dataset_dir data \ 77 | --log_root $LOG_DIR \ 78 | --shape_z 256 --shape_y 320 \ 79 | --num_channels 8 \ 80 | --batch_size 2 \ 81 | --device 0 \ 82 | --max_steps 50000 \ 83 | --feat_map $FEAT \ 84 | --num_grad_steps $ITERATIONS \ 85 | --activation $ACTIVATION \ 86 | --conv $TYPE 87 | 88 | This script will run a deep learning reconstruction on a trained model with the specified parameters. It will also perform CS reconstructions using BART, and compute aggregate image metrics PSNR, SSIM, and NRMSE for each test image. Test images will be saved 89 | in a folder entitled "images" in the model directory. 90 | 91 | ## References 92 | 1. https://github.com/mrirecon/bart 93 | 2. http://mridata.org 94 | 3. https://arxiv.org/pdf/1705.09792.pdf 95 | -------------------------------------------------------------------------------- /complex_utils.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def complex_conv( 8 | tf_input, num_features, kernel_size, stride=1, data_format="channels_last", dilation_rate=(1, 1), use_bias=True, 9 | kernel_initializer=None, kernel_regularizer=None, bias_regularizer=None, 10 | activity_regularizer=None, kernel_constraint=None, bias_constraint=None, trainable=True 11 | ): 12 | # allocate half the features to real, half to imaginary 13 | num_features = num_features // 2 14 | 15 | tf_real = tf.real(tf_input) 16 | tf_imag = tf.imag(tf_input) 17 | 18 | with tf.variable_scope(None, default_name="complex_conv2d"): 19 | tf_real_real = tf.layers.conv2d( 20 | inputs=tf_real, 21 | filters=num_features, 22 | kernel_size=kernel_size, 23 | strides=[stride, stride], 24 | padding="same", 25 | data_format=data_format, 26 | dilation_rate=dilation_rate, 27 | activation=None, 28 | use_bias=use_bias, 29 | kernel_initializer=kernel_initializer, 30 | kernel_regularizer=None, 31 | bias_regularizer=None, 32 | activity_regularizer=None, 33 | kernel_constraint=None, 34 | bias_constraint=None, 35 | name="real_conv", 36 | ) 37 | tf_imag_real = tf.layers.conv2d( 38 | tf_imag, 39 | filters=num_features, 40 | kernel_size=kernel_size, 41 | strides=[stride, stride], 42 | padding="same", 43 | data_format=data_format, 44 | dilation_rate=dilation_rate, 45 | activation=None, 46 | use_bias=use_bias, 47 | kernel_initializer=kernel_initializer, 48 | kernel_regularizer=None, 49 | bias_regularizer=None, 50 | activity_regularizer=None, 51 | kernel_constraint=None, 52 | bias_constraint=None, 53 | name="real_conv", 54 | reuse=True, 55 | ) 56 | tf_real_imag = tf.layers.conv2d( 57 | tf_real, 58 | filters=num_features, 59 | kernel_size=kernel_size, 60 | strides=[stride, stride], 61 | padding="same", 62 | data_format=data_format, 63 | dilation_rate=dilation_rate, 64 | activation=None, 65 | use_bias=use_bias, 66 | kernel_initializer=kernel_initializer, 67 | kernel_regularizer=None, 68 | bias_regularizer=None, 69 | activity_regularizer=None, 70 | kernel_constraint=None, 71 | bias_constraint=None, 72 | name="imag_conv", 73 | ) 74 | tf_imag_imag = tf.layers.conv2d( 75 | tf_imag, 76 | filters=num_features, 77 | kernel_size=kernel_size, 78 | strides=[stride, stride], 79 | padding="same", 80 | data_format=data_format, 81 | dilation_rate=dilation_rate, 82 | activation=None, 83 | use_bias=use_bias, 84 | kernel_initializer=kernel_initializer, 85 | kernel_regularizer=None, 86 | bias_regularizer=None, 87 | activity_regularizer=None, 88 | kernel_constraint=None, 89 | bias_constraint=None, 90 | name="imag_conv", 91 | reuse=True, 92 | ) 93 | real_out = tf_real_real - tf_imag_imag 94 | imag_out = tf_imag_real + tf_real_imag 95 | tf_output = tf.complex(real_out, imag_out) 96 | 97 | return tf_output 98 | 99 | 100 | def complex_conv_transpose(tf_input, num_features, kernel_size, stride, data_format="channels_last", use_bias=True, 101 | kernel_initializer=None, kernel_regularizer=None, bias_regularizer=None, 102 | activity_regularizer=None, kernel_constraint=None, bias_constraint=None, trainable=True 103 | ): 104 | # allocate half the features to real, half to imaginary 105 | # num_features = num_features // 2 106 | 107 | tf_real = tf.real(tf_input) 108 | tf_imag = tf.imag(tf_input) 109 | 110 | with tf.variable_scope(None, default_name="complex_conv2d"): 111 | tf_real_real = tf.layers.conv2d_transpose( 112 | inputs=tf_real, 113 | filters=num_features, 114 | kernel_size=kernel_size, 115 | strides=[stride, stride], 116 | padding="same", 117 | data_format=data_format, 118 | activation=None, 119 | use_bias=use_bias, 120 | kernel_initializer=kernel_initializer, 121 | kernel_regularizer=None, 122 | bias_regularizer=None, 123 | activity_regularizer=None, 124 | kernel_constraint=None, 125 | bias_constraint=None, 126 | name="real_conv", 127 | ) 128 | tf_imag_real = tf.layers.conv2d_transpose( 129 | tf_imag, 130 | filters=num_features, 131 | kernel_size=kernel_size, 132 | strides=[stride, stride], 133 | padding="same", 134 | data_format=data_format, 135 | activation=None, 136 | use_bias=use_bias, 137 | kernel_initializer=kernel_initializer, 138 | kernel_regularizer=None, 139 | bias_regularizer=None, 140 | activity_regularizer=None, 141 | kernel_constraint=None, 142 | bias_constraint=None, 143 | name="real_conv", 144 | reuse=True, 145 | ) 146 | tf_real_imag = tf.layers.conv2d_transpose( 147 | tf_real, 148 | filters=num_features, 149 | kernel_size=kernel_size, 150 | strides=[stride, stride], 151 | padding="same", 152 | data_format=data_format, 153 | activation=None, 154 | use_bias=use_bias, 155 | kernel_initializer=kernel_initializer, 156 | kernel_regularizer=None, 157 | bias_regularizer=None, 158 | activity_regularizer=None, 159 | kernel_constraint=None, 160 | bias_constraint=None, 161 | name="imag_conv", 162 | ) 163 | tf_imag_imag = tf.layers.conv2d_transpose( 164 | tf_imag, 165 | filters=num_features, 166 | kernel_size=kernel_size, 167 | strides=[stride, stride], 168 | padding="same", 169 | data_format=data_format, 170 | activation=None, 171 | use_bias=use_bias, 172 | kernel_initializer=kernel_initializer, 173 | kernel_regularizer=None, 174 | bias_regularizer=None, 175 | activity_regularizer=None, 176 | kernel_constraint=None, 177 | bias_constraint=None, 178 | name="imag_conv", 179 | reuse=True, 180 | ) 181 | real_out = tf_real_real - tf_imag_imag 182 | imag_out = tf_imag_real + tf_real_imag 183 | tf_output = tf.complex(real_out, imag_out) 184 | 185 | return tf_output 186 | 187 | 188 | def complex_conv1d( 189 | tf_input, num_features, kernel_size, stride=1, data_format="channels_last", dilation_rate=(1), use_bias=True, 190 | kernel_initializer=None, kernel_regularizer=None, bias_regularizer=None, 191 | activity_regularizer=None, kernel_constraint=None, bias_constraint=None, trainable=True 192 | ): 193 | # allocate half the features to real, half to imaginary 194 | num_features = num_features // 2 195 | 196 | tf_real = tf.real(tf_input) 197 | tf_imag = tf.imag(tf_input) 198 | 199 | with tf.variable_scope(None, default_name="complex_conv1d"): 200 | tf_real_real = tf.layers.conv1d( 201 | inputs=tf_real, 202 | filters=num_features, 203 | kernel_size=kernel_size, 204 | strides=stride, 205 | padding="same", 206 | data_format=data_format, 207 | dilation_rate=dilation_rate, 208 | activation=None, 209 | use_bias=use_bias, 210 | kernel_initializer=kernel_initializer, 211 | kernel_regularizer=None, 212 | bias_regularizer=None, 213 | activity_regularizer=None, 214 | kernel_constraint=None, 215 | bias_constraint=None, 216 | name="real_conv", 217 | ) 218 | tf_imag_real = tf.layers.conv1d( 219 | tf_imag, 220 | filters=num_features, 221 | kernel_size=kernel_size, 222 | strides=stride, 223 | padding="same", 224 | data_format=data_format, 225 | dilation_rate=dilation_rate, 226 | activation=None, 227 | use_bias=use_bias, 228 | kernel_initializer=kernel_initializer, 229 | kernel_regularizer=None, 230 | bias_regularizer=None, 231 | activity_regularizer=None, 232 | kernel_constraint=None, 233 | bias_constraint=None, 234 | name="real_conv", 235 | reuse=True, 236 | ) 237 | tf_real_imag = tf.layers.conv1d( 238 | tf_real, 239 | filters=num_features, 240 | kernel_size=kernel_size, 241 | strides=stride, 242 | padding="same", 243 | data_format=data_format, 244 | dilation_rate=dilation_rate, 245 | activation=None, 246 | use_bias=use_bias, 247 | kernel_initializer=kernel_initializer, 248 | kernel_regularizer=None, 249 | bias_regularizer=None, 250 | activity_regularizer=None, 251 | kernel_constraint=None, 252 | bias_constraint=None, 253 | name="imag_conv", 254 | ) 255 | tf_imag_imag = tf.layers.conv1d( 256 | tf_imag, 257 | filters=num_features, 258 | kernel_size=kernel_size, 259 | strides=stride, 260 | padding="same", 261 | data_format=data_format, 262 | dilation_rate=dilation_rate, 263 | activation=None, 264 | use_bias=use_bias, 265 | kernel_initializer=kernel_initializer, 266 | kernel_regularizer=None, 267 | bias_regularizer=None, 268 | activity_regularizer=None, 269 | kernel_constraint=None, 270 | bias_constraint=None, 271 | name="imag_conv", 272 | reuse=True, 273 | ) 274 | real_out = tf_real_real - tf_imag_imag 275 | imag_out = tf_imag_real + tf_real_imag 276 | tf_output = tf.complex(real_out, imag_out) 277 | 278 | return tf_output 279 | 280 | def zrelu(x): 281 | # x and tf_output are complex-valued 282 | phase = tf.angle(x) 283 | 284 | # Check whether phase <= pi/2 285 | le = tf.less_equal(phase, pi / 2) 286 | 287 | # if phase <= pi/2, keep it in comp 288 | # if phase > pi/2, throw it away and set comp equal to 0 289 | y = tf.zeros_like(x) 290 | x = tf.where(le, x, y) 291 | 292 | # Check whether phase >= 0 293 | ge = tf.greater_equal(phase, 0) 294 | 295 | # if phase >= 0, keep it 296 | # if phase < 0, throw it away and set output equal to 0 297 | output = tf.where(ge, x, y) 298 | 299 | return output 300 | 301 | 302 | def modrelu(x, data_format="channels_last"): 303 | input_shape = tf.shape(x) 304 | if data_format == "channels_last": 305 | axis_z = 1 306 | axis_y = 2 307 | axis_c = 3 308 | else: 309 | axis_c = 1 310 | axis_z = 2 311 | axis_y = 3 312 | 313 | # Channel size 314 | shape_c = x.shape[axis_c] 315 | 316 | with tf.name_scope("bias") as scope: 317 | if data_format == "channels_last": 318 | bias_shape = (1, 1, 1, shape_c) 319 | else: 320 | bias_shape = (1, shape_c, 1, 1) 321 | bias = tf.get_variable(name=scope, 322 | shape=bias_shape, 323 | initializer=tf.constant_initializer(0.0), 324 | trainable=True) 325 | # relu(|z|+b) * (z / |z|) 326 | norm = tf.abs(x) 327 | scale = tf.nn.relu(norm + bias) / (norm + 1e-6) 328 | output = tf.complex(tf.real(x) * scale, 329 | tf.imag(x) * scale) 330 | 331 | return output 332 | 333 | 334 | def cardioid(x): 335 | phase = tf.angle(x) 336 | scale = 0.5 * (1 + tf.cos(phase)) 337 | output = tf.complex(tf.real(x) * scale, tf.imag(x) * scale) 338 | # output = 0.5*(1+tf.cos(phase))*z 339 | 340 | return output 341 | -------------------------------------------------------------------------------- /masks/mask_1.0x2.0_c1_00.cfl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MRSRL/complex-networks-release/b661bbe2557366a9c010911fb70e291ebcd4af99/masks/mask_1.0x2.0_c1_00.cfl -------------------------------------------------------------------------------- /masks/mask_1.0x2.0_c1_00.hdr: -------------------------------------------------------------------------------- 1 | # Dimensions 2 | 1 320 256 1 1 3 | # Command 4 | poisson -C 1 -Y 320 -Z 256 -y 1 -z 2 -s 368529 masks/mask_1.0x2.0_c1_00 5 | # Files 6 | >masks/mask_1.0x2.0_c1_00 7 | # Creator 8 | BART v0.4.03-4-gd149bc8 9 | -------------------------------------------------------------------------------- /masks/mask_1.0x2.0_c1_01.cfl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MRSRL/complex-networks-release/b661bbe2557366a9c010911fb70e291ebcd4af99/masks/mask_1.0x2.0_c1_01.cfl -------------------------------------------------------------------------------- /masks/mask_1.0x2.0_c1_01.hdr: -------------------------------------------------------------------------------- 1 | # Dimensions 2 | 1 320 256 1 1 3 | # Command 4 | poisson -C 1 -Y 320 -Z 256 -y 1 -z 2 -s 49575 masks/mask_1.0x2.0_c1_01 5 | # Files 6 | >masks/mask_1.0x2.0_c1_01 7 | # Creator 8 | BART v0.4.03-4-gd149bc8 9 | -------------------------------------------------------------------------------- /masks/mask_1.0x2.0_c1_02.hdr: -------------------------------------------------------------------------------- 1 | # Dimensions 2 | 1 320 256 1 1 3 | # Command 4 | poisson -C 1 -Y 320 -Z 256 -y 1 -z 2 -s 632780 masks/mask_1.0x2.0_c1_02 5 | # Files 6 | >masks/mask_1.0x2.0_c1_02 7 | # Creator 8 | BART v0.4.03-4-gd149bc8 9 | -------------------------------------------------------------------------------- /mri_data.py: -------------------------------------------------------------------------------- 1 | """Util for data management.""" 2 | import os 3 | import glob 4 | import random 5 | import tensorflow as tf 6 | import numpy as np 7 | import mri_prep 8 | from mri_util import cfl 9 | from mri_util import recon 10 | from mri_util import tf_util 11 | 12 | 13 | def prepare_filenames(dir_name, search_str="/*.tfrecords"): 14 | """Find and return filenames.""" 15 | if not tf.gfile.Exists(dir_name) or not tf.gfile.IsDirectory(dir_name): 16 | raise FileNotFoundError("Could not find folder `%s'" % (dir_name)) 17 | 18 | full_path = os.path.join(dir_name) 19 | case_list = glob.glob(full_path + search_str) 20 | random.shuffle(case_list) 21 | 22 | return case_list 23 | 24 | 25 | def load_masks_cfl(filenames, image_shape=None): 26 | """Read masks from files.""" 27 | if image_shape is None: 28 | # First find masks shape... 29 | image_shape = [0, 0] 30 | for f in filenames: 31 | f_cfl = os.path.splitext(f)[0] 32 | mask = np.squeeze(cfl.read(f_cfl)) 33 | shape_z = mask.shape[-2] 34 | shape_y = mask.shape[-1] 35 | if image_shape[-2] < shape_z: 36 | image_shape[-2] = shape_z 37 | if image_shape[-1] < shape_y: 38 | image_shape[-1] = shape_y 39 | 40 | masks = np.zeros([len(filenames)] + image_shape, dtype=np.complex64) 41 | 42 | i_file = 0 43 | for f in filenames: 44 | f_cfl = os.path.splitext(f)[0] 45 | tmp = np.squeeze(cfl.read(f_cfl)) 46 | tmp = recon.zeropad(tmp, image_shape) 47 | masks[i_file, :, :] = tmp 48 | i_file = i_file + 1 49 | 50 | return masks 51 | 52 | 53 | def prep_tfrecord(example, masks, 54 | out_shape=[80, 180], 55 | num_channels=6, num_emaps=2, 56 | random_seed=0, 57 | verbose=False): 58 | """Prepare tfrecord for training""" 59 | name = "prep_tfrecord" 60 | 61 | _, _, ks_x, map_x = mri_prep.process_tfrecord( 62 | example, num_channels=num_channels, num_emaps=num_emaps) 63 | 64 | # Randomly select mask 65 | mask_x = tf.constant(masks, dtype=tf.complex64) 66 | mask_x = tf.random_shuffle(mask_x) 67 | mask_x = tf.slice(mask_x, [0, 0, 0], [1, -1, -1]) 68 | # Augment sampling masks 69 | mask_x = tf.image.random_flip_up_down(mask_x, seed=random_seed) 70 | mask_x = tf.image.random_flip_left_right(mask_x, seed=random_seed) 71 | 72 | # Tranpose to store data as (kz, ky, channels) 73 | mask_x = tf.transpose(mask_x, [1, 2, 0]) 74 | ks_x = tf.transpose(ks_x, [1, 2, 0]) 75 | map_x = tf.transpose(map_x, [1, 2, 0]) 76 | 77 | ks_x = tf.image.flip_up_down(ks_x) 78 | map_x = tf.image.flip_up_down(map_x) 79 | 80 | # Initially set image size to be all the same 81 | ks_x = tf.image.resize_image_with_crop_or_pad( 82 | ks_x, out_shape[0], out_shape[1]) 83 | mask_x = tf.image.resize_image_with_crop_or_pad( 84 | mask_x, out_shape[0], out_shape[1]) 85 | 86 | shape_cal = 20 87 | if shape_cal > 0: 88 | with tf.name_scope("CalibRegion"): 89 | if verbose: 90 | print("%s> Including calib region (%d, %d)..." % 91 | (name, shape_cal, shape_cal)) 92 | mask_calib = tf.ones([shape_cal, shape_cal, 1], 93 | dtype=tf.complex64) 94 | mask_calib = tf.image.resize_image_with_crop_or_pad( 95 | mask_calib, out_shape[0], out_shape[1]) 96 | mask_x = mask_x * (1 - mask_calib) + mask_calib 97 | 98 | mask_recon = tf.abs(ks_x) / tf.reduce_max(tf.abs(ks_x)) 99 | mask_recon = tf.cast(mask_recon > 1e-7, dtype=tf.complex64) 100 | mask_x = mask_x * mask_recon 101 | 102 | # Assuming calibration region is fully sampled 103 | shape_sc = 5 104 | scale = tf.image.resize_image_with_crop_or_pad( 105 | ks_x, shape_sc, shape_sc) 106 | scale = (tf.reduce_mean(tf.square(tf.abs(scale))) * 107 | (shape_sc * shape_sc / 1e5)) 108 | scale = tf.cast(1.0 / tf.sqrt(scale), dtype=tf.complex64) 109 | ks_x = ks_x * scale 110 | 111 | # Make sure size is correct 112 | map_shape = tf.shape(map_x) 113 | map_shape_z = tf.slice(map_shape, [0], [1]) 114 | map_shape_y = tf.slice(map_shape, [1], [1]) 115 | assert_z = tf.assert_equal(out_shape[0], map_shape_z) 116 | assert_y = tf.assert_equal(out_shape[1], map_shape_y) 117 | with tf.control_dependencies([assert_z, assert_y]): 118 | map_x = tf.identity(map_x, name="sensemap_size_check") 119 | map_x = tf.image.resize_image_with_crop_or_pad(map_x, 120 | out_shape[0], 121 | out_shape[1]) 122 | map_x = tf.reshape(map_x, [out_shape[0], out_shape[1], 123 | num_emaps, num_channels]) 124 | 125 | # Ground truth 126 | ks_truth = ks_x 127 | # Masked input 128 | ks_x = tf.multiply(ks_x, mask_x) 129 | 130 | features = {} 131 | features['ks_input'] = ks_x 132 | features['sensemap'] = map_x 133 | features['mask_recon'] = mask_recon 134 | features['scale'] = scale 135 | 136 | return features, ks_truth 137 | 138 | 139 | def create_dataset(train_data_dir, mask_data_dir, 140 | batch_size=16, 141 | buffer_size=10, 142 | out_shape=[80, 180], 143 | num_channels=6, num_emaps=1, 144 | verbose=True, 145 | random_seed=0, 146 | name="create_dataset"): 147 | """Setups input tensors.""" 148 | train_filenames_tfrecord = prepare_filenames(train_data_dir, 149 | search_str="/*.tfrecords") 150 | mask_filenames_cfl = prepare_filenames(mask_data_dir, 151 | search_str="/*.cfl") 152 | if verbose: 153 | print("%s> Number of training files (%s): %d" 154 | % (name, train_data_dir, len(train_filenames_tfrecord))) 155 | print("%s> Number of mask files (%s): %d" 156 | % (name, mask_data_dir, len(mask_filenames_cfl))) 157 | 158 | masks = load_masks_cfl(mask_filenames_cfl) 159 | 160 | with tf.variable_scope(name): 161 | dataset = tf.data.TFRecordDataset(train_filenames_tfrecord) 162 | def _prep_tfrecord_with_param(example): 163 | return prep_tfrecord(example, masks, out_shape=out_shape, 164 | num_channels=num_channels, num_emaps=num_emaps, 165 | random_seed=random_seed, verbose=verbose) 166 | dataset = dataset.map(_prep_tfrecord_with_param) 167 | dataset = dataset.prefetch(batch_size * buffer_size) 168 | dataset = dataset.batch(batch_size) 169 | dataset = dataset.repeat(-1) 170 | 171 | return dataset, len(train_filenames_tfrecord) -------------------------------------------------------------------------------- /mri_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import sys 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | # from tensorflow.python.util import deprecation 8 | 9 | import complex_utils 10 | from mri_util import tf_util 11 | 12 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 13 | 14 | # deprecation._PRINT_DEPRECATION_WARNINGS = False 15 | # tf.logging.set_verbosity(tf.logging.ERROR) 16 | 17 | 18 | def _batch_norm(tf_input, data_format="channels_last", training=False): 19 | tf_output = tf.layers.batch_normalization( 20 | tf_input, 21 | axis=(1 if data_format == "channels_first" else -1), 22 | training=training, 23 | renorm=True, 24 | fused=True, 25 | ) 26 | return tf_output 27 | 28 | 29 | def _batch_norm_relu(tf_input, data_format="channels_last", training=False, activation="relu"): 30 | tf_output = _batch_norm( 31 | tf_input, data_format=data_format, training=training) 32 | input_shape = tf.shape(tf_output) 33 | 34 | if(activation == "relu" or "crelu"): 35 | tf_output = tf.nn.relu(tf_output) 36 | else: 37 | # convert two channels to complex-valued in preparation for complex-valued activation functions 38 | tf_output = tf_util.channels_to_complex(tf_output) 39 | 40 | if(activation == "zrelu"): 41 | tf_output = complex_utils.zrelu(tf_output) 42 | 43 | if(activation == "modrelu"): 44 | tf_output = complex_utils.modrelu(tf_output, data_format) 45 | 46 | if(activation == "cardioid"): 47 | tf_output = complex_utils.cardioid(tf_output) 48 | 49 | # convert complex back to two channels 50 | tf_output = tf_util.complex_to_channels(tf_output) 51 | 52 | return tf_output 53 | 54 | 55 | def _circular_pad(tf_input, pad, axis): 56 | """Perform circular padding.""" 57 | shape_input = tf.shape(tf_input) 58 | shape_0 = tf.cast(tf.reduce_prod(shape_input[:axis]), dtype=tf.int32) 59 | shape_axis = shape_input[axis] 60 | tf_output = tf.reshape(tf_input, tf.stack((shape_0, shape_axis, -1))) 61 | 62 | tf_pre = tf_output[:, shape_axis - pad:, :] 63 | tf_post = tf_output[:, :pad, :] 64 | tf_output = tf.concat((tf_pre, tf_output, tf_post), axis=1) 65 | 66 | shape_out = tf.concat( 67 | (shape_input[:axis], [shape_axis + 2 * pad], shape_input[axis + 1:]), axis=0 68 | ) 69 | tf_output = tf.reshape(tf_output, shape_out) 70 | 71 | return tf_output 72 | 73 | 74 | def _conv2d( 75 | tf_input, 76 | num_features=128, 77 | kernel_size=3, 78 | data_format="channels_last", 79 | circular=True, 80 | conjugate=False, 81 | ): 82 | """Conv2d with option for circular convolution.""" 83 | if data_format == "channels_last": 84 | # (batch, z, y, channels) 85 | axis_z = 1 86 | axis_y = 2 87 | axis_c = 3 88 | else: 89 | # (batch, channels, z, y) 90 | axis_c = 1 91 | axis_z = 2 92 | axis_y = 3 93 | 94 | pad = int((kernel_size - 0.5) / 2) 95 | tf_output = tf_input 96 | 97 | if circular: 98 | with tf.name_scope("circular_pad"): 99 | tf_output = _circular_pad(tf_output, pad, axis_z) 100 | tf_output = _circular_pad(tf_output, pad, axis_y) 101 | 102 | if type_conv == "real": 103 | print("real convolution") 104 | num_features = int(num_features) // np.sqrt(2) 105 | tf_output = tf.layers.conv2d( 106 | tf_output, 107 | num_features, 108 | kernel_size, 109 | padding="same", 110 | use_bias=False, 111 | data_format=data_format, 112 | ) 113 | if type_conv == "complex": 114 | print("complex convolution") 115 | # channels to complex 116 | tf_output = tf_util.channels_to_complex(tf_output) 117 | 118 | if num_features != 2: 119 | num_features = num_features // 2 120 | 121 | tf_output = complex_utils.complex_conv( 122 | tf_output, num_features=num_features, kernel_size=kernel_size) 123 | 124 | if conjugate == True and num_features != 2: 125 | print("conjugation") 126 | # conjugate the output 127 | tf_real = tf_util.getReal(tf_output, data_format) 128 | imag_out = tf_util.getImag(tf_output, data_format) 129 | imag_conj = -1 * imag_out 130 | 131 | real_out = tf.concat([real_out, real_out], axis=-1) 132 | imag_out = tf.concat([imag_out, imag_conj], axis=-1) 133 | 134 | tf_output = tf.concat([real_out, imag_out], axis=-1) 135 | 136 | # complex to channels 137 | tf_output = tf_util.complex_to_channels(tf_output) 138 | 139 | if circular: 140 | shape_input = tf.shape(tf_input) 141 | shape_z = shape_input[axis_z] 142 | shape_y = shape_input[axis_y] 143 | with tf.name_scope("circular_crop"): 144 | if data_format == "channels_last": 145 | tf_output = tf_output[ 146 | :, pad: (shape_z + pad), pad: (shape_y + pad), : 147 | ] 148 | else: 149 | tf_output = tf_output[ 150 | :, :, pad: (shape_z + pad), pad: (shape_y + pad) 151 | ] 152 | # add all needed attributes to tensor 153 | else: 154 | with tf.name_scope("non_circular"): 155 | tf_output = tf_output[:, :, :, :] 156 | 157 | return tf_output 158 | 159 | 160 | def _res_block( 161 | net_input, 162 | num_features=32, 163 | kernel_size=3, 164 | data_format="channels_last", 165 | circular=True, 166 | training=True, 167 | name="res_block", 168 | activation="relu", 169 | ): 170 | """Create ResNet block. 171 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 172 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 173 | """ 174 | if data_format == "channels_last": 175 | axis_z = 1 176 | axis_y = 2 177 | axis_c = 3 178 | else: 179 | axis_c = 1 180 | axis_z = 2 181 | axis_y = 3 182 | shape_c = net_input.shape[axis_c] 183 | pad = int((2 * (kernel_size - 1) + 0.5) / 2) 184 | 185 | with tf.name_scope(name): 186 | shortcut = net_input 187 | if num_features != shape_c: 188 | shortcut = _conv2d( 189 | shortcut, 190 | num_features=num_features, 191 | kernel_size=1, 192 | data_format=data_format, 193 | circular=circular, 194 | ) 195 | 196 | net_cur = net_input 197 | 198 | if circular: 199 | with tf.name_scope("circular_pad"): 200 | net_cur = _circular_pad(net_cur, pad, axis_z) 201 | net_cur = _circular_pad(net_cur, pad, axis_y) 202 | 203 | net_cur = _batch_norm_relu( 204 | net_cur, data_format=data_format, training=training, activation=activation) 205 | net_cur = _conv2d( 206 | net_cur, 207 | num_features=num_features, 208 | kernel_size=kernel_size, 209 | data_format=data_format, 210 | circular=False, 211 | ) 212 | net_cur = _batch_norm_relu( 213 | net_cur, data_format=data_format, training=training, activation=activation) 214 | net_cur = _conv2d( 215 | net_cur, 216 | num_features=num_features, 217 | kernel_size=kernel_size, 218 | data_format=data_format, 219 | circular=False, 220 | ) 221 | 222 | if circular: 223 | shape_input = tf.shape(net_input) 224 | shape_z = shape_input[axis_z] 225 | shape_y = shape_input[axis_y] 226 | with tf.name_scope("circular_crop"): 227 | if data_format == "channels_last": 228 | net_cur = net_cur[ 229 | :, pad: (pad + shape_z), pad: (pad + shape_y), : 230 | ] 231 | else: 232 | net_cur = net_cur[ 233 | :, :, pad: (pad + shape_z), pad: (pad + shape_y) 234 | ] 235 | 236 | net_cur = net_cur + shortcut 237 | 238 | return net_cur 239 | 240 | 241 | def prior_grad_res_net( 242 | curr_x, 243 | num_features=32, 244 | kernel_size=3, 245 | num_blocks=2, 246 | data_format="channels_last", 247 | do_residual=True, 248 | training=True, 249 | num_features_out=None, 250 | circular=True, 251 | name="prior_grad_resnet", 252 | activation="relu", 253 | ): 254 | """Create prior gradient.""" 255 | if data_format == "channels_last": 256 | num_features_in = curr_x.shape[-1] 257 | else: 258 | num_features_in = curr_x.shape[1] 259 | if num_features_out is None: 260 | num_features_out = num_features_in 261 | 262 | with tf.name_scope(name): 263 | net = curr_x 264 | 265 | if do_residual: 266 | if num_features_in != num_features_out: 267 | shortcut = _conv2d( 268 | net, 269 | num_features=num_features_out, 270 | kernel_size=1, 271 | data_format=data_format, 272 | circular=circular, 273 | ) 274 | else: 275 | shortcut = net 276 | 277 | for _ in range(num_blocks): 278 | net = _res_block( 279 | net, 280 | training=training, 281 | num_features=num_features, 282 | kernel_size=kernel_size, 283 | data_format=data_format, 284 | circular=circular, 285 | ) 286 | 287 | net = _batch_norm_relu(net, data_format=data_format, 288 | training=training, activation=activation) 289 | net = _conv2d( 290 | net, 291 | num_features=num_features_out, 292 | kernel_size=kernel_size, 293 | data_format=data_format, 294 | circular=circular, 295 | ) 296 | if do_residual: 297 | net = net + shortcut 298 | 299 | return net 300 | 301 | 302 | def prior_grad_simple( 303 | net_input, 304 | num_features=128, 305 | num_features_out=None, 306 | kernel_size=3, 307 | num_blocks=5, 308 | data_format="channels_last", 309 | do_residual=True, 310 | circular=True, 311 | training=True, 312 | name="prior_grad_simple", 313 | activation="relu", 314 | ): 315 | """Create prior gradient. 316 | This is based on the original work proposed by Diamond et al. 317 | """ 318 | if data_format == "channels_last": 319 | axis_z = 1 320 | axis_y = 2 321 | axis_c = 3 322 | else: 323 | axis_c = 1 324 | axis_z = 2 325 | axis_y = 3 326 | num_features_in = net_input.shape[axis_c] 327 | if num_features_out is None: 328 | num_features_out = num_features_in 329 | # Number of total conv2d: 2 + num_blocks 330 | pad = int(((2 + num_blocks) * (kernel_size - 1) + 0.5) / 2) 331 | 332 | with tf.name_scope(name): 333 | net_cur = net_input 334 | 335 | if circular: 336 | with tf.name_scope("circular_pad"): 337 | net_cur = _circular_pad(net_cur, pad, axis_z) 338 | net_cur = _circular_pad(net_cur, pad, axis_y) 339 | 340 | # Expand to specified number of features 341 | net_cur = _conv2d( 342 | net_cur, 343 | num_features=num_features, 344 | kernel_size=kernel_size, 345 | data_format=data_format, 346 | circular=False, 347 | ) 348 | net_cur = _batch_norm_relu( 349 | net_cur, data_format=data_format, training=training, activation=activation) 350 | 351 | # Repeat conv2d, bn, relu 352 | for _ in range(num_blocks): 353 | net_cur = _conv2d( 354 | net_cur, 355 | num_features=num_features, 356 | kernel_size=kernel_size, 357 | data_format=data_format, 358 | circular=False, 359 | ) 360 | net_cur = _batch_norm_relu( 361 | net_cur, data_format=data_format, training=training, activation=activation 362 | ) 363 | 364 | net_cur = _conv2d( 365 | net_cur, 366 | num_features=num_features_out, 367 | kernel_size=kernel_size, 368 | data_format=data_format, 369 | circular=False, 370 | ) 371 | 372 | if circular: 373 | shape_input = tf.shape(net_input) 374 | shape_z = shape_input[axis_z] 375 | shape_y = shape_input[axis_y] 376 | 377 | with tf.name_scope("circular_crop"): 378 | if data_format == "channels_last": 379 | net_cur = net_cur[ 380 | :, pad: (pad + shape_z), pad: (pad + shape_y), : 381 | ] 382 | else: 383 | net_cur = net_cur[ 384 | :, :, pad: (pad + shape_z), pad: (pad + shape_y) 385 | ] 386 | 387 | if do_residual: 388 | net_cur = net_cur + net_input 389 | 390 | return net_cur 391 | 392 | 393 | def unroll_fista( 394 | ks_input, 395 | sensemap, 396 | num_grad_steps=5, 397 | resblock_num_features=128, 398 | resblock_num_blocks=2, 399 | is_training=True, 400 | scope="MRI", 401 | mask_output=1, 402 | window=None, 403 | do_hardproj=True, 404 | num_summary_image=0, 405 | mask=None, 406 | verbose=False, 407 | conv="real", 408 | do_conjugate=False, 409 | activation="relu", 410 | ): 411 | """Create general unrolled network for MRI. 412 | x_{k+1} = S( x_k - 2 * t * A^T W (A x- b) ) 413 | = S( x_k - 2 * t * (A^T W A x - A^T W b)) 414 | """ 415 | if window is None: 416 | window = 1 417 | summary_iter = None 418 | 419 | global type_conv 420 | type_conv = conv 421 | global conjugate 422 | conjugate = do_conjugate 423 | 424 | if verbose: 425 | print( 426 | "%s> Building FISTA unrolled network (%d steps)...." 427 | % (scope, num_grad_steps) 428 | ) 429 | if sensemap is not None: 430 | print("%s> Using sensitivity maps..." % scope) 431 | with tf.variable_scope(scope): 432 | if mask is None: 433 | mask = tf_util.kspace_mask(ks_input, dtype=tf.complex64) 434 | ks_input = mask * ks_input 435 | ks_0 = ks_input 436 | # x0 = A^T W b 437 | im_0 = tf_util.model_transpose(ks_0 * window, sensemap) 438 | im_0 = tf.identity(im_0, name="input_image") 439 | # To be updated 440 | ks_k = ks_0 441 | im_k = im_0 442 | 443 | for i_step in range(num_grad_steps): 444 | iter_name = "iter_%02d" % i_step 445 | with tf.variable_scope(iter_name): 446 | # = S( x_k - 2 * t * (A^T W A x_k - A^T W b)) 447 | # = S( x_k - 2 * t * (A^T W A x_k - x0)) 448 | with tf.variable_scope("update"): 449 | im_k_orig = im_k 450 | # xk = A^T A x_k 451 | ks_k = tf_util.model_forward(im_k, sensemap) 452 | ks_k = mask * ks_k 453 | im_k = tf_util.model_transpose(ks_k * window, sensemap) 454 | # xk = A^T A x_k - A^T b 455 | im_k = tf_util.complex_to_channels(im_k - im_0) 456 | im_k_orig = tf_util.complex_to_channels(im_k_orig) 457 | # Update step 458 | t_update = tf.get_variable( 459 | "t", dtype=tf.float32, initializer=tf.constant([-2.0]) 460 | ) 461 | im_k = im_k_orig + t_update * im_k 462 | 463 | with tf.variable_scope("prox"): 464 | num_channels_out = im_k.shape[-1] 465 | im_k = prior_grad_res_net( 466 | im_k, 467 | training=is_training, 468 | num_features=resblock_num_features, 469 | num_blocks=resblock_num_blocks, 470 | num_features_out=num_channels_out, 471 | data_format="channels_last", 472 | activation=activation 473 | ) 474 | im_k = tf_util.channels_to_complex(im_k) 475 | 476 | im_k = tf.identity(im_k, name="image") 477 | if num_summary_image > 0: 478 | with tf.name_scope("summary"): 479 | tmp = tf_util.sumofsq(im_k, keep_dims=True) 480 | if summary_iter is None: 481 | summary_iter = tmp 482 | else: 483 | summary_iter = tf.concat( 484 | (summary_iter, tmp), axis=2) 485 | tf.summary.scalar("max/" + iter_name, 486 | tf.reduce_max(tmp)) 487 | 488 | ks_k = tf_util.model_forward(im_k, sensemap) 489 | if do_hardproj: 490 | if verbose: 491 | print("%s> Final hard data projection..." % scope) 492 | # Final data projection 493 | ks_k = mask * ks_0 + (1 - mask) * ks_k 494 | if mask_output is not None: 495 | ks_k = ks_k * mask_output 496 | im_k = tf_util.model_transpose(ks_k * window, sensemap) 497 | 498 | ks_k = tf.identity(ks_k, name="output_kspace") 499 | im_k = tf.identity(im_k, name="output_image") 500 | 501 | if summary_iter is not None: 502 | tf.summary.image("iter/image", summary_iter, 503 | max_outputs=num_summary_image) 504 | 505 | return im_k 506 | -------------------------------------------------------------------------------- /mri_prep.py: -------------------------------------------------------------------------------- 1 | """Data preparation for training.""" 2 | import os 3 | import random 4 | import shutil 5 | import subprocess 6 | import zipfile 7 | 8 | import numpy as np 9 | import tensorflow as tf 10 | import wget 11 | 12 | from mri_util import cfl, fftc, tf_util 13 | 14 | tf.logging.set_verbosity(tf.logging.ERROR) 15 | 16 | BIN_BART = "bart" 17 | 18 | 19 | def download_dataset_knee(dir_out, dir_tmp="tmp", verbose=False, do_cleanup=True): 20 | """Download and unzip knee dataset from mridata.org.""" 21 | if not os.path.isdir(dir_out): 22 | os.makedirs(dir_out) 23 | if os.path.isdir(dir_tmp): 24 | print("WARNING! Temporary folder exists (%s)" % dir_tmp) 25 | else: 26 | os.makedirs(dir_tmp) 27 | 28 | num_data = 1 29 | for i in range(num_data): 30 | if verbose: 31 | print("Processing data (%d)..." % i) 32 | 33 | url = "http://old.mridata.org/knees/fully_sampled/p%d/e1/s1/P%d.zip" % ( 34 | i + 1, 35 | i + 1, 36 | ) 37 | dir_name_i = os.path.join(dir_out, "data%02d" % i) 38 | 39 | if verbose: 40 | print(" dowloading from %s..." % url) 41 | if not os.path.isdir(dir_name_i): 42 | os.makedirs(dir_name_i) 43 | file_download = wget.download(url, out=dir_tmp) 44 | 45 | if verbose: 46 | print(" unzipping contents to %s..." % dir_name_i) 47 | with zipfile.ZipFile(file_download, "r") as zip_ref: 48 | for member in zip_ref.namelist(): 49 | filename = os.path.basename(member) 50 | if not filename: 51 | continue 52 | file_src = zip_ref.open(member) 53 | file_dest = open(os.path.join(dir_name_i, filename), "wb") 54 | with file_src, file_dest: 55 | shutil.copyfileobj(file_src, file_dest) 56 | 57 | if do_cleanup: 58 | if verbose: 59 | print("Cleanup...") 60 | shutil.rmtree(dir_tmp) 61 | 62 | if verbose: 63 | print("Done") 64 | 65 | 66 | def create_masks( 67 | dir_out, 68 | shape_y=320, 69 | shape_z=256, 70 | verbose=False, 71 | acc_y=(1, 2, 3), 72 | acc_z=(1, 2, 3), 73 | shape_calib=1, 74 | variable_density=False, 75 | num_repeat=4, 76 | ): 77 | """Create sampling masks using BART.""" 78 | flags = "" 79 | file_fmt = "mask_%0.1fx%0.1f_c%d_%02d" 80 | if variable_density: 81 | flags = flags + " -v " 82 | file_fmt = file_fmt + "_vd" 83 | 84 | if not os.path.exists(dir_out): 85 | os.mkdir(dir_out) 86 | 87 | for a_y in acc_y: 88 | for a_z in acc_z: 89 | if a_y * a_z != 1: 90 | num_repeat_i = num_repeat 91 | if (a_y == acc_y[-1]) and (a_z == acc_z[-1]): 92 | num_repeat_i = num_repeat_i * 2 93 | for i in range(num_repeat_i): 94 | random_seed = 1e6 * random.random() 95 | file_name = file_fmt % (a_y, a_z, shape_calib, i) 96 | if verbose: 97 | print("creating mask (%s)..." % file_name) 98 | file_name = os.path.join(dir_out, file_name) 99 | cmd = "%s poisson -C %d -Y %d -Z %d -y %d -z %d -s %d %s %s" % ( 100 | BIN_BART, 101 | shape_calib, 102 | shape_y, 103 | shape_z, 104 | a_y, 105 | a_z, 106 | random_seed, 107 | flags, 108 | file_name, 109 | ) 110 | subprocess.check_output(["bash", "-c", cmd]) 111 | 112 | 113 | def _int64_feature(value): 114 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 115 | 116 | 117 | def _bytes_feature(value): 118 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 119 | 120 | 121 | def setup_data_tfrecords( 122 | dir_in_root, 123 | dir_out, 124 | data_divide=(0.75, 0.05, 0.2), 125 | min_shape=[80, 180], 126 | num_maps=1, 127 | crop_maps=False, 128 | verbose=False, 129 | ): 130 | """Setups training data as tfrecords. 131 | 132 | prep_data.setup_data('/mnt/raid3/data/Studies_DCE/recon-ccomp6/', 133 | '/mnt/raid3/jycheng/Project/deepspirit/data/train/', verbose=True) 134 | """ 135 | 136 | # Check for two echos in here 137 | # Use glob to find if have echo01 138 | 139 | if verbose: 140 | print("Directory names:") 141 | print(" Input root: %s" % dir_in_root) 142 | print(" Output root: %s" % dir_out) 143 | 144 | file_kspace = "kspace" 145 | file_sensemap = "sensemap" 146 | 147 | case_list = os.listdir(dir_in_root) 148 | random.shuffle(case_list) 149 | num_cases = len(case_list) 150 | 151 | i_train_1 = np.round(data_divide[0] * num_cases).astype(int) 152 | i_validate_0 = i_train_1 + 1 153 | i_validate_1 = np.round( 154 | data_divide[1] * num_cases).astype(int) + i_validate_0 155 | 156 | if not os.path.exists(dir_out): 157 | os.mkdir(dir_out) 158 | if not os.path.exists(os.path.join(dir_out, "train")): 159 | os.mkdir(os.path.join(dir_out, "train")) 160 | if not os.path.exists(os.path.join(dir_out, "validate")): 161 | os.mkdir(os.path.join(dir_out, "validate")) 162 | if not os.path.exists(os.path.join(dir_out, "test")): 163 | os.mkdir(os.path.join(dir_out, "test")) 164 | 165 | i_case = 0 166 | for case_name in case_list: 167 | file_kspace_i = os.path.join(dir_in_root, case_name, file_kspace) 168 | file_sensemap_i = os.path.join(dir_in_root, case_name, file_sensemap) 169 | 170 | if i_case < i_train_1: 171 | dir_out_i = os.path.join(dir_out, "train") 172 | elif i_case < i_validate_1: 173 | dir_out_i = os.path.join(dir_out, "validate") 174 | else: 175 | dir_out_i = os.path.join(dir_out, "test") 176 | 177 | if verbose: 178 | print("Processing [%d] %s..." % (i_case, case_name)) 179 | i_case = i_case + 1 180 | 181 | kspace = np.squeeze(cfl.read(file_kspace_i)) 182 | if (min_shape is None) or ( 183 | min_shape[0] <= kspace.shape[1] and min_shape[1] <= kspace.shape[2] 184 | ): 185 | if verbose: 186 | print(" Slice shape: (%d, %d)" % 187 | (kspace.shape[1], kspace.shape[2])) 188 | print(" Num channels: %d" % kspace.shape[0]) 189 | shape_x = kspace.shape[-1] 190 | kspace = fftc.ifftc(kspace, axis=-1) 191 | kspace = kspace.astype(np.complex64) 192 | 193 | # if shape_c_out < shape_c: 194 | # if verbose: 195 | # print(" applying coil compression (%d -> %d)..." % 196 | # (shape_c, shape_c_out)) 197 | # shape_cal = 24 198 | # ks_cal = recon.crop(ks, [-1, shape_cal, shape_cal, -1]) 199 | # ks_cal = np.reshape(ks_cal, [shape_c, 200 | # shape_cal*shape_cal, 201 | # shape_x]) 202 | # cc_mat = coilcomp.calc_gcc_weights_c(ks_cal, shape_c_out) 203 | # ks_cc = np.reshape(ks, [shape_c, -1, shape_x]) 204 | # ks_cc = coilcomp.apply_gcc_weights_c(ks_cc, cc_mat) 205 | # ks = np.reshape(ks_cc, [shape_c_out, shape_z, shape_y, shape_x]) 206 | 207 | cmd_flags = "" 208 | if crop_maps: 209 | cmd_flags = cmd_flags + " -c 1e-9" 210 | cmd_flags = cmd_flags + (" -m %d" % num_maps) 211 | cmd = "%s ecalib %s %s %s" % ( 212 | BIN_BART, 213 | cmd_flags, 214 | file_kspace_i, 215 | file_sensemap_i, 216 | ) 217 | if verbose: 218 | print(" Estimating sensitivity maps (bart espirit)...") 219 | print(" %s" % cmd) 220 | subprocess.check_call(["bash", "-c", cmd]) 221 | sensemap = np.squeeze(cfl.read(file_sensemap_i)) 222 | sensemap = np.expand_dims(sensemap, axis=0) 223 | sensemap = sensemap.astype(np.complex64) 224 | 225 | if verbose: 226 | print(" Creating tfrecords (%d)..." % shape_x) 227 | for i_x in range(shape_x): 228 | file_out = os.path.join( 229 | dir_out_i, "%s_x%03d.tfrecords" % (case_name, i_x) 230 | ) 231 | kspace_x = kspace[:, :, :, i_x] 232 | sensemap_x = sensemap[:, :, :, :, i_x] 233 | 234 | example = tf.train.Example( 235 | features=tf.train.Features( 236 | feature={ 237 | "name": _bytes_feature(str.encode(case_name)), 238 | "xslice": _int64_feature(i_x), 239 | "ks_shape_x": _int64_feature(kspace.shape[3]), 240 | "ks_shape_y": _int64_feature(kspace.shape[2]), 241 | "ks_shape_z": _int64_feature(kspace.shape[1]), 242 | "ks_shape_c": _int64_feature(kspace.shape[0]), 243 | "map_shape_x": _int64_feature(sensemap.shape[4]), 244 | "map_shape_y": _int64_feature(sensemap.shape[3]), 245 | "map_shape_z": _int64_feature(sensemap.shape[2]), 246 | "map_shape_c": _int64_feature(sensemap.shape[1]), 247 | "map_shape_m": _int64_feature(sensemap.shape[0]), 248 | "ks": _bytes_feature(kspace_x.tostring()), 249 | "map": _bytes_feature(sensemap_x.tostring()), 250 | } 251 | ) 252 | ) 253 | 254 | tf_writer = tf.python_io.TFRecordWriter(file_out) 255 | tf_writer.write(example.SerializeToString()) 256 | tf_writer.close() 257 | 258 | 259 | def process_tfrecord(example, num_channels=None, num_emaps=None): 260 | """Process TFRecord to actual tensors.""" 261 | features = tf.parse_single_example( 262 | example, 263 | features={ 264 | "name": tf.FixedLenFeature([], tf.string), 265 | "xslice": tf.FixedLenFeature([], tf.int64), 266 | "ks_shape_x": tf.FixedLenFeature([], tf.int64), 267 | "ks_shape_y": tf.FixedLenFeature([], tf.int64), 268 | "ks_shape_z": tf.FixedLenFeature([], tf.int64), 269 | "ks_shape_c": tf.FixedLenFeature([], tf.int64), 270 | "map_shape_x": tf.FixedLenFeature([], tf.int64), 271 | "map_shape_y": tf.FixedLenFeature([], tf.int64), 272 | "map_shape_z": tf.FixedLenFeature([], tf.int64), 273 | "map_shape_c": tf.FixedLenFeature([], tf.int64), 274 | "map_shape_m": tf.FixedLenFeature([], tf.int64), 275 | "ks": tf.FixedLenFeature([], tf.string), 276 | "map": tf.FixedLenFeature([], tf.string), 277 | }, 278 | ) 279 | 280 | name = features["name"] 281 | xslice = tf.cast(features["xslice"], dtype=tf.int32) 282 | # shape_x = tf.cast(features['shape_x'], dtype=tf.int32) 283 | ks_shape_y = tf.cast(features["ks_shape_y"], dtype=tf.int32) 284 | ks_shape_z = tf.cast(features["ks_shape_z"], dtype=tf.int32) 285 | if num_channels is None: 286 | ks_shape_c = tf.cast(features["ks_shape_c"], dtype=tf.int32) 287 | else: 288 | ks_shape_c = num_channels 289 | map_shape_y = tf.cast(features["map_shape_y"], dtype=tf.int32) 290 | map_shape_z = tf.cast(features["map_shape_z"], dtype=tf.int32) 291 | if num_channels is None: 292 | map_shape_c = tf.cast(features["map_shape_c"], dtype=tf.int32) 293 | else: 294 | map_shape_c = num_channels 295 | if num_emaps is None: 296 | map_shape_m = tf.cast(features["map_shape_m"], dtype=tf.int32) 297 | else: 298 | map_shape_m = num_emaps 299 | 300 | with tf.name_scope("kspace"): 301 | ks_record_bytes = tf.decode_raw(features["ks"], tf.float32) 302 | image_shape = [ks_shape_c, ks_shape_z, ks_shape_y] 303 | ks_x = tf.reshape(ks_record_bytes, image_shape + [2]) 304 | ks_x = tf_util.channels_to_complex(ks_x) 305 | ks_x = tf.reshape(ks_x, image_shape) 306 | 307 | with tf.name_scope("sensemap"): 308 | map_record_bytes = tf.decode_raw(features["map"], tf.float32) 309 | map_shape = [map_shape_m * map_shape_c, map_shape_z, map_shape_y] 310 | map_x = tf.reshape(map_record_bytes, map_shape + [2]) 311 | map_x = tf_util.channels_to_complex(map_x) 312 | map_x = tf.reshape(map_x, map_shape) 313 | 314 | return name, xslice, ks_x, map_x 315 | 316 | 317 | def read_tfrecord_with_sess(tf_sess, filename_tfrecord): 318 | """Read TFRecord for debugging.""" 319 | tf_reader = tf.TFRecordReader() 320 | filename_queue = tf.train.string_input_producer([filename_tfrecord]) 321 | _, serialized_example = tf_reader.read(filename_queue) 322 | name, xslice, ks_x, map_x = process_tfrecord(serialized_example) 323 | coord = tf.train.Coordinator() 324 | threads = tf.train.start_queue_runners(sess=tf_sess, coord=coord) 325 | name, xslice, ks_x, map_x = tf_sess.run([name, xslice, ks_x, map_x]) 326 | coord.request_stop() 327 | coord.join(threads) 328 | 329 | return {"name": name, "xslice": xslice, "ks": ks_x, "sensemap": map_x} 330 | 331 | 332 | def read_tfrecord(filename_tfrecord): 333 | """Read TFRecord for debugging.""" 334 | session_config = tf.ConfigProto() 335 | session_config.gpu_options.allow_growth = True 336 | tf_sess = tf.Session(config=session_config) 337 | data = read_tfrecord_with_sess(tf_sess, filename_tfrecord) 338 | tf_sess.close() 339 | return data 340 | -------------------------------------------------------------------------------- /mri_util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MRSRL/complex-networks-release/b661bbe2557366a9c010911fb70e291ebcd4af99/mri_util/__init__.py -------------------------------------------------------------------------------- /mri_util/bartwrap.py: -------------------------------------------------------------------------------- 1 | """Wraps BART functions.""" 2 | import random 3 | import subprocess 4 | from timeit import default_timer as timer 5 | 6 | import numpy as np 7 | 8 | from mri_util import cfl, recon 9 | 10 | 11 | def bart_generate_mask( 12 | shape, 13 | acc, 14 | variable_density=True, 15 | shape_calib=10, 16 | verbose=False, 17 | tmp_file="mask.tmp", 18 | ): 19 | """Use bart poisson to generate masks.""" 20 | if verbose: 21 | print("Generating sampling mask...") 22 | random_seed = 1e6 * random.random() 23 | flags = "-Z %d -Y %d -z %g -y %g -C %d -s %d" % ( 24 | shape[0], 25 | shape[1], 26 | acc[0], 27 | acc[1], 28 | shape_calib, 29 | random_seed, 30 | ) 31 | if variable_density: 32 | flags = flags + " -v" 33 | cmd = "bart poisson %s %s" % (flags, tmp_file) 34 | if verbose: 35 | print(" %s" % cmd) 36 | subprocess.check_output(["bash", "-c", cmd]) 37 | mask = np.abs(np.squeeze(cfl.read(tmp_file))) 38 | return mask 39 | 40 | 41 | def bart_espirit( 42 | ks_input, 43 | shape=None, 44 | verbose=False, 45 | shape_e=2, 46 | crop_value=None, 47 | cal_size=None, 48 | smooth=False, 49 | filename_ks_tmp="ks.tmp", 50 | filename_map_tmp="map.tmp", 51 | ): 52 | """Estimate sensitivity maps using BART ESPIRiT. 53 | ks_input dimensions: [emaps, channels, kz, ky, kx] 54 | """ 55 | if verbose: 56 | print("Estimating sensitivity maps...") 57 | if shape is not None: 58 | ks_input = recon.crop(ks_input, [-1, -1, shape[0], shape[1], -1]) 59 | ks_input = recon.zeropad(ks_input, [-1, -1, shape[0], shape[1], -1]) 60 | 61 | flags = "" 62 | if crop_value is not None: 63 | flags = flags + "-c %f " % crop_value 64 | if cal_size is not None: 65 | flags = flags + "-r %d " % cal_size 66 | if smooth: 67 | flags = flags + "-S " 68 | 69 | cfl.write(filename_ks_tmp, ks_input) 70 | cmd = "bart ecalib -m %d %s %s %s" % ( 71 | shape_e, 72 | flags, 73 | filename_ks_tmp, 74 | filename_map_tmp, 75 | ) 76 | if verbose: 77 | print(" %s" % cmd) 78 | time_start = timer() 79 | subprocess.check_output(["bash", "-c", cmd]) 80 | time_end = timer() 81 | sensemap = cfl.read(filename_map_tmp) 82 | return sensemap, time_end - time_start 83 | 84 | 85 | def bart_pics( 86 | ks_input, 87 | verbose=False, 88 | sensemap=None, 89 | shape_e=2, 90 | do_cs=True, 91 | do_imag_reg=False, 92 | filename_ks_tmp="ks.tmp", 93 | filename_map_tmp="map.tmp", 94 | filename_im_tmp="im.tmp", 95 | filename_ks_out_tmp="ks_out.tmp", 96 | ): 97 | """BART PICS reconstruction.""" 98 | if verbose: 99 | print("PICS (l1-ESPIRiT) reconstruction...") 100 | 101 | cfl.write(filename_ks_tmp, ks_input) 102 | if sensemap is None: 103 | cmd = "bart ecalib -m %d -c 1e-9 %s %s" % ( 104 | shape_e, 105 | filename_ks_tmp, 106 | filename_map_tmp, 107 | ) 108 | if verbose: 109 | print(" %s" % cmd) 110 | subprocess.check_output(["bash", "-c", cmd]) 111 | else: 112 | cfl.write(filename_map_tmp, sensemap) 113 | if do_cs: 114 | flags = "-l1 -r 1e-2" 115 | else: 116 | flags = "-l2 -r 1e-2" 117 | if do_imag_reg: 118 | flags = flags + " -R R1:7:1e-1" 119 | 120 | cmd = "bart pics %s -S %s %s %s" % ( 121 | flags, 122 | filename_ks_tmp, 123 | filename_map_tmp, 124 | filename_im_tmp, 125 | ) 126 | if verbose: 127 | print(" %s" % cmd) 128 | subprocess.check_output(["bash", "-c", cmd]) 129 | 130 | cmd = "bart fakeksp -r %s %s %s %s" % ( 131 | filename_im_tmp, 132 | filename_ks_tmp, 133 | filename_map_tmp, 134 | filename_ks_out_tmp, 135 | ) 136 | if verbose: 137 | print(" %s" % cmd) 138 | subprocess.check_output(["bash", "-c", cmd]) 139 | ks_pics = np.squeeze(cfl.read(filename_ks_out_tmp)) 140 | ks_pics = np.expand_dims(ks_pics, axis=0) 141 | 142 | return ks_pics 143 | -------------------------------------------------------------------------------- /mri_util/cfl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2013-2015. The Regents of the University of California. 2 | # All rights reserved. Use of this source code is governed by 3 | # a BSD-style license which can be found in the LICENSE file. 4 | # 5 | # Authors: 6 | # 2013 Martin Uecker 7 | # 2015 Jonathan Tamir 8 | 9 | import numpy as np 10 | 11 | 12 | def read_hdr(name, order="C"): 13 | # get dims from .hdr 14 | h = open(name + ".hdr", "r") 15 | h.readline() # skip 16 | l = h.readline() 17 | h.close() 18 | dims = [int(i) for i in l.split()] 19 | if order == "C": 20 | dims.reverse() 21 | return dims 22 | 23 | 24 | def read(name, order="C"): 25 | 26 | dims = read_hdr(name, order) 27 | 28 | # remove singleton dimensions from the end 29 | n = np.prod(dims) 30 | dims_prod = np.cumprod(dims) 31 | dims = dims[: np.searchsorted(dims_prod, n) + 1] 32 | 33 | # load data and reshape into dims 34 | d = open(name + ".cfl", "r") 35 | a = np.fromfile(d, dtype=np.complex64, count=n) 36 | d.close() 37 | return a.reshape(dims, order=order) # column-major 38 | 39 | 40 | def readcfl(name): 41 | return read(name, order="F") 42 | 43 | 44 | def write(name, array, order="C"): 45 | h = open(name + ".hdr", "w") 46 | h.write("# Dimensions\n") 47 | if order == "C": 48 | for i in array.shape[::-1]: 49 | h.write("%d " % i) 50 | else: 51 | for i in array.shape: 52 | h.write("%d " % i) 53 | h.write("\n") 54 | h.close() 55 | 56 | d = open(name + ".cfl", "w") 57 | if order == "C": 58 | array.astype(np.complex64).tofile(d) 59 | else: 60 | # tranpose for column-major order 61 | array.T.astype(np.complex64).tofile(d) 62 | d.close() 63 | 64 | 65 | def writecfl(name, array): 66 | write(name, array, order="F") 67 | -------------------------------------------------------------------------------- /mri_util/coilcomp.py: -------------------------------------------------------------------------------- 1 | """Coil compression. 2 | 3 | Reference(s): 4 | [1] Zhang T, Pauly JM, Vasanawala SS, Lustig M. Coil compression for 5 | accelerated imaging with Cartesian sampling. Magn Reson Med 2013 6 | Mar 9;69:571-582 7 | """ 8 | import os 9 | import sys 10 | 11 | import numpy as np 12 | 13 | from mri_util import fftc 14 | 15 | 16 | def calc_gcc_weights_c(ks_calib, num_virtual_channels, correction=True): 17 | """Calculate coil compression weights. 18 | 19 | Input 20 | ks_calib -- raw k-space data of dimensions 21 | (num_channels, num_readout, num_kx) 22 | num_virtual_channels -- number of virtual channels to compress to 23 | correction -- apply rotation correction (default: True) 24 | Output 25 | cc_mat -- coil compression matrix (use apply_gcc_weights) 26 | """ 27 | me = "coilcomp.calc_gcc_weights_c" 28 | 29 | num_kx = ks_calib.shape[2] 30 | # num_readout = ks_calib.shape[1] 31 | num_channels = ks_calib.shape[0] 32 | 33 | if num_virtual_channels > num_channels: 34 | print( 35 | [ 36 | "%s> Num of virtual channels (%d) is more than the actual " 37 | + " channels (%d)!" 38 | ] 39 | % (me, num_virtual_channels, num_channels) 40 | ) 41 | return np.eye(num_channels, dtype=np.complex64) 42 | 43 | if num_kx > 1: 44 | # find max in readout 45 | tmp = np.sum(np.sum(np.power(np.abs(ks_calib), 2), axis=0), axis=1) 46 | i_xmax = np.argmax(tmp) 47 | # circ shift to move max to center (make copy to not touch original data) 48 | ks_calib_int = np.roll(ks_calib.copy(), int(num_kx / 2 - i_xmax), axis=-1) 49 | ks_calib_int = fftc.ifftc(ks_calib_int, axis=-1) 50 | else: 51 | ks_calib_int = ks_calib.copy() 52 | 53 | cc_mat = np.zeros((num_virtual_channels, num_channels, num_kx), dtype=np.complex64) 54 | for i_x in range(num_kx): 55 | ks_calib_x = np.squeeze(ks_calib_int[:, :, i_x]) 56 | U, s, Vh = np.linalg.svd(ks_calib_x.T, full_matrices=False) 57 | V = Vh.conj() 58 | cc_mat[:, :, i_x] = V[0:num_virtual_channels, :] 59 | 60 | if correction: 61 | for i_x in range(int(num_kx / 2) - 2, -1, -1): 62 | V1 = cc_mat[:, :, i_x + 1] 63 | V2 = cc_mat[:, :, i_x] 64 | A = np.matmul(V1.conj(), V2.T) 65 | Ua, sa, Vah = np.linalg.svd(A, full_matrices=False) 66 | P = np.matmul(Ua, Vah) 67 | P = P.conj() 68 | cc_mat[:, :, i_x] = np.matmul(P, cc_mat[:, :, i_x]) 69 | 70 | for i_x in range(int(num_kx / 2) - 1, num_kx, 1): 71 | V1 = cc_mat[:, :, i_x - 1] 72 | V2 = cc_mat[:, :, i_x] 73 | A = np.matmul(V1.conj(), V2.T) 74 | Ua, sa, Vah = np.linalg.svd(A, full_matrices=False) 75 | P = np.matmul(Ua, Vah) 76 | P = P.conj() 77 | cc_mat[:, :, i_x] = np.matmul(P, np.squeeze(cc_mat[:, :, i_x])) 78 | 79 | return cc_mat 80 | 81 | 82 | def apply_gcc_weights_c(ks, cc_mat): 83 | """Apply coil compression weights. 84 | 85 | Input 86 | ks -- raw k-space data of dimensions (num_channels, num_readout, num_kx) 87 | cc_mat -- coil compression matrix calculated using calc_gcc_weights 88 | Output 89 | ks_out -- coil compresssed data 90 | """ 91 | me = "coilcomp.apply_gcc_weights_c" 92 | 93 | num_channels = ks.shape[0] 94 | num_readout = ks.shape[1] 95 | num_kx = ks.shape[2] 96 | num_virtual_channels = cc_mat.shape[0] 97 | 98 | if num_channels != cc_mat.shape[1]: 99 | print("%s> ERROR! num channels does not match!" % me) 100 | print("%s> ks: num channels = %d" % (me, num_channels)) 101 | print("%s> cc_mat: num channels = %d" % (me, cc_mat.shape[1])) 102 | 103 | ks_x = fftc.ifftc(ks, axis=-1) 104 | ks_out = np.zeros((num_virtual_channels, num_readout, num_kx), dtype=np.complex64) 105 | for i_channel in range(num_virtual_channels): 106 | cc_mat_i = np.reshape(cc_mat[i_channel, :, :], (num_channels, 1, num_kx)) 107 | ks_out[i_channel, :, :] = np.sum(ks_x * cc_mat_i, axis=0) 108 | ks_out = fftc.fftc(ks_out, axis=-1) 109 | 110 | return ks_out 111 | 112 | 113 | def calc_gcc_weights(ks_calib, num_virtual_channels, correction=True): 114 | """Calculate coil compression weights. 115 | 116 | Input 117 | ks_calib -- raw k-space data of dimensions (num_kx, num_readout, num_channels) 118 | num_virtual_channels -- number of virtual channels to compress to 119 | correction -- apply rotation correction (default: True) 120 | Output 121 | cc_mat -- coil compression matrix (use apply_gcc_weights) 122 | """ 123 | 124 | me = "coilcomp.calc_gcc_weights" 125 | 126 | num_kx = ks_calib.shape[0] 127 | # num_readout = ks_calib.shape[1] 128 | num_channels = ks_calib.shape[2] 129 | 130 | if num_virtual_channels > num_channels: 131 | print( 132 | "%s> Num of virtual channels (%d) is more than the actual channels (%d)!" 133 | % (me, num_virtual_channels, num_channels) 134 | ) 135 | return np.eye(num_channels, dtype=complex) 136 | 137 | # find max in readout 138 | tmp = np.sum(np.sum(np.power(np.abs(ks_calib), 2), axis=2), axis=1) 139 | i_xmax = np.argmax(tmp) 140 | # circ shift to move max to center (make copy to not touch original data) 141 | ks_calib_int = np.roll(ks_calib.copy(), int(num_kx / 2 - i_xmax), axis=0) 142 | ks_calib_int = fftc.ifftc(ks_calib_int, axis=0) 143 | 144 | cc_mat = np.zeros((num_kx, num_channels, num_virtual_channels), dtype=complex) 145 | for i_x in range(num_kx): 146 | ks_calib_x = np.squeeze(ks_calib_int[i_x, :, :]) 147 | U, s, Vh = np.linalg.svd(ks_calib_x, full_matrices=False) 148 | V = Vh.conj().T 149 | cc_mat[i_x, :, :] = V[:, 0:num_virtual_channels] 150 | 151 | if correction: 152 | for i_x in range(int(num_kx / 2) - 2, -1, -1): 153 | V1 = cc_mat[i_x + 1, :, :] 154 | V2 = cc_mat[i_x, :, :] 155 | A = np.matmul(V1.conj().T, V2) 156 | Ua, sa, Vah = np.linalg.svd(A, full_matrices=False) 157 | P = np.matmul(Ua, Vah) 158 | P = P.conj().T 159 | cc_mat[i_x, :, :] = np.matmul(cc_mat[i_x, :, :], P) 160 | 161 | for i_x in range(int(num_kx / 2) - 1, num_kx, 1): 162 | V1 = cc_mat[i_x - 1, :, :] 163 | V2 = cc_mat[i_x, :, :] 164 | A = np.matmul(V1.conj().T, V2) 165 | Ua, sa, Vah = np.linalg.svd(A, full_matrices=False) 166 | P = np.matmul(Ua, Vah) 167 | P = P.conj().T 168 | cc_mat[i_x, :, :] = np.matmul(np.squeeze(cc_mat[i_x, :, :]), P) 169 | 170 | return cc_mat 171 | 172 | 173 | def apply_gcc_weights(ks, cc_mat): 174 | """ Apply coil compression weights 175 | Input 176 | ks -- raw k-space data of dimensions (num_kx, num_readout, num_channels) 177 | cc_mat -- coil compression matrix calculated using calc_gcc_weights 178 | Output 179 | ks_out -- coil compresssed data 180 | """ 181 | 182 | me = "coilcomp.apply_gcc_weights" 183 | 184 | if ks.shape[2] != cc_mat.shape[1]: 185 | print("%s> ERROR! num channels does not match!" % me) 186 | print("%s> ks: num channels = %d" % (me, ks.shape[2])) 187 | print("%s> cc_mat: num channels = %d" % (me, cc_mat.shape[1])) 188 | 189 | num_kx = ks.shape[0] 190 | num_readout = ks.shape[1] 191 | num_channels = ks.shape[2] 192 | num_virtual_channels = cc_mat.shape[2] 193 | 194 | ks_x = fftc.ifftc(ks, axis=0) 195 | ks_out = np.zeros((num_kx, num_readout, num_virtual_channels), dtype=complex) 196 | for i_channel in range(num_virtual_channels): 197 | cc_mat_i = np.reshape(cc_mat[:, :, i_channel], (num_kx, 1, num_channels)) 198 | ks_out[:, :, i_channel] = np.sum(ks_x * cc_mat_i, axis=2) 199 | ks_out = fftc.fftc(ks_out, axis=0) 200 | 201 | return ks_out 202 | -------------------------------------------------------------------------------- /mri_util/cs_metrics.py: -------------------------------------------------------------------------------- 1 | """Wraps BART functions.""" 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import os 5 | import subprocess 6 | 7 | import numpy as np 8 | 9 | from packages.fileio import cfl 10 | from packages.mrirecon import recon 11 | 12 | BIN_BART = "bart" 13 | 14 | 15 | def bart_generate_mask( 16 | shape, 17 | acc, 18 | variable_density=True, 19 | shape_calib=10, 20 | verbose=False, 21 | tmp_file="mask.tmp", 22 | ): 23 | """Use bart poisson to generate masks.""" 24 | if verbose: 25 | print("Generating sampling mask...") 26 | flags = "-Z %d -Y %d -z %g -y %g" % (shape[0], shape[1], acc[0], acc[1]) 27 | if shape_calib > 0: 28 | flags = flags + (" -C %d" % shape_calib) 29 | if variable_density: 30 | flags = flags + " -v" 31 | cmd = "%s poisson %s %s" % (BIN_BART, flags, tmp_file) 32 | if verbose: 33 | print(" %s" % cmd) 34 | subprocess.check_output(["bash", "-c", cmd]) 35 | mask = np.abs(np.squeeze(cfl.read(tmp_file))) 36 | return mask 37 | 38 | 39 | def bart_espirit( 40 | ks_input, 41 | shape=None, 42 | verbose=False, 43 | filename_ks_tmp="ks.tmp", 44 | filename_map_tmp="map.tmp", 45 | ): 46 | """Estimate sensitivity maps using BART ESPIRiT. 47 | 48 | ks_input dimensions: [emaps, channels, kz, ky, kx] 49 | """ 50 | if verbose: 51 | print("Estimating sensitivity maps...") 52 | if shape is not None: 53 | ks_input = recon.crop(ks_input, [-1, -1, shape[0], shape[1], -1]) 54 | cfl.write(filename_ks_tmp, ks_input) 55 | cmd = "%s ecalib %s %s" % (BIN_BART, filename_ks_tmp, filename_map_tmp) 56 | if verbose: 57 | print(" %s" % cmd) 58 | subprocess.check_output(["bash", "-c", cmd]) 59 | sensemap = cfl.read(filename_map_tmp) 60 | return sensemap 61 | 62 | 63 | def bart_pics( 64 | ks_input, 65 | sensemap=None, 66 | verbose=False, 67 | do_l1=True, 68 | filename_ks_tmp="ks.tmp", 69 | filename_map_tmp="map.tmp", 70 | filename_im_tmp="im.tmp", 71 | filename_ks_out_tmp="ks_out.tmp", 72 | ): 73 | """BART PICS reconstruction.""" 74 | if verbose: 75 | print("PICS (l1-ESPIRiT) reconstruction...") 76 | 77 | cfl.write(filename_ks_tmp, ks_input) 78 | if sensemap is None: 79 | cmd = "%s ecalib -c 1e-9 %s %s" % (BIN_BART, filename_ks_tmp, filename_map_tmp) 80 | if verbose: 81 | print(" %s" % cmd) 82 | subprocess.check_output(["bash", "-c", cmd]) 83 | else: 84 | cfl.write(filename_map_tmp, sensemap) 85 | 86 | pics_flags = "" 87 | if do_l1: 88 | pics_flags = "-l1 -r 1e-1" 89 | else: 90 | pics_flags = "-l2 -r 1e-1" 91 | cmd = "%s pics %s -S %s %s %s" % ( 92 | BIN_BART, 93 | pics_flags, 94 | filename_ks_tmp, 95 | filename_map_tmp, 96 | filename_im_tmp, 97 | ) 98 | if verbose: 99 | print(" %s" % cmd) 100 | subprocess.check_output(["bash", "-c", cmd]) 101 | 102 | cmd = "%s fakeksp -r %s %s %s %s" % ( 103 | BIN_BART, 104 | filename_im_tmp, 105 | filename_ks_tmp, 106 | filename_map_tmp, 107 | filename_ks_out_tmp, 108 | ) 109 | if verbose: 110 | print(" %s" % cmd) 111 | subprocess.check_output(["bash", "-c", cmd]) 112 | ks_pics = np.squeeze(cfl.read(filename_ks_out_tmp)) 113 | ks_pics = np.expand_dims(ks_pics, axis=0) 114 | 115 | return ks_pics 116 | 117 | 118 | def recon_dataset(raw_input, raw_output, dir_tmp=".", tag=None): 119 | """Recon datasets.""" 120 | if tag is None: 121 | tag = "%05d" % np.random.randint(0, 1e4) 122 | # raw_input = np.real(raw_input) 123 | # raw_input = raw_input[:, :, :, ::2] + 1j * raw_input[:, :, :, 1::2] 124 | 125 | # raw_output = np.real(raw_output) 126 | # raw_output = raw_output[:, :, :, ::2] + 1j * raw_output[:, :, :, 1::2] 127 | 128 | file_ksin = os.path.join(dir_tmp, "ksin." + tag) 129 | file_imout = os.path.join(dir_tmp, "imout." + tag) 130 | file_kscalib = os.path.join(dir_tmp, "kscalib." + tag) 131 | file_map = os.path.join(dir_tmp, "map." + tag) 132 | 133 | im_out = np.zeros((raw_input.shape[0], 2) + raw_input.shape[1:3], dtype=np.complex) 134 | acc_list = np.zeros((raw_input.shape[0], 1)) 135 | print("ESPIRiT reconstruction...") 136 | for i in range(raw_input.shape[0]): 137 | raw_input_i = raw_input[i, :, :, :] 138 | raw_output_i = raw_output[i, :, :, :] 139 | 140 | acc = np.sum(raw_output_i != 0) / np.sum(raw_input_i != 0) 141 | acc_list[i] = acc 142 | print(" [%d] Acceleration = %g" % (i, acc)) 143 | 144 | raw_output_i = np.transpose(raw_output_i, (2, 0, 1)) 145 | raw_output_i = np.reshape(raw_output_i, (1,) + raw_output_i.shape + (1,)) 146 | cfl.write(file_kscalib, raw_output_i) 147 | cmd = "%s ecalib -c 1e-9 %s %s" % (BIN_BART, file_kscalib, file_map) 148 | subprocess.check_output(["bash", "-c", cmd]) 149 | 150 | raw_input_i = np.transpose(raw_input_i, (2, 0, 1)) 151 | raw_input_i = np.reshape(raw_input_i, (1,) + raw_input_i.shape + (1,)) 152 | cfl.write(file_ksin, raw_input_i) 153 | cmd = "%s pics -l2 %s %s %s" % (BIN_BART, file_ksin, file_map, file_imout) 154 | subprocess.check_output(["bash", "-c", cmd]) 155 | im_out[i, :, :, :] = np.squeeze(cfl.read(file_imout)) 156 | 157 | return im_out, acc_list 158 | -------------------------------------------------------------------------------- /mri_util/fftc.py: -------------------------------------------------------------------------------- 1 | try: 2 | import pyfftw.interfaces.numpy_fft as fft 3 | except: 4 | from numpy import fft 5 | import numpy as np 6 | 7 | 8 | def ifftnc(x, axes): 9 | tmp = fft.fftshift(x, axes=axes) 10 | tmp = fft.ifftn(tmp, axes=axes) 11 | return fft.ifftshift(tmp, axes=axes) 12 | 13 | 14 | def fftnc(x, axes): 15 | tmp = fft.fftshift(x, axes=axes) 16 | tmp = fft.fftn(tmp, axes=axes) 17 | return fft.ifftshift(tmp, axes=axes) 18 | 19 | 20 | def fftc(x, axis=0, do_orthonorm=True): 21 | if do_orthonorm: 22 | scale = np.sqrt(x.shape[axis]) 23 | else: 24 | scale = 1.0 25 | return fftnc(x, (axis,)) / scale 26 | 27 | 28 | def ifftc(x, axis=0, do_orthonorm=True): 29 | if do_orthonorm: 30 | scale = np.sqrt(x.shape[axis]) 31 | else: 32 | scale = 1.0 33 | return ifftnc(x, (axis,)) * scale 34 | 35 | 36 | def fft2c(x, order="C", do_orthonorm=True): 37 | if order == "C": 38 | if do_orthonorm: 39 | scale = np.sqrt(np.prod(x.shape[-2:])) 40 | else: 41 | scale = 1.0 42 | return fftnc(x, (-2, -1)) / scale 43 | else: 44 | if do_orthonorm: 45 | scale = np.sqrt(np.prod(x.shape[:2])) 46 | else: 47 | scale = 1.0 48 | return fftnc(x, (0, 1)) / scale 49 | 50 | 51 | def ifft2c(x, order="C", do_orthonorm=True): 52 | if order == "C": 53 | if do_orthonorm: 54 | scale = np.sqrt(np.prod(x.shape[-2:])) 55 | else: 56 | scale = 1.0 57 | return ifftnc(x, (-2, -1)) * scale 58 | else: 59 | if do_orthonorm: 60 | scale = np.sqrt(np.prod(x.shape[:2])) 61 | else: 62 | scale = 1.0 63 | return ifftnc(x, (0, 1)) * scale 64 | 65 | 66 | def fft3c(x, order="C"): 67 | if order == "C": 68 | return fftnc(x, (-3, -2, -1)) 69 | else: 70 | return fftnc(x, (0, 1, 2)) 71 | 72 | 73 | def ifft3c(x, order="C"): 74 | if order == "C": 75 | return ifftnc(x, (-3, -2, -1)) 76 | else: 77 | return ifftnc(x, (0, 1, 2)) 78 | -------------------------------------------------------------------------------- /mri_util/mask.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on Thu Aug 16 11:03:49 2018 3 | 4 | @author: penglai 5 | 6 | routine to generate 2D kt sampling with polynomial (via vdDegree) sampling density 7 | and golden ratio ky-shifting along time 8 | """ 9 | 10 | from math import ceil, floor 11 | 12 | import numpy as np 13 | 14 | 15 | def goldenratio_shift(accel, nt): 16 | GOLDEN_RATIO = 0.618034 17 | 18 | return np.round(np.arange(0, nt) * GOLDEN_RATIO * accel) % accel 19 | 20 | 21 | def generate_perturbed2dvdkt( 22 | ny, 23 | nt, 24 | accel, 25 | nCal, 26 | vdDegree, 27 | partialFourierFactor=0.0, 28 | vdFactor=None, 29 | perturbFactor=0.0, 30 | adhereFactor=0.0, 31 | ): 32 | 33 | vdDegree = max(vdDegree, 0.0) 34 | perturbFactor = min(max(perturbFactor, 0.0), 1.0) 35 | adhereFactor = min(max(adhereFactor, 0.0), 1.0) 36 | nCal = max(nCal, 0) 37 | 38 | if vdFactor == None or vdFactor > accel: 39 | vdFactor = accel 40 | 41 | yCent = floor(ny / 2.0) 42 | yRadius = (ny - 1) / 2.0 43 | 44 | if vdDegree > 0: 45 | vdFactor = vdFactor ** (1.0 / vdDegree) 46 | 47 | accel_aCoef = (vdFactor - 1.0) / vdFactor 48 | accel_bCoef = 1.0 / vdFactor 49 | 50 | ktMask = np.zeros([ny, nt], np.complex64) 51 | ktShift = goldenratio_shift(accel, nt) 52 | 53 | for t in range(0, nt): 54 | # inital sampling with uiform density kt 55 | ySamp = np.arange(ktShift[t], ny, accel) 56 | 57 | # add random perturbation with certain adherence 58 | if perturbFactor > 0: 59 | for n in range(0, ySamp.size): 60 | if ( 61 | ySamp[n] < perturbFactor * accel 62 | or ySamp[n] >= ny - perturbFactor * accel 63 | ): 64 | continue 65 | 66 | yPerturb = perturbFactor * accel * (np.random.rand() - 0.5) 67 | 68 | ySamp[n] += yPerturb 69 | 70 | if n > 0: 71 | ySamp[n - 1] += adhereFactor * yPerturb 72 | 73 | if n < ySamp.size - 1: 74 | ySamp[n + 1] += adhereFactor * yPerturb 75 | 76 | ySamp = np.clip(ySamp, 0, ny - 1) 77 | 78 | ySamp = (ySamp - yRadius) / yRadius 79 | 80 | ySamp = ySamp * (accel_aCoef * np.abs(ySamp) + accel_bCoef) ** vdDegree 81 | 82 | ind = np.argsort(np.abs(ySamp)) 83 | ySamp = ySamp[ind] 84 | 85 | yUppHalf = np.where(ySamp >= 0)[0] 86 | yLowHalf = np.where(ySamp < 0)[0] 87 | 88 | # fit upper half k-space to Cartesian grid 89 | yAdjFactor = 1.0 90 | yEdge = floor(ySamp[yUppHalf[0]] * yRadius + yRadius + 0.0001) 91 | yOffset = 0.0 92 | 93 | for n in range(0, yUppHalf.size): 94 | # add a very small float 0.0001 to be tolerant to numerical error with floor() 95 | yLoc = min( 96 | floor( 97 | (yOffset + (ySamp[yUppHalf[n]] - yOffset) * yAdjFactor) * yRadius 98 | + yRadius 99 | + 0.0001 100 | ), 101 | ny - 1, 102 | ) 103 | 104 | if ktMask[yLoc, t] == 0: 105 | ktMask[yLoc, t] = 1 106 | 107 | yEdge = yLoc + 1 108 | 109 | else: 110 | ktMask[yEdge, t] = 1 111 | yOffset = ySamp[yUppHalf[n]] 112 | yAdjFactor = (yRadius - float(yEdge - yRadius)) / ( 113 | yRadius * (1 - abs(yOffset)) 114 | ) 115 | yEdge += 1 116 | 117 | # fit lower half k-space to Cartesian grid 118 | yAdjFactor = 1.0 119 | yEdge = floor(ySamp[yLowHalf[0]] * yRadius + yRadius + 0.0001) 120 | yOffset = 0.0 121 | 122 | if ktMask[yEdge, t] == 1: 123 | yEdge -= 1 124 | yOffset = ySamp[yLowHalf[0]] 125 | yAdjFactor = (yRadius + float(yEdge - yRadius)) / ( 126 | yRadius * (1.0 - abs(yOffset)) 127 | ) 128 | 129 | for n in range(0, yLowHalf.size): 130 | yLoc = max( 131 | floor( 132 | (yOffset + (ySamp[yLowHalf[n]] - yOffset) * yAdjFactor) * yRadius 133 | + yRadius 134 | + 0.0001 135 | ), 136 | 0, 137 | ) 138 | 139 | if ktMask[yLoc, t] == 0: 140 | ktMask[yLoc, t] = 1 141 | 142 | yEdge = yLoc + 1 143 | 144 | else: 145 | ktMask[yEdge, t] = 1 146 | yOffset = ySamp[yLowHalf[n]] 147 | yAdjFactor = (yRadius - float(yEdge - yRadius)) / ( 148 | yRadius * (1 - abs(yOffset)) 149 | ) 150 | yEdge -= 1 151 | 152 | # at last, add calibration data 153 | ktMask[(yCent - ceil(nCal / 2)) : (yCent + nCal - 1 - ceil(nCal / 2)), :] = 1 154 | 155 | # CMS: simulate partial Fourier scheme with alternating ky lines 156 | if partialFourierFactor > 0.0: 157 | nyMask = int(ny * partialFourierFactor) 158 | # print(nyMask) 159 | # print(ny-nyMask) 160 | ktMask[(ny - nyMask) : ny, 0::2] = 0 161 | ktMask[0:nyMask, 1::2] = 0 162 | 163 | return ktMask 164 | -------------------------------------------------------------------------------- /mri_util/metrics.py: -------------------------------------------------------------------------------- 1 | """Metrics for testing.""" 2 | import numpy as np 3 | import skimage.measure 4 | 5 | from mri_util import recon 6 | 7 | 8 | def compute_psnr(ref, x): 9 | """Compute peak to signal to noise ratio.""" 10 | max_val = np.max(np.abs(ref)) 11 | mse = np.mean(np.square(np.abs(x - ref))) 12 | psnr = 10 * np.log(np.square(max_val) / mse) / np.log(10) 13 | return psnr 14 | 15 | 16 | def compute_nrmse(ref, x): 17 | """Compute normalized root mean square error. 18 | The norm of reference is used to normalize the metric. 19 | """ 20 | mse = np.sqrt(np.mean(np.square(np.abs(ref - x)))) 21 | norm = np.sqrt(np.mean(np.square(np.abs(ref)))) 22 | 23 | return mse / norm 24 | 25 | 26 | def compute_ssim(ref, x, sos_axis=None): 27 | """Compute structural similarity index metric. 28 | The image is first converted to magnitude image and normalized 29 | before the metric is computed. 30 | """ 31 | ref = ref.copy() 32 | x = x.copy() 33 | if sos_axis is not None: 34 | x = recon.sumofsq(x, axis=sos_axis) 35 | ref = recon.sumofsq(ref, axis=sos_axis) 36 | x = np.squeeze(x) 37 | ref = np.squeeze(ref) 38 | x /= np.mean(np.square(np.abs(x))) 39 | ref /= np.mean(np.square(np.abs(ref))) 40 | return skimage.measure.compare_ssim(ref, x, data_range=x.max() - x.min()) 41 | 42 | 43 | def compute_all(ref, x, sos_axis=None): 44 | psnr = compute_psnr(ref, x) 45 | nrmse = compute_nrmse(ref, x) 46 | ssim = compute_ssim(ref, x, sos_axis=sos_axis) 47 | 48 | return psnr, nrmse, ssim 49 | -------------------------------------------------------------------------------- /mri_util/recon.py: -------------------------------------------------------------------------------- 1 | """Basic MRI reconstruction functions.""" 2 | import numpy as np 3 | 4 | 5 | def sumofsq(im, axis=0): 6 | """Compute square root of sum of squares. 7 | 8 | :param im: raw image 9 | """ 10 | if axis < 0: 11 | axis = im.ndim - 1 12 | if axis > im.ndim: 13 | print("ERROR! Dimension %d invalid for given matrix" % axis) 14 | return -1 15 | 16 | out = np.sqrt(np.sum(im.real * im.real + im.imag * im.imag, axis=axis)) 17 | 18 | return out 19 | 20 | 21 | def phasecontrast(im, ref, axis=-1, coilaxis=-1): 22 | """Compute phase contrast.""" 23 | if axis < 0: 24 | axis = im.ndim - 1 25 | 26 | out = np.conj(ref) * im 27 | if coilaxis >= 0: 28 | out = np.sum(out, axis=coilaxis) 29 | out = np.angle(out) 30 | 31 | return out 32 | 33 | 34 | def fftmod(im, axis=-1): 35 | """Apply 1 -1 modulation along dimension specified by axis""" 36 | if axis < 0: 37 | axis = im.ndim - 1 38 | 39 | # generate modulation kernel 40 | dims = im.shape 41 | mod = np.ones( 42 | np.append(dims[axis], np.ones(len(dims) - 1, dtype=int)), dtype=im.dtype 43 | ) 44 | mod[1 : dims[axis] : 2] = -1 45 | mod = np.transpose(mod, np.append(np.arange(1, len(dims)), 0)) 46 | 47 | # apply kernel 48 | tpdims = np.concatenate( 49 | (np.arange(0, axis), np.arange(axis + 1, len(dims)), [axis]) 50 | ) 51 | out = np.transpose(im, tpdims) # transpose for broadcasting 52 | out = out * mod 53 | tpdims = np.concatenate( 54 | (np.arange(0, axis), [len(dims) - 1], np.arange(axis, len(dims) - 1)) 55 | ) 56 | out = np.transpose(out, tpdims) # transpose back to original dims 57 | 58 | return out 59 | 60 | 61 | def crop_in_dim(im, shape, dim): 62 | """Centered crop of image.""" 63 | if dim < 0 or dim >= im.ndim: 64 | print("ERROR! Invalid dimension specified!") 65 | return im 66 | if shape > im.shape[dim]: 67 | print("ERROR! Invalid shape specified!") 68 | return im 69 | 70 | im_shape = im.shape 71 | tmp_shape = [ 72 | int(np.prod(im_shape[:dim])), 73 | im_shape[dim], 74 | int(np.prod(im_shape[(dim + 1) :])), 75 | ] 76 | im_out = np.reshape(im, tmp_shape) 77 | ind0 = (im_shape[dim] - shape) // 2 78 | ind1 = ind0 + shape 79 | im_out = im_out[:, ind0:ind1, :].copy() 80 | im_out = np.reshape(im_out, im_shape[:dim] + (shape,) + im_shape[(dim + 1) :]) 81 | return im_out 82 | 83 | 84 | def crop(im, out_shape, verbose=False): 85 | """Centered crop.""" 86 | if im.ndim != np.size(out_shape): 87 | print("ERROR! Num dim of input image not same as desired shape") 88 | print(" %d != %d" % (im.ndim, np.size(out_shape))) 89 | return [] 90 | 91 | im_out = im 92 | for i in range(np.size(out_shape)): 93 | if out_shape[i] > 0: 94 | if verbose: 95 | print("Crop [%d]: %d to %d" % (i, im_out.shape[i], out_shape[i])) 96 | im_out = crop_in_dim(im_out, out_shape[i], i) 97 | 98 | return im_out 99 | 100 | 101 | def zeropad(im, out_shape): 102 | """Zeropad image.""" 103 | if im.ndim != np.size(out_shape): 104 | print("ERROR! Num dim of input image not same as desired shape") 105 | print(" %d != %d" % (im.ndim, np.size(out_shape))) 106 | return im 107 | 108 | pad_shape = [] 109 | for i in range(np.size(out_shape)): 110 | if out_shape[i] == -1: 111 | pad_shape_i = [0, 0] 112 | else: 113 | pad_start = int((out_shape[i] - im.shape[i]) / 2) 114 | pad_end = out_shape[i] - im.shape[i] - pad_start 115 | pad_shape_i = [pad_start, pad_end] 116 | 117 | pad_shape = pad_shape + [pad_shape_i] 118 | 119 | im_out = np.pad(im, pad_shape, "constant") 120 | 121 | return im_out 122 | -------------------------------------------------------------------------------- /mri_util/tf_util.py: -------------------------------------------------------------------------------- 1 | """Common functions for setup.""" 2 | import numpy as np 3 | import scipy.signal 4 | import tensorflow as tf 5 | 6 | 7 | def compute_psnr(predictions, ground_truths, maxpsnr=100): 8 | """Compute PSNR.""" 9 | ndims = len(predictions.get_shape().as_list()) 10 | mse = tf.reduce_mean( 11 | tf.square(tf.abs(predictions - ground_truths)), axis=list(range(1, ndims)) 12 | ) 13 | maxvals = tf.reduce_max(tf.abs(ground_truths), axis=list(range(1, ndims))) 14 | psnrs = ( 15 | 20 * tf.log(maxvals / tf.sqrt(mse)) / 16 | tf.log(tf.constant(10, dtype=mse.dtype)) 17 | ) 18 | # Handle case where mse = 0. 19 | psnrs = tf.minimum(psnrs, maxpsnr) 20 | return psnrs 21 | 22 | 23 | def hartley(im): 24 | ft = fft2c(im) 25 | hart = tf.real(ft) - tf.imag(ft) 26 | return hart 27 | 28 | 29 | def getReal(tf_output, data_format): 30 | if data_format == "channels_last": 31 | real = tf_output[:, :, :, ::2] 32 | else: 33 | real = tf_output[:, ::2, :, :] 34 | return real 35 | 36 | 37 | def getImag(tf_output, data_format): 38 | if data_format == "channels_last": 39 | imag = tf_output[:, :, :, 1::2] 40 | else: 41 | imag = tf_output[:, 1::2, :, :] 42 | return imag 43 | 44 | 45 | def interleave(tf_output, data_format): 46 | if data_format == "channels_last": 47 | output_shape = tf.shape(tf_output) 48 | s = output_shape[3] 49 | realOut = tf_output[:, :, :, 0: s // 2] 50 | imagOut = tf_output[:, :, :, s // 2: s] 51 | tf_output = tf.concat([realOut, imagOut], 2) 52 | tf_output = tf.reshape(tf_output, output_shape) 53 | else: 54 | output_shape = tf.shape(tf_output) 55 | s = output_shape[1] 56 | realOut = tf_output[:, 0: s // 2, :, :] 57 | imagOut = tf_output[:, s // 2: s, :, :] 58 | tf_output = tf.concat([realOut, imagOut], 0) 59 | tf_output = tf.reshape(tf_output, output_shape) 60 | return tf_output 61 | 62 | 63 | def complex_to_channels(image, name="complex2channels"): 64 | """Convert data from complex to channels.""" 65 | with tf.name_scope(name): 66 | image_out = tf.stack([tf.real(image), tf.imag(image)], axis=-1) 67 | shape_out = tf.concat( 68 | [tf.shape(image)[:-1], [image.shape[-1] * 2]], axis=0) 69 | image_out = tf.reshape(image_out, shape_out) 70 | return image_out 71 | 72 | 73 | def channels_to_complex(image, name="channels2complex"): 74 | """Convert data from channels to complex.""" 75 | with tf.name_scope(name): 76 | image_out = tf.reshape(image, [-1, 2]) 77 | image_out = tf.complex(image_out[:, 0], image_out[:, 1]) 78 | shape_out = tf.concat( 79 | [tf.shape(image)[:-1], [image.shape[-1] // 2]], axis=0) 80 | image_out = tf.reshape(image_out, shape_out) 81 | return image_out 82 | 83 | 84 | def fftshift(im, axis=0, name="fftshift"): 85 | """Perform fft shift. 86 | This function assumes that the axis to perform fftshift is divisible by 2. 87 | """ 88 | with tf.name_scope(name): 89 | split0, split1 = tf.split(im, 2, axis=axis) 90 | output = tf.concat((split1, split0), axis=axis) 91 | 92 | return output 93 | 94 | 95 | def ifftc(im, name="ifftc", do_orthonorm=True): 96 | """Centered iFFT on second to last dimension.""" 97 | with tf.name_scope(name): 98 | im_out = im 99 | if do_orthonorm: 100 | fftscale = tf.sqrt(1.0 * im_out.get_shape().as_list()[-2]) 101 | else: 102 | fftscale = 1.0 103 | fftscale = tf.cast(fftscale, dtype=tf.complex64) 104 | if len(im.get_shape()) == 4: 105 | im_out = tf.transpose(im_out, [0, 3, 1, 2]) 106 | im_out = fftshift(im_out, axis=3) 107 | else: 108 | im_out = tf.transpose(im_out, [2, 0, 1]) 109 | im_out = fftshift(im_out, axis=2) 110 | with tf.device("/gpu:0"): 111 | # FFT is only supported on the GPU 112 | im_out = tf.ifft(im_out) * fftscale 113 | if len(im.get_shape()) == 4: 114 | im_out = fftshift(im_out, axis=3) 115 | im_out = tf.transpose(im_out, [0, 2, 3, 1]) 116 | else: 117 | im_out = fftshift(im_out, axis=2) 118 | im_out = tf.transpose(im_out, [1, 2, 0]) 119 | 120 | return im_out 121 | 122 | 123 | def fftc(im, name="fftc", do_orthonorm=True): 124 | """Centered FFT on second to last dimension.""" 125 | with tf.name_scope(name): 126 | im_out = im 127 | if do_orthonorm: 128 | fftscale = tf.sqrt(1.0 * im_out.get_shape().as_list()[-2]) 129 | else: 130 | fftscale = 1.0 131 | fftscale = tf.cast(fftscale, dtype=tf.complex64) 132 | if len(im.get_shape()) == 4: 133 | im_out = tf.transpose(im_out, [0, 3, 1, 2]) 134 | im_out = fftshift(im_out, axis=3) 135 | else: 136 | im_out = tf.transpose(im_out, [2, 0, 1]) 137 | im_out = fftshift(im_out, axis=2) 138 | with tf.device("/gpu:0"): 139 | im_out = tf.fft(im_out) / fftscale 140 | if len(im.get_shape()) == 4: 141 | im_out = fftshift(im_out, axis=3) 142 | im_out = tf.transpose(im_out, [0, 2, 3, 1]) 143 | else: 144 | im_out = fftshift(im_out, axis=2) 145 | im_out = tf.transpose(im_out, [1, 2, 0]) 146 | 147 | return im_out 148 | 149 | 150 | def ifft2c(im, name="ifft2c", do_orthonorm=True): 151 | """Centered inverse FFT2 on second and third dimensions.""" 152 | with tf.name_scope(name): 153 | im_out = im 154 | dims = tf.shape(im_out) 155 | if do_orthonorm: 156 | fftscale = tf.sqrt(tf.cast(dims[1] * dims[2], dtype=tf.float32)) 157 | else: 158 | fftscale = 1.0 159 | fftscale = tf.cast(fftscale, dtype=tf.complex64) 160 | 161 | # permute FFT dimensions to be the last (faster!) 162 | tpdims = list(range(len(im_out.get_shape().as_list()))) 163 | tpdims[-1], tpdims[1] = tpdims[1], tpdims[-1] 164 | tpdims[-2], tpdims[2] = tpdims[2], tpdims[-2] 165 | 166 | im_out = tf.transpose(im_out, tpdims) 167 | im_out = fftshift(im_out, axis=-1) 168 | im_out = fftshift(im_out, axis=-2) 169 | 170 | # with tf.device('/gpu:0'): 171 | im_out = tf.ifft2d(im_out) * fftscale 172 | 173 | im_out = fftshift(im_out, axis=-1) 174 | im_out = fftshift(im_out, axis=-2) 175 | im_out = tf.transpose(im_out, tpdims) 176 | 177 | return im_out 178 | 179 | 180 | def fft2c(im, name="fft2c", do_orthonorm=True): 181 | """Centered FFT2 on second and third dimensions.""" 182 | with tf.name_scope(name): 183 | im_out = im 184 | dims = tf.shape(im_out) 185 | if do_orthonorm: 186 | fftscale = tf.sqrt(tf.cast(dims[1] * dims[2], dtype=tf.float32)) 187 | else: 188 | fftscale = 1.0 189 | fftscale = tf.cast(fftscale, dtype=tf.complex64) 190 | 191 | # permute FFT dimensions to be the last (faster!) 192 | tpdims = list(range(len(im_out.get_shape().as_list()))) 193 | tpdims[-1], tpdims[1] = tpdims[1], tpdims[-1] 194 | tpdims[-2], tpdims[2] = tpdims[2], tpdims[-2] 195 | 196 | im_out = tf.transpose(im_out, tpdims) 197 | im_out = fftshift(im_out, axis=-1) 198 | im_out = fftshift(im_out, axis=-2) 199 | 200 | # with tf.device('/gpu:0'): 201 | im_out = tf.fft2d(im_out) / fftscale 202 | 203 | im_out = fftshift(im_out, axis=-1) 204 | im_out = fftshift(im_out, axis=-2) 205 | im_out = tf.transpose(im_out, tpdims) 206 | 207 | return im_out 208 | 209 | 210 | def sumofsq(image_in, keep_dims=False, axis=-1, name="sumofsq", type="mag"): 211 | """Compute square root of sum of squares.""" 212 | with tf.variable_scope(name): 213 | if type == "mag": 214 | image_out = tf.square(tf.abs(image_in)) 215 | else: 216 | image_out = tf.square(tf.angle(image_in)) 217 | image_out = tf.reduce_sum(image_out, keep_dims=keep_dims, axis=axis) 218 | image_out = tf.sqrt(image_out) 219 | 220 | return image_out 221 | 222 | 223 | def conj_kspace(image_in, name="kspace_conj"): 224 | """Conjugate k-space data.""" 225 | with tf.variable_scope(name): 226 | image_out = tf.reverse(image_in, axis=[1]) 227 | image_out = tf.reverse(image_out, axis=[2]) 228 | mod = np.zeros((1, 1, 1, image_in.get_shape().as_list()[-1])) 229 | mod[:, :, :, 1::2] = -1 230 | mod = tf.constant(mod, dtype=tf.float32) 231 | image_out = tf.multiply(image_out, mod) 232 | 233 | return image_out 234 | 235 | 236 | def replace_kspace(image_orig, image_cur, name="replace_kspace"): 237 | """Replace k-space with known values.""" 238 | with tf.variable_scope(name): 239 | mask_x = kspace_mask(image_orig) 240 | image_out = tf.add( 241 | tf.multiply(mask_x, image_orig), tf.multiply( 242 | (1 - mask_x), image_cur) 243 | ) 244 | 245 | return image_out 246 | 247 | 248 | def kspace_mask(image_orig, name="kspace_mask", dtype=None): 249 | """Find k-space mask.""" 250 | with tf.variable_scope(name): 251 | mask_x = tf.not_equal(image_orig, 0) 252 | if dtype is not None: 253 | mask_x = tf.cast(mask_x, dtype=dtype) 254 | return mask_x 255 | 256 | 257 | def kspace_threshhold(image_orig, threshhold=1e-8, name="kspace_threshhold"): 258 | """Find k-space mask based on threshhold. 259 | Anything less the specified threshhold is set to 0. 260 | Anything above the specified threshhold is set to 1. 261 | """ 262 | with tf.variable_scope(name): 263 | mask_x = tf.greater(tf.abs(image_orig), threshhold) 264 | mask_x = tf.cast(mask_x, dtype=tf.float32) 265 | return mask_x 266 | 267 | 268 | def kspace_location(image_size): 269 | """Construct matrix with k-space normalized location.""" 270 | x = np.arange(image_size[0], dtype=np.float32) / image_size[0] - 0.5 271 | y = np.arange(image_size[1], dtype=np.float32) / image_size[1] - 0.5 272 | xg, yg = np.meshgrid(x, y) 273 | out = np.stack((xg.T, yg.T)) 274 | return out 275 | 276 | 277 | def tf_kspace_location(tf_shape_y, tf_shape_x): 278 | """Construct matrix with k-psace normalized location as tensor.""" 279 | tf_y = tf.cast(tf.range(tf_shape_y), tf.float32) 280 | tf_y = tf_y / tf.cast(tf_shape_y, tf.float32) - 0.5 281 | tf_x = tf.cast(tf.range(tf_shape_x), tf.float32) 282 | tf_x = tf_x / tf.cast(tf_shape_x, tf.float32) - 0.5 283 | 284 | [tf_yg, tf_xg] = tf.meshgrid(tf_y, tf_x) 285 | tf_yg = tf.transpose(tf_yg, [1, 0]) 286 | tf_xg = tf.transpose(tf_xg, [1, 0]) 287 | out = tf.stack((tf_yg, tf_xg)) 288 | return out 289 | 290 | 291 | def create_window(out_shape, pad_shape=10): 292 | """Create 2D window mask.""" 293 | g_std = pad_shape / 10 294 | window_z = np.ones(out_shape[0] - pad_shape) 295 | window_z = np.convolve( 296 | window_z, scipy.signal.gaussian(pad_shape + 1, g_std), mode="full" 297 | ) 298 | 299 | window_z = np.expand_dims(window_z, axis=1) 300 | window_y = np.ones(out_shape[1] - pad_shape) 301 | window_y = np.convolve( 302 | window_y, scipy.signal.gaussian(pad_shape + 1, g_std), mode="full" 303 | ) 304 | window_y = np.expand_dims(window_y, axis=0) 305 | 306 | window = np.expand_dims(window_z * window_y, axis=2) 307 | window = window / np.max(window) 308 | 309 | return window 310 | 311 | 312 | def kspace_radius(image_size): 313 | """Construct matrix with k-space radius.""" 314 | x = np.arange(image_size[0], dtype=np.float32) / image_size[0] - 0.5 315 | y = np.arange(image_size[1], dtype=np.float32) / image_size[1] - 0.5 316 | xg, yg = np.meshgrid(x, y) 317 | kr = np.sqrt(xg * xg + yg * yg) 318 | 319 | return kr.T 320 | 321 | 322 | def sensemap_model(x, sensemap, name="sensemap_model", do_transpose=False): 323 | """Apply sensitivity maps.""" 324 | with tf.variable_scope(name): 325 | if do_transpose: 326 | x_shape = x.get_shape().as_list() 327 | x = tf.expand_dims(x, axis=-2) 328 | x = tf.multiply(tf.conj(sensemap), x) 329 | x = tf.reduce_sum(x, axis=-1) 330 | else: 331 | x = tf.expand_dims(x, axis=-1) 332 | x = tf.multiply(x, sensemap) 333 | x = tf.reduce_sum(x, axis=3) 334 | return x 335 | 336 | 337 | def model_forward(x, sensemap, name="model_forward"): 338 | """Apply forward model. 339 | Image domain to k-space domain. 340 | """ 341 | with tf.variable_scope(name): 342 | if sensemap is not None: 343 | x = sensemap_model(x, sensemap, do_transpose=False) 344 | x = fft2c(x) 345 | return x 346 | 347 | 348 | def model_transpose(x, sensemap, name="model_transpose"): 349 | """Apply transpose model. 350 | k-Space domain to image domain 351 | """ 352 | with tf.variable_scope(name): 353 | x = ifft2c(x) 354 | if sensemap is not None: 355 | x = sensemap_model(x, sensemap, do_transpose=True) 356 | return x 357 | -------------------------------------------------------------------------------- /mri_util/zReLU.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | import theano.tensor as T 6 | from keras import backend as K 7 | from keras.engine.topology import Layer 8 | 9 | 10 | class zReLU(Layer): 11 | def get_realpart(self, x): 12 | image_format = K.image_data_format() 13 | ndim = K.ndim(x) 14 | input_shape = K.shape(x) 15 | 16 | if (image_format == "channels_first" and ndim != 3) or ndim == 2: 17 | input_dim = input_shape[1] // 2 18 | return x[:, :input_dim] 19 | 20 | input_dim = input_shape[-1] // 2 21 | if ndim == 3: 22 | return x[:, :, :input_dim] 23 | elif ndim == 4: 24 | return x[:, :, :, :input_dim] 25 | elif ndim == 5: 26 | return x[:, :, :, :, :input_dim] 27 | 28 | def get_imagpart(self, x): 29 | image_format = K.image_data_format() 30 | ndim = K.ndim(x) 31 | input_shape = K.shape(x) 32 | 33 | if (image_format == "channels_first" and ndim != 3) or ndim == 2: 34 | input_dim = input_shape[1] // 2 35 | return x[:, input_dim:] 36 | 37 | input_dim = input_shape[-1] // 2 38 | if ndim == 3: 39 | return x[:, :, input_dim:] 40 | elif ndim == 4: 41 | return x[:, :, :, input_dim:] 42 | elif ndim == 5: 43 | return x[:, :, :, :, input_dim:] 44 | 45 | def get_angle(self, x): 46 | real = self.get_realpart(x) 47 | imag = self.get_imagpart(x) 48 | # ang = T.arctan2(imag,real) 49 | comp = tf.complex(real, imag) 50 | ang = tf.angle(comp) 51 | return ang 52 | # T.angle(comp_num) 53 | 54 | def __init__(self, **kwargs): 55 | super(zReLU, self).__init__(**kwargs) 56 | 57 | def build(self, input_shape): 58 | super(zReLU, self).build(input_shape) # Be sure to call this somewhere! 59 | 60 | def call(self, x): 61 | real = self.get_realpart(x) 62 | imag = self.get_imagpart(x) 63 | # mag = self.get_abs(x) 64 | ang = self.get_angle(x) + 0.0001 65 | indices1 = T.nonzero(T.ge(ang, pi / 2)) 66 | indices2 = T.nonzero(T.le(ang, 0)) 67 | 68 | real = T.set_subtensor(real[indices1], 0) 69 | imag = T.set_subtensor(imag[indices1], 0) 70 | 71 | real = T.set_subtensor(real[indices2], 0) 72 | imag = T.set_subtensor(imag[indices2], 0) 73 | 74 | act = K.concatenate([real, imag], axis=1) 75 | 76 | return act 77 | 78 | def compute_output_shape(self, input_shape): 79 | return input_shape 80 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib>=1.5.1 2 | numpy>=1.14.0 3 | scipy>=0.19.1 4 | tensorflow-gpu==1.14.0 5 | wget>=3.2 6 | mimir 7 | resampy 8 | sklearn -------------------------------------------------------------------------------- /setup_mri.py: -------------------------------------------------------------------------------- 1 | """Script for setup""" 2 | import sys 3 | import argparse 4 | import mri_prep 5 | 6 | 7 | def main(argv): 8 | """Parse args and executes commands.""" 9 | parser = argparse.ArgumentParser( 10 | description="Setup dataset for training") 11 | parser.add_argument("-d", "--download", default="raw", 12 | help="download directory (default: raw)") 13 | parser.add_argument("-m", "--masks", default="masks", 14 | help="mask directory (default: mask)") 15 | parser.add_argument("-v", "--verbose", action="store_true", 16 | help="verbose printing") 17 | parser.add_argument("-o", "--output", default="data", 18 | help="final data directory (default: data)") 19 | args = parser.parse_args() 20 | 21 | 22 | verbose = args.verbose 23 | dir_download = args.download 24 | dir_masks = args.masks 25 | dir_output = args.output 26 | 27 | if verbose: 28 | print("<<< Downloading data... >>>") 29 | mri_prep.download_dataset_knee(dir_download, verbose=verbose) 30 | 31 | if verbose: 32 | print("<<< Creating masks... >>>") 33 | mri_prep.create_masks(dir_masks, verbose=verbose) 34 | 35 | if verbose: 36 | print("<<< Preparing TFRecords...>>>") 37 | mri_prep.setup_data_tfrecords(dir_download, dir_output, verbose=verbose) 38 | 39 | if verbose: 40 | print("Done.") 41 | 42 | 43 | if __name__ == "__main__": 44 | main(sys.argv[1:]) -------------------------------------------------------------------------------- /test_conv1d.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import complex_utils 3 | 4 | tf_input = tf.ones( 5 | shape=[1, 256, 1], dtype=tf.dtypes.complex64, name=None) 6 | 7 | num_features = 100 8 | kernel_size = 3 9 | 10 | tf_output = complex_utils.complex_conv1d( 11 | tf_input, num_features=num_features, kernel_size=kernel_size) 12 | 13 | print(tf_output) 14 | -------------------------------------------------------------------------------- /test_images.py: -------------------------------------------------------------------------------- 1 | """Test loop that will calculate image metrics and representative images for a few images in a "test_images" folder.""" 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import os 5 | import random 6 | import subprocess 7 | import sys 8 | 9 | import numpy as np 10 | import scipy.misc 11 | import tensorflow as tf 12 | from tensorflow.python.util import deprecation 13 | 14 | # import mri_data 15 | import mri_data 16 | import mri_model 17 | from mri_util import cfl, fftc, metrics, tf_util 18 | 19 | # BIN_BART = "bart" 20 | BIN_BART = "/home/sandino/bart/bart-0.4.03/bart" 21 | 22 | deprecation._PRINT_DEPRECATION_WARNINGS = False 23 | 24 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 25 | 26 | # import mri_data_cc as mri_data 27 | tf.app.flags.DEFINE_float("l1", 0.001, "L1 regularizer for CS") 28 | 29 | tf.app.flags.DEFINE_string("gpu", "single", "Single or multi GPU Mode") 30 | tf.app.flags.DEFINE_string("conv", "complex", "Real or complex convolution") 31 | tf.app.flags.DEFINE_boolean("conjugate", "False", "Complex conjugate") 32 | # Data dimensions 33 | tf.app.flags.DEFINE_integer("feat_map", 128, "Number of feature maps") 34 | 35 | tf.app.flags.DEFINE_integer("shape_y", 180, "Image shape in Y") 36 | tf.app.flags.DEFINE_integer("shape_z", 80, "Image shape in Z") 37 | tf.app.flags.DEFINE_integer( 38 | "num_channels", 8, "Number of channels for input datasets.") 39 | tf.app.flags.DEFINE_integer( 40 | "num_emaps", 1, "Number of eigen maps for input sensitivity maps." 41 | ) 42 | 43 | # For logging 44 | tf.app.flags.DEFINE_integer("print_level", 1, "Print out level.") 45 | tf.app.flags.DEFINE_string( 46 | "log_root", "summary", "Root directory where logs are written to." 47 | ) 48 | tf.app.flags.DEFINE_string( 49 | "train_dir", "train", "Directory for checkpoints and event logs." 50 | ) 51 | tf.app.flags.DEFINE_integer( 52 | "num_summary_image", 4, "Number of images for summary output" 53 | ) 54 | tf.app.flags.DEFINE_integer( 55 | "log_every_n_steps", 10, "The frequency with which logs are print." 56 | ) 57 | tf.app.flags.DEFINE_integer( 58 | "save_summaries_secs", 59 | 10, 60 | "The frequency with which summaries are saved, " + "in seconds.", 61 | ) 62 | 63 | tf.app.flags.DEFINE_integer( 64 | "save_interval_secs", 65 | 10, 66 | "The frequency with which the model is saved, " + "in seconds.", 67 | ) 68 | 69 | tf.app.flags.DEFINE_integer( 70 | "random_seed", 1000, "Seed to initialize random number generators." 71 | ) 72 | 73 | # For model 74 | tf.app.flags.DEFINE_integer( 75 | "num_grad_steps", 2, "Number of grad steps for unrolled algorithms" 76 | ) 77 | tf.app.flags.DEFINE_boolean( 78 | "do_hard_proj", True, "Turn on/off hard data projection at the end" 79 | ) 80 | 81 | # Optimization Flags 82 | tf.app.flags.DEFINE_string("device", "0", "GPU device to use.") 83 | tf.app.flags.DEFINE_integer( 84 | "batch_size", 4, "The number of samples in each batch.") 85 | 86 | tf.app.flags.DEFINE_float( 87 | "adam_beta2", 0.999, "The exponential decay rate for the 2nd moment estimates." 88 | ) 89 | tf.app.flags.DEFINE_float( 90 | "opt_epsilon", 1.0, "Epsilon term for the optimizer.") 91 | tf.app.flags.DEFINE_float("learning_rate", 0.001, "Initial learning rate.") 92 | tf.app.flags.DEFINE_integer( 93 | "max_steps", None, "The maximum number of training steps.") 94 | 95 | # Dataset Flags 96 | tf.app.flags.DEFINE_string( 97 | "mask_path", "masks", "Directory where masks are located.") 98 | tf.app.flags.DEFINE_string( 99 | "train_path", "train", "Sub directory where training data are located." 100 | ) 101 | tf.app.flags.DEFINE_string( 102 | "dataset_dir", "dataset", "The directory where the dataset files are stored." 103 | ) 104 | 105 | tf.app.flags.DEFINE_boolean( 106 | "do_validation", True, "Turn on/off validation during training" 107 | ) 108 | 109 | tf.app.flags.DEFINE_string( 110 | "mode", "train_validate", "Train_validate, train, or predict" 111 | ) 112 | 113 | tf.app.flags.DEFINE_string( 114 | "activation", "relu", "The activation function used") 115 | # If not defined will loop through entire test directory 116 | tf.app.flags.DEFINE_integer("num_cases", None, "The number of inference files") 117 | 118 | # plot middle layer weights in frequency domain 119 | tf.app.flags.DEFINE_integer("layer_num", 0, "The number layer to plot") 120 | 121 | FLAGS = tf.app.flags.FLAGS 122 | 123 | 124 | def main(_): 125 | if FLAGS.batch_size is not 1: 126 | print("Error: to test images, batch size must be 1") 127 | exit() 128 | model_dir = os.path.join(FLAGS.log_root, FLAGS.train_dir) 129 | if not os.path.exists(FLAGS.log_root): 130 | os.makedirs(FLAGS.log_root) 131 | if not os.path.exists(model_dir): 132 | os.makedirs(model_dir) 133 | bart_dir = os.path.join(model_dir, "bart_recon") 134 | if not os.path.exists(bart_dir): 135 | os.makedirs(bart_dir) 136 | 137 | image_dir = os.path.join(model_dir, "images") 138 | if not os.path.exists(image_dir): 139 | os.makedirs(image_dir) 140 | 141 | run_config = tf.ConfigProto() 142 | run_config.gpu_options.allow_growth = True 143 | 144 | with tf.Session(config=run_config) as sess: 145 | """Execute main function.""" 146 | os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.device 147 | 148 | if not FLAGS.dataset_dir: 149 | raise ValueError( 150 | "You must supply the dataset directory with " + "--dataset_dir" 151 | ) 152 | 153 | if FLAGS.random_seed >= 0: 154 | random.seed(FLAGS.random_seed) 155 | np.random.seed(FLAGS.random_seed) 156 | 157 | tf.logging.set_verbosity(tf.logging.INFO) 158 | 159 | print("Preparing dataset...") 160 | out_shape = [FLAGS.shape_z, FLAGS.shape_y] 161 | 162 | test_dataset, num_files = mri_data.create_dataset( 163 | os.path.join(FLAGS.dataset_dir, "test_images"), 164 | FLAGS.mask_path, 165 | num_channels=FLAGS.num_channels, 166 | num_emaps=FLAGS.num_emaps, 167 | batch_size=FLAGS.batch_size, 168 | out_shape=out_shape, 169 | ) 170 | # channels first: (batch, channels, z, y) 171 | # placeholders 172 | ks_shape = [None, FLAGS.shape_z, FLAGS.shape_y, FLAGS.num_channels] 173 | ks_place = tf.placeholder(tf.complex64, ks_shape) 174 | sense_shape = [None, FLAGS.shape_z, 175 | FLAGS.shape_y, 1, FLAGS.num_channels] 176 | sense_place = tf.placeholder(tf.complex64, sense_shape) 177 | im_shape = [None, FLAGS.shape_z, FLAGS.shape_y, 1] 178 | im_truth_place = tf.placeholder(tf.complex64, im_shape) 179 | # run through unrolled 180 | im_out_place = mri_model.unroll_fista( 181 | ks_place, 182 | sense_place, 183 | is_training=True, 184 | verbose=True, 185 | do_hardproj=FLAGS.do_hard_proj, 186 | num_summary_image=FLAGS.num_summary_image, 187 | resblock_num_features=FLAGS.feat_map, 188 | num_grad_steps=FLAGS.num_grad_steps, 189 | conv=FLAGS.conv, 190 | ) 191 | saver = tf.train.Saver() 192 | summary_writer = tf.summary.FileWriter(model_dir, sess.graph) 193 | 194 | # initialize model 195 | print("[*] initializing network...") 196 | if not load(model_dir, saver, sess): 197 | sess.run(tf.global_variables_initializer()) 198 | coord = tf.train.Coordinator() 199 | threads = tf.train.start_queue_runners(sess, coord) 200 | 201 | # See how many parameters are in model 202 | total_parameters = 0 203 | for variable in tf.trainable_variables(): 204 | variable_parameters = 1 205 | for dim in variable.get_shape(): 206 | variable_parameters *= dim.value 207 | total_parameters += variable_parameters 208 | print("Total number of trainable parameters: %d" % total_parameters) 209 | tf.summary.scalar("parameters/parameters", total_parameters) 210 | 211 | test_iterator = test_dataset.make_one_shot_iterator() 212 | features, labels = test_iterator.get_next() 213 | 214 | ks_truth = labels 215 | ks_in = features["ks_input"] 216 | sense_in = features["sensemap"] 217 | mask_recon = features["mask_recon"] 218 | im_in = tf_util.model_transpose(ks_in * mask_recon, sense_in) 219 | im_truth = tf_util.model_transpose(ks_truth * mask_recon, sense_in) 220 | 221 | total_summary = tf.summary.merge_all() 222 | 223 | output_psnr = [] 224 | output_nrmse = [] 225 | output_ssim = [] 226 | cs_psnr = [] 227 | cs_nrmse = [] 228 | cs_ssim = [] 229 | 230 | for test_file in range(num_files): 231 | ks_in_run, sense_in_run, im_truth_run, im_in_run = sess.run( 232 | [ks_in, sense_in, im_truth, im_in] 233 | ) 234 | im_out, total_summary_run = sess.run( 235 | [im_out_place, total_summary], 236 | feed_dict={ 237 | ks_place: ks_in_run, 238 | sense_place: sense_in_run, 239 | im_truth_place: im_truth_run, 240 | }, 241 | ) 242 | 243 | # CS recon 244 | bart_test = bart_cs(bart_dir, ks_in_run, sense_in_run, l1=FLAGS.l1) 245 | 246 | # Rotating images 247 | # im_in_run = np.rot90(np.squeeze(im_in_run), k=3) 248 | # im_out = np.rot90(np.squeeze(im_out), k=3) 249 | # bart_test = np.rot90(np.squeeze(bart_test), k=3) 250 | # im_truth_run = np.rot90(np.squeeze(im_truth_run), k=3) 251 | 252 | # save magnitude input, output, cs, truth as .png 253 | # complex 254 | if FLAGS.conv == "complex": 255 | mag_images = np.squeeze( 256 | np.absolute( 257 | np.concatenate( 258 | (im_out, bart_test, im_truth_run), axis=2) 259 | ) 260 | ) 261 | phase_images = np.squeeze( 262 | np.angle(np.concatenate( 263 | (im_out, bart_test, im_truth_run), axis=2)) 264 | ) 265 | diff_out = im_truth_run - im_out 266 | 267 | diff_cs = im_truth_run - bart_test 268 | 269 | diff_mag = np.squeeze( 270 | np.absolute(np.concatenate((diff_out, diff_cs), axis=2)) 271 | ) 272 | diff_phase = np.squeeze( 273 | np.angle(np.concatenate((diff_out, diff_cs), axis=2)) 274 | ) 275 | 276 | if FLAGS.conv == "real": 277 | mag_images = np.squeeze( 278 | np.absolute(np.concatenate((im_in_run, im_out), axis=2)) 279 | ) 280 | phase_images = np.squeeze( 281 | np.angle(np.concatenate((im_in_run, im_out), axis=2)) 282 | ) 283 | diff_in = im_truth_run - im_in_run 284 | diff_out = im_truth_run - im_out 285 | 286 | diff_mag = np.squeeze( 287 | np.absolute(np.concatenate((diff_in, diff_out), axis=2)) 288 | ) 289 | diff_phase = np.squeeze( 290 | np.angle(np.concatenate((diff_in, diff_out), axis=2)) 291 | ) 292 | 293 | filename = image_dir + "/mag_" + str(test_file) + ".png" 294 | scipy.misc.imsave(filename, mag_images) 295 | 296 | # filename = image_dir + "/diff_mag_" + str(test_file) + ".png" 297 | # scipy.misc.imsave(filename, diff_mag) 298 | 299 | filename = image_dir + "/phase_" + str(test_file) + ".png" 300 | scipy.misc.imsave(filename, phase_images) 301 | 302 | # filename = image_dir + "/diff_phase_" + str(test_file) + ".png" 303 | # scipy.misc.imsave(filename, diff_phase) 304 | 305 | filename = image_dir + "/diff_phase_" + str(test_file) + ".npy" 306 | np.save(filename, diff_phase) 307 | 308 | filename = image_dir + "/diff_mag_" + str(test_file) + ".npy" 309 | np.save(filename, diff_mag) 310 | 311 | psnr, nrmse, ssim = metrics.compute_all( 312 | im_truth_run, im_out, sos_axis=-1) 313 | output_psnr.append(psnr) 314 | output_nrmse.append(nrmse) 315 | output_ssim.append(ssim) 316 | 317 | print("output psnr, nrmse, ssim") 318 | print( 319 | np.mean(output_psnr), 320 | np.std(output_psnr), 321 | np.mean(output_nrmse), 322 | np.std(output_nrmse), 323 | np.mean(output_ssim), 324 | np.std(output_ssim), 325 | ) 326 | 327 | psnr, nrmse, ssim = metrics.compute_all( 328 | im_truth_run, bart_test, sos_axis=-1 329 | ) 330 | cs_psnr.append(psnr) 331 | cs_nrmse.append(nrmse) 332 | cs_ssim.append(ssim) 333 | 334 | print("cs psnr, nrmse, ssim") 335 | print( 336 | np.mean(cs_psnr), 337 | np.std(cs_psnr), 338 | np.mean(cs_nrmse), 339 | np.std(cs_nrmse), 340 | np.mean(cs_ssim), 341 | np.std(cs_ssim), 342 | ) 343 | print("End of testing loop") 344 | txt_path = os.path.join(model_dir, "metrics.txt") 345 | f = open(txt_path, "w") 346 | f.write( 347 | "parameters = " + str(total_parameters) + "\n" 348 | "output psnr = " 349 | + str(np.mean(output_psnr)) 350 | + " +\- " 351 | + str(np.std(output_psnr)) 352 | + "\n" 353 | + "output nrmse = " 354 | + str(np.mean(output_nrmse)) 355 | + " +\- " 356 | + str(np.std(output_nrmse)) 357 | + "\n" 358 | + "output ssim = " 359 | + str(np.mean(output_ssim)) 360 | + " +\- " 361 | + str(np.std(output_ssim)) 362 | + "\n" 363 | "cs psnr = " 364 | + str(np.mean(cs_psnr)) 365 | + " +\- " 366 | + str(np.std(cs_psnr)) 367 | + "\n" 368 | + "output nrmse = " 369 | + str(np.mean(cs_nrmse)) 370 | + " +\- " 371 | + str(np.std(cs_nrmse)) 372 | + "\n" 373 | + "output ssim = " 374 | + str(np.mean(cs_ssim)) 375 | + " +\- " 376 | + str(np.std(cs_ssim)) 377 | ) 378 | f.close() 379 | 380 | 381 | def load(log_dir, saver, sess): 382 | print("[*] Reading Checkpoints...") 383 | ckpt = tf.train.get_checkpoint_state(log_dir) 384 | if ckpt and ckpt.model_checkpoint_path: 385 | saver.restore(sess, ckpt.model_checkpoint_path) 386 | print("[*] Model restored.") 387 | return True 388 | else: 389 | print("[*] Failed to find a checkpoint") 390 | return False 391 | 392 | 393 | def bart_cs(bart_dir, ks, sensemap, l1=0.01): 394 | cfl_ks = np.squeeze(ks) 395 | cfl_ks = np.expand_dims(cfl_ks, -2) 396 | cfl_sensemap = np.squeeze(sensemap) 397 | cfl_sensemap = np.expand_dims(cfl_sensemap, -2) 398 | 399 | ks_dir = os.path.join(bart_dir, "file_ks") 400 | sense_dir = os.path.join(bart_dir, "file_sensemap") 401 | img_dir = os.path.join(bart_dir, "file_img") 402 | 403 | cfl.write(ks_dir, cfl_ks, "R") 404 | cfl.write(sense_dir, cfl_sensemap, "R") 405 | 406 | # L1-wavelet regularized 407 | cmd_flags = "-S -e -R W:3:0:%f -i 100" % l1 408 | 409 | cmd = "%s pics %s %s %s %s" % ( 410 | BIN_BART, cmd_flags, ks_dir, sense_dir, img_dir,) 411 | subprocess.check_call(["bash", "-c", cmd]) 412 | bart_recon = load_recon(img_dir, sense_dir) 413 | return bart_recon 414 | 415 | 416 | def load_recon(file, file_sensemap): 417 | bart_recon = np.squeeze(cfl.read(file)) 418 | if bart_recon.ndim == 2: 419 | bart_recon = np.transpose(bart_recon, [1, 0]) 420 | bart_recon = np.expand_dims(bart_recon, axis=0) 421 | bart_recon = np.expand_dims(bart_recon, axis=-1) 422 | if bart_recon.ndim == 3: 423 | bart_recon = np.transpose(bart_recon, [2, 1, 0]) 424 | bart_recon = np.expand_dims(bart_recon, axis=-1) 425 | # print(bart_recon.shape) 426 | return bart_recon 427 | 428 | 429 | def calculate_metrics(output, bart_test, truth): 430 | cs_psnr = [] 431 | cs_nrmse = [] 432 | cs_ssim = [] 433 | output_psnr = [] 434 | output_nrmse = [] 435 | output_ssim = [] 436 | 437 | psnr, nrmse, ssim = metrics.compute_all(truth, output, sos_axis=-1) 438 | output_psnr.append(psnr) 439 | output_nrmse.append(nrmse) 440 | output_ssim.append(ssim) 441 | 442 | print("cs psnr, nrmse, ssim") 443 | print( 444 | np.mean(cs_psnr), 445 | np.std(cs_psnr), 446 | np.mean(cs_nrmse), 447 | np.std(cs_nrmse), 448 | np.mean(cs_ssim), 449 | np.std(cs_ssim), 450 | ) 451 | print(output_psnr) 452 | print("output psnr, nrmse, ssim") 453 | print( 454 | np.mean(output_psnr), 455 | np.std(output_psnr), 456 | np.mean(output_nrmse), 457 | np.std(output_nrmse), 458 | np.mean(output_ssim), 459 | np.std(output_ssim), 460 | ) 461 | 462 | 463 | def _create_summary(sense_place, ks_place, im_out_place, im_truth_place): 464 | sensemap = sense_place 465 | ks_input = ks_place 466 | image_output = im_out_place 467 | image_truth = im_truth_place 468 | 469 | image_input = tf_util.model_transpose(ks_input, sensemap) 470 | mask_input = tf_util.kspace_mask(ks_input, dtype=tf.complex64) 471 | ks_output = tf_util.model_forward(image_output, sensemap) 472 | ks_truth = tf_util.model_forward(image_truth, sensemap) 473 | 474 | with tf.name_scope("input-output-truth"): 475 | summary_input = tf_util.sumofsq(ks_input, keep_dims=True) 476 | summary_output = tf_util.sumofsq(ks_output, keep_dims=True) 477 | summary_truth = tf_util.sumofsq(ks_truth, keep_dims=True) 478 | summary_fft = tf.log( 479 | tf.concat((summary_input, summary_output, 480 | summary_truth), axis=2) + 1e-6 481 | ) 482 | tf.summary.image("kspace", summary_fft, 483 | max_outputs=FLAGS.num_summary_image) 484 | summary_input = tf_util.sumofsq(image_input, keep_dims=True) 485 | summary_output = tf_util.sumofsq(image_output, keep_dims=True) 486 | summary_truth = tf_util.sumofsq(image_truth, keep_dims=True) 487 | summary_image = tf.concat( 488 | (summary_input, summary_output, summary_truth), axis=2 489 | ) 490 | tf.summary.image("image", summary_image, 491 | max_outputs=FLAGS.num_summary_image) 492 | 493 | with tf.name_scope("truth"): 494 | summary_truth_real = tf.reduce_sum( 495 | image_truth, axis=-1, keep_dims=True) 496 | summary_truth_real = tf.real(summary_truth_real) 497 | tf.summary.image( 498 | "image_real", summary_truth_real, max_outputs=FLAGS.num_summary_image 499 | ) 500 | 501 | with tf.name_scope("mask"): 502 | summary_mask = tf_util.sumofsq(mask_input, keep_dims=True) 503 | tf.summary.image("mask", summary_mask, 504 | max_outputs=FLAGS.num_summary_image) 505 | 506 | with tf.name_scope("sensemap"): 507 | summary_map = tf.slice( 508 | tf.abs(sensemap), [0, 0, 0, 0, 0], [-1, -1, -1, 1, -1]) 509 | summary_map = tf.transpose(summary_map, [0, 1, 4, 2, 3]) 510 | summary_map = tf.reshape( 511 | summary_map, [tf.shape(summary_map)[0], 512 | tf.shape(summary_map)[1], -1] 513 | ) 514 | summary_map = tf.expand_dims(summary_map, axis=-1) 515 | tf.summary.image("image", summary_map, 516 | max_outputs=FLAGS.num_summary_image) 517 | 518 | 519 | if __name__ == "__main__": 520 | tf.app.run() 521 | -------------------------------------------------------------------------------- /test_loop.py: -------------------------------------------------------------------------------- 1 | """Test loop that will calculate image metrics.""" 2 | from __future__ import absolute_import, division, print_function 3 | 4 | import os 5 | import random 6 | import subprocess 7 | import sys 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from tensorflow.python.util import deprecation 12 | 13 | import mri_data 14 | import mri_model 15 | from mri_util import cfl, fftc, metrics, tf_util 16 | 17 | BIN_BART = "bart" 18 | 19 | deprecation._PRINT_DEPRECATION_WARNINGS = False 20 | 21 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 22 | 23 | tf.app.flags.DEFINE_string("gpu", "single", "Single or multi GPU Mode") 24 | tf.app.flags.DEFINE_string("conv", "complex", "Real or complex convolution") 25 | tf.app.flags.DEFINE_boolean("do_conjugate", "False", "Complex conjugate") 26 | # Data dimensions 27 | tf.app.flags.DEFINE_integer("feat_map", 128, "Number of feature maps") 28 | 29 | tf.app.flags.DEFINE_integer("shape_y", 180, "Image shape in Y") 30 | tf.app.flags.DEFINE_integer("shape_z", 80, "Image shape in Z") 31 | tf.app.flags.DEFINE_integer( 32 | "num_channels", 8, "Number of channels for input datasets.") 33 | tf.app.flags.DEFINE_integer( 34 | "num_emaps", 1, "Number of eigen maps for input sensitivity maps." 35 | ) 36 | 37 | # For logging 38 | tf.app.flags.DEFINE_integer("print_level", 1, "Print out level.") 39 | tf.app.flags.DEFINE_string( 40 | "log_root", "summary", "Root directory where logs are written to." 41 | ) 42 | tf.app.flags.DEFINE_string( 43 | "train_dir", "train", "Directory for checkpoints and event logs." 44 | ) 45 | tf.app.flags.DEFINE_integer( 46 | "num_summary_image", 4, "Number of images for summary output" 47 | ) 48 | tf.app.flags.DEFINE_integer( 49 | "log_every_n_steps", 10, "The frequency with which logs are print." 50 | ) 51 | tf.app.flags.DEFINE_integer( 52 | "save_summaries_secs", 53 | 10, 54 | "The frequency with which summaries are saved, " + "in seconds.", 55 | ) 56 | 57 | tf.app.flags.DEFINE_integer( 58 | "save_interval_secs", 59 | 10, 60 | "The frequency with which the model is saved, " + "in seconds.", 61 | ) 62 | 63 | tf.app.flags.DEFINE_integer( 64 | "random_seed", 1000, "Seed to initialize random number generators." 65 | ) 66 | 67 | # For model 68 | tf.app.flags.DEFINE_integer( 69 | "num_grad_steps", 2, "Number of grad steps for unrolled algorithms" 70 | ) 71 | tf.app.flags.DEFINE_boolean( 72 | "do_hard_proj", True, "Turn on/off hard data projection at the end" 73 | ) 74 | 75 | # Optimization Flags 76 | tf.app.flags.DEFINE_string("device", "0", "GPU device to use.") 77 | tf.app.flags.DEFINE_integer( 78 | "batch_size", 4, "The number of samples in each batch.") 79 | 80 | tf.app.flags.DEFINE_float( 81 | "adam_beta2", 0.999, "The exponential decay rate for the 2nd moment estimates." 82 | ) 83 | tf.app.flags.DEFINE_float( 84 | "opt_epsilon", 1.0, "Epsilon term for the optimizer.") 85 | tf.app.flags.DEFINE_float("learning_rate", 0.001, "Initial learning rate.") 86 | tf.app.flags.DEFINE_integer( 87 | "max_steps", None, "The maximum number of training steps.") 88 | 89 | # Dataset Flags 90 | tf.app.flags.DEFINE_string( 91 | "mask_path", "masks", "Directory where masks are located.") 92 | tf.app.flags.DEFINE_string( 93 | "train_path", "train", "Sub directory where training data are located." 94 | ) 95 | tf.app.flags.DEFINE_string( 96 | "dataset_dir", "dataset", "The directory where the dataset files are stored." 97 | ) 98 | 99 | tf.app.flags.DEFINE_boolean( 100 | "do_validation", True, "Turn on/off validation during training" 101 | ) 102 | 103 | tf.app.flags.DEFINE_string( 104 | "mode", "train_validate", "Train_validate, train, or predict" 105 | ) 106 | 107 | tf.app.flags.DEFINE_string( 108 | "activation", "relu", "The activation function used") 109 | # If not defined will loop through entire test directory 110 | tf.app.flags.DEFINE_integer("num_cases", None, "The number of inference files") 111 | 112 | # plot middle layer weights in frequency domain 113 | tf.app.flags.DEFINE_integer("layer_num", 0, "The number layer to plot") 114 | 115 | FLAGS = tf.app.flags.FLAGS 116 | 117 | 118 | def main(_): 119 | if FLAGS.batch_size is not 1: 120 | print("Error: to test images, batch size must be 1") 121 | exit() 122 | 123 | model_dir = os.path.join(FLAGS.log_root, FLAGS.train_dir) 124 | if not os.path.exists(FLAGS.log_root): 125 | os.makedirs(FLAGS.log_root) 126 | if not os.path.exists(model_dir): 127 | os.makedirs(model_dir) 128 | bart_dir = os.path.join(model_dir, "bart_recon") 129 | if not os.path.exists(bart_dir): 130 | os.makedirs(bart_dir) 131 | 132 | run_config = tf.ConfigProto() 133 | run_config.gpu_options.allow_growth = True 134 | 135 | with tf.Session(config=run_config) as sess: 136 | """Execute main function.""" 137 | os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.device 138 | 139 | if not FLAGS.dataset_dir: 140 | raise ValueError( 141 | "You must supply the dataset directory with " + "--dataset_dir" 142 | ) 143 | 144 | if FLAGS.random_seed >= 0: 145 | random.seed(FLAGS.random_seed) 146 | np.random.seed(FLAGS.random_seed) 147 | 148 | tf.logging.set_verbosity(tf.logging.INFO) 149 | 150 | print("Preparing dataset...") 151 | out_shape = [FLAGS.shape_z, FLAGS.shape_y] 152 | 153 | test_dataset, num_files = mri_data.create_dataset( 154 | os.path.join(FLAGS.dataset_dir, "test"), 155 | FLAGS.mask_path, 156 | num_channels=FLAGS.num_channels, 157 | num_emaps=FLAGS.num_emaps, 158 | batch_size=FLAGS.batch_size, 159 | out_shape=out_shape, 160 | ) 161 | # channels first: (batch, channels, z, y) 162 | # placeholders 163 | ks_shape = [None, FLAGS.shape_z, FLAGS.shape_y, FLAGS.num_channels] 164 | ks_place = tf.placeholder(tf.complex64, ks_shape) 165 | sense_shape = [None, FLAGS.shape_z, 166 | FLAGS.shape_y, 1, FLAGS.num_channels] 167 | sense_place = tf.placeholder(tf.complex64, sense_shape) 168 | im_shape = [None, FLAGS.shape_z, FLAGS.shape_y, 1] 169 | im_truth_place = tf.placeholder(tf.complex64, im_shape) 170 | # run through unrolled 171 | im_out_place = mri_model.unroll_fista( 172 | ks_place, 173 | sense_place, 174 | is_training=True, 175 | verbose=True, 176 | do_hardproj=FLAGS.do_hard_proj, 177 | num_summary_image=FLAGS.num_summary_image, 178 | resblock_num_features=FLAGS.feat_map, 179 | num_grad_steps=FLAGS.num_grad_steps, 180 | conv=FLAGS.conv, 181 | do_conjugate=FLAGS.do_conjugate, 182 | ) 183 | 184 | saver = tf.train.Saver() 185 | summary_writer = tf.summary.FileWriter(model_dir, sess.graph) 186 | 187 | # initialize model 188 | print("[*] initializing network...") 189 | if not load(model_dir, saver, sess): 190 | sess.run(tf.global_variables_initializer()) 191 | coord = tf.train.Coordinator() 192 | threads = tf.train.start_queue_runners(sess, coord) 193 | 194 | # See how many parameters are in model 195 | total_parameters = 0 196 | for variable in tf.trainable_variables(): 197 | variable_parameters = 1 198 | for dim in variable.get_shape(): 199 | variable_parameters *= dim.value 200 | total_parameters += variable_parameters 201 | print("Total number of trainable parameters: %d" % total_parameters) 202 | 203 | test_iterator = test_dataset.make_one_shot_iterator() 204 | features, labels = test_iterator.get_next() 205 | 206 | ks_truth = labels 207 | ks_in = features["ks_input"] 208 | sense_in = features["sensemap"] 209 | mask_recon = features["mask_recon"] 210 | im_truth = tf_util.model_transpose(ks_truth * mask_recon, sense_in) 211 | 212 | total_summary = tf.summary.merge_all() 213 | 214 | output_psnr = [] 215 | output_nrmse = [] 216 | output_ssim = [] 217 | cs_psnr = [] 218 | cs_nrmse = [] 219 | cs_ssim = [] 220 | 221 | for test_file in range(num_files): 222 | ks_in_run, sense_in_run, im_truth_run = sess.run( 223 | [ks_in, sense_in, im_truth] 224 | ) 225 | im_out, total_summary_run = sess.run( 226 | [im_out_place, total_summary], 227 | feed_dict={ 228 | ks_place: ks_in_run, 229 | sense_place: sense_in_run, 230 | im_truth_place: im_truth_run, 231 | }, 232 | ) 233 | 234 | # CS recon 235 | bart_test = bart_cs(bart_dir, ks_in_run, sense_in_run, l1=0.007) 236 | # bart_test = None 237 | 238 | # handle batch dimension 239 | for b in range(FLAGS.batch_size): 240 | truth = im_truth_run[b, :, :, :] 241 | out = im_out[b, :, :, :] 242 | psnr, nrmse, ssim = metrics.compute_all( 243 | truth, out, sos_axis=-1) 244 | output_psnr.append(psnr) 245 | output_nrmse.append(nrmse) 246 | output_ssim.append(ssim) 247 | 248 | print("output mean +/ standard deviation psnr, nrmse, ssim") 249 | print( 250 | np.mean(output_psnr), 251 | np.std(output_psnr), 252 | np.mean(output_nrmse), 253 | np.std(output_nrmse), 254 | np.mean(output_ssim), 255 | np.std(output_ssim), 256 | ) 257 | 258 | psnr, nrmse, ssim = metrics.compute_all( 259 | im_truth_run, bart_test, sos_axis=-1 260 | ) 261 | cs_psnr.append(psnr) 262 | cs_nrmse.append(nrmse) 263 | cs_ssim.append(ssim) 264 | 265 | print("cs mean +/ standard deviation psnr, nrmse, ssim") 266 | print( 267 | np.mean(cs_psnr), 268 | np.std(cs_psnr), 269 | np.mean(cs_nrmse), 270 | np.std(cs_nrmse), 271 | np.mean(cs_ssim), 272 | np.std(cs_ssim), 273 | ) 274 | print("End of testing loop") 275 | txt_path = os.path.join(model_dir, "metrics.txt") 276 | f = open(txt_path, "w") 277 | f.write( 278 | "parameters = " 279 | + str(total_parameters) 280 | + "\n" 281 | + "output psnr = " 282 | + str(np.mean(output_psnr)) 283 | + " +\- " 284 | + str(np.std(output_psnr)) 285 | + "\n" 286 | + "output nrmse = " 287 | + str(np.mean(output_nrmse)) 288 | + " +\- " 289 | + str(np.std(output_nrmse)) 290 | + "\n" 291 | + "output ssim = " 292 | + str(np.mean(output_ssim)) 293 | + " +\- " 294 | + str(np.std(output_ssim)) 295 | + "\n" 296 | "cs psnr = " 297 | + str(np.mean(cs_psnr)) 298 | + " +\- " 299 | + str(np.std(cs_psnr)) 300 | + "\n" 301 | + "output nrmse = " 302 | + str(np.mean(cs_nrmse)) 303 | + " +\- " 304 | + str(np.std(cs_nrmse)) 305 | + "\n" 306 | + "output ssim = " 307 | + str(np.mean(cs_ssim)) 308 | + " +\- " 309 | + str(np.std(cs_ssim)) 310 | ) 311 | f.close() 312 | 313 | 314 | def load(log_dir, saver, sess): 315 | print("[*] Reading Checkpoints...") 316 | ckpt = tf.train.get_checkpoint_state(log_dir) 317 | if ckpt and ckpt.model_checkpoint_path: 318 | saver.restore(sess, ckpt.model_checkpoint_path) 319 | print("[*] Model restored.") 320 | return True 321 | else: 322 | print("[*] Failed to find a checkpoint") 323 | return False 324 | 325 | 326 | def bart_cs(bart_dir, ks, sensemap, l1=0.01): 327 | cfl_ks = np.squeeze(ks) 328 | cfl_ks = np.expand_dims(cfl_ks, -2) 329 | cfl_sensemap = np.squeeze(sensemap) 330 | cfl_sensemap = np.expand_dims(cfl_sensemap, -2) 331 | 332 | ks_dir = os.path.join(bart_dir, "file_ks") 333 | sense_dir = os.path.join(bart_dir, "file_sensemap") 334 | img_dir = os.path.join(bart_dir, "file_img") 335 | 336 | cfl.write(ks_dir, cfl_ks, "R") 337 | cfl.write(sense_dir, cfl_sensemap, "R") 338 | 339 | # L1-wavelet regularized 340 | cmd_flags = "-S -e -R W:3:0:%f -i 100" % l1 341 | 342 | cmd = "%s pics %s %s %s %s" % ( 343 | BIN_BART, cmd_flags, ks_dir, sense_dir, img_dir,) 344 | subprocess.check_call(["bash", "-c", cmd]) 345 | bart_recon = load_recon(img_dir, sense_dir) 346 | return bart_recon 347 | 348 | 349 | def load_recon(file, file_sensemap): 350 | bart_recon = np.squeeze(cfl.read(file)) 351 | if bart_recon.ndim == 2: 352 | bart_recon = np.transpose(bart_recon, [1, 0]) 353 | bart_recon = np.expand_dims(bart_recon, axis=0) 354 | bart_recon = np.expand_dims(bart_recon, axis=-1) 355 | if bart_recon.ndim == 3: 356 | bart_recon = np.transpose(bart_recon, [2, 1, 0]) 357 | bart_recon = np.expand_dims(bart_recon, axis=-1) 358 | 359 | return bart_recon 360 | 361 | 362 | def calculate_metrics(output, bart_test, truth): 363 | cs_psnr = [] 364 | cs_nrmse = [] 365 | cs_ssim = [] 366 | output_psnr = [] 367 | output_nrmse = [] 368 | output_ssim = [] 369 | 370 | psnr, nrmse, ssim = metrics.compute_all(truth, output, sos_axis=-1) 371 | output_psnr.append(psnr) 372 | output_nrmse.append(nrmse) 373 | output_ssim.append(ssim) 374 | 375 | 376 | def _create_summary(sense_place, ks_place, im_out_place, im_truth_place): 377 | sensemap = sense_place 378 | ks_input = ks_place 379 | image_output = im_out_place 380 | image_truth = im_truth_place 381 | 382 | image_input = tf_util.model_transpose(ks_input, sensemap) 383 | mask_input = tf_util.kspace_mask(ks_input, dtype=tf.complex64) 384 | ks_output = tf_util.model_forward(image_output, sensemap) 385 | ks_truth = tf_util.model_forward(image_truth, sensemap) 386 | 387 | with tf.name_scope("input-output-truth"): 388 | summary_input = tf_util.sumofsq(ks_input, keep_dims=True) 389 | summary_output = tf_util.sumofsq(ks_output, keep_dims=True) 390 | summary_truth = tf_util.sumofsq(ks_truth, keep_dims=True) 391 | summary_fft = tf.log( 392 | tf.concat((summary_input, summary_output, 393 | summary_truth), axis=2) + 1e-6 394 | ) 395 | tf.summary.image("kspace", summary_fft, 396 | max_outputs=FLAGS.num_summary_image) 397 | summary_input = tf_util.sumofsq(image_input, keep_dims=True) 398 | summary_output = tf_util.sumofsq(image_output, keep_dims=True) 399 | summary_truth = tf_util.sumofsq(image_truth, keep_dims=True) 400 | summary_image = tf.concat( 401 | (summary_input, summary_output, summary_truth), axis=2 402 | ) 403 | tf.summary.image("image", summary_image, 404 | max_outputs=FLAGS.num_summary_image) 405 | 406 | with tf.name_scope("truth"): 407 | summary_truth_real = tf.reduce_sum( 408 | image_truth, axis=-1, keep_dims=True) 409 | summary_truth_real = tf.real(summary_truth_real) 410 | tf.summary.image( 411 | "image_real", summary_truth_real, max_outputs=FLAGS.num_summary_image 412 | ) 413 | 414 | with tf.name_scope("mask"): 415 | summary_mask = tf_util.sumofsq(mask_input, keep_dims=True) 416 | tf.summary.image("mask", summary_mask, 417 | max_outputs=FLAGS.num_summary_image) 418 | 419 | with tf.name_scope("sensemap"): 420 | summary_map = tf.slice( 421 | tf.abs(sensemap), [0, 0, 0, 0, 0], [-1, -1, -1, 1, -1]) 422 | summary_map = tf.transpose(summary_map, [0, 1, 4, 2, 3]) 423 | summary_map = tf.reshape( 424 | summary_map, [tf.shape(summary_map)[0], 425 | tf.shape(summary_map)[1], -1] 426 | ) 427 | summary_map = tf.expand_dims(summary_map, axis=-1) 428 | tf.summary.image("image", summary_map, 429 | max_outputs=FLAGS.num_summary_image) 430 | 431 | 432 | if __name__ == "__main__": 433 | tf.app.run() 434 | -------------------------------------------------------------------------------- /test_script.sh: -------------------------------------------------------------------------------- 1 | WORK_DIR=~/Workspace/complex-networks-release #repo directory 2 | DATASET_DIR=/home_local/ekcole/knee_data #where the knee dataset was downloaded to 3 | MASKS_PATH=/home_local/ekcole/knee_masks #where the masks were generated by BART 4 | 5 | TYPE=complex #real or complex convolution 6 | ITERATIONS=4 #number of unrolled iterations 7 | FEAT=128 #number of feature maps 8 | ACTIVATION=relu #activation function 9 | 10 | #To test images, batch size has to be 1 11 | 12 | LOG_DIR="f"$FEAT"_g"$ITERATIONS 13 | # training 14 | python3 $WORK_DIR/test_images.py \ 15 | --train_dir $TYPE"_"$ACTIVATION \ 16 | --mask_path $MASKS_PATH \ 17 | --dataset_dir $DATASET_DIR \ 18 | --log_root $LOG_DIR \ 19 | --shape_z 256 --shape_y 320 \ 20 | --num_channels 8 \ 21 | --batch_size 1 \ 22 | --device 0 \ 23 | --max_steps 10000 \ 24 | --feat_map $FEAT \ 25 | --num_grad_steps $ITERATIONS \ 26 | --activation $ACTIVATION \ 27 | --conv $TYPE -------------------------------------------------------------------------------- /train_loop.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import logging 4 | import os 5 | import random 6 | import sys 7 | import types 8 | import warnings 9 | 10 | import numpy as np 11 | import tensorflow as tf 12 | 13 | import mri_data 14 | import mri_model 15 | from mri_util import fftc, metrics, tf_util 16 | 17 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 18 | 19 | tf.app.flags.DEFINE_string("gpu", "single", "Single or multi GPU Mode") 20 | tf.app.flags.DEFINE_string("conv", "real", "Real or complex convolution") 21 | tf.app.flags.DEFINE_boolean("do_conjugate", "False", "Complex conjugate") 22 | # Data dimensions 23 | tf.app.flags.DEFINE_integer("feat_map", 128, "Number of feature maps") 24 | 25 | tf.app.flags.DEFINE_integer("shape_y", 180, "Image shape in Y") 26 | tf.app.flags.DEFINE_integer("shape_z", 80, "Image shape in Z") 27 | tf.app.flags.DEFINE_integer( 28 | "num_channels", 8, "Number of channels for input datasets.") 29 | tf.app.flags.DEFINE_integer( 30 | "num_emaps", 1, "Number of eigen maps for input sensitivity maps." 31 | ) 32 | 33 | # For logging 34 | tf.app.flags.DEFINE_integer("print_level", 1, "Print out level.") 35 | tf.app.flags.DEFINE_string( 36 | "log_root", "summary", "Root directory where logs are written to." 37 | ) 38 | tf.app.flags.DEFINE_string( 39 | "train_dir", "train", "Directory for checkpoints and event logs." 40 | ) 41 | tf.app.flags.DEFINE_integer( 42 | "num_summary_image", 4, "Number of images for summary output" 43 | ) 44 | tf.app.flags.DEFINE_integer( 45 | "log_every_n_steps", 10, "The frequency with which logs are print." 46 | ) 47 | tf.app.flags.DEFINE_integer( 48 | "save_summaries_secs", 49 | 10, 50 | "The frequency with which summaries are saved, " + "in seconds.", 51 | ) 52 | 53 | tf.app.flags.DEFINE_integer( 54 | "save_interval_secs", 55 | 10, 56 | "The frequency with which the model is saved, " + "in seconds.", 57 | ) 58 | 59 | tf.app.flags.DEFINE_integer( 60 | "random_seed", 1000, "Seed to initialize random number generators." 61 | ) 62 | 63 | # For model 64 | tf.app.flags.DEFINE_integer( 65 | "num_grad_steps", 2, "Number of grad steps for unrolled algorithms" 66 | ) 67 | tf.app.flags.DEFINE_boolean( 68 | "do_hard_proj", True, "Turn on/off hard data projection at the end" 69 | ) 70 | 71 | # Optimization Flags 72 | tf.app.flags.DEFINE_string("device", "0", "GPU device to use.") 73 | tf.app.flags.DEFINE_integer( 74 | "batch_size", 4, "The number of samples in each batch.") 75 | 76 | tf.app.flags.DEFINE_float( 77 | "adam_beta1", 0.9, "The exponential decay rate for the 1st moment estimates." 78 | ) 79 | tf.app.flags.DEFINE_float( 80 | "adam_beta2", 0.999, "The exponential decay rate for the 2nd moment estimates." 81 | ) 82 | tf.app.flags.DEFINE_float( 83 | "opt_epsilon", 1.0, "Epsilon term for the optimizer.") 84 | tf.app.flags.DEFINE_float("learning_rate", 0.001, "Initial learning rate.") 85 | tf.app.flags.DEFINE_integer( 86 | "max_steps", None, "The maximum number of training steps.") 87 | 88 | # Dataset Flags 89 | tf.app.flags.DEFINE_string( 90 | "mask_path", "masks", "Directory where masks are located.") 91 | tf.app.flags.DEFINE_string( 92 | "train_path", "train", "Sub directory where training data are located." 93 | ) 94 | tf.app.flags.DEFINE_string( 95 | "dataset_dir", "dataset", "The directory where the dataset files are stored." 96 | ) 97 | 98 | tf.app.flags.DEFINE_boolean( 99 | "do_validation", True, "Turn on/off validation during training" 100 | ) 101 | 102 | tf.app.flags.DEFINE_string( 103 | "mode", "train_validate", "Train_validate, train, or predict" 104 | ) 105 | 106 | tf.app.flags.DEFINE_string( 107 | "activation", "relu", "The activation function used") 108 | # If not defined will loop through entire test directory 109 | tf.app.flags.DEFINE_integer("num_cases", None, "The number of inference files") 110 | 111 | # plot middle layer weights in frequency domain 112 | tf.app.flags.DEFINE_integer("layer_num", 0, "The number layer to plot") 113 | 114 | FLAGS = tf.app.flags.FLAGS 115 | 116 | 117 | def main(_): 118 | # path where model checkpoints and summaries will be saved 119 | model_dir = os.path.join(FLAGS.log_root, FLAGS.train_dir) 120 | if not os.path.exists(FLAGS.log_root): 121 | os.makedirs(FLAGS.log_root) 122 | if not os.path.exists(model_dir): 123 | os.makedirs(model_dir) 124 | 125 | run_config = tf.ConfigProto() 126 | run_config.gpu_options.allow_growth = True 127 | 128 | with tf.Session( 129 | config=tf.ConfigProto(allow_soft_placement=True, 130 | log_device_placement=True) 131 | ) as sess: 132 | """Execute main function.""" 133 | os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.device 134 | 135 | if not FLAGS.dataset_dir: 136 | raise ValueError( 137 | "You must supply the dataset directory with " + "--dataset_dir" 138 | ) 139 | 140 | if FLAGS.random_seed >= 0: 141 | random.seed(FLAGS.random_seed) 142 | np.random.seed(FLAGS.random_seed) 143 | 144 | tf.logging.set_verbosity(tf.logging.INFO) 145 | 146 | print("Preparing dataset...") 147 | out_shape = [FLAGS.shape_z, FLAGS.shape_y] 148 | train_dataset, num_files = mri_data.create_dataset( 149 | os.path.join(FLAGS.dataset_dir, "train"), 150 | FLAGS.mask_path, 151 | num_channels=FLAGS.num_channels, 152 | num_emaps=FLAGS.num_emaps, 153 | batch_size=FLAGS.batch_size, 154 | out_shape=out_shape, 155 | ) 156 | 157 | # channels last format: batch, z, y, channels 158 | # placeholders 159 | ks_shape = [None, FLAGS.shape_z, FLAGS.shape_y, FLAGS.num_channels] 160 | ks_place = tf.placeholder(tf.complex64, ks_shape) 161 | sense_shape = [None, FLAGS.shape_z, 162 | FLAGS.shape_y, 1, FLAGS.num_channels] 163 | sense_place = tf.placeholder(tf.complex64, sense_shape) 164 | im_shape = [None, FLAGS.shape_z, FLAGS.shape_y, 1] 165 | im_truth_place = tf.placeholder(tf.complex64, im_shape) 166 | 167 | # run through unrolled model 168 | im_out_place = mri_model.unroll_fista( 169 | ks_place, 170 | sense_place, 171 | is_training=True, 172 | verbose=True, 173 | do_hardproj=FLAGS.do_hard_proj, 174 | num_summary_image=FLAGS.num_summary_image, 175 | resblock_num_features=FLAGS.feat_map, 176 | num_grad_steps=FLAGS.num_grad_steps, 177 | conv=FLAGS.conv, 178 | do_conjugate=FLAGS.do_conjugate, 179 | activation=FLAGS.activation 180 | ) 181 | 182 | # tensorboard summary function 183 | _create_summary(sense_place, ks_place, im_out_place, im_truth_place) 184 | 185 | # define L1 loss between output and ground truth 186 | loss = tf.reduce_mean(tf.abs(im_out_place - im_truth_place), name="l1") 187 | loss_sum = tf.summary.scalar("loss/l1", loss) 188 | 189 | # optimize using Adam 190 | optimizer = tf.train.AdamOptimizer( 191 | learning_rate=FLAGS.learning_rate, 192 | name="opt", 193 | beta1=FLAGS.adam_beta1, 194 | beta2=FLAGS.adam_beta2, 195 | ).minimize(loss) 196 | 197 | # counter for saving checkpoints 198 | with tf.variable_scope("counter"): 199 | counter = tf.get_variable( 200 | "counter", 201 | shape=[1], 202 | initializer=tf.constant_initializer([0]), 203 | dtype=tf.int32, 204 | ) 205 | update_counter = tf.assign(counter, tf.add(counter, 1)) 206 | 207 | saver = tf.train.Saver() 208 | summary_writer = tf.summary.FileWriter(model_dir, sess.graph) 209 | 210 | # initialize model 211 | print("[*] initializing network...") 212 | if not load(model_dir, saver, sess): 213 | sess.run(tf.global_variables_initializer()) 214 | coord = tf.train.Coordinator() 215 | threads = tf.train.start_queue_runners(sess, coord) 216 | 217 | # calculate number of parameters in model 218 | total_parameters = 0 219 | for variable in tf.trainable_variables(): 220 | variable_parameters = 1 221 | for dim in variable.get_shape(): 222 | variable_parameters *= dim.value 223 | total_parameters += variable_parameters 224 | print("Total number of trainable parameters: %d" % total_parameters) 225 | tf.summary.scalar("parameters/parameters", total_parameters) 226 | 227 | # use iterator to go through TFrecord dataset 228 | train_iterator = train_dataset.make_one_shot_iterator() 229 | features, labels = train_iterator.get_next() 230 | 231 | ks_truth = labels # ground truth kspace 232 | ks_in = features["ks_input"] # input kspace 233 | sense_in = features["sensemap"] # sensitivity maps 234 | mask_recon = features["mask_recon"] # reconstruction mask 235 | 236 | # ground truth kspace to image domain 237 | im_truth = tf_util.model_transpose(ks_truth * mask_recon, sense_in) 238 | 239 | # gather summaries for tensorboard 240 | total_summary = tf.summary.merge_all() 241 | 242 | print("Start from step %d." % (sess.run(counter))) 243 | for step in range(int(sess.run(counter)), FLAGS.max_steps): 244 | # evaluate input kspace, sensitivity maps, ground truth image 245 | ks_in_run, sense_in_run, im_truth_run = sess.run( 246 | [ks_in, sense_in, im_truth] 247 | ) 248 | # run optimizer and collect output image from model and tensorboard summary 249 | im_out, total_summary_run, _ = sess.run( 250 | [im_out_place, total_summary, optimizer], 251 | feed_dict={ 252 | ks_place: ks_in_run, 253 | sense_place: sense_in_run, 254 | im_truth_place: im_truth_run, 255 | }, 256 | ) 257 | print("step", step) 258 | # add summary to tensorboard 259 | summary_writer.add_summary(total_summary_run, step) 260 | 261 | # save checkpoint every 500 steps 262 | if step % 500 == 0: 263 | print("saving checkpoint") 264 | saver.save(sess, model_dir + "/model.ckpt") 265 | 266 | # update recorded step training is at 267 | sess.run(update_counter) 268 | print("End of training loop") 269 | 270 | 271 | def load(log_dir, saver, sess): 272 | # search for and load a model 273 | print("[*] Reading Checkpoints...") 274 | ckpt = tf.train.get_checkpoint_state(log_dir) 275 | if ckpt and ckpt.model_checkpoint_path: 276 | saver.restore(sess, ckpt.model_checkpoint_path) 277 | print("[*] Model restored.") 278 | return True 279 | else: 280 | print("[*] Failed to find a checkpoint") 281 | return False 282 | 283 | 284 | def _create_summary(sense_place, ks_place, im_out_place, im_truth_place): 285 | # tensorboard summary function 286 | sensemap = sense_place 287 | ks_input = ks_place 288 | image_output = im_out_place 289 | image_truth = im_truth_place 290 | 291 | image_input = tf_util.model_transpose(ks_input, sensemap) 292 | mask_input = tf_util.kspace_mask(ks_input, dtype=tf.complex64) 293 | ks_output = tf_util.model_forward(image_output, sensemap) 294 | ks_truth = tf_util.model_forward(image_truth, sensemap) 295 | 296 | with tf.name_scope("input-output-truth"): 297 | summary_input = tf_util.sumofsq(ks_input, keep_dims=True) 298 | summary_output = tf_util.sumofsq(ks_output, keep_dims=True) 299 | summary_truth = tf_util.sumofsq(ks_truth, keep_dims=True) 300 | summary_fft = tf.log( 301 | tf.concat((summary_input, summary_output, 302 | summary_truth), axis=2) + 1e-6 303 | ) 304 | tf.summary.image("kspace", summary_fft, 305 | max_outputs=FLAGS.num_summary_image) 306 | summary_input = tf_util.sumofsq(image_input, keep_dims=True) 307 | summary_output = tf_util.sumofsq(image_output, keep_dims=True) 308 | summary_truth = tf_util.sumofsq(image_truth, keep_dims=True) 309 | summary_image = tf.concat( 310 | (summary_input, summary_output, summary_truth), axis=2 311 | ) 312 | tf.summary.image("image", summary_image, 313 | max_outputs=FLAGS.num_summary_image) 314 | 315 | with tf.name_scope("phase"): 316 | summary_input = tf.angle(image_input) 317 | summary_output = tf.angle(image_output) 318 | summary_truth = tf.angle(image_truth) 319 | summary_image = tf.concat( 320 | (summary_input, summary_output, summary_truth), axis=2 321 | ) 322 | tf.summary.image("total", summary_image, 323 | max_outputs=FLAGS.num_summary_image) 324 | tf.summary.image("input", summary_input, 325 | max_outputs=FLAGS.num_summary_image) 326 | tf.summary.image("output", summary_output, 327 | max_outputs=FLAGS.num_summary_image) 328 | 329 | with tf.name_scope("truth"): 330 | summary_truth_real = tf.reduce_sum( 331 | image_truth, axis=-1, keep_dims=True) 332 | summary_truth_real = tf.real(summary_truth_real) 333 | tf.summary.image( 334 | "image_real", summary_truth_real, max_outputs=FLAGS.num_summary_image 335 | ) 336 | 337 | with tf.name_scope("mask"): 338 | summary_mask = tf_util.sumofsq(mask_input, keep_dims=True) 339 | tf.summary.image("mask", summary_mask, 340 | max_outputs=FLAGS.num_summary_image) 341 | 342 | with tf.name_scope("sensemap"): 343 | summary_map = tf.slice( 344 | tf.abs(sensemap), [0, 0, 0, 0, 0], [-1, -1, -1, 1, -1]) 345 | summary_map = tf.transpose(summary_map, [0, 1, 4, 2, 3]) 346 | summary_map = tf.reshape( 347 | summary_map, [tf.shape(summary_map)[0], 348 | tf.shape(summary_map)[1], -1] 349 | ) 350 | summary_map = tf.expand_dims(summary_map, axis=-1) 351 | tf.summary.image("image", summary_map, 352 | max_outputs=FLAGS.num_summary_image) 353 | 354 | 355 | if __name__ == "__main__": 356 | tf.app.run() 357 | -------------------------------------------------------------------------------- /train_script.sh: -------------------------------------------------------------------------------- 1 | WORK_DIR=~/Workspace/complex-networks-release #repo directory 2 | DATASET_DIR=/home_local/ekcole/knee_data #where the knee dataset was downloaded to 3 | MASKS_PATH=/home_local/ekcole/knee_masks #where the masks were generated by BART 4 | 5 | TYPE=complex #real or complex convolution 6 | ITERATIONS=4 #number of unrolled iterations 7 | FEAT=128 #number of feature maps 8 | ACTIVATION=relu #activation function 9 | 10 | LOG_DIR="f"$FEAT"_g"$ITERATIONS 11 | 12 | # training 13 | python3 $WORK_DIR/train_loop.py \ 14 | --train_dir $TYPE"_"$ACTIVATION \ 15 | --mask_path $MASKS_PATH \ 16 | --dataset_dir $DATASET_DIR \ 17 | --log_root $LOG_DIR \ 18 | --shape_z 256 --shape_y 320 \ 19 | --num_channels 8 \ 20 | --batch_size 2 \ 21 | --device 0 \ 22 | --max_steps 10000 \ 23 | --feat_map $FEAT \ 24 | --num_grad_steps $ITERATIONS \ 25 | --activation $ACTIVATION \ 26 | --conv $TYPE 27 | 28 | # testing 29 | # python3 $WORK_DIR/test_loop.py --train_dir $TYPE \ 30 | # --shape_z 256 --shape_y 320 \ 31 | # --batch_size 1 \ 32 | # --feat_map $FEAT \ 33 | # --num_grad_steps 4 \ 34 | # --mask_path $MASKS_PATH \ 35 | # --dataset_dir $DATASET_DIR \ 36 | # --device 0 \ 37 | # --log_root $LOG_DIR \ 38 | # --num_channels 8 \ 39 | # --activation $ACTIVATION \ 40 | # --gpu single \ 41 | # --conv $TYPE --------------------------------------------------------------------------------