├── figures
├── 0.png
├── 1.png
├── 2.png
├── 3.png
├── 4.png
├── 5.png
├── 6.png
├── 7.png
├── 8.png
├── 9.png
├── f0.png
├── f1.png
├── f2.png
└── loss.svg
├── source
├── connector.py
├── utils.py
├── datamanager.py
└── tf_process.py
├── LICENSE
├── README.md
├── run.py
└── neuralnet
└── net00_cnn.py
/figures/0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/0.png
--------------------------------------------------------------------------------
/figures/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/1.png
--------------------------------------------------------------------------------
/figures/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/2.png
--------------------------------------------------------------------------------
/figures/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/3.png
--------------------------------------------------------------------------------
/figures/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/4.png
--------------------------------------------------------------------------------
/figures/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/5.png
--------------------------------------------------------------------------------
/figures/6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/6.png
--------------------------------------------------------------------------------
/figures/7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/7.png
--------------------------------------------------------------------------------
/figures/8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/8.png
--------------------------------------------------------------------------------
/figures/9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/9.png
--------------------------------------------------------------------------------
/figures/f0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/f0.png
--------------------------------------------------------------------------------
/figures/f1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/f1.png
--------------------------------------------------------------------------------
/figures/f2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YeongHyeon/MNIST_AttentionMap/HEAD/figures/f2.png
--------------------------------------------------------------------------------
/source/connector.py:
--------------------------------------------------------------------------------
1 | def connect(nn):
2 |
3 | if(nn == 0): import neuralnet.net00_cnn as nn
4 |
5 | return nn
6 |
--------------------------------------------------------------------------------
/source/utils.py:
--------------------------------------------------------------------------------
1 | import os, glob, shutil
2 |
3 | def make_dir(path, refresh=False):
4 |
5 | try: os.mkdir(path)
6 | except:
7 | if(refresh):
8 | shutil.rmtree(path)
9 | os.mkdir(path)
10 |
11 | def sorted_list(path):
12 |
13 | tmplist = glob.glob(path)
14 | tmplist.sort()
15 |
16 | return tmplist
17 |
18 | def min_max_norm(x):
19 |
20 | min_x, max_x = x.min(), x.max()
21 | return (x - min_x + 1e-12) / (max_x - min_x + 1e-12)
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 YeongHyeon
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [TensorFlow] Attention mechanism with MNIST dataset
2 | =====
3 |
4 | ## Usage
5 | ``` sh
6 | $ python run.py
7 | ```
8 |
9 | ## Result
10 |
11 | ### Training
12 |
13 |
14 |
15 |
16 |
Loss graph.
17 |
18 |
19 | ### Test
20 |
21 |

22 |

23 |

24 |

25 |

26 |

27 |

28 |

29 |

30 |

31 |
Each figure shows input digit, attention map, and overlapped image sequentially.
32 |
33 |
34 | ### Further usage
35 |
36 |

37 |

38 |

39 |
The further usages. Detecting the location of digits can be conducted using an attention map.
40 |
41 |
42 | ## Requirements
43 | * TensorFlow 2.3.0
44 | * Numpy 1.18.5
45 |
46 | ## Additional Resources
47 | [1] Simple attention mechanism test by Myung Jin Kim
48 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import argparse, time, os, operator
2 |
3 | import tensorflow as tf
4 | import source.connector as con
5 | import source.tf_process as tfp
6 | import source.datamanager as dman
7 |
8 | def main():
9 |
10 | os.environ["CUDA_VISIBLE_DEVICES"]=FLAGS.gpu
11 |
12 | gpus = tf.config.experimental.list_physical_devices('GPU')
13 | if gpus:
14 | try:
15 | for gpu in gpus:
16 | tf.config.experimental.set_memory_growth(gpu, True)
17 | logical_gpus = tf.config.experimental.list_logical_devices('GPU')
18 | print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
19 | except RuntimeError as e:
20 | print(e)
21 |
22 | dataset = dman.Dataset()
23 |
24 | agent = con.connect(nn=FLAGS.nn).Agent(\
25 | dim_h = dataset.height, \
26 | dim_w = dataset.width, \
27 | dim_c = dataset.channel, \
28 | num_class = dataset.num_class, \
29 | ksize = FLAGS.ksize, \
30 | learning_rate = FLAGS.lr, \
31 | path_ckpt = 'Checkpoint')
32 |
33 | time_tr = time.time()
34 | tfp.training(agent=agent, dataset=dataset, \
35 | batch_size=FLAGS.batch, epochs=FLAGS.epochs)
36 | time_te = time.time()
37 | tfp.test(agent=agent, dataset=dataset, batch_size=FLAGS.batch)
38 | time_fin = time.time()
39 |
40 | print("Time (TR): %.5f [sec]" %(time_te - time_tr))
41 | te_time = time_fin - time_te
42 | print("Time (TE): %.5f (%.5f [sec/sample])" %(te_time, te_time/dataset.num_te))
43 |
44 | if __name__ == '__main__':
45 |
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument('--gpu', type=str, default="0", help='')
48 | parser.add_argument('--nn', type=int, default=0, help='')
49 | parser.add_argument('--ksize', type=int, default=3, help='')
50 | parser.add_argument('--lr', type=float, default=1e-4, help='')
51 | parser.add_argument('--batch', type=int, default=32, help='')
52 | parser.add_argument('--epochs', type=int, default=100, help='')
53 |
54 | FLAGS, unparsed = parser.parse_known_args()
55 |
56 | main()
57 |
--------------------------------------------------------------------------------
/source/datamanager.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import source.utils as utils
4 | from sklearn.utils import shuffle
5 |
6 | class Dataset(object):
7 |
8 | def __init__(self, normalize=True):
9 |
10 | print("\nInitializing Dataset...")
11 |
12 | self.normalize = normalize
13 |
14 | (x_tr, y_tr), (x_te, y_te) = tf.keras.datasets.mnist.load_data()
15 | self.x_tr, self.y_tr = x_tr, y_tr
16 | self.x_te, self.y_te = x_te, y_te
17 |
18 | self.x_tr = np.ndarray.astype(self.x_tr, np.float32)
19 | self.x_te = np.ndarray.astype(self.x_te, np.float32)
20 |
21 | self.__normalizing()
22 |
23 | self.num_tr, self.num_te = self.x_tr.shape[0], self.x_te.shape[0]
24 | self.idx_tr, self.idx_te = 0, 0
25 |
26 | x_sample, y_sample = self.x_te[0], self.y_te[0]
27 | self.height = x_sample.shape[0]
28 | self.width = x_sample.shape[1]
29 | try: self.channel = x_sample.shape[2]
30 | except: self.channel = 1
31 |
32 | self.num_class = (y_te.max()+1)
33 |
34 | def __normalizing(self):
35 |
36 | for idx, _ in enumerate(self.x_tr):
37 | self.x_tr[idx] = utils.min_max_norm(self.x_tr[idx])
38 |
39 | for idx, _ in enumerate(self.x_te):
40 | self.x_te[idx] = utils.min_max_norm(self.x_te[idx])
41 |
42 | def __reset_index(self):
43 |
44 | self.idx_tr, self.idx_te = 0, 0
45 |
46 | def next_batch(self, batch_size=1, tt=0):
47 |
48 | if(tt == 0):
49 | idx_d, num_d, data_x, data_y = self.idx_tr, self.num_tr, self.x_tr, self.y_tr
50 | elif(tt == 1):
51 | idx_d, num_d, data_x, data_y = self.idx_te, self.num_te, self.x_te, self.y_te
52 |
53 | batch_x, batch_y, terminator = [], [], False
54 | while(True):
55 | try:
56 | tmp_x, tmp_y = data_x[idx_d].copy(), data_y[idx_d].copy()
57 | except:
58 | idx_d = 0
59 | self.x_tr, self.y_tr = shuffle(self.x_tr, self.y_tr)
60 | terminator = True
61 | break
62 | else:
63 | batch_x.append(np.expand_dims(tmp_x, axis=-1))
64 | batch_y.append(np.diag(np.ones(self.num_class))[tmp_y])
65 | idx_d += 1
66 | if(len(batch_x) == batch_size): break
67 |
68 | batch_x = np.asarray(batch_x)
69 | batch_y = np.asarray(batch_y)
70 |
71 | if(tt == 0):
72 | self.idx_tr = idx_d
73 | elif(tt == 1):
74 | self.idx_te = idx_d
75 |
76 | return {'x':batch_x.astype(np.float32), 'y':batch_y.astype(np.float32), 't':terminator}
77 |
--------------------------------------------------------------------------------
/source/tf_process.py:
--------------------------------------------------------------------------------
1 | import os
2 | import scipy.ndimage
3 | import numpy as np
4 | import pandas as pd
5 | import matplotlib.pyplot as plt
6 | from sklearn.metrics import roc_curve, auc
7 |
8 | import source.utils as utils
9 |
10 | def training(agent, dataset, batch_size, epochs):
11 |
12 | print("\n** Training of the CNN to %d epoch | Batch size: %d" %(epochs, batch_size))
13 | iteration = 0
14 |
15 | for epoch in range(epochs):
16 |
17 | while(True):
18 | minibatch = dataset.next_batch(batch_size=batch_size, tt=0)
19 | if(len(minibatch['x'].shape) == 1): break
20 | step_dict = agent.step(minibatch=minibatch, iteration=iteration, training=True)
21 | iteration += 1
22 | if(minibatch['t']): break
23 |
24 | print("Epoch [%d / %d] | Loss: %f" %(epoch, epochs, step_dict['losses']['entropy']))
25 | agent.save_params(model='model_0_finepocch')
26 |
27 | def test(agent, dataset, batch_size):
28 |
29 | savedir = 'results_te'
30 | utils.make_dir(path=savedir, refresh=True)
31 |
32 | list_model = utils.sorted_list(os.path.join('Checkpoint', 'model*'))
33 | for idx_model, path_model in enumerate(list_model):
34 | list_model[idx_model] = path_model.split('/')[-1]
35 |
36 | for idx_model, path_model in enumerate(list_model):
37 |
38 | print("\n** Test with %s" %(path_model))
39 | agent.load_params(model=path_model)
40 | utils.make_dir(path=os.path.join(savedir, path_model), refresh=False)
41 |
42 | saveidx = 0
43 | while(True):
44 | minibatch = dataset.next_batch(batch_size=batch_size, tt=1)
45 | if(len(minibatch['x'].shape) == 1): break
46 | step_dict = agent.step(minibatch=minibatch, training=False)
47 |
48 | for idx_y, _ in enumerate(minibatch['y']):
49 | y_true = np.argmax(minibatch['y'][idx_y])
50 | y_pred = np.argmax(step_dict['y_hat'][idx_y])
51 |
52 | canvas, canvas_attn = \
53 | minibatch['x'][idx_y], scipy.ndimage.zoom(step_dict['attn'][idx_y].numpy(), 4, order=3)
54 |
55 | plt.clf()
56 | plt.figure(figsize=(8, 5))
57 |
58 | plt.subplot(1, 2, 1)
59 | plt.axis('off')
60 | plt.title("Input")
61 | plt.imshow(canvas[:, :, 0], cmap='gray')
62 |
63 | plt.subplot(1, 2, 2)
64 | plt.axis('off')
65 | plt.title("Attention Map")
66 | plt.imshow(canvas_attn[:, :, 0], cmap='jet')
67 |
68 | plt.tight_layout()
69 | plt.savefig(os.path.join(savedir, path_model, "true_%d;pred_%d;%08d.png" %(y_true, y_pred, saveidx)))
70 | plt.close()
71 | saveidx += 1
72 |
73 | if(minibatch['t']): break
74 |
--------------------------------------------------------------------------------
/neuralnet/net00_cnn.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import tensorflow as tf
4 | import source.utils as utils
5 | import whiteboxlayer.layers as wbl
6 | import whiteboxlayer.extensions.utility as wblu
7 |
8 | class Agent(object):
9 |
10 | def __init__(self, **kwargs):
11 |
12 | print("\nInitializing Neural Network...")
13 |
14 | self.dim_h = kwargs['dim_h']
15 | self.dim_w = kwargs['dim_w']
16 | self.dim_c = kwargs['dim_c']
17 | self.num_class = kwargs['num_class']
18 | self.ksize = kwargs['ksize']
19 | self.learning_rate = kwargs['learning_rate']
20 | self.path_ckpt = kwargs['path_ckpt']
21 |
22 | self.variables = {}
23 |
24 | self.__model = Neuralnet(\
25 | who_am_i="CNN", **kwargs, \
26 | filters=[1, 32, 64, 128])
27 |
28 | dummy = tf.zeros((1, self.dim_h, self.dim_w, self.dim_c), dtype=tf.float32)
29 | self.__model.forward(x=dummy, verbose=True)
30 |
31 | self.__init_propagation(path=self.path_ckpt)
32 |
33 | def __init_propagation(self, path):
34 |
35 | self.summary_writer = tf.summary.create_file_writer(self.path_ckpt)
36 |
37 | self.variables['trainable'] = []
38 | ftxt = open("list_parameters.txt", "w")
39 | for key in list(self.__model.layer.parameters.keys()):
40 | trainable = self.__model.layer.parameters[key].trainable
41 | text = "T: " + str(key) + str(self.__model.layer.parameters[key].shape)
42 | if(trainable):
43 | self.variables['trainable'].append(self.__model.layer.parameters[key])
44 | ftxt.write("%s\n" %(text))
45 | ftxt.close()
46 |
47 | self.optimizer = tf.optimizers.Adam(learning_rate=self.learning_rate)
48 | self.save_params()
49 |
50 | conc_func = self.__model.__call__.get_concrete_function(\
51 | tf.TensorSpec(shape=(1, self.dim_h, self.dim_w, self.dim_c), dtype=tf.float32))
52 |
53 | def __loss(self, y, y_hat):
54 |
55 | entropy_b = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_hat)
56 | entropy = tf.math.reduce_mean(entropy_b)
57 |
58 | return {'entropy_b': entropy_b, 'entropy': entropy}
59 |
60 | @tf.autograph.experimental.do_not_convert
61 | def step(self, minibatch, iteration=0, training=False):
62 |
63 | x, y = minibatch['x'], minibatch['y']
64 |
65 | with tf.GradientTape() as tape:
66 | attn, logit, y_hat = self.__model.forward(x=x, verbose=False)
67 | losses = self.__loss(y=y, y_hat=logit)
68 |
69 | if(training):
70 | gradients = tape.gradient(losses['entropy'], self.variables['trainable'])
71 | self.optimizer.apply_gradients(zip(gradients, self.variables['trainable']))
72 |
73 | with self.summary_writer.as_default():
74 | tf.summary.scalar('%s/entropy' %(self.__model.who_am_i), losses['entropy'], step=iteration)
75 |
76 | return {'attn':attn, 'y_hat':y_hat, 'losses':losses}
77 |
78 | def save_params(self, model='base', tflite=False):
79 |
80 | if(tflite):
81 | # https://github.com/tensorflow/tensorflow/issues/42818
82 | conc_func = self.__model.__call__.get_concrete_function(\
83 | tf.TensorSpec(shape=(1, self.dim_h, self.dim_w, self.dim_c), dtype=tf.float32))
84 | converter = tf.lite.TFLiteConverter.from_concrete_functions([conc_func])
85 |
86 | converter.optimizations = [tf.lite.Optimize.DEFAULT]
87 | converter.experimental_new_converter = True
88 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
89 |
90 | tflite_model = converter.convert()
91 |
92 | with open('model.tflite', 'wb') as f:
93 | f.write(tflite_model)
94 | else:
95 | vars_to_save = self.__model.layer.parameters.copy()
96 | vars_to_save["optimizer"] = self.optimizer
97 |
98 | ckpt = tf.train.Checkpoint(**vars_to_save)
99 | ckptman = tf.train.CheckpointManager(ckpt, directory=os.path.join(self.path_ckpt, model), max_to_keep=1)
100 | ckptman.save()
101 |
102 | def load_params(self, model):
103 |
104 | vars_to_load = self.__model.layer.parameters.copy()
105 | vars_to_load["optimizer"] = self.optimizer
106 |
107 | ckpt = tf.train.Checkpoint(**vars_to_load)
108 | latest_ckpt = tf.train.latest_checkpoint(os.path.join(self.path_ckpt, model))
109 | status = ckpt.restore(latest_ckpt)
110 | status.expect_partial()
111 |
112 | class Neuralnet(tf.Module):
113 |
114 | def __init__(self, **kwargs):
115 | super(Neuralnet, self).__init__()
116 |
117 | self.who_am_i = kwargs['who_am_i']
118 | self.dim_h = kwargs['dim_h']
119 | self.dim_w = kwargs['dim_w']
120 | self.dim_c = kwargs['dim_c']
121 | self.ksize = kwargs['ksize']
122 | self.num_class = kwargs['num_class']
123 | self.filters = kwargs['filters']
124 |
125 | self.layer = wbl.Layers()
126 |
127 | self.forward = tf.function(self.__call__)
128 |
129 | @tf.function
130 | def __call__(self, x, verbose=False):
131 |
132 | attn, logit = self.__nn(x=x, name=self.who_am_i, verbose=verbose)
133 | y_hat = tf.nn.softmax(logit, name="y_hat")
134 |
135 | return attn, logit, y_hat
136 |
137 | def __nn(self, x, name='neuralnet', verbose=True):
138 |
139 | att = None
140 | for idx, _ in enumerate(self.filters[:-1]):
141 | if(idx == 0): continue
142 | x = self.layer.conv2d(x=x, stride=1, \
143 | filter_size=[self.ksize, self.ksize, self.filters[idx-1], self.filters[idx]], \
144 | activation='relu', name='%s-%dconv1' %(name, idx), verbose=verbose)
145 | x = self.layer.conv2d(x=x, stride=1, \
146 | filter_size=[self.ksize, self.ksize, self.filters[idx], self.filters[idx]], \
147 | activation='relu', name='%s-%dconv2' %(name, idx), verbose=verbose)
148 | x = self.layer.maxpool(x=x, ksize=2, strides=2, \
149 | name='%s-%dmp' %(name, idx), verbose=verbose)
150 | if(idx == 2):
151 | attn = wblu.attention(self.layer.conv2d(x=x, stride=1, \
152 | filter_size=[1, 1, self.filters[idx], self.dim_c], \
153 | activation=None, name='%s-attn' %(name), verbose=verbose))
154 | x = x * attn
155 |
156 | x = self.layer.conv2d(x=x, stride=1, \
157 | filter_size=[self.ksize, self.ksize, self.filters[-2], 512], \
158 | activation='relu', name='%s-clf0' %(name), verbose=verbose)
159 | x = tf.math.reduce_mean(x, axis=(1, 2))
160 | x = self.layer.fully_connected(x=x, c_out=self.filters[-1], \
161 | activation='relu', name="%s-clf1" %(name), verbose=verbose)
162 | x = self.layer.fully_connected(x=x, c_out=self.num_class, \
163 | activation=None, name="%s-clf2" %(name), verbose=verbose)
164 |
165 | return attn, x
166 |
--------------------------------------------------------------------------------
/figures/loss.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------