├── .gitignore ├── Img ├── myAutomap_5000im80_ep100_lr000002.png ├── myAutomap_7500im80_ep150_lr000002.png ├── myAutomap_7500im80_ep200_lr000002_nonnorm.png └── myAutomap_7500im80_ep200_lr000002_nonnorm_brain.png ├── README.md ├── generate_input.py ├── generate_input_motion.py ├── myAutomap_recon.py └── myAutomap.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Status 2 | -------------------------------------------------------------------------------- /Img/myAutomap_5000im80_ep100_lr000002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tetianadadakova/MRI-CNN/HEAD/Img/myAutomap_5000im80_ep100_lr000002.png -------------------------------------------------------------------------------- /Img/myAutomap_7500im80_ep150_lr000002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tetianadadakova/MRI-CNN/HEAD/Img/myAutomap_7500im80_ep150_lr000002.png -------------------------------------------------------------------------------- /Img/myAutomap_7500im80_ep200_lr000002_nonnorm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tetianadadakova/MRI-CNN/HEAD/Img/myAutomap_7500im80_ep200_lr000002_nonnorm.png -------------------------------------------------------------------------------- /Img/myAutomap_7500im80_ep200_lr000002_nonnorm_brain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tetianadadakova/MRI-CNN/HEAD/Img/myAutomap_7500im80_ep200_lr000002_nonnorm_brain.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is my implementation of the AUTOMAP algorithm described in the following paper: 2 | B. Zhu, J. Z. Liu, B. R. Rosen, and M. S. Rosen. Image reconstruction by domain transform manifold learning. arXiv preprint arXiv:1704.08841, 2017. 3 | https://arxiv.org/abs/1704.08841 4 | 5 | NB1: I run the code at AWS cluster, using the following AMI: Deep Learning AMI (Ubuntu), and the following instance: p3.2xlarge. In addition, I use CPU’s memory to initialize the second fully-connected layer for 128x128 images (otherwise, there is memory error) 6 | NB2: I use the following Python package to download images from ImageNet: imagenetscraper 1.0 (https://goo.gl/QK6f8p) 7 | 8 | I encourage you to contact me if you have any questions, comments, or suggestions: tetiana.d@gmail.com. 9 | 10 | 11 | The code uses data in image space and corresponding frequency space to teach a CNN model to do a reconstruction of an MRI image. The architecture consists of fully-connected (FC) and convolutional (Conv) layers and is the following: 12 | FC1 -> tahn activation -> FC2 -> tanh activation -> Conv1 -> ReLU activation -> Conv2 -> ReLU activation -> de-Conv 13 | 14 | **generate_input.py** 15 | This includes function load_images_from_folder, which creates training set for a model. It loads images into array Y and performs a Fourier transform and saves both real and imaginary parts of it into array X. 16 | Optional normalizing of data and rotation of input images are available. 17 | 18 | **generate_input_motion.py** 19 | This includes function load_images_from_folder, which creates training set for a model. It loads images into array Y and "moves" it by 8 pixels the performs a Fourier transform and combined the frequency space of both Y images (before and after it was moved) - as if the patient moved by 8 pixels in one direction in the middle of the acquisition. Then the function saves both real and imaginary parts of motion-corrupted frequency space into array X. 20 | Optional normalizing of data and rotation of input images are available. 21 | 22 | **myAutomap.py** 23 | This includes the CNN model using TensorFlow. 24 | 25 | **myAutomap_recon.py** 26 | Uses forward propagation to reconstruct image from frequency space using the trained model, which was saved in myAutomap.py 27 | 28 | 29 | **(Very) preliminary results:** 30 | Y_dev - original images; X_iFFT - images reconstructed from frequency space corrupted by motion - ghosting artifacts are clearly seen; Y_recon - images reconstructed using trained model - ghosting is gone!, however, images look very blurry - the cost was still quite high - needs optimizing. 31 | 32 | Hyperparameters: learning rate - 0.00002, 7500 images (30000 after augmentation), 80x80 resolution, 200 epochs. 33 | 34 | Example with brain image: 35 | ![alt text](Img/myAutomap_7500im80_ep200_lr000002_nonnorm_brain.png) 36 | 37 | 38 | -------------------------------------------------------------------------------- /generate_input.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | # from matplotlib import pyplot as plt 5 | 6 | 7 | def load_images_from_folder(folder, n_im, normalize=False, imrotate=False): 8 | """ Loads n_im images from the folder and puts them in an array bigy of 9 | size (n_im, im_size1, im_size2), where (im_size1, im_size2) is an image 10 | size. 11 | Performs FFT of every input image and puts it in an array bigx of size 12 | (n_im, im_size1, im_size2, 2), where "2" represents real and imaginary 13 | dimensions 14 | :param folder: path to the folder, which contains images 15 | :param n_im: number of images to load from the folder 16 | :param normalize: if True - the xbig data will be normalized 17 | :param imrotate: if True - the each input image will be rotated by 90, 180, 18 | and 270 degrees 19 | :return: 20 | bigx: 4D array of frequency data of size (n_im, im_size1, im_size2, 2) 21 | bigy: 3D array of images of size (n_im, im_size1, im_size2) 22 | """ 23 | 24 | # Initialize the arrays: 25 | if imrotate: # number of images is 4 * n_im 26 | bigy = np.empty((n_im * 4, 64, 64)) 27 | bigx = np.empty((n_im * 4, 64, 64, 2)) 28 | else: 29 | bigy = np.empty((n_im, 64, 64)) 30 | bigx = np.empty((n_im, 64, 64, 2)) 31 | 32 | im = 0 # image counter 33 | for filename in os.listdir(folder): 34 | if not filename.startswith('.'): 35 | bigy_temp = cv2.imread(os.path.join(folder, filename), 36 | cv2.IMREAD_GRAYSCALE) 37 | bigy[im, :, :] = bigy_temp 38 | bigx[im, :, :, :] = create_x(bigy_temp, normalize) 39 | im += 1 40 | if imrotate: 41 | for angle in [90, 180, 270]: 42 | bigy_rot = im_rotate(bigy_temp, angle) 43 | bigx_rot = create_x(bigy_rot, normalize) 44 | bigy[im, :, :] = bigy_rot 45 | bigx[im, :, :, :] = bigx_rot 46 | im += 1 47 | 48 | if imrotate: 49 | if im > (n_im * 4 - 1): # how many images to load 50 | break 51 | else: 52 | if im > (n_im - 1): # how many images to load 53 | break 54 | 55 | if normalize: 56 | bigx = (bigx - np.amin(bigx)) / (np.amax(bigx) - np.amin(bigx)) 57 | 58 | return bigx, bigy 59 | 60 | 61 | def create_x(y, normalize=False): 62 | """ 63 | Prepares frequency data from image data: applies to_freq_space, 64 | expands the dimensions from 3D to 4D, and normalizes if normalize=True 65 | :param y: input image 66 | :param normalize: if True - the frequency data will be normalized 67 | :return: frequency data 4D array of size (1, im_size1, im_size2, 2) 68 | """ 69 | x = to_freq_space(y) # FFT: (128, 128, 2) 70 | x = np.expand_dims(x, axis=0) # (1, 128, 128, 2) 71 | if normalize: 72 | x = x - np.mean(x) 73 | 74 | return x 75 | 76 | 77 | def to_freq_space(img): 78 | """ Performs FFT of an image 79 | :param img: input 2D image 80 | :return: Frequency-space data of the input image, third dimension (size: 2) 81 | contains real ans imaginary part 82 | """ 83 | 84 | img_f = np.fft.fft2(img) # FFT 85 | img_fshift = np.fft.fftshift(img_f) # FFT shift 86 | img_real = img_fshift.real # Real part: (im_size1, im_size2) 87 | img_imag = img_fshift.imag # Imaginary part: (im_size1, im_size2) 88 | img_real_imag = np.dstack((img_real, img_imag)) # (im_size1, im_size2, 2) 89 | 90 | return img_real_imag 91 | 92 | 93 | def im_rotate(img, angle): 94 | """ Rotates an image by angle degrees 95 | :param img: input image 96 | :param angle: angle by which the image is rotated, in degrees 97 | :return: rotated image 98 | """ 99 | rows, cols = img.shape 100 | rotM = cv2.getRotationMatrix2D((cols/2-0.5, rows/2-0.5), angle, 1) 101 | imrotated = cv2.warpAffine(img, rotM, (cols, rows)) 102 | 103 | return imrotated 104 | 105 | 106 | ''' 107 | # For debugging: show the images and their frequency space 108 | 109 | dir_temp = 'path to folder with images' 110 | X, Y = load_images_from_folder(dir_temp, 5, normalize=False, imrotate=True) 111 | 112 | print(Y.shape) 113 | print(X.shape) 114 | 115 | 116 | plt.subplot(221), plt.imshow(Y[12, :, :], cmap='gray') 117 | plt.xticks([]), plt.yticks([]) 118 | plt.subplot(222), plt.imshow(Y[13, :, :], cmap='gray') 119 | plt.xticks([]), plt.yticks([]) 120 | plt.subplot(223), plt.imshow(Y[14, :, :], cmap='gray') 121 | plt.xticks([]), plt.yticks([]) 122 | plt.subplot(224), plt.imshow(Y[15, :, :], cmap='gray') 123 | plt.xticks([]), plt.yticks([]) 124 | plt.show() 125 | 126 | X_m = 20*np.log(np.sqrt(np.power(X[:, :, :, 0], 2) + 127 | np.power(X[:, :, :, 1], 2))) # Magnitude 128 | plt.subplot(221), plt.imshow(X_m[12, :, :], cmap='gray') 129 | plt.xticks([]), plt.yticks([]) 130 | plt.subplot(222), plt.imshow(X_m[13, :, :], cmap='gray') 131 | plt.xticks([]), plt.yticks([]) 132 | plt.subplot(223), plt.imshow(X_m[14, :, :], cmap='gray') 133 | plt.xticks([]), plt.yticks([]) 134 | plt.subplot(224), plt.imshow(X_m[15, :, :], cmap='gray') 135 | plt.xticks([]), plt.yticks([]) 136 | plt.show() 137 | ''' 138 | -------------------------------------------------------------------------------- /generate_input_motion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | #from matplotlib import pyplot as plt 5 | 6 | 7 | def load_images_from_folder(folder, n_im, normalize=False, imrotate=False): 8 | """ Loads n_im images from the folder and puts them in an array bigy of 9 | size (n_im, im_size1, im_size2), where (im_size1, im_size2) is an image 10 | size. 11 | Performs FFT of every input image and puts it in an array bigx of size 12 | (n_im, im_size1, im_size2, 2), where "2" represents real and imaginary 13 | dimensions 14 | :param folder: path to the folder, which contains images 15 | :param n_im: number of images to load from the folder 16 | :param normalize: if True - the xbig data will be normalized 17 | :param imrotate: if True - the each input image will be rotated by 90, 180, 18 | and 270 degrees 19 | :return: 20 | bigx: 4D array of frequency data of size (n_im, im_size1, im_size2, 2) 21 | bigy: 3D array of images of size (n_im, im_size1, im_size2) 22 | """ 23 | 24 | # Initialize the arrays: 25 | if imrotate: # number of images is 4 * n_im 26 | bigy = np.empty((n_im * 4, 80, 80)) 27 | bigx = np.empty((n_im * 4, 80, 80, 2)) 28 | else: 29 | bigy = np.empty((n_im, 80, 80)) 30 | bigx = np.empty((n_im, 80, 80, 2)) 31 | 32 | im = 0 # image counter 33 | for filename in os.listdir(folder): 34 | if not filename.startswith('.'): 35 | bigy_temp = cv2.imread(os.path.join(folder, filename), 36 | cv2.IMREAD_GRAYSCALE) 37 | bigy_padded = np.zeros((80, 80)) 38 | bigy_padded[8:72, 8:72] = bigy_temp 39 | bigy[im, :, :] = bigy_padded 40 | bigx[im, :, :, :] = create_x(bigy_temp, normalize) 41 | im += 1 42 | if imrotate: 43 | for angle in [90, 180, 270]: 44 | bigy_rot = im_rotate(bigy_temp, angle) 45 | bigx_rot = create_x(bigy_rot, normalize) 46 | 47 | bigy_rot_padded = np.zeros((80, 80)) 48 | bigy_rot_padded[8:72, 8:72] = bigy_rot 49 | 50 | bigy[im, :, :] = bigy_rot_padded 51 | bigx[im, :, :, :] = bigx_rot 52 | im += 1 53 | 54 | if imrotate: 55 | if im > (n_im * 4 - 1): # how many images to load 56 | break 57 | else: 58 | if im > (n_im - 1): # how many images to load 59 | break 60 | 61 | if normalize: 62 | bigx = (bigx - np.amin(bigx)) / (np.amax(bigx) - np.amin(bigx)) 63 | 64 | return bigx, bigy 65 | 66 | 67 | def create_x(y, normalize=False): 68 | """ 69 | Prepares frequency data from image data: first image y is padded by 8 70 | pixels of value zero from each side (y_pad_loc1), then second image is 71 | created by moving the input image (64x64) 8 pixels down -> two same images 72 | at different locations are created; then both images are transformed to 73 | frequency space and their frequency space is combined as if the image 74 | moved half-way through the acquisition (upper part of freq space from one 75 | image and lower part of freq space from another image) 76 | expands the dimensions from 3D to 4D, and normalizes if normalize=True 77 | :param y: input image 78 | :param normalize: if True - the frequency data will be normalized 79 | :return: "Motion corrupted" frequency-space data of the input image, 80 | 4D array of size (1, im_size1, im_size2, 2), third dimension (size: 2) 81 | contains real and imaginary part 82 | """ 83 | 84 | # Pad y and move 8 pixels 85 | y_pad_loc1 = np.zeros((80, 80)) 86 | y_pad_loc2 = np.zeros((80, 80)) 87 | y_pad_loc1[8:72, 8:72] = y 88 | y_pad_loc2[0:64, 8:72] = y 89 | 90 | # FFT of both images 91 | img_f1 = np.fft.fft2(y_pad_loc1) # FFT 92 | img_fshift1 = np.fft.fftshift(img_f1) # FFT shift 93 | img_f2 = np.fft.fft2(y_pad_loc2) # FFT 94 | img_fshift2 = np.fft.fftshift(img_f2) # FFT shift 95 | 96 | # Combine halfs of both k-space - as if subject moved 8 pixels in the 97 | # middle of acquisition 98 | x_compl = np.zeros((80, 80), dtype=np.complex_) 99 | x_compl[0:41, :] = img_fshift1[0:41, :] 100 | x_compl[41:81, :] = img_fshift2[41:81, :] 101 | 102 | # Finally, separate into real and imaginary channels 103 | x_real = x_compl.real 104 | x_imag = x_compl.imag 105 | x = np.dstack((x_real, x_imag)) 106 | 107 | x = np.expand_dims(x, axis=0) 108 | 109 | if normalize: 110 | x = x - np.mean(x) 111 | 112 | return x 113 | 114 | 115 | def im_rotate(img, angle): 116 | """ Rotates an image by angle degrees 117 | :param img: input image 118 | :param angle: angle by which the image is rotated, in degrees 119 | :return: rotated image 120 | """ 121 | 122 | rows, cols = img.shape 123 | rotM = cv2.getRotationMatrix2D((cols/2-0.5, rows/2-0.5), angle, 1) 124 | imrotated = cv2.warpAffine(img, rotM, (cols, rows)) 125 | 126 | return imrotated 127 | 128 | 129 | ''' 130 | # For debugging: show the images and their frequency space 131 | 132 | dir_temp = 'path to folder with images' 133 | X, Y = load_images_from_folder(dir_temp, 5, normalize=False, imrotate=True) 134 | 135 | print(Y.shape) 136 | print(X.shape) 137 | 138 | # Image 139 | plt.subplot(221), plt.imshow(Y[8, :, :], cmap='gray') 140 | plt.title('Y_rot0'), plt.xticks([]), plt.yticks([]) 141 | plt.subplot(222), plt.imshow(Y[9, :, :], cmap='gray') 142 | plt.title('Y_rot90'), plt.xticks([]), plt.yticks([]) 143 | plt.subplot(223), plt.imshow(Y[10, :, :], cmap='gray') 144 | plt.title('Y_rot180'), plt.xticks([]), plt.yticks([]) 145 | plt.subplot(224), plt.imshow(Y[11, :, :], cmap='gray') 146 | plt.title('Y_rot270'), plt.xticks([]), plt.yticks([]) 147 | plt.show() 148 | 149 | # Corresponding frequency space (magnitude) 150 | X_m = np.sqrt(np.power(X[:, :, :, 0], 2) 151 | + np.power(X[:, :, :, 1], 2)) 152 | plt.subplot(221), plt.imshow(X_m[8, :, :], cmap='gray') 153 | plt.title('X_freq_rot0'), plt.xticks([]), plt.yticks([]) 154 | plt.subplot(222), plt.imshow(X_m[9, :, :], cmap='gray') 155 | plt.title('X_freq_rot90'), plt.xticks([]), plt.yticks([]) 156 | plt.subplot(223), plt.imshow(X_m[10, :, :], cmap='gray') 157 | plt.title('X_freq_rot180'), plt.xticks([]), plt.yticks([]) 158 | plt.subplot(224), plt.imshow(X_m[11, :, :], cmap='gray') 159 | plt.title('X_freq_rot270'), plt.xticks([]), plt.yticks([]) 160 | plt.show() 161 | 162 | 163 | # iFFT back to image from corrupted frequency space 164 | X_compl = X[:, :, :, 0] + X[:, :, :, 1] * 1j 165 | 166 | im_artif0 = np.fft.ifft2(X_compl[8, :, :]) 167 | im_artif1 = np.fft.ifft2(X_compl[9, :, :]) 168 | im_artif2 = np.fft.ifft2(X_compl[10, :, :]) 169 | im_artif3 = np.fft.ifft2(X_compl[11, :, :]) 170 | 171 | img_artif_M0 = np.sqrt(np.power(im_artif0.real, 2) 172 | + np.power(im_artif0.imag, 2)) 173 | img_artif_M1 = np.sqrt(np.power(im_artif1.real, 2) 174 | + np.power(im_artif1.imag, 2)) 175 | img_artif_M2 = np.sqrt(np.power(im_artif2.real, 2) 176 | + np.power(im_artif2.imag, 2)) 177 | img_artif_M3 = np.sqrt(np.power(im_artif3.real, 2) 178 | + np.power(im_artif3.imag, 2)) 179 | 180 | plt.subplot(221), plt.imshow(img_artif_M0, cmap='gray') 181 | plt.title('X_rot0'), plt.xticks([]), plt.yticks([]) 182 | plt.subplot(222), plt.imshow(img_artif_M1, cmap='gray') 183 | plt.title('X_rot1'), plt.xticks([]), plt.yticks([]) 184 | plt.subplot(223), plt.imshow(img_artif_M2, cmap='gray') 185 | plt.title('X_rot2'), plt.xticks([]), plt.yticks([]) 186 | plt.subplot(224), plt.imshow(img_artif_M3, cmap='gray') 187 | plt.title('X_rot3'), plt.xticks([]), plt.yticks([]) 188 | plt.show() 189 | ''' 190 | -------------------------------------------------------------------------------- /myAutomap_recon.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.framework import ops 4 | from matplotlib import pyplot as plt 5 | from generate_input_motion import load_images_from_folder 6 | 7 | 8 | # Load development/test data: 9 | dir_dev = "path to the folder with dev/test images" 10 | n_im_dev = 60 # How many images to load 11 | # Load images and create motion-corrupted frequency space 12 | # No normalization or rotations: 13 | X_dev, Y_dev = load_images_from_folder( # Load images for evaluating model 14 | dir_dev, 15 | n_im_dev, 16 | normalize=False, 17 | imrotate=False) 18 | print('X_dev.shape at input = ', X_dev.shape) 19 | print('Y_dev.shape at input = ', Y_dev.shape) 20 | 21 | 22 | def create_placeholders(n_H0, n_W0): 23 | """ Creates placeholders for x and y for tf.session 24 | :param n_H0: image height 25 | :param n_W0: image width 26 | :return: x and y - tf placeholders 27 | """ 28 | 29 | x = tf.placeholder(tf.float32, shape=[None, n_H0, n_W0, 2], name='x') 30 | y = tf.placeholder(tf.float32, shape=[None, n_H0, n_W0], name='y') 31 | 32 | return x, y 33 | 34 | 35 | def initialize_parameters(): 36 | """ Initializes filters for the convolutional and de-convolutional layers 37 | :return: parameters - a dictionary of filters (W1 - first convolutional 38 | layer, W2 - second convolutional layer, W3 - de-convolutional layer 39 | """ 40 | 41 | W1 = tf.get_variable("W1", [5, 5, 1, 64], # 64 filters of size 5x5 42 | initializer=tf.contrib.layers.xavier_initializer 43 | (seed=0)) 44 | W2 = tf.get_variable("W2", [5, 5, 64, 64], # 64 filters of size 5x5 45 | initializer=tf.contrib.layers.xavier_initializer 46 | (seed=0)) 47 | W3 = tf.get_variable("W3", [7, 7, 1, 64], # 64 filters of size 7x7 48 | initializer=tf.contrib.layers.xavier_initializer 49 | (seed=0)) # conv2d_transpose 50 | 51 | parameters = {"W1": W1, 52 | "W2": W2, 53 | "W3": W3} 54 | 55 | return parameters 56 | 57 | 58 | def forward_propagation(x, parameters): 59 | """ Defines all layers for forward propagation: 60 | Fully connected (FC1) -> tanh activation: size (n_im, n_H0 * n_W0) 61 | -> Fully connected (FC2) -> tanh activation: size (n_im, n_H0 * n_W0) 62 | -> Convolutional -> ReLU activation: size (n_im, n_H0, n_W0, 64) 63 | -> Convolutional -> ReLU activation: size (n_im, n_H0, n_W0, 64) 64 | -> De-convolutional: size (n_im, n_H0, n_W0) 65 | :param x: Input - images in frequency space, size (n_im, n_H0, n_W0, 2) 66 | :param parameters: parameters of the layers (e.g. filters) 67 | :return: output of the last layer of the neural network 68 | """ 69 | 70 | x_temp = tf.contrib.layers.flatten(x) # size (n_im, n_H0 * n_W0 * 2) 71 | n_out = np.int(x.shape[1] * x.shape[2]) # size (n_im, n_H0 * n_W0) 72 | 73 | # FC: input size (n_im, n_H0 * n_W0 * 2), output size (n_im, n_H0 * n_W0) 74 | FC1 = tf.contrib.layers.fully_connected( 75 | x_temp, 76 | n_out, 77 | activation_fn=tf.tanh, 78 | normalizer_fn=None, 79 | normalizer_params=None, 80 | weights_initializer=tf.contrib.layers.xavier_initializer(), 81 | weights_regularizer=None, 82 | biases_initializer=None, 83 | biases_regularizer=None, 84 | reuse=True, 85 | variables_collections=None, 86 | outputs_collections=None, 87 | trainable=True, 88 | scope='fc1') 89 | 90 | # FC: input size (n_im, n_H0 * n_W0), output size (n_im, n_H0 * n_W0) 91 | FC2 = tf.contrib.layers.fully_connected( 92 | FC1, 93 | n_out, 94 | activation_fn=tf.tanh, 95 | normalizer_fn=None, 96 | normalizer_params=None, 97 | weights_initializer=tf.contrib.layers.xavier_initializer(), 98 | weights_regularizer=None, 99 | biases_initializer=None, 100 | biases_regularizer=None, 101 | reuse=True, 102 | variables_collections=None, 103 | outputs_collections=None, 104 | trainable=True, 105 | scope='fc2') 106 | 107 | # Reshape output from FC layers into array of size (n_im, n_H0, n_W0, 1): 108 | FC_M = tf.reshape(FC2, [tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], 1]) 109 | 110 | # Retrieve the parameters from the dictionary "parameters": 111 | W1 = parameters['W1'] 112 | W2 = parameters['W2'] 113 | W3 = parameters['W3'] 114 | 115 | # CONV2D: filters W1, stride of 1, padding 'SAME' 116 | # Input size (n_im, n_H0, n_W0, 1), output size (n_im, n_H0, n_W0, 64) 117 | Z1 = tf.nn.conv2d(FC_M, W1, strides=[1, 1, 1, 1], padding='SAME') 118 | # RELU 119 | CONV1 = tf.nn.relu(Z1) 120 | 121 | # CONV2D: filters W2, stride 1, padding 'SAME' 122 | # Input size (n_im, n_H0, n_W0, 64), output size (n_im, n_H0, n_W0, 64) 123 | Z2 = tf.nn.conv2d(CONV1, W2, strides=[1, 1, 1, 1], padding='SAME') 124 | # RELU 125 | CONV2 = tf.nn.relu(Z2) 126 | 127 | # DE-CONV2D: filters W3, stride 1, padding 'SAME' 128 | # Input size (n_im, n_H0, n_W0, 64), output size (n_im, n_H0, n_W0, 1) 129 | batch_size = tf.shape(x)[0] 130 | deconv_shape = tf.stack([batch_size, x.shape[1], x.shape[2], 1]) 131 | DECONV = tf.nn.conv2d_transpose(CONV2, W3, output_shape=deconv_shape, 132 | strides=[1, 1, 1, 1], padding='SAME') 133 | DECONV = tf.squeeze(DECONV) 134 | 135 | return DECONV 136 | 137 | 138 | def model(X_dev): 139 | """ Runs the forward propagation to reconstruct images using trained model 140 | :param X_dev: input development frequency-space data 141 | :return: returns the image, reconstructed using a trained model 142 | """ 143 | 144 | ops.reset_default_graph() # to not overwrite tf variables 145 | (_, n_H0, n_W0, _) = X_dev.shape 146 | 147 | # Create Placeholders 148 | X, Y = create_placeholders(n_H0, n_W0) 149 | 150 | # Initialize parameters 151 | parameters = initialize_parameters() 152 | 153 | # Build the forward propagation in the tf graph 154 | forward_propagation(X, parameters) 155 | 156 | # Add ops to save and restore all the variables 157 | saver = tf.train.Saver() 158 | 159 | # Start the session to compute the tf graph 160 | with tf.Session() as sess: 161 | 162 | saver.restore(sess, "path to saved model/model_name.ckpt") 163 | 164 | print("Model restored") 165 | 166 | Y_recon_temp = forward_propagation(X, parameters) 167 | Y_recon = Y_recon_temp.eval({X: X_dev}) 168 | 169 | return parameters, Y_recon 170 | 171 | 172 | # Reconstruct the image using trained model 173 | _, Y_recon = model(X_dev) 174 | print('Y_recon.shape = ', Y_recon.shape) 175 | print('Y_dev.shape = ', Y_dev.shape) 176 | 177 | 178 | # Visualize the images, their reconstruction using iFFT and using trained model 179 | # 4 images to visualize: 180 | im1 = 32 181 | im2 = 33 182 | im3 = 34 183 | im4 = 35 184 | 185 | # iFFT back to image from corrupted frequency space 186 | # Complex image from real and imaginary part 187 | X_dev_compl = X_dev[:, :, :, 0] + X_dev[:, :, :, 1] * 1j 188 | 189 | #iFFT 190 | X_iFFT0 = np.fft.ifft2(X_dev_compl[im1, :, :]) 191 | X_iFFT1 = np.fft.ifft2(X_dev_compl[im2, :, :]) 192 | X_iFFT2 = np.fft.ifft2(X_dev_compl[im3, :, :]) 193 | X_iFFT3 = np.fft.ifft2(X_dev_compl[im4, :, :]) 194 | 195 | # Magnitude of complex image 196 | X_iFFT_M1 = np.sqrt(np.power(X_iFFT0.real, 2) 197 | + np.power(X_iFFT0.imag, 2)) 198 | X_iFFT_M2 = np.sqrt(np.power(X_iFFT1.real, 2) 199 | + np.power(X_iFFT1.imag, 2)) 200 | X_iFFT_M3 = np.sqrt(np.power(X_iFFT2.real, 2) 201 | + np.power(X_iFFT2.imag, 2)) 202 | X_iFFT_M4 = np.sqrt(np.power(X_iFFT3.real, 2) 203 | + np.power(X_iFFT3.imag, 2)) 204 | 205 | # SHOW 206 | # Show Y - input images 207 | plt.subplot(341), plt.imshow(Y_dev[im1, :, :], cmap='gray') 208 | plt.title('Y_dev1'), plt.xticks([]), plt.yticks([]) 209 | plt.subplot(342), plt.imshow(Y_dev[im2, :, :], cmap='gray') 210 | plt.title('Y_dev2'), plt.xticks([]), plt.yticks([]) 211 | plt.subplot(343), plt.imshow(Y_dev[im3, :, :], cmap='gray') 212 | plt.title('Y_dev3'), plt.xticks([]), plt.yticks([]) 213 | plt.subplot(344), plt.imshow(Y_dev[im4, :, :], cmap='gray') 214 | plt.title('Y_dev4'), plt.xticks([]), plt.yticks([]) 215 | 216 | # Show images reconstructed using iFFT 217 | plt.subplot(345), plt.imshow(X_iFFT_M1, cmap='gray') 218 | plt.title('X_iFFT1'), plt.xticks([]), plt.yticks([]) 219 | plt.subplot(346), plt.imshow(X_iFFT_M2, cmap='gray') 220 | plt.title('X_iFFT2'), plt.xticks([]), plt.yticks([]) 221 | plt.subplot(347), plt.imshow(X_iFFT_M3, cmap='gray') 222 | plt.title('X_iFFT3'), plt.xticks([]), plt.yticks([]) 223 | plt.subplot(348), plt.imshow(X_iFFT_M4, cmap='gray') 224 | plt.title('X_iFFT4'), plt.xticks([]), plt.yticks([]) 225 | 226 | # Show images reconstructed using model 227 | plt.subplot(349), plt.imshow(Y_recon[im1, :, :], cmap='gray') 228 | plt.title('Y_recon1'), plt.xticks([]), plt.yticks([]) 229 | plt.subplot(3, 4, 10), plt.imshow(Y_recon[im2, :, :], cmap='gray') 230 | plt.title('Y_recon2'), plt.xticks([]), plt.yticks([]) 231 | plt.subplot(3, 4, 11), plt.imshow(Y_recon[im3, :, :], cmap='gray') 232 | plt.title('Y_recon3'), plt.xticks([]), plt.yticks([]) 233 | plt.subplot(3, 4, 12), plt.imshow(Y_recon[im4, :, :], cmap='gray') 234 | plt.title('Y_recon4'), plt.xticks([]), plt.yticks([]) 235 | plt.show() 236 | 237 | 238 | 239 | -------------------------------------------------------------------------------- /myAutomap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.framework import ops 4 | import math 5 | import time 6 | from generate_input import load_images_from_folder 7 | 8 | 9 | # Load training data: 10 | tic1 = time.time() 11 | dir_train = 'path to the folder with images for training' # Folder with images 12 | n_im = 10000 # How many images to load 13 | X_train, Y_train = load_images_from_folder( # Load images for training 14 | dir_train, 15 | n_im, 16 | normalize=False, 17 | imrotate=True) 18 | toc1 = time.time() 19 | print('Time to load data = ', (toc1 - tic1)) 20 | print('X_train.shape at input = ', X_train.shape) 21 | print('Y_train.shape at input = ', Y_train.shape) 22 | 23 | 24 | def create_placeholders(n_H0, n_W0): 25 | """ Creates placeholders for x and y for tf.session 26 | :param n_H0: image height 27 | :param n_W0: image width 28 | :return: x and y - tf placeholders 29 | """ 30 | 31 | x = tf.placeholder(tf.float32, shape=[None, n_H0, n_W0, 2], name='x') 32 | y = tf.placeholder(tf.float32, shape=[None, n_H0, n_W0], name='y') 33 | 34 | return x, y 35 | 36 | 37 | def initialize_parameters(): 38 | """ Initializes filters for the convolutional and de-convolutional layers 39 | :return: parameters - a dictionary of filters (W1 - first convolutional 40 | layer, W2 - second convolutional layer, W3 - de-convolutional layer 41 | """ 42 | 43 | W1 = tf.get_variable("W1", [5, 5, 1, 64], # 64 filters of size 5x5 44 | initializer=tf.contrib.layers.xavier_initializer 45 | (seed=0)) 46 | W2 = tf.get_variable("W2", [5, 5, 64, 64], # 64 filters of size 5x5 47 | initializer=tf.contrib.layers.xavier_initializer 48 | (seed=0)) 49 | W3 = tf.get_variable("W3", [7, 7, 1, 64], # 64 filters of size 7x7 50 | initializer=tf.contrib.layers.xavier_initializer 51 | (seed=0)) # conv2d_transpose 52 | 53 | parameters = {"W1": W1, 54 | "W2": W2, 55 | "W3": W3} 56 | 57 | return parameters 58 | 59 | 60 | def forward_propagation(x, parameters): 61 | """ Defines all layers for forward propagation: 62 | Fully connected (FC1) -> tanh activation: size (n_im, n_H0 * n_W0) 63 | -> Fully connected (FC2) -> tanh activation: size (n_im, n_H0 * n_W0) 64 | -> Convolutional -> ReLU activation: size (n_im, n_H0, n_W0, 64) 65 | -> Convolutional -> ReLU activation with l1 regularization: size (n_im, n_H0, n_W0, 64) 66 | -> De-convolutional: size (n_im, n_H0, n_W0) 67 | :param x: Input - images in frequency space, size (n_im, n_H0, n_W0, 2) 68 | :param parameters: parameters of the layers (e.g. filters) 69 | :return: output of the last layer of the neural network 70 | """ 71 | 72 | x_temp = tf.contrib.layers.flatten(x) # size (n_im, n_H0 * n_W0 * 2) 73 | n_out = np.int(x.shape[1] * x.shape[2]) # size (n_im, n_H0 * n_W0) 74 | 75 | with tf.device('/gpu:0'): 76 | # FC: input size (n_im, n_H0 * n_W0 * 2), output size (n_im, n_H0 * n_W0) 77 | FC1 = tf.contrib.layers.fully_connected( 78 | x_temp, 79 | n_out, 80 | activation_fn=tf.tanh, 81 | normalizer_fn=None, 82 | normalizer_params=None, 83 | weights_initializer=tf.contrib.layers.xavier_initializer(), 84 | weights_regularizer=None, 85 | biases_initializer=None, 86 | biases_regularizer=None, 87 | reuse=tf.AUTO_REUSE, 88 | variables_collections=None, 89 | outputs_collections=None, 90 | trainable=True, 91 | scope='fc1') 92 | 93 | with tf.device('/cpu:0'): 94 | # FC: input size (n_im, n_H0 * n_W0), output size (n_im, n_H0 * n_W0) 95 | FC2 = tf.contrib.layers.fully_connected( 96 | FC1, 97 | n_out, 98 | activation_fn=tf.tanh, 99 | normalizer_fn=None, 100 | normalizer_params=None, 101 | weights_initializer=tf.contrib.layers.xavier_initializer(), 102 | weights_regularizer=None, 103 | biases_initializer=None, 104 | biases_regularizer=None, 105 | reuse=tf.AUTO_REUSE, 106 | variables_collections=None, 107 | outputs_collections=None, 108 | trainable=True, 109 | scope='fc2') 110 | 111 | # Reshape output from FC layers into array of size (n_im, n_H0, n_W0, 1): 112 | FC_M = tf.reshape(FC2, [tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], 1]) 113 | 114 | # Retrieve the parameters from the dictionary "parameters": 115 | W1 = parameters['W1'] 116 | W2 = parameters['W2'] 117 | W3 = parameters['W3'] 118 | 119 | # CONV2D: filters W1, stride of 1, padding 'SAME' 120 | # Input size (n_im, n_H0, n_W0, 1), output size (n_im, n_H0, n_W0, 64) 121 | Z1 = tf.nn.conv2d(FC_M, W1, strides=[1, 1, 1, 1], padding='SAME') 122 | # RELU 123 | CONV1 = tf.nn.relu(Z1) 124 | 125 | # CONV2D: filters W2, stride 1, padding 'SAME' 126 | # Input size (n_im, n_H0, n_W0, 64), output size (n_im, n_H0, n_W0, 64) 127 | # Z2 = tf.nn.conv2d(CONV1, W2, strides=[1, 1, 1, 1], padding='SAME') 128 | # RELU 129 | # CONV2 = tf.nn.relu(Z2) 130 | CONV2 = tf.layers.conv2d( 131 | CONV1, 132 | filters=64, 133 | kernel_size=5, 134 | strides=(1, 1), 135 | padding='same', 136 | data_format='channels_last', 137 | dilation_rate=(1, 1), 138 | activation=tf.nn.relu, 139 | use_bias=True, 140 | kernel_initializer=None, 141 | bias_initializer=tf.zeros_initializer(), 142 | kernel_regularizer=tf.contrib.layers.l1_regularizer(0.0001), 143 | bias_regularizer=None, 144 | activity_regularizer=None, 145 | kernel_constraint=None, 146 | bias_constraint=None, 147 | trainable=True, 148 | name='conv2', 149 | reuse=tf.AUTO_REUSE) 150 | 151 | # DE-CONV2D: filters W3, stride 1, padding 'SAME' 152 | # Input size (n_im, n_H0, n_W0, 64), output size (n_im, n_H0, n_W0, 1) 153 | batch_size = tf.shape(x)[0] 154 | deconv_shape = tf.stack([batch_size, x.shape[1], x.shape[2], 1]) 155 | DECONV = tf.nn.conv2d_transpose(CONV2, W3, output_shape=deconv_shape, 156 | strides=[1, 1, 1, 1], padding='SAME') 157 | DECONV = tf.squeeze(DECONV) 158 | 159 | return DECONV 160 | 161 | 162 | def compute_cost(DECONV, Y): 163 | """ 164 | Computes cost (squared loss) between the output of forward propagation and 165 | the label image 166 | :param DECONV: output of forward propagation 167 | :param Y: label image 168 | :return: cost (squared loss) 169 | """ 170 | 171 | cost = tf.square(DECONV - Y) 172 | 173 | return cost 174 | 175 | 176 | def random_mini_batches(x, y, mini_batch_size=64, seed=0): 177 | """ Shuffles training examples and partitions them into mini-batches 178 | to speed up the gradient descent 179 | :param x: input frequency space data 180 | :param y: input image space data 181 | :param mini_batch_size: mini-batch size 182 | :param seed: can be chosen to keep the random choice consistent 183 | :return: a mini-batch of size mini_batch_size of training examples 184 | """ 185 | 186 | m = x.shape[0] # number of input images 187 | mini_batches = [] 188 | np.random.seed(seed) 189 | 190 | # Shuffle (x, y) 191 | permutation = list(np.random.permutation(m)) 192 | shuffled_X = x[permutation, :] 193 | shuffled_Y = y[permutation, :] 194 | 195 | # Partition (shuffled_X, shuffled_Y). Minus the end case. 196 | num_complete_minibatches = int(math.floor( 197 | m / mini_batch_size)) # number of mini batches of size mini_batch_size 198 | 199 | for k in range(0, num_complete_minibatches): 200 | mini_batch_X = shuffled_X[k * mini_batch_size:k * mini_batch_size 201 | + mini_batch_size, :, :, :] 202 | mini_batch_Y = shuffled_Y[k * mini_batch_size:k * mini_batch_size 203 | + mini_batch_size, :, :] 204 | mini_batch = (mini_batch_X, mini_batch_Y) 205 | mini_batches.append(mini_batch) 206 | 207 | # Handling the end case (last mini-batch < mini_batch_size) 208 | if m % mini_batch_size != 0: 209 | mini_batch_X = shuffled_X[num_complete_minibatches 210 | * mini_batch_size: m, :, :, :] 211 | mini_batch_Y = shuffled_Y[num_complete_minibatches 212 | * mini_batch_size: m, :, :] 213 | mini_batch = (mini_batch_X, mini_batch_Y) 214 | mini_batches.append(mini_batch) 215 | 216 | return mini_batches 217 | 218 | 219 | def model(X_train, Y_train, learning_rate=0.0001, 220 | num_epochs=100, minibatch_size=64, print_cost=True): 221 | """ Runs the forward and backward propagation 222 | :param X_train: input training frequency-space data 223 | :param Y_train: input training image-space data 224 | :param learning_rate: learning rate of gradient descent 225 | :param num_epochs: number of epochs 226 | :param minibatch_size: size of mini-batch 227 | :param print_cost: if True - the cost will be printed every epoch, as well 228 | as how long it took to run the epoch 229 | :return: this function saves the model to a file. The model can then 230 | be used to reconstruct the image from frequency space 231 | """ 232 | 233 | with tf.device('/gpu:0'): 234 | ops.reset_default_graph() # to not overwrite tf variables 235 | seed = 3 236 | (m, n_H0, n_W0, _) = X_train.shape 237 | 238 | # Create Placeholders 239 | X, Y = create_placeholders(n_H0, n_W0) 240 | 241 | # Initialize parameters 242 | parameters = initialize_parameters() 243 | 244 | # Build the forward propagation in the tf graph 245 | DECONV = forward_propagation(X, parameters) 246 | 247 | # Add cost function to tf graph 248 | cost = compute_cost(DECONV, Y) 249 | 250 | # Backpropagation 251 | optimizer = tf.train.RMSPropOptimizer(learning_rate).minimize(cost) 252 | 253 | # Initialize all the variables globally 254 | init = tf.global_variables_initializer() 255 | 256 | # Add ops to save and restore all the variables 257 | saver = tf.train.Saver() 258 | 259 | # For memory 260 | config = tf.ConfigProto() 261 | config.gpu_options.allow_growth = True 262 | 263 | # Memory config 264 | #config = tf.ConfigProto() 265 | #config.gpu_options.allow_growth = True 266 | config = tf.ConfigProto(log_device_placement=True) 267 | 268 | # Start the session to compute the tf graph 269 | with tf.Session(config=config) as sess: 270 | 271 | # Initialization 272 | sess.run(init) 273 | 274 | # Training loop 275 | for epoch in range(num_epochs): 276 | tic = time.time() 277 | 278 | minibatch_cost = 0. 279 | num_minibatches = int(m / minibatch_size) # number of minibatches 280 | seed += 1 281 | minibatches = random_mini_batches(X_train, Y_train, 282 | minibatch_size, seed) 283 | # Minibatch loop 284 | for minibatch in minibatches: 285 | # Select a minibatch 286 | (minibatch_X, minibatch_Y) = minibatch 287 | # Run the session to execute the optimizer and the cost 288 | _, temp_cost = sess.run( 289 | [optimizer, cost], 290 | feed_dict={X: minibatch_X, Y: minibatch_Y}) 291 | 292 | cost_mean = np.mean(temp_cost) / num_minibatches 293 | minibatch_cost += cost_mean 294 | 295 | # Print the cost every epoch 296 | if print_cost: 297 | toc = time.time() 298 | print ('EPOCH = ', epoch, 'COST = ', minibatch_cost, 'Elapsed time = ', (toc - tic)) 299 | 300 | # Save the variables to disk. 301 | save_path = saver.save(sess, "path to save model/model_name.ckpt") 302 | print("Model saved in file: %s" % save_path) 303 | 304 | sess.close() 305 | 306 | 307 | # Finally run the model! 308 | model(X_train, Y_train, 309 | learning_rate=0.00002, 310 | num_epochs=30, 311 | minibatch_size=64, # should be < than the number of input examples 312 | print_cost=True) 313 | 314 | 315 | --------------------------------------------------------------------------------