├── LICENSE ├── README.md ├── convnet.py ├── materials ├── test.csv ├── train.csv └── val.csv ├── mini_imagenet.py ├── samplers.py ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yinbo Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prototypical Network 2 | 3 | A re-implementation of [Prototypical Network](https://arxiv.org/abs/1703.05175). 4 | 5 | With ConvNet-4 backbone on miniImageNet. 6 | 7 | ***For deep backbones (ResNet), see [Meta-Baseline](https://github.com/cyvius96/few-shot-meta-baseline).*** 8 | 9 | ### Results 10 | 11 | 1-shot: 49.1% (49.4% in the paper) 12 | 13 | 5-shot: 66.9% (68.2% in the paper) 14 | 15 | ## Environment 16 | 17 | * python 3 18 | * pytorch 0.4.0 19 | 20 | ## Instructions 21 | 22 | 1. Download the images: https://drive.google.com/open?id=0B3Irx3uQNoBMQ1FlNXJsZUdYWEE 23 | 24 | 2. Make a folder `materials/images` and put those images into it. 25 | 26 | `--gpu` to specify device for program. 27 | 28 | ### 1-shot Train 29 | 30 | `python train.py` 31 | 32 | ### 1-shot Test 33 | 34 | `python test.py` 35 | 36 | ### 5-shot Train 37 | 38 | `python train.py --shot 5 --train-way 20 --save-path ./save/proto-5` 39 | 40 | ### 5-shot Test 41 | 42 | `python test.py --load ./save/proto-5/max-acc.pth --shot 5` 43 | -------------------------------------------------------------------------------- /convnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def conv_block(in_channels, out_channels): 5 | bn = nn.BatchNorm2d(out_channels) 6 | nn.init.uniform_(bn.weight) # for pytorch 1.2 or later 7 | return nn.Sequential( 8 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 9 | bn, 10 | nn.ReLU(), 11 | nn.MaxPool2d(2) 12 | ) 13 | 14 | 15 | class Convnet(nn.Module): 16 | 17 | def __init__(self, x_dim=3, hid_dim=64, z_dim=64): 18 | super().__init__() 19 | self.encoder = nn.Sequential( 20 | conv_block(x_dim, hid_dim), 21 | conv_block(hid_dim, hid_dim), 22 | conv_block(hid_dim, hid_dim), 23 | conv_block(hid_dim, z_dim), 24 | ) 25 | self.out_channels = 1600 26 | 27 | def forward(self, x): 28 | x = self.encoder(x) 29 | return x.view(x.size(0), -1) 30 | 31 | -------------------------------------------------------------------------------- /mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from PIL import Image 3 | 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | 7 | 8 | ROOT_PATH = './materials/' 9 | 10 | 11 | class MiniImageNet(Dataset): 12 | 13 | def __init__(self, setname): 14 | csv_path = osp.join(ROOT_PATH, setname + '.csv') 15 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:] 16 | 17 | data = [] 18 | label = [] 19 | lb = -1 20 | 21 | self.wnids = [] 22 | 23 | for l in lines: 24 | name, wnid = l.split(',') 25 | path = osp.join(ROOT_PATH, 'images', name) 26 | if wnid not in self.wnids: 27 | self.wnids.append(wnid) 28 | lb += 1 29 | data.append(path) 30 | label.append(lb) 31 | 32 | self.data = data 33 | self.label = label 34 | 35 | self.transform = transforms.Compose([ 36 | transforms.Resize(84), 37 | transforms.CenterCrop(84), 38 | transforms.ToTensor(), 39 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 40 | std=[0.229, 0.224, 0.225]) 41 | ]) 42 | 43 | def __len__(self): 44 | return len(self.data) 45 | 46 | def __getitem__(self, i): 47 | path, label = self.data[i], self.label[i] 48 | image = self.transform(Image.open(path).convert('RGB')) 49 | return image, label 50 | 51 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class CategoriesSampler(): 6 | 7 | def __init__(self, label, n_batch, n_cls, n_per): 8 | self.n_batch = n_batch 9 | self.n_cls = n_cls 10 | self.n_per = n_per 11 | 12 | label = np.array(label) 13 | self.m_ind = [] 14 | for i in range(max(label) + 1): 15 | ind = np.argwhere(label == i).reshape(-1) 16 | ind = torch.from_numpy(ind) 17 | self.m_ind.append(ind) 18 | 19 | def __len__(self): 20 | return self.n_batch 21 | 22 | def __iter__(self): 23 | for i_batch in range(self.n_batch): 24 | batch = [] 25 | classes = torch.randperm(len(self.m_ind))[:self.n_cls] 26 | for c in classes: 27 | l = self.m_ind[c] 28 | pos = torch.randperm(len(l))[:self.n_per] 29 | batch.append(l[pos]) 30 | batch = torch.stack(batch).t().reshape(-1) 31 | yield batch 32 | 33 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | from mini_imagenet import MiniImageNet 7 | from samplers import CategoriesSampler 8 | from convnet import Convnet 9 | from utils import pprint, set_gpu, count_acc, Averager, euclidean_metric 10 | 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--gpu', default='0') 15 | parser.add_argument('--load', default='./save/proto-1/max-acc.pth') 16 | parser.add_argument('--batch', type=int, default=2000) 17 | parser.add_argument('--way', type=int, default=5) 18 | parser.add_argument('--shot', type=int, default=1) 19 | parser.add_argument('--query', type=int, default=30) 20 | args = parser.parse_args() 21 | pprint(vars(args)) 22 | 23 | set_gpu(args.gpu) 24 | 25 | dataset = MiniImageNet('test') 26 | sampler = CategoriesSampler(dataset.label, 27 | args.batch, args.way, args.shot + args.query) 28 | loader = DataLoader(dataset, batch_sampler=sampler, 29 | num_workers=8, pin_memory=True) 30 | 31 | model = Convnet().cuda() 32 | model.load_state_dict(torch.load(args.load)) 33 | model.eval() 34 | 35 | ave_acc = Averager() 36 | 37 | for i, batch in enumerate(loader, 1): 38 | data, _ = [_.cuda() for _ in batch] 39 | k = args.way * args.shot 40 | data_shot, data_query = data[:k], data[k:] 41 | 42 | x = model(data_shot) 43 | x = x.reshape(args.shot, args.way, -1).mean(dim=0) 44 | p = x 45 | 46 | logits = euclidean_metric(model(data_query), p) 47 | 48 | label = torch.arange(args.way).repeat(args.query) 49 | label = label.type(torch.cuda.LongTensor) 50 | 51 | acc = count_acc(logits, label) 52 | ave_acc.add(acc) 53 | print('batch {}: {:.2f}({:.2f})'.format(i, ave_acc.item() * 100, acc * 100)) 54 | 55 | x = None; p = None; logits = None 56 | 57 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | 8 | from mini_imagenet import MiniImageNet 9 | from samplers import CategoriesSampler 10 | from convnet import Convnet 11 | from utils import pprint, set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--max-epoch', type=int, default=200) 17 | parser.add_argument('--save-epoch', type=int, default=20) 18 | parser.add_argument('--shot', type=int, default=1) 19 | parser.add_argument('--query', type=int, default=15) 20 | parser.add_argument('--train-way', type=int, default=30) 21 | parser.add_argument('--test-way', type=int, default=5) 22 | parser.add_argument('--save-path', default='./save/proto-1') 23 | parser.add_argument('--gpu', default='0') 24 | args = parser.parse_args() 25 | pprint(vars(args)) 26 | 27 | set_gpu(args.gpu) 28 | ensure_path(args.save_path) 29 | 30 | trainset = MiniImageNet('train') 31 | train_sampler = CategoriesSampler(trainset.label, 100, 32 | args.train_way, args.shot + args.query) 33 | train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler, 34 | num_workers=8, pin_memory=True) 35 | 36 | valset = MiniImageNet('val') 37 | val_sampler = CategoriesSampler(valset.label, 400, 38 | args.test_way, args.shot + args.query) 39 | val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, 40 | num_workers=8, pin_memory=True) 41 | 42 | model = Convnet().cuda() 43 | 44 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 45 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) 46 | 47 | def save_model(name): 48 | torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth')) 49 | 50 | trlog = {} 51 | trlog['args'] = vars(args) 52 | trlog['train_loss'] = [] 53 | trlog['val_loss'] = [] 54 | trlog['train_acc'] = [] 55 | trlog['val_acc'] = [] 56 | trlog['max_acc'] = 0.0 57 | 58 | timer = Timer() 59 | 60 | for epoch in range(1, args.max_epoch + 1): 61 | lr_scheduler.step() 62 | 63 | model.train() 64 | 65 | tl = Averager() 66 | ta = Averager() 67 | 68 | for i, batch in enumerate(train_loader, 1): 69 | data, _ = [_.cuda() for _ in batch] 70 | p = args.shot * args.train_way 71 | data_shot, data_query = data[:p], data[p:] 72 | 73 | proto = model(data_shot) 74 | proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0) 75 | 76 | label = torch.arange(args.train_way).repeat(args.query) 77 | label = label.type(torch.cuda.LongTensor) 78 | 79 | logits = euclidean_metric(model(data_query), proto) 80 | loss = F.cross_entropy(logits, label) 81 | acc = count_acc(logits, label) 82 | print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}' 83 | .format(epoch, i, len(train_loader), loss.item(), acc)) 84 | 85 | tl.add(loss.item()) 86 | ta.add(acc) 87 | 88 | optimizer.zero_grad() 89 | loss.backward() 90 | optimizer.step() 91 | 92 | proto = None; logits = None; loss = None 93 | 94 | tl = tl.item() 95 | ta = ta.item() 96 | 97 | model.eval() 98 | 99 | vl = Averager() 100 | va = Averager() 101 | 102 | for i, batch in enumerate(val_loader, 1): 103 | data, _ = [_.cuda() for _ in batch] 104 | p = args.shot * args.test_way 105 | data_shot, data_query = data[:p], data[p:] 106 | 107 | proto = model(data_shot) 108 | proto = proto.reshape(args.shot, args.test_way, -1).mean(dim=0) 109 | 110 | label = torch.arange(args.test_way).repeat(args.query) 111 | label = label.type(torch.cuda.LongTensor) 112 | 113 | logits = euclidean_metric(model(data_query), proto) 114 | loss = F.cross_entropy(logits, label) 115 | acc = count_acc(logits, label) 116 | 117 | vl.add(loss.item()) 118 | va.add(acc) 119 | 120 | proto = None; logits = None; loss = None 121 | 122 | vl = vl.item() 123 | va = va.item() 124 | print('epoch {}, val, loss={:.4f} acc={:.4f}'.format(epoch, vl, va)) 125 | 126 | if va > trlog['max_acc']: 127 | trlog['max_acc'] = va 128 | save_model('max-acc') 129 | 130 | trlog['train_loss'].append(tl) 131 | trlog['train_acc'].append(ta) 132 | trlog['val_loss'].append(vl) 133 | trlog['val_acc'].append(va) 134 | 135 | torch.save(trlog, osp.join(args.save_path, 'trlog')) 136 | 137 | save_model('epoch-last') 138 | 139 | if epoch % args.save_epoch == 0: 140 | save_model('epoch-{}'.format(epoch)) 141 | 142 | print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / args.max_epoch))) 143 | 144 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import pprint 5 | 6 | import torch 7 | 8 | 9 | def set_gpu(x): 10 | os.environ['CUDA_VISIBLE_DEVICES'] = x 11 | print('using gpu:', x) 12 | 13 | 14 | def ensure_path(path): 15 | if os.path.exists(path): 16 | if input('{} exists, remove? ([y]/n)'.format(path)) != 'n': 17 | shutil.rmtree(path) 18 | os.makedirs(path) 19 | else: 20 | os.makedirs(path) 21 | 22 | 23 | class Averager(): 24 | 25 | def __init__(self): 26 | self.n = 0 27 | self.v = 0 28 | 29 | def add(self, x): 30 | self.v = (self.v * self.n + x) / (self.n + 1) 31 | self.n += 1 32 | 33 | def item(self): 34 | return self.v 35 | 36 | 37 | def count_acc(logits, label): 38 | pred = torch.argmax(logits, dim=1) 39 | return (pred == label).type(torch.cuda.FloatTensor).mean().item() 40 | 41 | 42 | def dot_metric(a, b): 43 | return torch.mm(a, b.t()) 44 | 45 | 46 | def euclidean_metric(a, b): 47 | n = a.shape[0] 48 | m = b.shape[0] 49 | a = a.unsqueeze(1).expand(n, m, -1) 50 | b = b.unsqueeze(0).expand(n, m, -1) 51 | logits = -((a - b)**2).sum(dim=2) 52 | return logits 53 | 54 | 55 | class Timer(): 56 | 57 | def __init__(self): 58 | self.o = time.time() 59 | 60 | def measure(self, p=1): 61 | x = (time.time() - self.o) / p 62 | x = int(x) 63 | if x >= 3600: 64 | return '{:.1f}h'.format(x / 3600) 65 | if x >= 60: 66 | return '{}m'.format(round(x / 60)) 67 | return '{}s'.format(x) 68 | 69 | _utils_pp = pprint.PrettyPrinter() 70 | def pprint(x): 71 | _utils_pp.pprint(x) 72 | 73 | 74 | def l2_loss(pred, label): 75 | return ((pred - label)**2).sum() / len(pred) / 2 76 | 77 | --------------------------------------------------------------------------------