├── ImageNet_iNat ├── ResNet.py ├── data │ ├── ImageNet_LT_test.txt │ ├── ImageNet_LT_train.txt │ ├── ImageNet_LT_val.txt │ ├── iNaturalist18_train.txt │ └── iNaturalist18_val.txt ├── data_utils.py ├── dataloader.py ├── loss.py ├── resnet_meta.py ├── scripts │ ├── test.sh │ └── train.sh ├── test.py ├── train.py └── utils.py ├── LICENSE ├── MetaSAug_LDAM_train.py ├── MetaSAug_test.py ├── README.md ├── assets └── illustration.png ├── data_utils.py ├── loss.py ├── resnet.py └── scripts ├── MetaSAug_CE_test.sh ├── MetaSAug_LDAM_test.sh └── MetaSAug_LDAM_train.sh /ImageNet_iNat/ResNet.py: -------------------------------------------------------------------------------- 1 | from resnet_meta import * 2 | from utils import * 3 | from os import path 4 | 5 | 6 | def create_model(use_selfatt=False, use_fc=False, dropout=None, stage1_weights=False, dataset=None, log_dir=None, test=False, *args): 7 | 8 | print('Loading Scratch ResNet 50 Feature Model.') 9 | if not use_fc: 10 | resnet50 = FeatureMeta(BottleneckMeta, [3, 4, 6, 3], dropout=None) 11 | else: 12 | resnet50 = FCMeta(2048, 1000) 13 | if not test: 14 | if stage1_weights: 15 | assert dataset 16 | print('Loading %s Stage 1 ResNet 10 Weights.' % dataset) 17 | if log_dir is not None: 18 | # subdir = log_dir.strip('/').split('/')[-1] 19 | # subdir = subdir.replace('stage2', 'stage1') 20 | # weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), subdir) 21 | #weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), 'stage1') 22 | weight_dir = log_dir 23 | else: 24 | weight_dir = './logs/%s/stage1' % dataset 25 | print('==> Loading weights from %s' % weight_dir) 26 | if not use_fc: 27 | resnet50 = init_weights(model=resnet50, 28 | weights_path=weight_dir) 29 | else: 30 | resnet50 = init_weights(model=resnet50, weights_path=weight_dir, classifier=True) 31 | #resnet50.load_state_dict(torch.load(weight_dir)) 32 | else: 33 | print('No Pretrained Weights For Feature Model.') 34 | 35 | return resnet50 36 | -------------------------------------------------------------------------------- /ImageNet_iNat/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.parallel 5 | import torch.backends.cudnn as cudnn 6 | import torch.optim 7 | import torch.utils.data 8 | import torchvision.transforms as transforms 9 | import torchvision 10 | import numpy as np 11 | import copy 12 | 13 | np.random.seed(6) 14 | #random.seed(2) 15 | def build_dataset(dataset,num_meta,num_classes): 16 | 17 | 18 | img_num_list = [num_meta] * num_classes 19 | 20 | # print(train_dataset.targets) 21 | data_list_val = {} 22 | for j in range(num_classes): 23 | data_list_val[j] = [i for i, label in enumerate(dataset.labels) if label == j] 24 | 25 | idx_to_meta = [] 26 | idx_to_train = [] 27 | #print(img_num_list) 28 | 29 | for cls_idx, img_id_list in data_list_val.items(): 30 | np.random.shuffle(img_id_list) 31 | img_num = img_num_list[int(cls_idx)] 32 | idx_to_meta.extend(img_id_list[:img_num]) 33 | idx_to_train.extend(img_id_list[img_num:]) 34 | train_data = copy.deepcopy(dataset) 35 | train_data_meta = copy.deepcopy(dataset) 36 | 37 | train_data_meta.img_path = np.delete(dataset.img_path,idx_to_train,axis=0) 38 | train_data_meta.labels = np.delete(dataset.labels, idx_to_train, axis=0) 39 | train_data.img_path = np.delete(dataset.img_path, idx_to_meta, axis=0) 40 | train_data.labels = np.delete(dataset.labels, idx_to_meta, axis=0) 41 | 42 | return train_data_meta, train_data 43 | 44 | 45 | def get_img_num_per_cls(dataset, imb_factor=None, num_meta=None): 46 | """ 47 | Get a list of image numbers for each class, given cifar version 48 | Num of imgs follows emponential distribution 49 | img max: 5000 / 500 * e^(-lambda * 0); 50 | img min: 5000 / 500 * e^(-lambda * int(cifar_version - 1)) 51 | exp(-lambda * (int(cifar_version) - 1)) = img_max / img_min 52 | args: 53 | cifar_version: str, '10', '100', '20' 54 | imb_factor: float, imbalance factor: img_min/img_max, 55 | None if geting default cifar data number 56 | output: 57 | img_num_per_cls: a list of number of images per class 58 | """ 59 | if dataset == 'cifar10': 60 | img_max = (50000-num_meta)/10 61 | cls_num = 10 62 | 63 | if dataset == 'cifar100': 64 | img_max = (50000-num_meta)/100 65 | cls_num = 100 66 | 67 | if imb_factor is None: 68 | return [img_max] * cls_num 69 | img_num_per_cls = [] 70 | for cls_idx in range(cls_num): 71 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 72 | img_num_per_cls.append(int(num)) 73 | return img_num_per_cls 74 | 75 | 76 | # This function is used to generate imbalanced test set 77 | ''' 78 | def get_img_num_per_cls_test(dataset,imb_factor=None,num_meta=None): 79 | """ 80 | Get a list of image numbers for each class, given cifar version 81 | Num of imgs follows emponential distribution 82 | img max: 5000 / 500 * e^(-lambda * 0); 83 | img min: 5000 / 500 * e^(-lambda * int(cifar_version - 1)) 84 | exp(-lambda * (int(cifar_version) - 1)) = img_max / img_min 85 | args: 86 | cifar_version: str, '10', '100', '20' 87 | imb_factor: float, imbalance factor: img_min/img_max, 88 | None if geting default cifar data number 89 | output: 90 | img_num_per_cls: a list of number of images per class 91 | """ 92 | if dataset == 'cifar10': 93 | img_max = (10000-num_meta)/10 94 | cls_num = 10 95 | 96 | if dataset == 'cifar100': 97 | img_max = (10000-num_meta)/100 98 | cls_num = 100 99 | 100 | if imb_factor is None: 101 | return [img_max] * cls_num 102 | img_num_per_cls = [] 103 | for cls_idx in range(cls_num): 104 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 105 | img_num_per_cls.append(int(num)) 106 | return img_num_per_cls 107 | ''' 108 | -------------------------------------------------------------------------------- /ImageNet_iNat/dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchvision 3 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 4 | from torchvision import transforms 5 | import os 6 | from PIL import Image 7 | import json 8 | 9 | # Image statistics 10 | RGB_statistics = { 11 | 'iNaturalist18': { 12 | 'mean': [0.466, 0.471, 0.380], 13 | 'std': [0.195, 0.194, 0.192] 14 | }, 15 | 'default': { 16 | 'mean': [0.485, 0.456, 0.406], 17 | 'std': [0.229, 0.224, 0.225] 18 | } 19 | } 20 | 21 | 22 | def get_data_transform(split, rgb_mean, rbg_std, key='default'): 23 | data_transforms = { 24 | 'train': transforms.Compose([ 25 | transforms.RandomResizedCrop(224), 26 | transforms.RandomHorizontalFlip(), 27 | transforms.ToTensor(), 28 | transforms.Normalize(rgb_mean, rbg_std) 29 | ]) if key == 'iNaturalist18' else transforms.Compose([ 30 | transforms.RandomResizedCrop(224), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0), 33 | transforms.ToTensor(), 34 | transforms.Normalize(rgb_mean, rbg_std) 35 | ]), 36 | 'val': transforms.Compose([ 37 | transforms.Resize(256), 38 | transforms.CenterCrop(224), 39 | transforms.ToTensor(), 40 | transforms.Normalize(rgb_mean, rbg_std) 41 | ]), 42 | 'test': transforms.Compose([ 43 | transforms.Resize(256), 44 | transforms.CenterCrop(224), 45 | transforms.ToTensor(), 46 | transforms.Normalize(rgb_mean, rbg_std) 47 | ]) 48 | } 49 | return data_transforms[split] 50 | 51 | 52 | class LT_Dataset(Dataset): 53 | 54 | def __init__(self, root, txt, transform=None): 55 | self.img_path = [] 56 | self.labels = [] 57 | self.transform = transform 58 | print("--------------------------------------------") 59 | print(root) 60 | with open(txt) as f: 61 | for line in f: 62 | self.img_path.append(os.path.join(root, line.split()[0])) 63 | self.labels.append(int(line.split()[1])) 64 | 65 | def __len__(self): 66 | return len(self.labels) 67 | 68 | def __getitem__(self, index): 69 | 70 | path = self.img_path[index] 71 | label = self.labels[index] 72 | 73 | with open(path, 'rb') as f: 74 | sample = Image.open(f).convert('RGB') 75 | 76 | if self.transform is not None: 77 | sample = self.transform(sample) 78 | 79 | return sample, label 80 | 81 | class LT_Dataset_iNat17(Dataset): 82 | ''' 83 | Reading the Json file of iNaturalist17 84 | ''' 85 | def __init__(self, root, txt, transform=None): 86 | self.img_path = [] 87 | self.labels = [] 88 | self.transform = transform 89 | 90 | with open(txt,'r',encoding='utf8')as fp: 91 | json_data = json.load(fp) 92 | images = json_data["images"] 93 | labels = json_data["annotations"] 94 | for i in range(len(images)): 95 | 96 | self.img_path.append(os.path.join(root, images[i]["file_name"])) 97 | self.labels.append(int(labels[i]["category_id"])) 98 | 99 | def __len__(self): 100 | return len(self.labels) 101 | 102 | def __getitem__(self, index): 103 | 104 | path = self.img_path[index] 105 | label = self.labels[index] 106 | 107 | with open(path, 'rb') as f: 108 | sample = Image.open(f).convert('RGB') 109 | 110 | if self.transform is not None: 111 | sample = self.transform(sample) 112 | 113 | return sample, label 114 | 115 | 116 | def load_data_distributed(data_root, dataset, phase, batch_size, sampler_dic=None, num_workers=4, test_open=False, shuffle=True): 117 | 118 | # if phase == 'train_plain': 119 | # txt_split = 'train' 120 | # elif phase == 'train_val': 121 | # txt_split = 'val' 122 | # phase = 'train' 123 | # else: 124 | # txt_split = phase 125 | if dataset == "iNaturalist17": 126 | txt = 'data/%s_%s.json' % (dataset, phase) 127 | else: 128 | txt = 'data/%s_%s.txt' % (dataset, phase) 129 | 130 | print('Loading data from %s' % txt) 131 | 132 | if dataset == 'iNaturalist18': 133 | print('===> Loading iNaturalist18 statistics') 134 | key = 'iNaturalist18' 135 | elif dataset == 'iNaturalist17': 136 | print('===> Loading iNaturalist17 statistics') 137 | key = 'iNaturalist18' 138 | else: 139 | key = 'default' 140 | 141 | rgb_mean, rgb_std = RGB_statistics[key]['mean'], RGB_statistics[key]['std'] 142 | 143 | if phase not in ['train', 'val']: 144 | transform = get_data_transform('test', rgb_mean, rgb_std, key) 145 | else: 146 | transform = get_data_transform(phase, rgb_mean, rgb_std, key) 147 | 148 | print('Use data transformation:', transform) 149 | 150 | if dataset == "iNaturalist17": 151 | set_ = LT_Dataset_iNat17(data_root, txt, transform) 152 | else: 153 | set_ = LT_Dataset(data_root, txt, transform) 154 | 155 | 156 | return set_ 157 | -------------------------------------------------------------------------------- /ImageNet_iNat/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -* 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import math 8 | import torch.nn.functional as F 9 | import pdb 10 | 11 | def MI(outputs_target): 12 | batch_size = outputs_target.size(0) 13 | softmax_outs_t = nn.Softmax(dim=1)(outputs_target) 14 | avg_softmax_outs_t = torch.sum(softmax_outs_t, dim=0) / float(batch_size) 15 | log_avg_softmax_outs_t = torch.log(avg_softmax_outs_t) 16 | item1 = -torch.sum(avg_softmax_outs_t * log_avg_softmax_outs_t) 17 | item2 = -torch.sum(softmax_outs_t * torch.log(softmax_outs_t)) / float(batch_size) 18 | return item1, item2 19 | 20 | class EstimatorMean(): 21 | def __init__(self, feature_num, class_num): 22 | super(EstimatorMean, self).__init__() 23 | self.class_num = class_num 24 | self.Ave = torch.zeros(class_num, feature_num).cuda() 25 | self.Amount = torch.zeros(class_num).cuda() 26 | 27 | def update_Mean(self, features, labels): 28 | N = features.size(0) 29 | C = self.class_num 30 | A = features.size(1) 31 | 32 | NxCxFeatures = features.view(N, 1, A).expand(N, C, A) 33 | onehot = torch.zeros(N, C).cuda() 34 | onehot.scatter_(1, labels.view(-1, 1), 1) 35 | 36 | NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A) 37 | 38 | features_by_sort = NxCxFeatures.mul(NxCxA_onehot) 39 | 40 | Amount_CxA = NxCxA_onehot.sum(0) 41 | Amount_CxA[Amount_CxA == 0] = 1 42 | 43 | ave_CxA = features_by_sort.sum(0) / Amount_CxA 44 | 45 | sum_weight_CV = onehot.sum(0).view(C, 1, 1).expand(C, A, A) 46 | 47 | sum_weight_AV = onehot.sum(0).view(C, 1).expand(C, A) 48 | 49 | weight_CV = sum_weight_CV.div(sum_weight_CV + self.Amount.view(C, 1, 1).expand(C, A, A)) 50 | weight_CV[weight_CV != weight_CV] = 0 51 | 52 | weight_AV = sum_weight_AV.div(sum_weight_AV + self.Amount.view(C, 1).expand(C, A)) 53 | weight_AV[weight_AV != weight_AV] = 0 54 | 55 | self.Ave = (self.Ave.mul(1 - weight_AV) + ave_CxA.mul(weight_AV)).detach() 56 | self.Amount += onehot.sum(0) 57 | 58 | # the estimation of covariance matrix 59 | class EstimatorCV(): 60 | def __init__(self, feature_num, class_num): 61 | super(EstimatorCV, self).__init__() 62 | self.class_num = class_num 63 | self.CoVariance = torch.zeros(class_num, feature_num).cuda() 64 | self.Ave = torch.zeros(class_num, feature_num).cuda() 65 | self.Amount = torch.zeros(class_num).cuda() 66 | 67 | def update_CV(self, features, labels): 68 | N = features.size(0) 69 | C = self.class_num 70 | A = features.size(1) 71 | 72 | NxCxFeatures = features.view(N, 1, A).expand(N, C, A) 73 | # onehot = torch.zeros(N, C).cuda() 74 | onehot = torch.zeros(N, C).cuda() 75 | onehot.scatter_(1, labels.view(-1, 1), 1) 76 | 77 | NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A) 78 | 79 | features_by_sort = NxCxFeatures.mul(NxCxA_onehot) 80 | 81 | Amount_CxA = NxCxA_onehot.sum(0) 82 | Amount_CxA[Amount_CxA == 0] = 1 83 | 84 | ave_CxA = features_by_sort.sum(0) / Amount_CxA 85 | 86 | var_temp = features_by_sort - ave_CxA.expand(N, C, A).mul(NxCxA_onehot) 87 | 88 | # var_temp = torch.bmm(var_temp.permute(1, 2, 0), var_temp.permute(1, 0, 2)).div(Amount_CxA.view(C, A, 1).expand(C, A, A)) 89 | var_temp = torch.mul(var_temp.permute(1, 2, 0), var_temp.permute(1, 2, 0)).sum(2).div(Amount_CxA.view(C, A)) 90 | 91 | # sum_weight_CV = onehot.sum(0).view(C, 1, 1).expand(C, A, A) 92 | sum_weight_CV = onehot.sum(0).view(C, 1).expand(C, A) 93 | 94 | sum_weight_AV = onehot.sum(0).view(C, 1).expand(C, A) 95 | 96 | # weight_CV = sum_weight_CV.div(sum_weight_CV + self.Amount.view(C, 1, 1).expand(C, A, A)) 97 | weight_CV = sum_weight_CV.div(sum_weight_CV + self.Amount.view(C, 1).expand(C, A)) 98 | weight_CV[weight_CV != weight_CV] = 0 99 | 100 | weight_AV = sum_weight_AV.div(sum_weight_AV + self.Amount.view(C, 1).expand(C, A)) 101 | weight_AV[weight_AV != weight_AV] = 0 102 | 103 | additional_CV = weight_CV.mul(1 - weight_CV).mul( 104 | torch.mul( 105 | (self.Ave - ave_CxA).view(C, A), 106 | (self.Ave - ave_CxA).view(C, A) 107 | ) 108 | ) 109 | # (self.Ave - ave_CxA).pow(2) 110 | 111 | self.CoVariance = (self.CoVariance.mul(1 - weight_CV) + var_temp.mul( 112 | weight_CV)).detach() + additional_CV.detach() 113 | self.Ave = (self.Ave.mul(1 - weight_AV) + ave_CxA.mul(weight_AV)).detach() 114 | self.Amount += onehot.sum(0) 115 | 116 | class Loss_meta(nn.Module): 117 | def __init__(self, feature_num, class_num): 118 | super(Loss_meta, self).__init__() 119 | self.source_estimator = EstimatorCV(feature_num, class_num) 120 | self.class_num = class_num 121 | self.cross_entropy = nn.CrossEntropyLoss() 122 | 123 | def MetaSAug(self, fc, features, y_s, labels_s, s_cv_matrix, ratio): 124 | N = features.size(0) 125 | C = self.class_num 126 | A = features.size(1) 127 | 128 | weight_m = fc 129 | 130 | NxW_ij = weight_m.expand(N, C, A) 131 | NxW_kj = torch.gather(NxW_ij, 1, labels_s.view(N, 1, 1).expand(N, C, A)) 132 | 133 | CV_temp = s_cv_matrix[labels_s] 134 | sigma2 = ratio * (weight_m - NxW_kj).pow(2).mul(CV_temp.view(N, 1, A).expand(N, C, A)).sum(2) 135 | aug_result = y_s + 0.5 * sigma2 136 | return aug_result 137 | 138 | def forward(self, fc, features_source, y_s, labels_source, ratio, weights, cv, mode): 139 | 140 | aug_y = self.MetaSAug(fc, features_source, y_s, labels_source, cv, \ 141 | ratio) 142 | 143 | if mode == "update": 144 | self.source_estimator.update_CV(features_source.detach(), labels_source) 145 | loss = F.cross_entropy(aug_y, labels_source, weight=weights) 146 | else: 147 | loss = F.cross_entropy(aug_y, labels_source, weight=weights) 148 | return loss 149 | 150 | def get_cv(self): 151 | return self.source_estimator.CoVariance 152 | 153 | def update_cv(self, cv): 154 | self.source_estimator.CoVariance = cv 155 | -------------------------------------------------------------------------------- /ImageNet_iNat/resnet_meta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch.autograd import Variable 6 | import torch.nn.init as init 7 | 8 | 9 | def to_var(x, requires_grad=True): 10 | if torch.cuda.is_available(): 11 | x = x.cuda() 12 | return Variable(x, requires_grad=requires_grad) 13 | 14 | 15 | class MetaModule(nn.Module): 16 | # adopted from: Adrien Ecoffet https://github.com/AdrienLE 17 | def params(self): 18 | for name, param in self.named_params(self): 19 | yield param 20 | 21 | def named_leaves(self): 22 | return [] 23 | 24 | def named_submodules(self): 25 | return [] 26 | 27 | def named_params(self, curr_module=None, memo=None, prefix=''): 28 | if memo is None: 29 | memo = set() 30 | 31 | if hasattr(curr_module, 'named_leaves'): 32 | for name, p in curr_module.named_leaves(): 33 | if p is not None and p not in memo: 34 | memo.add(p) 35 | yield prefix + ('.' if prefix else '') + name, p 36 | else: 37 | for name, p in curr_module._parameters.items(): 38 | if p is not None and p not in memo: 39 | memo.add(p) 40 | yield prefix + ('.' if prefix else '') + name, p 41 | 42 | for mname, module in curr_module.named_children(): 43 | submodule_prefix = prefix + ('.' if prefix else '') + mname 44 | for name, p in self.named_params(module, memo, submodule_prefix): 45 | yield name, p 46 | 47 | def update_params(self, lr_inner, first_order=False, source_params=None, detach=False): 48 | if source_params is not None: 49 | for tgt, src in zip(self.named_params(self), source_params): 50 | name_t, param_t = tgt 51 | # name_s, param_s = src 52 | # grad = param_s.grad 53 | # name_s, param_s = src 54 | grad = src 55 | if first_order: 56 | grad = to_var(grad.detach().data) 57 | tmp = param_t - lr_inner * grad 58 | self.set_param(self, name_t, tmp) 59 | else: 60 | 61 | for name, param in self.named_params(self): 62 | if not detach: 63 | grad = param.grad 64 | if first_order: 65 | grad = to_var(grad.detach().data) 66 | tmp = param - lr_inner * grad 67 | self.set_param(self, name, tmp) 68 | else: 69 | param = param.detach_() # https://blog.csdn.net/qq_39709535/article/details/81866686 70 | self.set_param(self, name, param) 71 | 72 | def set_param(self, curr_mod, name, param): 73 | if '.' in name: 74 | n = name.split('.') 75 | module_name = n[0] 76 | rest = '.'.join(n[1:]) 77 | for name, mod in curr_mod.named_children(): 78 | if module_name == name: 79 | self.set_param(mod, rest, param) 80 | break 81 | else: 82 | setattr(curr_mod, name, param) 83 | 84 | def detach_params(self): 85 | for name, param in self.named_params(self): 86 | self.set_param(self, name, param.detach()) 87 | 88 | def copy(self, other, same_var=False): 89 | for name, param in other.named_params(): 90 | if not same_var: 91 | param = to_var(param.data.clone(), requires_grad=True) 92 | self.set_param(name, param) 93 | 94 | 95 | class MetaLinear(MetaModule): 96 | def __init__(self, *args, **kwargs): 97 | super().__init__() 98 | ignore = nn.Linear(*args, **kwargs) 99 | 100 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 101 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 102 | 103 | def forward(self, x): 104 | return F.linear(x, self.weight, self.bias) 105 | 106 | def named_leaves(self): 107 | return [('weight', self.weight), ('bias', self.bias)] 108 | 109 | # This layer will be used when the loss function is LDAM 110 | class MetaLinear_Norm(MetaModule): 111 | def __init__(self, *args, **kwargs): 112 | super().__init__() 113 | temp = nn.Linear(*args, **kwargs) 114 | # import pdb; pdb.set_trace() 115 | temp.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 116 | self.register_buffer('weight', to_var(temp.weight.data.t(), requires_grad=True)) 117 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 118 | #self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 119 | 120 | def forward(self, x): 121 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 122 | return out 123 | # return F.linear(x, self.weight, self.bias) 124 | 125 | def named_leaves(self): 126 | return [('weight', self.weight)]#, ('bias', self.bias)] 127 | 128 | 129 | class MetaConv2d(MetaModule): 130 | def __init__(self, *args, **kwargs): 131 | super().__init__() 132 | ignore = nn.Conv2d(*args, **kwargs) 133 | 134 | self.in_channels = ignore.in_channels 135 | self.out_channels = ignore.out_channels 136 | self.stride = ignore.stride 137 | self.padding = ignore.padding 138 | self.dilation = ignore.dilation 139 | self.groups = ignore.groups 140 | self.kernel_size = ignore.kernel_size 141 | 142 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 143 | 144 | if ignore.bias is not None: 145 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 146 | else: 147 | self.register_buffer('bias', None) 148 | 149 | def forward(self, x): 150 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 151 | 152 | def named_leaves(self): 153 | return [('weight', self.weight), ('bias', self.bias)] 154 | 155 | 156 | class MetaConvTranspose2d(MetaModule): 157 | def __init__(self, *args, **kwargs): 158 | super().__init__() 159 | ignore = nn.ConvTranspose2d(*args, **kwargs) 160 | 161 | self.stride = ignore.stride 162 | self.padding = ignore.padding 163 | self.dilation = ignore.dilation 164 | self.groups = ignore.groups 165 | 166 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 167 | 168 | if ignore.bias is not None: 169 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 170 | else: 171 | self.register_buffer('bias', None) 172 | 173 | def forward(self, x, output_size=None): 174 | output_padding = self._output_padding(x, output_size) 175 | return F.conv_transpose2d(x, self.weight, self.bias, self.stride, self.padding, 176 | output_padding, self.groups, self.dilation) 177 | 178 | def named_leaves(self): 179 | return [('weight', self.weight), ('bias', self.bias)] 180 | 181 | 182 | class MetaBatchNorm1d(MetaModule): 183 | def __init__(self, *args, **kwargs): 184 | super(MetaBatchNorm1d, self).__init__() 185 | ignore = nn.BatchNorm1d(*args, **kwargs) 186 | 187 | self.num_features = ignore.num_features 188 | self.eps = ignore.eps 189 | self.momentum = ignore.momentum 190 | self.affine = ignore.affine 191 | self.track_running_stats = ignore.track_running_stats 192 | 193 | if self.affine: 194 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 195 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 196 | 197 | if self.track_running_stats: 198 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 199 | self.register_buffer('running_var', torch.ones(self.num_features)) 200 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 201 | else: 202 | self.register_parameter('running_mean', None) 203 | self.register_parameter('running_var', None) 204 | 205 | def reset_running_stats(self): 206 | if self.track_running_stats: 207 | self.running_mean.zero_() 208 | self.running_var.fill_(1) 209 | self.num_batches_tracked.zero_() 210 | 211 | def reset_parameters(self): 212 | self.reset_running_stats() 213 | if self.affine: 214 | self.weight.data.uniform_() 215 | self.bias.data.zero_() 216 | 217 | def _check_input_dim(self, input): 218 | if input.dim() != 2 and input.dim() != 3: 219 | raise ValueError('expected 2D or 3D input (got {}D input)' 220 | .format(input.dim())) 221 | 222 | def forward(self, x): 223 | self._check_input_dim(x) 224 | exponential_average_factor = 0.0 225 | 226 | if self.training and self.track_running_stats: 227 | self.num_batches_tracked += 1 228 | if self.momentum is None: # use cumulative moving average 229 | exponential_average_factor = 1.0 / self.num_batches_tracked.item() 230 | else: # use exponential moving average 231 | exponential_average_factor = self.momentum 232 | 233 | return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 234 | self.training or not self.track_running_stats, self.momentum, self.eps) 235 | 236 | def named_leaves(self): 237 | return [('weight', self.weight), ('bias', self.bias)] 238 | 239 | def extra_repr(self): 240 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 241 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 242 | 243 | def _load_from_state_dict(self, state_dict, prefix, metadata, strict, 244 | missing_keys, unexpected_keys, error_msgs): 245 | version = metadata.get('version', None) 246 | 247 | if (version is None or version < 2) and self.track_running_stats: 248 | # at version 2: added num_batches_tracked buffer 249 | # this should have a default value of 0 250 | num_batches_tracked_key = prefix + 'num_batches_tracked' 251 | if num_batches_tracked_key not in state_dict: 252 | state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) 253 | 254 | super(MetaBatchNorm1d, self)._load_from_state_dict( 255 | state_dict, prefix, metadata, strict, 256 | missing_keys, unexpected_keys, error_msgs) 257 | 258 | 259 | class MetaBatchNorm2d(MetaModule): 260 | def __init__(self, *args, **kwargs): 261 | super().__init__() 262 | ignore = nn.BatchNorm2d(*args, **kwargs) 263 | 264 | self.num_features = ignore.num_features 265 | self.eps = ignore.eps 266 | self.momentum = ignore.momentum 267 | self.affine = ignore.affine 268 | self.track_running_stats = ignore.track_running_stats 269 | 270 | if self.affine: 271 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 272 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 273 | 274 | if self.track_running_stats: 275 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 276 | self.register_buffer('running_var', torch.ones(self.num_features)) 277 | else: 278 | self.register_parameter('running_mean', None) 279 | self.register_parameter('running_var', None) 280 | 281 | def forward(self, x): 282 | return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 283 | self.training or not self.track_running_stats, self.momentum, self.eps) 284 | 285 | def named_leaves(self): 286 | return [('weight', self.weight), ('bias', self.bias)] 287 | 288 | 289 | def _weights_init(m): 290 | classname = m.__class__.__name__ 291 | # print(classname) 292 | if isinstance(m, MetaLinear) or isinstance(m, MetaConv2d): 293 | init.kaiming_normal(m.weight) 294 | 295 | class LambdaLayer(MetaModule): 296 | def __init__(self, lambd): 297 | super(LambdaLayer, self).__init__() 298 | self.lambd = lambd 299 | 300 | def forward(self, x): 301 | return self.lambd(x) 302 | 303 | 304 | class BasicBlock(MetaModule): 305 | expansion = 1 306 | 307 | def __init__(self, in_planes, planes, stride=1, option='A'): 308 | super(BasicBlock, self).__init__() 309 | self.conv1 = MetaConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 310 | self.bn1 = MetaBatchNorm2d(planes) 311 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 312 | self.bn2 = MetaBatchNorm2d(planes) 313 | 314 | self.shortcut = nn.Sequential() 315 | if stride != 1 or in_planes != planes: 316 | if option == 'A': 317 | """ 318 | For CIFAR10 ResNet paper uses option A. 319 | """ 320 | self.shortcut = LambdaLayer(lambda x: 321 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 322 | elif option == 'B': 323 | self.shortcut = nn.Sequential( 324 | MetaConv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 325 | MetaBatchNorm2d(self.expansion * planes) 326 | ) 327 | 328 | def forward(self, x): 329 | out = F.relu(self.bn1(self.conv1(x))) 330 | out = self.bn2(self.conv2(out)) 331 | out += self.shortcut(x) 332 | out = F.relu(out) 333 | return out 334 | 335 | 336 | class BottleneckMeta(MetaModule): 337 | expansion = 4 338 | 339 | def __init__(self, in_planes, planes, stride=1, downsample=None): 340 | super(BottleneckMeta, self).__init__() 341 | self.conv1 = MetaConv2d(in_planes, planes, kernel_size=1, bias=False) 342 | self.bn1 = MetaBatchNorm2d(planes) 343 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 344 | self.bn2 = MetaBatchNorm2d(planes) 345 | self.conv3 = MetaConv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 346 | self.bn3 = MetaBatchNorm2d(planes * self.expansion) 347 | self.downsample = downsample 348 | self.stride = stride 349 | 350 | def forward(self, x): 351 | residual = x 352 | 353 | out = self.conv1(x) 354 | out = F.relu(self.bn1(out)) 355 | 356 | out = self.conv2(out) 357 | out = F.relu(self.bn2(out)) 358 | 359 | out = self.conv3(out) 360 | out = self.bn3(out) 361 | 362 | if self.downsample is not None: 363 | residual = self.downsample(x) 364 | 365 | out += residual 366 | out = F.relu(out) 367 | return out 368 | 369 | 370 | class ResNet32(MetaModule): 371 | def __init__(self, num_classes, block=BasicBlock, num_blocks=[5, 5, 5]): 372 | super(ResNet32, self).__init__() 373 | self.in_planes = 16 374 | 375 | self.conv1 = MetaConv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 376 | self.bn1 = MetaBatchNorm2d(16) 377 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 378 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 379 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 380 | self.linear = MetaLinear(64, num_classes) # MetaLinear_Norm(64,num_classes,bias=False) # 381 | 382 | self.apply(_weights_init) 383 | 384 | def _make_layer(self, block, planes, num_blocks, stride): 385 | strides = [stride] + [1]*(num_blocks-1) 386 | layers = [] 387 | for stride in strides: 388 | layers.append(block(self.in_planes, planes, stride)) 389 | self.in_planes = planes * block.expansion 390 | 391 | return nn.Sequential(*layers) 392 | 393 | def forward(self, x): 394 | out = F.relu(self.bn1(self.conv1(x))) 395 | out = self.layer1(out) 396 | out = self.layer2(out) 397 | out = self.layer3(out) 398 | out = F.avg_pool2d(out, out.size()[3]) 399 | out = out.view(out.size(0), -1) 400 | y = self.linear(out) 401 | return out, y 402 | 403 | 404 | class FeatureMeta(MetaModule): 405 | def __init__(self, block, num_blocks, use_fc=False, dropout=None): 406 | super(FeatureMeta, self).__init__() 407 | self.inplanes = 64 408 | 409 | self.conv1 = MetaConv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 410 | self.bn1 = MetaBatchNorm2d(64) 411 | self.relu = nn.ReLU(inplace=True) 412 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 413 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 414 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 415 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 416 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 417 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 418 | # self.avgpool = nn.AvgPool2d(7, stride=1) 419 | self.avgpool = nn.AvgPool2d(7, stride=1) 420 | 421 | self.use_fc = use_fc 422 | self.use_dropout = True if dropout else False 423 | 424 | #if self.use_fc: 425 | #print('Using fc.') 426 | #self.fc = MetaLinear(512*block.expansion, 8142) 427 | 428 | if self.use_dropout: 429 | print('Using dropout') 430 | self.dropout = nn.Dropout(p=dropout) 431 | 432 | for m in self.modules(): 433 | if isinstance(m, MetaConv2d): 434 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 435 | m.weight.data.normal_(0, math.sqrt(2. / n)) 436 | elif isinstance(m, MetaBatchNorm2d): 437 | m.weight.data.fill_(1) 438 | m.bias.data.zero_() 439 | 440 | 441 | def _make_layer(self, block, planes, num_blocks, stride): 442 | downsample = None 443 | if stride != 1 or self.inplanes != planes * block.expansion: 444 | downsample = nn.Sequential( 445 | MetaConv2d(self.inplanes, planes * block.expansion, 446 | kernel_size=1, stride=stride, bias=False), 447 | MetaBatchNorm2d(planes * block.expansion), 448 | ) 449 | 450 | layers = [] 451 | layers.append(block(self.inplanes, planes, stride, downsample)) 452 | self.inplanes = planes * block.expansion 453 | for i in range(1, num_blocks): 454 | layers.append(block(self.inplanes, planes)) 455 | 456 | return nn.Sequential(*layers) 457 | 458 | def forward(self, x): 459 | x = self.conv1(x) 460 | x = self.bn1(x) 461 | x = self.relu(x) 462 | x = self.maxpool(x) 463 | 464 | x = self.layer1(x) 465 | x = self.layer2(x) 466 | x = self.layer3(x) 467 | x = self.layer4(x) 468 | 469 | x = self.avgpool(x) 470 | x = x.view(x.size(0), -1) 471 | 472 | #if self.use_fc: 473 | #y = F.relu(self.fc(x)) 474 | 475 | if self.use_dropout: 476 | x = self.dropout(x) 477 | 478 | return x 479 | 480 | class FCMeta(MetaModule): 481 | def __init__(self, feature_dim=2048, output_dim=1000): 482 | super(FCMeta, self).__init__() 483 | self.fc = MetaLinear(feature_dim, output_dim) 484 | 485 | def forward(self, x): 486 | y = self.fc(x) 487 | return y 488 | 489 | 490 | class FCModel(nn.Module): 491 | def __init__(self, feature_dim, output_dim=1000): 492 | super(FCModel, self).__init__() 493 | self.fc = nn.Linear(feature_dim, output_dim) 494 | 495 | def forward(self, x): 496 | y = self.fc(x) 497 | return y 498 | 499 | 500 | -------------------------------------------------------------------------------- /ImageNet_iNat/scripts/test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=7 python3 test.py --loading_path checkpoints/MetaSAug_ImageNet_LT.tar --dataset ImageNet_LT --num_classes 1000 --data_root ../ImageNet 2 | CUDA_VISIBLE_DEVICES=7 python3 test.py --loading_path checkpoints/MetaSAug_iNat18.tar --dataset iNaturalist18 --num_classes 8142 --data_root ../iNaturalist18 3 | -------------------------------------------------------------------------------- /ImageNet_iNat/scripts/train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port 53212 train.py --lr 0.0003 --meta_lr 0.1 --workers 0 --batch_size 256 --epochs 20 --dataset ImageNet_LT --num_classes 1000 --data_root ../ImageNet 2 | -------------------------------------------------------------------------------- /ImageNet_iNat/test.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | import argparse 5 | import random 6 | import copy 7 | import torch 8 | import torchvision 9 | import numpy as np 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | import torchvision.transforms as transforms 13 | from data_utils import * 14 | from dataloader import load_data_distributed 15 | import shutil 16 | from ResNet import * 17 | import resnet_meta 18 | 19 | import multiprocessing 20 | import torch.nn.parallel 21 | import torch.nn as nn 22 | from collections import Counter 23 | import time 24 | parser = argparse.ArgumentParser(description='Imbalanced Example') 25 | parser.add_argument('--dataset', default='iNaturalist18', type=str, 26 | help='dataset') 27 | parser.add_argument('--data_root', default='/data1/TL/data/iNaturalist18', type=str) 28 | parser.add_argument('--batch_size', type=int, default=32, metavar='N', 29 | help='input batch size for training (default: 64)') 30 | parser.add_argument('--num_classes', type=int, default=8142) 31 | parser.add_argument('--num_meta', type=int, default=10, 32 | help='The number of meta data for each class.') 33 | parser.add_argument('--test_batch_size', type=int, default=512, metavar='N', 34 | help='input batch size for testing (default: 100)') 35 | parser.add_argument('--epochs', type=int, default=200, metavar='N', 36 | help='number of epochs to train') 37 | parser.add_argument('--lr', '--learning-rate', default=1e-1, type=float, 38 | help='initial learning rate') 39 | parser.add_argument('--workers', default=16, type=int) 40 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 41 | parser.add_argument('--nesterov', default=True, type=bool, help='nesterov momentum') 42 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 43 | help='weight decay (default: 5e-4)') 44 | parser.add_argument('--no-cuda', action='store_true', default=False, 45 | help='disables CUDA training') 46 | parser.add_argument('--split', type=int, default=1000) 47 | parser.add_argument('--seed', type=int, default=42, metavar='S', 48 | help='random seed (default: 42)') 49 | parser.add_argument('--print-freq', '-p', default=1000, type=int, 50 | help='print frequency (default: 10)') 51 | parser.add_argument('--gpu', default=None, type=int) 52 | parser.add_argument('--lam', default=0.25, type=float) 53 | parser.add_argument('--local_rank', default=0, type=int) 54 | parser.add_argument('--meta_lr', default=0.1, type=float) 55 | parser.add_argument('--loading_path', default=None, type=str) 56 | 57 | args = parser.parse_args() 58 | # print(args) 59 | for arg in vars(args): 60 | print("{}={}".format(arg, getattr(args, arg))) 61 | 62 | 63 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 64 | 65 | kwargs = {'num_workers': 1, 'pin_memory': True} 66 | use_cuda = not args.no_cuda and torch.cuda.is_available() 67 | 68 | cudnn.benchmark = True 69 | cudnn.enabled = True 70 | torch.manual_seed(args.seed) 71 | device = torch.device("cuda" if use_cuda else "cpu") 72 | 73 | if args.dataset == 'ImageNet_LT': 74 | val_set = load_data_distributed(data_root=args.data_root, dataset=args.dataset, phase="test", batch_size=args.test_batch_size, 75 | num_workers=args.workers, test_open=False, shuffle=False) 76 | else: 77 | val_set = load_data_distributed(data_root=args.data_root, dataset=args.dataset, phase="val", batch_size=args.test_batch_size, 78 | num_workers=args.workers, test_open=False, shuffle=False) 79 | 80 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.test_batch_size, shuffle=False, num_workers=0, pin_memory=True) 81 | 82 | 83 | np.random.seed(42) 84 | random.seed(42) 85 | torch.manual_seed(args.seed) 86 | 87 | data_list = {} 88 | data_list_num = [] 89 | best_prec1 = 0 90 | 91 | model_dict = {"ImageNet_LT": "models/resnet50_uniform_e90.pth", 92 | "iNaturalist18": "models/iNat18/resnet50_uniform_e200.pth"} 93 | 94 | def main(): 95 | global args, best_prec1 96 | args = parser.parse_args() 97 | 98 | cudnn.benchmark = True 99 | 100 | print(torch.cuda.is_available()) 101 | print(torch.cuda.device_count()) 102 | model = FCModel(2048, args.num_classes) 103 | model = model.cuda() 104 | loading_path = args.loading_path 105 | weights = torch.load(loading_path, map_location=torch.device("cpu")) 106 | weights_c = weights['state_dict']['classifier'] 107 | weights_c = {k: weights_c[k] for k in model.state_dict()} 108 | for k in model.state_dict(): 109 | if k not in weights: 110 | print("Loading Weights Warning.") 111 | 112 | model.load_state_dict(weights_c) 113 | feature_extractor = create_model(stage1_weights=True, dataset=args.dataset, log_dir=model_dict[args.dataset]) 114 | weight_f = weights['state_dict']['feature'] 115 | feature_extractor.load_state_dict(weight_f) 116 | feature_extractor = feature_extractor.cuda() 117 | feature_extractor.eval() 118 | 119 | prec1, preds, gt_labels = validate(val_loader, model, feature_extractor, nn.CrossEntropyLoss()) 120 | 121 | print('Accuracy: ', prec1) 122 | 123 | def validate(val_loader, model, feature_extractor, criterion): 124 | """Perform validation on the validation set""" 125 | batch_time = AverageMeter() 126 | losses = AverageMeter() 127 | top1 = AverageMeter() 128 | 129 | true_labels = [] 130 | preds = [] 131 | 132 | torch.cuda.empty_cache() 133 | model.eval() 134 | end = time.time() 135 | for i, (input, target) in enumerate(val_loader): 136 | input = input.cuda(non_blocking=True) 137 | target = target.cuda(non_blocking=True) 138 | # compute output 139 | with torch.no_grad(): 140 | feature = feature_extractor(input) 141 | output = model(feature) 142 | loss = criterion(output, target) 143 | 144 | output_numpy = output.data.cpu().numpy() 145 | preds_output = list(output_numpy.argmax(axis=1)) 146 | 147 | true_labels += list(target.data.cpu().numpy()) 148 | preds += preds_output 149 | 150 | 151 | # measure accuracy and record loss 152 | prec1 = accuracy(output.data, target, topk=(1,))[0] 153 | losses.update(loss.data.item(), input.size(0)) 154 | top1.update(prec1.item(), input.size(0)) 155 | 156 | # measure elapsed time 157 | batch_time.update(time.time() - end) 158 | end = time.time() 159 | 160 | if i % args.print_freq == 0: 161 | print('Test: [{0}/{1}]\t' 162 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 163 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 164 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 165 | i, len(val_loader), batch_time=batch_time, loss=losses, 166 | top1=top1)) 167 | 168 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 169 | # log to TensorBoard 170 | # import pdb; pdb.set_trace() 171 | 172 | return top1.avg, preds, true_labels 173 | 174 | 175 | 176 | def to_var(x, requires_grad=True): 177 | if torch.cuda.is_available(): 178 | x = x.cuda() 179 | return Variable(x, requires_grad=requires_grad) 180 | 181 | 182 | class AverageMeter(object): 183 | """Computes and stores the average and current value""" 184 | 185 | def __init__(self): 186 | self.reset() 187 | 188 | def reset(self): 189 | self.val = 0 190 | self.avg = 0 191 | self.sum = 0 192 | self.count = 0 193 | 194 | def update(self, val, n=1): 195 | self.val = val 196 | self.sum += val * n 197 | self.count += n 198 | self.avg = self.sum / self.count 199 | 200 | def accuracy(output, target, topk=(1,)): 201 | """Computes the precision@k for the specified values of k""" 202 | maxk = max(topk) 203 | batch_size = target.size(0) 204 | 205 | _, pred = output.topk(maxk, 1, True, True) 206 | pred = pred.t() 207 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 208 | 209 | res = [] 210 | for k in topk: 211 | correct_k = correct[:k].view(-1).float().sum(0) 212 | res.append(correct_k.mul_(100.0 / batch_size)) 213 | return res 214 | 215 | 216 | def save_checkpoint(args, state, is_best, epoch): 217 | filename = 'checkpoint/' + 'train_' + str(args.dataset) + '/' + str(args.lr) + '_' + str(args.batch_size) + '_' + str(args.meta_lr) + 'epoch' + str(epoch) + '_ckpt.pth.tar' 218 | file_root, _ = os.path.split(filename) 219 | if not os.path.exists(file_root): 220 | os.makedirs(file_root) 221 | torch.save(state, filename) 222 | 223 | 224 | if __name__ == '__main__': 225 | main() 226 | -------------------------------------------------------------------------------- /ImageNet_iNat/train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | import argparse 5 | import random 6 | import copy 7 | import torch 8 | import torchvision 9 | import numpy as np 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | import torchvision.transforms as transforms 13 | from data_utils import * 14 | # import resnet 15 | from dataloader import load_data_distributed 16 | import shutil 17 | from ResNet import * 18 | import loss 19 | import multiprocessing 20 | import torch.nn.parallel 21 | import torch.nn as nn 22 | from collections import Counter 23 | import time 24 | parser = argparse.ArgumentParser(description='Imbalanced Example') 25 | parser.add_argument('--dataset', default='iNaturalist18', type=str, 26 | help='dataset') 27 | parser.add_argument('--data_root', default='/data1/TL/data/iNaturalist18', type=str) 28 | parser.add_argument('--batch_size', type=int, default=32, metavar='N', 29 | help='input batch size for training (default: 64)') 30 | parser.add_argument('--num_classes', type=int, default=8142) 31 | parser.add_argument('--num_meta', type=int, default=10, 32 | help='The number of meta data for each class.') 33 | parser.add_argument('--test_batch_size', type=int, default=512, metavar='N', 34 | help='input batch size for testing (default: 100)') 35 | parser.add_argument('--epochs', type=int, default=200, metavar='N', 36 | help='number of epochs to train') 37 | parser.add_argument('--lr', '--learning-rate', default=1e-1, type=float, 38 | help='initial learning rate') 39 | parser.add_argument('--workers', default=16, type=int) 40 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 41 | parser.add_argument('--nesterov', default=True, type=bool, help='nesterov momentum') 42 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 43 | help='weight decay (default: 5e-4)') 44 | parser.add_argument('--no-cuda', action='store_true', default=False, 45 | help='disables CUDA training') 46 | parser.add_argument('--split', type=int, default=1000) 47 | parser.add_argument('--seed', type=int, default=42, metavar='S', 48 | help='random seed (default: 42)') 49 | parser.add_argument('--print-freq', '-p', default=1000, type=int, 50 | help='print frequency (default: 10)') 51 | parser.add_argument('--gpu', default=None, type=int) 52 | parser.add_argument('--lam', default=0.25, type=float) 53 | parser.add_argument('--local_rank', default=0, type=int) 54 | parser.add_argument('--meta_lr', default=0.1, type=float) 55 | 56 | args = parser.parse_args() 57 | # print(args) 58 | for arg in vars(args): 59 | print("{}={}".format(arg, getattr(args, arg))) 60 | 61 | 62 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 63 | 64 | kwargs = {'num_workers': 1, 'pin_memory': True} 65 | use_cuda = not args.no_cuda and torch.cuda.is_available() 66 | 67 | cudnn.benchmark = True 68 | cudnn.enabled = True 69 | torch.manual_seed(args.seed) 70 | device = torch.device("cuda" if use_cuda else "cpu") 71 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 72 | 73 | print(f'num_gpus: {num_gpus}') 74 | args.distributed = num_gpus > 1 75 | print("ditributed: {args.distributed}") 76 | if args.distributed: 77 | torch.cuda.set_device(args.local_rank) 78 | #torch.distributed.init_process_group(backend="nccl", init_method="env://") 79 | torch.distributed.init_process_group(backend="nccl") 80 | args.batch_size = int(args.batch_size / num_gpus) 81 | 82 | 83 | ######### ImageNet dataset 84 | splits = ["train", "val", "test"] 85 | if args.dataset == 'ImageNet_LT': 86 | train_set = load_data_distributed(data_root=args.data_root, dataset=args.dataset, phase="train", batch_size=args.batch_size, 87 | num_workers=args.workers, test_open=False, shuffle=False) 88 | val_set = load_data_distributed(data_root=args.data_root, dataset=args.dataset, phase="test", batch_size=args.test_batch_size, 89 | num_workers=args.workers, test_open=False, shuffle=False) 90 | 91 | meta_set = load_data_distributed(data_root=args.data_root, dataset=args.dataset, phase="val", batch_size=args.batch_size, num_workers=args.workers, test_open=False, shuffle=False) 92 | 93 | else: 94 | train_set = load_data_distributed(data_root=args.data_root, dataset=args.dataset, phase="train", batch_size=args.batch_size, 95 | num_workers=args.workers, test_open=False, shuffle=False) 96 | val_set = load_data_distributed(data_root=args.data_root, dataset=args.dataset, phase="val", batch_size=args.test_batch_size, 97 | num_workers=args.workers, test_open=False, shuffle=False) 98 | meta_set = train_set 99 | 100 | if args.dataset == 'iNaturalist17': 101 | meta_set, _ = build_dataset(meta_set, 5, args.num_classes) 102 | elif args.dataset == 'iNaturalist18': 103 | meta_set, _ = build_dataset(meta_set, 2, args.num_classes) 104 | else: 105 | meta_set, _ = build_dataset(meta_set, 10, args.num_classes) 106 | 107 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) 108 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=0, 109 | pin_memory=True, sampler=train_sampler) 110 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.test_batch_size, shuffle=False, num_workers=0, pin_memory=True) 111 | meta_sampler = torch.utils.data.distributed.DistributedSampler(meta_set) 112 | meta_loader = torch.utils.data.DataLoader(meta_set, batch_size=args.batch_size, shuffle=(meta_sampler is None), num_workers=0, 113 | pin_memory=True, sampler=meta_sampler) 114 | 115 | 116 | np.random.seed(42) 117 | random.seed(42) 118 | torch.manual_seed(args.seed) 119 | classe_labels = range(args.num_classes) 120 | 121 | data_list = {} 122 | data_list_num = [] 123 | num = Counter(train_loader.dataset.labels) 124 | data_list_num = [0] * args.num_classes 125 | for key in num: 126 | data_list_num[key] = num[key] 127 | 128 | beta = 0.9999 129 | effective_num = 1.0 - np.power(beta, data_list_num) 130 | per_cls_weights = (1.0 - beta) / np.array(effective_num) 131 | per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(data_list_num) 132 | per_cls_weights = torch.FloatTensor(per_cls_weights).cuda() 133 | 134 | model_dict = {"ImageNet_LT": "models/resnet50_uniform_e90.pth", 135 | "iNaturalist18": "models/iNat18/resnet50_uniform_e200.pth"} 136 | 137 | 138 | def main(): 139 | global args 140 | args = parser.parse_args() 141 | 142 | cudnn.benchmark = True 143 | print(torch.cuda.is_available()) 144 | print(torch.cuda.device_count()) 145 | print(f'local_rank: {args.local_rank}') 146 | model = FCModel(2048, args.num_classes) 147 | model = model.cuda() 148 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) 149 | weights = torch.load(model_dict[args.dataset], map_location=torch.device("cpu")) 150 | weights = weights['state_dict_best']['classifier'] 151 | weights = {k: weights['module.' + k] for k in model.module.state_dict()} 152 | for k in model.module.state_dict(): 153 | if k not in weights: 154 | print("Pretrained Weights Warning.") 155 | 156 | model.module.load_state_dict(weights) 157 | feature_extractor = create_model(stage1_weights=True, dataset=args.dataset, log_dir=model_dict[args.dataset]) 158 | feature_extractor = feature_extractor.cuda() 159 | feature_extractor.eval() 160 | 161 | 162 | torch.autograd.set_detect_anomaly(True) 163 | torch.distributed.barrier() 164 | 165 | optimizer_a = torch.optim.SGD(model.module.parameters(), args.lr, 166 | momentum=args.momentum, nesterov=args.nesterov, 167 | weight_decay=args.weight_decay) 168 | 169 | criterion = loss.Loss_meta(2048, args.num_classes) 170 | for epoch in range(args.epochs): 171 | ratio = args.lam * float(epoch) / float(args.epochs) 172 | train_meta(train_loader, model, feature_extractor, optimizer_a, epoch, criterion, ratio) 173 | 174 | if args.local_rank == 0: 175 | save_checkpoint(args, { 176 | 'epoch': epoch + 1, 177 | 'state_dict': {'feature': feature_extractor.state_dict(), 'classifier': model.module.state_dict()}, 178 | 'optimizer' : optimizer_a.state_dict(), 179 | }, False, epoch) 180 | 181 | def train_meta(train_loader, model, feature_extractor, optimizer_a, epoch, criterion, ratio): 182 | """Experimenting how to train stably in stage-2""" 183 | batch_time = AverageMeter() 184 | losses = AverageMeter() 185 | meta_losses = AverageMeter() 186 | top1 = AverageMeter() 187 | meta_top1 = AverageMeter() 188 | model.train() 189 | weights = torch.tensor(per_cls_weights).float() 190 | for i, (input, target) in enumerate(train_loader): 191 | 192 | input_var = input.cuda(non_blocking=True) 193 | target_var = target.cuda(non_blocking=True) 194 | cv = criterion.get_cv() 195 | cv_var = to_var(cv) 196 | 197 | meta_model = FCMeta(2048, args.num_classes) 198 | meta_model.load_state_dict(model.module.state_dict()) 199 | meta_model.cuda() 200 | 201 | with torch.no_grad(): 202 | feat_hat = feature_extractor(input_var) 203 | y_f_hat = meta_model(feat_hat) 204 | cls_loss_meta = criterion(list(meta_model.fc.named_leaves())[0][1], feat_hat, y_f_hat, target_var, ratio, 205 | weights, cv_var, "none") 206 | meta_model.zero_grad() 207 | grads = torch.autograd.grad(cls_loss_meta, (meta_model.params()), create_graph=True) 208 | meta_lr = args.lr 209 | meta_model.fc.update_params(meta_lr, source_params=grads) 210 | 211 | input_val, target_val = next(iter(meta_loader)) 212 | input_val_var = input_val.cuda(non_blocking=True) 213 | target_val_var = target_val.cuda(non_blocking=True) 214 | 215 | with torch.no_grad(): 216 | feature_val = feature_extractor(input_val_var) 217 | y_val = meta_model(feature_val) 218 | cls_meta = F.cross_entropy(y_val, target_val_var) 219 | grad_cv = torch.autograd.grad(cls_meta, cv_var, only_inputs=True)[0] 220 | new_cv = cv - args.meta_lr * grad_cv 221 | 222 | del grad_cv, grads, meta_model 223 | with torch.no_grad(): 224 | features = feature_extractor(input_var) 225 | predicts = model(features) 226 | cls_loss = criterion(list(model.module.fc.parameters())[0], features, predicts, target_var, ratio, weights, new_cv.detach(), "update") 227 | 228 | prec_train = accuracy(predicts.data, target_var.data, topk=(1,))[0] 229 | 230 | losses.update(cls_loss.item(), input.size(0)) 231 | top1.update(prec_train.item(), input.size(0)) 232 | 233 | optimizer_a.zero_grad() 234 | cls_loss.backward() 235 | optimizer_a.step() 236 | 237 | if i % args.print_freq == 0: 238 | print('Epoch: [{0}][{1}/{2}]\t' 239 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 240 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 241 | epoch, i, len(train_loader), 242 | loss=losses,top1=top1)) 243 | 244 | 245 | 246 | def validate(val_loader, model, feature_extractor, criterion, epoch, local_rank, distributed): 247 | """Perform validation on the validation set""" 248 | batch_time = AverageMeter() 249 | losses = AverageMeter() 250 | top1 = AverageMeter() 251 | 252 | # switch to evaluate mode 253 | 254 | true_labels = [] 255 | preds = [] 256 | if distributed: 257 | model = model.module 258 | torch.cuda.empty_cache() 259 | 260 | model.eval() 261 | end = time.time() 262 | for i, (input, target) in enumerate(val_loader): 263 | input = input.cuda(non_blocking=True) 264 | target = target.cuda(non_blocking=True) 265 | # compute output 266 | with torch.no_grad(): 267 | feature = feature_extractor(input) 268 | output = model(feature) 269 | loss = criterion(output, target) 270 | 271 | output_numpy = output.data.cpu().numpy() 272 | preds_output = list(output_numpy.argmax(axis=1)) 273 | 274 | true_labels += list(target.data.cpu().numpy()) 275 | preds += preds_output 276 | 277 | # measure accuracy and record loss 278 | prec1 = accuracy(output.data, target, topk=(1,))[0] 279 | losses.update(loss.data.item(), input.size(0)) 280 | top1.update(prec1.item(), input.size(0)) 281 | 282 | # measure elapsed time 283 | batch_time.update(time.time() - end) 284 | end = time.time() 285 | 286 | if i % args.print_freq == 0 and local_rank==0: 287 | print('Test: [{0}/{1}]\t' 288 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 289 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 290 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 291 | i, len(val_loader), batch_time=batch_time, loss=losses, 292 | top1=top1)) 293 | 294 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 295 | return top1.avg, preds, true_labels 296 | 297 | 298 | def to_var(x, requires_grad=True): 299 | if torch.cuda.is_available(): 300 | x = x.cuda() 301 | return Variable(x, requires_grad=requires_grad) 302 | 303 | 304 | class AverageMeter(object): 305 | """Computes and stores the average and current value""" 306 | 307 | def __init__(self): 308 | self.reset() 309 | 310 | def reset(self): 311 | self.val = 0 312 | self.avg = 0 313 | self.sum = 0 314 | self.count = 0 315 | 316 | def update(self, val, n=1): 317 | self.val = val 318 | self.sum += val * n 319 | self.count += n 320 | self.avg = self.sum / self.count 321 | 322 | 323 | def accuracy(output, target, topk=(1,)): 324 | """Computes the precision@k for the specified values of k""" 325 | maxk = max(topk) 326 | batch_size = target.size(0) 327 | 328 | _, pred = output.topk(maxk, 1, True, True) 329 | pred = pred.t() 330 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 331 | 332 | res = [] 333 | for k in topk: 334 | correct_k = correct[:k].view(-1).float().sum(0) 335 | res.append(correct_k.mul_(100.0 / batch_size)) 336 | return res 337 | 338 | 339 | def save_checkpoint(args, state, is_best, epoch): 340 | filename = 'checkpoint/' + 'train_' + str(args.dataset) + '/' + str(args.lr) + '_' + str(args.batch_size) + '_' + str(args.meta_lr) + 'epoch' + str(epoch) + '_ckpt.pth.tar' 341 | file_root, _ = os.path.split(filename) 342 | if not os.path.exists(file_root): 343 | os.makedirs(file_root) 344 | torch.save(state, filename) 345 | 346 | 347 | if __name__ == '__main__': 348 | main() 349 | -------------------------------------------------------------------------------- /ImageNet_iNat/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | from sklearn.metrics import f1_score 5 | import torch.nn.functional as F 6 | import importlib 7 | 8 | 9 | def source_import(file_path): 10 | """This function imports python module directly from source code using importlib""" 11 | spec = importlib.util.spec_from_file_location('', file_path) 12 | module = importlib.util.module_from_spec(spec) 13 | spec.loader.exec_module(module) 14 | return module 15 | 16 | 17 | def batch_show(inp, title=None): 18 | """Imshow for Tensor.""" 19 | inp = inp.numpy().transpose((1, 2, 0)) 20 | mean = np.array([0.485, 0.456, 0.406]) 21 | std = np.array([0.229, 0.224, 0.225]) 22 | inp = std * inp + mean 23 | inp = np.clip(inp, 0, 1) 24 | plt.figure(figsize=(20,20)) 25 | plt.imshow(inp) 26 | if title is not None: 27 | plt.title(title) 28 | 29 | 30 | def print_write(print_str, log_file): 31 | print(*print_str) 32 | if log_file is None: 33 | return 34 | with open(log_file, 'a') as f: 35 | print(*print_str, file=f) 36 | 37 | 38 | def init_weights(model, weights_path, caffe=False, classifier=False): 39 | """Initialize weights""" 40 | print('Pretrained %s weights path: %s' % ('classifier' if classifier else 'feature model', weights_path)) 41 | weights = torch.load(weights_path, map_location=torch.device("cpu")) 42 | weights_c = weights['state_dict_best']['classifier'] 43 | weights_f = weights['state_dict_best']['feat_model'] 44 | if not classifier: 45 | if caffe: 46 | weights = {k: weights[k] if k in weights else model.state_dict()[k] 47 | for k in model.state_dict()} 48 | else: 49 | weights = weights['state_dict_best']['feat_model'] 50 | weights = {k: weights['module.' + k] for k in model.state_dict()} 51 | else: 52 | weights = weights['state_dict_best']['classifier'] 53 | weights = {k: weights['module.' + k] if 'module.' + k in weights else model.state_dict()[k] 54 | for k in model.state_dict()} 55 | model.load_state_dict(weights) 56 | return model 57 | 58 | 59 | def shot_acc(preds, labels, train_data, many_shot_thr=100, low_shot_thr=20, acc_per_cls=False): 60 | if isinstance(train_data, np.ndarray): 61 | training_labels = np.array(train_data).astype(int) 62 | else: 63 | training_labels = np.array(train_data.dataset.labels).astype(int) 64 | 65 | if isinstance(preds, torch.Tensor): 66 | preds = preds.detach().cpu().numpy() 67 | labels = labels.detach().cpu().numpy() 68 | elif isinstance(preds, np.ndarray): 69 | pass 70 | else: 71 | raise TypeError('Type ({}) of preds not supported'.format(type(preds))) 72 | train_class_count = [] 73 | test_class_count = [] 74 | class_correct = [] 75 | for l in np.unique(labels): 76 | train_class_count.append(len(training_labels[training_labels == l])) 77 | test_class_count.append(len(labels[labels == l])) 78 | class_correct.append((preds[labels == l] == labels[labels == l]).sum()) 79 | 80 | many_shot = [] 81 | median_shot = [] 82 | low_shot = [] 83 | for i in range(len(train_class_count)): 84 | if train_class_count[i] > many_shot_thr: 85 | many_shot.append((class_correct[i] / test_class_count[i])) 86 | elif train_class_count[i] < low_shot_thr: 87 | low_shot.append((class_correct[i] / test_class_count[i])) 88 | else: 89 | median_shot.append((class_correct[i] / test_class_count[i])) 90 | 91 | if len(many_shot) == 0: 92 | many_shot.append(0) 93 | if len(median_shot) == 0: 94 | median_shot.append(0) 95 | if len(low_shot) == 0: 96 | low_shot.append(0) 97 | 98 | if acc_per_cls: 99 | class_accs = [c / cnt for c, cnt in zip(class_correct, test_class_count)] 100 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot), class_accs 101 | else: 102 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot) 103 | 104 | 105 | def weighted_shot_acc(preds, labels, ws, train_data, many_shot_thr=100, low_shot_thr=20): 106 | training_labels = np.array(train_data.dataset.labels).astype(int) 107 | 108 | if isinstance(preds, torch.Tensor): 109 | preds = preds.detach().cpu().numpy() 110 | labels = labels.detach().cpu().numpy() 111 | elif isinstance(preds, np.ndarray): 112 | pass 113 | else: 114 | raise TypeError('Type ({}) of preds not supported'.format(type(preds))) 115 | train_class_count = [] 116 | test_class_count = [] 117 | class_correct = [] 118 | for l in np.unique(labels): 119 | train_class_count.append(len(training_labels[training_labels == l])) 120 | test_class_count.append(ws[labels==l].sum()) 121 | class_correct.append(((preds[labels==l] == labels[labels==l]) * ws[labels==l]).sum()) 122 | 123 | many_shot = [] 124 | median_shot = [] 125 | low_shot = [] 126 | for i in range(len(train_class_count)): 127 | if train_class_count[i] > many_shot_thr: 128 | many_shot.append((class_correct[i] / test_class_count[i])) 129 | elif train_class_count[i] < low_shot_thr: 130 | low_shot.append((class_correct[i] / test_class_count[i])) 131 | else: 132 | median_shot.append((class_correct[i] / test_class_count[i])) 133 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot) 134 | 135 | 136 | def F_measure(preds, labels, openset=False, theta=None): 137 | if openset: 138 | # f1 score for openset evaluation 139 | true_pos = 0. 140 | false_pos = 0. 141 | false_neg = 0. 142 | 143 | for i in range(len(labels)): 144 | true_pos += 1 if preds[i] == labels[i] and labels[i] != -1 else 0 145 | false_pos += 1 if preds[i] != labels[i] and labels[i] != -1 and preds[i] != -1 else 0 146 | false_neg += 1 if preds[i] != labels[i] and labels[i] == -1 else 0 147 | 148 | precision = true_pos / (true_pos + false_pos) 149 | recall = true_pos / (true_pos + false_neg) 150 | return 2 * ((precision * recall) / (precision + recall + 1e-12)) 151 | else: 152 | # Regular f1 score 153 | return f1_score(labels.detach().cpu().numpy(), preds.detach().cpu().numpy(), average='macro') 154 | 155 | 156 | def mic_acc_cal(preds, labels): 157 | if isinstance(labels, tuple): 158 | assert len(labels) == 3 159 | targets_a, targets_b, lam = labels 160 | acc_mic_top1 = (lam * preds.eq(targets_a.data).cpu().sum().float() \ 161 | + (1 - lam) * preds.eq(targets_b.data).cpu().sum().float()) / len(preds) 162 | else: 163 | acc_mic_top1 = (preds == labels).sum().item() / len(labels) 164 | return acc_mic_top1 165 | 166 | 167 | def weighted_mic_acc_cal(preds, labels, ws): 168 | acc_mic_top1 = ws[preds == labels].sum() / ws.sum() 169 | return acc_mic_top1 170 | 171 | 172 | def class_count(data): 173 | labels = np.array(data.dataset.labels) 174 | class_data_num = [] 175 | for l in np.unique(labels): 176 | class_data_num.append(len(labels[labels == l])) 177 | return class_data_num 178 | 179 | 180 | def torch2numpy(x): 181 | if isinstance(x, torch.Tensor): 182 | return x.detach().cpu().numpy() 183 | elif isinstance(x, (list, tuple)): 184 | return tuple([torch2numpy(xi) for xi in x]) 185 | else: 186 | return x 187 | 188 | 189 | def logits2score(logits, labels): 190 | scores = F.softmax(logits, dim=1) 191 | score = scores.gather(1, labels.view(-1, 1)) 192 | score = score.squeeze().cpu().numpy() 193 | return score 194 | 195 | 196 | def logits2entropy(logits): 197 | scores = F.softmax(logits, dim=1) 198 | scores = scores.cpu().numpy() + 1e-30 199 | ent = -scores * np.log(scores) 200 | ent = np.sum(ent, 1) 201 | return ent 202 | 203 | 204 | def logits2CE(logits, labels): 205 | scores = F.softmax(logits, dim=1) 206 | score = scores.gather(1, labels.view(-1, 1)) 207 | score = score.squeeze().cpu().numpy() + 1e-30 208 | ce = -np.log(score) 209 | return ce 210 | 211 | 212 | def get_priority(ptype, logits, labels): 213 | if ptype == 'score': 214 | ws = 1 - logits2score(logits, labels) 215 | elif ptype == 'entropy': 216 | ws = logits2entropy(logits) 217 | elif ptype == 'CE': 218 | ws = logits2CE(logits, labels) 219 | 220 | return ws 221 | 222 | 223 | def get_value(oldv, newv): 224 | if newv is not None: 225 | return newv 226 | else: 227 | return oldv 228 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 BIT-DA 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MetaSAug_LDAM_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import random 5 | import copy 6 | import torch 7 | import torchvision 8 | import numpy as np 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | import torchvision.transforms as transforms 12 | from data_utils import * 13 | from resnet import * 14 | import shutil 15 | from loss import * 16 | 17 | parser = argparse.ArgumentParser(description='Imbalanced Example') 18 | parser.add_argument('--dataset', default='cifar100', type=str, 19 | help='dataset (cifar10 or cifar100[default])') 20 | parser.add_argument('--batch-size', type=int, default=100, metavar='N', 21 | help='input batch size for training (default: 64)') 22 | parser.add_argument('--num_classes', type=int, default=100) 23 | parser.add_argument('--num_meta', type=int, default=10, 24 | help='The number of meta data for each class.') 25 | parser.add_argument('--imb_factor', type=float, default=0.005) 26 | parser.add_argument('--test-batch-size', type=int, default=100, metavar='N', 27 | help='input batch size for testing (default: 100)') 28 | parser.add_argument('--epochs', type=int, default=200, metavar='N', 29 | help='number of epochs to train') 30 | parser.add_argument('--lr', '--learning-rate', default=1e-1, type=float, 31 | help='initial learning rate') 32 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 33 | parser.add_argument('--nesterov', default=True, type=bool, help='nesterov momentum') 34 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 35 | help='weight decay (default: 5e-4)') 36 | parser.add_argument('--no-cuda', action='store_true', default=False, 37 | help='disables CUDA training') 38 | parser.add_argument('--split', type=int, default=1000) 39 | parser.add_argument('--seed', type=int, default=42, metavar='S', 40 | help='random seed (default: 42)') 41 | parser.add_argument('--print-freq', '-p', default=100, type=int, 42 | help='print frequency (default: 10)') 43 | parser.add_argument('--lam', default=0.25, type=float, help='[0.25, 0.5, 0.75, 1.0]') 44 | parser.add_argument('--gpu', default=0, type=int) 45 | parser.add_argument('--meta_lr', default=0.1, type=float) 46 | parser.add_argument('--save_name', default='name', type=str) 47 | parser.add_argument('--idx', default='0', type=str) 48 | 49 | 50 | args = parser.parse_args() 51 | for arg in vars(args): 52 | print("{}={}".format(arg, getattr(args, arg))) 53 | 54 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 55 | os.environ["CUDA_VISIBLE_DEVICES"]= str(args.gpu) 56 | kwargs = {'num_workers': 1, 'pin_memory': False} 57 | use_cuda = not args.no_cuda and torch.cuda.is_available() 58 | 59 | torch.manual_seed(args.seed) 60 | device = torch.device("cuda" if use_cuda else "cpu") 61 | 62 | train_data_meta, train_data, test_dataset = build_dataset(args.dataset, args.num_meta) 63 | 64 | print(f'length of meta dataset:{len(train_data_meta)}') 65 | print(f'length of train dataset: {len(train_data)}') 66 | 67 | train_loader = torch.utils.data.DataLoader( 68 | train_data, batch_size=args.batch_size, shuffle=True, **kwargs) 69 | 70 | np.random.seed(42) 71 | random.seed(42) 72 | torch.manual_seed(args.seed) 73 | classe_labels = range(args.num_classes) 74 | 75 | data_list = {} 76 | 77 | 78 | for j in range(args.num_classes): 79 | data_list[j] = [i for i, label in enumerate(train_loader.dataset.targets) if label == j] 80 | 81 | 82 | img_num_list = get_img_num_per_cls(args.dataset, args.imb_factor, args.num_meta*args.num_classes) 83 | print(img_num_list) 84 | print(sum(img_num_list)) 85 | 86 | im_data = {} 87 | idx_to_del = [] 88 | for cls_idx, img_id_list in data_list.items(): 89 | random.shuffle(img_id_list) 90 | img_num = img_num_list[int(cls_idx)] 91 | im_data[cls_idx] = img_id_list[img_num:] 92 | idx_to_del.extend(img_id_list[img_num:]) 93 | 94 | print(len(idx_to_del)) 95 | imbalanced_train_dataset = copy.deepcopy(train_data) 96 | imbalanced_train_dataset.targets = np.delete(train_loader.dataset.targets, idx_to_del, axis=0) 97 | imbalanced_train_dataset.data = np.delete(train_loader.dataset.data, idx_to_del, axis=0) 98 | print(len(imbalanced_train_dataset)) 99 | imbalanced_train_loader = torch.utils.data.DataLoader( 100 | imbalanced_train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) 101 | 102 | validation_loader = torch.utils.data.DataLoader( 103 | train_data_meta, batch_size=args.batch_size, shuffle=True, **kwargs) 104 | 105 | test_loader = torch.utils.data.DataLoader( 106 | test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs) 107 | 108 | best_prec1 = 0 109 | 110 | beta = 0.9999 111 | effective_num = 1.0 - np.power(beta, img_num_list) 112 | per_cls_weights = (1.0 - beta) / np.array(effective_num) 113 | per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(img_num_list) 114 | per_cls_weights = torch.FloatTensor(per_cls_weights).cuda() 115 | weights = torch.tensor(per_cls_weights).float() 116 | 117 | def main(): 118 | global args, best_prec1 119 | args = parser.parse_args() 120 | 121 | model = build_model() 122 | optimizer_a = torch.optim.SGD(model.params(), args.lr, 123 | momentum=args.momentum, nesterov=args.nesterov, 124 | weight_decay=args.weight_decay) 125 | 126 | cudnn.benchmark = True 127 | 128 | criterion = LDAM_meta(64, args.dataset == "cifar10" and 10 or 100, cls_num_list=img_num_list, 129 | max_m=0.5, s=30) 130 | 131 | for epoch in range(args.epochs): 132 | adjust_learning_rate(optimizer_a, epoch + 1) 133 | 134 | ratio = args.lam * float(epoch) / float(args.epochs) 135 | if epoch < 160: 136 | train(imbalanced_train_loader, model, optimizer_a, epoch) 137 | 138 | else: 139 | train_MetaSAug(imbalanced_train_loader, validation_loader, model, optimizer_a, epoch, criterion, ratio) 140 | 141 | 142 | prec1, preds, gt_labels = validate(test_loader, model, nn.CrossEntropyLoss().cuda(), epoch) 143 | 144 | is_best = prec1 > best_prec1 145 | best_prec1 = max(prec1, best_prec1) 146 | 147 | # save_checkpoint(args, { 148 | # 'epoch': epoch + 1, 149 | # 'state_dict': model.state_dict(), 150 | # 'best_acc1': best_prec1, 151 | # 'optimizer': optimizer_a.state_dict(), 152 | # }, is_best) 153 | 154 | print('Best accuracy: ', best_prec1) 155 | 156 | 157 | def train(train_loader, model, optimizer_a, epoch): 158 | 159 | losses = AverageMeter() 160 | top1 = AverageMeter() 161 | model.train() 162 | 163 | for i, (input, target) in enumerate(train_loader): 164 | input_var = to_var(input, requires_grad=False) 165 | target_var = to_var(target, requires_grad=False) 166 | 167 | _, y_f = model(input_var) 168 | del _ 169 | cost_w = F.cross_entropy(y_f, target_var, reduce=False) 170 | l_f = torch.mean(cost_w) 171 | prec_train = accuracy(y_f.data, target_var.data, topk=(1,))[0] 172 | 173 | 174 | losses.update(l_f.item(), input.size(0)) 175 | top1.update(prec_train.item(), input.size(0)) 176 | 177 | optimizer_a.zero_grad() 178 | l_f.backward() 179 | optimizer_a.step() 180 | 181 | if i % args.print_freq == 0: 182 | print('Epoch: [{0}][{1}/{2}]\t' 183 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 184 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 185 | epoch, i, len(train_loader), 186 | loss=losses,top1=top1)) 187 | 188 | 189 | 190 | def train_MetaSAug(train_loader, validation_loader, model,optimizer_a, epoch, criterion, ratio): 191 | 192 | losses = AverageMeter() 193 | top1 = AverageMeter() 194 | model.train() 195 | 196 | for i, (input, target) in enumerate(train_loader): 197 | 198 | input_var = to_var(input, requires_grad=False) 199 | target_var = to_var(target, requires_grad=False) 200 | 201 | cv = criterion.get_cv() 202 | cv_var = to_var(cv) 203 | 204 | meta_model = ResNet32(args.dataset == 'cifar10' and 10 or 100) 205 | meta_model.load_state_dict(model.state_dict()) 206 | meta_model.cuda() 207 | 208 | feat_hat, y_f_hat = meta_model(input_var) 209 | cls_loss_meta = criterion(meta_model.linear, feat_hat, y_f_hat, target_var, ratio, 210 | weights, cv_var, "none") 211 | meta_model.zero_grad() 212 | 213 | grads = torch.autograd.grad(cls_loss_meta, (meta_model.params()), create_graph=True) 214 | meta_lr = args.lr * ((0.01 ** int(epoch >= 160)) * (0.01 ** int(epoch >= 180))) 215 | meta_model.update_params(meta_lr, source_params=grads) 216 | 217 | input_val, target_val = next(iter(validation_loader)) 218 | input_val_var = to_var(input_val, requires_grad=False) 219 | target_val_var = to_var(target_val, requires_grad=False) 220 | 221 | _, y_val = meta_model(input_val_var) 222 | cls_meta = F.cross_entropy(y_val, target_val_var) 223 | grad_cv = torch.autograd.grad(cls_meta, cv_var, only_inputs=True)[0] 224 | new_cv = cv_var - args.meta_lr * grad_cv 225 | 226 | del grad_cv, grads 227 | 228 | #model.train() 229 | features, predicts = model(input_var) 230 | cls_loss = criterion(model.linear, features, predicts, target_var, ratio, weights, new_cv, "update") 231 | 232 | prec_train = accuracy(predicts.data, target_var.data, topk=(1,))[0] 233 | 234 | 235 | losses.update(cls_loss.item(), input.size(0)) 236 | top1.update(prec_train.item(), input.size(0)) 237 | 238 | optimizer_a.zero_grad() 239 | cls_loss.backward() 240 | optimizer_a.step() 241 | 242 | if i % args.print_freq == 0: 243 | print('Epoch: [{0}][{1}/{2}]\t' 244 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 245 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 246 | epoch, i, len(train_loader), 247 | loss=losses,top1=top1)) 248 | 249 | 250 | def validate(val_loader, model, criterion, epoch): 251 | batch_time = AverageMeter() 252 | losses = AverageMeter() 253 | top1 = AverageMeter() 254 | 255 | model.eval() 256 | 257 | true_labels = [] 258 | preds = [] 259 | 260 | end = time.time() 261 | for i, (input, target) in enumerate(val_loader): 262 | target = target.cuda() 263 | input = input.cuda() 264 | input_var = torch.autograd.Variable(input) 265 | target_var = torch.autograd.Variable(target) 266 | 267 | with torch.no_grad(): 268 | _, output = model(input_var) 269 | 270 | output_numpy = output.data.cpu().numpy() 271 | preds_output = list(output_numpy.argmax(axis=1)) 272 | 273 | true_labels += list(target_var.data.cpu().numpy()) 274 | preds += preds_output 275 | 276 | 277 | prec1 = accuracy(output.data, target, topk=(1,))[0] 278 | top1.update(prec1.item(), input.size(0)) 279 | 280 | batch_time.update(time.time() - end) 281 | end = time.time() 282 | 283 | if i % args.print_freq == 0: 284 | print('Test: [{0}/{1}]\t' 285 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 286 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 287 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 288 | i, len(val_loader), batch_time=batch_time, loss=losses, 289 | top1=top1)) 290 | 291 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 292 | 293 | return top1.avg, preds, true_labels 294 | 295 | 296 | def build_model(): 297 | model = ResNet32(args.dataset == 'cifar10' and 10 or 100) 298 | 299 | if torch.cuda.is_available(): 300 | model.cuda() 301 | torch.backends.cudnn.benchmark = True 302 | 303 | 304 | return model 305 | 306 | def to_var(x, requires_grad=True): 307 | if torch.cuda.is_available(): 308 | x = x.cuda() 309 | return Variable(x, requires_grad=requires_grad) 310 | 311 | 312 | class AverageMeter(object): 313 | 314 | def __init__(self): 315 | self.reset() 316 | 317 | def reset(self): 318 | self.val = 0 319 | self.avg = 0 320 | self.sum = 0 321 | self.count = 0 322 | 323 | def update(self, val, n=1): 324 | self.val = val 325 | self.sum += val * n 326 | self.count += n 327 | self.avg = self.sum / self.count 328 | 329 | 330 | def adjust_learning_rate(optimizer, epoch): 331 | lr = args.lr * ((0.01 ** int(epoch >= 160)) * (0.01 ** int(epoch >= 180))) 332 | for param_group in optimizer.param_groups: 333 | param_group['lr'] = lr 334 | 335 | 336 | def accuracy(output, target, topk=(1,)): 337 | maxk = max(topk) 338 | batch_size = target.size(0) 339 | 340 | _, pred = output.topk(maxk, 1, True, True) 341 | pred = pred.t() 342 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 343 | 344 | res = [] 345 | for k in topk: 346 | correct_k = correct[:k].view(-1).float().sum(0) 347 | res.append(correct_k.mul_(100.0 / batch_size)) 348 | return res 349 | 350 | 351 | def save_checkpoint(args, state, is_best): 352 | path = 'checkpoint/ours/' + args.idx + '/' 353 | if not os.path.exists(path): 354 | os.makedirs(path) 355 | filename = path + args.save_name + '_ckpt.pth.tar' 356 | if is_best: 357 | torch.save(state, filename) 358 | 359 | if __name__ == '__main__': 360 | main() 361 | -------------------------------------------------------------------------------- /MetaSAug_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import random 5 | import copy 6 | import torch 7 | import torchvision 8 | import numpy as np 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | import torch.nn as nn 12 | import torchvision.transforms as transforms 13 | from data_utils import * 14 | from resnet import * 15 | import shutil 16 | import gc 17 | 18 | parser = argparse.ArgumentParser(description='Imbalanced Example') 19 | parser.add_argument('--checkpoint_path', default='path.pth.tar', type=str, 20 | help='the path of checkpoint') 21 | parser.add_argument('--dataset', default='cifar10', type=str) 22 | parser.add_argument('--imb_factor', default='0.1', type=float) 23 | 24 | 25 | args = parser.parse_args() 26 | print('checkpoint_path:', args.checkpoint_path) 27 | 28 | params = args.checkpoint_path.split('_') 29 | dataset = args.dataset 30 | imb_factor = args.imb_factor 31 | 32 | 33 | kwargs = {'num_workers': 4, 'pin_memory': False} 34 | use_cuda = torch.cuda.is_available() 35 | 36 | torch.manual_seed(42) 37 | 38 | print('start loading test data') 39 | train_data_meta, train_data, test_dataset = build_dataset(dataset, 10) 40 | test_loader = torch.utils.data.DataLoader( 41 | test_dataset, batch_size=100, shuffle=False, **kwargs) 42 | 43 | print('load test data successfully') 44 | 45 | best_prec1 = 0 46 | 47 | 48 | def main(): 49 | global args, best_prec1 50 | args = parser.parse_args() 51 | 52 | 53 | model = build_model() 54 | 55 | net_dict = torch.load(args.checkpoint_path) 56 | 57 | model.load_state_dict(net_dict['state_dict']) 58 | 59 | prec1, preds, gt_labels = validate( 60 | test_loader, model, nn.CrossEntropyLoss().cuda(), 0) 61 | print('Test result:\n' 62 | 'Dataset: {0}\t' 63 | 'Imb_factor: {1}\t' 64 | 'Accuracy: {2:.2f} \t' 65 | 'Error: {3:.2f} \n'.format( 66 | dataset, int(1 / imb_factor), prec1,100 - prec1)) 67 | 68 | 69 | 70 | def validate(val_loader, model, criterion, epoch): 71 | batch_time = AverageMeter() 72 | losses = AverageMeter() 73 | top1 = AverageMeter() 74 | 75 | model.eval() 76 | 77 | true_labels = [] 78 | preds = [] 79 | 80 | end = time.time() 81 | for i, (input, target) in enumerate(val_loader): 82 | target = target.cuda() 83 | input = input.cuda() 84 | input_var = torch.autograd.Variable(input) 85 | target_var = torch.autograd.Variable(target) 86 | 87 | with torch.no_grad(): 88 | _, output = model(input_var) 89 | 90 | output_numpy = output.data.cpu().numpy() 91 | preds_output = list(output_numpy.argmax(axis=1)) 92 | 93 | true_labels += list(target_var.data.cpu().numpy()) 94 | preds += preds_output 95 | 96 | prec1 = accuracy(output.data, target, topk=(1,))[0] 97 | top1.update(prec1.item(), input.size(0)) 98 | 99 | batch_time.update(time.time() - end) 100 | end = time.time() 101 | 102 | if i % 100 == 0: 103 | print('Test: [{0}/{1}]\t' 104 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 105 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 106 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 107 | i, len(val_loader), batch_time=batch_time, loss=losses, 108 | top1=top1)) 109 | 110 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 111 | 112 | return top1.avg, preds, true_labels 113 | 114 | 115 | def build_model(): 116 | model = ResNet32(dataset == 'cifar10' and 10 or 100) 117 | 118 | if torch.cuda.is_available(): 119 | model.cuda() 120 | torch.backends.cudnn.benchmark = True 121 | 122 | return model 123 | 124 | 125 | 126 | class AverageMeter(object): 127 | 128 | def __init__(self): 129 | self.reset() 130 | 131 | def reset(self): 132 | self.val = 0 133 | self.avg = 0 134 | self.sum = 0 135 | self.count = 0 136 | 137 | def update(self, val, n=1): 138 | self.val = val 139 | self.sum += val * n 140 | self.count += n 141 | self.avg = self.sum / self.count 142 | 143 | 144 | 145 | 146 | def accuracy(output, target, topk=(1,)): 147 | maxk = max(topk) 148 | batch_size = target.size(0) 149 | 150 | _, pred = output.topk(maxk, 1, True, True) 151 | pred = pred.t() 152 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 153 | 154 | res = [] 155 | for k in topk: 156 | correct_k = correct[:k].view(-1).float().sum(0) 157 | res.append(correct_k.mul_(100.0 / batch_size)) 158 | return res 159 | 160 | 161 | 162 | 163 | if __name__ == '__main__': 164 | main() 165 | 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 | # MetaSAug: Meta Semantic Augmentation for Long-Tailed Visual Recognition 5 | 6 | Shuang Li, Kaixiong Gong, et al. 7 | 8 | IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2021. [[CVPR 2021 PDF](https://openaccess.thecvf.com/content/CVPR2021/papers/Li_MetaSAug_Meta_Semantic_Augmentation_for_Long-Tailed_Visual_Recognition_CVPR_2021_paper.pdf)] 9 | 10 | [![Paper](https://img.shields.io/badge/paper-arxiv.2208.01195-B31B1B.svg)](https://arxiv.org/abs/2103.12579) 11 |
12 | 13 | This repository contains the code of our CVPR 2021 work "MetaSAug: Meta Semantic Augmentation for Long-Tailed Visual Recognition". 14 | 15 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/metasaug-meta-semantic-augmentation-for-long/long-tail-learning-on-cifar-100-lt-r-200)](https://paperswithcode.com/sota/long-tail-learning-on-cifar-100-lt-r-200?p=metasaug-meta-semantic-augmentation-for-long) 16 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/metasaug-meta-semantic-augmentation-for-long/long-tail-learning-on-cifar-10-lt-r-200)](https://paperswithcode.com/sota/long-tail-learning-on-cifar-10-lt-r-200?p=metasaug-meta-semantic-augmentation-for-long) 17 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/metasaug-meta-semantic-augmentation-for-long/long-tail-learning-on-cifar-10-lt-r-50)](https://paperswithcode.com/sota/long-tail-learning-on-cifar-10-lt-r-50?p=metasaug-meta-semantic-augmentation-for-long) 18 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/metasaug-meta-semantic-augmentation-for-long/long-tail-learning-on-cifar-10-lt-r-10)](https://paperswithcode.com/sota/long-tail-learning-on-cifar-10-lt-r-10?p=metasaug-meta-semantic-augmentation-for-long) 19 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/metasaug-meta-semantic-augmentation-for-long/long-tail-learning-on-cifar-100-lt-r-50)](https://paperswithcode.com/sota/long-tail-learning-on-cifar-100-lt-r-50?p=metasaug-meta-semantic-augmentation-for-long) 20 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/metasaug-meta-semantic-augmentation-for-long/long-tail-learning-on-cifar-10-lt-r-100)](https://paperswithcode.com/sota/long-tail-learning-on-cifar-10-lt-r-100?p=metasaug-meta-semantic-augmentation-for-long) 21 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/metasaug-meta-semantic-augmentation-for-long/image-classification-on-inaturalist)](https://paperswithcode.com/sota/image-classification-on-inaturalist?p=metasaug-meta-semantic-augmentation-for-long) 22 | 23 | #### Abstract 24 | 25 | Real-world training data usually exhibits long-tailed distribution, where several majority classes have a significantly larger number of samples than the remaining minority classes. This imbalance degrades the performance of typical supervised learning algorithms designed for balanced training sets. In this paper, we address this issue by augmenting minority classes with a recently proposed implicit semantic data augmentation (ISDA) algorithm, which produces diversified augmented samples by translating deep features along many semantically meaningful directions. Importantly, given that ISDA estimates the classconditional statistics to obtain semantic directions, we find it ineffective to do this on minority classes due to the insufficient training data. To this end, we propose a novel approach to learn transformed semantic directions with metalearning automatically. In specific, the augmentation strategy during training is dynamically optimized, aiming to minimize the loss on a small balanced validation set, which is approximated via a meta update step. Extensive empirical results on CIFAR-LT-10/100, ImageNet-LT, and iNaturalist2017/2018 validate the effectiveness of our method 26 | 27 |

28 | drawing 29 |

30 | 31 | 32 | If you find this idea or code useful for your research, please consider citing our paper: 33 | ``` 34 | @inproceedings{li2021metasaug, 35 | title={Metasaug: Meta semantic augmentation for long-tailed visual recognition}, 36 | author={Li, Shuang and Gong, Kaixiong and Liu, Chi Harold and Wang, Yulin and Qiao, Feng and Cheng, Xinjing}, 37 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 38 | pages={5212--5221}, 39 | year={2021} 40 | } 41 | ``` 42 | 43 | ## Prerequisite 44 | 45 | - PyTorch >= 1.2.0 46 | - Python3 47 | - torchvision 48 | - PIL 49 | - argparse 50 | - numpy 51 | 52 | ## Evaluation 53 | 54 | We provide several trained models of MetaSAug for evaluation. 55 | 56 | Testing on CIFAR-LT-10/100: 57 | 58 | - `sh scripts/MetaSAug_CE_test.sh` 59 | - `sh scripts/MetaSAug_LDAM_test.sh` 60 | 61 | Testing on ImageNet and iNaturalist18: 62 | 63 | - `sh ImageNet_iNat/test.sh` 64 | 65 | The trained models are in [Google Drive](https://drive.google.com/drive/folders/1YyE4RAniebDo8KyvdobcRfS0w5ZtMAQt?usp=sharing). 66 | 67 | ## Getting Started 68 | 69 | ### Dataset 70 | - Long-tailed CIFAR10/100: The long-tailed version of CIFAR10/100. Code for coverting to long-tailed version is in [data_utils.py](https://github.com/BIT-DA/MetaSAug/blob/main/data_utils.py). 71 | - ImageNet-LT: The long-tailed version of ImageNet. [[Long-tailed annotations](https://github.com/BIT-DA/MetaSAug/tree/main/ImageNet_iNat/data)] 72 | - [iNaturalist2017](https://github.com/visipedia/inat_comp/tree/master/2017): A natural long-tailed dataset. 73 | - [iNaturalist2018](https://github.com/visipedia/inat_comp/tree/master/2018): A natural long-tailed dataset. 74 | 75 | ### Training 76 | 77 | Training on CIFAR-LT-10/100: 78 | ``` 79 | CIFAR-LT-100, MetaSAug with LDAM loss 80 | python3.6 MetaSAug_LDAM_train.py --gpu 0 --lr 0.1 --lam 0.75 --imb_factor 0.05 --dataset cifar100 --num_classes 100 --save_name MetaSAug_cifar100_LDAM_imb0.05 --idx 1 81 | ``` 82 | 83 | Or run the script: 84 | 85 | ``` 86 | sh scripts/MetaSAug_LDAM_train.sh 87 | ``` 88 | 89 | Training on ImageNet-LT: 90 | 91 | ``` 92 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port 53212 train.py --lr 0.0003 --meta_lr 0.1 --workers 0 --batch_size 256 --epochs 20 --dataset ImageNet_LT --num_classes 1000 --data_root ../ImageNet 93 | ``` 94 | 95 | Or run the script: 96 | 97 | ``` 98 | sh ImageNet_iNat/scripts/train.sh 99 | ``` 100 | 101 | **Note**: Training on large scale datasets like ImageNet-LT and iNaturalist2017/2018 involves multiple gpus for faster speed. To achieve better generalizable representations, vanilla CE loss is used for training the network in the early training stage. For convenience, the training starts from the pre-trained models, e.g., [ImageNet-LT](https://dl.fbaipublicfiles.com/classifier-balancing/ImageNet_LT/models/resnet50_uniform_e90.pth), [iNat18](https://dl.fbaipublicfiles.com/classifier-balancing/iNaturalist18/models/resnet50_uniform_e200.pth) (both from project [cRT](https://github.com/facebookresearch/classifier-balancing)). 102 | 103 | ## Results and models 104 | **CIFAR-LT-10** 105 | | Model | Imb.| Top-1 Error | Download |Model | Imb.| Top-1 Error | Download | 106 | | --------- |:--------:|:-----------:|:-------------:|--------- |:--------:|:-----------:|:-------------:| 107 | | MetaSAug+LDAM | 200 |22.65 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) |MetaSAug+CE | 200 |23.11 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) | 108 | | MetaSAug+LDAM | 100 |19.34 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) |MetaSAug+CE | 100 |19.46 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) | 109 | | MetaSAug+LDAM | 50 |15.66 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) |MetaSAug+CE | 50 |15.97 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) | 110 | | MetaSAug+LDAM | 20 |11.90 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) |MetaSAug+CE | 20 |12.36 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) | 111 | | MetaSAug+LDAM | 10 |10.32 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) |MetaSAug+CE | 10 |10.56 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) | 112 | 113 | **CIFAR-LT-100** 114 | | Model | Imb.| Top-1 Error | Download |Model | Imb.| Top-1 Error | Download | 115 | | --------- |:--------:|:-----------:|:-------------:|--------- |:--------:|:-----------:|:-------------:| 116 | | MetaSAug+LDAM | 200 |56.91 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) |MetaSAug+CE | 200 |60.06 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) | 117 | | MetaSAug+LDAM | 100 |51.99 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) |MetaSAug+CE | 100 |53.13 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) | 118 | | MetaSAug+LDAM | 50 |47.73 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) |MetaSAug+CE | 50 |48.10 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) | 119 | | MetaSAug+LDAM | 20 |42.47 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) |MetaSAug+CE | 20 |42.15 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) | 120 | | MetaSAug+LDAM | 10 |38.72 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) |MetaSAug+CE | 10 |38.27 | [ResNet32](https://drive.google.com/drive/folders/1eKGWDXBa1jqOBWXRUVks6iZOn2YhKkET?usp=sharing) | 121 | 122 | **ImageNet-LT** 123 | | Model | Top-1 Error| Download | 124 | | --------- |:--------:|:-----------:| 125 | | MetaSAug | 52.33 | [ResNet50](https://drive.google.com/drive/folders/1HuaMsPCcR4DV1Tev9dHxd4BGU7mxJuqZ?usp=sharing)| 126 | 127 | **iNaturalist18** 128 | | Model | Top-1 Error| Download | 129 | | --------- |:--------:|:-----------:| 130 | | MetaSAug | 30.50 | [ResNet50](https://drive.google.com/drive/folders/1yQDFKDQmgxWArHNc9kvEPxMPcs2mXa6O?usp=sharing)| 131 | 132 | ## Acknowledgements 133 | Some codes in this project are adapted from [Meta-class-weight](https://github.com/abdullahjamal/Longtail_DA) and [cRT](https://github.com/facebookresearch/classifier-balancing). We thank them for their excellent projects. 134 | 135 | 136 | -------------------------------------------------------------------------------- /assets/illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BIT-DA/MetaSAug/9efe54a6a8f752671e77ef500e5c26f7bf772f77/assets/illustration.png -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.parallel 5 | import torch.backends.cudnn as cudnn 6 | import torch.optim 7 | import torch.utils.data 8 | import torchvision.transforms as transforms 9 | import torchvision 10 | import numpy as np 11 | import copy 12 | 13 | np.random.seed(6) 14 | 15 | def build_dataset(dataset,num_meta): 16 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 17 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 18 | 19 | transform_train = transforms.Compose([ 20 | transforms.ToTensor(), 21 | transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), 22 | (4, 4, 4, 4), mode='reflect').squeeze()), 23 | transforms.ToPILImage(), 24 | transforms.RandomCrop(32), 25 | transforms.RandomHorizontalFlip(), 26 | transforms.ToTensor(), 27 | normalize, 28 | ]) 29 | 30 | transform_test = transforms.Compose([ 31 | transforms.ToTensor(), 32 | normalize 33 | ]) 34 | 35 | if dataset == 'cifar10': 36 | train_dataset = torchvision.datasets.CIFAR10(root='../cifar-10', train=True, download=False, transform=transform_train) 37 | test_dataset = torchvision.datasets.CIFAR10('../cifar-10', train=False, transform=transform_test) 38 | img_num_list = [num_meta] * 10 39 | num_classes = 10 40 | 41 | if dataset == 'cifar100': 42 | train_dataset = torchvision.datasets.CIFAR100(root='../cifar-100', train=True, download=True, transform=transform_train) 43 | test_dataset = torchvision.datasets.CIFAR100('../cifar-100', train=False, transform=transform_test) 44 | img_num_list = [num_meta] * 100 45 | num_classes = 100 46 | 47 | data_list_val = {} 48 | for j in range(num_classes): 49 | data_list_val[j] = [i for i, label in enumerate(train_dataset.targets) if label == j] 50 | 51 | idx_to_meta = [] 52 | idx_to_train = [] 53 | print(img_num_list) 54 | 55 | for cls_idx, img_id_list in data_list_val.items(): 56 | np.random.shuffle(img_id_list) 57 | img_num = img_num_list[int(cls_idx)] 58 | idx_to_meta.extend(img_id_list[:img_num]) 59 | idx_to_train.extend(img_id_list[img_num:]) 60 | train_data = copy.deepcopy(train_dataset) 61 | train_data_meta = copy.deepcopy(train_dataset) 62 | 63 | train_data_meta.data = np.delete(train_dataset.data, idx_to_train,axis=0) 64 | train_data_meta.targets = np.delete(train_dataset.targets, idx_to_train, axis=0) 65 | train_data.data = np.delete(train_dataset.data, idx_to_meta, axis=0) 66 | train_data.targets = np.delete(train_dataset.targets, idx_to_meta, axis=0) 67 | 68 | return train_data_meta, train_data, test_dataset 69 | 70 | def get_img_num_per_cls(dataset, imb_factor=None, num_meta=None): 71 | 72 | if dataset == 'cifar10': 73 | img_max = (50000-num_meta)/10 74 | cls_num = 10 75 | 76 | if dataset == 'cifar100': 77 | img_max = (50000-num_meta)/100 78 | cls_num = 100 79 | 80 | if imb_factor is None: 81 | return [img_max] * cls_num 82 | img_num_per_cls = [] 83 | for cls_idx in range(cls_num): 84 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 85 | img_num_per_cls.append(int(num)) 86 | return img_num_per_cls 87 | 88 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | import math 6 | import torch.nn.functional as F 7 | import pdb 8 | 9 | 10 | class EstimatorCV(): 11 | def __init__(self, feature_num, class_num): 12 | super(EstimatorCV, self).__init__() 13 | self.class_num = class_num 14 | self.CoVariance = torch.zeros(class_num, feature_num, feature_num).cuda() 15 | self.Ave = torch.zeros(class_num, feature_num).cuda() 16 | self.Amount = torch.zeros(class_num).cuda() 17 | 18 | def update_CV(self, features, labels): 19 | N = features.size(0) 20 | C = self.class_num 21 | A = features.size(1) 22 | 23 | NxCxFeatures = features.view(N, 1, A).expand(N, C, A) 24 | onehot = torch.zeros(N, C).cuda() 25 | onehot.scatter_(1, labels.view(-1, 1), 1) 26 | 27 | NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A) 28 | 29 | features_by_sort = NxCxFeatures.mul(NxCxA_onehot) 30 | 31 | Amount_CxA = NxCxA_onehot.sum(0) 32 | Amount_CxA[Amount_CxA == 0] = 1 33 | 34 | ave_CxA = features_by_sort.sum(0) / Amount_CxA 35 | 36 | var_temp = features_by_sort - ave_CxA.expand(N, C, A).mul(NxCxA_onehot) 37 | 38 | var_temp = torch.bmm(var_temp.permute(1, 2, 0), var_temp.permute(1, 0, 2)).div(Amount_CxA.view(C, A, 1).expand(C, A, A)) 39 | 40 | sum_weight_CV = onehot.sum(0).view(C, 1, 1).expand(C, A, A) 41 | 42 | sum_weight_AV = onehot.sum(0).view(C, 1).expand(C, A) 43 | 44 | weight_CV = sum_weight_CV.div(sum_weight_CV + self.Amount.view(C, 1, 1).expand(C, A, A)) 45 | weight_CV[weight_CV != weight_CV] = 0 46 | 47 | weight_AV = sum_weight_AV.div(sum_weight_AV + self.Amount.view(C, 1).expand(C, A)) 48 | weight_AV[weight_AV != weight_AV] = 0 49 | 50 | additional_CV = weight_CV.mul(1 - weight_CV).mul( 51 | torch.bmm( 52 | (self.Ave - ave_CxA).view(C, A, 1), 53 | (self.Ave - ave_CxA).view(C, 1, A) 54 | ) 55 | ) 56 | 57 | self.CoVariance = (self.CoVariance.mul(1 - weight_CV) + var_temp.mul(weight_CV)).detach() + additional_CV.detach() 58 | self.Ave = (self.Ave.mul(1 - weight_AV) + ave_CxA.mul(weight_AV)).detach() 59 | self.Amount += onehot.sum(0) 60 | 61 | 62 | class LDAM_meta(nn.Module): 63 | def __init__(self, feature_num, class_num, cls_num_list, max_m=0.5, s=30): 64 | super(LDAM_meta, self).__init__() 65 | self.estimator = EstimatorCV(feature_num, class_num) 66 | self.class_num = class_num 67 | self.cross_entropy = nn.CrossEntropyLoss() 68 | 69 | m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list)) 70 | m_list = m_list * (max_m / np.max(m_list)) 71 | m_list = torch.cuda.FloatTensor(m_list) 72 | self.m_list = m_list 73 | assert s > 0 74 | self.s = s 75 | 76 | def MetaSAug(self, fc, features, y_s, labels_s, s_cv_matrix, ratio,): 77 | N = features.size(0) 78 | C = self.class_num 79 | A = features.size(1) 80 | 81 | weight_m = list(fc.named_leaves())[0][1] 82 | 83 | NxW_ij = weight_m.expand(N, C, A) 84 | NxW_kj = torch.gather(NxW_ij, 1, labels_s.view(N, 1, 1).expand(N, C, A)) 85 | 86 | s_CV_temp = s_cv_matrix[labels_s] 87 | 88 | sigma2 = ratio * torch.bmm(torch.bmm(NxW_ij - NxW_kj, s_CV_temp), (NxW_ij - NxW_kj).permute(0, 2, 1)) 89 | sigma2 = sigma2.mul(torch.eye(C).cuda().expand(N, C, C)).sum(2).view(N, C) 90 | 91 | aug_result = y_s + 0.5 * sigma2 92 | index = torch.zeros_like(y_s, dtype=torch.uint8) 93 | index.scatter_(1, labels_s.data.view(-1, 1), 1) 94 | 95 | index_float = index.type(torch.cuda.FloatTensor) 96 | batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0, 1)) 97 | batch_m = batch_m.view((-1, 1)) 98 | aug_result_m = aug_result - batch_m 99 | 100 | output = torch.where(index, aug_result_m, aug_result) 101 | return output 102 | 103 | def forward(self, fc, features, y_s, labels, ratio, weights, cv, manner): 104 | 105 | #self.estimator.update_CV(features.detach(), labels) 106 | aug_y = self.MetaSAug(fc, features, y_s, labels, cv, \ 107 | ratio) 108 | if manner == "update": 109 | self.estimator.update_CV(features.detach(), labels) 110 | loss = F.cross_entropy(aug_y, labels, weight=weights) 111 | else: 112 | loss = F.cross_entropy(aug_y, labels, weight=weights) 113 | return loss 114 | 115 | def get_cv(self): 116 | return self.estimator.CoVariance 117 | 118 | def update_cv(self, cv): 119 | self.estimator.CoVariance = cv 120 | 121 | 122 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch.autograd import Variable 6 | import torch.nn.init as init 7 | 8 | 9 | def to_var(x, requires_grad=True): 10 | if torch.cuda.is_available(): 11 | x = x.cuda() 12 | return Variable(x, requires_grad=requires_grad) 13 | 14 | 15 | class MetaModule(nn.Module): 16 | # adopted from: Adrien Ecoffet https://github.com/AdrienLE 17 | def params(self): 18 | for name, param in self.named_params(self): 19 | yield param 20 | 21 | def named_leaves(self): 22 | return [] 23 | 24 | def named_submodules(self): 25 | return [] 26 | 27 | def named_params(self, curr_module=None, memo=None, prefix=''): 28 | if memo is None: 29 | memo = set() 30 | 31 | if hasattr(curr_module, 'named_leaves'): 32 | for name, p in curr_module.named_leaves(): 33 | if p is not None and p not in memo: 34 | memo.add(p) 35 | yield prefix + ('.' if prefix else '') + name, p 36 | else: 37 | for name, p in curr_module._parameters.items(): 38 | if p is not None and p not in memo: 39 | memo.add(p) 40 | yield prefix + ('.' if prefix else '') + name, p 41 | 42 | for mname, module in curr_module.named_children(): 43 | submodule_prefix = prefix + ('.' if prefix else '') + mname 44 | for name, p in self.named_params(module, memo, submodule_prefix): 45 | yield name, p 46 | 47 | def update_params(self, lr_inner, first_order=False, source_params=None, detach=False): 48 | if source_params is not None: 49 | for tgt, src in zip(self.named_params(self), source_params): 50 | name_t, param_t = tgt 51 | grad = src 52 | if first_order: 53 | grad = to_var(grad.detach().data) 54 | tmp = param_t - lr_inner * grad 55 | self.set_param(self, name_t, tmp) 56 | else: 57 | 58 | for name, param in self.named_params(self): 59 | if not detach: 60 | grad = param.grad 61 | if first_order: 62 | grad = to_var(grad.detach().data) 63 | tmp = param - lr_inner * grad 64 | self.set_param(self, name, tmp) 65 | else: 66 | param = param.detach_() 67 | self.set_param(self, name, param) 68 | 69 | def set_param(self, curr_mod, name, param): 70 | if '.' in name: 71 | n = name.split('.') 72 | module_name = n[0] 73 | rest = '.'.join(n[1:]) 74 | for name, mod in curr_mod.named_children(): 75 | if module_name == name: 76 | self.set_param(mod, rest, param) 77 | break 78 | else: 79 | setattr(curr_mod, name, param) 80 | 81 | def detach_params(self): 82 | for name, param in self.named_params(self): 83 | self.set_param(self, name, param.detach()) 84 | 85 | def copy(self, other, same_var=False): 86 | for name, param in other.named_params(): 87 | if not same_var: 88 | param = to_var(param.data.clone(), requires_grad=True) 89 | self.set_param(name, param) 90 | 91 | 92 | class MetaLinear(MetaModule): 93 | def __init__(self, *args, **kwargs): 94 | super().__init__() 95 | ignore = nn.Linear(*args, **kwargs) 96 | 97 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 98 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 99 | 100 | def forward(self, x): 101 | return F.linear(x, self.weight, self.bias) 102 | 103 | def named_leaves(self): 104 | return [('weight', self.weight), ('bias', self.bias)] 105 | 106 | class MetaLinear_Norm(MetaModule): 107 | def __init__(self, *args, **kwargs): 108 | super().__init__() 109 | temp = nn.Linear(*args, **kwargs) 110 | temp.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 111 | self.register_buffer('weight', to_var(temp.weight.data.t(), requires_grad=True)) 112 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 113 | 114 | def forward(self, x): 115 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 116 | return out 117 | 118 | def named_leaves(self): 119 | return [('weight', self.weight)] 120 | 121 | 122 | class MetaConv2d(MetaModule): 123 | def __init__(self, *args, **kwargs): 124 | super().__init__() 125 | ignore = nn.Conv2d(*args, **kwargs) 126 | 127 | self.in_channels = ignore.in_channels 128 | self.out_channels = ignore.out_channels 129 | self.stride = ignore.stride 130 | self.padding = ignore.padding 131 | self.dilation = ignore.dilation 132 | self.groups = ignore.groups 133 | self.kernel_size = ignore.kernel_size 134 | 135 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 136 | 137 | if ignore.bias is not None: 138 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 139 | else: 140 | self.register_buffer('bias', None) 141 | 142 | def forward(self, x): 143 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 144 | 145 | def named_leaves(self): 146 | return [('weight', self.weight), ('bias', self.bias)] 147 | 148 | 149 | class MetaConvTranspose2d(MetaModule): 150 | def __init__(self, *args, **kwargs): 151 | super().__init__() 152 | ignore = nn.ConvTranspose2d(*args, **kwargs) 153 | 154 | self.stride = ignore.stride 155 | self.padding = ignore.padding 156 | self.dilation = ignore.dilation 157 | self.groups = ignore.groups 158 | 159 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 160 | 161 | if ignore.bias is not None: 162 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 163 | else: 164 | self.register_buffer('bias', None) 165 | 166 | def forward(self, x, output_size=None): 167 | output_padding = self._output_padding(x, output_size) 168 | return F.conv_transpose2d(x, self.weight, self.bias, self.stride, self.padding, 169 | output_padding, self.groups, self.dilation) 170 | 171 | def named_leaves(self): 172 | return [('weight', self.weight), ('bias', self.bias)] 173 | 174 | 175 | class MetaBatchNorm2d(MetaModule): 176 | def __init__(self, *args, **kwargs): 177 | super().__init__() 178 | ignore = nn.BatchNorm2d(*args, **kwargs) 179 | 180 | self.num_features = ignore.num_features 181 | self.eps = ignore.eps 182 | self.momentum = ignore.momentum 183 | self.affine = ignore.affine 184 | self.track_running_stats = ignore.track_running_stats 185 | 186 | if self.affine: 187 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 188 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 189 | 190 | if self.track_running_stats: 191 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 192 | self.register_buffer('running_var', torch.ones(self.num_features)) 193 | else: 194 | self.register_parameter('running_mean', None) 195 | self.register_parameter('running_var', None) 196 | 197 | def forward(self, x): 198 | return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 199 | self.training or not self.track_running_stats, self.momentum, self.eps) 200 | 201 | def named_leaves(self): 202 | return [('weight', self.weight), ('bias', self.bias)] 203 | 204 | 205 | def _weights_init(m): 206 | classname = m.__class__.__name__ 207 | if isinstance(m, MetaLinear) or isinstance(m, MetaConv2d): 208 | init.kaiming_normal(m.weight) 209 | 210 | class LambdaLayer(MetaModule): 211 | def __init__(self, lambd): 212 | super(LambdaLayer, self).__init__() 213 | self.lambd = lambd 214 | 215 | def forward(self, x): 216 | return self.lambd(x) 217 | 218 | 219 | class BasicBlock(MetaModule): 220 | expansion = 1 221 | 222 | def __init__(self, in_planes, planes, stride=1, option='A'): 223 | super(BasicBlock, self).__init__() 224 | self.conv1 = MetaConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 225 | self.bn1 = MetaBatchNorm2d(planes) 226 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 227 | self.bn2 = MetaBatchNorm2d(planes) 228 | 229 | self.shortcut = nn.Sequential() 230 | if stride != 1 or in_planes != planes: 231 | if option == 'A': 232 | self.shortcut = LambdaLayer(lambda x: 233 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 234 | elif option == 'B': 235 | self.shortcut = nn.Sequential( 236 | MetaConv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 237 | MetaBatchNorm2d(self.expansion * planes) 238 | ) 239 | 240 | def forward(self, x): 241 | out = F.relu(self.bn1(self.conv1(x))) 242 | out = self.bn2(self.conv2(out)) 243 | out += self.shortcut(x) 244 | out = F.relu(out) 245 | return out 246 | 247 | 248 | class ResNet32(MetaModule): 249 | def __init__(self, num_classes, block=BasicBlock, num_blocks=[5, 5, 5]): 250 | super(ResNet32, self).__init__() 251 | self.in_planes = 16 252 | 253 | self.conv1 = MetaConv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 254 | self.bn1 = MetaBatchNorm2d(16) 255 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 256 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 257 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 258 | self.linear = MetaLinear(64, num_classes) 259 | 260 | self.apply(_weights_init) 261 | 262 | def _make_layer(self, block, planes, num_blocks, stride): 263 | strides = [stride] + [1]*(num_blocks-1) 264 | layers = [] 265 | for stride in strides: 266 | layers.append(block(self.in_planes, planes, stride)) 267 | self.in_planes = planes * block.expansion 268 | 269 | return nn.Sequential(*layers) 270 | 271 | def forward(self, x): 272 | out = F.relu(self.bn1(self.conv1(x))) 273 | out = self.layer1(out) 274 | out = self.layer2(out) 275 | out = self.layer3(out) 276 | out = F.avg_pool2d(out, out.size()[3]) 277 | out = out.view(out.size(0), -1) 278 | y = self.linear(out) 279 | return out, y 280 | -------------------------------------------------------------------------------- /scripts/MetaSAug_CE_test.sh: -------------------------------------------------------------------------------- 1 | python3.6 MetaSAug_test.py --dataset cifar10 --imb_factor 0.005 --checkpoint_path checkpoint/ours/MetaSAug_cifar10_CE_imb0.005_ckpt.pth.tar 2 | python3.6 MetaSAug_test.py --dataset cifar10 --imb_factor 0.01 --checkpoint_path checkpoint/ours/MetaSAug_cifar10_CE_imb0.01_ckpt.pth.tar 3 | python3.6 MetaSAug_test.py --dataset cifar10 --imb_factor 0.02 --checkpoint_path checkpoint/ours/MetaSAug_cifar10_CE_imb0.02_ckpt.pth.tar 4 | python3.6 MetaSAug_test.py --dataset cifar10 --imb_factor 0.05 --checkpoint_path checkpoint/ours/MetaSAug_cifar10_CE_imb0.05_ckpt.pth.tar 5 | python3.6 MetaSAug_test.py --dataset cifar10 --imb_factor 0.1 --checkpoint_path checkpoint/ours/MetaSAug_cifar10_CE_imb0.1_ckpt.pth.tar 6 | 7 | python3.6 MetaSAug_test.py --dataset cifar100 --imb_factor 0.005 --checkpoint_path checkpoint/ours/MetaSAug_cifar100_CE_imb0.005_ckpt.pth.tar 8 | python3.6 MetaSAug_test.py --dataset cifar100 --imb_factor 0.01 --checkpoint_path checkpoint/ours/MetaSAug_cifar100_CE_imb0.01_ckpt.pth.tar 9 | python3.6 MetaSAug_test.py --dataset cifar100 --imb_factor 0.02 --checkpoint_path checkpoint/ours/MetaSAug_cifar100_CE_imb0.02_ckpt.pth.tar 10 | python3.6 MetaSAug_test.py --dataset cifar100 --imb_factor 0.05 --checkpoint_path checkpoint/ours/MetaSAug_cifar100_CE_imb0.05_ckpt.pth.tar 11 | python3.6 MetaSAug_test.py --dataset cifar100 --imb_factor 0.1 --checkpoint_path checkpoint/ours/MetaSAug_cifar100_CE_imb0.1_ckpt.pth.tar 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /scripts/MetaSAug_LDAM_test.sh: -------------------------------------------------------------------------------- 1 | python3.6 MetaSAug_test.py --dataset cifar10 --imb_factor 0.005 --checkpoint_path checkpoint/ours/MetaSAug_cifar10_LDAM_imb0.005_ckpt.pth.tar 2 | python3.6 MetaSAug_test.py --dataset cifar10 --imb_factor 0.01 --checkpoint_path checkpoint/ours/MetaSAug_cifar10_LDAM_imb0.01_ckpt.pth.tar 3 | python3.6 MetaSAug_test.py --dataset cifar10 --imb_factor 0.02 --checkpoint_path checkpoint/ours/MetaSAug_cifar10_LDAM_imb0.02_ckpt.pth.tar 4 | python3.6 MetaSAug_test.py --dataset cifar10 --imb_factor 0.05 --checkpoint_path checkpoint/ours/MetaSAug_cifar10_LDAM_imb0.05_ckpt.pth.tar 5 | python3.6 MetaSAug_test.py --dataset cifar10 --imb_factor 0.1 --checkpoint_path checkpoint/ours/MetaSAug_cifar10_LDAM_imb0.1_ckpt.pth.tar 6 | 7 | python3.6 MetaSAug_test.py --dataset cifar100 --imb_factor 0.005 --checkpoint_path checkpoint/ours/MetaSAug_cifar100_LDAM_imb0.005_ckpt.pth.tar 8 | python3.6 MetaSAug_test.py --dataset cifar100 --imb_factor 0.01 --checkpoint_path checkpoint/ours/MetaSAug_cifar100_LDAM_imb0.01_ckpt.pth.tar 9 | python3.6 MetaSAug_test.py --dataset cifar100 --imb_factor 0.02 --checkpoint_path checkpoint/ours/MetaSAug_cifar100_LDAM_imb0.02_ckpt.pth.tar 10 | python3.6 MetaSAug_test.py --dataset cifar100 --imb_factor 0.05 --checkpoint_path checkpoint/ours/MetaSAug_cifar100_LDAM_imb0.05_ckpt.pth.tar 11 | python3.6 MetaSAug_test.py --dataset cifar100 --imb_factor 0.1 --checkpoint_path checkpoint/ours/MetaSAug_cifar100_LDAM_imb0.1_ckpt.pth.tar 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /scripts/MetaSAug_LDAM_train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3.6 MetaSAug_LDAM_train.py --gpu 0 --lr 0.1 --lam 0.75 --imb_factor 0.05 --dataset cifar100 --num_classes 100 --save_name MetaSAug_cifar100_LDAM_imb0.05 --idx 1 2 | CUDA_VISIBLE_DEVICES=0 python3.6 MetaSAug_LDAM_train.py --gpu 0 --lr 0.1 --lam 0.75 --imb_factor 0.02 --dataset cifar100 --num_classes 100 --save_name MetaSAug_cifar100_LDAM_imb0.02 --idx 1 3 | --------------------------------------------------------------------------------