├── README.md ├── checkpoints ├── mini_TPN_5w1s_5tw1ts_rn300_k20 │ └── models │ │ ├── ckpt-81500.data-00000-of-00001 │ │ ├── ckpt-81500.index │ │ └── ckpt-81500.meta └── mini_TPN_5w5s_5tw5ts_rn300_k20 │ └── models │ ├── ckpt-50100.data-00000-of-00001 │ ├── ckpt-50100.index │ └── ckpt-50100.meta ├── dataset_mini.py ├── dataset_tiered.py ├── models.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Transductive Propagation Network 2 | Code for ICLR19 paper: 3 | *Learning to Propagate Labels: Transductive Propagation Network for Few-shot Learning.* [pdf](https://openreview.net/pdf?id=SyVuRiC5K7) 4 | 5 | ## Pytorch Version 6 | https://github.com/csyanbin/TPN-pytorch 7 | 8 | ## Requirements 9 | * Python 3.5 10 | * Tensorflow 1.3+ 11 | * tqdm 12 | 13 | 14 | ## Data Download (miniImagenet and tieredImagenet) 15 | Please download the compressed tar files from: https://github.com/renmengye/few-shot-ssl-public 16 | 17 | ``` 18 | mkdir -p data/miniImagenet/data 19 | tar -zxvf mini-imagenet.tar.gz 20 | mv *.pkl data/miniImagenet/data 21 | 22 | mkdir -p data/tieredImagenet/data 23 | tar -xvf tiered-imagenet.tar 24 | mv *.pkl data/tieredImagenet/data 25 | 26 | ``` 27 | 28 | ## TPN mini-5way1shot 29 | ``` 30 | python train.py --gpu=0 --n_way=5 --n_shot=1 --n_test_way=5 --n_test_shot=1 --lr=0.001 --step_size=10000 --dataset=mini --exp_name=mini_TPN_5w1s_5tw1ts_rn300_k20 --rn=300 --alpha=0.99 --k=20 31 | ``` 32 | 33 | ``` 34 | python test.py --gpu=0 --n_way=5 --n_shot=1 --n_test_way=5 --n_test_shot=1 --lr=0.001 --step_size=10000 --dataset=mini --exp_name=mini_TPN_5w1s_5tw1ts_rn300_k20 --rn=300 --alpha=0.99 --k=20 --iters=81500 35 | ``` 36 | 37 | ## TPN mini-5way5shot 38 | ``` 39 | python train.py --gpu=0 --n_way=5 --n_shot=5 --n_test_way=5 --n_test_shot=5 --lr=0.001 --step_size=10000 --dataset=mini --exp_name=mini_TPN_5w5s_5tw5ts_rn300_k20 --rn=300 --alpha=0.99 --k=20 40 | ``` 41 | 42 | ``` 43 | python test.py --gpu=0 --n_way=5 --n_shot=5 --n_test_way=5 --n_test_shot=5 --lr=0.001 --step_size=10000 --dataset=mini --exp_name=mini_TPN_5w5s_5tw5ts_rn300_k20 --rn=300 --alpha=0.99 --k=20 --iters=50100 44 | 45 | ``` 46 | 47 | ## TPN tiered-5way1shot 48 | ``` 49 | python train.py --gpu=0 --n_way=5 --n_shot=1 --n_test_way=5 --n_test_shot=1 --lr=0.001 --step_size=25000 --dataset=tiered --exp_name=tiered_TPN_5w1s_5tw1ts_rn300_k20 --rn=300 --alpha=0.99 --k=20 50 | ``` 51 | 52 | ## TPN tiered-5way5shot 53 | ``` 54 | python train.py --gpu=0 --n_way=5 --n_shot=5 --n_test_way=5 --n_test_shot=5 --lr=0.001 --step_size=25000 --dataset=tiered --exp_name=tiered_TPN_5w5s_5tw5ts_rn300_k20 --rn=300 --alpha=0.99 --k=20 55 | ``` 56 | 57 | 58 | ## Citation 59 | If you use our code, please consider to cite the following paper: 60 | * Yanbin Liu, Juho Lee, Minseop Park, Saehoon Kim, Eunho Yang, Sungju Hwang, Yi Yang. Learning to Propagate Labels: Transductive Propagation Network for Few-shot Learning. In *Proceedings of 7th International Conference on Learning Representations (ICLR)*, 2019. 61 | 62 | ``` 63 | 64 | @inproceedings{liu2019fewTPN, 65 | title={Learning to Propagate Labels: Transductive Propagation Network for Few-shot Learning}, 66 | author={Yanbin Liu and 67 | Juho Lee and 68 | Minseop Park and 69 | Saehoon Kim and 70 | Eunho Yang and 71 | Sungju Hwang and 72 | Yi Yang}, 73 | booktitle={International Conference on Learning Representations}, 74 | year={2019}, 75 | } 76 | 77 | ``` 78 | 79 | -------------------------------------------------------------------------------- /checkpoints/mini_TPN_5w1s_5tw1ts_rn300_k20/models/ckpt-81500.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csyanbin/TPN/3bc9eeaa1d4883ee237c9811cf9eb7c33bfadd24/checkpoints/mini_TPN_5w1s_5tw1ts_rn300_k20/models/ckpt-81500.data-00000-of-00001 -------------------------------------------------------------------------------- /checkpoints/mini_TPN_5w1s_5tw1ts_rn300_k20/models/ckpt-81500.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csyanbin/TPN/3bc9eeaa1d4883ee237c9811cf9eb7c33bfadd24/checkpoints/mini_TPN_5w1s_5tw1ts_rn300_k20/models/ckpt-81500.index -------------------------------------------------------------------------------- /checkpoints/mini_TPN_5w1s_5tw1ts_rn300_k20/models/ckpt-81500.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csyanbin/TPN/3bc9eeaa1d4883ee237c9811cf9eb7c33bfadd24/checkpoints/mini_TPN_5w1s_5tw1ts_rn300_k20/models/ckpt-81500.meta -------------------------------------------------------------------------------- /checkpoints/mini_TPN_5w5s_5tw5ts_rn300_k20/models/ckpt-50100.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csyanbin/TPN/3bc9eeaa1d4883ee237c9811cf9eb7c33bfadd24/checkpoints/mini_TPN_5w5s_5tw5ts_rn300_k20/models/ckpt-50100.data-00000-of-00001 -------------------------------------------------------------------------------- /checkpoints/mini_TPN_5w5s_5tw5ts_rn300_k20/models/ckpt-50100.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csyanbin/TPN/3bc9eeaa1d4883ee237c9811cf9eb7c33bfadd24/checkpoints/mini_TPN_5w5s_5tw5ts_rn300_k20/models/ckpt-50100.index -------------------------------------------------------------------------------- /checkpoints/mini_TPN_5w5s_5tw5ts_rn300_k20/models/ckpt-50100.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csyanbin/TPN/3bc9eeaa1d4883ee237c9811cf9eb7c33bfadd24/checkpoints/mini_TPN_5w5s_5tw5ts_rn300_k20/models/ckpt-50100.meta -------------------------------------------------------------------------------- /dataset_mini.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | from PIL import Image 4 | import pickle as pkl 5 | import os 6 | import glob 7 | import csv 8 | from scipy.ndimage import imread 9 | from scipy.misc import imresize 10 | 11 | class dataset_mini(object): 12 | def __init__(self, n_examples, n_episodes, split, args): 13 | self.im_width, self.im_height, self.channels = list(map(int, args['x_dim'].split(','))) 14 | self.n_examples = n_examples 15 | self.n_episodes = n_episodes 16 | self.split = split 17 | self.ratio = args['ratio'] 18 | self.seed = args['seed'] 19 | self.root_dir = './data/miniImagenet' 20 | 21 | self.n_label = int(self.ratio*self.n_examples) 22 | self.n_unlabel = self.n_examples-self.n_label 23 | self.dataset_l = [] 24 | self.dataset_u = [] 25 | 26 | self.args = args 27 | 28 | def load_data(self): 29 | """ 30 | Load data into memory and partition into label,unlabel 31 | """ 32 | print('Loading {} dataset'.format(self.split)) 33 | data_split_path = os.path.join(self.root_dir, 'splits', '{}.csv'.format(self.split)) 34 | with open(data_split_path,'r') as f: 35 | reader = csv.reader(f, delimiter=',') 36 | data_classes = {} 37 | for i,row in enumerate(reader): 38 | if i==0: 39 | continue 40 | data_classes[row[1]] = 1 41 | data_classes = data_classes.keys() 42 | print(data_classes) 43 | 44 | n_classes = len(data_classes) 45 | print('n_classes:{}, n_label:{}, n_unlabel:{}'.format(n_classes,self.n_label,self.n_unlabel)) 46 | dataset_l = np.zeros([n_classes, self.n_label, self.im_height, self.im_width, self.channels], dtype=np.float32) 47 | if self.n_unlabel>0: 48 | dataset_u = np.zeros([n_classes, self.n_unlabel, self.im_height, self.im_width, self.channels], dtype=np.float32) 49 | else: 50 | dataset_u = [] 51 | 52 | for i, cls in enumerate(data_classes): 53 | im_dir = os.path.join(self.root_dir, 'data/{}/'.format(self.split), cls) 54 | im_files = sorted(glob.glob(os.path.join(im_dir, '*.jpg'))) 55 | np.random.RandomState(self.seed).shuffle(im_files) # fix the seed to keep label,unlabel fixed 56 | for j, im_file in enumerate(im_files): 57 | im = np.array(Image.open(im_file).resize((self.im_width, self.im_height)), 58 | np.float32, copy=False) 59 | #im = np.array(imresize(imread(im_file), (self.im_width,self.im_height,3))) / 255.0 60 | if j0: 97 | dataset_u = np.zeros([n_classes, self.n_unlabel, self.im_height, self.im_width, self.channels], dtype=np.float32) 98 | else: 99 | dataset_u = [] 100 | 101 | for i, cls in enumerate(data_classes): 102 | idxs = class_dict[cls] 103 | np.random.RandomState(self.seed).shuffle(idxs) # fix the seed to keep label,unlabel fixed 104 | dataset_l[i] = image_data[idxs[0:self.n_label]] 105 | if self.n_unlabel>0: 106 | dataset_u[i] = image_data[idxs[self.n_label:]] 107 | print('labeled data:', np.shape(dataset_l)) 108 | print('unlabeled data:', np.shape(dataset_u)) 109 | 110 | self.dataset_l = dataset_l 111 | self.dataset_u = dataset_u 112 | self.n_classes = n_classes 113 | 114 | del image_data 115 | 116 | 117 | def next_data(self, n_way, n_shot, n_query, num_unlabel=0, n_distractor=0, train=True): 118 | """ 119 | get support,query,unlabel data from n_way 120 | get unlabel data from n_distractor 121 | """ 122 | support = np.zeros([n_way, n_shot, self.im_height, self.im_width, self.channels], dtype=np.float32) 123 | query = np.zeros([n_way, n_query, self.im_height, self.im_width, self.channels], dtype=np.float32) 124 | if num_unlabel>0: 125 | unlabel = np.zeros([n_way+n_distractor, num_unlabel, self.im_height, self.im_width, self.channels], dtype=np.float32) 126 | else: 127 | unlabel = [] 128 | n_distractor = 0 129 | 130 | selected_classes = np.random.permutation(self.n_classes)[:n_way+n_distractor] 131 | for i, cls in enumerate(selected_classes[0:n_way]): # train way 132 | # labled data 133 | idx1 = np.random.permutation(self.n_label)[:n_shot + n_query] 134 | support[i] = self.dataset_l[cls, idx1[:n_shot]] 135 | query[i] = self.dataset_l[cls, idx1[n_shot:]] 136 | # unlabel 137 | if num_unlabel>0: 138 | idx2 = np.random.permutation(self.n_unlabel)[:num_unlabel] 139 | unlabel[i] = self.dataset_u[cls,idx2] 140 | 141 | for j,cls in enumerate(selected_classes[self.n_classes:]): # distractor way 142 | idx3 = np.random.permutation(self.n_unlabel)[:num_unlabel] 143 | unlabel[i+j] = self.dataset_u[cls,idx3] 144 | 145 | support_labels = np.tile(np.arange(n_way)[:, np.newaxis], (1, n_shot)).astype(np.uint8) 146 | query_labels = np.tile(np.arange(n_way)[:, np.newaxis], (1, n_query)).astype(np.uint8) 147 | # unlabel_labels = np.tile(np.arange(n_way+n_distractor)[:, np.newaxis], (1, num_unlabel)).astype(np.uint8) 148 | 149 | return support, support_labels, query, query_labels, unlabel 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /dataset_tiered.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | from PIL import Image 4 | import pickle as pkl 5 | import os 6 | import glob 7 | import csv 8 | from tqdm import tqdm 9 | import cv2 10 | 11 | class dataset_tiered(object): 12 | def __init__(self, n_examples, n_episodes, split, args): 13 | self.im_width, self.im_height, self.channels = list(map(int, args['x_dim'].split(','))) 14 | self.n_examples = n_examples 15 | self.n_episodes = n_episodes 16 | self.split = split 17 | self.ratio = args['ratio'] 18 | self.seed = args['seed'] 19 | self.root_dir = './data/tieredImagenet' 20 | 21 | self.iamge_data = [] 22 | self.dict_index_label = [] 23 | self.dict_index_unlabel = [] 24 | 25 | 26 | 27 | def load_data_pkl(self): 28 | """ 29 | load the pkl processed tieredImagenet into label,unlabel 30 | maintain label,unlabel data dictionary for indexes 31 | """ 32 | labels_name = '{}/data/{}_labels.pkl'.format(self.root_dir, self.split) 33 | images_name = '{}/data/{}_images.npz'.format(self.root_dir, self.split) 34 | print('labels:',labels_name) 35 | print('images:',images_name) 36 | 37 | # decompress images if npz not exits 38 | if not os.path.exists(images_name): 39 | png_pkl = images_name[:-4] + '_png.pkl' 40 | if os.path.exists(png_pkl): 41 | decompress(images_name, png_pkl) 42 | else: 43 | raise ValueError('path png_pkl not exits') 44 | 45 | if os.path.exists(images_name) and os.path.exists(labels_name): 46 | try: 47 | with open(labels_name) as f: 48 | data = pkl.load(f) 49 | label_specific = data["label_specific"] 50 | #label_general = data["label_general"] 51 | #label_specific_str = data["label_specific_str"] 52 | #label_general_str = data["label_general_str"] 53 | except: 54 | with open(labels_name, 'rb') as f: 55 | data = pkl.load(f, encoding='bytes') 56 | label_specific = data[b'label_specific'] 57 | #label_general = data[b"label_general"] 58 | #label_specific_str = data[b"label_specific_str"] 59 | #label_general_str = data[b"label_general_str"] 60 | print('read label data:{}'.format(len(label_specific))) 61 | labels = label_specific 62 | 63 | with np.load(images_name, mmap_mode="r", encoding='latin1') as data: 64 | image_data = data["images"] 65 | print('read image data:{}'.format(image_data.shape)) 66 | 67 | n_classes = np.max(labels)+1 68 | 69 | print('n_classes:{}, n_label:{}%, n_unlabel:{}%'.format(n_classes,self.ratio*100,(1-self.ratio)*100)) 70 | dict_index_label = {} # key:label, value:idxs 71 | dict_index_unlabel = {} 72 | 73 | for cls in range(n_classes): 74 | idxs = np.where(labels==cls)[0] 75 | nums = idxs.shape[0] 76 | np.random.RandomState(self.seed).shuffle(idxs) # fix the seed to keep label,unlabel fixed 77 | 78 | n_label = int(self.ratio*nums) 79 | n_unlabel = nums-n_label 80 | 81 | dict_index_label[cls] = idxs[0:n_label] 82 | dict_index_unlabel[cls] = idxs[n_label:] 83 | 84 | self.image_data = image_data 85 | self.dict_index_label = dict_index_label 86 | self.dict_index_unlabel = dict_index_unlabel 87 | self.n_classes = n_classes 88 | print(dict_index_label[0]) 89 | print(dict_index_unlabel[0]) 90 | 91 | 92 | def next_data(self, n_way, n_shot, n_query, num_unlabel=0, n_distractor=0, train=True): 93 | """ 94 | get support,query,unlabel data from n_way 95 | get unlabel data from n_distractor 96 | """ 97 | support = np.zeros([n_way, n_shot, self.im_height, self.im_width, self.channels], dtype=np.float32) 98 | query = np.zeros([n_way, n_query, self.im_height, self.im_width, self.channels], dtype=np.float32) 99 | if num_unlabel>0: 100 | unlabel = np.zeros([n_way+n_distractor, num_unlabel, self.im_height, self.im_width, self.channels], dtype=np.float32) 101 | else: 102 | unlabel = [] 103 | n_distractor = 0 104 | 105 | selected_classes = np.random.permutation(self.n_classes)[:n_way+n_distractor] 106 | for i, cls in enumerate(selected_classes[0:n_way]): # train way 107 | # labled data 108 | idx = self.dict_index_label[cls] 109 | np.random.RandomState().shuffle(idx) 110 | idx1 = idx[0:n_shot + n_query] 111 | support[i] = self.image_data[idx1[:n_shot]] 112 | query[i] = self.image_data[idx1[n_shot:]] 113 | 114 | # unlabel 115 | if num_unlabel>0: 116 | idx = self.dict_index_unlabel[cls] 117 | np.random.RandomState().shuffle(idx) 118 | idx2 = idx[0:num_unlabel] 119 | unlabel[i] = self.image_data[idx2] 120 | 121 | for j,cls in enumerate(selected_classes[self.n_classes:]): # distractor way 122 | idx = self.dict_index_unlabel[cls] 123 | np.random.RandomState().shuffle(idx) 124 | idx3 = idx[0:num_unlabel] 125 | unlabel[i+j] = self.image_data[idx3] 126 | 127 | support_labels = np.tile(np.arange(n_way)[:, np.newaxis], (1, n_shot)).astype(np.uint8) 128 | query_labels = np.tile(np.arange(n_way)[:, np.newaxis], (1, n_query)).astype(np.uint8) 129 | # unlabel_labels = np.tile(np.arange(n_way+n_distractor)[:, np.newaxis], (1, num_unlabel)).astype(np.uint8) 130 | return support, support_labels, query, query_labels, unlabel 131 | 132 | 133 | 134 | def compress(path, output): 135 | with np.load(path, mmap_mode="r") as data: 136 | images = data["images"] 137 | array = [] 138 | for ii in tqdm(six.moves.xrange(images.shape[0]), desc='compress'): 139 | im = images[ii] 140 | im_str = cv2.imencode('.png', im)[1] 141 | array.append(im_str) 142 | with open(output, 'wb') as f: 143 | pkl.dump(array, f, protocol=pkl.HIGHEST_PROTOCOL) 144 | 145 | 146 | def decompress(path, output): 147 | with open(output, 'rb') as f: 148 | array = pkl.load(f, encoding='bytes') 149 | images = np.zeros([len(array), 84, 84, 3], dtype=np.uint8) 150 | for ii, item in tqdm(enumerate(array), desc='decompress'): 151 | im = cv2.imdecode(item, 1) 152 | images[ii] = im 153 | np.savez(path, images=images) 154 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import print_function 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow.contrib.layers.python import layers as tf_layers 6 | 7 | class models(object): 8 | def __init__(self, args): 9 | # parameters 10 | self.im_width, self.im_height, self.channels = list(map(int, args['x_dim'].split(','))) 11 | self.h_dim, self.z_dim = args['h_dim'], args['z_dim'] 12 | self.args = args 13 | 14 | # placeholders for data and label 15 | self.x = tf.placeholder(tf.float32, [None, None, self.im_height, self.im_width, self.channels]) 16 | self.ys = tf.placeholder(tf.int64, [None, None]) 17 | self.q = tf.placeholder(tf.float32, [None, None, self.im_height, self.im_width, self.channels]) 18 | self.y = tf.placeholder(tf.int64, [None, None]) 19 | self.phase = tf.placeholder(tf.bool, name='phase') 20 | 21 | self.alpha = args['alpha'] 22 | 23 | 24 | def conv_block(self, inputs, out_channels, pool_pad='VALID', name='conv'): 25 | with tf.variable_scope(name): 26 | conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, padding="same") 27 | conv = tf.contrib.layers.batch_norm(conv, is_training=self.phase, decay=0.999, epsilon=1e-3, scale=True, center=True) 28 | conv = tf.nn.relu(conv) 29 | conv = tf.contrib.layers.max_pool2d(conv, 2, padding=pool_pad) 30 | 31 | return conv 32 | 33 | 34 | def encoder(self, x, h_dim, z_dim, reuse=False): 35 | """Feature embedding network""" 36 | with tf.variable_scope('encoder', reuse=reuse): 37 | net = self.conv_block(x, h_dim, name='conv_1') 38 | net = self.conv_block(net, h_dim, name='conv_2') 39 | net = self.conv_block(net, h_dim, name='conv_3') 40 | net = self.conv_block(net, z_dim, name='conv_4') 41 | 42 | net = tf.contrib.layers.flatten(net) 43 | 44 | return net 45 | 46 | 47 | def relation(self, x, h_dim, z_dim, reuse=False): 48 | """Graph Construction Module""" 49 | with tf.variable_scope('relation', reuse=reuse): 50 | x = tf.reshape(x, (-1,5,5,64)) 51 | 52 | net = self.conv_block(x, h_dim, pool_pad='SAME', name='conv_5') 53 | net = self.conv_block(net, 1, pool_pad='SAME', name='conv_6') 54 | 55 | net = tf.contrib.layers.flatten(net) 56 | 57 | net = tf.contrib.layers.fully_connected(net, 8) 58 | net = tf.contrib.layers.fully_connected(net, 1, tf.identity) 59 | 60 | net = tf.contrib.layers.flatten(net) 61 | 62 | return net 63 | 64 | 65 | # contruct the model 66 | def construct(self): 67 | # data input 68 | x_shape = tf.shape(self.x) 69 | q_shape = tf.shape(self.q) 70 | num_classes, num_support = x_shape[0], x_shape[1] 71 | num_queries = q_shape[1] 72 | 73 | ys_one_hot = tf.one_hot(self.ys, depth=num_classes) 74 | y_one_hot = tf.one_hot(self.y, depth=num_classes) 75 | 76 | # construct the model 77 | x = tf.reshape(self.x, [num_classes * num_support, self.im_height, self.im_width, self.channels]) 78 | q = tf.reshape(self.q, [num_classes * num_queries, self.im_height, self.im_width, self.channels]) 79 | emb_x = self.encoder(x, self.h_dim, self.z_dim) 80 | emb_q = self.encoder(q, self.h_dim, self.z_dim, reuse=True) 81 | emb_dim = tf.shape(emb_x)[-1] 82 | 83 | if self.args['rn']==300: # learned sigma, fixed alpha 84 | self.alpha = tf.constant(self.args['alpha']) 85 | else: # learned sigma and alpha 86 | self.alpha = tf.Variable(self.alpha, name='alpha') 87 | 88 | ce_loss, acc, sigma_value = self.label_prop(emb_x, emb_q, ys_one_hot) 89 | 90 | return ce_loss, acc, sigma_value 91 | 92 | 93 | 94 | def label_prop(self, x, u, ys): 95 | 96 | epsilon = np.finfo(float).eps 97 | # x: NxD, u: UxD 98 | s = tf.shape(ys) 99 | ys = tf.reshape(ys, (s[0]*s[1],-1)) 100 | Ns, C = tf.shape(ys)[0], tf.shape(ys)[1] 101 | Nu = tf.shape(u)[0] 102 | 103 | yu = tf.zeros((Nu,C))/tf.cast(C,tf.float32) + epsilon # 0 initialization 104 | #yu = tf.ones((Nu,C))/tf.cast(C,tf.float32) # 1/C initialization 105 | y = tf.concat([ys,yu],axis=0) 106 | gt = tf.reshape(tf.tile(tf.expand_dims(tf.range(C),1), [1,tf.cast(Nu/C,tf.int32)]), [-1]) 107 | 108 | all_un = tf.concat([x,u],0) 109 | all_un = tf.reshape(all_un, [-1, 1600]) 110 | N, d = tf.shape(all_un)[0], tf.shape(all_un)[1] 111 | 112 | # compute graph weights 113 | if self.args['rn'] in [30, 300]: # compute example-wise sigma 114 | self.sigma = self.relation(all_un, self.h_dim, self.z_dim) 115 | 116 | all_un = all_un / (self.sigma+epsilon) 117 | all1 = tf.expand_dims(all_un, axis=0) 118 | all2 = tf.expand_dims(all_un, axis=1) 119 | W = tf.reduce_mean(tf.square(all1-all2), axis=2) 120 | W = tf.exp(-W/2) 121 | 122 | # kNN Graph 123 | if self.args['k']>0: 124 | W = self.topk(W, self.args['k']) 125 | 126 | # Laplacian norm 127 | D = tf.reduce_sum(W, axis=0) 128 | D_inv = 1.0/(D+epsilon) 129 | D_sqrt_inv = tf.sqrt(D_inv) 130 | 131 | # compute propagated label 132 | D1 = tf.expand_dims(D_sqrt_inv, axis=1) 133 | D2 = tf.expand_dims(D_sqrt_inv, axis=0) 134 | S = D1*W*D2 135 | F = tf.matrix_inverse(tf.eye(N)-self.alpha*S+epsilon) 136 | F = tf.matmul(F,y) 137 | label = tf.argmax(F, 1) 138 | 139 | # loss computation 140 | F = tf.nn.softmax(F) 141 | 142 | y_one_hot = tf.reshape(tf.one_hot(gt,depth=C),[Nu, -1]) 143 | y_one_hot = tf.concat([ys,y_one_hot], axis=0) 144 | 145 | ce_loss = y_one_hot*tf.log(F+epsilon) 146 | ce_loss = tf.negative(ce_loss) 147 | ce_loss = tf.reduce_mean(tf.reduce_sum(ce_loss,1)) 148 | 149 | # only consider query examples acc 150 | F_un = F[Ns:,:] 151 | acc = tf.reduce_mean(tf.to_float(tf.equal(label[Ns:],tf.cast(gt,tf.int64)))) 152 | 153 | return ce_loss, acc, self.sigma 154 | 155 | 156 | def topk(self, W, k): 157 | # construct k-NN and compute margin loss 158 | values, indices = tf.nn.top_k(W, k, sorted=False) 159 | my_range = tf.expand_dims(tf.range(0, tf.shape(indices)[0]), 1) 160 | my_range_repeated = tf.tile(my_range, [1, k]) 161 | full_indices = tf.concat([tf.expand_dims(my_range_repeated, 2), tf.expand_dims(indices, 2)], axis=2) 162 | full_indices = tf.reshape(full_indices, [-1, 2]) 163 | 164 | topk_W = tf.sparse_to_dense(full_indices, tf.shape(W), tf.reshape(values, [-1]), default_value=0., validate_indices=False) 165 | ind1 = (topk_W>0)|(tf.transpose(topk_W)>0) # union, k-nearest neighbor 166 | ind2 = (topk_W>0)&(tf.transpose(topk_W)>0) # intersection, mutal k-nearest neighbor 167 | ind1 = tf.cast(ind1,tf.float32) 168 | 169 | topk_W = ind1*W 170 | 171 | return topk_W 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # coding: utf 2 | from __future__ import print_function 3 | 4 | from PIL import Image 5 | import tensorflow as tf 6 | import os 7 | import glob 8 | import csv 9 | from models import * 10 | from dataset_mini import * 11 | from dataset_tiered import * 12 | from tqdm import tqdm 13 | import argparse 14 | import random 15 | import numpy as np 16 | import scipy as sp 17 | import scipy.stats 18 | 19 | 20 | parser = argparse.ArgumentParser(description='Test TPN') 21 | 22 | # parse gpu and random params 23 | default_gpu = "0" 24 | parser.add_argument('--gpu', type=str, default=0, metavar='GPU', 25 | help="gpu name, default:{}".format(default_gpu)) 26 | parser.add_argument('--seed', type=int, default=1000, metavar='SEED', 27 | help="random seed, -1 means no seed") 28 | 29 | # model params 30 | n_examples = 600 31 | parser.add_argument('--x_dim', type=str, default="84,84,3", metavar='XDIM', 32 | help='input image dims') 33 | parser.add_argument('--h_dim', type=int, default=64, metavar='HDIM', 34 | help="dimensionality of hidden layers (default: 64)") 35 | parser.add_argument('--z_dim', type=int, default=64, metavar='ZDIM', 36 | help="dimensionality of input images (default: 64)") 37 | 38 | # basic training hyper-parameters 39 | n_episodes = 100 40 | parser.add_argument('--n_way', type=int, default=5, metavar='NWAY', 41 | help="nway") 42 | parser.add_argument('--n_shot', type=int, default=5, metavar='NSHOT', 43 | help="nshot") 44 | parser.add_argument('--n_query', type=int, default=15, metavar='NQUERY', 45 | help="nquery") 46 | parser.add_argument('--n_epochs', type=int, default=1100, metavar='NEPOCHS', 47 | help="nepochs") 48 | 49 | # val and test hyper-parameters 50 | parser.add_argument('--n_test_way', type=int, default=5, metavar='NTESTWAY', 51 | help="ntestway") 52 | parser.add_argument('--n_test_shot', type=int, default=5, metavar='NTESTSHOT', 53 | help="ntestshot") 54 | parser.add_argument('--n_test_query', type=int, default=15, metavar='NTESTQUERY', 55 | help="ntestquery") 56 | parser.add_argument('--n_test_episodes',type=int, default=600, metavar='NTESTEPI', 57 | help="ntestepisodes") 58 | 59 | # optimization params 60 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 61 | help="base learning rate") 62 | parser.add_argument('--step_size', type=int, default=2000, metavar='DSTEP', 63 | help="step size") 64 | parser.add_argument('--gamma', type=float, default=0.5, metavar='DRATE', 65 | help="gamma") 66 | parser.add_argument('--patience', type=int, default=200, metavar='PATIENCE', 67 | help="patience") 68 | 69 | # dataset params 70 | parser.add_argument('--dataset', type=str, default='mini', metavar='DATASET', 71 | help="mini or tiered") 72 | parser.add_argument('--ratio', type=float, default=1.0, metavar='RATIO', 73 | help="ratio of labeled data") 74 | parser.add_argument('--pkl', type=int, default=1, metavar='PKL', 75 | help="") 76 | 77 | # label propagation params 78 | parser.add_argument('--k', type=int, default=20, metavar='K', 79 | help="K in refine protos") 80 | parser.add_argument('--sigma', type=float, default=0.25, metavar='SIGMA', 81 | help="SIGMA of k-NN graph construction") 82 | parser.add_argument('--alpha', type=float, default=0.99, metavar='ALPHA', 83 | help="ALPHA in label propagation") 84 | parser.add_argument('--rn', type=int, default=300, metavar='RN', 85 | help="relation types" + 86 | "30:learned sigma and alpha, 300:learned sigma, fixed alpha") 87 | 88 | # restore params 89 | parser.add_argument('--iters', type=int, default=0, metavar='ITERS', 90 | help="iteration to restore params") 91 | parser.add_argument('--exp_name', type=str, default='exp', metavar='EXPNAME', 92 | help="experiment description name") 93 | 94 | 95 | args = vars(parser.parse_args()) 96 | im_width, im_height, channels = list(map(int, args['x_dim'].split(','))) 97 | print(args) 98 | for key,v in args.items(): exec(key+'=v') 99 | 100 | 101 | #seed = 1423 102 | #random.seed(seed) 103 | #np.random.seed(seed) 104 | #tf.set_random_seed(seed) 105 | 106 | 107 | os.environ["CUDA_VISIBLE_DEVICES"] = args['gpu'] 108 | is_training = False 109 | 110 | 111 | 112 | # construct dataset 113 | if dataset=='mini': 114 | loader_test = dataset_mini(n_examples, n_episodes, 'test', args) 115 | elif dataset=='tiered': 116 | loader_test = dataset_tiered(n_examples, n_episodes, 'test', args) 117 | 118 | if not pkl: 119 | loader_test.load_data() 120 | else: 121 | loader_test.load_data_pkl() 122 | 123 | 124 | # construct model 125 | m = models(args) 126 | ce_loss,acc,sigma_value = m.construct() 127 | 128 | 129 | # init session and start training 130 | config = tf.ConfigProto() 131 | config.gpu_options.allow_growth=True 132 | sess = tf.Session(config=config) 133 | 134 | init_op = tf.global_variables_initializer() 135 | sess.run(init_op) 136 | 137 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) 138 | save_dir = 'checkpoints/'+args['exp_name'] 139 | 140 | model_path = save_dir+'/models' 141 | # restore pre-trained model 142 | if iters>0: 143 | ckpt_path = model_path+'/ckpt-'+str(iters) 144 | 145 | saver.restore(sess, ckpt_path) 146 | print('Load model from {}'.format(ckpt_path)) 147 | 148 | 149 | 150 | print("Testing...") 151 | 152 | all_acc = [] 153 | all_std = [] 154 | all_ci95 = [] 155 | 156 | nums = n_test_way * (n_test_shot+n_test_query) 157 | 158 | 159 | list_acc = [] 160 | # test epochs 161 | for epi in tqdm(range(n_test_episodes), desc='test'): 162 | support, s_labels, query, q_labels, unlabel = loader_test.next_data(n_test_way, n_test_shot, n_test_query, train=False) 163 | vls, vac, vsigma = sess.run([ce_loss, acc, sigma_value], feed_dict={m.x: support, m.ys:s_labels, m.q: query, m.y:q_labels, m.phase:0}) 164 | list_acc.append(vac) 165 | 166 | mean_acc = np.mean(list_acc) 167 | std_acc = np.std(list_acc) 168 | ci95 = 1.96*std_acc/np.sqrt(n_test_episodes) 169 | 170 | print('Acc:{:.4f},std:{:.4f},ci95:{:.4f}'.format(mean_acc, std_acc, ci95)) 171 | 172 | 173 | 174 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #------------------------------------- 2 | # Paper: Learning to Propagate Labels: Transductive Propagation Network for Few-shot Learning 3 | # Date: 2018.11.17 4 | # Author: Anonymous 5 | # All Rights Reserved 6 | #------------------------------------- 7 | 8 | from __future__ import print_function 9 | from PIL import Image 10 | import numpy as np 11 | import tensorflow as tf 12 | import os 13 | import glob 14 | import csv 15 | from models import * 16 | from dataset_mini import * 17 | from dataset_tiered import * 18 | from tqdm import tqdm 19 | import argparse 20 | import random 21 | 22 | parser = argparse.ArgumentParser(description='Train TPN') 23 | 24 | # parse gpu 25 | parser.add_argument('--gpu', type=str, default=0, metavar='GPU', 26 | help="gpu name, default:0") 27 | 28 | # model params 29 | n_examples = 600 30 | parser.add_argument('--x_dim', type=str, default="84,84,3", metavar='XDIM', 31 | help='input image dims') 32 | parser.add_argument('--h_dim', type=int, default=64, metavar='HDIM', 33 | help="channels of hidden conv layers (default: 64)") 34 | parser.add_argument('--z_dim', type=int, default=64, metavar='ZDIM', 35 | help="channels of last conv layer (default: 64)") 36 | 37 | # training hyper-parameters 38 | n_episodes = 100 39 | parser.add_argument('--n_way', type=int, default=5, metavar='NWAY', 40 | help="nway") 41 | parser.add_argument('--n_shot', type=int, default=5, metavar='NSHOT', 42 | help="nshot") 43 | parser.add_argument('--n_query', type=int, default=15, metavar='NQUERY', 44 | help="nquery") 45 | parser.add_argument('--n_epochs', type=int, default=2100, metavar='NEPOCHS', 46 | help="nepochs") 47 | # test hyper-parameters 48 | parser.add_argument('--n_test_way', type=int, default=5, metavar='NTESTWAY', 49 | help="ntestway") 50 | parser.add_argument('--n_test_shot',type=int, default=5, metavar='NTESTSHOT', 51 | help="ntestshot") 52 | parser.add_argument('--n_test_query',type=int, default=15, metavar='NTESTQUERY', 53 | help="ntestquery") 54 | 55 | # optimization params 56 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 57 | help="base learning rate") 58 | parser.add_argument('--step_size', type=int, default=10000, metavar='DSTEP', 59 | help="step_size") 60 | parser.add_argument('--gamma', type=float, default=0.5, metavar='DRATE', 61 | help="gamma") 62 | parser.add_argument('--patience', type=int, default=200, metavar='PATIENCE', 63 | help="patience") 64 | 65 | # dataset params 66 | parser.add_argument('--dataset', type=str, default='mini',metavar='DATASET', 67 | help="mini or tiered") 68 | parser.add_argument('--ratio', type=float, default=1.0, metavar='RATIO', 69 | help="ratio of labeled data (for semi-supervised setting") 70 | parser.add_argument('--pkl', type=int, default=1, metavar='PKL', 71 | help="1 for use pkl dataset, 0 for original images") 72 | 73 | # label propagation params 74 | parser.add_argument('--k', type=int, default=20, metavar='K', 75 | help="top k in constructing the graph W") 76 | parser.add_argument('--sigma', type=float, default=0.25, metavar='SIGMA', 77 | help="sigma of graph computing parameter") 78 | parser.add_argument('--alpha', type=float, default=0.99, metavar='ALPHA', 79 | help="alpha in label propagation") 80 | parser.add_argument('--rn', type=int, default=300, metavar='RN', 81 | help="graph construction types: " 82 | "300: sigma is learned, alpha is fixed" + 83 | "30: both sigma and alpha learned") 84 | 85 | # seed and exp_name 86 | parser.add_argument('--seed', type=int, default=1000, metavar='SEED', 87 | help="random seed, -1 means no seed") 88 | parser.add_argument('--exp_name', type=str, default='exp', metavar='EXPNAME', 89 | help="experiment description name") 90 | parser.add_argument('--iters', type=int, default=0, metavar='ITERS', 91 | help="checkpoint restore iters") 92 | 93 | 94 | 95 | # deal with params 96 | args = vars(parser.parse_args()) 97 | im_width, im_height, channels = list(map(int, args['x_dim'].split(','))) 98 | for key,v in args.items(): exec(key+'=v') 99 | 100 | ## RANDOM SEED 101 | #random.seed(seed) 102 | #np.random.seed(seed) 103 | #tf.set_random_seed(seed) 104 | 105 | # set environment variables 106 | os.environ["CUDA_VISIBLE_DEVICES"] = args['gpu'] 107 | is_training = True 108 | 109 | 110 | # deal with checkpoints save folder 111 | def _init_(): 112 | if not os.path.exists('checkpoints'): 113 | os.makedirs('checkpoints') 114 | if not os.path.exists('checkpoints/'+args['exp_name']): 115 | os.makedirs('checkpoints/'+args['exp_name']) 116 | if not os.path.exists('checkpoints/'+args['exp_name']+'/'+'models'): 117 | os.makedirs('checkpoints/'+args['exp_name']+'/'+'models') 118 | if not os.path.exists('checkpoints/'+args['exp_name']+'/'+'summaries'): 119 | os.makedirs('checkpoints/'+args['exp_name']+'/'+'summaries') 120 | os.system('cp train.py checkpoints'+'/'+args['exp_name']+'/'+'train.py.backup') 121 | os.system('cp models.py checkpoints' + '/' + args['exp_name'] + '/' + 'models.py.backup') 122 | f = open('checkpoints/'+args['exp_name']+'/log.txt', 'a') 123 | print(args, file=f) 124 | f.close() 125 | _init_() 126 | 127 | 128 | # construct dataset 129 | if dataset=='mini': 130 | loader_train = dataset_mini(n_examples, n_episodes, 'train', args) 131 | loader_val = dataset_mini(n_examples, n_episodes, 'val', args) 132 | elif dataset=='tiered': 133 | loader_train = dataset_tiered(n_examples, n_episodes, 'train', args) 134 | loader_val = dataset_tiered(n_examples, n_episodes, 'val', args) 135 | 136 | if pkl==0: 137 | print('Load image data rather than PKL') 138 | loader_train.load_data() 139 | loader_val.load_data() 140 | else: 141 | print('Load PKL data') 142 | loader_train.load_data_pkl() 143 | loader_val.load_data_pkl() 144 | 145 | 146 | # construct model 147 | m = models(args) 148 | ce_loss,acc,sigma_value = m.construct() 149 | 150 | 151 | # train and stepsize 152 | global_step = tf.Variable(0, name="global_step", trainable=False) 153 | learning_rate = tf.train.exponential_decay(lr, global_step, 154 | step_size, gamma, staircase=True) 155 | # update ops for batch norm 156 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 157 | with tf.control_dependencies(update_ops): 158 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(ce_loss, global_step=global_step) 159 | 160 | # init session and start training 161 | config = tf.ConfigProto() 162 | config.gpu_options.allow_growth=True 163 | sess = tf.Session(config=config) 164 | 165 | init_op = tf.global_variables_initializer() 166 | sess.run(init_op) 167 | 168 | 169 | # summary 170 | save_dir = 'checkpoints/'+args['exp_name'] 171 | 172 | loss_summary = tf.summary.scalar("loss", ce_loss) 173 | acc_summary = tf.summary.scalar("accuracy", acc) 174 | lr_summary = tf.summary.scalar("lr", learning_rate) 175 | sigma_summary = tf.summary.histogram("sigma", sigma_value) 176 | 177 | train_summary_op = tf.summary.merge([loss_summary, acc_summary, lr_summary, sigma_summary]) 178 | train_summary_dir = os.path.join(save_dir, "summaries", "train") 179 | train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph) 180 | 181 | val_summary_op = tf.summary.merge([loss_summary, acc_summary, sigma_summary]) 182 | val_summary_dir = os.path.join(save_dir, "summaries", "val") 183 | val_summary_writer = tf.summary.FileWriter(val_summary_dir, sess.graph) 184 | 185 | 186 | # restore pre-trained model 187 | saver = tf.train.Saver(tf.global_variables(), max_to_keep=100) 188 | model_path = save_dir+'/models' 189 | if iters>0: 190 | ckpt_path = model_path+'/ckpt-'+str(iters) 191 | 192 | saver.restore(sess, ckpt_path) 193 | print('Load model from {}'.format(ckpt_path)) 194 | 195 | 196 | # Train and Val stages 197 | best_acc = 0 198 | best_loss = np.inf 199 | wait = 0 200 | 201 | for ep in range(int(iters/100), n_epochs): 202 | loss_tr = [] 203 | acc_tr = [] 204 | loss_val = [] 205 | acc_val = [] 206 | # run episodes training and then val 207 | for epi in tqdm(range(n_episodes), desc='train epoc:{}'.format(ep)): 208 | if ratio==1.0: 209 | support, s_labels, query, q_labels, _ = loader_train.next_data(n_way, n_shot, n_query) 210 | else: 211 | support, s_labels, query, q_labels, _ = loader_train.next_data_un(n_way, n_shot, n_query) 212 | 213 | _, summaries, step, ls, ac = sess.run([train_op, train_summary_op, global_step, ce_loss, acc], feed_dict={m.x: support, m.ys:s_labels, m.q: query, m.y:q_labels, m.phase:1}) 214 | 215 | train_summary_writer.add_summary(summaries, step) 216 | loss_tr.append(ls) 217 | acc_tr.append(ac) 218 | 219 | # validation after each episode training, and decide if stop after train_patience steps 220 | for epi in tqdm(range(n_episodes), desc='val epoc:{}'.format(ep)): 221 | # validation to decide if stop 222 | support, s_labels, query, q_labels, _ = loader_val.next_data(n_test_way, n_test_shot, n_test_query, train=False) 223 | summaries, vls, vac = sess.run([val_summary_op, ce_loss, acc], feed_dict={m.x: support, m.ys:s_labels, m.q: query, m.y:q_labels, m.phase:0}) 224 | 225 | val_summary_writer.add_summary(summaries, step) 226 | loss_val.append(vls) 227 | acc_val.append(vac) 228 | 229 | print('epoch:{}, loss:{:.5f}, acc:{:.5f}, val, loss:{:.5f}, acc:{:.5f}'.format(ep, np.mean(loss_tr), np.mean(acc_tr), np.mean(loss_val), np.mean(acc_val))) 230 | 231 | 232 | # Model save and stop criterion 233 | cond1 = (np.mean(acc_val)>best_acc) 234 | cond2 = (np.mean(loss_val)patience and ep>n_epochs and rn>=0: 256 | break 257 | 258 | 259 | --------------------------------------------------------------------------------