├── MNIST_files ├── t10k-images-idx3-ubyte ├── t10k-labels-idx1-ubyte ├── train-images-idx3-ubyte └── train-labels-idx1-ubyte ├── README.md ├── convolutional_autoencoder.py ├── mnist.py ├── models ├── __init__.py ├── layers.py └── model.py └── saver ├── checkpoint ├── cnn-100000.data-00000-of-00001 ├── cnn-100000.index └── cnn-100000.meta /MNIST_files/t10k-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Seratna/TensorFlow-Convolutional-AutoEncoder/131e19c657ea23e3c0845fb24e15dae9b0c2917a/MNIST_files/t10k-images-idx3-ubyte -------------------------------------------------------------------------------- /MNIST_files/t10k-labels-idx1-ubyte: -------------------------------------------------------------------------------- 1 | '                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                             -------------------------------------------------------------------------------- /MNIST_files/train-images-idx3-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Seratna/TensorFlow-Convolutional-AutoEncoder/131e19c657ea23e3c0845fb24e15dae9b0c2917a/MNIST_files/train-images-idx3-ubyte -------------------------------------------------------------------------------- /MNIST_files/train-labels-idx1-ubyte: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Seratna/TensorFlow-Convolutional-AutoEncoder/131e19c657ea23e3c0845fb24e15dae9b0c2917a/MNIST_files/train-labels-idx1-ubyte -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Convolutional AutoEncoder 2 | 3 | This project provides utilities to build a deep Convolutional AutoEncoder (CAE) in just a few lines of code. 4 | 5 | This project is based only on TensorFlow. 6 | 7 | 8 | ## Experiments 9 | 10 | `convolutional_autoencoder.py` shows an example of a CAE for the MNIST dataset. 11 | 12 | The structure of this conv autoencoder is shown below: 13 | 14 | ![autoencoder structure](https://cloud.githubusercontent.com/assets/13087207/23317657/540f170a-fa9d-11e6-9bcb-8b529a805a9f.png) 15 | 16 | The encoding part has 2 convolution layers (each followed by a max-pooling layer) and a fully connected layer. This part 17 | would encode an input image into a 20-dimension vector (representation). And then the decoding part, which has 1 fully connected layer 18 | and 2 convolution layers, would decode the representation back to a 28x28 image (reconstruction). 19 | 20 | Training was done using GTX1070 GPU, batch size 100, 100000 passes. 21 | 22 | Trained weights (saved in the saver directory) of the 1st convolutional layer are shown below: 23 | ![conv_1_weights](https://cloud.githubusercontent.com/assets/13087207/23318050/e4ae8006-fa9e-11e6-8687-c1b732241136.png) 24 | 25 | And here's some of the reconstruction results: 26 | ![reconstructions](https://cloud.githubusercontent.com/assets/13087207/23318055/e717e6e8-fa9e-11e6-91b4-f4bed411c5b8.png) 27 | 28 | ## Implementation 29 | 30 | ### Un-pooling 31 | 32 | Since the max-pooling operation is not injective, and TensorFlow does not have a built-in unpooling method, 33 | we have to implement our own approximation. 34 | But it is actually easy to do so using TensorFlow's [`tf.nn.conv2d_transpose()`](https://www.tensorflow.org/api_docs/python/nn/convolution#conv2d_transpose) method. 35 | 36 | The idea was to replace each entry in the pooled map with an NxM kernel with the original entry in the upper left, 37 | where N and M are the shape of the pooling kernel. 38 | 39 | ![un-pooling](https://cloud.githubusercontent.com/assets/13087207/22672037/77e521c6-ec9f-11e6-9aba-119f954cd9f8.png) 40 | 41 | This is equivalent to doing transpose of conv2d on the input map 42 | with a kernel that has 1 on the upper left and 0 elsewhere. 43 | Therefore we could do this trick with `tf.nn.conv2d_transpose()` method. -------------------------------------------------------------------------------- /convolutional_autoencoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from matplotlib import pyplot as plt 4 | 5 | from models import * 6 | from mnist import MNIST # this is the MNIST data manager that provides training/testing batches 7 | 8 | 9 | class ConvolutionalAutoencoder(object): 10 | """ 11 | 12 | """ 13 | def __init__(self): 14 | """ 15 | build the graph 16 | """ 17 | # place holder of input data 18 | x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1]) # [#batch, img_height, img_width, #channels] 19 | 20 | # encode 21 | conv1 = Convolution2D([5, 5, 1, 32], activation=tf.nn.relu, scope='conv_1')(x) 22 | pool1 = MaxPooling(kernel_shape=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', scope='pool_1')(conv1) 23 | conv2 = Convolution2D([5, 5, 32, 32], activation=tf.nn.relu, scope='conv_2')(pool1) 24 | pool2 = MaxPooling(kernel_shape=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', scope='pool_2')(conv2) 25 | unfold = Unfold(scope='unfold')(pool2) 26 | encoded = FullyConnected(20, activation=tf.nn.relu, scope='encode')(unfold) 27 | # decode 28 | decoded = FullyConnected(7*7*32, activation=tf.nn.relu, scope='decode')(encoded) 29 | fold = Fold([-1, 7, 7, 32], scope='fold')(decoded) 30 | unpool1 = UnPooling((2, 2), output_shape=tf.shape(conv2), scope='unpool_1')(fold) 31 | deconv1 = DeConvolution2D([5, 5, 32, 32], output_shape=tf.shape(pool1), activation=tf.nn.relu, scope='deconv_1')(unpool1) 32 | unpool2 = UnPooling((2, 2), output_shape=tf.shape(conv1), scope='unpool_2')(deconv1) 33 | reconstruction = DeConvolution2D([5, 5, 1, 32], output_shape=tf.shape(x), activation=tf.nn.sigmoid, scope='deconv_2')(unpool2) 34 | 35 | # loss function 36 | loss = tf.nn.l2_loss(x - reconstruction) # L2 loss 37 | 38 | # training 39 | training = tf.train.AdamOptimizer(1e-4).minimize(loss) 40 | 41 | # 42 | self.x = x 43 | self.reconstruction = reconstruction 44 | self.loss = loss 45 | self.training = training 46 | 47 | def train(self, batch_size, passes, new_training=True): 48 | """ 49 | 50 | :param batch_size: 51 | :param passes: 52 | :param new_training: 53 | :return: 54 | """ 55 | mnist = MNIST() 56 | 57 | with tf.Session() as sess: 58 | # prepare session 59 | if new_training: 60 | saver, global_step = Model.start_new_session(sess) 61 | else: 62 | saver, global_step = Model.continue_previous_session(sess, ckpt_file='saver/checkpoint') 63 | 64 | # start training 65 | for step in range(1+global_step, 1+passes+global_step): 66 | x, y = mnist.get_batch(batch_size) 67 | self.training.run(feed_dict={self.x: x}) 68 | 69 | if step % 10 == 0: 70 | loss = self.loss.eval(feed_dict={self.x: x}) 71 | print("pass {}, training loss {}".format(step, loss)) 72 | 73 | if step % 1000 == 0: # save weights 74 | saver.save(sess, 'saver/cnn', global_step=step) 75 | print('checkpoint saved') 76 | 77 | def reconstruct(self): 78 | """ 79 | 80 | """ 81 | def weights_to_grid(weights, rows, cols): 82 | """convert the weights tensor into a grid for visualization""" 83 | height, width, in_channel, out_channel = weights.shape 84 | padded = np.pad(weights, [(1, 1), (1, 1), (0, 0), (0, rows * cols - out_channel)], 85 | mode='constant', constant_values=0) 86 | transposed = padded.transpose((3, 1, 0, 2)) 87 | reshaped = transposed.reshape((rows, -1)) 88 | grid_rows = [row.reshape((-1, height + 2, in_channel)).transpose((1, 0, 2)) for row in reshaped] 89 | grid = np.concatenate(grid_rows, axis=0) 90 | 91 | return grid.squeeze() 92 | 93 | mnist = MNIST() 94 | 95 | with tf.Session() as sess: 96 | saver, global_step = Model.continue_previous_session(sess, ckpt_file='saver/checkpoint') 97 | 98 | # visualize weights 99 | first_layer_weights = tf.get_default_graph().get_tensor_by_name("conv_1/kernel:0").eval() 100 | grid_image = weights_to_grid(first_layer_weights, 4, 8) 101 | 102 | fig, ax0 = plt.subplots(ncols=1, figsize=(8, 4)) 103 | ax0.imshow(grid_image, cmap=plt.cm.gray, interpolation='nearest') 104 | ax0.set_title('first conv layers weights') 105 | plt.show() 106 | 107 | # visualize results 108 | batch_size = 36 109 | x, y = mnist.get_batch(batch_size, dataset='testing') 110 | org, recon = sess.run((self.x, self.reconstruction), feed_dict={self.x: x}) 111 | 112 | input_images = weights_to_grid(org.transpose((1, 2, 3, 0)), 6, 6) 113 | recon_images = weights_to_grid(recon.transpose((1, 2, 3, 0)), 6, 6) 114 | 115 | fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(10, 5)) 116 | ax0.imshow(input_images, cmap=plt.cm.gray, interpolation='nearest') 117 | ax0.set_title('input images') 118 | ax1.imshow(recon_images, cmap=plt.cm.gray, interpolation='nearest') 119 | ax1.set_title('reconstructed images') 120 | plt.show() 121 | 122 | 123 | def main(): 124 | conv_autoencoder = ConvolutionalAutoencoder() 125 | # conv_autoencoder.train(batch_size=100, passes=100000, new_training=True) 126 | conv_autoencoder.reconstruct() 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /mnist.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import numpy as np 3 | import os 4 | 5 | 6 | class MNIST(object): 7 | """ 8 | Prepare data batches for training and testing. 9 | 10 | Data files are from http://yann.lecun.com/exdb/mnist/ 11 | 12 | The training set contains 60000 examples, and the test set 10000 examples. 13 | The first 5000 examples of the test set are taken from the original NIST training set. 14 | The last 5000 are taken from the original NIST test set. 15 | The first 5000 are cleaner and easier than the last 5000. 16 | each image is 28x28 pixels 17 | """ 18 | def __init__(self): 19 | self.training_images = None 20 | self.training_labels = None 21 | 22 | self.testing_images = None 23 | self.testing_labels = None 24 | 25 | def read_file(self, file_name) -> np.ndarray: 26 | """ 27 | read images from idx files. 28 | 29 | All the data are stored in idx format. 30 | All the integers in the files are stored in the MSB first format used by most non-Intel processors. 31 | Pixel value in the raw data was 0~255. 32 | This method normalize all values to 0~1. 33 | """ 34 | def bytes2int(input_bytes): 35 | return int.from_bytes(input_bytes, byteorder='big') 36 | 37 | with open(file_name, 'rb') as file: 38 | first2bytes = bytes2int(file.read(2)) 39 | assert first2bytes == 0 # The first 2 bytes are always 0 40 | 41 | data_type = bytes2int(file.read(1)) 42 | assert data_type == 8 # 0x08: unsigned byte 43 | 44 | num_dimensions = bytes2int(file.read(1)) 45 | shape = [bytes2int(file.read(4)) for _ in range(num_dimensions)] + [1] # make the num of channels to be 1 46 | 47 | matrix = np.frombuffer(file.read(), dtype=np.uint8).reshape(shape) 48 | 49 | return matrix 50 | 51 | def vectorize(self, labels): 52 | """ 53 | convert labels into a 1-hot vectors 54 | """ 55 | num_labels = labels.shape[0] 56 | num_classes = 10 57 | 58 | keys = np.zeros((num_labels, num_classes)) 59 | for key, label in zip(keys, labels): 60 | key[label] = 1 61 | return keys 62 | 63 | def get_batch(self, batch_size, dataset='training'): 64 | """ 65 | get a batch of images and corresponding labels. 66 | 67 | returned images would have the shape of (batch_size, 28, 28); 68 | returned labels would have the shape of (batch_size, 10) 69 | 70 | :param batch_size: 71 | :param dataset: 'training' or 'testing' 72 | """ 73 | if dataset == 'training': 74 | if self.training_images is None or self.training_labels is None: 75 | self.training_images = self.read_file('MNIST_files/train-images-idx3-ubyte') / 255 76 | self.training_labels = self.vectorize(self.read_file('MNIST_files/train-labels-idx1-ubyte')) 77 | images = self.training_images 78 | labels = self.training_labels 79 | elif dataset == 'testing': 80 | if self.testing_images is None or self.testing_labels is None: 81 | self.testing_images = self.read_file('MNIST_files/t10k-images-idx3-ubyte') / 255 82 | self.testing_labels = self.vectorize(self.read_file('MNIST_files/t10k-labels-idx1-ubyte')) 83 | images = self.testing_images 84 | labels = self.testing_labels 85 | else: 86 | return 87 | 88 | num_samples = labels.shape[0] 89 | idx = np.random.randint(num_samples, size=batch_size) 90 | 91 | return images[idx], labels[idx] 92 | 93 | 94 | def main(): 95 | mnist = MNIST() 96 | images, labels = mnist.get_batch(10, 'training') 97 | 98 | for im, lb in zip(images, labels): 99 | plt.imshow(im, cmap=plt.cm.gray, interpolation='nearest') 100 | plt.text(1, 1, lb, color='w') 101 | plt.show() 102 | 103 | 104 | if __name__ == '__main__': 105 | main() 106 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.layers import Convolution2D, DeConvolution2D, MaxPooling, UnPooling, Fold, Unfold, FullyConnected 2 | from models.model import Model 3 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | 7 | class Layer(object, metaclass=ABCMeta): 8 | """ 9 | 10 | """ 11 | def __init__(self): 12 | pass 13 | 14 | @abstractmethod 15 | def call(self, *args, **kwargs): 16 | raise NotImplementedError 17 | 18 | def __call__(self, *args, **kwargs): 19 | return self.call(*args, **kwargs) 20 | 21 | 22 | class Convolution2D(Layer): 23 | """ 24 | 25 | """ 26 | def __init__(self, 27 | kernel_shape, 28 | kernel=None, 29 | bias=None, 30 | strides=(1, 1, 1, 1), 31 | padding='SAME', 32 | activation=None, 33 | scope=''): 34 | Layer.__init__(self) 35 | 36 | self.kernel_shape = kernel_shape 37 | self.kernel = kernel 38 | self.bias = bias 39 | self.strides = strides 40 | self.padding = padding 41 | self.activation = activation 42 | self.scope = scope 43 | 44 | def build(self, input_tensor): 45 | # build kernel 46 | if self.kernel: 47 | assert self.kernel.get_shape() == self.kernel_shape 48 | else: 49 | self.kernel = tf.Variable(tf.truncated_normal(self.kernel_shape, stddev=0.1), name='kernel') 50 | 51 | # build bias 52 | kernel_height, kernel_width, num_input_channels, num_output_channels = self.kernel.get_shape() 53 | if self.bias: 54 | assert self.bias.get_shape() == (num_output_channels, ) 55 | else: 56 | self.bias = tf.Variable(tf.constant(0.1, shape=[num_output_channels]), name='bias') 57 | 58 | # convolution 59 | conv = tf.nn.conv2d(input_tensor, self.kernel, strides=self.strides, padding=self.padding) 60 | 61 | # activation 62 | if self.activation: 63 | return self.activation(conv + self.bias) 64 | return conv + self.bias 65 | 66 | def call(self, input_tensor): 67 | if self.scope: 68 | with tf.variable_scope(self.scope) as scope: 69 | return self.build(input_tensor) 70 | else: 71 | return self.build(input_tensor) 72 | 73 | 74 | class DeConvolution2D(Layer): 75 | """ 76 | 77 | """ 78 | def __init__(self, 79 | kernel_shape, 80 | output_shape, 81 | kernel=None, 82 | bias=None, 83 | strides=(1, 1, 1, 1), 84 | padding='SAME', 85 | activation=None, 86 | scope=''): 87 | Layer.__init__(self) 88 | 89 | self.kernel_shape = kernel_shape 90 | self.output_shape = output_shape 91 | self.kernel = kernel 92 | self.bias = bias 93 | self.strides = strides 94 | self.padding = padding 95 | self.activation = activation 96 | self.scope = scope 97 | 98 | def build(self, input_tensor): 99 | # build kernel 100 | if self.kernel: 101 | assert self.kernel.get_shape() == self.kernel_shape 102 | else: 103 | self.kernel = tf.Variable(tf.truncated_normal(self.kernel_shape, stddev=0.1), name='kernel') 104 | 105 | # build bias 106 | window_height, window_width, num_output_channels, num_input_channels = self.kernel.get_shape() 107 | if self.bias: 108 | assert self.bias.get_shape() == (num_output_channels, ) 109 | else: 110 | self.bias = tf.Variable(tf.constant(0.1, shape=[num_output_channels]), name='bias') 111 | 112 | # convolution 113 | deconv = tf.nn.conv2d_transpose(input_tensor, 114 | self.kernel, 115 | output_shape=self.output_shape, 116 | strides=self.strides, 117 | padding=self.padding) 118 | 119 | # activation 120 | if self.activation: 121 | return self.activation(deconv + self.bias) 122 | return deconv + self.bias 123 | 124 | def call(self, input_tensor): 125 | if self.scope: 126 | with tf.variable_scope(self.scope) as scope: 127 | return self.build(input_tensor) 128 | else: 129 | return self.build(input_tensor) 130 | 131 | 132 | class MaxPooling(Layer): 133 | """ 134 | 135 | """ 136 | def __init__(self, 137 | kernel_shape, 138 | strides, 139 | padding, 140 | scope=''): 141 | Layer.__init__(self) 142 | 143 | self.kernel_shape = kernel_shape 144 | self.strides = strides 145 | self.padding = padding 146 | self.scope = scope 147 | 148 | def build(self, input_tensor): 149 | return tf.nn.max_pool(input_tensor, ksize=self.kernel_shape, strides=self.strides, padding=self.padding) 150 | 151 | def call(self, input_tensor): 152 | if self.scope: 153 | with tf.variable_scope(self.scope) as scope: 154 | return self.build(input_tensor) 155 | else: 156 | return self.build(input_tensor) 157 | 158 | 159 | class UnPooling(Layer): 160 | """ 161 | Unpool a max-pooled layer. 162 | 163 | Currently this method does not use the argmax information from the previous pooling layer. 164 | Currently this method assumes that the size of the max-pooling filter is same as the strides. 165 | 166 | Each entry in the pooled map would be replaced with an NxN kernel with the original entry in the upper left. 167 | For example: a 1x2x2x1 map of 168 | 169 | [[[[1], [2]], 170 | [[3], [4]]]] 171 | 172 | could be unpooled to a 1x4x4x1 map of 173 | 174 | [[[[ 1.], [ 0.], [ 2.], [ 0.]], 175 | [[ 0.], [ 0.], [ 0.], [ 0.]], 176 | [[ 3.], [ 0.], [ 4.], [ 0.]], 177 | [[ 0.], [ 0.], [ 0.], [ 0.]]]] 178 | """ 179 | def __init__(self, 180 | kernel_shape, 181 | output_shape, 182 | scope=''): 183 | Layer.__init__(self) 184 | 185 | self.kernel_shape = kernel_shape 186 | self.output_shape = output_shape 187 | self.scope = scope 188 | 189 | def build(self, input_tensor): 190 | num_channels = input_tensor.get_shape()[-1] 191 | input_dtype_as_numpy = input_tensor.dtype.as_numpy_dtype() 192 | kernel_rows, kernel_cols = self.kernel_shape 193 | 194 | # build kernel 195 | kernel_value = np.zeros((kernel_rows, kernel_cols, num_channels, num_channels), dtype=input_dtype_as_numpy) 196 | kernel_value[0, 0, :, :] = np.eye(num_channels, num_channels) 197 | kernel = tf.constant(kernel_value) 198 | 199 | # do the un-pooling using conv2d_transpose 200 | unpool = tf.nn.conv2d_transpose(input_tensor, 201 | kernel, 202 | output_shape=self.output_shape, 203 | strides=(1, kernel_rows, kernel_cols, 1), 204 | padding='VALID') 205 | # TODO test!!! 206 | return unpool 207 | 208 | def call(self, input_tensor): 209 | if self.scope: 210 | with tf.variable_scope(self.scope) as scope: 211 | return self.build(input_tensor) 212 | else: 213 | return self.build(input_tensor) 214 | 215 | 216 | class Unfold(Layer): 217 | """ 218 | 219 | """ 220 | def __init__(self, 221 | scope=''): 222 | Layer.__init__(self) 223 | 224 | self.scope = scope 225 | 226 | def build(self, input_tensor): 227 | num_batch, height, width, num_channels = input_tensor.get_shape() 228 | 229 | return tf.reshape(input_tensor, [-1, (height * width * num_channels).value]) 230 | 231 | def call(self, input_tensor): 232 | if self.scope: 233 | with tf.variable_scope(self.scope) as scope: 234 | return self.build(input_tensor) 235 | else: 236 | return self.build(input_tensor) 237 | 238 | 239 | class Fold(Layer): 240 | """ 241 | 242 | """ 243 | def __init__(self, 244 | fold_shape, 245 | scope=''): 246 | Layer.__init__(self) 247 | 248 | self.fold_shape = fold_shape 249 | self.scope = scope 250 | 251 | def build(self, input_tensor): 252 | return tf.reshape(input_tensor, self.fold_shape) 253 | 254 | def call(self, input_tensor): 255 | if self.scope: 256 | with tf.variable_scope(self.scope) as scope: 257 | return self.build(input_tensor) 258 | else: 259 | return self.build(input_tensor) 260 | 261 | 262 | class FullyConnected(Layer): 263 | """ 264 | 265 | """ 266 | def __init__(self, 267 | output_dim, 268 | weights=None, 269 | bias=None, 270 | activation=None, 271 | scope=''): 272 | Layer.__init__(self) 273 | 274 | self.output_dim = output_dim 275 | self.weights = weights 276 | self.bias = bias 277 | self.activation = activation 278 | self.scope = scope 279 | 280 | def build(self, input_tensor): 281 | num_batch, input_dim = input_tensor.get_shape() 282 | 283 | # build weights 284 | if self.weights: 285 | assert self.weights.get_shape() == (input_dim.value, self.output_dim) 286 | else: 287 | self.weights = tf.Variable(tf.truncated_normal((input_dim.value, self.output_dim), stddev=0.1), 288 | name='weights') 289 | 290 | # build bias 291 | if self.bias: 292 | assert self.bias.get_shape() == (self.output_dim, ) 293 | else: 294 | self.bias = tf.Variable(tf.constant(0.1, shape=[self.output_dim]), name='bias') 295 | 296 | # fully connected layer 297 | fc = tf.matmul(input_tensor, self.weights) + self.bias 298 | 299 | # activation 300 | if self.activation: 301 | return self.activation(fc) 302 | return fc 303 | 304 | def call(self, input_tensor): 305 | if self.scope: 306 | with tf.variable_scope(self.scope) as scope: 307 | return self.build(input_tensor) 308 | else: 309 | return self.build(input_tensor) 310 | 311 | 312 | def main(): 313 | conv = Convolution2D([5, 5, 1, 32]) 314 | 315 | 316 | if __name__ == '__main__': 317 | main() 318 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class Model(object): 5 | """ 6 | 7 | """ 8 | 9 | def __init__(self): 10 | pass 11 | 12 | @staticmethod 13 | def start_new_session(sess): 14 | saver = tf.train.Saver() # create a saver 15 | global_step = 0 16 | 17 | sess.run(tf.global_variables_initializer()) 18 | print('started a new session') 19 | 20 | return saver, global_step 21 | 22 | @staticmethod 23 | def continue_previous_session(sess, ckpt_file): 24 | saver = tf.train.Saver() # create a saver 25 | 26 | with open(ckpt_file) as file: # read checkpoint file 27 | line = file.readline() # read the first line, which contains the file name of the latest checkpoint 28 | ckpt = line.split('"')[1] 29 | global_step = int(ckpt.split('-')[1]) 30 | 31 | # restore 32 | saver.restore(sess, 'saver/'+ckpt) 33 | print('restored from checkpoint ' + ckpt) 34 | 35 | return saver, global_step 36 | 37 | 38 | def main(): 39 | pass 40 | 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /saver/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "cnn-100000" 2 | all_model_checkpoint_paths: "cnn-96000" 3 | all_model_checkpoint_paths: "cnn-97000" 4 | all_model_checkpoint_paths: "cnn-98000" 5 | all_model_checkpoint_paths: "cnn-99000" 6 | all_model_checkpoint_paths: "cnn-100000" 7 | -------------------------------------------------------------------------------- /saver/cnn-100000.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Seratna/TensorFlow-Convolutional-AutoEncoder/131e19c657ea23e3c0845fb24e15dae9b0c2917a/saver/cnn-100000.data-00000-of-00001 -------------------------------------------------------------------------------- /saver/cnn-100000.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Seratna/TensorFlow-Convolutional-AutoEncoder/131e19c657ea23e3c0845fb24e15dae9b0c2917a/saver/cnn-100000.index -------------------------------------------------------------------------------- /saver/cnn-100000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Seratna/TensorFlow-Convolutional-AutoEncoder/131e19c657ea23e3c0845fb24e15dae9b0c2917a/saver/cnn-100000.meta --------------------------------------------------------------------------------