├── LICENSE ├── README.md ├── data └── readme.txt ├── datasets ├── cifarfs.py ├── mini_imagenet.py ├── samplers.py └── tiered_imagenet.py ├── models ├── convnet.py ├── distill.py └── resnet.py ├── save └── readme.txt ├── test.py ├── train_stage1.py ├── train_stage2.py ├── train_stage3.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jit Yan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSL-ProtoNet: Self-supervised Learning Prototypical Networks for Few-shot Learning 2 | 3 | This repository contains the **pytorch** code for the paper: "[SSL-ProtoNet: Self-supervised Learning Prototypical Networks for Few-shot Learning](https://doi.org/10.1016/j.eswa.2023.122173)" Jit Yan Lim, Kian Ming Lim, Chin Poo Lee, Yong Xuan Tan 4 | 5 | ## Environment 6 | The code is tested on Windows 10 with Anaconda3 and following packages: 7 | - python 3.7.4 8 | - pytorch 1.3.1 9 | 10 | ## Preparation 11 | 1. Change the ROOT_PATH value in the following files to yours: 12 | - `datasets/mini_imagenet.py` 13 | - `datasets/tiered_imagenet.py` 14 | - `datasets/cifarfs.py` 15 | 16 | 2. Download the datasets and put them into corresponding folders that mentioned in the ROOT_PATH:
17 | - ***mini*ImageNet**: download from [CSS](https://github.com/anyuexuan/CSS) and put in `data/mini-imagenet` folder. 18 | 19 | - ***tiered*ImageNet**: download from [RFS](https://www.dropbox.com/sh/6yd1ygtyc3yd981/AABVeEqzC08YQv4UZk7lNHvya?dl=0) and put in `data/tiered-imagenet` folder. 20 | 21 | - **CIFARFS**: download from [MetaOptNet](https://github.com/kjunelee/MetaOptNet) and put in `data/cifar-fs` folder. 22 | 23 | 24 | ## Pre-trained Models 25 | [Optional] The pre-trained models can be downloaded from [here](https://drive.google.com/file/d/14IOHnVfVACpkhjj1o3ZjwG7YD4p6ULLM/view?usp=sharing). Extract and put the content in the save folder. To evaluate the model, run the test.py file with the proper save path as in the next section. 26 | 27 | 28 | ## Experiments 29 | To train on 1-shot and 5-shot CIFAR-FS:
30 | ``` 31 | python train_stage1.py --dataset cifarfs --train-way 50 --train-batch 100 --save-path ./save/cifarfs-stage1 32 | 33 | python train_stage2.py --dataset cifarfs --shot 1 --save-path ./save/cifarfs-stage2-1s --stage1-path ./save/cifarfs-stage1 --train-way 20 34 | python train_stage2.py --dataset cifarfs --shot 5 --save-path ./save/cifarfs-stage2-5s --stage1-path ./save/cifarfs-stage1 --train-way 10 35 | 36 | python train_stage3.py --kd-coef 0.7 --dataset cifarfs --shot 1 --train-way 20 --stage1-path ./save/cifarfs-stage1 --stage2-path ./save/cifarfs-stage2-1s --save-path ./save/cifarfs-stage3-1s 37 | python train_stage3.py --kd-coef 0.1 --dataset cifarfs --shot 5 --train-way 10 --stage1-path ./save/cifarfs-stage1 --stage2-path ./save/cifarfs-stage2-5s --save-path ./save/cifarfs-stage3-5s 38 | ``` 39 | To evaluate on 5-way 1-shot and 5-way 5-shot CIFAR-FS:
40 | ``` 41 | python test.py --dataset cifarfs --shot 1 --save-path ./save/cifarfs-stage3-1s 42 | python test.py --dataset cifarfs --shot 5 --save-path ./save/cifarfs-stage3-1s 43 | ``` 44 | 45 | 46 | ## Citation 47 | If you find this repo useful for your research, please consider citing the paper: 48 | ``` 49 | @article{LIM2023122173, 50 | title = {SSL-ProtoNet: Self-supervised Learning Prototypical Networks for few-shot learning}, 51 | journal = {Expert Systems with Applications}, 52 | pages = {122173}, 53 | year = {2023}, 54 | issn = {0957-4174}, 55 | doi = {https://doi.org/10.1016/j.eswa.2023.122173}, 56 | author = {Jit Yan Lim and Kian Ming Lim and Chin Poo Lee and Yong Xuan Tan} 57 | } 58 | ``` 59 | 60 | ## Contacts 61 | For any questions, please contact:
62 | 63 | Jit Yan Lim (jityan95@gmail.com)
64 | Kian Ming Lim (Kian-Ming.Lim@nottingham.edu.cn) 65 | 66 | ## Acknowlegements 67 | This repo is based on **[Prototypical Networks](https://github.com/yinboc/prototypical-network-pytorch)**, **[RFS](https://github.com/WangYueFt/rfs)**, and **[SKD](https://github.com/brjathu/SKD)**. 68 | -------------------------------------------------------------------------------- /data/readme.txt: -------------------------------------------------------------------------------- 1 | datasets location -------------------------------------------------------------------------------- /datasets/cifarfs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import torch 4 | from torch.utils.data import Dataset 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | import numpy as np 8 | 9 | ROOT_PATH = './data/cifar-fs' 10 | 11 | def load_data(file): 12 | try: 13 | with open(file, 'rb') as fo: 14 | data = pickle.load(fo) 15 | return data 16 | except: 17 | with open(file, 'rb') as f: 18 | u = pickle._Unpickler(f) 19 | u.encoding = 'latin1' 20 | data = u.load() 21 | return data 22 | 23 | 24 | class CIFAR_FS(Dataset): 25 | 26 | def __init__(self, phase='train', size=32, transform=None): 27 | 28 | filepath = os.path.join(ROOT_PATH, 'CIFAR_FS_' + phase + ".pickle") 29 | datafile = load_data(filepath) 30 | 31 | data = datafile['data'] 32 | label = datafile['labels'] 33 | 34 | data = [Image.fromarray(x) for x in data] 35 | 36 | min_label = min(label) 37 | label = [x - min_label for x in label] 38 | 39 | newlabel = [] 40 | classlabel = 0 41 | for i in range(len(label)): 42 | if (i > 0) and (label[i] != label[i-1]): 43 | classlabel += 1 44 | newlabel.append(classlabel) 45 | 46 | self.data = data 47 | self.label = newlabel 48 | 49 | if transform is None: 50 | self.transform = transforms.Compose([ 51 | transforms.Resize(size), 52 | transforms.CenterCrop(size), 53 | transforms.ToTensor(), 54 | transforms.Normalize( 55 | np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 56 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]) 57 | ) 58 | ]) 59 | else: 60 | self.transform = transform 61 | 62 | def __len__(self): 63 | return len(self.data) 64 | 65 | def __getitem__(self, i): 66 | return self.transform(self.data[i]), self.label[i] 67 | 68 | 69 | class SSLCifarFS(Dataset): 70 | 71 | def __init__(self, phase, args): 72 | filepath = os.path.join(ROOT_PATH, 'CIFAR_FS_' + phase + ".pickle") 73 | datafile = load_data(filepath) 74 | 75 | data = datafile['data'] 76 | label = datafile['labels'] 77 | 78 | data = [Image.fromarray(x) for x in data] 79 | 80 | min_label = min(label) 81 | label = [x - min_label for x in label] 82 | 83 | newlabel = [] 84 | classlabel = 0 85 | for i in range(len(label)): 86 | if (i > 0) and (label[i] != label[i-1]): 87 | classlabel += 1 88 | newlabel.append(classlabel) 89 | 90 | self.data = data 91 | self.label = newlabel 92 | self.args = args 93 | 94 | color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4, 95 | saturation=0.4, hue=0.1) 96 | self.augmentation_transform = transforms.Compose([transforms.RandomResizedCrop(size=(args.size, args.size)[-2:], 97 | scale=(0.5, 1.0)), 98 | transforms.RandomHorizontalFlip(p=0.5), 99 | transforms.RandomVerticalFlip(p=0.5), 100 | transforms.RandomApply([color_jitter], p=0.8), 101 | transforms.RandomGrayscale(p=0.2), 102 | transforms.ToTensor(), 103 | transforms.Normalize( 104 | np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 105 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]) 106 | ) 107 | ]) 108 | # 109 | self.identity_transform = transforms.Compose([ 110 | transforms.Resize(args.size), 111 | transforms.CenterCrop(args.size), 112 | transforms.ToTensor(), 113 | transforms.Normalize( 114 | np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]), 115 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]]) 116 | ) 117 | ]) 118 | 119 | def __len__(self): 120 | return len(self.data) 121 | 122 | def __getitem__(self, i): 123 | img, label = self.data[i], self.label[i] 124 | image = [] 125 | for _ in range(self.args.shot): 126 | image.append(self.identity_transform(img).unsqueeze(0)) 127 | for i in range(self.args.train_query): 128 | image.append(self.augmentation_transform(img).unsqueeze(0)) 129 | return dict(data=torch.cat(image)), label 130 | 131 | -------------------------------------------------------------------------------- /datasets/mini_imagenet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from PIL import Image 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | ROOT_PATH = './data/mini-imagenet' 10 | 11 | 12 | class MiniImageNet(Dataset): 13 | 14 | def __init__(self, setname, size, transform=None): 15 | csv_path = osp.join(ROOT_PATH, setname + '.csv') 16 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:] 17 | 18 | data = [] 19 | label = [] 20 | lb = -1 21 | 22 | self.wnids = [] 23 | 24 | for l in lines: 25 | name, wnid = l.split(',') 26 | path = osp.join(ROOT_PATH, 'images', name) 27 | if wnid not in self.wnids: 28 | self.wnids.append(wnid) 29 | lb += 1 30 | data.append(path) 31 | label.append(lb) 32 | 33 | self.data = data 34 | self.label = label 35 | 36 | if transform is None: 37 | self.transform = transforms.Compose([ 38 | transforms.Resize(size), 39 | transforms.CenterCrop(size), 40 | transforms.ToTensor(), 41 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 42 | std=[0.229, 0.224, 0.225]) 43 | ]) 44 | else: 45 | self.transform = transform 46 | 47 | def __len__(self): 48 | return len(self.data) 49 | 50 | def __getitem__(self, i): 51 | path, label = self.data[i], self.label[i] 52 | image = self.transform(Image.open(path).convert('RGB')) 53 | return image, label 54 | 55 | 56 | class SSLMiniImageNet(Dataset): 57 | 58 | def __init__(self, setname, args): 59 | csv_path = osp.join(ROOT_PATH, setname + '.csv') 60 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:] 61 | 62 | data = [] 63 | label = [] 64 | lb = -1 65 | 66 | self.wnids = [] 67 | self.args = args 68 | 69 | for l in lines: 70 | name, wnid = l.split(',') 71 | path = osp.join(ROOT_PATH, 'images', name) 72 | if wnid not in self.wnids: 73 | self.wnids.append(wnid) 74 | lb += 1 75 | data.append(path) 76 | label.append(lb) 77 | 78 | self.data = data 79 | self.label = label 80 | 81 | color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4, 82 | saturation=0.4, hue=0.1) 83 | self.augmentation_transform = transforms.Compose([transforms.RandomResizedCrop(size=(args.size, args.size)[-2:], 84 | scale=(0.5, 1.0)), 85 | transforms.RandomHorizontalFlip(p=0.5), 86 | transforms.RandomVerticalFlip(p=0.5), 87 | transforms.RandomApply([color_jitter], p=0.8), 88 | transforms.RandomGrayscale(p=0.2), 89 | transforms.ToTensor(), 90 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 91 | std=[0.229, 0.224, 0.225]), 92 | ]) 93 | # 94 | self.identity_transform = transforms.Compose([ 95 | transforms.Resize(args.size), 96 | transforms.CenterCrop(args.size), 97 | transforms.ToTensor(), 98 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 99 | std=[0.229, 0.224, 0.225]) 100 | ]) 101 | 102 | def __len__(self): 103 | return len(self.data) 104 | 105 | def __getitem__(self, i): 106 | path, label = self.data[i], self.label[i] 107 | img = Image.open(path).convert('RGB') 108 | image = [] 109 | for _ in range(self.args.shot): 110 | image.append(self.identity_transform(img).unsqueeze(0)) 111 | for i in range(self.args.train_query): 112 | image.append(self.augmentation_transform(img).unsqueeze(0)) 113 | return dict(data=torch.cat(image)), label 114 | -------------------------------------------------------------------------------- /datasets/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class CategoriesSampler(): 6 | 7 | def __init__(self, label, n_batch, n_cls, n_per): 8 | self.n_batch = n_batch 9 | self.n_cls = n_cls 10 | self.n_per = n_per 11 | 12 | label = np.array(label) 13 | self.m_ind = [] 14 | for i in range(max(label) + 1): 15 | ind = np.argwhere(label == i).reshape(-1) 16 | ind = torch.from_numpy(ind) 17 | self.m_ind.append(ind) 18 | 19 | def __len__(self): 20 | return self.n_batch 21 | 22 | def __iter__(self): 23 | for i_batch in range(self.n_batch): 24 | batch = [] 25 | classes = torch.randperm(len(self.m_ind))[:self.n_cls] 26 | for c in classes: 27 | l = self.m_ind[c] 28 | pos = torch.randperm(len(l))[:self.n_per] 29 | batch.append(l[pos]) 30 | batch = torch.stack(batch).t().reshape(-1) 31 | yield batch 32 | 33 | -------------------------------------------------------------------------------- /datasets/tiered_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from PIL import Image 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset 8 | from torchvision import transforms 9 | 10 | 11 | ROOT_PATH = './data/tiered-imagenet-kwon' 12 | 13 | class TieredImageNet(Dataset): 14 | 15 | def __init__(self, split='train', size=84, transform=None): 16 | split_tag = split 17 | data = np.load(os.path.join( 18 | ROOT_PATH, '{}_images.npz'.format(split_tag)), 19 | allow_pickle=True)['images'] 20 | data = data[:, :, :, ::-1] 21 | 22 | with open(os.path.join( 23 | ROOT_PATH, '{}_labels.pkl'.format(split_tag)), 'rb') as f: 24 | label = pickle.load(f)['labels'] 25 | 26 | data = [Image.fromarray(x) for x in data] 27 | 28 | min_label = min(label) 29 | label = [x - min_label for x in label] 30 | 31 | self.data = data 32 | self.label = label 33 | self.n_classes = max(self.label) + 1 34 | 35 | if transform is None: 36 | if split in ['train', 'trainval']: 37 | self.transform = transforms.Compose([ 38 | transforms.Resize(size+12), 39 | transforms.RandomCrop(size, padding=8), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.ToTensor(), 42 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 43 | std=[0.229, 0.224, 0.225]), 44 | ]) 45 | else: 46 | self.transform = transforms.Compose([ 47 | transforms.Resize(size), 48 | transforms.CenterCrop(size), 49 | transforms.ToTensor(), 50 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 51 | std=[0.229, 0.224, 0.225]) 52 | ]) 53 | else: 54 | self.transform = transform 55 | 56 | def __len__(self): 57 | return len(self.data) 58 | 59 | def __getitem__(self, i): 60 | return self.transform(self.data[i]), self.label[i] 61 | 62 | 63 | class SSLTieredImageNet(Dataset): 64 | 65 | def __init__(self, split='train', args=None): 66 | split_tag = split 67 | data = np.load(os.path.join( 68 | ROOT_PATH, '{}_images.npz'.format(split_tag)), 69 | allow_pickle=True)['images'] 70 | data = data[:, :, :, ::-1] 71 | 72 | with open(os.path.join( 73 | ROOT_PATH, '{}_labels.pkl'.format(split_tag)), 'rb') as f: 74 | label = pickle.load(f)['labels'] 75 | 76 | data = [Image.fromarray(x) for x in data] 77 | 78 | min_label = min(label) 79 | label = [x - min_label for x in label] 80 | 81 | self.data = data 82 | self.label = label 83 | self.n_classes = max(self.label) + 1 84 | 85 | color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4, 86 | saturation=0.4, hue=0.1) 87 | self.augmentation_transform = transforms.Compose([transforms.RandomResizedCrop(size=(args.size, args.size)[-2:], 88 | scale=(0.5, 1.0)), 89 | transforms.RandomHorizontalFlip(p=0.5), 90 | transforms.RandomVerticalFlip(p=0.5), 91 | transforms.RandomApply([color_jitter], p=0.8), 92 | transforms.RandomGrayscale(p=0.2), 93 | transforms.ToTensor(), 94 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 95 | std=[0.229, 0.224, 0.225]), 96 | ]) 97 | # 98 | self.identity_transform = transforms.Compose([ 99 | transforms.Resize(args.size), 100 | transforms.CenterCrop(args.size), 101 | transforms.ToTensor(), 102 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 103 | std=[0.229, 0.224, 0.225]) 104 | ]) 105 | 106 | def __len__(self): 107 | return len(self.data) 108 | 109 | def __getitem__(self, i): 110 | img, label = self.data[i], self.label[i] 111 | image = [] 112 | for _ in range(1): 113 | image.append(self.identity_transform(img).unsqueeze(0)) 114 | for i in range(3): 115 | image.append(self.augmentation_transform(img).unsqueeze(0)) 116 | return dict(data=torch.cat(image)), label 117 | -------------------------------------------------------------------------------- /models/convnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def conv_block(in_channels, out_channels): 4 | bn = nn.BatchNorm2d(out_channels) 5 | nn.init.uniform_(bn.weight) 6 | return nn.Sequential( 7 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 8 | bn, 9 | #nn.BatchNorm2d(out_channels), 10 | nn.ReLU(), 11 | nn.MaxPool2d(2) 12 | ) 13 | 14 | 15 | class Convnet(nn.Module): 16 | 17 | def __init__(self, x_dim=3, hid_dim=64, z_dim=64): 18 | super().__init__() 19 | self.encoder = nn.Sequential( 20 | conv_block(x_dim, hid_dim), 21 | conv_block(hid_dim, hid_dim), 22 | conv_block(hid_dim, hid_dim), 23 | conv_block(hid_dim, z_dim), 24 | ) 25 | 26 | def forward(self, x): 27 | x = self.encoder(x) 28 | return x.view(x.size(0), -1) 29 | -------------------------------------------------------------------------------- /models/distill.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class DistillKL(nn.Module): 5 | 6 | def __init__(self, T): 7 | super(DistillKL, self).__init__() 8 | self.T = T 9 | 10 | def forward(self, y_s, y_t): 11 | p_s = F.log_softmax(y_s/self.T, dim=1) 12 | p_t = F.softmax(y_t/self.T, dim=1) 13 | loss = F.kl_div(p_s, p_t, reduction='sum')*(self.T**2)/y_s.shape[0] 14 | return loss 15 | 16 | class HintLoss(nn.Module): 17 | def __init__(self): 18 | super(HintLoss, self).__init__() 19 | self.crit = nn.MSELoss() 20 | def forward(self, fs, ft): 21 | return self.crit(fs, ft) -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.distributions import Bernoulli 5 | 6 | # ======== 2D RESNET =========== 7 | 8 | def conv3x3(in_planes, out_planes, stride=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 11 | padding=1, bias=False) 12 | 13 | 14 | class SELayer(nn.Module): 15 | def __init__(self, channel, reduction=16): 16 | super(SELayer, self).__init__() 17 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 18 | self.fc = nn.Sequential( 19 | nn.Linear(channel, channel // reduction), 20 | nn.ReLU(inplace=True), 21 | nn.Linear(channel // reduction, channel), 22 | nn.Sigmoid() 23 | ) 24 | 25 | def forward(self, x): 26 | b, c, _, _ = x.size() 27 | y = self.avg_pool(x).view(b, c) 28 | y = self.fc(y).view(b, c, 1, 1) 29 | return x * y 30 | 31 | 32 | class DropBlock(nn.Module): 33 | def __init__(self, block_size): 34 | super(DropBlock, self).__init__() 35 | 36 | self.block_size = block_size 37 | #self.gamma = gamma 38 | #self.bernouli = Bernoulli(gamma) 39 | 40 | def forward(self, x, gamma): 41 | # shape: (bsize, channels, height, width) 42 | 43 | if self.training: 44 | batch_size, channels, height, width = x.shape 45 | 46 | bernoulli = Bernoulli(gamma) 47 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda() 48 | block_mask = self._compute_block_mask(mask) 49 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3] 50 | count_ones = block_mask.sum() 51 | 52 | return block_mask * x * (countM / count_ones) 53 | else: 54 | return x 55 | 56 | def _compute_block_mask(self, mask): 57 | left_padding = int((self.block_size-1) / 2) 58 | right_padding = int(self.block_size / 2) 59 | 60 | batch_size, channels, height, width = mask.shape 61 | #print ("mask", mask[0][0]) 62 | non_zero_idxs = mask.nonzero() 63 | nr_blocks = non_zero_idxs.shape[0] 64 | 65 | offsets = torch.stack( 66 | [ 67 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding, 68 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding 69 | ] 70 | ).t().cuda() 71 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1) 72 | 73 | if nr_blocks > 0: 74 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1) 75 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4) 76 | offsets = offsets.long() 77 | 78 | block_idxs = non_zero_idxs + offsets 79 | #block_idxs += left_padding 80 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 81 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1. 82 | else: 83 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding)) 84 | 85 | block_mask = 1 - padded_mask#[:height, :width] 86 | return block_mask 87 | 88 | 89 | class BasicBlock(nn.Module): 90 | expansion = 1 91 | 92 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, 93 | block_size=1, use_se=False): 94 | super(BasicBlock, self).__init__() 95 | self.conv1 = conv3x3(inplanes, planes) 96 | self.bn1 = nn.BatchNorm2d(planes) 97 | self.relu = nn.LeakyReLU(0.1) 98 | self.conv2 = conv3x3(planes, planes) 99 | self.bn2 = nn.BatchNorm2d(planes) 100 | self.conv3 = conv3x3(planes, planes) 101 | self.bn3 = nn.BatchNorm2d(planes) 102 | self.maxpool = nn.MaxPool2d(stride) 103 | self.downsample = downsample 104 | self.stride = stride 105 | self.drop_rate = drop_rate 106 | self.num_batches_tracked = 0 107 | self.drop_block = drop_block 108 | self.block_size = block_size 109 | self.DropBlock = DropBlock(block_size=self.block_size) 110 | self.use_se = use_se 111 | if self.use_se: 112 | self.se = SELayer(planes, 4) 113 | 114 | def forward(self, x): 115 | self.num_batches_tracked += 1 116 | 117 | residual = x 118 | 119 | out = self.conv1(x) 120 | out = self.bn1(out) 121 | out = self.relu(out) 122 | 123 | out = self.conv2(out) 124 | out = self.bn2(out) 125 | out = self.relu(out) 126 | 127 | out = self.conv3(out) 128 | out = self.bn3(out) 129 | if self.use_se: 130 | out = self.se(out) 131 | 132 | if self.downsample is not None: 133 | residual = self.downsample(x) 134 | out += residual 135 | out = self.relu(out) 136 | out = self.maxpool(out) 137 | 138 | if self.drop_rate > 0: 139 | if self.drop_block == True: 140 | feat_size = out.size()[2] 141 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate) 142 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2 143 | out = self.DropBlock(out, gamma=gamma) 144 | else: 145 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True) 146 | 147 | return out 148 | 149 | 150 | class ResNet2d(nn.Module): 151 | 152 | def __init__(self, block, n_blocks, keep_prob=1.0, avg_pool=False, drop_rate=0.0, 153 | dropblock_size=5, num_classes=-1, use_se=False): 154 | super(ResNet2d, self).__init__() 155 | 156 | self.inplanes = 3 157 | self.use_se = use_se 158 | self.layer1 = self._make_layer(block, n_blocks[0], 64, 159 | stride=2, drop_rate=drop_rate) 160 | self.layer2 = self._make_layer(block, n_blocks[1], 160, 161 | stride=2, drop_rate=drop_rate) 162 | self.layer3 = self._make_layer(block, n_blocks[2], 320, 163 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 164 | self.layer4 = self._make_layer(block, n_blocks[3], 640, 165 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size) 166 | if avg_pool: 167 | # self.avgpool = nn.AvgPool2d(5, stride=1) 168 | self.avgpool = nn.AdaptiveAvgPool2d(1) 169 | self.keep_prob = keep_prob 170 | self.keep_avg_pool = avg_pool 171 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False) 172 | self.drop_rate = drop_rate 173 | 174 | for m in self.modules(): 175 | if isinstance(m, nn.Conv2d): 176 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') 177 | elif isinstance(m, nn.BatchNorm2d): 178 | nn.init.constant_(m.weight, 1) 179 | nn.init.constant_(m.bias, 0) 180 | 181 | def _make_layer(self, block, n_block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1): 182 | downsample = None 183 | if stride != 1 or self.inplanes != planes * block.expansion: 184 | downsample = nn.Sequential( 185 | nn.Conv2d(self.inplanes, planes * block.expansion, 186 | kernel_size=1, stride=1, bias=False), 187 | nn.BatchNorm2d(planes * block.expansion), 188 | ) 189 | 190 | layers = [] 191 | if n_block == 1: 192 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size, self.use_se) 193 | else: 194 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, self.use_se) 195 | layers.append(layer) 196 | self.inplanes = planes * block.expansion 197 | 198 | for i in range(1, n_block): 199 | if i == n_block - 1: 200 | layer = block(self.inplanes, planes, drop_rate=drop_rate, drop_block=drop_block, 201 | block_size=block_size, use_se=self.use_se) 202 | else: 203 | layer = block(self.inplanes, planes, drop_rate=drop_rate, use_se=self.use_se) 204 | layers.append(layer) 205 | 206 | return nn.Sequential(*layers) 207 | 208 | def forward(self, x): 209 | x = self.layer1(x) 210 | x = self.layer2(x) 211 | x = self.layer3(x) 212 | x = self.layer4(x) 213 | if self.keep_avg_pool: 214 | x = self.avgpool(x) 215 | x = x.view(x.size(0), -1) 216 | return x 217 | 218 | def resnet12(keep_prob=1.0, avg_pool=False, **kwargs): 219 | model = ResNet2d(BasicBlock, [1, 1, 1, 1], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs) 220 | return model -------------------------------------------------------------------------------- /save/readme.txt: -------------------------------------------------------------------------------- 1 | Saved model/checkpoint location -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os.path as osp 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from datasets.mini_imagenet import MiniImageNet 8 | from datasets.tiered_imagenet import TieredImageNet 9 | from datasets.cifarfs import CIFAR_FS 10 | from datasets.samplers import CategoriesSampler 11 | from models.convnet import Convnet 12 | from models.resnet import resnet12 13 | from utils import set_gpu, Averager, count_acc, euclidean_metric, seed_torch, compute_confidence_interval 14 | 15 | 16 | def final_evaluate(args): 17 | if args.dataset == 'mini': 18 | valset = MiniImageNet('test', args.size) 19 | elif args.dataset == 'tiered': 20 | valset = TieredImageNet('test', args.size) 21 | elif args.dataset == "cifarfs": 22 | valset = CIFAR_FS('test', args.size) 23 | else: 24 | print("Invalid dataset...") 25 | exit() 26 | val_sampler = CategoriesSampler(valset.label, args.test_batch, 27 | args.test_way, args.shot + args.test_query) 28 | loader = DataLoader(dataset=valset, batch_sampler=val_sampler, 29 | num_workers=args.worker, pin_memory=True) 30 | 31 | if args.model == 'convnet': 32 | model = Convnet().cuda() 33 | print("=> Convnet architecture...") 34 | else: 35 | if args.dataset in ['mini', 'tiered']: 36 | model = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=5).cuda() 37 | else: 38 | model = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=2).cuda() 39 | print("=> Resnet architecture...") 40 | 41 | model.load_state_dict(torch.load(osp.join(args.save_path, 'max-acc.pth'))) 42 | print("=> Model loaded...") 43 | model.eval() 44 | 45 | ave_acc = Averager() 46 | acc_list = [] 47 | 48 | for i, batch in enumerate(loader, 1): 49 | data, _ = [_.cuda() for _ in batch] 50 | k = args.test_way * args.shot 51 | data_shot, data_query = data[:k], data[k:] 52 | 53 | x = model(data_shot) 54 | x = x.reshape(args.shot, args.test_way, -1).mean(dim=0) 55 | p = x 56 | 57 | logits = euclidean_metric(model(data_query), p) 58 | 59 | label = torch.arange(args.test_way).repeat(args.test_query) 60 | label = label.type(torch.cuda.LongTensor) 61 | 62 | acc = count_acc(logits, label) 63 | ave_acc.add(acc) 64 | acc_list.append(acc*100) 65 | 66 | x = None; p = None; logits = None 67 | 68 | a, b = compute_confidence_interval(acc_list) 69 | print("Final accuracy with 95% interval : {:.2f}±{:.2f}".format(a, b)) 70 | 71 | 72 | if __name__ == '__main__': 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('--shot', type=int, default=1) 75 | parser.add_argument('--test-query', type=int, default=15) 76 | parser.add_argument('--test-way', type=int, default=5) 77 | parser.add_argument('--save-path', default='') 78 | parser.add_argument('--gpu', default='0') 79 | parser.add_argument('--size', type=int, default=84) 80 | parser.add_argument('--test-batch', type=int, default=2000) 81 | parser.add_argument('--worker', type=int, default=8) 82 | parser.add_argument('--model', type=str, default='convnet', choices=['convnet', 'resnet']) 83 | parser.add_argument('--dataset', type=str, default='mini', choices=['mini','tiered','cifarfs']) 84 | args = parser.parse_args() 85 | 86 | start_time = datetime.datetime.now() 87 | 88 | # fix seed 89 | seed_torch(1) 90 | set_gpu(args.gpu) 91 | 92 | if args.dataset in ['mini', 'tiered']: 93 | args.size = 84 94 | elif args.dataset in ['cifarfs']: 95 | args.size = 32 96 | args.worker = 0 97 | else: 98 | args.size = 28 99 | 100 | final_evaluate(args) 101 | 102 | end_time = datetime.datetime.now() 103 | print("Total executed time :", end_time - start_time) -------------------------------------------------------------------------------- /train_stage1.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os.path as osp 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | 8 | from datasets.mini_imagenet import MiniImageNet, SSLMiniImageNet 9 | from datasets.tiered_imagenet import TieredImageNet, SSLTieredImageNet 10 | from datasets.cifarfs import CIFAR_FS, SSLCifarFS 11 | from datasets.samplers import CategoriesSampler 12 | from models.convnet import Convnet 13 | from models.resnet import resnet12 14 | from utils import set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric, seed_torch, compute_confidence_interval 15 | 16 | 17 | def get_dataset(args): 18 | if args.dataset == 'mini': 19 | trainset = SSLMiniImageNet('train', args) 20 | valset = MiniImageNet('test', args.size) 21 | print("=> MiniImageNet...") 22 | elif args.dataset == 'tiered': 23 | trainset = SSLTieredImageNet('train', args) 24 | valset = TieredImageNet('test', args.size) 25 | print("=> TieredImageNet...") 26 | elif args.dataset == 'cifarfs': 27 | trainset = SSLCifarFS('train', args) 28 | valset = CIFAR_FS('test', args.size) 29 | print("=> CIFAR FS...") 30 | else: 31 | print("Invalid dataset...") 32 | exit() 33 | 34 | train_loader = DataLoader(dataset=trainset, batch_size=args.train_way * args.shot, 35 | shuffle=True, drop_last=True, 36 | num_workers=args.worker, pin_memory=True) 37 | 38 | val_sampler = CategoriesSampler(valset.label, args.test_batch, 39 | args.test_way, args.shot + args.test_query) 40 | val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, 41 | num_workers=args.worker, pin_memory=True) 42 | return train_loader, val_loader 43 | 44 | def training(args): 45 | ensure_path(args.save_path) 46 | 47 | train_loader, val_loader = get_dataset(args) 48 | 49 | if args.model == 'convnet': 50 | model = Convnet().cuda() 51 | print("=> Convnet architecture...") 52 | else: 53 | if args.dataset in ['mini', 'tiered']: 54 | model = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=5).cuda() 55 | else: 56 | model = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=2).cuda() 57 | print("=> Resnet architecture...") 58 | 59 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) 60 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5) 61 | 62 | def save_model(name): 63 | torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth')) 64 | 65 | trlog = {} 66 | trlog['args'] = vars(args) 67 | trlog['train_loss'] = [] 68 | trlog['val_loss'] = [] 69 | trlog['train_acc'] = [] 70 | trlog['val_acc'] = [] 71 | trlog['max_acc'] = 0.0 72 | 73 | timer = Timer() 74 | best_epoch = 0 75 | cmi = [0.0, 0.0] 76 | 77 | for epoch in range(1, args.max_epoch + 1): 78 | 79 | tl, ta = train(args, model, train_loader, optimizer) 80 | lr_scheduler.step() 81 | vl, va, aa, bb = validate(args, model, val_loader) 82 | 83 | if va > trlog['max_acc']: 84 | trlog['max_acc'] = va 85 | save_model('max-acc') 86 | best_epoch = epoch 87 | cmi[0] = aa 88 | cmi[1] = bb 89 | 90 | trlog['train_loss'].append(tl) 91 | trlog['train_acc'].append(ta) 92 | trlog['val_loss'].append(vl) 93 | trlog['val_acc'].append(va) 94 | 95 | torch.save(trlog, osp.join(args.save_path, 'trlog')) 96 | 97 | save_model('epoch-last') 98 | ot, ots = timer.measure() 99 | tt, _ = timer.measure(epoch / args.max_epoch) 100 | 101 | print('Epoch {}/{}, train loss={:.4f} - acc={:.4f} - val loss={:.4f} - acc={:.4f} - max acc={:.4f} - ETA:{}/{}'.format( 102 | epoch, args.max_epoch, tl, ta, vl, va, trlog['max_acc'], ots, timer.tts(tt-ot))) 103 | 104 | if epoch == args.max_epoch: 105 | print("Best Epoch is {} with acc={:.2f}±{:.2f}%...".format(best_epoch, cmi[0], cmi[1])) 106 | print("---------------------------------------------------") 107 | 108 | def preprocess_data(data): 109 | for idxx, img in enumerate(data): 110 | # 4,3,84,84 111 | supportimg = img.data[0].unsqueeze(0) 112 | x90 = img.data[1].unsqueeze(0).transpose(2,3).flip(2) 113 | x180 = img.data[2].unsqueeze(0).flip(2).flip(3) 114 | x270 = img.data[3].unsqueeze(0).flip(2).transpose(2,3) 115 | queryimg = torch.cat((x90, x180, x270), 0) 116 | queryimg = queryimg.unsqueeze(0) 117 | if idxx <= 0: 118 | # support 119 | dshot = supportimg 120 | # query 121 | dquery = queryimg 122 | else: 123 | dshot = torch.cat((dshot, supportimg), 0) 124 | dquery = torch.cat((dquery, queryimg), 0) 125 | dquery = torch.transpose(dquery, 0, 1) 126 | dquery = dquery.reshape(args.train_way*args.train_query, 3, args.size, args.size) 127 | return dshot.cuda(), dquery.cuda() 128 | 129 | def train(args, model, train_loader, optimizer): 130 | model.train() 131 | 132 | tl = Averager() 133 | ta = Averager() 134 | 135 | for i, batch in enumerate(train_loader, 1): 136 | data, _ = batch 137 | data_shot, data_query = preprocess_data(data['data']) 138 | 139 | proto = model(data_shot) 140 | proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0) 141 | 142 | label = torch.arange(args.train_way).repeat(args.train_query) 143 | label = label.type(torch.cuda.LongTensor) 144 | 145 | logits = euclidean_metric(model(data_query), proto) 146 | loss = F.cross_entropy(logits, label) 147 | acc = count_acc(logits, label) 148 | 149 | tl.add(loss.item()) 150 | ta.add(acc) 151 | 152 | optimizer.zero_grad() 153 | loss.backward() 154 | optimizer.step() 155 | 156 | proto = None; logits = None; loss = None 157 | 158 | if (args.train_batch > 0) and (i >= args.train_batch): 159 | break 160 | 161 | return tl.item(), ta.item() 162 | 163 | 164 | def validate(args, model, val_loader): 165 | model.eval() 166 | 167 | vl = Averager() 168 | va = Averager() 169 | acc_list = [] 170 | 171 | for i, batch in enumerate(val_loader, 1): 172 | data, _ = [_.cuda() for _ in batch] 173 | p = args.shot * args.test_way 174 | data_shot, data_query = data[:p], data[p:] 175 | 176 | proto = model(data_shot) 177 | proto = proto.reshape(args.shot, args.test_way, -1).mean(dim=0) 178 | 179 | label = torch.arange(args.test_way).repeat(args.test_query) 180 | label = label.type(torch.cuda.LongTensor) 181 | 182 | logits = euclidean_metric(model(data_query), proto) 183 | loss = F.cross_entropy(logits, label) 184 | acc = count_acc(logits, label) 185 | 186 | vl.add(loss.item()) 187 | va.add(acc) 188 | acc_list.append(acc*100) 189 | 190 | proto = None; logits = None; loss = None 191 | 192 | a, b = compute_confidence_interval(acc_list) 193 | return vl.item(), va.item(), a, b 194 | 195 | 196 | if __name__ == '__main__': 197 | parser = argparse.ArgumentParser() 198 | parser.add_argument('--max-epoch', type=int, default=200) 199 | parser.add_argument('--shot', type=int, default=1) 200 | parser.add_argument('--train-query', type=int, default=3) 201 | parser.add_argument('--test-query', type=int, default=15) 202 | parser.add_argument('--train-way', type=int, default=50) 203 | parser.add_argument('--test-way', type=int, default=5) 204 | parser.add_argument('--save-path', default='') 205 | parser.add_argument('--gpu', default='0') 206 | parser.add_argument('--size', type=int, default=84) 207 | parser.add_argument('--lr', type=float, default=0.001) 208 | parser.add_argument('--wd', type=float, default=0.001) 209 | parser.add_argument('--step-size', type=int, default=20) 210 | parser.add_argument('--train-batch', type=int, default=-1) 211 | parser.add_argument('--test-batch', type=int, default=2000) 212 | parser.add_argument('--worker', type=int, default=8) 213 | parser.add_argument('--model', type=str, default='convnet', choices=['convnet', 'resnet']) 214 | parser.add_argument('--dataset', type=str, default='mini', choices=['mini','tiered','cifarfs']) 215 | args = parser.parse_args() 216 | 217 | start_time = datetime.datetime.now() 218 | 219 | # fix seed 220 | seed_torch(1) 221 | set_gpu(args.gpu) 222 | 223 | if args.dataset in ['mini', 'tiered']: 224 | args.size = 84 225 | elif args.dataset in ['cifarfs']: 226 | args.size = 32 227 | args.worker = 0 228 | else: 229 | args.size = 28 230 | 231 | training(args) 232 | 233 | end_time = datetime.datetime.now() 234 | print("Total executed time :", end_time - start_time) 235 | 236 | -------------------------------------------------------------------------------- /train_stage2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os.path as osp 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.utils.data import DataLoader 7 | 8 | from datasets.mini_imagenet import MiniImageNet 9 | from datasets.tiered_imagenet import TieredImageNet 10 | from datasets.cifarfs import CIFAR_FS 11 | from datasets.samplers import CategoriesSampler 12 | from models.convnet import Convnet 13 | from models.resnet import resnet12 14 | from utils import set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric, seed_torch, compute_confidence_interval 15 | 16 | def get_dataset(args): 17 | if args.dataset == 'mini': 18 | trainset = MiniImageNet('train', args.size) 19 | valset = MiniImageNet('test', args.size) 20 | print("=> MiniImageNet...") 21 | elif args.dataset == 'tiered': 22 | trainset = TieredImageNet('train', args.size) 23 | valset = TieredImageNet('test', args.size) 24 | print("=> TieredImageNet...") 25 | elif args.dataset == 'cifarfs': 26 | trainset = CIFAR_FS('train', args.size) 27 | valset = CIFAR_FS('test', args.size) 28 | print("=> CIFAR FS...") 29 | else: 30 | print("Invalid dataset...") 31 | exit() 32 | train_sampler = CategoriesSampler(trainset.label, args.train_batch, 33 | args.train_way, args.shot + args.train_query) 34 | train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler, 35 | num_workers=args.worker, pin_memory=True) 36 | 37 | val_sampler = CategoriesSampler(valset.label, args.test_batch, 38 | args.test_way, args.shot + args.test_query) 39 | val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, 40 | num_workers=args.worker, pin_memory=True) 41 | return train_loader, val_loader 42 | 43 | def training(args): 44 | ensure_path(args.save_path) 45 | 46 | train_loader, val_loader = get_dataset(args) 47 | 48 | if args.model == 'convnet': 49 | model = Convnet().cuda() 50 | print("=> Convnet architecture...") 51 | else: 52 | if args.dataset in ['mini', 'tiered','cub']: 53 | model = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=5).cuda() 54 | print("=> Large block resnet architecture...") 55 | else: 56 | model = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=2).cuda() 57 | print("=> Small block resnet architecture...") 58 | 59 | if args.stage1_path: 60 | model.load_state_dict(torch.load(osp.join(args.stage1_path, 'max-acc.pth'))) 61 | print("=> Pretrain model loaded...") 62 | 63 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) 64 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5) 65 | 66 | def save_model(name): 67 | torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth')) 68 | 69 | trlog = {} 70 | trlog['args'] = vars(args) 71 | trlog['train_loss'] = [] 72 | trlog['val_loss'] = [] 73 | trlog['train_acc'] = [] 74 | trlog['val_acc'] = [] 75 | trlog['max_acc'] = 0.0 76 | 77 | timer = Timer() 78 | best_epoch = 0 79 | cmi = [0.0, 0.0] 80 | 81 | for epoch in range(1, args.max_epoch + 1): 82 | 83 | tl, ta = train(args, model, train_loader, optimizer) 84 | # 85 | lr_scheduler.step() 86 | 87 | vl, va, aa, bb = validate(args, model, val_loader) 88 | 89 | if va > trlog['max_acc']: 90 | trlog['max_acc'] = va 91 | save_model('max-acc') 92 | best_epoch = epoch 93 | cmi[0] = aa 94 | cmi[1] = bb 95 | 96 | trlog['train_loss'].append(tl) 97 | trlog['train_acc'].append(ta) 98 | trlog['val_loss'].append(vl) 99 | trlog['val_acc'].append(va) 100 | 101 | torch.save(trlog, osp.join(args.save_path, 'trlog')) 102 | 103 | save_model('epoch-last') 104 | ot, ots = timer.measure() 105 | tt, _ = timer.measure(epoch / args.max_epoch) 106 | 107 | print('Epoch {}/{}, train loss={:.4f} - acc={:.4f} - val loss={:.4f} - acc={:.4f} - max acc={:.4f} - ETA:{}/{}'.format( 108 | epoch, args.max_epoch, tl, ta, vl, va, trlog['max_acc'], ots, timer.tts(tt-ot))) 109 | 110 | if epoch == args.max_epoch: 111 | print("Best Epoch is {} with acc={:.2f}±{:.2f}%...".format(best_epoch, cmi[0], cmi[1])) 112 | print("---------------------------------------------------") 113 | 114 | def ssl_loss(args, model, data_shot): 115 | # s1 s2 q1 q2 q1 q2 116 | x_90 = data_shot.transpose(2,3).flip(2) 117 | x_180 = data_shot.flip(2).flip(3) 118 | x_270 = data_shot.flip(2).transpose(2,3) 119 | data_query = torch.cat((x_90, x_180, x_270),0) 120 | 121 | proto = model(data_shot) 122 | proto = proto.reshape(1, args.train_way*args.shot, -1).mean(dim=0) 123 | query = model(data_query) 124 | 125 | label = torch.arange(args.train_way*args.shot).repeat(args.pre_query) 126 | label = label.type(torch.cuda.LongTensor) 127 | 128 | logits = euclidean_metric(query, proto) 129 | loss = F.cross_entropy(logits, label) 130 | 131 | return loss 132 | 133 | def train(args, model, train_loader, optimizer): 134 | model.train() 135 | 136 | tl = Averager() 137 | ta = Averager() 138 | 139 | for i, batch in enumerate(train_loader, 1): 140 | data, _ = [_.cuda() for _ in batch] 141 | p = args.shot * args.train_way 142 | data_shot, data_query = data[:p], data[p:] 143 | 144 | proto = model(data_shot) # (30, 1600) 145 | proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0) 146 | query = model(data_query) 147 | 148 | label = torch.arange(args.train_way).repeat(args.train_query) 149 | label = label.type(torch.cuda.LongTensor) 150 | 151 | logits = euclidean_metric(query, proto) 152 | loss_ss = ssl_loss(args, model, data_shot) 153 | loss = F.cross_entropy(logits, label) + args.beta * loss_ss 154 | acc = count_acc(logits, label) 155 | 156 | tl.add(loss.item()) 157 | ta.add(acc) 158 | 159 | optimizer.zero_grad() 160 | loss.backward() 161 | optimizer.step() 162 | 163 | proto = None; query = None; logits = None; loss = None 164 | 165 | return tl.item(), ta.item() 166 | 167 | def validate(args, model, val_loader): 168 | model.eval() 169 | 170 | vl = Averager() 171 | va = Averager() 172 | acc_list = [] 173 | 174 | for i, batch in enumerate(val_loader, 1): 175 | data, _ = [_.cuda() for _ in batch] 176 | p = args.shot * args.test_way 177 | data_shot, data_query = data[:p], data[p:] 178 | 179 | proto = model(data_shot) 180 | proto = proto.reshape(args.shot, args.test_way, -1).mean(dim=0) 181 | query = model(data_query) 182 | 183 | label = torch.arange(args.test_way).repeat(args.test_query) 184 | label = label.type(torch.cuda.LongTensor) 185 | 186 | logits = euclidean_metric(query, proto) 187 | loss = F.cross_entropy(logits, label) 188 | acc = count_acc(logits, label) 189 | 190 | vl.add(loss.item()) 191 | va.add(acc) 192 | acc_list.append(acc*100) 193 | 194 | proto = None; query = None; logits = None; loss = None 195 | a,b = compute_confidence_interval(acc_list) 196 | return vl.item(), va.item(), a, b 197 | 198 | 199 | if __name__ == '__main__': 200 | parser = argparse.ArgumentParser() 201 | parser.add_argument('--max-epoch', type=int, default=200) 202 | parser.add_argument('--shot', type=int, default=1) 203 | parser.add_argument('--pre-query', type=int, default=3) 204 | parser.add_argument('--train-query', type=int, default=15) 205 | parser.add_argument('--test-query', type=int, default=15) 206 | parser.add_argument('--train-way', type=int, default=5) 207 | parser.add_argument('--test-way', type=int, default=5) 208 | parser.add_argument('--save-path', default='') 209 | parser.add_argument('--gpu', default='0') 210 | parser.add_argument('--size', type=int, default=84) 211 | parser.add_argument('--lr', type=float, default=0.001) 212 | parser.add_argument('--wd', type=float, default=0.001) 213 | parser.add_argument('--step-size', type=int, default=20) 214 | parser.add_argument('--train-batch', type=int, default=100) 215 | parser.add_argument('--test-batch', type=int, default=2000) 216 | parser.add_argument('--worker', type=int, default=8) 217 | parser.add_argument('--model', type=str, default='convnet', choices=['convnet', 'resnet']) 218 | parser.add_argument('--mode', type=int, default=0, choices=[0,1]) 219 | parser.add_argument('--stage1-path', default='') 220 | parser.add_argument('--beta', type=float, default=0.1) 221 | parser.add_argument('--dataset', type=str, default='mini', choices=['mini','tiered','cifarfs']) 222 | args = parser.parse_args() 223 | 224 | start_time = datetime.datetime.now() 225 | 226 | # fix seed 227 | seed_torch(1) 228 | set_gpu(args.gpu) 229 | 230 | if args.dataset in ['mini', 'tiered']: 231 | args.size = 84 232 | elif args.dataset in ['cifarfs']: 233 | args.size = 32 234 | args.worker = 0 235 | else: 236 | args.size = 28 237 | 238 | training(args) 239 | 240 | end_time = datetime.datetime.now() 241 | print("Total executed time :", end_time - start_time) 242 | 243 | -------------------------------------------------------------------------------- /train_stage3.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os.path as osp 4 | import copy 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.utils.data import DataLoader 8 | 9 | from datasets.mini_imagenet import MiniImageNet 10 | from datasets.tiered_imagenet import TieredImageNet 11 | from datasets.cifarfs import CIFAR_FS 12 | from datasets.samplers import CategoriesSampler 13 | from models.convnet import Convnet 14 | from models.distill import DistillKL, HintLoss 15 | from models.resnet import resnet12 16 | from utils import set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric, seed_torch, compute_confidence_interval 17 | 18 | 19 | def get_dataset(args): 20 | if args.dataset == 'mini': 21 | trainset = MiniImageNet('train', args.size) 22 | valset = MiniImageNet('test', args.size) 23 | print("=> MiniImageNet...") 24 | elif args.dataset == 'tiered': 25 | trainset = TieredImageNet('train', args.size) 26 | valset = TieredImageNet('test', args.size) 27 | print("=> TieredImageNet...") 28 | elif args.dataset == 'cifarfs': 29 | trainset = CIFAR_FS('train', args.size) 30 | valset = CIFAR_FS('test', args.size) 31 | print("=> CIFAR FS...") 32 | else: 33 | print("Invalid dataset...") 34 | exit() 35 | train_sampler = CategoriesSampler(trainset.label, args.train_batch, 36 | args.train_way, args.shot + args.train_query) 37 | train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler, 38 | num_workers=args.worker, pin_memory=True) 39 | 40 | val_sampler = CategoriesSampler(valset.label, args.test_batch, 41 | args.test_way, args.shot + args.test_query) 42 | val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, 43 | num_workers=args.worker, pin_memory=True) 44 | return train_loader, val_loader 45 | 46 | def training(args): 47 | ensure_path(args.save_path) 48 | 49 | train_loader, val_loader = get_dataset(args) 50 | 51 | if args.model == 'convnet': 52 | teacher = Convnet().cuda() 53 | print("=> Convnet architecture...") 54 | else: 55 | if args.dataset in ['mini', 'tiered']: 56 | teacher = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=5).cuda() 57 | else: 58 | teacher = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=2).cuda() 59 | print("=> Resnet architecture...") 60 | 61 | if args.kd_mode != 0: 62 | # produce a student model with the same structure as teacher model without knowldege 63 | model = copy.deepcopy(teacher) 64 | if args.stage1_path: 65 | model.load_state_dict(torch.load(osp.join(args.stage1_path, 'max-acc.pth'))) 66 | print("=> Student loaded with pretrain knowledge...") 67 | 68 | teacher.load_state_dict(torch.load(osp.join(args.stage2_path, 'max-acc.pth'))) 69 | print("=> Teacher model loaded...") 70 | 71 | if args.kd_mode == 0: 72 | # intilialize student with same knowledge as teacher 73 | model = copy.deepcopy(teacher) 74 | print("=> Student obtain teacher's knowledge...") 75 | 76 | if args.kd_type == 'kd': 77 | criterion_kd = DistillKL(args.temperature).cuda() 78 | else: 79 | criterion_kd = HintLoss().cuda() 80 | 81 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd) 82 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5) 83 | 84 | def save_model(name): 85 | torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth')) 86 | 87 | trlog = {} 88 | trlog['args'] = vars(args) 89 | trlog['train_loss'] = [] 90 | trlog['val_loss'] = [] 91 | trlog['train_acc'] = [] 92 | trlog['val_acc'] = [] 93 | trlog['max_acc'] = 0.0 94 | 95 | timer = Timer() 96 | best_epoch = 0 97 | cmi = [0.0, 0.0] 98 | 99 | for epoch in range(1, args.max_epoch + 1): 100 | 101 | tl, ta = train(args, teacher, model, train_loader, optimizer, criterion_kd) 102 | lr_scheduler.step() 103 | vl, va, aa, bb = validate(args, model, val_loader) 104 | 105 | if va > trlog['max_acc']: 106 | trlog['max_acc'] = va 107 | save_model('max-acc') 108 | best_epoch = epoch 109 | cmi[0] = aa 110 | cmi[1] = bb 111 | 112 | trlog['train_loss'].append(tl) 113 | trlog['train_acc'].append(ta) 114 | trlog['val_loss'].append(vl) 115 | trlog['val_acc'].append(va) 116 | 117 | torch.save(trlog, osp.join(args.save_path, 'trlog')) 118 | 119 | save_model('epoch-last') 120 | ot, ots = timer.measure() 121 | tt, _ = timer.measure(epoch / args.max_epoch) 122 | 123 | print('Epoch {}/{}, train loss={:.4f} - acc={:.4f} - val loss={:.4f} - acc={:.4f} - max acc={:.4f} - ETA:{}/{}'.format( 124 | epoch, args.max_epoch, tl, ta, vl, va, trlog['max_acc'], ots, timer.tts(tt-ot))) 125 | 126 | if epoch == args.max_epoch: 127 | print("Best Epoch is {} with acc={:.2f}±{:.2f}%...".format(best_epoch, cmi[0], cmi[1])) 128 | print("---------------------------------------------------") 129 | 130 | def ssl_loss(args, model, data_shot): 131 | # s1 s2 q1 q2 q1 q2 132 | x_90 = data_shot.transpose(2,3).flip(2) 133 | x_180 = data_shot.flip(2).flip(3) 134 | x_270 = data_shot.flip(2).transpose(2,3) 135 | data_query = torch.cat((x_90, x_180, x_270),0) 136 | 137 | proto = model(data_shot) 138 | proto = proto.reshape(1, args.shot*args.train_way, -1).mean(dim=0) 139 | 140 | label = torch.arange(args.train_way * args.shot).repeat(args.pre_query) 141 | label = label.type(torch.cuda.LongTensor) 142 | 143 | logits = euclidean_metric(model(data_query), proto) 144 | loss = F.cross_entropy(logits, label) 145 | 146 | return loss 147 | 148 | def train(args, teacher, model, train_loader, optimizer, criterion_kd): 149 | teacher.eval() 150 | model.train() 151 | 152 | tl = Averager() 153 | ta = Averager() 154 | 155 | for i, batch in enumerate(train_loader, 1): 156 | data, _ = [_.cuda() for _ in batch] 157 | p = args.shot * args.train_way 158 | data_shot, data_query = data[:p], data[p:] # datashot (30, 3, 84, 84) 159 | 160 | # teacher 161 | with torch.no_grad(): 162 | tproto = teacher(data_shot) 163 | ft = tproto 164 | ft = [f.detach() for f in ft] 165 | tproto = tproto.reshape(args.shot, args.train_way, -1).mean(dim=0) 166 | # soft target from teacher 167 | tlogits = euclidean_metric(teacher(data_query), tproto) 168 | 169 | proto = model(data_shot) # (30, 1600) 170 | fs = proto 171 | proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0) 172 | 173 | label = torch.arange(args.train_way).repeat(args.train_query) 174 | label = label.type(torch.cuda.LongTensor) 175 | 176 | logits = euclidean_metric(model(data_query), proto) 177 | acc = count_acc(logits, label) 178 | 179 | if args.kd_mode != 0: 180 | # few-shot loss from student 181 | clsloss = F.cross_entropy(logits, label) 182 | 183 | # distillation loss 184 | if args.kd_type == 'kd': 185 | kdloss = criterion_kd(logits, tlogits) 186 | else: 187 | kdloss = criterion_kd(fs[-1], ft[-1]) 188 | 189 | # self-supervised loss signal 190 | loss_ss = ssl_loss(args, model, data_shot) 191 | 192 | if args.kd_mode != 0: 193 | loss = ((1.0 - args.kd_coef) * clsloss) + (args.kd_coef * kdloss) + (args.ssl_coef * loss_ss) 194 | else: 195 | loss = kdloss + (args.ssl_coef * loss_ss) 196 | 197 | tl.add(loss.item()) 198 | ta.add(acc) 199 | 200 | optimizer.zero_grad() 201 | loss.backward() 202 | optimizer.step() 203 | 204 | proto = None; logits = None; loss = None 205 | 206 | return tl.item(), ta.item() 207 | 208 | def validate(args, model, val_loader): 209 | model.eval() 210 | 211 | vl = Averager() 212 | va = Averager() 213 | acc_list = [] 214 | 215 | for i, batch in enumerate(val_loader, 1): 216 | data, _ = [_.cuda() for _ in batch] 217 | p = args.shot * args.test_way 218 | data_shot, data_query = data[:p], data[p:] 219 | 220 | proto = model(data_shot) 221 | proto = proto.reshape(args.shot, args.test_way, -1).mean(dim=0) 222 | 223 | label = torch.arange(args.test_way).repeat(args.test_query) 224 | label = label.type(torch.cuda.LongTensor) 225 | 226 | logits = euclidean_metric(model(data_query), proto) 227 | loss = F.cross_entropy(logits, label) 228 | acc = count_acc(logits, label) 229 | 230 | vl.add(loss.item()) 231 | va.add(acc) 232 | acc_list.append(acc*100) 233 | 234 | proto = None; logits = None; loss = None 235 | a,b = compute_confidence_interval(acc_list) 236 | return vl.item(), va.item(), a, b 237 | 238 | 239 | if __name__ == '__main__': 240 | parser = argparse.ArgumentParser() 241 | parser.add_argument('--max-epoch', type=int, default=200) 242 | parser.add_argument('--shot', type=int, default=1) 243 | parser.add_argument('--pre-query', type=int, default=3) # for self-supervised process: the number of query image generated based on support image 244 | parser.add_argument('--train-query', type=int, default=15) 245 | parser.add_argument('--test-query', type=int, default=15) 246 | parser.add_argument('--train-way', type=int, default=5) 247 | parser.add_argument('--test-way', type=int, default=5) 248 | parser.add_argument('--save-path', default='') 249 | parser.add_argument('--gpu', default='0') 250 | parser.add_argument('--size', type=int, default=84) 251 | parser.add_argument('--lr', type=float, default=0.001) 252 | parser.add_argument('--wd', type=float, default=0.001) 253 | parser.add_argument('--step-size', type=int, default=20) 254 | parser.add_argument('--train-batch', type=int, default=100) 255 | parser.add_argument('--test-batch', type=int, default=2000) 256 | parser.add_argument('--worker', type=int, default=8) 257 | parser.add_argument('--model', type=str, default='convnet', choices=['convnet', 'resnet']) 258 | parser.add_argument('--dataset', type=str, default='mini', choices=['mini','tiered','cifarfs']) 259 | parser.add_argument('--ssl-coef', type=float, default=0.1, help='The beta coefficient for self-supervised loss') 260 | # self-distillation stage parameter 261 | parser.add_argument('--temperature', type=int, default=4) 262 | parser.add_argument('--kd-coef', type=float, default=0.1, help="The gamma coefficient for distillation loss") 263 | # 0: copy teacher and only KD 1: common KD 264 | parser.add_argument('--kd-mode', type=int, default=1, choices=[0,1]) 265 | parser.add_argument('--kd-type', type=str, default='kd', choices=['kd', 'hint']) 266 | parser.add_argument('--stage1-path', default='') 267 | parser.add_argument('--stage2-path', default='') 268 | args = parser.parse_args() 269 | 270 | start_time = datetime.datetime.now() 271 | 272 | # fix seed 273 | seed_torch(1) 274 | set_gpu(args.gpu) 275 | 276 | if args.dataset in ['mini', 'tiered']: 277 | args.size = 84 278 | elif args.dataset in ['cifarfs']: 279 | args.size = 32 280 | args.worker = 0 281 | else: 282 | args.size = 28 283 | 284 | training(args) 285 | 286 | end_time = datetime.datetime.now() 287 | print("Total executed time :", end_time - start_time) 288 | 289 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import random 5 | import numpy as np 6 | 7 | import torch 8 | 9 | def seed_torch(seed=1337): 10 | random.seed(seed) 11 | os.environ['PYTHONHASHSEED'] = str(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 16 | torch.backends.cudnn.benchmark = False 17 | torch.backends.cudnn.deterministic = True 18 | 19 | 20 | def set_gpu(x): 21 | os.environ['CUDA_VISIBLE_DEVICES'] = x 22 | print('using gpu:', x) 23 | 24 | 25 | def ensure_path(path): 26 | if os.path.exists(path): 27 | #if input('{} exists, remove? ([y]/n)'.format(path)) != 'n': 28 | shutil.rmtree(path) 29 | os.makedirs(path) 30 | else: 31 | os.makedirs(path) 32 | 33 | 34 | class Averager(): 35 | 36 | def __init__(self): 37 | self.n = 0 38 | self.v = 0 39 | 40 | def add(self, x): 41 | self.v = (self.v * self.n + x) / (self.n + 1) 42 | self.n += 1 43 | 44 | def item(self): 45 | return self.v 46 | 47 | 48 | def count_acc(logits, label): 49 | pred = torch.argmax(logits, dim=1) 50 | return (pred == label).type(torch.cuda.FloatTensor).mean().item() 51 | 52 | 53 | def dot_metric(a, b): 54 | return torch.mm(a, b.t()) 55 | 56 | import torch.nn.functional as F 57 | def cos_metric(a, b): 58 | return torch.mm(F.normalize(a, dim=-1), F.normalize(b, dim=-1).t()) 59 | 60 | def euclidean_metric(a, b): 61 | n = a.shape[0] 62 | m = b.shape[0] 63 | a = a.unsqueeze(1).expand(n, m, -1) 64 | b = b.unsqueeze(0).expand(n, m, -1) 65 | logits = -((a - b)**2).sum(dim=2) 66 | return logits 67 | 68 | 69 | class Timer(): 70 | 71 | def __init__(self): 72 | self.o = time.time() 73 | 74 | def measure(self, p=1): 75 | x = (time.time() - self.o) / p 76 | x = int(x) 77 | return x, self.tts(x) 78 | 79 | def tts(self, x=0): 80 | if x >= 3600: 81 | return '{:.1f}h'.format(x / 3600) 82 | if x >= 60: 83 | return '{}m'.format(round(x / 60)) 84 | return '{}s'.format(x) 85 | 86 | 87 | def compute_confidence_interval(data): 88 | """ 89 | Compute 95% confidence interval 90 | :param data: An array of mean accuracy (or mAP) across a number of sampled episodes. 91 | :return: the 95% confidence interval for this data. 92 | """ 93 | a = 1.0 * np.array(data) 94 | m = np.mean(a) 95 | std = np.std(a) 96 | pm = 1.96 * (std / np.sqrt(len(a))) 97 | return m, pm 98 | --------------------------------------------------------------------------------