├── README.md ├── imgs └── Figure_1.png └── nets ├── __init__.py ├── l_softmax.py ├── mnist_lsoftmax.py └── vis_results.py /README.md: -------------------------------------------------------------------------------- 1 | # L_Softmax_TensorFlow 2 | TensorFlow version of L_SoftMax. 3 | ### Results: 4 | ![results](imgs/Figure_1.png) 5 | 6 | I found `prelu` is quite stable than `relu`, so I used `prelu` as paper said. 7 | ### Reference: 8 | * [mx-lsoftmax](https://github.com/luoyetx/mx-lsoftmax) 9 | * [Large-Margin Softmax Loss for Convolutional Neural Networks](https://arxiv.org/pdf/1612.02295.pdf) 10 | ### Contribution 11 | This is mainly implemented by `py_func`, which is quite slow. If anyone have implemented a `tf_op` in C++ or cuda, pull request is warmly welcome. -------------------------------------------------------------------------------- /imgs/Figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/auroua/L_Softmax_TensorFlow/a017f571b2a12da74834b9e9962ec6ce64d56c50/imgs/Figure_1.png -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/auroua/L_Softmax_TensorFlow/a017f571b2a12da74834b9e9962ec6ce64d56c50/nets/__init__.py -------------------------------------------------------------------------------- /nets/l_softmax.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import uuid 4 | import tensorflow as tf 5 | 6 | margin = 4 7 | beta = 100 8 | scale = 0.99 9 | beta_min = 0 10 | eps = 0 11 | c_map = [] 12 | k_map = [] 13 | c_m_n = lambda m, n: math.factorial(n) / math.factorial(m) / math.factorial(n - m) 14 | for i in range(margin + 1): 15 | c_map.append(c_m_n(i, margin)) 16 | k_map.append(math.cos(i * math.pi / margin)) 17 | 18 | 19 | def find_k(cos_t): 20 | '''find k for cos(theta) 21 | ''' 22 | # for numeric issue 23 | eps = 1e-5 24 | le = lambda x, y: x < y or abs(x - y) < eps 25 | for i in range(margin): 26 | if le(k_map[i + 1], cos_t) and le(cos_t, k_map[i]): 27 | return i 28 | raise ValueError('can not find k for cos_t = %f' % cos_t) 29 | 30 | 31 | def find_k_vector(cos_t_vec): 32 | k_val = [] 33 | for i in range(cos_t_vec.shape[0]): 34 | try: 35 | k_val.append(find_k(cos_t_vec[i])) 36 | except ValueError: 37 | print(cos_t_vec) 38 | return k_val 39 | 40 | 41 | def calc_cos_mt(cos_t): 42 | '''calculate cos(m*theta) 43 | ''' 44 | cos_mt = 0 45 | sin2_t = 1 - cos_t * cos_t 46 | flag = -1 47 | for p in range(margin // 2 + 1): 48 | flag *= -1 49 | cos_mt += flag * c_map[2*p] * pow(cos_t, margin-2*p) * pow(sin2_t, p) 50 | return cos_mt 51 | 52 | 53 | def calc_cos_mt_vector(cos_t_vector): 54 | cos_mt_val = [] 55 | for i in range(cos_t_vector.shape[0]): 56 | cos_mt_val.append(calc_cos_mt(cos_t_vector[i])) 57 | return cos_mt_val 58 | 59 | 60 | def lsoftmax(x, weights, labels): 61 | def _lsoftmax(net_val, weights, labels): 62 | global beta, scale 63 | normalize_net = np.linalg.norm(net_val, axis=1).reshape([net_val.shape[0], 1]) 64 | normalize_weights = np.linalg.norm(weights, axis=0).reshape([-1, weights.shape[1]]) 65 | normalize_val = normalize_net * normalize_weights 66 | 67 | indexes = np.arange(net_val.shape[0]) 68 | labels = labels.reshape((-1,)) 69 | 70 | normalize_val_target = normalize_val[indexes, labels] 71 | logit = np.dot(net_val, weights) 72 | cos_t_target = logit[indexes, labels] / (normalize_val_target + eps) 73 | k_val = np.array(find_k_vector(cos_t_target)) 74 | cos_mt_val = np.array(calc_cos_mt_vector(cos_t_target)) 75 | logit_output_cos = np.power(-1, k_val) * cos_mt_val - 2 * k_val 76 | logit_output = logit_output_cos * normalize_val_target 77 | logit_output_beta = (logit_output + beta * logit[indexes, labels]) / (1 + beta) 78 | logit[indexes, labels] = logit_output_beta 79 | return logit 80 | 81 | def _lsoftmax_grad(x, w, label, grad): 82 | global beta, scale, beta_min 83 | # original without lsoftmax 84 | w_grad = x.T.dot(grad) # 2, 10 85 | x_grad = grad.dot(w.T) # 2, 2 86 | n = label.shape[0] 87 | m = w.shape[1] 88 | feature_dim = w.shape[0] 89 | cos_t = np.zeros(n, dtype=np.float32) 90 | cos_mt = np.zeros(n, dtype=np.float32) 91 | sin2_t = np.zeros(n, dtype=np.float32) 92 | fo = np.zeros(n, dtype=np.float32) 93 | k = np.zeros(n, dtype=np.int32) 94 | x_norm = np.linalg.norm(x, axis=1) 95 | w_norm = np.linalg.norm(w, axis=0) 96 | w_tmp = w.T 97 | for i in range(n): 98 | yi = int(label[i]) 99 | f = w_tmp[yi].dot(x[i]) 100 | cos_t[i] = f / (w_norm[yi] * x_norm[i]) 101 | k[i] = find_k(cos_t[i]) 102 | cos_mt[i] = calc_cos_mt(cos_t[i]) 103 | sin2_t[i] = 1 - cos_t[i]*cos_t[i] 104 | fo[i] = f 105 | # gradient w.r.t. x_i 106 | for i in range(n): 107 | # df / dx at x = x_i, w = w_yi 108 | j = yi = int(label[i]) 109 | dcos_dx = w_tmp[yi] / (w_norm[yi]*x_norm[i]) - x[i] * fo[i] / (w_norm[yi]*pow(x_norm[i], 3)) 110 | dsin2_dx = -2 * cos_t[i] * dcos_dx 111 | dcosm_dx = margin*pow(cos_t[i], margin-1) * dcos_dx # p = 0 112 | flag = 1 113 | for p in range(1, margin//2+1): 114 | flag *= -1 115 | dcosm_dx += flag * c_map[2*p] * (p*pow(cos_t[i], margin-2*p)*pow(sin2_t[i], p-1)*dsin2_dx + 116 | (margin-2*p)*pow(cos_t[i], margin-2*p-1)*pow(sin2_t[i], p)*dcos_dx) 117 | df_dx = (pow(-1, k[i]) * cos_mt[i] - 2*k[i]) * w_norm[yi] / x_norm[i] * x[i] + \ 118 | pow(-1, k[i]) * w_norm[yi] * x_norm[i] * dcosm_dx 119 | alpha = 1 / (1 + beta) 120 | x_grad[i] += alpha * grad[i, yi] * (df_dx - w_tmp[yi]) 121 | # gradient w.r.t. w_j 122 | for j in range(m): 123 | dw = np.zeros(feature_dim, dtype=np.float32) 124 | for i in range(n): 125 | yi = int(label[i]) 126 | if yi == j: 127 | # df / dw at x = x_i, w = w_yi and yi == j 128 | dcos_dw = x[i] / (w_norm[yi]*x_norm[i]) - w_tmp[yi] * fo[i] / (x_norm[i]*pow(w_norm[yi], 3)) 129 | dsin2_dw = -2 * cos_t[i] * dcos_dw 130 | dcosm_dw = margin*pow(cos_t[i], margin-1) * dcos_dw # p = 0 131 | flag = 1 132 | for p in range(1, margin//2+1): 133 | flag *= -1 134 | dcosm_dw += flag * c_map[2*p] * (p*pow(cos_t[i], margin-2*p)*pow(sin2_t[i], p-1)*dsin2_dw + 135 | (margin-2*p)*pow(cos_t[i], margin-2*p-1)*pow(sin2_t[i], p)*dcos_dw) 136 | df_dw_j = (pow(-1, k[i]) * cos_mt[i] - 2*k[i]) * x_norm[i] / w_norm[yi] * w_tmp[yi] + \ 137 | pow(-1, k[i]) * w_norm[yi] * x_norm[i] * dcosm_dw 138 | dw += grad[i, yi] * (df_dw_j - x[i]) 139 | alpha = 1 / (1 + beta) 140 | w_grad[:, j] += alpha * dw 141 | beta *= scale 142 | beta = max(beta, beta_min) 143 | return x_grad, w_grad 144 | 145 | def _lsoftmax_grad_op(op, grad): 146 | x = op.inputs[0] 147 | weights = op.inputs[1] 148 | labels = op.inputs[2] 149 | x_grad, w_grad = tf.py_func(_lsoftmax_grad, [x, weights, labels, grad], [tf.float32, tf.float32]) 150 | return x_grad, w_grad, labels 151 | 152 | grad_name = 'lsoftmax_' + str(uuid.uuid4()) 153 | tf.RegisterGradient(grad_name)(_lsoftmax_grad_op) 154 | 155 | g = tf.get_default_graph() 156 | with g.gradient_override_map({"PyFunc": grad_name}): 157 | output = tf.py_func(_lsoftmax, [x, weights, labels], tf.float32) 158 | return output -------------------------------------------------------------------------------- /nets/mnist_lsoftmax.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.examples.tutorials.mnist import input_data 3 | from l_softmax import lsoftmax 4 | import numpy as np 5 | 6 | flags = tf.app.flags 7 | flags.DEFINE_integer('channel', 1, 'mnist input channel') 8 | flags.DEFINE_integer('batch_size', 100, 'input image batch size') 9 | flags.DEFINE_string('data_dir', '/home/aurora/workspaces/data/mnist', 'mnist dataset dir') 10 | flags.DEFINE_string('save_dir', '/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/', 'data dir') 11 | FLAGS = flags.FLAGS 12 | 13 | 14 | def prelu(_x, scope=None): 15 | """parametric ReLU activation""" 16 | with tf.variable_scope(name_or_scope=scope, default_name="prelu"): 17 | _alpha = tf.get_variable("prelu", shape=_x.get_shape()[-1], 18 | dtype=_x.dtype, initializer=tf.constant_initializer(0.1)) 19 | return tf.maximum(0.0, _x) + _alpha * tf.minimum(0.0, _x) 20 | 21 | 22 | class MNIST(object): 23 | def __init__(self): 24 | self.images = tf.placeholder(name='input_images', shape=[None, 28*28], dtype=tf.float32) 25 | self.images_reshape = tf.reshape(self.images, shape=[-1, 28, 28, FLAGS.channel]) 26 | self.labels = tf.placeholder(name='image_labels', shape=[None], dtype=tf.int64) 27 | 28 | def inference(self): 29 | # conv1 30 | weights1 = tf.get_variable(name='weights1', shape=[5, 5, FLAGS.channel, 32], dtype=tf.float32, 31 | initializer=tf.contrib.layers.xavier_initializer()) 32 | bias1 = tf.Variable(tf.constant(0.0, shape=[32]), name='bias1') 33 | conv1_output = tf.nn.conv2d(self.images_reshape, filter=weights1, strides=[1, 1, 1, 1], 34 | padding='SAME', name='conv1') + bias1 35 | # conv1_output = tf.nn.relu(conv1_output) 36 | conv1_output = prelu(conv1_output, scope='conv1') 37 | conv1_output = tf.nn.max_pool(conv1_output, strides=[1, 2, 2, 1], ksize=[1, 2, 2, 1], padding='SAME', 38 | name='conv1_max_pool') 39 | tf.logging.info(conv1_output.op.name + ': ' + str(conv1_output.get_shape())) 40 | tf.add_to_collection('weights', tf.nn.l2_loss(weights1)) 41 | 42 | # conv2 43 | weights2 = tf.get_variable(name='weights2', shape=[5, 5, 32, 64], dtype=tf.float32, 44 | initializer=tf.contrib.layers.xavier_initializer()) 45 | bias2 = tf.Variable(tf.constant(0.0, shape=[64]), name='bias2') 46 | net = tf.nn.conv2d(conv1_output, filter=weights2, strides=[1, 1, 1, 1], padding='SAME', name='conv2') 47 | # net = tf.nn.relu(net + bias2) 48 | net = prelu(net + bias2, scope='conv2') 49 | net = tf.nn.max_pool(net, strides=[1, 2, 2, 1], ksize=[1, 2, 2, 1], padding='SAME', name='conv2_max_pool') 50 | tf.logging.info(net.op.name + ': ' + str(net.get_shape())) 51 | tf.add_to_collection('weights', tf.nn.l2_loss(weights2)) 52 | 53 | # conv3 54 | conv_weights3 = tf.get_variable(name='conv_weights3', shape=[5, 5, 64, 128], dtype=tf.float32, 55 | initializer=tf.contrib.layers.xavier_initializer()) 56 | conv_bias3 = tf.Variable(tf.constant(0.0, shape=[128]), name='conv_bias3') 57 | net = tf.nn.conv2d(net, filter=conv_weights3, strides=[1, 1, 1, 1], padding='SAME', name='conv3') 58 | # net = tf.nn.relu(net + conv_bias3) 59 | net = prelu(net + conv_bias3, scope='conv3') 60 | net = tf.nn.max_pool(net, strides=[1, 2, 2, 1], ksize=[1, 2, 2, 1], padding='SAME', name='conv3_max_pool') 61 | tf.logging.info(net.op.name + ': ' + str(net.get_shape())) 62 | tf.add_to_collection('weights', tf.nn.l2_loss(conv_weights3)) 63 | 64 | # fc1 65 | shapes = net.get_shape().as_list() 66 | dim_len = shapes[1] * shapes[2] * shapes[3] 67 | net = tf.reshape(net, shape=(-1, dim_len)) 68 | weights3 = tf.get_variable(name='weights3', shape=[dim_len, 256], dtype=tf.float32, 69 | initializer=tf.contrib.layers.xavier_initializer()) 70 | # bias3 = tf.Variable(tf.constant(0.0, shape=[256]), name='bias3') 71 | # net = tf.matmul(net, weights3) + bias3 72 | net = tf.matmul(net, weights3) 73 | # net = tf.nn.relu(net) 74 | net = prelu(net, scope='fc1') 75 | tf.logging.info(net.op.name + ': ' + str(net.get_shape())) 76 | tf.add_to_collection('weights', tf.nn.l2_loss(weights3)) 77 | 78 | # fc2 79 | weights4_1 = tf.get_variable(name='weights4_1', shape=[256, 2], dtype=tf.float32, 80 | initializer=tf.contrib.layers.xavier_initializer()) 81 | # bias4_1 = tf.get_variable(name='bias4_1', shape=[2], dtype=tf.float32) 82 | # net = tf.matmul(net, weights4_1) + bias4_1 83 | net = tf.matmul(net, weights4_1) 84 | 85 | tf.add_to_collection('weights', tf.nn.l2_loss(weights4_1)) 86 | tf.logging.info(net.op.name + ': ' + str(net.get_shape())) 87 | 88 | # fc3 89 | weights4 = tf.get_variable(name='weights4', shape=[2, 10], dtype=tf.float32, 90 | initializer=tf.contrib.layers.xavier_initializer()) 91 | tf.add_to_collection('weights', tf.nn.l2_loss(weights4)) 92 | 93 | logit = lsoftmax(net, weights4, self.labels) 94 | logit.set_shape([FLAGS.batch_size, weights4.shape.as_list()[1]]) 95 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logit, labels=self.labels) 96 | loss = tf.reduce_mean(loss) 97 | reg_loss = tf.add_n(tf.get_collection('weights')) 98 | total_loss = loss + 0.0005*reg_loss 99 | pred = tf.nn.softmax(logit) 100 | acc = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred, axis=1), self.labels), dtype=tf.float32)) 101 | return logit, loss, reg_loss, total_loss, acc, net 102 | 103 | def get_shape(self, input_tensor): 104 | static_shape = input_tensor.get_shape().as_list() 105 | dynamic_shape = tf.unstack(tf.shape(input_tensor)) 106 | dims = [dim_tensors[0] if dim_tensors[0] is not None else dim_tensors[1] for dim_tensors in zip(static_shape, dynamic_shape)] 107 | return dims 108 | 109 | 110 | if __name__ == '__main__': 111 | mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=False, validation_size=0) 112 | tf.logging.set_verbosity(tf.logging.DEBUG) 113 | minst_net = MNIST() 114 | logit, loss, reg_loss, total_loss, acc, net = minst_net.inference() 115 | 116 | sess = tf.Session() 117 | 118 | # for var in tf.trainable_variables(): 119 | # print(var.op.name) 120 | 121 | global_steps = tf.Variable(0, trainable=False, name='global_steps') 122 | learning_rate = tf.train.exponential_decay( 123 | 0.01, # Base learning rate. 124 | global_steps, # Current index into the dataset. 125 | 3000, # Decay step. 126 | 0.9, # Decay rate. 127 | staircase=True) 128 | tf.summary.scalar('lr', learning_rate) 129 | 130 | step_ops = tf.assign_add(global_steps, 1) 131 | # optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) 132 | optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) 133 | gvs = optimizer.compute_gradients(total_loss, tf.trainable_variables()) 134 | train_op = optimizer.apply_gradients(gvs) 135 | 136 | sess.run(tf.global_variables_initializer()) 137 | # summaries 138 | tf.summary.scalar("loss", loss) 139 | merged_summary_op = tf.summary.merge_all() 140 | summary_writer = tf.summary.FileWriter('/home/aurora/workspaces/PycharmProjects/tensorflow/' 141 | 'L_SoftMax_TensorFlow/logs', graph=sess.graph) 142 | for i in range(20): 143 | for j in range(600): 144 | batch = mnist.train.next_batch(FLAGS.batch_size) 145 | if j % 100 == 0 and i != 0: 146 | train_accuracy, summary, loss_val = sess.run([acc, merged_summary_op, total_loss], 147 | feed_dict={minst_net.images: batch[0], minst_net.labels: batch[1]}) 148 | print("epoch %d, step %d, training accuracy %g, loss_val %g" % (i, j, train_accuracy, loss_val)) 149 | summary_writer.add_summary(summary, global_step=i) 150 | summary_writer.flush() 151 | _, _, loss_val = sess.run([train_op, step_ops, total_loss], feed_dict={minst_net.images: batch[0], minst_net.labels: batch[1]}) 152 | 153 | net_vals = np.zeros((10000, 2), dtype=np.float32) 154 | net_labels = np.zeros((10000, 1), dtype=np.int32) 155 | total_acc = 0 156 | for i in range(100): 157 | batch = mnist.test.next_batch(FLAGS.batch_size) 158 | 159 | hidden_net, acc_val = sess.run([net, acc], feed_dict={minst_net.images: batch[0], minst_net.labels:batch[1]}) 160 | total_acc += acc_val 161 | net_vals[i*FLAGS.batch_size:(i+1)*FLAGS.batch_size, :] = hidden_net 162 | labels = batch[1] 163 | labels = labels[:, np.newaxis] 164 | net_labels[i*FLAGS.batch_size:(i+1)*FLAGS.batch_size, :] = labels 165 | np.save(FLAGS.save_dir+'hidden_m4', net_vals) 166 | np.save(FLAGS.save_dir+'labels_m4', net_labels) 167 | print("test accuracy %g" % (total_acc/100)) # test accuracy m=1 0.9811 m=2 0.982 m3=0.986 m4=0.9846 168 | # m5=0.9869 m6=0.9874 m7=0.889 m8=0.7902 169 | 170 | -------------------------------------------------------------------------------- /nets/vis_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | if __name__ == '__main__': 6 | hidden = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/hidden_m1.npy') 7 | labels = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/labels_m1.npy') 8 | 9 | hidden_l2 = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/hidden_m2.npy') 10 | labels_l2 = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/labels_m2.npy') 11 | # 12 | hidden_l3 = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/hidden_m3.npy') 13 | labels_l3 = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/labels_m3.npy') 14 | # # 15 | hidden_l4 = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/hidden_m4.npy') 16 | labels_l4 = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/labels_m4.npy') 17 | 18 | hidden_l5 = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/hidden_m5.npy') 19 | labels_l5 = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/labels_m5.npy') 20 | 21 | hidden_l6 = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/hidden_m6.npy') 22 | labels_l6 = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/labels_m6.npy') 23 | 24 | hidden_l7 = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/hidden_m7.npy') 25 | labels_l7 = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/labels_m7.npy') 26 | 27 | hidden_l8 = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/hidden_m8.npy') 28 | labels_l8 = np.load('/home/aurora/workspaces/PycharmProjects/tensorflow/L_SoftMax_TensorFlow/data/labels_m8.npy') 29 | 30 | plt.set_cmap('hsv') 31 | plt.subplot(241) 32 | m1 = plt.scatter(hidden[:, 0], hidden[:, 1], c=labels, label='m=1, test_acc=0.9811') 33 | plt.legend(handles=[m1], bbox_to_anchor=(0., 1.02, 1., .102), loc=3, mode="expand", borderaxespad=0.) 34 | 35 | plt.subplot(242) 36 | m2 = plt.scatter(hidden_l2[:, 0], hidden_l2[:, 1], c=labels_l2, label='m=2, test_acc=0.982, \n beta=100, scale=0.99') 37 | plt.legend(handles=[m2], bbox_to_anchor=(0., 1.02, 1., .102), loc=3, mode="expand", borderaxespad=0.) 38 | # 39 | plt.subplot(243) 40 | m3 = plt.scatter(hidden_l3[:, 0], hidden_l3[:, 1], c=labels_l3, label='m=3, test_acc=0.986, \n beta=100, scale=0.99') 41 | plt.legend(handles=[m3], bbox_to_anchor=(0., 1.02, 1., .102), loc=3, mode="expand", borderaxespad=0.) 42 | # # # 43 | plt.subplot(244) 44 | m4 = plt.scatter(hidden_l4[:, 0], hidden_l4[:, 1], c=labels_l4, 45 | label='m=4, test_acc=0.9846, \n beta=100, scale=0.99') 46 | plt.legend(handles=[m4], bbox_to_anchor=(0., 1.02, 1., .102), loc=3, mode="expand", borderaxespad=0.) 47 | # 48 | plt.subplot(245) 49 | m5 = plt.scatter(hidden_l5[:, 0], hidden_l5[:, 1], c=labels_l5, 50 | label='m=5, test_acc=0.9869, \n beta=100, scale=0.99') 51 | plt.legend(handles=[m5], bbox_to_anchor=(0., 1.02, 1., .102), loc=3, mode="expand", borderaxespad=0.) 52 | 53 | plt.subplot(246) 54 | m6 = plt.scatter(hidden_l6[:, 0], hidden_l6[:, 1], c=labels_l6, 55 | label='m=6, test_acc=0.9874, \n beta=100, scale=0.99') 56 | plt.legend(handles=[m6], bbox_to_anchor=(0., 1.02, 1., .102), loc=3, mode="expand", borderaxespad=0.) 57 | 58 | plt.subplot(247) 59 | m7 = plt.scatter(hidden_l7[:, 0], hidden_l7[:, 1], c=labels_l7, 60 | label='m=7, test_acc=0.889, \n beta=100, scale=0.99') 61 | plt.legend(handles=[m7], bbox_to_anchor=(0., 1.02, 1., .102), loc=3, mode="expand", borderaxespad=0.) 62 | 63 | plt.subplot(248) 64 | m8 = plt.scatter(hidden_l8[:, 0], hidden_l8[:, 1], c=labels_l8, 65 | label='m=8, test_acc=0.7902, \n beta=100, scale=0.99') 66 | plt.legend(handles=[m8], bbox_to_anchor=(0., 1.02, 1., .102), loc=3, mode="expand", borderaxespad=0.) 67 | 68 | plt.show() --------------------------------------------------------------------------------