├── figures ├── 0.png ├── 1.png ├── 2.png ├── 3.png ├── 4.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png ├── 9.png ├── f0.png ├── f1.png ├── f2.png └── loss.svg ├── source ├── connector.py ├── utils.py ├── datamanager.py └── tf_process.py ├── LICENSE ├── README.md ├── run.py └── neuralnet └── net00_cnn.py /figures/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/0.png -------------------------------------------------------------------------------- /figures/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/1.png -------------------------------------------------------------------------------- /figures/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/2.png -------------------------------------------------------------------------------- /figures/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/3.png -------------------------------------------------------------------------------- /figures/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/4.png -------------------------------------------------------------------------------- /figures/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/5.png -------------------------------------------------------------------------------- /figures/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/6.png -------------------------------------------------------------------------------- /figures/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/7.png -------------------------------------------------------------------------------- /figures/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/8.png -------------------------------------------------------------------------------- /figures/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/9.png -------------------------------------------------------------------------------- /figures/f0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/f0.png -------------------------------------------------------------------------------- /figures/f1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/f1.png -------------------------------------------------------------------------------- /figures/f2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/f2.png -------------------------------------------------------------------------------- /source/connector.py: -------------------------------------------------------------------------------- 1 | def connect(nn): 2 | 3 | if(nn == 0): import neuralnet.net00_cnn as nn 4 | 5 | return nn 6 | -------------------------------------------------------------------------------- /source/utils.py: -------------------------------------------------------------------------------- 1 | import os, glob, shutil 2 | 3 | def make_dir(path, refresh=False): 4 | 5 | try: os.mkdir(path) 6 | except: 7 | if(refresh): 8 | shutil.rmtree(path) 9 | os.mkdir(path) 10 | 11 | def sorted_list(path): 12 | 13 | tmplist = glob.glob(path) 14 | tmplist.sort() 15 | 16 | return tmplist 17 | 18 | def min_max_norm(x): 19 | 20 | min_x, max_x = x.min(), x.max() 21 | return (x - min_x + 1e-12) / (max_x - min_x + 1e-12) 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 YeongHyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [TensorFlow] Attention mechanism with MNIST dataset 2 | ===== 3 | 4 | ## Usage 5 | ``` sh 6 | $ python run.py 7 | ``` 8 | 9 | ## Result 10 | 11 | ### Training 12 |
13 |

14 | 15 |

16 |

Loss graph.

17 |
18 | 19 | ### Test 20 |
21 | 22 |
23 | 24 |
25 | 26 |
27 | 28 |
29 | 30 |
31 |

Each figure shows input digit, attention map, and overlapped image sequentially.

32 |
33 | 34 | ### Further usage 35 |
36 |
37 |
38 |
39 |

The further usages. Detecting the location of digits can be conducted using an attention map.

40 |
41 | 42 | ## Requirements 43 | * TensorFlow 2.3.0 44 | * Numpy 1.18.5 45 | 46 | ## Additional Resources 47 | [1] Simple attention mechanism test by Myung Jin Kim 48 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse, time, os, operator 2 | 3 | import tensorflow as tf 4 | import source.connector as con 5 | import source.tf_process as tfp 6 | import source.datamanager as dman 7 | 8 | def main(): 9 | 10 | os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.gpu 11 | 12 | gpus = tf.config.experimental.list_physical_devices('GPU') 13 | if gpus: 14 | try: 15 | for gpu in gpus: 16 | tf.config.experimental.set_memory_growth(gpu, True) 17 | logical_gpus = tf.config.experimental.list_logical_devices('GPU') 18 | print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs") 19 | except RuntimeError as e: 20 | print(e) 21 | 22 | dataset = dman.Dataset() 23 | 24 | agent = con.connect(nn=FLAGS.nn).Agent(\ 25 | dim_h = dataset.height, \ 26 | dim_w = dataset.width, \ 27 | dim_c = dataset.channel, \ 28 | num_class = dataset.num_class, \ 29 | ksize = FLAGS.ksize, \ 30 | learning_rate = FLAGS.lr, \ 31 | path_ckpt = 'Checkpoint') 32 | 33 | time_tr = time.time() 34 | tfp.training(agent=agent, dataset=dataset, \ 35 | batch_size=FLAGS.batch, epochs=FLAGS.epochs) 36 | time_te = time.time() 37 | tfp.test(agent=agent, dataset=dataset, batch_size=FLAGS.batch) 38 | time_fin = time.time() 39 | 40 | print("Time (TR): %.5f [sec]" %(time_te - time_tr)) 41 | te_time = time_fin - time_te 42 | print("Time (TE): %.5f (%.5f [sec/sample])" %(te_time, te_time/dataset.num_te)) 43 | 44 | if __name__ == '__main__': 45 | 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--gpu', type=str, default="0", help='') 48 | parser.add_argument('--nn', type=int, default=0, help='') 49 | parser.add_argument('--ksize', type=int, default=3, help='') 50 | parser.add_argument('--lr', type=float, default=1e-4, help='') 51 | parser.add_argument('--batch', type=int, default=32, help='') 52 | parser.add_argument('--epochs', type=int, default=100, help='') 53 | 54 | FLAGS, unparsed = parser.parse_known_args() 55 | 56 | main() 57 | -------------------------------------------------------------------------------- /source/datamanager.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import source.utils as utils 4 | from sklearn.utils import shuffle 5 | 6 | class Dataset(object): 7 | 8 | def __init__(self, normalize=True): 9 | 10 | print("\nInitializing Dataset...") 11 | 12 | self.normalize = normalize 13 | 14 | (x_tr, y_tr), (x_te, y_te) = tf.keras.datasets.mnist.load_data() 15 | self.x_tr, self.y_tr = x_tr, y_tr 16 | self.x_te, self.y_te = x_te, y_te 17 | 18 | self.x_tr = np.ndarray.astype(self.x_tr, np.float32) 19 | self.x_te = np.ndarray.astype(self.x_te, np.float32) 20 | 21 | self.__normalizing() 22 | 23 | self.num_tr, self.num_te = self.x_tr.shape[0], self.x_te.shape[0] 24 | self.idx_tr, self.idx_te = 0, 0 25 | 26 | x_sample, y_sample = self.x_te[0], self.y_te[0] 27 | self.height = x_sample.shape[0] 28 | self.width = x_sample.shape[1] 29 | try: self.channel = x_sample.shape[2] 30 | except: self.channel = 1 31 | 32 | self.num_class = (y_te.max()+1) 33 | 34 | def __normalizing(self): 35 | 36 | for idx, _ in enumerate(self.x_tr): 37 | self.x_tr[idx] = utils.min_max_norm(self.x_tr[idx]) 38 | 39 | for idx, _ in enumerate(self.x_te): 40 | self.x_te[idx] = utils.min_max_norm(self.x_te[idx]) 41 | 42 | def __reset_index(self): 43 | 44 | self.idx_tr, self.idx_te = 0, 0 45 | 46 | def next_batch(self, batch_size=1, tt=0): 47 | 48 | if(tt == 0): 49 | idx_d, num_d, data_x, data_y = self.idx_tr, self.num_tr, self.x_tr, self.y_tr 50 | elif(tt == 1): 51 | idx_d, num_d, data_x, data_y = self.idx_te, self.num_te, self.x_te, self.y_te 52 | 53 | batch_x, batch_y, terminator = [], [], False 54 | while(True): 55 | try: 56 | tmp_x, tmp_y = data_x[idx_d].copy(), data_y[idx_d].copy() 57 | except: 58 | idx_d = 0 59 | self.x_tr, self.y_tr = shuffle(self.x_tr, self.y_tr) 60 | terminator = True 61 | break 62 | else: 63 | batch_x.append(np.expand_dims(tmp_x, axis=-1)) 64 | batch_y.append(np.diag(np.ones(self.num_class))[tmp_y]) 65 | idx_d += 1 66 | if(len(batch_x) == batch_size): break 67 | 68 | batch_x = np.asarray(batch_x) 69 | batch_y = np.asarray(batch_y) 70 | 71 | if(tt == 0): 72 | self.idx_tr = idx_d 73 | elif(tt == 1): 74 | self.idx_te = idx_d 75 | 76 | return {'x':batch_x.astype(np.float32), 'y':batch_y.astype(np.float32), 't':terminator} 77 | -------------------------------------------------------------------------------- /source/tf_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.ndimage 3 | import numpy as np 4 | import pandas as pd 5 | import matplotlib.pyplot as plt 6 | from sklearn.metrics import roc_curve, auc 7 | 8 | import source.utils as utils 9 | 10 | def training(agent, dataset, batch_size, epochs): 11 | 12 | print("\n** Training of the CNN to %d epoch | Batch size: %d" %(epochs, batch_size)) 13 | iteration = 0 14 | 15 | for epoch in range(epochs): 16 | 17 | while(True): 18 | minibatch = dataset.next_batch(batch_size=batch_size, tt=0) 19 | if(len(minibatch['x'].shape) == 1): break 20 | step_dict = agent.step(minibatch=minibatch, iteration=iteration, training=True) 21 | iteration += 1 22 | if(minibatch['t']): break 23 | 24 | print("Epoch [%d / %d] | Loss: %f" %(epoch, epochs, step_dict['losses']['entropy'])) 25 | agent.save_params(model='model_0_finepocch') 26 | 27 | def test(agent, dataset, batch_size): 28 | 29 | savedir = 'results_te' 30 | utils.make_dir(path=savedir, refresh=True) 31 | 32 | list_model = utils.sorted_list(os.path.join('Checkpoint', 'model*')) 33 | for idx_model, path_model in enumerate(list_model): 34 | list_model[idx_model] = path_model.split('/')[-1] 35 | 36 | for idx_model, path_model in enumerate(list_model): 37 | 38 | print("\n** Test with %s" %(path_model)) 39 | agent.load_params(model=path_model) 40 | utils.make_dir(path=os.path.join(savedir, path_model), refresh=False) 41 | 42 | saveidx = 0 43 | while(True): 44 | minibatch = dataset.next_batch(batch_size=batch_size, tt=1) 45 | if(len(minibatch['x'].shape) == 1): break 46 | step_dict = agent.step(minibatch=minibatch, training=False) 47 | 48 | for idx_y, _ in enumerate(minibatch['y']): 49 | y_true = np.argmax(minibatch['y'][idx_y]) 50 | y_pred = np.argmax(step_dict['y_hat'][idx_y]) 51 | 52 | canvas, canvas_attn = \ 53 | minibatch['x'][idx_y], scipy.ndimage.zoom(step_dict['attn'][idx_y].numpy(), 4, order=3) 54 | 55 | plt.clf() 56 | plt.figure(figsize=(8, 5)) 57 | 58 | plt.subplot(1, 2, 1) 59 | plt.axis('off') 60 | plt.title("Input") 61 | plt.imshow(canvas[:, :, 0], cmap='gray') 62 | 63 | plt.subplot(1, 2, 2) 64 | plt.axis('off') 65 | plt.title("Attention Map") 66 | plt.imshow(canvas_attn[:, :, 0], cmap='jet') 67 | 68 | plt.tight_layout() 69 | plt.savefig(os.path.join(savedir, path_model, "true_%d;pred_%d;%08d.png" %(y_true, y_pred, saveidx))) 70 | plt.close() 71 | saveidx += 1 72 | 73 | if(minibatch['t']): break 74 | -------------------------------------------------------------------------------- /neuralnet/net00_cnn.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | import source.utils as utils 5 | import whiteboxlayer.layers as wbl 6 | import whiteboxlayer.extensions.utility as wblu 7 | 8 | class Agent(object): 9 | 10 | def __init__(self, **kwargs): 11 | 12 | print("\nInitializing Neural Network...") 13 | 14 | self.dim_h = kwargs['dim_h'] 15 | self.dim_w = kwargs['dim_w'] 16 | self.dim_c = kwargs['dim_c'] 17 | self.num_class = kwargs['num_class'] 18 | self.ksize = kwargs['ksize'] 19 | self.learning_rate = kwargs['learning_rate'] 20 | self.path_ckpt = kwargs['path_ckpt'] 21 | 22 | self.variables = {} 23 | 24 | self.__model = Neuralnet(\ 25 | who_am_i="CNN", **kwargs, \ 26 | filters=[1, 32, 64, 128]) 27 | 28 | dummy = tf.zeros((1, self.dim_h, self.dim_w, self.dim_c), dtype=tf.float32) 29 | self.__model.forward(x=dummy, verbose=True) 30 | 31 | self.__init_propagation(path=self.path_ckpt) 32 | 33 | def __init_propagation(self, path): 34 | 35 | self.summary_writer = tf.summary.create_file_writer(self.path_ckpt) 36 | 37 | self.variables['trainable'] = [] 38 | ftxt = open("list_parameters.txt", "w") 39 | for key in list(self.__model.layer.parameters.keys()): 40 | trainable = self.__model.layer.parameters[key].trainable 41 | text = "T: " + str(key) + str(self.__model.layer.parameters[key].shape) 42 | if(trainable): 43 | self.variables['trainable'].append(self.__model.layer.parameters[key]) 44 | ftxt.write("%s\n" %(text)) 45 | ftxt.close() 46 | 47 | self.optimizer = tf.optimizers.Adam(learning_rate=self.learning_rate) 48 | self.save_params() 49 | 50 | conc_func = self.__model.__call__.get_concrete_function(\ 51 | tf.TensorSpec(shape=(1, self.dim_h, self.dim_w, self.dim_c), dtype=tf.float32)) 52 | 53 | def __loss(self, y, y_hat): 54 | 55 | entropy_b = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_hat) 56 | entropy = tf.math.reduce_mean(entropy_b) 57 | 58 | return {'entropy_b': entropy_b, 'entropy': entropy} 59 | 60 | @tf.autograph.experimental.do_not_convert 61 | def step(self, minibatch, iteration=0, training=False): 62 | 63 | x, y = minibatch['x'], minibatch['y'] 64 | 65 | with tf.GradientTape() as tape: 66 | attn, logit, y_hat = self.__model.forward(x=x, verbose=False) 67 | losses = self.__loss(y=y, y_hat=logit) 68 | 69 | if(training): 70 | gradients = tape.gradient(losses['entropy'], self.variables['trainable']) 71 | self.optimizer.apply_gradients(zip(gradients, self.variables['trainable'])) 72 | 73 | with self.summary_writer.as_default(): 74 | tf.summary.scalar('%s/entropy' %(self.__model.who_am_i), losses['entropy'], step=iteration) 75 | 76 | return {'attn':attn, 'y_hat':y_hat, 'losses':losses} 77 | 78 | def save_params(self, model='base', tflite=False): 79 | 80 | if(tflite): 81 | # https://github.com/tensorflow/tensorflow/issues/42818 82 | conc_func = self.__model.__call__.get_concrete_function(\ 83 | tf.TensorSpec(shape=(1, self.dim_h, self.dim_w, self.dim_c), dtype=tf.float32)) 84 | converter = tf.lite.TFLiteConverter.from_concrete_functions([conc_func]) 85 | 86 | converter.optimizations = [tf.lite.Optimize.DEFAULT] 87 | converter.experimental_new_converter = True 88 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] 89 | 90 | tflite_model = converter.convert() 91 | 92 | with open('model.tflite', 'wb') as f: 93 | f.write(tflite_model) 94 | else: 95 | vars_to_save = self.__model.layer.parameters.copy() 96 | vars_to_save["optimizer"] = self.optimizer 97 | 98 | ckpt = tf.train.Checkpoint(**vars_to_save) 99 | ckptman = tf.train.CheckpointManager(ckpt, directory=os.path.join(self.path_ckpt, model), max_to_keep=1) 100 | ckptman.save() 101 | 102 | def load_params(self, model): 103 | 104 | vars_to_load = self.__model.layer.parameters.copy() 105 | vars_to_load["optimizer"] = self.optimizer 106 | 107 | ckpt = tf.train.Checkpoint(**vars_to_load) 108 | latest_ckpt = tf.train.latest_checkpoint(os.path.join(self.path_ckpt, model)) 109 | status = ckpt.restore(latest_ckpt) 110 | status.expect_partial() 111 | 112 | class Neuralnet(tf.Module): 113 | 114 | def __init__(self, **kwargs): 115 | super(Neuralnet, self).__init__() 116 | 117 | self.who_am_i = kwargs['who_am_i'] 118 | self.dim_h = kwargs['dim_h'] 119 | self.dim_w = kwargs['dim_w'] 120 | self.dim_c = kwargs['dim_c'] 121 | self.ksize = kwargs['ksize'] 122 | self.num_class = kwargs['num_class'] 123 | self.filters = kwargs['filters'] 124 | 125 | self.layer = wbl.Layers() 126 | 127 | self.forward = tf.function(self.__call__) 128 | 129 | @tf.function 130 | def __call__(self, x, verbose=False): 131 | 132 | attn, logit = self.__nn(x=x, name=self.who_am_i, verbose=verbose) 133 | y_hat = tf.nn.softmax(logit, name="y_hat") 134 | 135 | return attn, logit, y_hat 136 | 137 | def __nn(self, x, name='neuralnet', verbose=True): 138 | 139 | att = None 140 | for idx, _ in enumerate(self.filters[:-1]): 141 | if(idx == 0): continue 142 | x = self.layer.conv2d(x=x, stride=1, \ 143 | filter_size=[self.ksize, self.ksize, self.filters[idx-1], self.filters[idx]], \ 144 | activation='relu', name='%s-%dconv1' %(name, idx), verbose=verbose) 145 | x = self.layer.conv2d(x=x, stride=1, \ 146 | filter_size=[self.ksize, self.ksize, self.filters[idx], self.filters[idx]], \ 147 | activation='relu', name='%s-%dconv2' %(name, idx), verbose=verbose) 148 | x = self.layer.maxpool(x=x, ksize=2, strides=2, \ 149 | name='%s-%dmp' %(name, idx), verbose=verbose) 150 | if(idx == 2): 151 | attn = wblu.attention(self.layer.conv2d(x=x, stride=1, \ 152 | filter_size=[1, 1, self.filters[idx], self.dim_c], \ 153 | activation=None, name='%s-attn' %(name), verbose=verbose)) 154 | x = x * attn 155 | 156 | x = self.layer.conv2d(x=x, stride=1, \ 157 | filter_size=[self.ksize, self.ksize, self.filters[-2], 512], \ 158 | activation='relu', name='%s-clf0' %(name), verbose=verbose) 159 | x = tf.math.reduce_mean(x, axis=(1, 2)) 160 | x = self.layer.fully_connected(x=x, c_out=self.filters[-1], \ 161 | activation='relu', name="%s-clf1" %(name), verbose=verbose) 162 | x = self.layer.fully_connected(x=x, c_out=self.num_class, \ 163 | activation=None, name="%s-clf2" %(name), verbose=verbose) 164 | 165 | return attn, x 166 | -------------------------------------------------------------------------------- /figures/loss.svg: -------------------------------------------------------------------------------- 1 | 00.10.20.30.40.50.60.70.8-10k010k20k30k40k50k60k70k80k90k100k --------------------------------------------------------------------------------