├── utils ├── __init__.py └── io_utils.py ├── datasets ├── .gitignore └── compressed │ ├── omniglot │ └── readme.md │ └── mini_imagenet │ └── readme.md ├── models ├── __init__.py ├── models.py └── gnn_iclr.py ├── data ├── __init__.py ├── parser.py ├── generator.py ├── omniglot.py └── mini_imagenet.py ├── .gitignore ├── test.py ├── README.md └── main.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | */ 3 | !.gitignore -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import gnn_iclr 2 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from . import generator 3 | from . import parser 4 | -------------------------------------------------------------------------------- /datasets/compressed/omniglot/readme.md: -------------------------------------------------------------------------------- 1 | Download **images_background.zip** and **images_evaluation.zip** files from [brendenlake/omniglot](https://github.com/brendenlake/omniglot/tree/master/python) and copy it inside this folder 2 | -------------------------------------------------------------------------------- /datasets/compressed/mini_imagenet/readme.md: -------------------------------------------------------------------------------- 1 | ## Instructions 2 | You must copy the file **images.zip** inside this folder 3 | 4 | To get **images.zip** file, you should send a mail to victor.few.shot@gmail.com with the subject: **mini_imagenet dataset** 5 | -------------------------------------------------------------------------------- /utils/io_utils.py: -------------------------------------------------------------------------------- 1 | class IOStream(): 2 | def __init__(self, path): 3 | self.f = open(path, 'a') 4 | 5 | def cprint(self, text): 6 | print(text) 7 | self.f.write(text+'\n') 8 | self.f.flush() 9 | 10 | def close(self): 11 | self.f.close() 12 | 13 | -------------------------------------------------------------------------------- /data/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import fnmatch 3 | 4 | 5 | def get_image_paths(source, extension='png'): 6 | images_path, class_names = [], [] 7 | for root, dirnames, filenames in os.walk(source): 8 | filenames = [filename for filename in filenames if '._' not in filename] 9 | for filename in fnmatch.filter(filenames, '*.'+extension): 10 | images_path.append(os.path.join(root, filename)) 11 | class_name = root.split('/') 12 | class_name = class_name[len(class_name)-2:] 13 | class_name = '/'.join(class_name) 14 | class_names.append(class_name) 15 | return class_names, images_path 16 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | *.pickle 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | *.log 37 | *.w 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *,cover 48 | .hypothesis/ 49 | # Translations 50 | *.mo 51 | *.pot 52 | .DS_Store 53 | .idea/ 54 | # Django stuff: 55 | *.log 56 | *.gv 57 | *.pdf 58 | *.DS_Store 59 | *.t7 60 | *.backup 61 | *.JPEG 62 | *._ 63 | local_settings.py 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # IPython Notebook 79 | .ipynb_checkpoints 80 | 81 | # pyenv 82 | .python-version 83 | 84 | # celery beat schedule file 85 | celerybeat-schedule 86 | 87 | # dotenv 88 | .env 89 | 90 | # virtualenv 91 | venv/ 92 | ENV/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import io_utils 3 | from data import generator 4 | from torch.autograd import Variable 5 | 6 | 7 | def test_one_shot(args, model, test_samples=5000, partition='test'): 8 | io = io_utils.IOStream('checkpoints/' + args.exp_name + '/run.log') 9 | 10 | io.cprint('\n**** TESTING WITH %s ***' % (partition,)) 11 | 12 | loader = generator.Generator(args.dataset_root, args, partition=partition, dataset=args.dataset) 13 | 14 | [enc_nn, metric_nn, softmax_module] = model 15 | enc_nn.eval() 16 | metric_nn.eval() 17 | correct = 0 18 | total = 0 19 | iterations = int(test_samples/args.batch_size_test) 20 | for i in range(iterations): 21 | data = loader.get_task_batch(batch_size=args.batch_size_test, n_way=args.test_N_way, 22 | num_shots=args.test_N_shots, unlabeled_extra=args.unlabeled_extra) 23 | [x, labels_x_cpu, _, _, xi_s, labels_yi_cpu, oracles_yi, hidden_labels] = data 24 | 25 | if args.cuda: 26 | xi_s = [batch_xi.cuda() for batch_xi in xi_s] 27 | labels_yi = [label_yi.cuda() for label_yi in labels_yi_cpu] 28 | oracles_yi = [oracle_yi.cuda() for oracle_yi in oracles_yi] 29 | hidden_labels = hidden_labels.cuda() 30 | x = x.cuda() 31 | else: 32 | labels_yi = labels_yi_cpu 33 | 34 | xi_s = [Variable(batch_xi) for batch_xi in xi_s] 35 | labels_yi = [Variable(label_yi) for label_yi in labels_yi] 36 | oracles_yi = [Variable(oracle_yi) for oracle_yi in oracles_yi] 37 | hidden_labels = Variable(hidden_labels) 38 | x = Variable(x) 39 | 40 | # Compute embedding from x and xi_s 41 | z = enc_nn(x)[-1] 42 | zi_s = [enc_nn(batch_xi)[-1] for batch_xi in xi_s] 43 | 44 | # Compute metric from embeddings 45 | output, out_logits = metric_nn(inputs=[z, zi_s, labels_yi, oracles_yi, hidden_labels]) 46 | output = out_logits 47 | y_pred = softmax_module.forward(output) 48 | y_pred = y_pred.data.cpu().numpy() 49 | y_pred = np.argmax(y_pred, axis=1) 50 | labels_x_cpu = labels_x_cpu.numpy() 51 | labels_x_cpu = np.argmax(labels_x_cpu, axis=1) 52 | 53 | for row_i in range(y_pred.shape[0]): 54 | if y_pred[row_i] == labels_x_cpu[row_i]: 55 | correct += 1 56 | total += 1 57 | 58 | if (i+1) % 100 == 0: 59 | io.cprint('{} correct from {} \tAccuracy: {:.3f}%)'.format(correct, total, 100.0*correct/total)) 60 | 61 | io.cprint('{} correct from {} \tAccuracy: {:.3f}%)'.format(correct, total, 100.0*correct/total)) 62 | io.cprint('*** TEST FINISHED ***\n'.format(correct, total, 100.0 * correct / total)) 63 | enc_nn.train() 64 | metric_nn.train() 65 | 66 | return 100.0 * correct / total 67 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Few-Shot Learning with Graph Neural Networks 2 | Implementation of [Few-Shot Learning with Graph Neural Networks](https://arxiv.org/pdf/1711.04043.pdf) on Python3, Pytorch 0.3.1 3 | 4 | 5 | ## Mini-Imagenet 6 | 7 | ### Download the dataset 8 | Create **images.zip** file and copy it inside ```mini_imagenet``` directory: 9 | 10 | 11 | . 12 | ├── ... 13 | └── datasets 14 | └── compressed 15 | └── mini_imagenet 16 | └── images.zip 17 | 18 | The **images.zip** file must contain the splits and images in the following format: 19 | 20 | ── images.zip 21 | ├── test.csv 22 | ├── train.csv 23 | ├── val.csv 24 | └── images 25 | ├── n0153282900000006.jpg 26 | ├── ... 27 | └── n1313361300001299.jpg 28 | 29 | The splits *{test.csv, train.csv, val.csv}* can be downloaded from [Ravi and Larochelle - splits](https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet). For more information on how to obtain the images check the original source [Ravi and Larochelle - github](https://github.com/twitter/meta-learning-lstm) 30 | 31 | 32 | ### Training 33 | 34 | ``` 35 | # 5-Way 1-shot | Few-shot 36 | EXPNAME=minimagenet_N5_S1 37 | python3 main.py --exp_name $EXPNAME --dataset mini_imagenet --test_N_way 5 --train_N_way 5 --train_N_shots 1 --test_N_shots 1 --batch_size 100 --dec_lr=15000 --iterations 80000 38 | 39 | # 5-Way 5-shot | Few-shot 40 | EXPNAME=minimagenet_N5_S5 41 | python3 main.py --exp_name $EXPNAME --dataset mini_imagenet --test_N_way 5 --train_N_way 5 --train_N_shots 5 --test_N_shots 5 --batch_size 40 --dec_lr=15000 --iterations 90000 42 | 43 | # 5-Way 5-shot 20%-labeled | Semi-supervised 44 | EXPNAME=minimagenet_N5_S1_U4 45 | python3 main.py --exp_name $EXPNAME --dataset mini_imagenet --test_N_way 5 --train_N_way 5 --train_N_shots 5 --test_N_shots 5 --unlabeled_extra 4 --batch_size 40 --dec_lr=15000 --iterations 100000 46 | ``` 47 | 48 | 49 | ## Omniglot 50 | 51 | ### Download the dataset 52 | Download **images_background.zip** and **images_evaluation.zip** files from [brendenlake/omniglot](https://github.com/brendenlake/omniglot/tree/master/python) and copy it inside the ```omniglot``` directory: 53 | 54 | . 55 | ├── ... 56 | └── datasets 57 | └── compressed 58 | └── omniglot 59 | ├── images_background.zip 60 | └── images_evaluation.zip 61 | 62 | ### Training 63 | ``` 64 | # 5-Way 1-shot | Few-shot 65 | EXPNAME=omniglot_N5_S1_v2 66 | python3 main.py --exp_name $EXPNAME --dataset omniglot --test_N_way 5 --train_N_way 5 --train_N_shots 1 --test_N_shots 1 --batch_size 300 --dec_lr=10000 --iterations 100000 67 | 68 | # 5-Way 5-shot | Few-shot 69 | EXPNAME=omniglot_N5_S5 70 | python3 main.py --exp_name $EXPNAME --dataset omniglot --test_N_way 5 --train_N_way 5 --train_N_shots 5 --test_N_shots 5 --batch_size 100 --dec_lr=10000 --iterations 80000 71 | 72 | # 20-Way 1-shot | Few-shot 73 | EXPNAME=omniglot_N20_S1 74 | python3 main.py --exp_name $EXPNAME --dataset omniglot --test_N_way 20 --train_N_way 20 --train_N_shots 1 --test_N_shots 1 --batch_size 100 --dec_lr=10000 --iterations 80000 75 | 76 | # 5-Way 5-shot 20%-labeled | Semi-supervised 77 | EXPNAME=omniglot_N5_S1_U4 78 | python3 main.py --exp_name $EXPNAME --dataset omniglot --test_N_way 5 --train_N_way 5 --train_N_shots 5 --test_N_shots 5 --unlabeled_extra 4 --batch_size 100 --dec_lr=10000 --iterations 80000 79 | ``` 80 | 81 | ## Citation 82 | If you find this code useful you can cite us using the following bibTex: 83 | ``` 84 | @article{garcia2017few, 85 | title={Few-Shot Learning with Graph Neural Networks}, 86 | author={Garcia, Victor and Bruna, Joan}, 87 | journal={arXiv preprint arXiv:1711.04043}, 88 | year={2017} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /data/generator.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | import torch 4 | import numpy as np 5 | import random 6 | from torch.autograd import Variable 7 | from . import omniglot 8 | from . import mini_imagenet 9 | 10 | 11 | class Generator(data.Dataset): 12 | def __init__(self, root, args, partition='train', dataset='omniglot'): 13 | self.root = root 14 | self.partition = partition # training set or test set 15 | self.args = args 16 | 17 | assert (dataset == 'omniglot' or 18 | dataset == 'mini_imagenet'), 'Incorrect dataset partition' 19 | self.dataset = dataset 20 | 21 | if self.dataset == 'omniglot': 22 | self.input_channels = 1 23 | self.size = (28, 28) 24 | else: 25 | self.input_channels = 3 26 | self.size = (84, 84) 27 | 28 | if dataset == 'omniglot': 29 | self.loader = omniglot.Omniglot(self.root, dataset=dataset) 30 | self.data = self.loader.load_dataset(self.partition == 'train', self.size) 31 | elif dataset == 'mini_imagenet': 32 | self.loader = mini_imagenet.MiniImagenet(self.root) 33 | self.data, self.label_encoder = self.loader.load_dataset(self.partition, self.size) 34 | else: 35 | raise NotImplementedError 36 | 37 | self.class_encoder = {} 38 | for id_key, key in enumerate(self.data): 39 | self.class_encoder[key] = id_key 40 | 41 | def rotate_image(self, image, times): 42 | rotated_image = np.zeros(image.shape) 43 | for channel in range(image.shape[0]): 44 | rotated_image[channel, :, :] = np.rot90(image[channel, :, :], k=times) 45 | return rotated_image 46 | 47 | def get_task_batch(self, batch_size=5, n_way=20, num_shots=1, unlabeled_extra=0, cuda=False, variable=False): 48 | # Init variables 49 | batch_x = np.zeros((batch_size, self.input_channels, self.size[0], self.size[1]), dtype='float32') 50 | labels_x = np.zeros((batch_size, n_way), dtype='float32') 51 | labels_x_global = np.zeros(batch_size, dtype='int64') 52 | target_distances = np.zeros((batch_size, n_way * num_shots), dtype='float32') 53 | hidden_labels = np.zeros((batch_size, n_way * num_shots + 1), dtype='float32') 54 | numeric_labels = [] 55 | batches_xi, labels_yi, oracles_yi = [], [], [] 56 | for i in range(n_way*num_shots): 57 | batches_xi.append(np.zeros((batch_size, self.input_channels, self.size[0], self.size[1]), dtype='float32')) 58 | labels_yi.append(np.zeros((batch_size, n_way), dtype='float32')) 59 | oracles_yi.append(np.zeros((batch_size, n_way), dtype='float32')) 60 | # Iterate over tasks for the same batch 61 | 62 | for batch_counter in range(batch_size): 63 | positive_class = random.randint(0, n_way - 1) 64 | 65 | # Sample random classes for this TASK 66 | classes_ = list(self.data.keys()) 67 | sampled_classes = random.sample(classes_, n_way) 68 | indexes_perm = np.random.permutation(n_way * num_shots) 69 | 70 | counter = 0 71 | for class_counter, class_ in enumerate(sampled_classes): 72 | if class_counter == positive_class: 73 | # We take num_shots + one sample for one class 74 | samples = random.sample(self.data[class_], num_shots+1) 75 | # Test sample is loaded 76 | batch_x[batch_counter, :, :, :] = samples[0] 77 | labels_x[batch_counter, class_counter] = 1 78 | labels_x_global[batch_counter] = self.class_encoder[class_] 79 | samples = samples[1::] 80 | else: 81 | samples = random.sample(self.data[class_], num_shots) 82 | 83 | for s_i in range(0, len(samples)): 84 | batches_xi[indexes_perm[counter]][batch_counter, :, :, :] = samples[s_i] 85 | if s_i < unlabeled_extra: 86 | labels_yi[indexes_perm[counter]][batch_counter, class_counter] = 0 87 | hidden_labels[batch_counter, indexes_perm[counter] + 1] = 1 88 | else: 89 | labels_yi[indexes_perm[counter]][batch_counter, class_counter] = 1 90 | oracles_yi[indexes_perm[counter]][batch_counter, class_counter] = 1 91 | target_distances[batch_counter, indexes_perm[counter]] = 0 92 | counter += 1 93 | 94 | numeric_labels.append(positive_class) 95 | 96 | batches_xi = [torch.from_numpy(batch_xi) for batch_xi in batches_xi] 97 | labels_yi = [torch.from_numpy(label_yi) for label_yi in labels_yi] 98 | oracles_yi = [torch.from_numpy(oracle_yi) for oracle_yi in oracles_yi] 99 | 100 | labels_x_scalar = np.argmax(labels_x, 1) 101 | 102 | return_arr = [torch.from_numpy(batch_x), torch.from_numpy(labels_x), torch.from_numpy(labels_x_scalar), 103 | torch.from_numpy(labels_x_global), batches_xi, labels_yi, oracles_yi, 104 | torch.from_numpy(hidden_labels)] 105 | if cuda: 106 | return_arr = self.cast_cuda(return_arr) 107 | if variable: 108 | return_arr = self.cast_variable(return_arr) 109 | return return_arr 110 | 111 | def cast_cuda(self, input): 112 | if type(input) == type([]): 113 | for i in range(len(input)): 114 | input[i] = self.cast_cuda(input[i]) 115 | else: 116 | return input.cuda() 117 | return input 118 | 119 | def cast_variable(self, input): 120 | if type(input) == type([]): 121 | for i in range(len(input)): 122 | input[i] = self.cast_variable(input[i]) 123 | else: 124 | return Variable(input) 125 | 126 | return input 127 | -------------------------------------------------------------------------------- /data/omniglot.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | import os 4 | import os.path 5 | import numpy as np 6 | from PIL import Image as pil_image 7 | import pickle 8 | import random 9 | from . import parser 10 | 11 | 12 | class Omniglot(data.Dataset): 13 | def __init__(self, root, dataset='omniglot'): 14 | self.root = root 15 | self.seed = 10 16 | self.dataset = dataset 17 | if not self._check_exists_(): 18 | self._init_folders_() 19 | if self.check_decompress(): 20 | self._decompress_() 21 | self._preprocess_() 22 | 23 | def _init_folders_(self): 24 | decompress = False 25 | if not os.path.exists(self.root): 26 | os.makedirs(self.root) 27 | if not os.path.exists(os.path.join(self.root, 'omniglot')): 28 | os.makedirs(os.path.join(self.root, 'omniglot')) 29 | decompress = True 30 | if not os.path.exists(os.path.join(self.root, 'omniglot', 'train')): 31 | os.makedirs(os.path.join(self.root, 'omniglot', 'train')) 32 | decompress = True 33 | if not os.path.exists(os.path.join(self.root, 'omniglot', 'test')): 34 | os.makedirs(os.path.join(self.root, 'omniglot', 'test')) 35 | decompress = True 36 | if not os.path.exists(os.path.join(self.root, 'compacted_datasets')): 37 | os.makedirs(os.path.join(self.root, 'compacted_datasets')) 38 | decompress = True 39 | return decompress 40 | 41 | def check_decompress(self): 42 | return os.listdir('%s/omniglot/test' % self.root) == [] 43 | 44 | def _decompress_(self): 45 | print("\nDecompressing Images...") 46 | comp_files = ['%s/compressed/omniglot/images_background.zip' % self.root, 47 | '%s/compressed/omniglot/images_evaluation.zip' % self.root] 48 | if os.path.isfile(comp_files[0]) and os.path.isfile(comp_files[1]): 49 | os.system(('unzip %s -d ' % comp_files[0]) + 50 | os.path.join(self.root, 'omniglot', 'train')) 51 | os.system(('unzip %s -d ' % comp_files[1]) + 52 | os.path.join(self.root, 'omniglot', 'test')) 53 | else: 54 | raise Exception('Missing %s or %s' % (comp_files[0], comp_files[1])) 55 | print("Decompressed") 56 | 57 | def _check_exists_(self): 58 | return os.path.exists(os.path.join(self.root, 'compacted_datasets', 'omniglot_train.pickle')) and \ 59 | os.path.exists(os.path.join(self.root, 'compacted_datasets', 'omniglot_test.pickle')) 60 | 61 | def _preprocess_(self): 62 | print('\nPreprocessing Omniglot images...') 63 | (class_names_train, images_path_train) = parser.get_image_paths(os.path.join(self.root, 'omniglot', 'train')) 64 | (class_names_test, images_path_test) = parser.get_image_paths(os.path.join(self.root, 'omniglot', 'test')) 65 | 66 | keys_all = sorted(list(set(class_names_train + class_names_test))) 67 | label_encoder = {} 68 | label_decoder = {} 69 | for i in range(len(keys_all)): 70 | label_encoder[keys_all[i]] = i 71 | label_decoder[i] = keys_all[i] 72 | 73 | all_set = {} 74 | for class_, path in zip(class_names_train + class_names_test, images_path_train + images_path_test): 75 | img = np.array(pil_image.open(path), dtype='float32') 76 | if label_encoder[class_] not in all_set: 77 | all_set[label_encoder[class_]] = [] 78 | all_set[label_encoder[class_]].append(img) 79 | 80 | # Now we save the 1200 training - 423 testing partition 81 | keys = sorted(list(all_set.keys())) 82 | random.seed(self.seed) 83 | random.shuffle(keys) 84 | 85 | train_set = {} 86 | test_set = {} 87 | for i in range(1200): 88 | train_set[keys[i]] = all_set[keys[i]] 89 | for i in range(1200, len(keys)): 90 | test_set[keys[i]] = all_set[keys[i]] 91 | 92 | self.sanity_check(all_set) 93 | 94 | with open(os.path.join(self.root, 'compacted_datasets', 'omniglot_train.pickle'), 'wb') as handle: 95 | pickle.dump(train_set, handle, protocol=2) 96 | with open(os.path.join(self.root, 'compacted_datasets', 'omniglot_test.pickle'), 'wb') as handle: 97 | pickle.dump(test_set, handle, protocol=2) 98 | 99 | with open(os.path.join(self.root, 'compacted_datasets', 'omniglot_label_encoder.pickle'), 'wb') as handle: 100 | pickle.dump(label_encoder, handle, protocol=2) 101 | with open(os.path.join(self.root, 'compacted_datasets', 'omniglot_label_decoder.pickle'), 'wb') as handle: 102 | pickle.dump(label_decoder, handle, protocol=2) 103 | 104 | print('Images preprocessed') 105 | 106 | def sanity_check(self, all_set): 107 | all_good = True 108 | for class_ in all_set: 109 | if len(all_set[class_]) != 20: 110 | all_good = False 111 | if all_good: 112 | print("All classes have 20 samples") 113 | 114 | def load_dataset(self, train, size): 115 | print("Loading dataset") 116 | if train: 117 | with open(os.path.join(self.root, 'compacted_datasets', 'omniglot_train.pickle'), 'rb') as handle: 118 | data = pickle.load(handle) 119 | else: 120 | with open(os.path.join(self.root, 'compacted_datasets', 'omniglot_test.pickle'), 'rb') as handle: 121 | data = pickle.load(handle) 122 | print("Num classes before rotations: "+str(len(data))) 123 | 124 | data_rot = {} 125 | # resize images and normalize 126 | for class_ in data: 127 | for rot in range(4): 128 | data_rot[class_ * 4 + rot] = [] 129 | for i in range(len(data[class_])): 130 | image2resize = pil_image.fromarray(np.uint8(data[class_][i]*255)) 131 | image_resized = image2resize.resize((size[1], size[0])) 132 | image_resized = np.array(image_resized, dtype='float32')/127.5 - 1 133 | image = self.rotate_image(image_resized, rot) 134 | image = np.expand_dims(image, axis=0) 135 | data_rot[class_ * 4 + rot].append(image) 136 | 137 | print("Dataset Loaded") 138 | print("Num classes after rotations: "+str(len(data_rot))) 139 | self.sanity_check(data_rot) 140 | return data_rot 141 | 142 | def rotate_image(self, image, times): 143 | rotated_image = np.zeros(image.shape) 144 | for channel in range(image.shape[0]): 145 | rotated_image[:, :] = np.rot90(image[:, :], k=times) 146 | return rotated_image -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | from models import gnn_iclr 7 | 8 | 9 | class EmbeddingOmniglot(nn.Module): 10 | ''' In this network the input image is supposed to be 28x28 ''' 11 | 12 | def __init__(self, args, emb_size): 13 | super(EmbeddingOmniglot, self).__init__() 14 | self.emb_size = emb_size 15 | self.nef = 64 16 | self.args = args 17 | 18 | # input is 1 x 28 x 28 19 | self.conv1 = nn.Conv2d(1, self.nef, 3, padding=1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(self.nef) 21 | # state size. (nef) x 14 x 14 22 | self.conv2 = nn.Conv2d(self.nef, self.nef, 3, padding=1, bias=False) 23 | self.bn2 = nn.BatchNorm2d(self.nef) 24 | 25 | # state size. (1.5*ndf) x 7 x 7 26 | self.conv3 = nn.Conv2d(self.nef, self.nef, 3, bias=False) 27 | self.bn3 = nn.BatchNorm2d(self.nef) 28 | # state size. (2*ndf) x 5 x 5 29 | self.conv4 = nn.Conv2d(self.nef, self.nef, 3, bias=False) 30 | self.bn4 = nn.BatchNorm2d(self.nef) 31 | # state size. (2*ndf) x 3 x 3 32 | self.fc_last = nn.Linear(3 * 3 * self.nef, self.emb_size, bias=False) 33 | self.bn_last = nn.BatchNorm1d(self.emb_size) 34 | 35 | def forward(self, inputs): 36 | e1 = F.max_pool2d(self.bn1(self.conv1(inputs)), 2) 37 | x = F.leaky_relu(e1, 0.1, inplace=True) 38 | 39 | e2 = F.max_pool2d(self.bn2(self.conv2(x)), 2) 40 | x = F.leaky_relu(e2, 0.1, inplace=True) 41 | 42 | e3 = self.bn3(self.conv3(x)) 43 | x = F.leaky_relu(e3, 0.1, inplace=True) 44 | e4 = self.bn4(self.conv4(x)) 45 | x = F.leaky_relu(e4, 0.1, inplace=True) 46 | x = x.view(-1, 3 * 3 * self.nef) 47 | 48 | output = F.leaky_relu(self.bn_last(self.fc_last(x))) 49 | 50 | return [e1, e2, e3, output] 51 | 52 | 53 | class EmbeddingImagenet(nn.Module): 54 | ''' In this network the input image is supposed to be 28x28 ''' 55 | 56 | def __init__(self, args, emb_size): 57 | super(EmbeddingImagenet, self).__init__() 58 | self.emb_size = emb_size 59 | self.ndf = 64 60 | self.args = args 61 | 62 | # Input 84x84x3 63 | self.conv1 = nn.Conv2d(3, self.ndf, kernel_size=3, stride=1, padding=1, bias=False) 64 | self.bn1 = nn.BatchNorm2d(self.ndf) 65 | 66 | # Input 42x42x64 67 | self.conv2 = nn.Conv2d(self.ndf, int(self.ndf*1.5), kernel_size=3, bias=False) 68 | self.bn2 = nn.BatchNorm2d(int(self.ndf*1.5)) 69 | 70 | # Input 20x20x96 71 | self.conv3 = nn.Conv2d(int(self.ndf*1.5), self.ndf*2, kernel_size=3, padding=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(self.ndf*2) 73 | self.drop_3 = nn.Dropout2d(0.4) 74 | 75 | # Input 10x10x128 76 | self.conv4 = nn.Conv2d(self.ndf*2, self.ndf*4, kernel_size=3, padding=1, bias=False) 77 | self.bn4 = nn.BatchNorm2d(self.ndf*4) 78 | self.drop_4 = nn.Dropout2d(0.5) 79 | 80 | # Input 5x5x256 81 | self.fc1 = nn.Linear(self.ndf*4*5*5, self.emb_size, bias=True) 82 | self.bn_fc = nn.BatchNorm1d(self.emb_size) 83 | 84 | def forward(self, input): 85 | e1 = F.max_pool2d(self.bn1(self.conv1(input)), 2) 86 | x = F.leaky_relu(e1, 0.2, inplace=True) 87 | e2 = F.max_pool2d(self.bn2(self.conv2(x)), 2) 88 | x = F.leaky_relu(e2, 0.2, inplace=True) 89 | e3 = F.max_pool2d(self.bn3(self.conv3(x)), 2) 90 | x = F.leaky_relu(e3, 0.2, inplace=True) 91 | x = self.drop_3(x) 92 | e4 = F.max_pool2d(self.bn4(self.conv4(x)), 2) 93 | x = F.leaky_relu(e4, 0.2, inplace=True) 94 | x = self.drop_4(x) 95 | x = x.view(-1, self.ndf*4*5*5) 96 | output = self.bn_fc(self.fc1(x)) 97 | 98 | return [e1, e2, e3, e4, None, output] 99 | 100 | 101 | class MetricNN(nn.Module): 102 | def __init__(self, args, emb_size): 103 | super(MetricNN, self).__init__() 104 | 105 | self.metric_network = args.metric_network 106 | self.emb_size = emb_size 107 | self.args = args 108 | 109 | if self.metric_network == 'gnn_iclr_nl': 110 | assert(self.args.train_N_way == self.args.test_N_way) 111 | num_inputs = self.emb_size + self.args.train_N_way 112 | if self.args.dataset == 'mini_imagenet': 113 | self.gnn_obj = gnn_iclr.GNN_nl(args, num_inputs, nf=96, J=1) 114 | elif 'omniglot' in self.args.dataset: 115 | self.gnn_obj = gnn_iclr.GNN_nl_omniglot(args, num_inputs, nf=96, J=1) 116 | elif self.metric_network == 'gnn_iclr_active': 117 | assert(self.args.train_N_way == self.args.test_N_way) 118 | num_inputs = self.emb_size + self.args.train_N_way 119 | self.gnn_obj = gnn_iclr.GNN_active(args, num_inputs, 96, J=1) 120 | else: 121 | raise NotImplementedError 122 | 123 | def gnn_iclr_forward(self, z, zi_s, labels_yi): 124 | # Creating WW matrix 125 | zero_pad = Variable(torch.zeros(labels_yi[0].size())) 126 | if self.args.cuda: 127 | zero_pad = zero_pad.cuda() 128 | 129 | labels_yi = [zero_pad] + labels_yi 130 | zi_s = [z] + zi_s 131 | 132 | nodes = [torch.cat([zi, label_yi], 1) for zi, label_yi in zip(zi_s, labels_yi)] 133 | nodes = [node.unsqueeze(1) for node in nodes] 134 | nodes = torch.cat(nodes, 1) 135 | 136 | logits = self.gnn_obj(nodes).squeeze(-1) 137 | outputs = F.sigmoid(logits) 138 | 139 | return outputs, logits 140 | 141 | def gnn_iclr_active_forward(self, z, zi_s, labels_yi, oracles_yi, hidden_layers): 142 | # Creating WW matrix 143 | zero_pad = Variable(torch.zeros(labels_yi[0].size())) 144 | if self.args.cuda: 145 | zero_pad = zero_pad.cuda() 146 | 147 | labels_yi = [zero_pad] + labels_yi 148 | zi_s = [z] + zi_s 149 | 150 | nodes = [torch.cat([label_yi, zi], 1) for zi, label_yi in zip(zi_s, labels_yi)] 151 | nodes = [node.unsqueeze(1) for node in nodes] 152 | nodes = torch.cat(nodes, 1) 153 | 154 | oracles_yi = [zero_pad] + oracles_yi 155 | oracles_yi = [oracle_yi.unsqueeze(1) for oracle_yi in oracles_yi] 156 | oracles_yi = torch.cat(oracles_yi, 1) 157 | 158 | logits = self.gnn_obj(nodes, oracles_yi, hidden_layers).squeeze(-1) 159 | outputs = F.sigmoid(logits) 160 | 161 | return outputs, logits 162 | 163 | def forward(self, inputs): 164 | '''input: [batch_x, [batches_xi], [labels_yi]]''' 165 | [z, zi_s, labels_yi, oracles_yi, hidden_labels] = inputs 166 | 167 | if 'gnn_iclr_active' in self.metric_network: 168 | return self.gnn_iclr_active_forward(z, zi_s, labels_yi, oracles_yi, hidden_labels) 169 | elif 'gnn_iclr' in self.metric_network: 170 | return self.gnn_iclr_forward(z, zi_s, labels_yi) 171 | else: 172 | raise NotImplementedError 173 | 174 | 175 | class SoftmaxModule(): 176 | def __init__(self): 177 | self.softmax_metric = 'log_softmax' 178 | 179 | def forward(self, outputs): 180 | if self.softmax_metric == 'log_softmax': 181 | return F.log_softmax(outputs) 182 | else: 183 | raise(NotImplementedError) 184 | 185 | 186 | def load_model(model_name, args, io): 187 | try: 188 | model = torch.load('checkpoints/%s/models/%s.t7' % (args.exp_name, model_name)) 189 | io.cprint('Loading Parameters from the last trained %s Model' % model_name) 190 | return model 191 | except: 192 | io.cprint('Initiallize new Network Weights for %s' % model_name) 193 | pass 194 | return None 195 | 196 | 197 | def create_models(args): 198 | print (args.dataset) 199 | 200 | if 'omniglot' == args.dataset: 201 | enc_nn = EmbeddingOmniglot(args, 64) 202 | elif 'mini_imagenet' == args.dataset: 203 | enc_nn = EmbeddingImagenet(args, 128) 204 | else: 205 | raise NameError('Dataset ' + args.dataset + ' not knows') 206 | return enc_nn, MetricNN(args, emb_size=enc_nn.emb_size) 207 | -------------------------------------------------------------------------------- /data/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch.utils.data as data 3 | import os 4 | import os.path 5 | import numpy as np 6 | from PIL import Image as pil_image 7 | import pickle 8 | 9 | 10 | class MiniImagenet(data.Dataset): 11 | def __init__(self, root, dataset='mini_imagenet'): 12 | self.root = root 13 | self.dataset = dataset 14 | if not self._check_exists_(): 15 | self._init_folders_() 16 | if self.check_decompress(): 17 | self._decompress_() 18 | self._preprocess_() 19 | 20 | def _init_folders_(self): 21 | decompress = False 22 | if not os.path.exists(self.root): 23 | os.makedirs(self.root) 24 | if not os.path.exists(os.path.join(self.root, 'mini_imagenet')): 25 | os.makedirs(os.path.join(self.root, 'mini_imagenet')) 26 | decompress = True 27 | if not os.path.exists(os.path.join(self.root, 'compacted_datasets')): 28 | os.makedirs(os.path.join(self.root, 'compacted_datasets')) 29 | decompress = True 30 | return decompress 31 | 32 | def check_decompress(self): 33 | return os.listdir('%s/mini_imagenet' % self.root) == [] 34 | 35 | def _decompress_(self): 36 | print("\nDecompressing Images...") 37 | compressed_file = '%s/compressed/mini_imagenet/images.zip' % self.root 38 | if os.path.isfile(compressed_file): 39 | os.system('unzip %s -d %s/mini_imagenet/' % (compressed_file, self.root)) 40 | else: 41 | raise Exception('Missing %s' % compressed_file) 42 | print("Decompressed") 43 | 44 | def _check_exists_(self): 45 | if not os.path.exists(os.path.join(self.root, 'compacted_datasets', 'mini_imagenet_train.pickle')) or not \ 46 | os.path.exists(os.path.join(self.root, 'compacted_datasets', 'mini_imagenet_test.pickle')): 47 | return False 48 | else: 49 | return True 50 | 51 | def get_image_paths(self, file): 52 | images_path, class_names = [], [] 53 | with open(file, 'r') as f: 54 | f.readline() 55 | for line in f: 56 | name, class_ = line.split(',') 57 | class_ = class_[0:(len(class_)-1)] 58 | path = self.root + '/mini_imagenet/images/'+name 59 | images_path.append(path) 60 | class_names.append(class_) 61 | return class_names, images_path 62 | 63 | def _preprocess_(self): 64 | print('\nPreprocessing Mini-Imagenet images...') 65 | (class_names_train, images_path_train) = self.get_image_paths('%s/mini_imagenet/train.csv' % self.root) 66 | (class_names_test, images_path_test) = self.get_image_paths('%s/mini_imagenet/test.csv' % self.root) 67 | (class_names_val, images_path_val) = self.get_image_paths('%s/mini_imagenet/val.csv' % self.root) 68 | 69 | keys_train = list(set(class_names_train)) 70 | keys_test = list(set(class_names_test)) 71 | keys_val = list(set(class_names_val)) 72 | label_encoder = {} 73 | label_decoder = {} 74 | for i in range(len(keys_train)): 75 | label_encoder[keys_train[i]] = i 76 | label_decoder[i] = keys_train[i] 77 | for i in range(len(keys_train), len(keys_train)+len(keys_test)): 78 | label_encoder[keys_test[i-len(keys_train)]] = i 79 | label_decoder[i] = keys_test[i-len(keys_train)] 80 | for i in range(len(keys_train)+len(keys_test), len(keys_train)+len(keys_test)+len(keys_val)): 81 | label_encoder[keys_val[i-len(keys_train) - len(keys_test)]] = i 82 | label_decoder[i] = keys_val[i-len(keys_train)-len(keys_test)] 83 | 84 | counter = 0 85 | train_set = {} 86 | for class_, path in zip(class_names_train, images_path_train): 87 | img = pil_image.open(path) 88 | img = img.convert('RGB') 89 | img = img.resize((84, 84), pil_image.ANTIALIAS) 90 | img = np.array(img, dtype='float32') 91 | if label_encoder[class_] not in train_set: 92 | train_set[label_encoder[class_]] = [] 93 | train_set[label_encoder[class_]].append(img) 94 | counter += 1 95 | if counter % 1000 == 0: 96 | print("Counter "+str(counter) + " from " + str(len(images_path_train) + len(class_names_test) + 97 | len(class_names_val))) 98 | 99 | test_set = {} 100 | for class_, path in zip(class_names_test, images_path_test): 101 | img = pil_image.open(path) 102 | img = img.convert('RGB') 103 | img = img.resize((84, 84), pil_image.ANTIALIAS) 104 | img = np.array(img, dtype='float32') 105 | 106 | if label_encoder[class_] not in test_set: 107 | test_set[label_encoder[class_]] = [] 108 | test_set[label_encoder[class_]].append(img) 109 | counter += 1 110 | if counter % 1000 == 0: 111 | print("Counter " + str(counter) + " from "+str(len(images_path_train) + len(class_names_test) + 112 | len(class_names_val))) 113 | 114 | val_set = {} 115 | for class_, path in zip(class_names_val, images_path_val): 116 | img = pil_image.open(path) 117 | img = img.convert('RGB') 118 | img = img.resize((84, 84), pil_image.ANTIALIAS) 119 | img = np.array(img, dtype='float32') 120 | 121 | if label_encoder[class_] not in val_set: 122 | val_set[label_encoder[class_]] = [] 123 | val_set[label_encoder[class_]].append(img) 124 | counter += 1 125 | if counter % 1000 == 0: 126 | print("Counter "+str(counter) + " from " + str(len(images_path_train) + len(class_names_test) + 127 | len(class_names_val))) 128 | 129 | with open(os.path.join(self.root, 'compacted_datasets', 'mini_imagenet_train.pickle'), 'wb') as handle: 130 | pickle.dump(train_set, handle, protocol=2) 131 | with open(os.path.join(self.root, 'compacted_datasets', 'mini_imagenet_test.pickle'), 'wb') as handle: 132 | pickle.dump(test_set, handle, protocol=2) 133 | with open(os.path.join(self.root, 'compacted_datasets', 'mini_imagenet_val.pickle'), 'wb') as handle: 134 | pickle.dump(val_set, handle, protocol=2) 135 | 136 | label_encoder = {} 137 | keys = list(train_set.keys()) + list(test_set.keys()) 138 | for id_key, key in enumerate(keys): 139 | label_encoder[key] = id_key 140 | with open(os.path.join(self.root, 'compacted_datasets', 'mini_imagenet_label_encoder.pickle'), 'wb') as handle: 141 | pickle.dump(label_encoder, handle, protocol=2) 142 | 143 | print('Images preprocessed') 144 | 145 | def load_dataset(self, partition, size=(84, 84)): 146 | print("Loading dataset") 147 | if partition == 'train_val': 148 | with open(os.path.join(self.root, 'compacted_datasets', 'mini_imagenet_%s.pickle' % 'train'), 149 | 'rb') as handle: 150 | data = pickle.load(handle) 151 | with open(os.path.join(self.root, 'compacted_datasets', 'mini_imagenet_%s.pickle' % 'val'), 152 | 'rb') as handle: 153 | data_val = pickle.load(handle) 154 | data.update(data_val) 155 | del data_val 156 | else: 157 | with open(os.path.join(self.root, 'compacted_datasets', 'mini_imagenet_%s.pickle' % partition), 158 | 'rb') as handle: 159 | data = pickle.load(handle) 160 | 161 | with open(os.path.join(self.root, 'compacted_datasets', 'mini_imagenet_label_encoder.pickle'), 162 | 'rb') as handle: 163 | label_encoder = pickle.load(handle) 164 | 165 | # Resize images and normalize 166 | for class_ in data: 167 | for i in range(len(data[class_])): 168 | image2resize = pil_image.fromarray(np.uint8(data[class_][i])) 169 | image_resized = image2resize.resize((size[1], size[0])) 170 | image_resized = np.array(image_resized, dtype='float32') 171 | 172 | # Normalize 173 | image_resized = np.transpose(image_resized, (2, 0, 1)) 174 | image_resized[0, :, :] -= 120.45 # R 175 | image_resized[1, :, :] -= 115.74 # G 176 | image_resized[2, :, :] -= 104.65 # B 177 | image_resized /= 127.5 178 | 179 | data[class_][i] = image_resized 180 | 181 | print("Num classes " + str(len(data))) 182 | num_images = 0 183 | for class_ in data: 184 | num_images += len(data[class_]) 185 | print("Num images " + str(num_images)) 186 | return data, label_encoder 187 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.autograd import Variable 8 | from data import generator 9 | from utils import io_utils 10 | import models.models as models 11 | import test 12 | import numpy as np 13 | 14 | # Training settings 15 | parser = argparse.ArgumentParser(description='Few-Shot Learning with Graph Neural Networks') 16 | parser.add_argument('--exp_name', type=str, default='debug_vx', metavar='N', 17 | help='Name of the experiment') 18 | parser.add_argument('--batch_size', type=int, default=10, metavar='batch_size', 19 | help='Size of batch)') 20 | parser.add_argument('--batch_size_test', type=int, default=10, metavar='batch_size', 21 | help='Size of batch)') 22 | parser.add_argument('--iterations', type=int, default=50000, metavar='N', 23 | help='number of epochs to train ') 24 | parser.add_argument('--decay_interval', type=int, default=10000, metavar='N', 25 | help='Learning rate decay interval') 26 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', 27 | help='learning rate (default: 0.01)') 28 | parser.add_argument('--momentum', type=float, default=0.5, metavar='M', 29 | help='SGD momentum (default: 0.5)') 30 | parser.add_argument('--no-cuda', action='store_true', default=False, 31 | help='enables CUDA training') 32 | parser.add_argument('--seed', type=int, default=1, metavar='S', 33 | help='random seed (default: 1)') 34 | parser.add_argument('--log-interval', type=int, default=20, metavar='N', 35 | help='how many batches to wait before logging training status') 36 | parser.add_argument('--save_interval', type=int, default=300000, metavar='N', 37 | help='how many batches between each model saving') 38 | parser.add_argument('--test_interval', type=int, default=2000, metavar='N', 39 | help='how many batches between each test') 40 | parser.add_argument('--test_N_way', type=int, default=5, metavar='N', 41 | help='Number of classes for doing each classification run') 42 | parser.add_argument('--train_N_way', type=int, default=5, metavar='N', 43 | help='Number of classes for doing each training comparison') 44 | parser.add_argument('--test_N_shots', type=int, default=1, metavar='N', 45 | help='Number of shots in test') 46 | parser.add_argument('--train_N_shots', type=int, default=1, metavar='N', 47 | help='Number of shots when training') 48 | parser.add_argument('--unlabeled_extra', type=int, default=0, metavar='N', 49 | help='Number of shots when training') 50 | parser.add_argument('--metric_network', type=str, default='gnn_iclr_nl', metavar='N', 51 | help='gnn_iclr_nl' + 'gnn_iclr_active') 52 | parser.add_argument('--active_random', type=int, default=0, metavar='N', 53 | help='random active ? ') 54 | parser.add_argument('--dataset_root', type=str, default='datasets', metavar='N', 55 | help='Root dataset') 56 | parser.add_argument('--test_samples', type=int, default=30000, metavar='N', 57 | help='Number of shots') 58 | parser.add_argument('--dataset', type=str, default='mini_imagenet', metavar='N', 59 | help='omniglot') 60 | parser.add_argument('--dec_lr', type=int, default=10000, metavar='N', 61 | help='Decreasing the learning rate every x iterations') 62 | args = parser.parse_args() 63 | 64 | 65 | def _init_(): 66 | if not os.path.exists('checkpoints'): 67 | os.makedirs('checkpoints') 68 | if not os.path.exists('checkpoints/'+args.exp_name): 69 | os.makedirs('checkpoints/'+args.exp_name) 70 | if not os.path.exists('checkpoints/'+args.exp_name+'/'+'models'): 71 | os.makedirs('checkpoints/'+args.exp_name+'/'+'models') 72 | os.system('cp main.py checkpoints'+'/'+args.exp_name+'/'+'main.py.backup') 73 | os.system('cp models/models.py checkpoints' + '/' + args.exp_name + '/' + 'models.py.backup') 74 | _init_() 75 | 76 | io = io_utils.IOStream('checkpoints/' + args.exp_name + '/run.log') 77 | io.cprint(str(args)) 78 | 79 | args.cuda = not args.no_cuda and torch.cuda.is_available() 80 | torch.manual_seed(args.seed) 81 | if args.cuda: 82 | io.cprint('Using GPU : ' + str(torch.cuda.current_device())+' from '+str(torch.cuda.device_count())+' devices') 83 | torch.cuda.manual_seed(args.seed) 84 | else: 85 | io.cprint('Using CPU') 86 | 87 | 88 | def train_batch(model, data): 89 | [enc_nn, metric_nn, softmax_module] = model 90 | [batch_x, label_x, batches_xi, labels_yi, oracles_yi, hidden_labels] = data 91 | 92 | # Compute embedding from x and xi_s 93 | z = enc_nn(batch_x)[-1] 94 | zi_s = [enc_nn(batch_xi)[-1] for batch_xi in batches_xi] 95 | 96 | # Compute metric from embeddings 97 | out_metric, out_logits = metric_nn(inputs=[z, zi_s, labels_yi, oracles_yi, hidden_labels]) 98 | logsoft_prob = softmax_module.forward(out_logits) 99 | 100 | # Loss 101 | label_x_numpy = label_x.cpu().data.numpy() 102 | formatted_label_x = np.argmax(label_x_numpy, axis=1) 103 | formatted_label_x = Variable(torch.LongTensor(formatted_label_x)) 104 | if args.cuda: 105 | formatted_label_x = formatted_label_x.cuda() 106 | loss = F.nll_loss(logsoft_prob, formatted_label_x) 107 | loss.backward() 108 | 109 | return loss 110 | 111 | 112 | def train(): 113 | train_loader = generator.Generator(args.dataset_root, args, partition='train', dataset=args.dataset) 114 | io.cprint('Batch size: '+str(args.batch_size)) 115 | 116 | #Try to load models 117 | enc_nn = models.load_model('enc_nn', args, io) 118 | metric_nn = models.load_model('metric_nn', args, io) 119 | 120 | if enc_nn is None or metric_nn is None: 121 | enc_nn, metric_nn = models.create_models(args=args) 122 | softmax_module = models.SoftmaxModule() 123 | 124 | if args.cuda: 125 | enc_nn.cuda() 126 | metric_nn.cuda() 127 | 128 | io.cprint(str(enc_nn)) 129 | io.cprint(str(metric_nn)) 130 | 131 | weight_decay = 0 132 | if args.dataset == 'mini_imagenet': 133 | print('Weight decay '+str(1e-6)) 134 | weight_decay = 1e-6 135 | opt_enc_nn = optim.Adam(enc_nn.parameters(), lr=args.lr, weight_decay=weight_decay) 136 | opt_metric_nn = optim.Adam(metric_nn.parameters(), lr=args.lr, weight_decay=weight_decay) 137 | 138 | enc_nn.train() 139 | metric_nn.train() 140 | counter = 0 141 | total_loss = 0 142 | val_acc, val_acc_aux = 0, 0 143 | test_acc = 0 144 | for batch_idx in range(args.iterations): 145 | 146 | #################### 147 | # Train 148 | #################### 149 | data = train_loader.get_task_batch(batch_size=args.batch_size, n_way=args.train_N_way, 150 | unlabeled_extra=args.unlabeled_extra, num_shots=args.train_N_shots, 151 | cuda=args.cuda, variable=True) 152 | [batch_x, label_x, _, _, batches_xi, labels_yi, oracles_yi, hidden_labels] = data 153 | 154 | opt_enc_nn.zero_grad() 155 | opt_metric_nn.zero_grad() 156 | 157 | loss_d_metric = train_batch(model=[enc_nn, metric_nn, softmax_module], 158 | data=[batch_x, label_x, batches_xi, labels_yi, oracles_yi, hidden_labels]) 159 | 160 | opt_enc_nn.step() 161 | opt_metric_nn.step() 162 | 163 | 164 | adjust_learning_rate(optimizers=[opt_enc_nn, opt_metric_nn], lr=args.lr, iter=batch_idx) 165 | 166 | #################### 167 | # Display 168 | #################### 169 | counter += 1 170 | total_loss += loss_d_metric.item() 171 | if batch_idx % args.log_interval == 0: 172 | display_str = 'Train Iter: {}'.format(batch_idx) 173 | display_str += '\tLoss_d_metric: {:.6f}'.format(total_loss/counter) 174 | io.cprint(display_str) 175 | counter = 0 176 | total_loss = 0 177 | 178 | #################### 179 | # Test 180 | #################### 181 | if (batch_idx + 1) % args.test_interval == 0 or batch_idx == 20: 182 | if batch_idx == 20: 183 | test_samples = 100 184 | else: 185 | test_samples = 3000 186 | if args.dataset == 'mini_imagenet': 187 | val_acc_aux = test.test_one_shot(args, model=[enc_nn, metric_nn, softmax_module], 188 | test_samples=test_samples*5, partition='val') 189 | test_acc_aux = test.test_one_shot(args, model=[enc_nn, metric_nn, softmax_module], 190 | test_samples=test_samples*5, partition='test') 191 | test.test_one_shot(args, model=[enc_nn, metric_nn, softmax_module], 192 | test_samples=test_samples, partition='train') 193 | enc_nn.train() 194 | metric_nn.train() 195 | 196 | if val_acc_aux is not None and val_acc_aux >= val_acc: 197 | test_acc = test_acc_aux 198 | val_acc = val_acc_aux 199 | 200 | if args.dataset == 'mini_imagenet': 201 | io.cprint("Best test accuracy {:.4f} \n".format(test_acc)) 202 | 203 | #################### 204 | # Save model 205 | #################### 206 | if (batch_idx + 1) % args.save_interval == 0: 207 | torch.save(enc_nn, 'checkpoints/%s/models/enc_nn.t7' % args.exp_name) 208 | torch.save(metric_nn, 'checkpoints/%s/models/metric_nn.t7' % args.exp_name) 209 | 210 | # Test after training 211 | test.test_one_shot(args, model=[enc_nn, metric_nn, softmax_module], 212 | test_samples=args.test_samples) 213 | 214 | 215 | def adjust_learning_rate(optimizers, lr, iter): 216 | new_lr = lr * (0.5**(int(iter/args.dec_lr))) 217 | 218 | for optimizer in optimizers: 219 | for param_group in optimizer.param_groups: 220 | param_group['lr'] = new_lr 221 | 222 | 223 | if __name__ == "__main__": 224 | train() 225 | 226 | -------------------------------------------------------------------------------- /models/gnn_iclr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | 4 | # Pytorch requirements 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | 10 | if torch.cuda.is_available(): 11 | dtype = torch.cuda.FloatTensor 12 | dtype_l = torch.cuda.LongTensor 13 | else: 14 | dtype = torch.FloatTensor 15 | dtype_l = torch.cuda.LongTensor 16 | 17 | 18 | def gmul(input): 19 | W, x = input 20 | # x is a tensor of size (bs, N, num_features) 21 | # W is a tensor of size (bs, N, N, J) 22 | x_size = x.size() 23 | W_size = W.size() 24 | N = W_size[-2] 25 | W = W.split(1, 3) 26 | W = torch.cat(W, 1).squeeze(3) # W is now a tensor of size (bs, J*N, N) 27 | output = torch.bmm(W, x) # output has size (bs, J*N, num_features) 28 | output = output.split(N, 1) 29 | output = torch.cat(output, 2) # output has size (bs, N, J*num_features) 30 | return output 31 | 32 | 33 | class Gconv(nn.Module): 34 | def __init__(self, nf_input, nf_output, J, bn_bool=True): 35 | super(Gconv, self).__init__() 36 | self.J = J 37 | self.num_inputs = J*nf_input 38 | self.num_outputs = nf_output 39 | self.fc = nn.Linear(self.num_inputs, self.num_outputs) 40 | 41 | self.bn_bool = bn_bool 42 | if self.bn_bool: 43 | self.bn = nn.BatchNorm1d(self.num_outputs) 44 | 45 | def forward(self, input): 46 | W = input[0] 47 | x = gmul(input) # out has size (bs, N, num_inputs) 48 | #if self.J == 1: 49 | # x = torch.abs(x) 50 | x_size = x.size() 51 | x = x.contiguous() 52 | x = x.view(-1, self.num_inputs) 53 | x = self.fc(x) # has size (bs*N, num_outputs) 54 | 55 | if self.bn_bool: 56 | x = self.bn(x) 57 | 58 | x = x.view(*x_size[:-1], self.num_outputs) 59 | return W, x 60 | 61 | 62 | class Wcompute(nn.Module): 63 | def __init__(self, input_features, nf, operator='J2', activation='softmax', ratio=[2,2,1,1], num_operators=1, drop=False): 64 | super(Wcompute, self).__init__() 65 | self.num_features = nf 66 | self.operator = operator 67 | self.conv2d_1 = nn.Conv2d(input_features, int(nf * ratio[0]), 1, stride=1) 68 | self.bn_1 = nn.BatchNorm2d(int(nf * ratio[0])) 69 | self.drop = drop 70 | if self.drop: 71 | self.dropout = nn.Dropout(0.3) 72 | self.conv2d_2 = nn.Conv2d(int(nf * ratio[0]), int(nf * ratio[1]), 1, stride=1) 73 | self.bn_2 = nn.BatchNorm2d(int(nf * ratio[1])) 74 | self.conv2d_3 = nn.Conv2d(int(nf * ratio[1]), nf*ratio[2], 1, stride=1) 75 | self.bn_3 = nn.BatchNorm2d(nf*ratio[2]) 76 | self.conv2d_4 = nn.Conv2d(nf*ratio[2], nf*ratio[3], 1, stride=1) 77 | self.bn_4 = nn.BatchNorm2d(nf*ratio[3]) 78 | self.conv2d_last = nn.Conv2d(nf, num_operators, 1, stride=1) 79 | self.activation = activation 80 | 81 | def forward(self, x, W_id): 82 | W1 = x.unsqueeze(2) 83 | W2 = torch.transpose(W1, 1, 2) #size: bs x N x N x num_features 84 | W_new = torch.abs(W1 - W2) #size: bs x N x N x num_features 85 | W_new = torch.transpose(W_new, 1, 3) #size: bs x num_features x N x N 86 | 87 | W_new = self.conv2d_1(W_new) 88 | W_new = self.bn_1(W_new) 89 | W_new = F.leaky_relu(W_new) 90 | if self.drop: 91 | W_new = self.dropout(W_new) 92 | 93 | W_new = self.conv2d_2(W_new) 94 | W_new = self.bn_2(W_new) 95 | W_new = F.leaky_relu(W_new) 96 | 97 | W_new = self.conv2d_3(W_new) 98 | W_new = self.bn_3(W_new) 99 | W_new = F.leaky_relu(W_new) 100 | 101 | W_new = self.conv2d_4(W_new) 102 | W_new = self.bn_4(W_new) 103 | W_new = F.leaky_relu(W_new) 104 | 105 | W_new = self.conv2d_last(W_new) 106 | W_new = torch.transpose(W_new, 1, 3) #size: bs x N x N x 1 107 | 108 | if self.activation == 'softmax': 109 | W_new = W_new - W_id.expand_as(W_new) * 1e8 110 | W_new = torch.transpose(W_new, 2, 3) 111 | # Applying Softmax 112 | W_new = W_new.contiguous() 113 | W_new_size = W_new.size() 114 | W_new = W_new.view(-1, W_new.size(3)) 115 | W_new = F.softmax(W_new) 116 | W_new = W_new.view(W_new_size) 117 | # Softmax applied 118 | W_new = torch.transpose(W_new, 2, 3) 119 | 120 | elif self.activation == 'sigmoid': 121 | W_new = F.sigmoid(W_new) 122 | W_new *= (1 - W_id) 123 | elif self.activation == 'none': 124 | W_new *= (1 - W_id) 125 | else: 126 | raise (NotImplementedError) 127 | 128 | if self.operator == 'laplace': 129 | W_new = W_id - W_new 130 | elif self.operator == 'J2': 131 | W_new = torch.cat([W_id, W_new], 3) 132 | else: 133 | raise(NotImplementedError) 134 | 135 | return W_new 136 | 137 | 138 | class GNN_nl_omniglot(nn.Module): 139 | def __init__(self, args, input_features, nf, J): 140 | super(GNN_nl_omniglot, self).__init__() 141 | self.args = args 142 | self.input_features = input_features 143 | self.nf = nf 144 | self.J = J 145 | 146 | self.num_layers = 2 147 | for i in range(self.num_layers): 148 | module_w = Wcompute(self.input_features + int(nf / 2) * i, 149 | self.input_features + int(nf / 2) * i, 150 | operator='J2', activation='softmax', ratio=[2, 1.5, 1, 1], drop=False) 151 | module_l = Gconv(self.input_features + int(nf / 2) * i, int(nf / 2), 2) 152 | self.add_module('layer_w{}'.format(i), module_w) 153 | self.add_module('layer_l{}'.format(i), module_l) 154 | 155 | self.w_comp_last = Wcompute(self.input_features + int(self.nf / 2) * self.num_layers, 156 | self.input_features + int(self.nf / 2) * (self.num_layers - 1), 157 | operator='J2', activation='softmax', ratio=[2, 1.5, 1, 1], drop=True) 158 | self.layer_last = Gconv(self.input_features + int(self.nf / 2) * self.num_layers, args.train_N_way, 2, bn_bool=True) 159 | 160 | def forward(self, x): 161 | W_init = Variable(torch.eye(x.size(1)).unsqueeze(0).repeat(x.size(0), 1, 1).unsqueeze(3)) 162 | if self.args.cuda: 163 | W_init = W_init.cuda() 164 | 165 | for i in range(self.num_layers): 166 | Wi = self._modules['layer_w{}'.format(i)](x, W_init) 167 | 168 | x_new = F.leaky_relu(self._modules['layer_l{}'.format(i)]([Wi, x])[1]) 169 | x = torch.cat([x, x_new], 2) 170 | 171 | Wl = self.w_comp_last(x, W_init) 172 | out = self.layer_last([Wl, x])[1] 173 | 174 | return out[:, 0, :] 175 | 176 | 177 | class GNN_nl(nn.Module): 178 | def __init__(self, args, input_features, nf, J): 179 | super(GNN_nl, self).__init__() 180 | self.args = args 181 | self.input_features = input_features 182 | self.nf = nf 183 | self.J = J 184 | 185 | if args.dataset == 'mini_imagenet': 186 | self.num_layers = 2 187 | else: 188 | self.num_layers = 2 189 | 190 | for i in range(self.num_layers): 191 | if i == 0: 192 | module_w = Wcompute(self.input_features, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 193 | module_l = Gconv(self.input_features, int(nf / 2), 2) 194 | else: 195 | module_w = Wcompute(self.input_features + int(nf / 2) * i, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 196 | module_l = Gconv(self.input_features + int(nf / 2) * i, int(nf / 2), 2) 197 | self.add_module('layer_w{}'.format(i), module_w) 198 | self.add_module('layer_l{}'.format(i), module_l) 199 | 200 | self.w_comp_last = Wcompute(self.input_features + int(self.nf / 2) * self.num_layers, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 201 | self.layer_last = Gconv(self.input_features + int(self.nf / 2) * self.num_layers, args.train_N_way, 2, bn_bool=False) 202 | 203 | def forward(self, x): 204 | W_init = Variable(torch.eye(x.size(1)).unsqueeze(0).repeat(x.size(0), 1, 1).unsqueeze(3)) 205 | if self.args.cuda: 206 | W_init = W_init.cuda() 207 | 208 | for i in range(self.num_layers): 209 | Wi = self._modules['layer_w{}'.format(i)](x, W_init) 210 | 211 | x_new = F.leaky_relu(self._modules['layer_l{}'.format(i)]([Wi, x])[1]) 212 | x = torch.cat([x, x_new], 2) 213 | 214 | Wl=self.w_comp_last(x, W_init) 215 | out = self.layer_last([Wl, x])[1] 216 | 217 | return out[:, 0, :] 218 | 219 | class GNN_active(nn.Module): 220 | def __init__(self, args, input_features, nf, J): 221 | super(GNN_active, self).__init__() 222 | self.args = args 223 | self.input_features = input_features 224 | self.nf = nf 225 | self.J = J 226 | 227 | self.num_layers = 2 228 | for i in range(self.num_layers // 2): 229 | if i == 0: 230 | module_w = Wcompute(self.input_features, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 231 | module_l = Gconv(self.input_features, int(nf / 2), 2) 232 | else: 233 | module_w = Wcompute(self.input_features + int(nf / 2) * i, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 234 | module_l = Gconv(self.input_features + int(nf / 2) * i, int(nf / 2), 2) 235 | 236 | self.add_module('layer_w{}'.format(i), module_w) 237 | self.add_module('layer_l{}'.format(i), module_l) 238 | 239 | self.conv_active = nn.Conv1d(self.input_features + int(nf / 2) * 1, 1, 1, bias=False) 240 | nn.init.uniform_(self.conv_active.weight.data) 241 | 242 | for i in range(int(self.num_layers/2), self.num_layers): 243 | if i == 0: 244 | module_w = Wcompute(self.input_features, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 245 | module_l = Gconv(self.input_features, int(nf / 2), 2) 246 | else: 247 | module_w = Wcompute(self.input_features + int(nf / 2) * i, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 248 | module_l = Gconv(self.input_features + int(nf / 2) * i, int(nf / 2), 2) 249 | self.add_module('layer_w{}'.format(i), module_w) 250 | self.add_module('layer_l{}'.format(i), module_l) 251 | 252 | self.w_comp_last = Wcompute(self.input_features + int(self.nf / 2) * self.num_layers, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 253 | self.layer_last = Gconv(self.input_features + int(self.nf / 2) * self.num_layers, args.train_N_way, 2, bn_bool=False) 254 | 255 | def active(self, x, oracles_yi, hidden_labels): 256 | ''' 257 | :param x: torch.Size([40, 26, 181]) 258 | :param oracles_yi: torch.Size([40, 26, 5]) 259 | :param hidden_labels: torch.Size([40, 26]) 260 | :return: 261 | ''' 262 | 263 | 264 | x_active = torch.transpose(x, 1, 2) 265 | x_to_classify = x_active[:, :, 0:1] 266 | 267 | x_active = - ((x_active - x_to_classify) ** 2).detach() 268 | x_active = self.conv_active(x_active) 269 | x_active = torch.transpose(x_active, 1, 2) 270 | x_active = x_active.squeeze(-1) # torch.Size([40, 26]) 271 | 272 | if self.args.active_random == 1: 273 | x_active.data.fill_(1. / x_active.size(1)) 274 | 275 | # assigning lower prob to uncover the labels we already know 276 | x_active = x_active - (1 - hidden_labels) * 1e8 277 | 278 | if self.args.active_random == 1: 279 | mapping = F.gumbel_softmax(x_active, hard=True).unsqueeze(-1) 280 | mapping = mapping.detach() 281 | else: 282 | if self.training: 283 | mapping = F.gumbel_softmax(x_active, hard=True).unsqueeze(-1) 284 | else: 285 | temperature = 1e5 # larger temperature at test to pick the most likely 286 | mapping = F.gumbel_softmax(x_active * temperature, hard=True).unsqueeze(-1) 287 | 288 | 289 | label2add = oracles_yi * mapping 290 | 291 | # add ppadding 292 | padd = torch.zeros(x.size(0), x.size(1), x.size(2) - label2add.size(2)) 293 | padd = Variable(padd).detach() 294 | if self.args.cuda: 295 | padd = padd.cuda() 296 | label2add = torch.cat([label2add, padd], 2) 297 | x = x + label2add 298 | return x 299 | 300 | def forward(self, x, oracles_yi, hidden_labels): 301 | W_init = Variable(torch.eye(x.size(1)).unsqueeze(0).repeat(x.size(0), 1, 1).unsqueeze(3)) 302 | if self.args.cuda: 303 | W_init = W_init.cuda() 304 | 305 | for i in range(self.num_layers // 2): 306 | Wi = self._modules['layer_w{}'.format(i)](x, W_init) 307 | x_new = F.leaky_relu(self._modules['layer_l{}'.format(i)]([Wi, x])[1]) 308 | x = torch.cat([x, x_new], 2) 309 | 310 | x = self.active(x, oracles_yi, hidden_labels) 311 | 312 | for i in range(int(self.num_layers/2), self.num_layers): 313 | Wi = self._modules['layer_w{}'.format(i)](x, W_init) 314 | x_new = F.leaky_relu(self._modules['layer_l{}'.format(i)]([Wi, x])[1]) 315 | x = torch.cat([x, x_new], 2) 316 | 317 | 318 | Wl=self.w_comp_last(x, W_init) 319 | out = self.layer_last([Wl, x])[1] 320 | 321 | return out[:, 0, :] 322 | 323 | if __name__ == '__main__': 324 | # test modules 325 | bs = 4 326 | nf = 10 327 | num_layers = 5 328 | N = 8 329 | x = torch.ones((bs, N, nf)) 330 | W1 = torch.eye(N).unsqueeze(0).unsqueeze(-1).expand(bs, N, N, 1) 331 | W2 = torch.ones(N).unsqueeze(0).unsqueeze(-1).expand(bs, N, N, 1) 332 | J = 2 333 | W = torch.cat((W1, W2), 3) 334 | input = [Variable(W), Variable(x)] 335 | ######################### test gmul ############################## 336 | # feature_maps = [num_features, num_features, num_features] 337 | # out = gmul(input) 338 | # print(out[0, :, num_features:]) 339 | ######################### test gconv ############################## 340 | # feature_maps = [num_features, num_features, num_features] 341 | # gconv = Gconv(feature_maps, J) 342 | # _, out = gconv(input) 343 | # print(out.size()) 344 | ######################### test gnn ############################## 345 | # x = torch.ones((bs, N, 1)) 346 | # input = [Variable(W), Variable(x)] 347 | # gnn = GNN(num_features, num_layers, J) 348 | # out = gnn(input) 349 | # print(out.size()) 350 | 351 | 352 | --------------------------------------------------------------------------------