├── .gitattributes ├── README.md ├── dataset_mini.py ├── main.py ├── model.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Attentive Weights Generation for Few Shot Learning via Information Maximization 2 | 3 | Published at CVPR 2020 4 | 5 | By Yiluan Guo, Ngai-Man Cheung 6 | 7 | [Paper Link](http://openaccess.thecvf.com/content_CVPR_2020/papers/Guo_Attentive_Weights_Generation_for_Few_Shot_Learning_via_Information_Maximization_CVPR_2020_paper.pdf) 8 | 9 | The implementation is written in Python 3 and has been tested on tensorflow 1.12.0, Ubuntu 16.04. 10 | 11 | Parts of the code are borrowed from [LEO](https://github.com/deepmind/leo). 12 | 13 | The feature embeddings for miniImageNet and tieredImageNet can be downloaded from https://github.com/deepmind/leo. 14 | 15 | 5-way 1-shot experiment on miniImageNet: 16 | 17 | `python main.py` 18 | 19 | 20 | The hyper-parameters can be tuned in `main.py` and AWGIM is in `model.py`. 21 | 22 | 23 | ### Citation 24 | Please cite our work if you find it useful in your research: 25 | ``` 26 | @inproceedings{guo2020awgim, 27 | title = {Attentive Weights Generation for Few Shot Learning via Information Maximization}, 28 | author = {Yiluan Guo, Ngai-Man Cheung}, 29 | booktitle = {CVPR}, 30 | year = {2020} 31 | } 32 | ``` 33 | -------------------------------------------------------------------------------- /dataset_mini.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import pickle as pkl 4 | 5 | 6 | class dataset_mini(object): 7 | def __init__(self, split, args): 8 | self.split = split 9 | self.seed = args.seed 10 | self.root_dir = 'data/miniImagenet' 11 | 12 | def load_data_pkl(self): 13 | pkl_name = '{}/{}_embeddings.pkl'.format(self.root_dir, self.split) 14 | f = open(pkl_name, 'rb') 15 | data = pkl.load(f, encoding='latin1') 16 | f.close() 17 | self.data = data 18 | self.n_classes = data.shape[0] 19 | print('labeled data:', np.shape(self.data), self.n_classes) 20 | 21 | def next_data(self, n_way, n_shot, n_query): 22 | support = np.zeros([n_way, n_shot, 640], dtype=np.float32) 23 | query = np.zeros([n_way, n_query, 640], dtype=np.float32) 24 | selected_classes = np.random.permutation(self.n_classes)[:n_way] 25 | for i, cls in enumerate(selected_classes): # train way 26 | idx1 = np.random.permutation(600)[:n_shot + n_query] 27 | support[i] = self.data[cls, idx1[:n_shot]] 28 | query[i] = self.data[cls, idx1[n_shot:]] 29 | 30 | support_labels = np.tile(np.arange(n_way)[:, np.newaxis], (1, n_shot)).astype(np.uint8) 31 | query_labels = np.tile(np.arange(n_way)[:, np.newaxis], (1, n_query)).astype(np.uint8) 32 | 33 | query = np.reshape(query, (n_way * n_query, 640)) 34 | query_labels = np.reshape(query_labels, (n_way * n_query)) 35 | support = np.reshape(support, (n_way * n_shot, 640)) 36 | support_labels = np.reshape(support_labels, (n_way * n_shot)) 37 | return support, support_labels, query, query_labels 38 | 39 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import argparse 3 | import os 4 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 5 | import numpy as np 6 | import dataset_mini 7 | import dataset_tiered 8 | import model 9 | import random 10 | import utils 11 | 12 | parser = argparse.ArgumentParser(description="Few Shot classification") 13 | parser.add_argument('-shot', '--num_shot', type=int, default=1) 14 | parser.add_argument('-way', '--num_way', type=int, default=5) 15 | parser.add_argument('-q', '--num_query', type=int, default=15) 16 | parser.add_argument('-stage', '--stage', type=str, default='train') 17 | parser.add_argument('-sd', '--seed', type=int, default=1000) 18 | 19 | parser.add_argument('-gt', '--gradient_threshold', type=float, default=0.1) 20 | parser.add_argument('-gnt', '--gradient_norm_threshold', type=float, default=0.1) 21 | parser.add_argument("-drop", "--dropout", type=float, default=0.3) 22 | parser.add_argument('-step', '--step_size', type=int, default=15000) 23 | parser.add_argument('-e', '--epoch', type=int, default=500) 24 | parser.add_argument('-lr', '--learning_rate', type=float, default=2e-4) 25 | parser.add_argument("-weight", "--weight_decay", type=float, default=1e-6) 26 | 27 | parser.add_argument('-a1', '--alpha_1', type=float, default=1.) 28 | parser.add_argument('-a2', '--alpha_2', type=float, default=0.001) 29 | parser.add_argument('-a3', '--alpha_3', type=float, default=0.001) 30 | parser.add_argument('-sh', '--shuffle', type=int, default=0) 31 | 32 | parser.add_argument('-b', '--batch_size', type=int, default=64) 33 | parser.add_argument("-g", "--gpu", type=int, default=1) 34 | parser.add_argument('-md', '--more_data', type=bool, default=False) 35 | parser.add_argument('-ds', '--data_set', type=str, default='mini') 36 | parser.add_argument('-dl', '--dim_latent', type=int, default=128) 37 | parser.add_argument('-ms', '--mlp_size', type=int, default=2) 38 | 39 | args = parser.parse_args() 40 | tf.logging.set_verbosity(tf.logging.ERROR) 41 | tf.set_random_seed(args.seed) 42 | np.random.seed(args.seed) 43 | random.seed(args.seed) 44 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 45 | batch_size = args.batch_size 46 | 47 | 48 | def main(): 49 | is_training = tf.placeholder(tf.bool, name='is_training') 50 | num_class = args.num_way 51 | num_shot = args.num_shot 52 | num_query = args.num_query 53 | 54 | keep_prob = tf.placeholder(tf.float32, name='keep_prob') 55 | support_label = tf.placeholder(tf.int32, (None, ), 'support_label') 56 | query_label = tf.placeholder(tf.int32, (None, ), 'query_label') 57 | 58 | support_x = tf.placeholder(tf.float32, (None, 640), 'support_x') 59 | query_x = tf.placeholder(tf.float32, (None, 640), 'query_x') 60 | 61 | support_feature = support_x 62 | query_feature = query_x 63 | support_feature = tf.reshape(support_feature, (batch_size, num_class, num_shot, 640)) 64 | query_feature = tf.reshape(query_feature, (batch_size, num_class, num_query, 640)) 65 | support_label_reshape = tf.reshape(support_label, (batch_size, num_class, num_shot)) 66 | query_label_reshape = tf.reshape(query_label, (batch_size, num_class, num_query)) 67 | 68 | awgim = model.AWGIM(args, keep_prob, is_training) 69 | loss_cls, accuracy, tr_loss, tr_accuracy, support_reconstruction, query_reconstruction = \ 70 | awgim.forward(support_feature, support_label_reshape, query_feature, query_label_reshape) 71 | reg_term = tf.add_n([tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'kernel' in v.name]) 72 | loss_meta = loss_cls + args.alpha_1 * tr_loss + args.alpha_2 * support_reconstruction + args.alpha_3 * query_reconstruction 73 | Batch = tf.Variable(0, trainable=False, dtype=tf.float32, name='global_step') 74 | learning_rate = tf.train.exponential_decay(learning_rate=args.learning_rate, global_step=Batch, 75 | decay_steps=args.step_size, decay_rate=0.2, staircase=True) 76 | optim = tf.contrib.opt.AdamWOptimizer(learning_rate=learning_rate, weight_decay=args.weight_decay) 77 | meta_weights = [v for v in tf.trainable_variables()] 78 | print(meta_weights) 79 | 80 | if args.stage == 'train': 81 | meta_gradients = utils.grads_and_vars(loss_meta, meta_weights, reg_term) 82 | meta_gradients = utils.clip_gradients(meta_gradients, args.gradient_threshold, args.gradient_norm_threshold) 83 | train_op = optim.apply_gradients(zip(meta_gradients, meta_weights), global_step=Batch) 84 | 85 | config = tf.ConfigProto() 86 | config.gpu_options.allow_growth = True 87 | sess = tf.Session(config=config) 88 | saver = tf.train.Saver() 89 | sess.run(tf.global_variables_initializer()) 90 | 91 | save_path = utils.save(args) 92 | print(save_path) 93 | os.makedirs(save_path, exist_ok=True) 94 | if args.stage == 'test': 95 | print(tf.train.latest_checkpoint(save_path)) 96 | saver.restore(sess, tf.train.latest_checkpoint(save_path)) 97 | print('load model') 98 | if args.data_set == 'mini': 99 | loader_train = dataset_mini.dataset_mini('train', args) 100 | loader_val = dataset_mini.dataset_mini('val', args) 101 | loader_test = dataset_mini.dataset_mini('test', args) 102 | else: 103 | loader_train = dataset_tiered.dataset_tiered('train', args) 104 | loader_val = dataset_tiered.dataset_tiered('val', args) 105 | loader_test = dataset_tiered.dataset_tiered('test', args) 106 | 107 | if args.stage == 'train': 108 | print('Load PKL data') 109 | loader_train.load_data_pkl() 110 | loader_val.load_data_pkl() 111 | else: 112 | loader_test.load_data_pkl() 113 | 114 | val_best_accuracy = 0. 115 | n_iter = 0 116 | record_val_acc = [] 117 | if args.stage == 'train': 118 | for epoch in range(args.epoch): 119 | training_accuracy, training_loss, acc_cp, acc_real, c_loss, d_loss, g_loss = [], [], [], [], [], [], [] 120 | # training_loss_cls = [] 121 | for epi in range(100): 122 | support_input, s_labels, query_input, q_labels = utils.load_batch(args, loader_train, args.batch_size, True, loader_val) 123 | feed_dict = {support_x: support_input, support_label: s_labels, 124 | query_x: query_input, query_label: q_labels, 125 | is_training: True, keep_prob: 1. - args.dropout} 126 | outs = sess.run([train_op, loss_meta, accuracy, Batch], feed_dict=feed_dict) 127 | training_accuracy.append(outs[2]) 128 | training_loss.append(outs[1]) 129 | n_iter += 1 130 | if (epoch+1) % 3 == 0: 131 | log = 'epoch: ', epoch+1, 'accuracy: ', np.mean(training_accuracy), 'loss: ', np.mean(training_loss) 132 | print(log) 133 | if (epoch+1) % 3 == 0: 134 | accuracy_val = [] 135 | loss_val = [] 136 | for epi in range(100): 137 | support_input, s_labels, query_input, q_labels = utils.load_batch(args, loader_val, args.batch_size, training=False) 138 | outs = sess.run([loss_meta, accuracy, Batch], feed_dict={support_x: support_input, support_label: s_labels, 139 | query_x: query_input, query_label: q_labels, 140 | is_training: False, keep_prob: 1.}) 141 | accuracy_val.append(outs[1]) 142 | loss_val.append(outs[0]) 143 | mean_acc = np.mean(accuracy_val) 144 | std_acc = np.std(accuracy_val) 145 | ci95 = 1.96 * std_acc / np.sqrt(100) 146 | print(' Val Acc:{:.4f},std:{:.4f},ci95:{:.4f}'.format(mean_acc, std_acc, ci95), 'at epoch: ', epoch+1) 147 | record_val_acc.append(mean_acc) 148 | if mean_acc > val_best_accuracy: 149 | val_best_accuracy = mean_acc 150 | saver.save(sess, save_path=save_path + 'model.ckpt', global_step=Batch) 151 | if (epoch + 1) % 100 == 0: 152 | saver.save(sess, save_path=save_path + 'model.ckpt', global_step=Batch) 153 | elif args.stage == 'test': 154 | accuracy_test = [] 155 | loss_test = [] 156 | num = 600 157 | for epi in range(num): 158 | support_input, s_labels, query_input, q_labels = utils.load_batch(args, loader_test, args.batch_size, False) 159 | outs = sess.run([loss_meta, accuracy], 160 | feed_dict={support_x: support_input, support_label: s_labels, 161 | query_x: query_input, query_label: q_labels, 162 | is_training: False, keep_prob: 1.}) 163 | accuracy_test.append(outs[1]) 164 | loss_test.append(outs[0]) 165 | mean_acc = np.mean(accuracy_test) 166 | std_acc = np.std(accuracy_test) 167 | ci95 = 1.96 * std_acc / np.sqrt(num) 168 | print('Acc:{:.4f},std:{:.4f},ci95:{:.4f}'.format(mean_acc, std_acc, ci95)) 169 | 170 | sess.close() 171 | 172 | 173 | if __name__ == '__main__': 174 | main() -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow_probability as tfp 3 | 4 | 5 | class AWGIM: 6 | def __init__(self, args, keep_prob, is_training): 7 | self.dim_latent = args.dim_latent 8 | self._l2_penalty_weight = args.weight_decay 9 | self._float_dtype = tf.float32 10 | self.initializer = tf.initializers.glorot_uniform(dtype=self._float_dtype) 11 | self.keep_prob = keep_prob 12 | self.is_training = is_training 13 | self.num_shot = args.num_shot 14 | self.num_class = args.num_way 15 | self.num_query = args.num_query 16 | self.embedding_dim = 640 17 | self.random_sample = self.sample 18 | self.mlp_size = [2*self.dim_latent] * (args.mlp_size-1) 19 | self.b_size = args.batch_size 20 | self.shuffle = args.shuffle 21 | 22 | def sample(self, distribution_params, stddev_offset=0.): 23 | means, unnormalized_stddev = tf.split(distribution_params, 2, axis=-1) 24 | stddev = tf.exp(unnormalized_stddev) 25 | stddev -= (1. - stddev_offset) 26 | stddev = tf.maximum(stddev, 1e-10) 27 | distribution = tfp.distributions.Normal(loc=means, scale=stddev) 28 | samples = distribution.sample() 29 | return tf.cond(self.is_training, false_fn=lambda: (means, means, stddev), 30 | true_fn=lambda: (samples, means, stddev)) 31 | 32 | def support_loss(self, data, label, cls_weights): 33 | cls_weights = tf.reshape(cls_weights, (self.b_size, self.num_class*self.num_query, self.num_class, self.num_shot, self.embedding_dim)) 34 | cls_weights = tf.reduce_mean(cls_weights, 3) 35 | data = tf.reshape(data, (self.b_size, 1, self.num_class*self.num_shot, self.embedding_dim)) 36 | data = tf.tile(data, (1, self.num_query*self.num_class, 1, 1)) 37 | after_dropout = tf.nn.dropout(data, keep_prob=self.keep_prob) 38 | logits = tf.einsum('bqsp,bqcp->bqsc', after_dropout, cls_weights) 39 | 40 | logits = tf.reshape(logits, (self.b_size, self.num_class*self.num_query, self.num_class*self.num_shot, self.num_class)) 41 | label = tf.tile(tf.reshape(label, (self.b_size, 1, self.num_class*self.num_shot)), (1, self.num_class*self.num_query, 1)) 42 | one_hot_label = tf.one_hot(label, self.num_class) 43 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.stop_gradient(one_hot_label), logits=logits, dim=-1)) 44 | pred = tf.nn.softmax(logits) 45 | accuracy = tf.contrib.metrics.accuracy(tf.argmax(pred, -1, output_type=tf.int32), label) 46 | entropy = tf.reduce_mean(tf.reduce_sum(-1. * tf.multiply(pred, tf.log(pred + 1e-6)), -1)) 47 | return loss, accuracy, entropy, pred 48 | 49 | def query_loss(self, data, label, cls_weights): 50 | data = tf.reshape(data, (self.b_size, self.num_class*self.num_query, 1, self.embedding_dim)) 51 | cls_weights = tf.reshape(cls_weights, (self.b_size, self.num_class*self.num_query, self.num_class, self.num_shot, self.embedding_dim)) 52 | cls_weights = tf.reduce_mean(cls_weights, 3) 53 | if self.shuffle != 0: 54 | cls_weights = tf.reshape(cls_weights, (self.b_size, self.num_class, self.num_query, self.num_class, self.embedding_dim)) 55 | cls_weights = tf.transpose(cls_weights, (1, 0, 2, 3, 4) if self.shuffle==1 else (2, 1, 0, 3, 4)) 56 | cls_weights = tf.random_shuffle(cls_weights) 57 | cls_weights = tf.transpose(cls_weights, (1, 0, 2, 3, 4) if self.shuffle==1 else (2, 1, 0, 3, 4)) 58 | cls_weights = tf.reshape(cls_weights, (self.b_size, self.num_class*self.num_query, self.num_class, self.embedding_dim)) 59 | 60 | after_dropout = tf.nn.dropout(data, keep_prob=self.keep_prob) 61 | logits = tf.einsum('bqip,bqcp->bqic', after_dropout, cls_weights) 62 | logits = tf.squeeze(logits) 63 | logits = tf.reshape(logits, (self.b_size, self.num_class, self.num_query, self.num_class)) 64 | loss = self.loss_fn(logits, label) 65 | pred = tf.nn.softmax(logits) 66 | accuracy = tf.contrib.metrics.accuracy(tf.argmax(pred, -1, output_type=tf.int32), label) 67 | entropy = tf.reduce_mean(tf.reduce_sum(-1. * tf.multiply(pred, tf.log(pred + 1e-6)), -1)) 68 | return loss, accuracy, entropy, pred 69 | 70 | def loss_fn(self, logits, label): 71 | one_hot_label = tf.one_hot(label, self.num_class) 72 | return tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=tf.stop_gradient(one_hot_label), logits=logits, dim=-1)) 73 | 74 | def dot_product_attention(self, q, k, v, normalise): 75 | d_k = tf.shape(q)[-1] 76 | scale = tf.sqrt(tf.cast(d_k, tf.float32)) 77 | unnorm_weights = tf.einsum('bjk,bik->bij', k, q) / scale 78 | if normalise: 79 | weight_fn = tf.nn.softmax 80 | else: 81 | weight_fn = tf.sigmoid 82 | weights = weight_fn(unnorm_weights) 83 | rep = tf.einsum('bik,bkj->bij', weights, v) 84 | return rep 85 | 86 | def mlp(self, input, output_sizes, name): 87 | output = tf.nn.dropout(input, self.keep_prob) 88 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 89 | for i, size in enumerate(output_sizes[:-1]): 90 | output = tf.nn.relu(tf.layers.dense(output, size, name="layer_{}".format(i), use_bias=False)) 91 | # Last layer without a ReLu 92 | output = tf.layers.dense(output, output_sizes[-1], name="layer_out", use_bias=False) 93 | return output 94 | 95 | def multihead_attention(self, q, k, v, name, num_heads=4): 96 | d_k = q.get_shape().as_list()[-1] 97 | d_v = v.get_shape().as_list()[-1] 98 | head_size = d_v / num_heads 99 | key_initializer = tf.random_normal_initializer(stddev=d_k ** -0.5) 100 | value_initializer = tf.random_normal_initializer(stddev=d_v ** -0.5) 101 | rep = tf.constant(0.0) 102 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE): 103 | for h in range(num_heads): 104 | o = self.dot_product_attention( 105 | tf.layers.dense(q, head_size, kernel_initializer=key_initializer, use_bias=False, name='wq%d' % h), 106 | tf.layers.dense(k, head_size, kernel_initializer=key_initializer, use_bias=False, name='wk%d' % h), 107 | tf.layers.dense(v, head_size, kernel_initializer=key_initializer, use_bias=False, name='wv%d' % h), 108 | normalise=True) 109 | rep += tf.layers.dense(o, d_v, kernel_initializer=value_initializer, use_bias=False, name='wo%d' % h) 110 | rep += q 111 | rep = tf.contrib.layers.layer_norm(rep, 2) 112 | rep += self.mlp(rep, [2*d_v, d_v], name+'ln_mlp') 113 | rep = tf.contrib.layers.layer_norm(rep, 2) 114 | return rep 115 | 116 | def forward(self, tr_data, tr_label, val_data, val_label): 117 | # tr_data is b x c x k x p, tr_lable is b x c x k, val_data is b x c x q x p, val_label is b x c x q 118 | fan_in = tf.cast(self.embedding_dim, self._float_dtype) 119 | fan_out = tf.cast(self.num_class, self._float_dtype) 120 | stddev_offset = tf.sqrt(2. / (fan_out + fan_in)) 121 | 122 | # attentive path 123 | support = self.mlp(tr_data, [self.dim_latent], 'CA_encoder') 124 | query = self.mlp(val_data, [self.dim_latent], 'CA_encoder') 125 | key_support = tf.reshape(support, (self.b_size, self.num_class*self.num_shot, self.dim_latent)) 126 | query_query = tf.reshape(query, (self.b_size, self.num_class*self.num_query, self.dim_latent)) 127 | value_context = tf.reshape(support, (self.b_size, self.num_class * self.num_shot, self.dim_latent)) 128 | value_context = self.multihead_attention(value_context, value_context, value_context, 'context_CA') 129 | value_context = tf.reshape(value_context, (self.b_size, self.num_class, self.num_shot, self.dim_latent)) 130 | value_context = tf.reduce_mean(value_context, 2, True) 131 | value_context = tf.tile(value_context, (1, 1, self.num_shot, 1)) 132 | value_context = tf.reshape(value_context, (self.b_size, self.num_class*self.num_shot, self.dim_latent)) 133 | 134 | query_ca = self.multihead_attention(query_query, key_support, value_context, 'query_CA') 135 | query_ca_code = tf.tile(tf.reshape(query_ca, (self.b_size, self.num_class * self.num_query, 1, self.dim_latent)), 136 | (1, 1, self.num_shot * self.num_class, 1)) 137 | 138 | # contextual path 139 | support = self.mlp(tr_data, [self.dim_latent], 'SA_encoder') 140 | context = tf.reshape(support, (self.b_size, self.num_class * self.num_shot, self.dim_latent)) 141 | context = self.multihead_attention(context, context, context, 'context_SA') 142 | context_code = tf.reshape(context, (self.b_size, self.num_class, self.num_shot, self.dim_latent)) 143 | context_code = tf.reduce_mean(context_code, 2, True) 144 | context_code = tf.tile(context_code, (1, 1, self.num_shot, 1)) 145 | context_code = tf.tile(tf.reshape(context_code, (self.b_size, 1, self.num_class * self.num_shot, self.dim_latent)), 146 | (1, self.num_class * self.num_query, 1, 1)) 147 | concat = tf.concat((context_code, query_ca_code), 3) 148 | # concat is b x cq x ck x d 149 | decoder_name = 'decoder' 150 | reconstruct_mlp = self.mlp_size 151 | weights_dist_params = self.mlp(concat, self.mlp_size + [2*self.embedding_dim], decoder_name+'mlp_weight') 152 | classifier_weights, mu, sigma = self.random_sample(weights_dist_params, stddev_offset=stddev_offset) 153 | reconstructed_query = self.mlp(classifier_weights, reconstruct_mlp + [self.dim_latent], 154 | 'recontructed_q') 155 | reconstructed_support = self.mlp(classifier_weights, reconstruct_mlp + [self.dim_latent], 156 | 'recontructed_s') 157 | 158 | tr_loss, tr_accuracy, entropy_support, pred_support = self.support_loss(tr_data, tr_label, classifier_weights) 159 | val_loss, val_accuracy, entropy_query, pred_query = self.query_loss(val_data, val_label, classifier_weights) 160 | 161 | loss_support = tf.reduce_mean(tf.reduce_sum(tf.square(tf.stop_gradient(context_code) - reconstructed_support), -1)) 162 | loss_query = tf.reduce_mean(tf.reduce_sum(tf.square(tf.stop_gradient(query_ca_code) - reconstructed_query), -1)) 163 | return val_loss, val_accuracy, tr_loss, tr_accuracy, loss_support, loss_query 164 | 165 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | 6 | def save(args): 7 | if args.data_set == 'mini': 8 | save_path = 'saved_models_mini/' 9 | else: 10 | save_path = 'saved_models_tiered/' 11 | save_path += 'AWGIM' + str(args.num_way) + '_way_' + str(args.num_shot) + 'shot' 12 | save_path += '_wd' + str(args.weight_decay) + '_dl' + str(args.dim_latent) + '_lr' + str(args.learning_rate) 13 | save_path += '_as' + str(args.alpha_1) 14 | save_path += '_ar' + str(args.alpha_1) + str(args.alpha_2) 15 | save_path += '/' 16 | return save_path 17 | 18 | 19 | def load_batch(args, loader, b, training=True, loader_b=None): 20 | support_input_total, s_labels_total, query_input_total, q_labels_total = [], [], [], [] 21 | for i in range(b): 22 | if training: 23 | if args.more_data: 24 | a = random.uniform(0., (5. if args.data_set=='mini' else 4.62)) 25 | if a <= (4. if args.data_set=='mini' else 3.62): 26 | support_input, s_labels, query_input, q_labels = loader.next_data(args.num_way, args.num_shot, args.num_query) 27 | else: 28 | support_input, s_labels, query_input, q_labels = loader_b.next_data(args.num_way, args.num_shot, args.num_query) 29 | else: 30 | support_input, s_labels, query_input, q_labels = loader.next_data(args.num_way, args.num_shot, args.num_query) 31 | else: 32 | support_input, s_labels, query_input, q_labels = loader.next_data(args.num_way, args.num_shot, args.num_query) 33 | support_input_total.append(support_input) 34 | s_labels_total.append(s_labels) 35 | query_input_total.append(query_input) 36 | q_labels_total.append(q_labels) 37 | support_input = np.concatenate(support_input_total, axis=0) 38 | s_labels = np.concatenate(s_labels_total, axis=0) 39 | query_input = np.concatenate(query_input_total, axis=0) 40 | q_labels = np.concatenate(q_labels_total, axis=0) 41 | return support_input, s_labels, query_input, q_labels 42 | 43 | 44 | def clip_gradients(gradients, gradient_threshold, gradient_norm_threshold): 45 | if gradient_threshold > 0: 46 | gradients = [ 47 | tf.clip_by_value(g, -gradient_threshold, gradient_threshold) 48 | for g in gradients 49 | ] 50 | if gradient_norm_threshold > 0: 51 | gradients = [ 52 | tf.clip_by_norm(g, gradient_norm_threshold) for g in gradients 53 | ] 54 | return gradients 55 | 56 | 57 | def grads_and_vars(metatrain_loss, weights, reg_term): 58 | """Computes gradients of metatrain_loss, avoiding NaN. 59 | 60 | Uses a fixed penalty of 1e-4 to enforce only the l2 regularization (and not 61 | minimize the loss) when metatrain_loss or any of its gradients with respect 62 | to trainable_vars are NaN. In practice, this approach pulls the variables 63 | back into a feasible region of the space when the loss or its gradients are 64 | not defined. 65 | 66 | Args: 67 | metatrain_loss: A tensor with the LEO meta-training loss. 68 | 69 | Returns: 70 | A tuple with: 71 | metatrain_gradients: A list of gradient tensors. 72 | metatrain_variables: A list of variables for this LEO model. 73 | """ 74 | metatrain_variables = weights 75 | metatrain_gradients = tf.gradients(metatrain_loss, metatrain_variables) 76 | 77 | nan_loss_or_grad = tf.logical_or( 78 | tf.is_nan(metatrain_loss), 79 | tf.reduce_any([tf.reduce_any(tf.is_nan(g)) 80 | for g in metatrain_gradients])) 81 | 82 | regularization_penalty = (1e-4 * reg_term) 83 | zero_or_regularization_gradients = [ 84 | g if g is not None else tf.zeros_like(v) 85 | for v, g in zip(tf.gradients(regularization_penalty, 86 | metatrain_variables), metatrain_variables)] 87 | 88 | metatrain_gradients = tf.cond(nan_loss_or_grad, 89 | lambda: zero_or_regularization_gradients, 90 | lambda: metatrain_gradients, strict=True) 91 | 92 | return metatrain_gradients --------------------------------------------------------------------------------