├── CNN.py ├── LICENSE ├── OS-CNN.py ├── README.md ├── _config.yml └── cross-entropy loss and accuracy.png /CNN.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | import os 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 5 | import time 6 | import torch 7 | import logging 8 | import argparse 9 | import torchvision 10 | import torch.nn as nn 11 | import numpy as np 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | import torch.backends.cudnn as cudnn 16 | from torch.nn.modules.module import Module 17 | from torch.nn.parameter import Parameter 18 | import torchvision.transforms as transforms 19 | from itertools import combinations, permutations 20 | #from utils import progress_bar 21 | logging.basicConfig(level=logging.INFO) 22 | parser = argparse.ArgumentParser(description='PyTorch CIFAR100 Training') 23 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 24 | args = parser.parse_args() 25 | logging.info(args) 26 | 27 | store_name = "CNN" 28 | nb_epoch = 400 29 | # setup output 30 | 31 | 32 | 33 | use_cuda = torch.cuda.is_available() 34 | 35 | 36 | # Data 37 | print('==> Preparing data..') 38 | transform_train = transforms.Compose([ 39 | transforms.RandomCrop(32, padding=4), 40 | transforms.RandomHorizontalFlip(), 41 | transforms.ToTensor(), 42 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 43 | ]) 44 | 45 | 46 | 47 | transform_test = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 50 | ]) 51 | 52 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) 53 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=8) 54 | 55 | 56 | 57 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) 58 | testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=8) 59 | 60 | 61 | 62 | cfg = { 63 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 64 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 65 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 66 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 67 | } 68 | 69 | 70 | 71 | class VGG(nn.Module): 72 | def __init__(self, vgg_name): 73 | super(VGG, self).__init__() 74 | self.features = self._make_layers(cfg[vgg_name]) 75 | 76 | 77 | self.classifier = nn.Sequential( 78 | nn.Linear(512,256), 79 | nn.Linear(256, 100) 80 | ) 81 | 82 | def forward(self, x): 83 | out = self.features(x) 84 | out = out.view(out.size(0), -1) 85 | out = self.classifier(out) 86 | return out 87 | 88 | def _make_layers(self, cfg): 89 | layers = [] 90 | in_channels = 3 91 | for x in cfg: 92 | if x == 'M': 93 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 94 | else: 95 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 96 | nn.BatchNorm2d(x), 97 | nn.ReLU(inplace=True)] 98 | in_channels = x 99 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 100 | return nn.Sequential(*layers) 101 | 102 | # Model 103 | 104 | 105 | print('==> Building model..') 106 | 107 | 108 | 109 | net = VGG('VGG16') 110 | 111 | 112 | if use_cuda: 113 | net.cuda() 114 | 115 | cudnn.benchmark = True 116 | 117 | 118 | criterion = nn.CrossEntropyLoss() 119 | 120 | 121 | def train(epoch): 122 | print('\nEpoch: %d' % epoch) 123 | net.train() 124 | train_loss = 0 125 | correct = 0 126 | total = 0 127 | idx = 0 128 | 129 | 130 | for batch_idx, (inputs, targets) in enumerate(trainloader): 131 | pass 132 | idx = batch_idx 133 | if use_cuda: 134 | inputs, targets = inputs.cuda(), targets.cuda() 135 | optimizer.zero_grad() 136 | inputs, targets = Variable(inputs), Variable(targets) 137 | outputs = net(inputs) 138 | 139 | 140 | loss = criterion(outputs, targets) 141 | 142 | 143 | loss.backward() 144 | optimizer.step() 145 | 146 | train_loss += loss.item() 147 | _, predicted = torch.max(outputs.data, 1) 148 | total += targets.size(0) 149 | correct += predicted.eq(targets.data).cpu().sum().item() 150 | 151 | train_acc = 100.*correct/total 152 | train_loss = train_loss/(idx+1) 153 | logging.info('Iteration %d, train_acc = %.5f,train_loss = %.6f' % (epoch, train_acc,train_loss)) 154 | 155 | 156 | def test(epoch): 157 | net.eval() 158 | test_loss = 0 159 | correct = 0 160 | total = 0 161 | idx = 0 162 | for batch_idx, (inputs, targets) in enumerate(testloader): 163 | with torch.no_grad(): 164 | idx = batch_idx 165 | if use_cuda: 166 | inputs, targets = inputs.cuda(), targets.cuda() 167 | inputs, targets = Variable(inputs), Variable(targets) 168 | outputs = net(inputs) 169 | 170 | loss = criterion(outputs, targets) 171 | 172 | test_loss += loss.item() 173 | _, predicted = torch.max(outputs.data, 1) 174 | total += targets.size(0) 175 | correct += predicted.eq(targets.data).cpu().sum().item() 176 | 177 | 178 | test_acc = 100.*correct/total 179 | test_loss = test_loss/(idx+1) 180 | logging.info('Iteration %d, test_acc = %.4f,test_loss = %.4f' % (epoch, test_acc,test_loss)) 181 | return test_acc 182 | 183 | 184 | 185 | def cosine_anneal_schedule(t): 186 | cos_inner = np.pi * (t % (nb_epoch )) # t - 1 is used when t has 1-based indexing. 187 | cos_inner /= (nb_epoch ) 188 | cos_out = np.cos(cos_inner) + 1 189 | return float(args.lr / 2 * cos_out) 190 | 191 | 192 | 193 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 194 | 195 | max_val_acc = 0 196 | for epoch in range(nb_epoch): 197 | lr = cosine_anneal_schedule(epoch) 198 | for param_group in optimizer.param_groups: 199 | print(param_group['lr']) 200 | param_group['lr'] = lr 201 | train(epoch) 202 | test_acc = test(epoch) 203 | 204 | if test_acc >max_val_acc: 205 | max_val_acc = test_acc 206 | print("max_val_acc", max_val_acc) 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Dongliang Chang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /OS-CNN.py: -------------------------------------------------------------------------------- 1 | '''Train CIFAR10 with PyTorch.''' 2 | from __future__ import print_function 3 | import os 4 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 5 | import time 6 | import math 7 | import torch 8 | import logging 9 | import argparse 10 | import torchvision 11 | # from models import * 12 | import torch.nn as nn 13 | import numpy as np 14 | import torch.optim as optim 15 | import torch.nn.functional as F 16 | from torch.autograd import Variable 17 | import torch.backends.cudnn as cudnn 18 | from torch.nn.modules.module import Module 19 | from torch.nn.parameter import Parameter 20 | import torchvision.transforms as transforms 21 | from itertools import combinations, permutations 22 | #from utils import progress_bar 23 | logging.basicConfig(level=logging.INFO) 24 | parser = argparse.ArgumentParser(description='PyTorch CIFAR100 Training') 25 | parser.add_argument('--lr', default=0.1, type=float, help='learning rate') 26 | args = parser.parse_args() 27 | logging.info(args) 28 | 29 | store_name = "OS-CNN" 30 | nb_epoch = 400 31 | # setup output 32 | 33 | 34 | use_cuda = torch.cuda.is_available() 35 | 36 | 37 | # Data 38 | print('==> Preparing data..') 39 | transform_train = transforms.Compose([ 40 | transforms.RandomCrop(32, padding=4), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 44 | ]) 45 | 46 | 47 | 48 | transform_test = transforms.Compose([ 49 | transforms.ToTensor(), 50 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 51 | ]) 52 | 53 | trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) 54 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=8) 55 | 56 | 57 | 58 | testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) 59 | testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=8) 60 | 61 | 62 | def design_tensor_C(previous_hidden_num =256,next_hidden_num=100,classes=100): 63 | tensor_C = np.zeros(previous_hidden_num*next_hidden_num).reshape(previous_hidden_num,next_hidden_num) 64 | 65 | top_left_nums = int(math.floor(previous_hidden_num / classes)) 66 | column =top_left_nums 67 | row = int(math.floor(next_hidden_num /classes)) 68 | top_left = [[i * column, i * row] for i in range(classes)] 69 | 70 | remainder_1 = previous_hidden_num % classes 71 | remainder_2 = next_hidden_num % classes 72 | 73 | base_matrix = [] 74 | for i in range(column): 75 | for j in range(row): 76 | base_matrix.append([i,j]) 77 | base_matrix_1 = np.array(base_matrix) 78 | 79 | base_matrix = [] 80 | for i in range(column+remainder_1): 81 | for j in range(row+remainder_2): 82 | base_matrix.append([i,j]) 83 | base_matrix_2 = np.array(base_matrix) 84 | 85 | 86 | matrix_one_1 = [(base_matrix_1 + i).tolist() for i in top_left[:-1]] 87 | 88 | matrix_one_1_1 = [] 89 | for item in matrix_one_1: 90 | matrix_one_1_1 = matrix_one_1_1 + item 91 | 92 | matrix_one_2 = (base_matrix_2 + top_left[-1]).tolist() 93 | matrix_one = matrix_one_1_1 + matrix_one_2 94 | 95 | for item in range(len(matrix_one)): 96 | tensor_C[matrix_one[item][0],matrix_one[item][1]] = 1 97 | 98 | tensor_C = Variable(torch.from_numpy(tensor_C.astype("float32")).cuda()) 99 | return tensor_C 100 | 101 | tensor_C = design_tensor_C() 102 | 103 | class OS_Linear(Module): 104 | 105 | def __init__(self, in_features, out_features, bias=True): 106 | super(OS_Linear, self).__init__() 107 | self.in_features = in_features 108 | self.out_features = out_features 109 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 110 | if bias: 111 | self.bias = Parameter(torch.Tensor(out_features)) 112 | else: 113 | self.register_parameter('bias', None) 114 | self.reset_parameters() 115 | 116 | def reset_parameters(self): 117 | stdv = 1. / math.sqrt(self.weight.size(1)) 118 | self.weight.data.uniform_(-stdv, stdv) 119 | if self.bias is not None: 120 | self.bias.data.uniform_(-stdv, stdv) 121 | 122 | def forward(self, input): 123 | if input.dim() == 2 and self.bias is not None: 124 | 125 | output = input.matmul(self.weight.t()* tensor_C) 126 | 127 | if self.bias is not None: 128 | output += self.bias 129 | 130 | return output 131 | 132 | 133 | def __repr__(self): 134 | return self.__class__.__name__ + ' (' \ 135 | + str(self.in_features) + ' -> ' \ 136 | + str(self.out_features) + ')' 137 | 138 | 139 | 140 | cfg = { 141 | 'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 142 | 'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 143 | 'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 144 | 'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 145 | } 146 | 147 | 148 | class OS_VGG(nn.Module): 149 | def __init__(self, vgg_name): 150 | super(OS_VGG, self).__init__() 151 | self.features = self._make_layers(cfg[vgg_name]) 152 | self.classifier = nn.Sequential( 153 | nn.Linear(512,256), 154 | OS_Linear(256, 100) 155 | ) 156 | def forward(self, x): 157 | out = self.features(x) 158 | out = out.view(out.size(0), -1) 159 | out = self.classifier(out) 160 | return out 161 | 162 | def _make_layers(self, cfg): 163 | layers = [] 164 | in_channels = 3 165 | for x in cfg: 166 | if x == 'M': 167 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 168 | else: 169 | layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1), 170 | nn.BatchNorm2d(x), 171 | nn.ReLU(inplace=True)] 172 | in_channels = x 173 | layers += [nn.AvgPool2d(kernel_size=1, stride=1)] 174 | return nn.Sequential(*layers) 175 | 176 | # Model 177 | 178 | 179 | print('==> Building model..') 180 | 181 | net = OS_VGG('VGG16') 182 | 183 | if use_cuda: 184 | net.cuda() 185 | 186 | cudnn.benchmark = True 187 | 188 | criterion = nn.CrossEntropyLoss() 189 | 190 | def train(epoch): 191 | print('\nEpoch: %d' % epoch) 192 | net.train() 193 | train_loss = 0 194 | correct = 0 195 | total = 0 196 | idx = 0 197 | 198 | 199 | for batch_idx, (inputs, targets) in enumerate(trainloader): 200 | idx = batch_idx 201 | if use_cuda: 202 | inputs, targets = inputs.cuda(), targets.cuda() 203 | optimizer.zero_grad() 204 | inputs, targets = Variable(inputs), Variable(targets) 205 | outputs = net(inputs) 206 | 207 | 208 | loss = criterion(outputs, targets) 209 | 210 | 211 | loss.backward() 212 | optimizer.step() 213 | 214 | train_loss += loss.item() 215 | _, predicted = torch.max(outputs.data, 1) 216 | total += targets.size(0) 217 | correct += predicted.eq(targets.data).cpu().sum().item() 218 | 219 | 220 | train_acc = 100.*correct/total 221 | train_loss = train_loss/(idx+1) 222 | logging.info('Iteration %d, train_acc = %.5f,train_loss = %.6f' % (epoch, train_acc,train_loss)) 223 | 224 | 225 | def test(epoch): 226 | net.eval() 227 | test_loss = 0 228 | correct = 0 229 | total = 0 230 | idx = 0 231 | for batch_idx, (inputs, targets) in enumerate(testloader): 232 | with torch.no_grad(): 233 | idx = batch_idx 234 | if use_cuda: 235 | inputs, targets = inputs.cuda(), targets.cuda() 236 | inputs, targets = Variable(inputs), Variable(targets) 237 | outputs = net(inputs) 238 | 239 | 240 | loss = criterion(outputs, targets) 241 | 242 | test_loss += loss.item() 243 | _, predicted = torch.max(outputs.data, 1) 244 | total += targets.size(0) 245 | correct += predicted.eq(targets.data).cpu().sum().item() 246 | 247 | 248 | test_acc = 100.*correct/total 249 | test_loss = test_loss/(idx+1) 250 | logging.info('Iteration %d, test_acc = %.4f,test_loss = %.4f' % (epoch, test_acc,test_loss)) 251 | return test_acc 252 | 253 | 254 | def cosine_anneal_schedule(t): 255 | cos_inner = np.pi * (t % (nb_epoch )) # t - 1 is used when t has 1-based indexing. 256 | cos_inner /= (nb_epoch ) 257 | cos_out = np.cos(cos_inner) + 1 258 | return float(args.lr / 2 * cos_out) 259 | 260 | 261 | 262 | optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) 263 | 264 | 265 | max_val_acc = 0 266 | for epoch in range(nb_epoch): 267 | lr = cosine_anneal_schedule(epoch) 268 | for param_group in optimizer.param_groups: 269 | print(param_group['lr']) 270 | param_group['lr'] = lr 271 | train(epoch) 272 | test_acc = test(epoch) 273 | 274 | if test_acc >max_val_acc: 275 | max_val_acc = test_acc 276 | print("max_val_acc", max_val_acc) 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OSLNet: Deep Small-Sample Classification with an Orthogonal Softmax Layer 2 | 3 | Code release for OSLNet: Deep Small-Sample Classification with an Orthogonal Softmax Layer (TIP2020) 4 | [DOI](https://doi.org/10.1109/TIP.2020.2990277 "DOI") 5 | 6 | 7 | ## Changelog 8 | - 2020/04/21 upload the code. 9 | 10 | ## Dataset 11 | ### CIFAR-100 12 | 13 | ## Requirements 14 | 15 | - python 3.6 16 | - PyTorch 1.2.0 17 | - torchvision 18 | 19 | ## Training 20 | - Download datasets 21 | - Train: `python OS-CNN.py` or `python CNN.py` 22 | - Description : PyTorch CIFAR-100 Training with OSNet or PyTorch CIFAR-100 Training with Vanilla Model. 23 | 24 | 25 | ## Accuracy and Cross-entropy loss 26 | ![AccuracyandCross-entropyloss](https://github.com/dongliangchang/OSLNet/blob/master/cross-entropy%20loss%20and%20accuracy.png) 27 | ## Citation 28 | If you find this paper useful in your research, please consider citing: 29 | ``` 30 | @ARTICLE{9088302, 31 | 32 | author={X. {Li} and D. {Chang} and Z. {Ma} and Z. {Tan} and J. {Xue} and J. {Cao} and J. {Yu} and J. {Guo}}, 33 | journal={IEEE Transactions on Image Processing}, 34 | title={OSLNet: Deep Small-Sample Classification with an Orthogonal Softmax Layer}, 35 | year={2020}, 36 | volume={}, 37 | number={}, 38 | pages={1-1}, 39 | } 40 | 41 | ``` 42 | 43 | 44 | ## Contact 45 | Thanks for your attention! 46 | If you have any suggestion or question, you can leave a message here or contact us directly: 47 | - mazhanyu@bupt.edu.cn 48 | - xiaoxulilut@gmail.com 49 | - changdongliang@bupt.edu.cn 50 | 51 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-modernist -------------------------------------------------------------------------------- /cross-entropy loss and accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRIS-CV/OSLNet/f9df95cac256c4108c3e49c9016513c825e23a8a/cross-entropy loss and accuracy.png --------------------------------------------------------------------------------