├── LICENSE
├── README.md
├── data
└── readme.txt
├── datasets
├── cifarfs.py
├── mini_imagenet.py
├── samplers.py
└── tiered_imagenet.py
├── models
├── convnet.py
├── distill.py
└── resnet.py
├── save
└── readme.txt
├── test.py
├── train_stage1.py
├── train_stage2.py
├── train_stage3.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Jit Yan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SSL-ProtoNet: Self-supervised Learning Prototypical Networks for Few-shot Learning
2 |
3 | This repository contains the **pytorch** code for the paper: "[SSL-ProtoNet: Self-supervised Learning Prototypical Networks for Few-shot Learning](https://doi.org/10.1016/j.eswa.2023.122173)" Jit Yan Lim, Kian Ming Lim, Chin Poo Lee, Yong Xuan Tan
4 |
5 | ## Environment
6 | The code is tested on Windows 10 with Anaconda3 and following packages:
7 | - python 3.7.4
8 | - pytorch 1.3.1
9 |
10 | ## Preparation
11 | 1. Change the ROOT_PATH value in the following files to yours:
12 | - `datasets/mini_imagenet.py`
13 | - `datasets/tiered_imagenet.py`
14 | - `datasets/cifarfs.py`
15 |
16 | 2. Download the datasets and put them into corresponding folders that mentioned in the ROOT_PATH:
17 | - ***mini*ImageNet**: download from [CSS](https://github.com/anyuexuan/CSS) and put in `data/mini-imagenet` folder.
18 |
19 | - ***tiered*ImageNet**: download from [RFS](https://www.dropbox.com/sh/6yd1ygtyc3yd981/AABVeEqzC08YQv4UZk7lNHvya?dl=0) and put in `data/tiered-imagenet` folder.
20 |
21 | - **CIFARFS**: download from [MetaOptNet](https://github.com/kjunelee/MetaOptNet) and put in `data/cifar-fs` folder.
22 |
23 |
24 | ## Pre-trained Models
25 | [Optional] The pre-trained models can be downloaded from [here](https://drive.google.com/file/d/14IOHnVfVACpkhjj1o3ZjwG7YD4p6ULLM/view?usp=sharing). Extract and put the content in the save folder. To evaluate the model, run the test.py file with the proper save path as in the next section.
26 |
27 |
28 | ## Experiments
29 | To train on 1-shot and 5-shot CIFAR-FS:
30 | ```
31 | python train_stage1.py --dataset cifarfs --train-way 50 --train-batch 100 --save-path ./save/cifarfs-stage1
32 |
33 | python train_stage2.py --dataset cifarfs --shot 1 --save-path ./save/cifarfs-stage2-1s --stage1-path ./save/cifarfs-stage1 --train-way 20
34 | python train_stage2.py --dataset cifarfs --shot 5 --save-path ./save/cifarfs-stage2-5s --stage1-path ./save/cifarfs-stage1 --train-way 10
35 |
36 | python train_stage3.py --kd-coef 0.7 --dataset cifarfs --shot 1 --train-way 20 --stage1-path ./save/cifarfs-stage1 --stage2-path ./save/cifarfs-stage2-1s --save-path ./save/cifarfs-stage3-1s
37 | python train_stage3.py --kd-coef 0.1 --dataset cifarfs --shot 5 --train-way 10 --stage1-path ./save/cifarfs-stage1 --stage2-path ./save/cifarfs-stage2-5s --save-path ./save/cifarfs-stage3-5s
38 | ```
39 | To evaluate on 5-way 1-shot and 5-way 5-shot CIFAR-FS:
40 | ```
41 | python test.py --dataset cifarfs --shot 1 --save-path ./save/cifarfs-stage3-1s
42 | python test.py --dataset cifarfs --shot 5 --save-path ./save/cifarfs-stage3-1s
43 | ```
44 |
45 |
46 | ## Citation
47 | If you find this repo useful for your research, please consider citing the paper:
48 | ```
49 | @article{LIM2023122173,
50 | title = {SSL-ProtoNet: Self-supervised Learning Prototypical Networks for few-shot learning},
51 | journal = {Expert Systems with Applications},
52 | pages = {122173},
53 | year = {2023},
54 | issn = {0957-4174},
55 | doi = {https://doi.org/10.1016/j.eswa.2023.122173},
56 | author = {Jit Yan Lim and Kian Ming Lim and Chin Poo Lee and Yong Xuan Tan}
57 | }
58 | ```
59 |
60 | ## Contacts
61 | For any questions, please contact:
62 |
63 | Jit Yan Lim (jityan95@gmail.com)
64 | Kian Ming Lim (Kian-Ming.Lim@nottingham.edu.cn)
65 |
66 | ## Acknowlegements
67 | This repo is based on **[Prototypical Networks](https://github.com/yinboc/prototypical-network-pytorch)**, **[RFS](https://github.com/WangYueFt/rfs)**, and **[SKD](https://github.com/brjathu/SKD)**.
68 |
--------------------------------------------------------------------------------
/data/readme.txt:
--------------------------------------------------------------------------------
1 | datasets location
--------------------------------------------------------------------------------
/datasets/cifarfs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import torch
4 | from torch.utils.data import Dataset
5 | import torchvision.transforms as transforms
6 | from PIL import Image
7 | import numpy as np
8 |
9 | ROOT_PATH = './data/cifar-fs'
10 |
11 | def load_data(file):
12 | try:
13 | with open(file, 'rb') as fo:
14 | data = pickle.load(fo)
15 | return data
16 | except:
17 | with open(file, 'rb') as f:
18 | u = pickle._Unpickler(f)
19 | u.encoding = 'latin1'
20 | data = u.load()
21 | return data
22 |
23 |
24 | class CIFAR_FS(Dataset):
25 |
26 | def __init__(self, phase='train', size=32, transform=None):
27 |
28 | filepath = os.path.join(ROOT_PATH, 'CIFAR_FS_' + phase + ".pickle")
29 | datafile = load_data(filepath)
30 |
31 | data = datafile['data']
32 | label = datafile['labels']
33 |
34 | data = [Image.fromarray(x) for x in data]
35 |
36 | min_label = min(label)
37 | label = [x - min_label for x in label]
38 |
39 | newlabel = []
40 | classlabel = 0
41 | for i in range(len(label)):
42 | if (i > 0) and (label[i] != label[i-1]):
43 | classlabel += 1
44 | newlabel.append(classlabel)
45 |
46 | self.data = data
47 | self.label = newlabel
48 |
49 | if transform is None:
50 | self.transform = transforms.Compose([
51 | transforms.Resize(size),
52 | transforms.CenterCrop(size),
53 | transforms.ToTensor(),
54 | transforms.Normalize(
55 | np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
56 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])
57 | )
58 | ])
59 | else:
60 | self.transform = transform
61 |
62 | def __len__(self):
63 | return len(self.data)
64 |
65 | def __getitem__(self, i):
66 | return self.transform(self.data[i]), self.label[i]
67 |
68 |
69 | class SSLCifarFS(Dataset):
70 |
71 | def __init__(self, phase, args):
72 | filepath = os.path.join(ROOT_PATH, 'CIFAR_FS_' + phase + ".pickle")
73 | datafile = load_data(filepath)
74 |
75 | data = datafile['data']
76 | label = datafile['labels']
77 |
78 | data = [Image.fromarray(x) for x in data]
79 |
80 | min_label = min(label)
81 | label = [x - min_label for x in label]
82 |
83 | newlabel = []
84 | classlabel = 0
85 | for i in range(len(label)):
86 | if (i > 0) and (label[i] != label[i-1]):
87 | classlabel += 1
88 | newlabel.append(classlabel)
89 |
90 | self.data = data
91 | self.label = newlabel
92 | self.args = args
93 |
94 | color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4,
95 | saturation=0.4, hue=0.1)
96 | self.augmentation_transform = transforms.Compose([transforms.RandomResizedCrop(size=(args.size, args.size)[-2:],
97 | scale=(0.5, 1.0)),
98 | transforms.RandomHorizontalFlip(p=0.5),
99 | transforms.RandomVerticalFlip(p=0.5),
100 | transforms.RandomApply([color_jitter], p=0.8),
101 | transforms.RandomGrayscale(p=0.2),
102 | transforms.ToTensor(),
103 | transforms.Normalize(
104 | np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
105 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])
106 | )
107 | ])
108 | #
109 | self.identity_transform = transforms.Compose([
110 | transforms.Resize(args.size),
111 | transforms.CenterCrop(args.size),
112 | transforms.ToTensor(),
113 | transforms.Normalize(
114 | np.array([x / 255.0 for x in [125.3, 123.0, 113.9]]),
115 | np.array([x / 255.0 for x in [63.0, 62.1, 66.7]])
116 | )
117 | ])
118 |
119 | def __len__(self):
120 | return len(self.data)
121 |
122 | def __getitem__(self, i):
123 | img, label = self.data[i], self.label[i]
124 | image = []
125 | for _ in range(self.args.shot):
126 | image.append(self.identity_transform(img).unsqueeze(0))
127 | for i in range(self.args.train_query):
128 | image.append(self.augmentation_transform(img).unsqueeze(0))
129 | return dict(data=torch.cat(image)), label
130 |
131 |
--------------------------------------------------------------------------------
/datasets/mini_imagenet.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from PIL import Image
3 |
4 | import torch
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | ROOT_PATH = './data/mini-imagenet'
10 |
11 |
12 | class MiniImageNet(Dataset):
13 |
14 | def __init__(self, setname, size, transform=None):
15 | csv_path = osp.join(ROOT_PATH, setname + '.csv')
16 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
17 |
18 | data = []
19 | label = []
20 | lb = -1
21 |
22 | self.wnids = []
23 |
24 | for l in lines:
25 | name, wnid = l.split(',')
26 | path = osp.join(ROOT_PATH, 'images', name)
27 | if wnid not in self.wnids:
28 | self.wnids.append(wnid)
29 | lb += 1
30 | data.append(path)
31 | label.append(lb)
32 |
33 | self.data = data
34 | self.label = label
35 |
36 | if transform is None:
37 | self.transform = transforms.Compose([
38 | transforms.Resize(size),
39 | transforms.CenterCrop(size),
40 | transforms.ToTensor(),
41 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
42 | std=[0.229, 0.224, 0.225])
43 | ])
44 | else:
45 | self.transform = transform
46 |
47 | def __len__(self):
48 | return len(self.data)
49 |
50 | def __getitem__(self, i):
51 | path, label = self.data[i], self.label[i]
52 | image = self.transform(Image.open(path).convert('RGB'))
53 | return image, label
54 |
55 |
56 | class SSLMiniImageNet(Dataset):
57 |
58 | def __init__(self, setname, args):
59 | csv_path = osp.join(ROOT_PATH, setname + '.csv')
60 | lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
61 |
62 | data = []
63 | label = []
64 | lb = -1
65 |
66 | self.wnids = []
67 | self.args = args
68 |
69 | for l in lines:
70 | name, wnid = l.split(',')
71 | path = osp.join(ROOT_PATH, 'images', name)
72 | if wnid not in self.wnids:
73 | self.wnids.append(wnid)
74 | lb += 1
75 | data.append(path)
76 | label.append(lb)
77 |
78 | self.data = data
79 | self.label = label
80 |
81 | color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4,
82 | saturation=0.4, hue=0.1)
83 | self.augmentation_transform = transforms.Compose([transforms.RandomResizedCrop(size=(args.size, args.size)[-2:],
84 | scale=(0.5, 1.0)),
85 | transforms.RandomHorizontalFlip(p=0.5),
86 | transforms.RandomVerticalFlip(p=0.5),
87 | transforms.RandomApply([color_jitter], p=0.8),
88 | transforms.RandomGrayscale(p=0.2),
89 | transforms.ToTensor(),
90 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
91 | std=[0.229, 0.224, 0.225]),
92 | ])
93 | #
94 | self.identity_transform = transforms.Compose([
95 | transforms.Resize(args.size),
96 | transforms.CenterCrop(args.size),
97 | transforms.ToTensor(),
98 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
99 | std=[0.229, 0.224, 0.225])
100 | ])
101 |
102 | def __len__(self):
103 | return len(self.data)
104 |
105 | def __getitem__(self, i):
106 | path, label = self.data[i], self.label[i]
107 | img = Image.open(path).convert('RGB')
108 | image = []
109 | for _ in range(self.args.shot):
110 | image.append(self.identity_transform(img).unsqueeze(0))
111 | for i in range(self.args.train_query):
112 | image.append(self.augmentation_transform(img).unsqueeze(0))
113 | return dict(data=torch.cat(image)), label
114 |
--------------------------------------------------------------------------------
/datasets/samplers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class CategoriesSampler():
6 |
7 | def __init__(self, label, n_batch, n_cls, n_per):
8 | self.n_batch = n_batch
9 | self.n_cls = n_cls
10 | self.n_per = n_per
11 |
12 | label = np.array(label)
13 | self.m_ind = []
14 | for i in range(max(label) + 1):
15 | ind = np.argwhere(label == i).reshape(-1)
16 | ind = torch.from_numpy(ind)
17 | self.m_ind.append(ind)
18 |
19 | def __len__(self):
20 | return self.n_batch
21 |
22 | def __iter__(self):
23 | for i_batch in range(self.n_batch):
24 | batch = []
25 | classes = torch.randperm(len(self.m_ind))[:self.n_cls]
26 | for c in classes:
27 | l = self.m_ind[c]
28 | pos = torch.randperm(len(l))[:self.n_per]
29 | batch.append(l[pos])
30 | batch = torch.stack(batch).t().reshape(-1)
31 | yield batch
32 |
33 |
--------------------------------------------------------------------------------
/datasets/tiered_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from PIL import Image
4 |
5 | import numpy as np
6 | import torch
7 | from torch.utils.data import Dataset
8 | from torchvision import transforms
9 |
10 |
11 | ROOT_PATH = './data/tiered-imagenet-kwon'
12 |
13 | class TieredImageNet(Dataset):
14 |
15 | def __init__(self, split='train', size=84, transform=None):
16 | split_tag = split
17 | data = np.load(os.path.join(
18 | ROOT_PATH, '{}_images.npz'.format(split_tag)),
19 | allow_pickle=True)['images']
20 | data = data[:, :, :, ::-1]
21 |
22 | with open(os.path.join(
23 | ROOT_PATH, '{}_labels.pkl'.format(split_tag)), 'rb') as f:
24 | label = pickle.load(f)['labels']
25 |
26 | data = [Image.fromarray(x) for x in data]
27 |
28 | min_label = min(label)
29 | label = [x - min_label for x in label]
30 |
31 | self.data = data
32 | self.label = label
33 | self.n_classes = max(self.label) + 1
34 |
35 | if transform is None:
36 | if split in ['train', 'trainval']:
37 | self.transform = transforms.Compose([
38 | transforms.Resize(size+12),
39 | transforms.RandomCrop(size, padding=8),
40 | transforms.RandomHorizontalFlip(),
41 | transforms.ToTensor(),
42 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
43 | std=[0.229, 0.224, 0.225]),
44 | ])
45 | else:
46 | self.transform = transforms.Compose([
47 | transforms.Resize(size),
48 | transforms.CenterCrop(size),
49 | transforms.ToTensor(),
50 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
51 | std=[0.229, 0.224, 0.225])
52 | ])
53 | else:
54 | self.transform = transform
55 |
56 | def __len__(self):
57 | return len(self.data)
58 |
59 | def __getitem__(self, i):
60 | return self.transform(self.data[i]), self.label[i]
61 |
62 |
63 | class SSLTieredImageNet(Dataset):
64 |
65 | def __init__(self, split='train', args=None):
66 | split_tag = split
67 | data = np.load(os.path.join(
68 | ROOT_PATH, '{}_images.npz'.format(split_tag)),
69 | allow_pickle=True)['images']
70 | data = data[:, :, :, ::-1]
71 |
72 | with open(os.path.join(
73 | ROOT_PATH, '{}_labels.pkl'.format(split_tag)), 'rb') as f:
74 | label = pickle.load(f)['labels']
75 |
76 | data = [Image.fromarray(x) for x in data]
77 |
78 | min_label = min(label)
79 | label = [x - min_label for x in label]
80 |
81 | self.data = data
82 | self.label = label
83 | self.n_classes = max(self.label) + 1
84 |
85 | color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4,
86 | saturation=0.4, hue=0.1)
87 | self.augmentation_transform = transforms.Compose([transforms.RandomResizedCrop(size=(args.size, args.size)[-2:],
88 | scale=(0.5, 1.0)),
89 | transforms.RandomHorizontalFlip(p=0.5),
90 | transforms.RandomVerticalFlip(p=0.5),
91 | transforms.RandomApply([color_jitter], p=0.8),
92 | transforms.RandomGrayscale(p=0.2),
93 | transforms.ToTensor(),
94 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
95 | std=[0.229, 0.224, 0.225]),
96 | ])
97 | #
98 | self.identity_transform = transforms.Compose([
99 | transforms.Resize(args.size),
100 | transforms.CenterCrop(args.size),
101 | transforms.ToTensor(),
102 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
103 | std=[0.229, 0.224, 0.225])
104 | ])
105 |
106 | def __len__(self):
107 | return len(self.data)
108 |
109 | def __getitem__(self, i):
110 | img, label = self.data[i], self.label[i]
111 | image = []
112 | for _ in range(1):
113 | image.append(self.identity_transform(img).unsqueeze(0))
114 | for i in range(3):
115 | image.append(self.augmentation_transform(img).unsqueeze(0))
116 | return dict(data=torch.cat(image)), label
117 |
--------------------------------------------------------------------------------
/models/convnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | def conv_block(in_channels, out_channels):
4 | bn = nn.BatchNorm2d(out_channels)
5 | nn.init.uniform_(bn.weight)
6 | return nn.Sequential(
7 | nn.Conv2d(in_channels, out_channels, 3, padding=1),
8 | bn,
9 | #nn.BatchNorm2d(out_channels),
10 | nn.ReLU(),
11 | nn.MaxPool2d(2)
12 | )
13 |
14 |
15 | class Convnet(nn.Module):
16 |
17 | def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
18 | super().__init__()
19 | self.encoder = nn.Sequential(
20 | conv_block(x_dim, hid_dim),
21 | conv_block(hid_dim, hid_dim),
22 | conv_block(hid_dim, hid_dim),
23 | conv_block(hid_dim, z_dim),
24 | )
25 |
26 | def forward(self, x):
27 | x = self.encoder(x)
28 | return x.view(x.size(0), -1)
29 |
--------------------------------------------------------------------------------
/models/distill.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | class DistillKL(nn.Module):
5 |
6 | def __init__(self, T):
7 | super(DistillKL, self).__init__()
8 | self.T = T
9 |
10 | def forward(self, y_s, y_t):
11 | p_s = F.log_softmax(y_s/self.T, dim=1)
12 | p_t = F.softmax(y_t/self.T, dim=1)
13 | loss = F.kl_div(p_s, p_t, reduction='sum')*(self.T**2)/y_s.shape[0]
14 | return loss
15 |
16 | class HintLoss(nn.Module):
17 | def __init__(self):
18 | super(HintLoss, self).__init__()
19 | self.crit = nn.MSELoss()
20 | def forward(self, fs, ft):
21 | return self.crit(fs, ft)
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.distributions import Bernoulli
5 |
6 | # ======== 2D RESNET ===========
7 |
8 | def conv3x3(in_planes, out_planes, stride=1):
9 | """3x3 convolution with padding"""
10 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
11 | padding=1, bias=False)
12 |
13 |
14 | class SELayer(nn.Module):
15 | def __init__(self, channel, reduction=16):
16 | super(SELayer, self).__init__()
17 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
18 | self.fc = nn.Sequential(
19 | nn.Linear(channel, channel // reduction),
20 | nn.ReLU(inplace=True),
21 | nn.Linear(channel // reduction, channel),
22 | nn.Sigmoid()
23 | )
24 |
25 | def forward(self, x):
26 | b, c, _, _ = x.size()
27 | y = self.avg_pool(x).view(b, c)
28 | y = self.fc(y).view(b, c, 1, 1)
29 | return x * y
30 |
31 |
32 | class DropBlock(nn.Module):
33 | def __init__(self, block_size):
34 | super(DropBlock, self).__init__()
35 |
36 | self.block_size = block_size
37 | #self.gamma = gamma
38 | #self.bernouli = Bernoulli(gamma)
39 |
40 | def forward(self, x, gamma):
41 | # shape: (bsize, channels, height, width)
42 |
43 | if self.training:
44 | batch_size, channels, height, width = x.shape
45 |
46 | bernoulli = Bernoulli(gamma)
47 | mask = bernoulli.sample((batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda()
48 | block_mask = self._compute_block_mask(mask)
49 | countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
50 | count_ones = block_mask.sum()
51 |
52 | return block_mask * x * (countM / count_ones)
53 | else:
54 | return x
55 |
56 | def _compute_block_mask(self, mask):
57 | left_padding = int((self.block_size-1) / 2)
58 | right_padding = int(self.block_size / 2)
59 |
60 | batch_size, channels, height, width = mask.shape
61 | #print ("mask", mask[0][0])
62 | non_zero_idxs = mask.nonzero()
63 | nr_blocks = non_zero_idxs.shape[0]
64 |
65 | offsets = torch.stack(
66 | [
67 | torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1), # - left_padding,
68 | torch.arange(self.block_size).repeat(self.block_size), #- left_padding
69 | ]
70 | ).t().cuda()
71 | offsets = torch.cat((torch.zeros(self.block_size**2, 2).cuda().long(), offsets.long()), 1)
72 |
73 | if nr_blocks > 0:
74 | non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
75 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
76 | offsets = offsets.long()
77 |
78 | block_idxs = non_zero_idxs + offsets
79 | #block_idxs += left_padding
80 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
81 | padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
82 | else:
83 | padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
84 |
85 | block_mask = 1 - padded_mask#[:height, :width]
86 | return block_mask
87 |
88 |
89 | class BasicBlock(nn.Module):
90 | expansion = 1
91 |
92 | def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False,
93 | block_size=1, use_se=False):
94 | super(BasicBlock, self).__init__()
95 | self.conv1 = conv3x3(inplanes, planes)
96 | self.bn1 = nn.BatchNorm2d(planes)
97 | self.relu = nn.LeakyReLU(0.1)
98 | self.conv2 = conv3x3(planes, planes)
99 | self.bn2 = nn.BatchNorm2d(planes)
100 | self.conv3 = conv3x3(planes, planes)
101 | self.bn3 = nn.BatchNorm2d(planes)
102 | self.maxpool = nn.MaxPool2d(stride)
103 | self.downsample = downsample
104 | self.stride = stride
105 | self.drop_rate = drop_rate
106 | self.num_batches_tracked = 0
107 | self.drop_block = drop_block
108 | self.block_size = block_size
109 | self.DropBlock = DropBlock(block_size=self.block_size)
110 | self.use_se = use_se
111 | if self.use_se:
112 | self.se = SELayer(planes, 4)
113 |
114 | def forward(self, x):
115 | self.num_batches_tracked += 1
116 |
117 | residual = x
118 |
119 | out = self.conv1(x)
120 | out = self.bn1(out)
121 | out = self.relu(out)
122 |
123 | out = self.conv2(out)
124 | out = self.bn2(out)
125 | out = self.relu(out)
126 |
127 | out = self.conv3(out)
128 | out = self.bn3(out)
129 | if self.use_se:
130 | out = self.se(out)
131 |
132 | if self.downsample is not None:
133 | residual = self.downsample(x)
134 | out += residual
135 | out = self.relu(out)
136 | out = self.maxpool(out)
137 |
138 | if self.drop_rate > 0:
139 | if self.drop_block == True:
140 | feat_size = out.size()[2]
141 | keep_rate = max(1.0 - self.drop_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate)
142 | gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2
143 | out = self.DropBlock(out, gamma=gamma)
144 | else:
145 | out = F.dropout(out, p=self.drop_rate, training=self.training, inplace=True)
146 |
147 | return out
148 |
149 |
150 | class ResNet2d(nn.Module):
151 |
152 | def __init__(self, block, n_blocks, keep_prob=1.0, avg_pool=False, drop_rate=0.0,
153 | dropblock_size=5, num_classes=-1, use_se=False):
154 | super(ResNet2d, self).__init__()
155 |
156 | self.inplanes = 3
157 | self.use_se = use_se
158 | self.layer1 = self._make_layer(block, n_blocks[0], 64,
159 | stride=2, drop_rate=drop_rate)
160 | self.layer2 = self._make_layer(block, n_blocks[1], 160,
161 | stride=2, drop_rate=drop_rate)
162 | self.layer3 = self._make_layer(block, n_blocks[2], 320,
163 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
164 | self.layer4 = self._make_layer(block, n_blocks[3], 640,
165 | stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
166 | if avg_pool:
167 | # self.avgpool = nn.AvgPool2d(5, stride=1)
168 | self.avgpool = nn.AdaptiveAvgPool2d(1)
169 | self.keep_prob = keep_prob
170 | self.keep_avg_pool = avg_pool
171 | self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False)
172 | self.drop_rate = drop_rate
173 |
174 | for m in self.modules():
175 | if isinstance(m, nn.Conv2d):
176 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
177 | elif isinstance(m, nn.BatchNorm2d):
178 | nn.init.constant_(m.weight, 1)
179 | nn.init.constant_(m.bias, 0)
180 |
181 | def _make_layer(self, block, n_block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1):
182 | downsample = None
183 | if stride != 1 or self.inplanes != planes * block.expansion:
184 | downsample = nn.Sequential(
185 | nn.Conv2d(self.inplanes, planes * block.expansion,
186 | kernel_size=1, stride=1, bias=False),
187 | nn.BatchNorm2d(planes * block.expansion),
188 | )
189 |
190 | layers = []
191 | if n_block == 1:
192 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, drop_block, block_size, self.use_se)
193 | else:
194 | layer = block(self.inplanes, planes, stride, downsample, drop_rate, self.use_se)
195 | layers.append(layer)
196 | self.inplanes = planes * block.expansion
197 |
198 | for i in range(1, n_block):
199 | if i == n_block - 1:
200 | layer = block(self.inplanes, planes, drop_rate=drop_rate, drop_block=drop_block,
201 | block_size=block_size, use_se=self.use_se)
202 | else:
203 | layer = block(self.inplanes, planes, drop_rate=drop_rate, use_se=self.use_se)
204 | layers.append(layer)
205 |
206 | return nn.Sequential(*layers)
207 |
208 | def forward(self, x):
209 | x = self.layer1(x)
210 | x = self.layer2(x)
211 | x = self.layer3(x)
212 | x = self.layer4(x)
213 | if self.keep_avg_pool:
214 | x = self.avgpool(x)
215 | x = x.view(x.size(0), -1)
216 | return x
217 |
218 | def resnet12(keep_prob=1.0, avg_pool=False, **kwargs):
219 | model = ResNet2d(BasicBlock, [1, 1, 1, 1], keep_prob=keep_prob, avg_pool=avg_pool, **kwargs)
220 | return model
--------------------------------------------------------------------------------
/save/readme.txt:
--------------------------------------------------------------------------------
1 | Saved model/checkpoint location
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import os.path as osp
4 | import torch
5 | from torch.utils.data import DataLoader
6 |
7 | from datasets.mini_imagenet import MiniImageNet
8 | from datasets.tiered_imagenet import TieredImageNet
9 | from datasets.cifarfs import CIFAR_FS
10 | from datasets.samplers import CategoriesSampler
11 | from models.convnet import Convnet
12 | from models.resnet import resnet12
13 | from utils import set_gpu, Averager, count_acc, euclidean_metric, seed_torch, compute_confidence_interval
14 |
15 |
16 | def final_evaluate(args):
17 | if args.dataset == 'mini':
18 | valset = MiniImageNet('test', args.size)
19 | elif args.dataset == 'tiered':
20 | valset = TieredImageNet('test', args.size)
21 | elif args.dataset == "cifarfs":
22 | valset = CIFAR_FS('test', args.size)
23 | else:
24 | print("Invalid dataset...")
25 | exit()
26 | val_sampler = CategoriesSampler(valset.label, args.test_batch,
27 | args.test_way, args.shot + args.test_query)
28 | loader = DataLoader(dataset=valset, batch_sampler=val_sampler,
29 | num_workers=args.worker, pin_memory=True)
30 |
31 | if args.model == 'convnet':
32 | model = Convnet().cuda()
33 | print("=> Convnet architecture...")
34 | else:
35 | if args.dataset in ['mini', 'tiered']:
36 | model = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=5).cuda()
37 | else:
38 | model = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=2).cuda()
39 | print("=> Resnet architecture...")
40 |
41 | model.load_state_dict(torch.load(osp.join(args.save_path, 'max-acc.pth')))
42 | print("=> Model loaded...")
43 | model.eval()
44 |
45 | ave_acc = Averager()
46 | acc_list = []
47 |
48 | for i, batch in enumerate(loader, 1):
49 | data, _ = [_.cuda() for _ in batch]
50 | k = args.test_way * args.shot
51 | data_shot, data_query = data[:k], data[k:]
52 |
53 | x = model(data_shot)
54 | x = x.reshape(args.shot, args.test_way, -1).mean(dim=0)
55 | p = x
56 |
57 | logits = euclidean_metric(model(data_query), p)
58 |
59 | label = torch.arange(args.test_way).repeat(args.test_query)
60 | label = label.type(torch.cuda.LongTensor)
61 |
62 | acc = count_acc(logits, label)
63 | ave_acc.add(acc)
64 | acc_list.append(acc*100)
65 |
66 | x = None; p = None; logits = None
67 |
68 | a, b = compute_confidence_interval(acc_list)
69 | print("Final accuracy with 95% interval : {:.2f}±{:.2f}".format(a, b))
70 |
71 |
72 | if __name__ == '__main__':
73 | parser = argparse.ArgumentParser()
74 | parser.add_argument('--shot', type=int, default=1)
75 | parser.add_argument('--test-query', type=int, default=15)
76 | parser.add_argument('--test-way', type=int, default=5)
77 | parser.add_argument('--save-path', default='')
78 | parser.add_argument('--gpu', default='0')
79 | parser.add_argument('--size', type=int, default=84)
80 | parser.add_argument('--test-batch', type=int, default=2000)
81 | parser.add_argument('--worker', type=int, default=8)
82 | parser.add_argument('--model', type=str, default='convnet', choices=['convnet', 'resnet'])
83 | parser.add_argument('--dataset', type=str, default='mini', choices=['mini','tiered','cifarfs'])
84 | args = parser.parse_args()
85 |
86 | start_time = datetime.datetime.now()
87 |
88 | # fix seed
89 | seed_torch(1)
90 | set_gpu(args.gpu)
91 |
92 | if args.dataset in ['mini', 'tiered']:
93 | args.size = 84
94 | elif args.dataset in ['cifarfs']:
95 | args.size = 32
96 | args.worker = 0
97 | else:
98 | args.size = 28
99 |
100 | final_evaluate(args)
101 |
102 | end_time = datetime.datetime.now()
103 | print("Total executed time :", end_time - start_time)
--------------------------------------------------------------------------------
/train_stage1.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import os.path as osp
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.utils.data import DataLoader
7 |
8 | from datasets.mini_imagenet import MiniImageNet, SSLMiniImageNet
9 | from datasets.tiered_imagenet import TieredImageNet, SSLTieredImageNet
10 | from datasets.cifarfs import CIFAR_FS, SSLCifarFS
11 | from datasets.samplers import CategoriesSampler
12 | from models.convnet import Convnet
13 | from models.resnet import resnet12
14 | from utils import set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric, seed_torch, compute_confidence_interval
15 |
16 |
17 | def get_dataset(args):
18 | if args.dataset == 'mini':
19 | trainset = SSLMiniImageNet('train', args)
20 | valset = MiniImageNet('test', args.size)
21 | print("=> MiniImageNet...")
22 | elif args.dataset == 'tiered':
23 | trainset = SSLTieredImageNet('train', args)
24 | valset = TieredImageNet('test', args.size)
25 | print("=> TieredImageNet...")
26 | elif args.dataset == 'cifarfs':
27 | trainset = SSLCifarFS('train', args)
28 | valset = CIFAR_FS('test', args.size)
29 | print("=> CIFAR FS...")
30 | else:
31 | print("Invalid dataset...")
32 | exit()
33 |
34 | train_loader = DataLoader(dataset=trainset, batch_size=args.train_way * args.shot,
35 | shuffle=True, drop_last=True,
36 | num_workers=args.worker, pin_memory=True)
37 |
38 | val_sampler = CategoriesSampler(valset.label, args.test_batch,
39 | args.test_way, args.shot + args.test_query)
40 | val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler,
41 | num_workers=args.worker, pin_memory=True)
42 | return train_loader, val_loader
43 |
44 | def training(args):
45 | ensure_path(args.save_path)
46 |
47 | train_loader, val_loader = get_dataset(args)
48 |
49 | if args.model == 'convnet':
50 | model = Convnet().cuda()
51 | print("=> Convnet architecture...")
52 | else:
53 | if args.dataset in ['mini', 'tiered']:
54 | model = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=5).cuda()
55 | else:
56 | model = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=2).cuda()
57 | print("=> Resnet architecture...")
58 |
59 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
60 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)
61 |
62 | def save_model(name):
63 | torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth'))
64 |
65 | trlog = {}
66 | trlog['args'] = vars(args)
67 | trlog['train_loss'] = []
68 | trlog['val_loss'] = []
69 | trlog['train_acc'] = []
70 | trlog['val_acc'] = []
71 | trlog['max_acc'] = 0.0
72 |
73 | timer = Timer()
74 | best_epoch = 0
75 | cmi = [0.0, 0.0]
76 |
77 | for epoch in range(1, args.max_epoch + 1):
78 |
79 | tl, ta = train(args, model, train_loader, optimizer)
80 | lr_scheduler.step()
81 | vl, va, aa, bb = validate(args, model, val_loader)
82 |
83 | if va > trlog['max_acc']:
84 | trlog['max_acc'] = va
85 | save_model('max-acc')
86 | best_epoch = epoch
87 | cmi[0] = aa
88 | cmi[1] = bb
89 |
90 | trlog['train_loss'].append(tl)
91 | trlog['train_acc'].append(ta)
92 | trlog['val_loss'].append(vl)
93 | trlog['val_acc'].append(va)
94 |
95 | torch.save(trlog, osp.join(args.save_path, 'trlog'))
96 |
97 | save_model('epoch-last')
98 | ot, ots = timer.measure()
99 | tt, _ = timer.measure(epoch / args.max_epoch)
100 |
101 | print('Epoch {}/{}, train loss={:.4f} - acc={:.4f} - val loss={:.4f} - acc={:.4f} - max acc={:.4f} - ETA:{}/{}'.format(
102 | epoch, args.max_epoch, tl, ta, vl, va, trlog['max_acc'], ots, timer.tts(tt-ot)))
103 |
104 | if epoch == args.max_epoch:
105 | print("Best Epoch is {} with acc={:.2f}±{:.2f}%...".format(best_epoch, cmi[0], cmi[1]))
106 | print("---------------------------------------------------")
107 |
108 | def preprocess_data(data):
109 | for idxx, img in enumerate(data):
110 | # 4,3,84,84
111 | supportimg = img.data[0].unsqueeze(0)
112 | x90 = img.data[1].unsqueeze(0).transpose(2,3).flip(2)
113 | x180 = img.data[2].unsqueeze(0).flip(2).flip(3)
114 | x270 = img.data[3].unsqueeze(0).flip(2).transpose(2,3)
115 | queryimg = torch.cat((x90, x180, x270), 0)
116 | queryimg = queryimg.unsqueeze(0)
117 | if idxx <= 0:
118 | # support
119 | dshot = supportimg
120 | # query
121 | dquery = queryimg
122 | else:
123 | dshot = torch.cat((dshot, supportimg), 0)
124 | dquery = torch.cat((dquery, queryimg), 0)
125 | dquery = torch.transpose(dquery, 0, 1)
126 | dquery = dquery.reshape(args.train_way*args.train_query, 3, args.size, args.size)
127 | return dshot.cuda(), dquery.cuda()
128 |
129 | def train(args, model, train_loader, optimizer):
130 | model.train()
131 |
132 | tl = Averager()
133 | ta = Averager()
134 |
135 | for i, batch in enumerate(train_loader, 1):
136 | data, _ = batch
137 | data_shot, data_query = preprocess_data(data['data'])
138 |
139 | proto = model(data_shot)
140 | proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0)
141 |
142 | label = torch.arange(args.train_way).repeat(args.train_query)
143 | label = label.type(torch.cuda.LongTensor)
144 |
145 | logits = euclidean_metric(model(data_query), proto)
146 | loss = F.cross_entropy(logits, label)
147 | acc = count_acc(logits, label)
148 |
149 | tl.add(loss.item())
150 | ta.add(acc)
151 |
152 | optimizer.zero_grad()
153 | loss.backward()
154 | optimizer.step()
155 |
156 | proto = None; logits = None; loss = None
157 |
158 | if (args.train_batch > 0) and (i >= args.train_batch):
159 | break
160 |
161 | return tl.item(), ta.item()
162 |
163 |
164 | def validate(args, model, val_loader):
165 | model.eval()
166 |
167 | vl = Averager()
168 | va = Averager()
169 | acc_list = []
170 |
171 | for i, batch in enumerate(val_loader, 1):
172 | data, _ = [_.cuda() for _ in batch]
173 | p = args.shot * args.test_way
174 | data_shot, data_query = data[:p], data[p:]
175 |
176 | proto = model(data_shot)
177 | proto = proto.reshape(args.shot, args.test_way, -1).mean(dim=0)
178 |
179 | label = torch.arange(args.test_way).repeat(args.test_query)
180 | label = label.type(torch.cuda.LongTensor)
181 |
182 | logits = euclidean_metric(model(data_query), proto)
183 | loss = F.cross_entropy(logits, label)
184 | acc = count_acc(logits, label)
185 |
186 | vl.add(loss.item())
187 | va.add(acc)
188 | acc_list.append(acc*100)
189 |
190 | proto = None; logits = None; loss = None
191 |
192 | a, b = compute_confidence_interval(acc_list)
193 | return vl.item(), va.item(), a, b
194 |
195 |
196 | if __name__ == '__main__':
197 | parser = argparse.ArgumentParser()
198 | parser.add_argument('--max-epoch', type=int, default=200)
199 | parser.add_argument('--shot', type=int, default=1)
200 | parser.add_argument('--train-query', type=int, default=3)
201 | parser.add_argument('--test-query', type=int, default=15)
202 | parser.add_argument('--train-way', type=int, default=50)
203 | parser.add_argument('--test-way', type=int, default=5)
204 | parser.add_argument('--save-path', default='')
205 | parser.add_argument('--gpu', default='0')
206 | parser.add_argument('--size', type=int, default=84)
207 | parser.add_argument('--lr', type=float, default=0.001)
208 | parser.add_argument('--wd', type=float, default=0.001)
209 | parser.add_argument('--step-size', type=int, default=20)
210 | parser.add_argument('--train-batch', type=int, default=-1)
211 | parser.add_argument('--test-batch', type=int, default=2000)
212 | parser.add_argument('--worker', type=int, default=8)
213 | parser.add_argument('--model', type=str, default='convnet', choices=['convnet', 'resnet'])
214 | parser.add_argument('--dataset', type=str, default='mini', choices=['mini','tiered','cifarfs'])
215 | args = parser.parse_args()
216 |
217 | start_time = datetime.datetime.now()
218 |
219 | # fix seed
220 | seed_torch(1)
221 | set_gpu(args.gpu)
222 |
223 | if args.dataset in ['mini', 'tiered']:
224 | args.size = 84
225 | elif args.dataset in ['cifarfs']:
226 | args.size = 32
227 | args.worker = 0
228 | else:
229 | args.size = 28
230 |
231 | training(args)
232 |
233 | end_time = datetime.datetime.now()
234 | print("Total executed time :", end_time - start_time)
235 |
236 |
--------------------------------------------------------------------------------
/train_stage2.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import os.path as osp
4 | import torch
5 | import torch.nn.functional as F
6 | from torch.utils.data import DataLoader
7 |
8 | from datasets.mini_imagenet import MiniImageNet
9 | from datasets.tiered_imagenet import TieredImageNet
10 | from datasets.cifarfs import CIFAR_FS
11 | from datasets.samplers import CategoriesSampler
12 | from models.convnet import Convnet
13 | from models.resnet import resnet12
14 | from utils import set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric, seed_torch, compute_confidence_interval
15 |
16 | def get_dataset(args):
17 | if args.dataset == 'mini':
18 | trainset = MiniImageNet('train', args.size)
19 | valset = MiniImageNet('test', args.size)
20 | print("=> MiniImageNet...")
21 | elif args.dataset == 'tiered':
22 | trainset = TieredImageNet('train', args.size)
23 | valset = TieredImageNet('test', args.size)
24 | print("=> TieredImageNet...")
25 | elif args.dataset == 'cifarfs':
26 | trainset = CIFAR_FS('train', args.size)
27 | valset = CIFAR_FS('test', args.size)
28 | print("=> CIFAR FS...")
29 | else:
30 | print("Invalid dataset...")
31 | exit()
32 | train_sampler = CategoriesSampler(trainset.label, args.train_batch,
33 | args.train_way, args.shot + args.train_query)
34 | train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler,
35 | num_workers=args.worker, pin_memory=True)
36 |
37 | val_sampler = CategoriesSampler(valset.label, args.test_batch,
38 | args.test_way, args.shot + args.test_query)
39 | val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler,
40 | num_workers=args.worker, pin_memory=True)
41 | return train_loader, val_loader
42 |
43 | def training(args):
44 | ensure_path(args.save_path)
45 |
46 | train_loader, val_loader = get_dataset(args)
47 |
48 | if args.model == 'convnet':
49 | model = Convnet().cuda()
50 | print("=> Convnet architecture...")
51 | else:
52 | if args.dataset in ['mini', 'tiered','cub']:
53 | model = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=5).cuda()
54 | print("=> Large block resnet architecture...")
55 | else:
56 | model = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=2).cuda()
57 | print("=> Small block resnet architecture...")
58 |
59 | if args.stage1_path:
60 | model.load_state_dict(torch.load(osp.join(args.stage1_path, 'max-acc.pth')))
61 | print("=> Pretrain model loaded...")
62 |
63 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
64 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)
65 |
66 | def save_model(name):
67 | torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth'))
68 |
69 | trlog = {}
70 | trlog['args'] = vars(args)
71 | trlog['train_loss'] = []
72 | trlog['val_loss'] = []
73 | trlog['train_acc'] = []
74 | trlog['val_acc'] = []
75 | trlog['max_acc'] = 0.0
76 |
77 | timer = Timer()
78 | best_epoch = 0
79 | cmi = [0.0, 0.0]
80 |
81 | for epoch in range(1, args.max_epoch + 1):
82 |
83 | tl, ta = train(args, model, train_loader, optimizer)
84 | #
85 | lr_scheduler.step()
86 |
87 | vl, va, aa, bb = validate(args, model, val_loader)
88 |
89 | if va > trlog['max_acc']:
90 | trlog['max_acc'] = va
91 | save_model('max-acc')
92 | best_epoch = epoch
93 | cmi[0] = aa
94 | cmi[1] = bb
95 |
96 | trlog['train_loss'].append(tl)
97 | trlog['train_acc'].append(ta)
98 | trlog['val_loss'].append(vl)
99 | trlog['val_acc'].append(va)
100 |
101 | torch.save(trlog, osp.join(args.save_path, 'trlog'))
102 |
103 | save_model('epoch-last')
104 | ot, ots = timer.measure()
105 | tt, _ = timer.measure(epoch / args.max_epoch)
106 |
107 | print('Epoch {}/{}, train loss={:.4f} - acc={:.4f} - val loss={:.4f} - acc={:.4f} - max acc={:.4f} - ETA:{}/{}'.format(
108 | epoch, args.max_epoch, tl, ta, vl, va, trlog['max_acc'], ots, timer.tts(tt-ot)))
109 |
110 | if epoch == args.max_epoch:
111 | print("Best Epoch is {} with acc={:.2f}±{:.2f}%...".format(best_epoch, cmi[0], cmi[1]))
112 | print("---------------------------------------------------")
113 |
114 | def ssl_loss(args, model, data_shot):
115 | # s1 s2 q1 q2 q1 q2
116 | x_90 = data_shot.transpose(2,3).flip(2)
117 | x_180 = data_shot.flip(2).flip(3)
118 | x_270 = data_shot.flip(2).transpose(2,3)
119 | data_query = torch.cat((x_90, x_180, x_270),0)
120 |
121 | proto = model(data_shot)
122 | proto = proto.reshape(1, args.train_way*args.shot, -1).mean(dim=0)
123 | query = model(data_query)
124 |
125 | label = torch.arange(args.train_way*args.shot).repeat(args.pre_query)
126 | label = label.type(torch.cuda.LongTensor)
127 |
128 | logits = euclidean_metric(query, proto)
129 | loss = F.cross_entropy(logits, label)
130 |
131 | return loss
132 |
133 | def train(args, model, train_loader, optimizer):
134 | model.train()
135 |
136 | tl = Averager()
137 | ta = Averager()
138 |
139 | for i, batch in enumerate(train_loader, 1):
140 | data, _ = [_.cuda() for _ in batch]
141 | p = args.shot * args.train_way
142 | data_shot, data_query = data[:p], data[p:]
143 |
144 | proto = model(data_shot) # (30, 1600)
145 | proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0)
146 | query = model(data_query)
147 |
148 | label = torch.arange(args.train_way).repeat(args.train_query)
149 | label = label.type(torch.cuda.LongTensor)
150 |
151 | logits = euclidean_metric(query, proto)
152 | loss_ss = ssl_loss(args, model, data_shot)
153 | loss = F.cross_entropy(logits, label) + args.beta * loss_ss
154 | acc = count_acc(logits, label)
155 |
156 | tl.add(loss.item())
157 | ta.add(acc)
158 |
159 | optimizer.zero_grad()
160 | loss.backward()
161 | optimizer.step()
162 |
163 | proto = None; query = None; logits = None; loss = None
164 |
165 | return tl.item(), ta.item()
166 |
167 | def validate(args, model, val_loader):
168 | model.eval()
169 |
170 | vl = Averager()
171 | va = Averager()
172 | acc_list = []
173 |
174 | for i, batch in enumerate(val_loader, 1):
175 | data, _ = [_.cuda() for _ in batch]
176 | p = args.shot * args.test_way
177 | data_shot, data_query = data[:p], data[p:]
178 |
179 | proto = model(data_shot)
180 | proto = proto.reshape(args.shot, args.test_way, -1).mean(dim=0)
181 | query = model(data_query)
182 |
183 | label = torch.arange(args.test_way).repeat(args.test_query)
184 | label = label.type(torch.cuda.LongTensor)
185 |
186 | logits = euclidean_metric(query, proto)
187 | loss = F.cross_entropy(logits, label)
188 | acc = count_acc(logits, label)
189 |
190 | vl.add(loss.item())
191 | va.add(acc)
192 | acc_list.append(acc*100)
193 |
194 | proto = None; query = None; logits = None; loss = None
195 | a,b = compute_confidence_interval(acc_list)
196 | return vl.item(), va.item(), a, b
197 |
198 |
199 | if __name__ == '__main__':
200 | parser = argparse.ArgumentParser()
201 | parser.add_argument('--max-epoch', type=int, default=200)
202 | parser.add_argument('--shot', type=int, default=1)
203 | parser.add_argument('--pre-query', type=int, default=3)
204 | parser.add_argument('--train-query', type=int, default=15)
205 | parser.add_argument('--test-query', type=int, default=15)
206 | parser.add_argument('--train-way', type=int, default=5)
207 | parser.add_argument('--test-way', type=int, default=5)
208 | parser.add_argument('--save-path', default='')
209 | parser.add_argument('--gpu', default='0')
210 | parser.add_argument('--size', type=int, default=84)
211 | parser.add_argument('--lr', type=float, default=0.001)
212 | parser.add_argument('--wd', type=float, default=0.001)
213 | parser.add_argument('--step-size', type=int, default=20)
214 | parser.add_argument('--train-batch', type=int, default=100)
215 | parser.add_argument('--test-batch', type=int, default=2000)
216 | parser.add_argument('--worker', type=int, default=8)
217 | parser.add_argument('--model', type=str, default='convnet', choices=['convnet', 'resnet'])
218 | parser.add_argument('--mode', type=int, default=0, choices=[0,1])
219 | parser.add_argument('--stage1-path', default='')
220 | parser.add_argument('--beta', type=float, default=0.1)
221 | parser.add_argument('--dataset', type=str, default='mini', choices=['mini','tiered','cifarfs'])
222 | args = parser.parse_args()
223 |
224 | start_time = datetime.datetime.now()
225 |
226 | # fix seed
227 | seed_torch(1)
228 | set_gpu(args.gpu)
229 |
230 | if args.dataset in ['mini', 'tiered']:
231 | args.size = 84
232 | elif args.dataset in ['cifarfs']:
233 | args.size = 32
234 | args.worker = 0
235 | else:
236 | args.size = 28
237 |
238 | training(args)
239 |
240 | end_time = datetime.datetime.now()
241 | print("Total executed time :", end_time - start_time)
242 |
243 |
--------------------------------------------------------------------------------
/train_stage3.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import os.path as osp
4 | import copy
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.utils.data import DataLoader
8 |
9 | from datasets.mini_imagenet import MiniImageNet
10 | from datasets.tiered_imagenet import TieredImageNet
11 | from datasets.cifarfs import CIFAR_FS
12 | from datasets.samplers import CategoriesSampler
13 | from models.convnet import Convnet
14 | from models.distill import DistillKL, HintLoss
15 | from models.resnet import resnet12
16 | from utils import set_gpu, ensure_path, Averager, Timer, count_acc, euclidean_metric, seed_torch, compute_confidence_interval
17 |
18 |
19 | def get_dataset(args):
20 | if args.dataset == 'mini':
21 | trainset = MiniImageNet('train', args.size)
22 | valset = MiniImageNet('test', args.size)
23 | print("=> MiniImageNet...")
24 | elif args.dataset == 'tiered':
25 | trainset = TieredImageNet('train', args.size)
26 | valset = TieredImageNet('test', args.size)
27 | print("=> TieredImageNet...")
28 | elif args.dataset == 'cifarfs':
29 | trainset = CIFAR_FS('train', args.size)
30 | valset = CIFAR_FS('test', args.size)
31 | print("=> CIFAR FS...")
32 | else:
33 | print("Invalid dataset...")
34 | exit()
35 | train_sampler = CategoriesSampler(trainset.label, args.train_batch,
36 | args.train_way, args.shot + args.train_query)
37 | train_loader = DataLoader(dataset=trainset, batch_sampler=train_sampler,
38 | num_workers=args.worker, pin_memory=True)
39 |
40 | val_sampler = CategoriesSampler(valset.label, args.test_batch,
41 | args.test_way, args.shot + args.test_query)
42 | val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler,
43 | num_workers=args.worker, pin_memory=True)
44 | return train_loader, val_loader
45 |
46 | def training(args):
47 | ensure_path(args.save_path)
48 |
49 | train_loader, val_loader = get_dataset(args)
50 |
51 | if args.model == 'convnet':
52 | teacher = Convnet().cuda()
53 | print("=> Convnet architecture...")
54 | else:
55 | if args.dataset in ['mini', 'tiered']:
56 | teacher = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=5).cuda()
57 | else:
58 | teacher = resnet12(avg_pool=True, drop_rate=0.1, dropblock_size=2).cuda()
59 | print("=> Resnet architecture...")
60 |
61 | if args.kd_mode != 0:
62 | # produce a student model with the same structure as teacher model without knowldege
63 | model = copy.deepcopy(teacher)
64 | if args.stage1_path:
65 | model.load_state_dict(torch.load(osp.join(args.stage1_path, 'max-acc.pth')))
66 | print("=> Student loaded with pretrain knowledge...")
67 |
68 | teacher.load_state_dict(torch.load(osp.join(args.stage2_path, 'max-acc.pth')))
69 | print("=> Teacher model loaded...")
70 |
71 | if args.kd_mode == 0:
72 | # intilialize student with same knowledge as teacher
73 | model = copy.deepcopy(teacher)
74 | print("=> Student obtain teacher's knowledge...")
75 |
76 | if args.kd_type == 'kd':
77 | criterion_kd = DistillKL(args.temperature).cuda()
78 | else:
79 | criterion_kd = HintLoss().cuda()
80 |
81 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
82 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)
83 |
84 | def save_model(name):
85 | torch.save(model.state_dict(), osp.join(args.save_path, name + '.pth'))
86 |
87 | trlog = {}
88 | trlog['args'] = vars(args)
89 | trlog['train_loss'] = []
90 | trlog['val_loss'] = []
91 | trlog['train_acc'] = []
92 | trlog['val_acc'] = []
93 | trlog['max_acc'] = 0.0
94 |
95 | timer = Timer()
96 | best_epoch = 0
97 | cmi = [0.0, 0.0]
98 |
99 | for epoch in range(1, args.max_epoch + 1):
100 |
101 | tl, ta = train(args, teacher, model, train_loader, optimizer, criterion_kd)
102 | lr_scheduler.step()
103 | vl, va, aa, bb = validate(args, model, val_loader)
104 |
105 | if va > trlog['max_acc']:
106 | trlog['max_acc'] = va
107 | save_model('max-acc')
108 | best_epoch = epoch
109 | cmi[0] = aa
110 | cmi[1] = bb
111 |
112 | trlog['train_loss'].append(tl)
113 | trlog['train_acc'].append(ta)
114 | trlog['val_loss'].append(vl)
115 | trlog['val_acc'].append(va)
116 |
117 | torch.save(trlog, osp.join(args.save_path, 'trlog'))
118 |
119 | save_model('epoch-last')
120 | ot, ots = timer.measure()
121 | tt, _ = timer.measure(epoch / args.max_epoch)
122 |
123 | print('Epoch {}/{}, train loss={:.4f} - acc={:.4f} - val loss={:.4f} - acc={:.4f} - max acc={:.4f} - ETA:{}/{}'.format(
124 | epoch, args.max_epoch, tl, ta, vl, va, trlog['max_acc'], ots, timer.tts(tt-ot)))
125 |
126 | if epoch == args.max_epoch:
127 | print("Best Epoch is {} with acc={:.2f}±{:.2f}%...".format(best_epoch, cmi[0], cmi[1]))
128 | print("---------------------------------------------------")
129 |
130 | def ssl_loss(args, model, data_shot):
131 | # s1 s2 q1 q2 q1 q2
132 | x_90 = data_shot.transpose(2,3).flip(2)
133 | x_180 = data_shot.flip(2).flip(3)
134 | x_270 = data_shot.flip(2).transpose(2,3)
135 | data_query = torch.cat((x_90, x_180, x_270),0)
136 |
137 | proto = model(data_shot)
138 | proto = proto.reshape(1, args.shot*args.train_way, -1).mean(dim=0)
139 |
140 | label = torch.arange(args.train_way * args.shot).repeat(args.pre_query)
141 | label = label.type(torch.cuda.LongTensor)
142 |
143 | logits = euclidean_metric(model(data_query), proto)
144 | loss = F.cross_entropy(logits, label)
145 |
146 | return loss
147 |
148 | def train(args, teacher, model, train_loader, optimizer, criterion_kd):
149 | teacher.eval()
150 | model.train()
151 |
152 | tl = Averager()
153 | ta = Averager()
154 |
155 | for i, batch in enumerate(train_loader, 1):
156 | data, _ = [_.cuda() for _ in batch]
157 | p = args.shot * args.train_way
158 | data_shot, data_query = data[:p], data[p:] # datashot (30, 3, 84, 84)
159 |
160 | # teacher
161 | with torch.no_grad():
162 | tproto = teacher(data_shot)
163 | ft = tproto
164 | ft = [f.detach() for f in ft]
165 | tproto = tproto.reshape(args.shot, args.train_way, -1).mean(dim=0)
166 | # soft target from teacher
167 | tlogits = euclidean_metric(teacher(data_query), tproto)
168 |
169 | proto = model(data_shot) # (30, 1600)
170 | fs = proto
171 | proto = proto.reshape(args.shot, args.train_way, -1).mean(dim=0)
172 |
173 | label = torch.arange(args.train_way).repeat(args.train_query)
174 | label = label.type(torch.cuda.LongTensor)
175 |
176 | logits = euclidean_metric(model(data_query), proto)
177 | acc = count_acc(logits, label)
178 |
179 | if args.kd_mode != 0:
180 | # few-shot loss from student
181 | clsloss = F.cross_entropy(logits, label)
182 |
183 | # distillation loss
184 | if args.kd_type == 'kd':
185 | kdloss = criterion_kd(logits, tlogits)
186 | else:
187 | kdloss = criterion_kd(fs[-1], ft[-1])
188 |
189 | # self-supervised loss signal
190 | loss_ss = ssl_loss(args, model, data_shot)
191 |
192 | if args.kd_mode != 0:
193 | loss = ((1.0 - args.kd_coef) * clsloss) + (args.kd_coef * kdloss) + (args.ssl_coef * loss_ss)
194 | else:
195 | loss = kdloss + (args.ssl_coef * loss_ss)
196 |
197 | tl.add(loss.item())
198 | ta.add(acc)
199 |
200 | optimizer.zero_grad()
201 | loss.backward()
202 | optimizer.step()
203 |
204 | proto = None; logits = None; loss = None
205 |
206 | return tl.item(), ta.item()
207 |
208 | def validate(args, model, val_loader):
209 | model.eval()
210 |
211 | vl = Averager()
212 | va = Averager()
213 | acc_list = []
214 |
215 | for i, batch in enumerate(val_loader, 1):
216 | data, _ = [_.cuda() for _ in batch]
217 | p = args.shot * args.test_way
218 | data_shot, data_query = data[:p], data[p:]
219 |
220 | proto = model(data_shot)
221 | proto = proto.reshape(args.shot, args.test_way, -1).mean(dim=0)
222 |
223 | label = torch.arange(args.test_way).repeat(args.test_query)
224 | label = label.type(torch.cuda.LongTensor)
225 |
226 | logits = euclidean_metric(model(data_query), proto)
227 | loss = F.cross_entropy(logits, label)
228 | acc = count_acc(logits, label)
229 |
230 | vl.add(loss.item())
231 | va.add(acc)
232 | acc_list.append(acc*100)
233 |
234 | proto = None; logits = None; loss = None
235 | a,b = compute_confidence_interval(acc_list)
236 | return vl.item(), va.item(), a, b
237 |
238 |
239 | if __name__ == '__main__':
240 | parser = argparse.ArgumentParser()
241 | parser.add_argument('--max-epoch', type=int, default=200)
242 | parser.add_argument('--shot', type=int, default=1)
243 | parser.add_argument('--pre-query', type=int, default=3) # for self-supervised process: the number of query image generated based on support image
244 | parser.add_argument('--train-query', type=int, default=15)
245 | parser.add_argument('--test-query', type=int, default=15)
246 | parser.add_argument('--train-way', type=int, default=5)
247 | parser.add_argument('--test-way', type=int, default=5)
248 | parser.add_argument('--save-path', default='')
249 | parser.add_argument('--gpu', default='0')
250 | parser.add_argument('--size', type=int, default=84)
251 | parser.add_argument('--lr', type=float, default=0.001)
252 | parser.add_argument('--wd', type=float, default=0.001)
253 | parser.add_argument('--step-size', type=int, default=20)
254 | parser.add_argument('--train-batch', type=int, default=100)
255 | parser.add_argument('--test-batch', type=int, default=2000)
256 | parser.add_argument('--worker', type=int, default=8)
257 | parser.add_argument('--model', type=str, default='convnet', choices=['convnet', 'resnet'])
258 | parser.add_argument('--dataset', type=str, default='mini', choices=['mini','tiered','cifarfs'])
259 | parser.add_argument('--ssl-coef', type=float, default=0.1, help='The beta coefficient for self-supervised loss')
260 | # self-distillation stage parameter
261 | parser.add_argument('--temperature', type=int, default=4)
262 | parser.add_argument('--kd-coef', type=float, default=0.1, help="The gamma coefficient for distillation loss")
263 | # 0: copy teacher and only KD 1: common KD
264 | parser.add_argument('--kd-mode', type=int, default=1, choices=[0,1])
265 | parser.add_argument('--kd-type', type=str, default='kd', choices=['kd', 'hint'])
266 | parser.add_argument('--stage1-path', default='')
267 | parser.add_argument('--stage2-path', default='')
268 | args = parser.parse_args()
269 |
270 | start_time = datetime.datetime.now()
271 |
272 | # fix seed
273 | seed_torch(1)
274 | set_gpu(args.gpu)
275 |
276 | if args.dataset in ['mini', 'tiered']:
277 | args.size = 84
278 | elif args.dataset in ['cifarfs']:
279 | args.size = 32
280 | args.worker = 0
281 | else:
282 | args.size = 28
283 |
284 | training(args)
285 |
286 | end_time = datetime.datetime.now()
287 | print("Total executed time :", end_time - start_time)
288 |
289 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import time
4 | import random
5 | import numpy as np
6 |
7 | import torch
8 |
9 | def seed_torch(seed=1337):
10 | random.seed(seed)
11 | os.environ['PYTHONHASHSEED'] = str(seed)
12 | np.random.seed(seed)
13 | torch.manual_seed(seed)
14 | torch.cuda.manual_seed(seed)
15 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
16 | torch.backends.cudnn.benchmark = False
17 | torch.backends.cudnn.deterministic = True
18 |
19 |
20 | def set_gpu(x):
21 | os.environ['CUDA_VISIBLE_DEVICES'] = x
22 | print('using gpu:', x)
23 |
24 |
25 | def ensure_path(path):
26 | if os.path.exists(path):
27 | #if input('{} exists, remove? ([y]/n)'.format(path)) != 'n':
28 | shutil.rmtree(path)
29 | os.makedirs(path)
30 | else:
31 | os.makedirs(path)
32 |
33 |
34 | class Averager():
35 |
36 | def __init__(self):
37 | self.n = 0
38 | self.v = 0
39 |
40 | def add(self, x):
41 | self.v = (self.v * self.n + x) / (self.n + 1)
42 | self.n += 1
43 |
44 | def item(self):
45 | return self.v
46 |
47 |
48 | def count_acc(logits, label):
49 | pred = torch.argmax(logits, dim=1)
50 | return (pred == label).type(torch.cuda.FloatTensor).mean().item()
51 |
52 |
53 | def dot_metric(a, b):
54 | return torch.mm(a, b.t())
55 |
56 | import torch.nn.functional as F
57 | def cos_metric(a, b):
58 | return torch.mm(F.normalize(a, dim=-1), F.normalize(b, dim=-1).t())
59 |
60 | def euclidean_metric(a, b):
61 | n = a.shape[0]
62 | m = b.shape[0]
63 | a = a.unsqueeze(1).expand(n, m, -1)
64 | b = b.unsqueeze(0).expand(n, m, -1)
65 | logits = -((a - b)**2).sum(dim=2)
66 | return logits
67 |
68 |
69 | class Timer():
70 |
71 | def __init__(self):
72 | self.o = time.time()
73 |
74 | def measure(self, p=1):
75 | x = (time.time() - self.o) / p
76 | x = int(x)
77 | return x, self.tts(x)
78 |
79 | def tts(self, x=0):
80 | if x >= 3600:
81 | return '{:.1f}h'.format(x / 3600)
82 | if x >= 60:
83 | return '{}m'.format(round(x / 60))
84 | return '{}s'.format(x)
85 |
86 |
87 | def compute_confidence_interval(data):
88 | """
89 | Compute 95% confidence interval
90 | :param data: An array of mean accuracy (or mAP) across a number of sampled episodes.
91 | :return: the 95% confidence interval for this data.
92 | """
93 | a = 1.0 * np.array(data)
94 | m = np.mean(a)
95 | std = np.std(a)
96 | pm = 1.96 * (std / np.sqrt(len(a)))
97 | return m, pm
98 |
--------------------------------------------------------------------------------