├── ImageNet_iNat ├── ResNet.py ├── data │ ├── ImageNet_LT_test.txt │ ├── ImageNet_LT_train.txt │ ├── ImageNet_LT_val.txt │ ├── iNaturalist18_train.txt │ └── iNaturalist18_val.txt ├── data_utils.py ├── dataloader.py ├── loss.py ├── resnet_meta.py ├── scripts │ ├── test.sh │ └── train.sh ├── test.py ├── train.py └── utils.py ├── LICENSE ├── MetaSAug_LDAM_train.py ├── MetaSAug_test.py ├── README.md ├── assets └── illustration.png ├── data_utils.py ├── loss.py ├── resnet.py └── scripts ├── MetaSAug_CE_test.sh ├── MetaSAug_LDAM_test.sh └── MetaSAug_LDAM_train.sh /ImageNet_iNat/ResNet.py: -------------------------------------------------------------------------------- 1 | from resnet_meta import * 2 | from utils import * 3 | from os import path 4 | 5 | 6 | def create_model(use_selfatt=False, use_fc=False, dropout=None, stage1_weights=False, dataset=None, log_dir=None, test=False, *args): 7 | 8 | print('Loading Scratch ResNet 50 Feature Model.') 9 | if not use_fc: 10 | resnet50 = FeatureMeta(BottleneckMeta, [3, 4, 6, 3], dropout=None) 11 | else: 12 | resnet50 = FCMeta(2048, 1000) 13 | if not test: 14 | if stage1_weights: 15 | assert dataset 16 | print('Loading %s Stage 1 ResNet 10 Weights.' % dataset) 17 | if log_dir is not None: 18 | # subdir = log_dir.strip('/').split('/')[-1] 19 | # subdir = subdir.replace('stage2', 'stage1') 20 | # weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), subdir) 21 | #weight_dir = path.join('/'.join(log_dir.split('/')[:-1]), 'stage1') 22 | weight_dir = log_dir 23 | else: 24 | weight_dir = './logs/%s/stage1' % dataset 25 | print('==> Loading weights from %s' % weight_dir) 26 | if not use_fc: 27 | resnet50 = init_weights(model=resnet50, 28 | weights_path=weight_dir) 29 | else: 30 | resnet50 = init_weights(model=resnet50, weights_path=weight_dir, classifier=True) 31 | #resnet50.load_state_dict(torch.load(weight_dir)) 32 | else: 33 | print('No Pretrained Weights For Feature Model.') 34 | 35 | return resnet50 36 | -------------------------------------------------------------------------------- /ImageNet_iNat/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.parallel 5 | import torch.backends.cudnn as cudnn 6 | import torch.optim 7 | import torch.utils.data 8 | import torchvision.transforms as transforms 9 | import torchvision 10 | import numpy as np 11 | import copy 12 | 13 | np.random.seed(6) 14 | #random.seed(2) 15 | def build_dataset(dataset,num_meta,num_classes): 16 | 17 | 18 | img_num_list = [num_meta] * num_classes 19 | 20 | # print(train_dataset.targets) 21 | data_list_val = {} 22 | for j in range(num_classes): 23 | data_list_val[j] = [i for i, label in enumerate(dataset.labels) if label == j] 24 | 25 | idx_to_meta = [] 26 | idx_to_train = [] 27 | #print(img_num_list) 28 | 29 | for cls_idx, img_id_list in data_list_val.items(): 30 | np.random.shuffle(img_id_list) 31 | img_num = img_num_list[int(cls_idx)] 32 | idx_to_meta.extend(img_id_list[:img_num]) 33 | idx_to_train.extend(img_id_list[img_num:]) 34 | train_data = copy.deepcopy(dataset) 35 | train_data_meta = copy.deepcopy(dataset) 36 | 37 | train_data_meta.img_path = np.delete(dataset.img_path,idx_to_train,axis=0) 38 | train_data_meta.labels = np.delete(dataset.labels, idx_to_train, axis=0) 39 | train_data.img_path = np.delete(dataset.img_path, idx_to_meta, axis=0) 40 | train_data.labels = np.delete(dataset.labels, idx_to_meta, axis=0) 41 | 42 | return train_data_meta, train_data 43 | 44 | 45 | def get_img_num_per_cls(dataset, imb_factor=None, num_meta=None): 46 | """ 47 | Get a list of image numbers for each class, given cifar version 48 | Num of imgs follows emponential distribution 49 | img max: 5000 / 500 * e^(-lambda * 0); 50 | img min: 5000 / 500 * e^(-lambda * int(cifar_version - 1)) 51 | exp(-lambda * (int(cifar_version) - 1)) = img_max / img_min 52 | args: 53 | cifar_version: str, '10', '100', '20' 54 | imb_factor: float, imbalance factor: img_min/img_max, 55 | None if geting default cifar data number 56 | output: 57 | img_num_per_cls: a list of number of images per class 58 | """ 59 | if dataset == 'cifar10': 60 | img_max = (50000-num_meta)/10 61 | cls_num = 10 62 | 63 | if dataset == 'cifar100': 64 | img_max = (50000-num_meta)/100 65 | cls_num = 100 66 | 67 | if imb_factor is None: 68 | return [img_max] * cls_num 69 | img_num_per_cls = [] 70 | for cls_idx in range(cls_num): 71 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 72 | img_num_per_cls.append(int(num)) 73 | return img_num_per_cls 74 | 75 | 76 | # This function is used to generate imbalanced test set 77 | ''' 78 | def get_img_num_per_cls_test(dataset,imb_factor=None,num_meta=None): 79 | """ 80 | Get a list of image numbers for each class, given cifar version 81 | Num of imgs follows emponential distribution 82 | img max: 5000 / 500 * e^(-lambda * 0); 83 | img min: 5000 / 500 * e^(-lambda * int(cifar_version - 1)) 84 | exp(-lambda * (int(cifar_version) - 1)) = img_max / img_min 85 | args: 86 | cifar_version: str, '10', '100', '20' 87 | imb_factor: float, imbalance factor: img_min/img_max, 88 | None if geting default cifar data number 89 | output: 90 | img_num_per_cls: a list of number of images per class 91 | """ 92 | if dataset == 'cifar10': 93 | img_max = (10000-num_meta)/10 94 | cls_num = 10 95 | 96 | if dataset == 'cifar100': 97 | img_max = (10000-num_meta)/100 98 | cls_num = 100 99 | 100 | if imb_factor is None: 101 | return [img_max] * cls_num 102 | img_num_per_cls = [] 103 | for cls_idx in range(cls_num): 104 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 105 | img_num_per_cls.append(int(num)) 106 | return img_num_per_cls 107 | ''' 108 | -------------------------------------------------------------------------------- /ImageNet_iNat/dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchvision 3 | from torch.utils.data import Dataset, DataLoader, ConcatDataset 4 | from torchvision import transforms 5 | import os 6 | from PIL import Image 7 | import json 8 | 9 | # Image statistics 10 | RGB_statistics = { 11 | 'iNaturalist18': { 12 | 'mean': [0.466, 0.471, 0.380], 13 | 'std': [0.195, 0.194, 0.192] 14 | }, 15 | 'default': { 16 | 'mean': [0.485, 0.456, 0.406], 17 | 'std': [0.229, 0.224, 0.225] 18 | } 19 | } 20 | 21 | 22 | def get_data_transform(split, rgb_mean, rbg_std, key='default'): 23 | data_transforms = { 24 | 'train': transforms.Compose([ 25 | transforms.RandomResizedCrop(224), 26 | transforms.RandomHorizontalFlip(), 27 | transforms.ToTensor(), 28 | transforms.Normalize(rgb_mean, rbg_std) 29 | ]) if key == 'iNaturalist18' else transforms.Compose([ 30 | transforms.RandomResizedCrop(224), 31 | transforms.RandomHorizontalFlip(), 32 | transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0), 33 | transforms.ToTensor(), 34 | transforms.Normalize(rgb_mean, rbg_std) 35 | ]), 36 | 'val': transforms.Compose([ 37 | transforms.Resize(256), 38 | transforms.CenterCrop(224), 39 | transforms.ToTensor(), 40 | transforms.Normalize(rgb_mean, rbg_std) 41 | ]), 42 | 'test': transforms.Compose([ 43 | transforms.Resize(256), 44 | transforms.CenterCrop(224), 45 | transforms.ToTensor(), 46 | transforms.Normalize(rgb_mean, rbg_std) 47 | ]) 48 | } 49 | return data_transforms[split] 50 | 51 | 52 | class LT_Dataset(Dataset): 53 | 54 | def __init__(self, root, txt, transform=None): 55 | self.img_path = [] 56 | self.labels = [] 57 | self.transform = transform 58 | print("--------------------------------------------") 59 | print(root) 60 | with open(txt) as f: 61 | for line in f: 62 | self.img_path.append(os.path.join(root, line.split()[0])) 63 | self.labels.append(int(line.split()[1])) 64 | 65 | def __len__(self): 66 | return len(self.labels) 67 | 68 | def __getitem__(self, index): 69 | 70 | path = self.img_path[index] 71 | label = self.labels[index] 72 | 73 | with open(path, 'rb') as f: 74 | sample = Image.open(f).convert('RGB') 75 | 76 | if self.transform is not None: 77 | sample = self.transform(sample) 78 | 79 | return sample, label 80 | 81 | class LT_Dataset_iNat17(Dataset): 82 | ''' 83 | Reading the Json file of iNaturalist17 84 | ''' 85 | def __init__(self, root, txt, transform=None): 86 | self.img_path = [] 87 | self.labels = [] 88 | self.transform = transform 89 | 90 | with open(txt,'r',encoding='utf8')as fp: 91 | json_data = json.load(fp) 92 | images = json_data["images"] 93 | labels = json_data["annotations"] 94 | for i in range(len(images)): 95 | 96 | self.img_path.append(os.path.join(root, images[i]["file_name"])) 97 | self.labels.append(int(labels[i]["category_id"])) 98 | 99 | def __len__(self): 100 | return len(self.labels) 101 | 102 | def __getitem__(self, index): 103 | 104 | path = self.img_path[index] 105 | label = self.labels[index] 106 | 107 | with open(path, 'rb') as f: 108 | sample = Image.open(f).convert('RGB') 109 | 110 | if self.transform is not None: 111 | sample = self.transform(sample) 112 | 113 | return sample, label 114 | 115 | 116 | def load_data_distributed(data_root, dataset, phase, batch_size, sampler_dic=None, num_workers=4, test_open=False, shuffle=True): 117 | 118 | # if phase == 'train_plain': 119 | # txt_split = 'train' 120 | # elif phase == 'train_val': 121 | # txt_split = 'val' 122 | # phase = 'train' 123 | # else: 124 | # txt_split = phase 125 | if dataset == "iNaturalist17": 126 | txt = 'data/%s_%s.json' % (dataset, phase) 127 | else: 128 | txt = 'data/%s_%s.txt' % (dataset, phase) 129 | 130 | print('Loading data from %s' % txt) 131 | 132 | if dataset == 'iNaturalist18': 133 | print('===> Loading iNaturalist18 statistics') 134 | key = 'iNaturalist18' 135 | elif dataset == 'iNaturalist17': 136 | print('===> Loading iNaturalist17 statistics') 137 | key = 'iNaturalist18' 138 | else: 139 | key = 'default' 140 | 141 | rgb_mean, rgb_std = RGB_statistics[key]['mean'], RGB_statistics[key]['std'] 142 | 143 | if phase not in ['train', 'val']: 144 | transform = get_data_transform('test', rgb_mean, rgb_std, key) 145 | else: 146 | transform = get_data_transform(phase, rgb_mean, rgb_std, key) 147 | 148 | print('Use data transformation:', transform) 149 | 150 | if dataset == "iNaturalist17": 151 | set_ = LT_Dataset_iNat17(data_root, txt, transform) 152 | else: 153 | set_ = LT_Dataset(data_root, txt, transform) 154 | 155 | 156 | return set_ 157 | -------------------------------------------------------------------------------- /ImageNet_iNat/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -* 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import math 8 | import torch.nn.functional as F 9 | import pdb 10 | 11 | def MI(outputs_target): 12 | batch_size = outputs_target.size(0) 13 | softmax_outs_t = nn.Softmax(dim=1)(outputs_target) 14 | avg_softmax_outs_t = torch.sum(softmax_outs_t, dim=0) / float(batch_size) 15 | log_avg_softmax_outs_t = torch.log(avg_softmax_outs_t) 16 | item1 = -torch.sum(avg_softmax_outs_t * log_avg_softmax_outs_t) 17 | item2 = -torch.sum(softmax_outs_t * torch.log(softmax_outs_t)) / float(batch_size) 18 | return item1, item2 19 | 20 | class EstimatorMean(): 21 | def __init__(self, feature_num, class_num): 22 | super(EstimatorMean, self).__init__() 23 | self.class_num = class_num 24 | self.Ave = torch.zeros(class_num, feature_num).cuda() 25 | self.Amount = torch.zeros(class_num).cuda() 26 | 27 | def update_Mean(self, features, labels): 28 | N = features.size(0) 29 | C = self.class_num 30 | A = features.size(1) 31 | 32 | NxCxFeatures = features.view(N, 1, A).expand(N, C, A) 33 | onehot = torch.zeros(N, C).cuda() 34 | onehot.scatter_(1, labels.view(-1, 1), 1) 35 | 36 | NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A) 37 | 38 | features_by_sort = NxCxFeatures.mul(NxCxA_onehot) 39 | 40 | Amount_CxA = NxCxA_onehot.sum(0) 41 | Amount_CxA[Amount_CxA == 0] = 1 42 | 43 | ave_CxA = features_by_sort.sum(0) / Amount_CxA 44 | 45 | sum_weight_CV = onehot.sum(0).view(C, 1, 1).expand(C, A, A) 46 | 47 | sum_weight_AV = onehot.sum(0).view(C, 1).expand(C, A) 48 | 49 | weight_CV = sum_weight_CV.div(sum_weight_CV + self.Amount.view(C, 1, 1).expand(C, A, A)) 50 | weight_CV[weight_CV != weight_CV] = 0 51 | 52 | weight_AV = sum_weight_AV.div(sum_weight_AV + self.Amount.view(C, 1).expand(C, A)) 53 | weight_AV[weight_AV != weight_AV] = 0 54 | 55 | self.Ave = (self.Ave.mul(1 - weight_AV) + ave_CxA.mul(weight_AV)).detach() 56 | self.Amount += onehot.sum(0) 57 | 58 | # the estimation of covariance matrix 59 | class EstimatorCV(): 60 | def __init__(self, feature_num, class_num): 61 | super(EstimatorCV, self).__init__() 62 | self.class_num = class_num 63 | self.CoVariance = torch.zeros(class_num, feature_num).cuda() 64 | self.Ave = torch.zeros(class_num, feature_num).cuda() 65 | self.Amount = torch.zeros(class_num).cuda() 66 | 67 | def update_CV(self, features, labels): 68 | N = features.size(0) 69 | C = self.class_num 70 | A = features.size(1) 71 | 72 | NxCxFeatures = features.view(N, 1, A).expand(N, C, A) 73 | # onehot = torch.zeros(N, C).cuda() 74 | onehot = torch.zeros(N, C).cuda() 75 | onehot.scatter_(1, labels.view(-1, 1), 1) 76 | 77 | NxCxA_onehot = onehot.view(N, C, 1).expand(N, C, A) 78 | 79 | features_by_sort = NxCxFeatures.mul(NxCxA_onehot) 80 | 81 | Amount_CxA = NxCxA_onehot.sum(0) 82 | Amount_CxA[Amount_CxA == 0] = 1 83 | 84 | ave_CxA = features_by_sort.sum(0) / Amount_CxA 85 | 86 | var_temp = features_by_sort - ave_CxA.expand(N, C, A).mul(NxCxA_onehot) 87 | 88 | # var_temp = torch.bmm(var_temp.permute(1, 2, 0), var_temp.permute(1, 0, 2)).div(Amount_CxA.view(C, A, 1).expand(C, A, A)) 89 | var_temp = torch.mul(var_temp.permute(1, 2, 0), var_temp.permute(1, 2, 0)).sum(2).div(Amount_CxA.view(C, A)) 90 | 91 | # sum_weight_CV = onehot.sum(0).view(C, 1, 1).expand(C, A, A) 92 | sum_weight_CV = onehot.sum(0).view(C, 1).expand(C, A) 93 | 94 | sum_weight_AV = onehot.sum(0).view(C, 1).expand(C, A) 95 | 96 | # weight_CV = sum_weight_CV.div(sum_weight_CV + self.Amount.view(C, 1, 1).expand(C, A, A)) 97 | weight_CV = sum_weight_CV.div(sum_weight_CV + self.Amount.view(C, 1).expand(C, A)) 98 | weight_CV[weight_CV != weight_CV] = 0 99 | 100 | weight_AV = sum_weight_AV.div(sum_weight_AV + self.Amount.view(C, 1).expand(C, A)) 101 | weight_AV[weight_AV != weight_AV] = 0 102 | 103 | additional_CV = weight_CV.mul(1 - weight_CV).mul( 104 | torch.mul( 105 | (self.Ave - ave_CxA).view(C, A), 106 | (self.Ave - ave_CxA).view(C, A) 107 | ) 108 | ) 109 | # (self.Ave - ave_CxA).pow(2) 110 | 111 | self.CoVariance = (self.CoVariance.mul(1 - weight_CV) + var_temp.mul( 112 | weight_CV)).detach() + additional_CV.detach() 113 | self.Ave = (self.Ave.mul(1 - weight_AV) + ave_CxA.mul(weight_AV)).detach() 114 | self.Amount += onehot.sum(0) 115 | 116 | class Loss_meta(nn.Module): 117 | def __init__(self, feature_num, class_num): 118 | super(Loss_meta, self).__init__() 119 | self.source_estimator = EstimatorCV(feature_num, class_num) 120 | self.class_num = class_num 121 | self.cross_entropy = nn.CrossEntropyLoss() 122 | 123 | def MetaSAug(self, fc, features, y_s, labels_s, s_cv_matrix, ratio): 124 | N = features.size(0) 125 | C = self.class_num 126 | A = features.size(1) 127 | 128 | weight_m = fc 129 | 130 | NxW_ij = weight_m.expand(N, C, A) 131 | NxW_kj = torch.gather(NxW_ij, 1, labels_s.view(N, 1, 1).expand(N, C, A)) 132 | 133 | CV_temp = s_cv_matrix[labels_s] 134 | sigma2 = ratio * (weight_m - NxW_kj).pow(2).mul(CV_temp.view(N, 1, A).expand(N, C, A)).sum(2) 135 | aug_result = y_s + 0.5 * sigma2 136 | return aug_result 137 | 138 | def forward(self, fc, features_source, y_s, labels_source, ratio, weights, cv, mode): 139 | 140 | aug_y = self.MetaSAug(fc, features_source, y_s, labels_source, cv, \ 141 | ratio) 142 | 143 | if mode == "update": 144 | self.source_estimator.update_CV(features_source.detach(), labels_source) 145 | loss = F.cross_entropy(aug_y, labels_source, weight=weights) 146 | else: 147 | loss = F.cross_entropy(aug_y, labels_source, weight=weights) 148 | return loss 149 | 150 | def get_cv(self): 151 | return self.source_estimator.CoVariance 152 | 153 | def update_cv(self, cv): 154 | self.source_estimator.CoVariance = cv 155 | -------------------------------------------------------------------------------- /ImageNet_iNat/resnet_meta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch.autograd import Variable 6 | import torch.nn.init as init 7 | 8 | 9 | def to_var(x, requires_grad=True): 10 | if torch.cuda.is_available(): 11 | x = x.cuda() 12 | return Variable(x, requires_grad=requires_grad) 13 | 14 | 15 | class MetaModule(nn.Module): 16 | # adopted from: Adrien Ecoffet https://github.com/AdrienLE 17 | def params(self): 18 | for name, param in self.named_params(self): 19 | yield param 20 | 21 | def named_leaves(self): 22 | return [] 23 | 24 | def named_submodules(self): 25 | return [] 26 | 27 | def named_params(self, curr_module=None, memo=None, prefix=''): 28 | if memo is None: 29 | memo = set() 30 | 31 | if hasattr(curr_module, 'named_leaves'): 32 | for name, p in curr_module.named_leaves(): 33 | if p is not None and p not in memo: 34 | memo.add(p) 35 | yield prefix + ('.' if prefix else '') + name, p 36 | else: 37 | for name, p in curr_module._parameters.items(): 38 | if p is not None and p not in memo: 39 | memo.add(p) 40 | yield prefix + ('.' if prefix else '') + name, p 41 | 42 | for mname, module in curr_module.named_children(): 43 | submodule_prefix = prefix + ('.' if prefix else '') + mname 44 | for name, p in self.named_params(module, memo, submodule_prefix): 45 | yield name, p 46 | 47 | def update_params(self, lr_inner, first_order=False, source_params=None, detach=False): 48 | if source_params is not None: 49 | for tgt, src in zip(self.named_params(self), source_params): 50 | name_t, param_t = tgt 51 | # name_s, param_s = src 52 | # grad = param_s.grad 53 | # name_s, param_s = src 54 | grad = src 55 | if first_order: 56 | grad = to_var(grad.detach().data) 57 | tmp = param_t - lr_inner * grad 58 | self.set_param(self, name_t, tmp) 59 | else: 60 | 61 | for name, param in self.named_params(self): 62 | if not detach: 63 | grad = param.grad 64 | if first_order: 65 | grad = to_var(grad.detach().data) 66 | tmp = param - lr_inner * grad 67 | self.set_param(self, name, tmp) 68 | else: 69 | param = param.detach_() # https://blog.csdn.net/qq_39709535/article/details/81866686 70 | self.set_param(self, name, param) 71 | 72 | def set_param(self, curr_mod, name, param): 73 | if '.' in name: 74 | n = name.split('.') 75 | module_name = n[0] 76 | rest = '.'.join(n[1:]) 77 | for name, mod in curr_mod.named_children(): 78 | if module_name == name: 79 | self.set_param(mod, rest, param) 80 | break 81 | else: 82 | setattr(curr_mod, name, param) 83 | 84 | def detach_params(self): 85 | for name, param in self.named_params(self): 86 | self.set_param(self, name, param.detach()) 87 | 88 | def copy(self, other, same_var=False): 89 | for name, param in other.named_params(): 90 | if not same_var: 91 | param = to_var(param.data.clone(), requires_grad=True) 92 | self.set_param(name, param) 93 | 94 | 95 | class MetaLinear(MetaModule): 96 | def __init__(self, *args, **kwargs): 97 | super().__init__() 98 | ignore = nn.Linear(*args, **kwargs) 99 | 100 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 101 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 102 | 103 | def forward(self, x): 104 | return F.linear(x, self.weight, self.bias) 105 | 106 | def named_leaves(self): 107 | return [('weight', self.weight), ('bias', self.bias)] 108 | 109 | # This layer will be used when the loss function is LDAM 110 | class MetaLinear_Norm(MetaModule): 111 | def __init__(self, *args, **kwargs): 112 | super().__init__() 113 | temp = nn.Linear(*args, **kwargs) 114 | # import pdb; pdb.set_trace() 115 | temp.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 116 | self.register_buffer('weight', to_var(temp.weight.data.t(), requires_grad=True)) 117 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 118 | #self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 119 | 120 | def forward(self, x): 121 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 122 | return out 123 | # return F.linear(x, self.weight, self.bias) 124 | 125 | def named_leaves(self): 126 | return [('weight', self.weight)]#, ('bias', self.bias)] 127 | 128 | 129 | class MetaConv2d(MetaModule): 130 | def __init__(self, *args, **kwargs): 131 | super().__init__() 132 | ignore = nn.Conv2d(*args, **kwargs) 133 | 134 | self.in_channels = ignore.in_channels 135 | self.out_channels = ignore.out_channels 136 | self.stride = ignore.stride 137 | self.padding = ignore.padding 138 | self.dilation = ignore.dilation 139 | self.groups = ignore.groups 140 | self.kernel_size = ignore.kernel_size 141 | 142 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 143 | 144 | if ignore.bias is not None: 145 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 146 | else: 147 | self.register_buffer('bias', None) 148 | 149 | def forward(self, x): 150 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 151 | 152 | def named_leaves(self): 153 | return [('weight', self.weight), ('bias', self.bias)] 154 | 155 | 156 | class MetaConvTranspose2d(MetaModule): 157 | def __init__(self, *args, **kwargs): 158 | super().__init__() 159 | ignore = nn.ConvTranspose2d(*args, **kwargs) 160 | 161 | self.stride = ignore.stride 162 | self.padding = ignore.padding 163 | self.dilation = ignore.dilation 164 | self.groups = ignore.groups 165 | 166 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 167 | 168 | if ignore.bias is not None: 169 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 170 | else: 171 | self.register_buffer('bias', None) 172 | 173 | def forward(self, x, output_size=None): 174 | output_padding = self._output_padding(x, output_size) 175 | return F.conv_transpose2d(x, self.weight, self.bias, self.stride, self.padding, 176 | output_padding, self.groups, self.dilation) 177 | 178 | def named_leaves(self): 179 | return [('weight', self.weight), ('bias', self.bias)] 180 | 181 | 182 | class MetaBatchNorm1d(MetaModule): 183 | def __init__(self, *args, **kwargs): 184 | super(MetaBatchNorm1d, self).__init__() 185 | ignore = nn.BatchNorm1d(*args, **kwargs) 186 | 187 | self.num_features = ignore.num_features 188 | self.eps = ignore.eps 189 | self.momentum = ignore.momentum 190 | self.affine = ignore.affine 191 | self.track_running_stats = ignore.track_running_stats 192 | 193 | if self.affine: 194 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 195 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 196 | 197 | if self.track_running_stats: 198 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 199 | self.register_buffer('running_var', torch.ones(self.num_features)) 200 | self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) 201 | else: 202 | self.register_parameter('running_mean', None) 203 | self.register_parameter('running_var', None) 204 | 205 | def reset_running_stats(self): 206 | if self.track_running_stats: 207 | self.running_mean.zero_() 208 | self.running_var.fill_(1) 209 | self.num_batches_tracked.zero_() 210 | 211 | def reset_parameters(self): 212 | self.reset_running_stats() 213 | if self.affine: 214 | self.weight.data.uniform_() 215 | self.bias.data.zero_() 216 | 217 | def _check_input_dim(self, input): 218 | if input.dim() != 2 and input.dim() != 3: 219 | raise ValueError('expected 2D or 3D input (got {}D input)' 220 | .format(input.dim())) 221 | 222 | def forward(self, x): 223 | self._check_input_dim(x) 224 | exponential_average_factor = 0.0 225 | 226 | if self.training and self.track_running_stats: 227 | self.num_batches_tracked += 1 228 | if self.momentum is None: # use cumulative moving average 229 | exponential_average_factor = 1.0 / self.num_batches_tracked.item() 230 | else: # use exponential moving average 231 | exponential_average_factor = self.momentum 232 | 233 | return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 234 | self.training or not self.track_running_stats, self.momentum, self.eps) 235 | 236 | def named_leaves(self): 237 | return [('weight', self.weight), ('bias', self.bias)] 238 | 239 | def extra_repr(self): 240 | return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \ 241 | 'track_running_stats={track_running_stats}'.format(**self.__dict__) 242 | 243 | def _load_from_state_dict(self, state_dict, prefix, metadata, strict, 244 | missing_keys, unexpected_keys, error_msgs): 245 | version = metadata.get('version', None) 246 | 247 | if (version is None or version < 2) and self.track_running_stats: 248 | # at version 2: added num_batches_tracked buffer 249 | # this should have a default value of 0 250 | num_batches_tracked_key = prefix + 'num_batches_tracked' 251 | if num_batches_tracked_key not in state_dict: 252 | state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) 253 | 254 | super(MetaBatchNorm1d, self)._load_from_state_dict( 255 | state_dict, prefix, metadata, strict, 256 | missing_keys, unexpected_keys, error_msgs) 257 | 258 | 259 | class MetaBatchNorm2d(MetaModule): 260 | def __init__(self, *args, **kwargs): 261 | super().__init__() 262 | ignore = nn.BatchNorm2d(*args, **kwargs) 263 | 264 | self.num_features = ignore.num_features 265 | self.eps = ignore.eps 266 | self.momentum = ignore.momentum 267 | self.affine = ignore.affine 268 | self.track_running_stats = ignore.track_running_stats 269 | 270 | if self.affine: 271 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 272 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 273 | 274 | if self.track_running_stats: 275 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 276 | self.register_buffer('running_var', torch.ones(self.num_features)) 277 | else: 278 | self.register_parameter('running_mean', None) 279 | self.register_parameter('running_var', None) 280 | 281 | def forward(self, x): 282 | return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 283 | self.training or not self.track_running_stats, self.momentum, self.eps) 284 | 285 | def named_leaves(self): 286 | return [('weight', self.weight), ('bias', self.bias)] 287 | 288 | 289 | def _weights_init(m): 290 | classname = m.__class__.__name__ 291 | # print(classname) 292 | if isinstance(m, MetaLinear) or isinstance(m, MetaConv2d): 293 | init.kaiming_normal(m.weight) 294 | 295 | class LambdaLayer(MetaModule): 296 | def __init__(self, lambd): 297 | super(LambdaLayer, self).__init__() 298 | self.lambd = lambd 299 | 300 | def forward(self, x): 301 | return self.lambd(x) 302 | 303 | 304 | class BasicBlock(MetaModule): 305 | expansion = 1 306 | 307 | def __init__(self, in_planes, planes, stride=1, option='A'): 308 | super(BasicBlock, self).__init__() 309 | self.conv1 = MetaConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 310 | self.bn1 = MetaBatchNorm2d(planes) 311 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 312 | self.bn2 = MetaBatchNorm2d(planes) 313 | 314 | self.shortcut = nn.Sequential() 315 | if stride != 1 or in_planes != planes: 316 | if option == 'A': 317 | """ 318 | For CIFAR10 ResNet paper uses option A. 319 | """ 320 | self.shortcut = LambdaLayer(lambda x: 321 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 322 | elif option == 'B': 323 | self.shortcut = nn.Sequential( 324 | MetaConv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 325 | MetaBatchNorm2d(self.expansion * planes) 326 | ) 327 | 328 | def forward(self, x): 329 | out = F.relu(self.bn1(self.conv1(x))) 330 | out = self.bn2(self.conv2(out)) 331 | out += self.shortcut(x) 332 | out = F.relu(out) 333 | return out 334 | 335 | 336 | class BottleneckMeta(MetaModule): 337 | expansion = 4 338 | 339 | def __init__(self, in_planes, planes, stride=1, downsample=None): 340 | super(BottleneckMeta, self).__init__() 341 | self.conv1 = MetaConv2d(in_planes, planes, kernel_size=1, bias=False) 342 | self.bn1 = MetaBatchNorm2d(planes) 343 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 344 | self.bn2 = MetaBatchNorm2d(planes) 345 | self.conv3 = MetaConv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 346 | self.bn3 = MetaBatchNorm2d(planes * self.expansion) 347 | self.downsample = downsample 348 | self.stride = stride 349 | 350 | def forward(self, x): 351 | residual = x 352 | 353 | out = self.conv1(x) 354 | out = F.relu(self.bn1(out)) 355 | 356 | out = self.conv2(out) 357 | out = F.relu(self.bn2(out)) 358 | 359 | out = self.conv3(out) 360 | out = self.bn3(out) 361 | 362 | if self.downsample is not None: 363 | residual = self.downsample(x) 364 | 365 | out += residual 366 | out = F.relu(out) 367 | return out 368 | 369 | 370 | class ResNet32(MetaModule): 371 | def __init__(self, num_classes, block=BasicBlock, num_blocks=[5, 5, 5]): 372 | super(ResNet32, self).__init__() 373 | self.in_planes = 16 374 | 375 | self.conv1 = MetaConv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 376 | self.bn1 = MetaBatchNorm2d(16) 377 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 378 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 379 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 380 | self.linear = MetaLinear(64, num_classes) # MetaLinear_Norm(64,num_classes,bias=False) # 381 | 382 | self.apply(_weights_init) 383 | 384 | def _make_layer(self, block, planes, num_blocks, stride): 385 | strides = [stride] + [1]*(num_blocks-1) 386 | layers = [] 387 | for stride in strides: 388 | layers.append(block(self.in_planes, planes, stride)) 389 | self.in_planes = planes * block.expansion 390 | 391 | return nn.Sequential(*layers) 392 | 393 | def forward(self, x): 394 | out = F.relu(self.bn1(self.conv1(x))) 395 | out = self.layer1(out) 396 | out = self.layer2(out) 397 | out = self.layer3(out) 398 | out = F.avg_pool2d(out, out.size()[3]) 399 | out = out.view(out.size(0), -1) 400 | y = self.linear(out) 401 | return out, y 402 | 403 | 404 | class FeatureMeta(MetaModule): 405 | def __init__(self, block, num_blocks, use_fc=False, dropout=None): 406 | super(FeatureMeta, self).__init__() 407 | self.inplanes = 64 408 | 409 | self.conv1 = MetaConv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 410 | self.bn1 = MetaBatchNorm2d(64) 411 | self.relu = nn.ReLU(inplace=True) 412 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 413 | # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 414 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 415 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 416 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 417 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 418 | # self.avgpool = nn.AvgPool2d(7, stride=1) 419 | self.avgpool = nn.AvgPool2d(7, stride=1) 420 | 421 | self.use_fc = use_fc 422 | self.use_dropout = True if dropout else False 423 | 424 | #if self.use_fc: 425 | #print('Using fc.') 426 | #self.fc = MetaLinear(512*block.expansion, 8142) 427 | 428 | if self.use_dropout: 429 | print('Using dropout') 430 | self.dropout = nn.Dropout(p=dropout) 431 | 432 | for m in self.modules(): 433 | if isinstance(m, MetaConv2d): 434 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 435 | m.weight.data.normal_(0, math.sqrt(2. / n)) 436 | elif isinstance(m, MetaBatchNorm2d): 437 | m.weight.data.fill_(1) 438 | m.bias.data.zero_() 439 | 440 | 441 | def _make_layer(self, block, planes, num_blocks, stride): 442 | downsample = None 443 | if stride != 1 or self.inplanes != planes * block.expansion: 444 | downsample = nn.Sequential( 445 | MetaConv2d(self.inplanes, planes * block.expansion, 446 | kernel_size=1, stride=stride, bias=False), 447 | MetaBatchNorm2d(planes * block.expansion), 448 | ) 449 | 450 | layers = [] 451 | layers.append(block(self.inplanes, planes, stride, downsample)) 452 | self.inplanes = planes * block.expansion 453 | for i in range(1, num_blocks): 454 | layers.append(block(self.inplanes, planes)) 455 | 456 | return nn.Sequential(*layers) 457 | 458 | def forward(self, x): 459 | x = self.conv1(x) 460 | x = self.bn1(x) 461 | x = self.relu(x) 462 | x = self.maxpool(x) 463 | 464 | x = self.layer1(x) 465 | x = self.layer2(x) 466 | x = self.layer3(x) 467 | x = self.layer4(x) 468 | 469 | x = self.avgpool(x) 470 | x = x.view(x.size(0), -1) 471 | 472 | #if self.use_fc: 473 | #y = F.relu(self.fc(x)) 474 | 475 | if self.use_dropout: 476 | x = self.dropout(x) 477 | 478 | return x 479 | 480 | class FCMeta(MetaModule): 481 | def __init__(self, feature_dim=2048, output_dim=1000): 482 | super(FCMeta, self).__init__() 483 | self.fc = MetaLinear(feature_dim, output_dim) 484 | 485 | def forward(self, x): 486 | y = self.fc(x) 487 | return y 488 | 489 | 490 | class FCModel(nn.Module): 491 | def __init__(self, feature_dim, output_dim=1000): 492 | super(FCModel, self).__init__() 493 | self.fc = nn.Linear(feature_dim, output_dim) 494 | 495 | def forward(self, x): 496 | y = self.fc(x) 497 | return y 498 | 499 | 500 | -------------------------------------------------------------------------------- /ImageNet_iNat/scripts/test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=7 python3 test.py --loading_path checkpoints/MetaSAug_ImageNet_LT.tar --dataset ImageNet_LT --num_classes 1000 --data_root ../ImageNet 2 | CUDA_VISIBLE_DEVICES=7 python3 test.py --loading_path checkpoints/MetaSAug_iNat18.tar --dataset iNaturalist18 --num_classes 8142 --data_root ../iNaturalist18 3 | -------------------------------------------------------------------------------- /ImageNet_iNat/scripts/train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --nproc_per_node=4 --master_port 53212 train.py --lr 0.0003 --meta_lr 0.1 --workers 0 --batch_size 256 --epochs 20 --dataset ImageNet_LT --num_classes 1000 --data_root ../ImageNet 2 | -------------------------------------------------------------------------------- /ImageNet_iNat/test.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | import argparse 5 | import random 6 | import copy 7 | import torch 8 | import torchvision 9 | import numpy as np 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | import torchvision.transforms as transforms 13 | from data_utils import * 14 | from dataloader import load_data_distributed 15 | import shutil 16 | from ResNet import * 17 | import resnet_meta 18 | 19 | import multiprocessing 20 | import torch.nn.parallel 21 | import torch.nn as nn 22 | from collections import Counter 23 | import time 24 | parser = argparse.ArgumentParser(description='Imbalanced Example') 25 | parser.add_argument('--dataset', default='iNaturalist18', type=str, 26 | help='dataset') 27 | parser.add_argument('--data_root', default='/data1/TL/data/iNaturalist18', type=str) 28 | parser.add_argument('--batch_size', type=int, default=32, metavar='N', 29 | help='input batch size for training (default: 64)') 30 | parser.add_argument('--num_classes', type=int, default=8142) 31 | parser.add_argument('--num_meta', type=int, default=10, 32 | help='The number of meta data for each class.') 33 | parser.add_argument('--test_batch_size', type=int, default=512, metavar='N', 34 | help='input batch size for testing (default: 100)') 35 | parser.add_argument('--epochs', type=int, default=200, metavar='N', 36 | help='number of epochs to train') 37 | parser.add_argument('--lr', '--learning-rate', default=1e-1, type=float, 38 | help='initial learning rate') 39 | parser.add_argument('--workers', default=16, type=int) 40 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 41 | parser.add_argument('--nesterov', default=True, type=bool, help='nesterov momentum') 42 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 43 | help='weight decay (default: 5e-4)') 44 | parser.add_argument('--no-cuda', action='store_true', default=False, 45 | help='disables CUDA training') 46 | parser.add_argument('--split', type=int, default=1000) 47 | parser.add_argument('--seed', type=int, default=42, metavar='S', 48 | help='random seed (default: 42)') 49 | parser.add_argument('--print-freq', '-p', default=1000, type=int, 50 | help='print frequency (default: 10)') 51 | parser.add_argument('--gpu', default=None, type=int) 52 | parser.add_argument('--lam', default=0.25, type=float) 53 | parser.add_argument('--local_rank', default=0, type=int) 54 | parser.add_argument('--meta_lr', default=0.1, type=float) 55 | parser.add_argument('--loading_path', default=None, type=str) 56 | 57 | args = parser.parse_args() 58 | # print(args) 59 | for arg in vars(args): 60 | print("{}={}".format(arg, getattr(args, arg))) 61 | 62 | 63 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 64 | 65 | kwargs = {'num_workers': 1, 'pin_memory': True} 66 | use_cuda = not args.no_cuda and torch.cuda.is_available() 67 | 68 | cudnn.benchmark = True 69 | cudnn.enabled = True 70 | torch.manual_seed(args.seed) 71 | device = torch.device("cuda" if use_cuda else "cpu") 72 | 73 | if args.dataset == 'ImageNet_LT': 74 | val_set = load_data_distributed(data_root=args.data_root, dataset=args.dataset, phase="test", batch_size=args.test_batch_size, 75 | num_workers=args.workers, test_open=False, shuffle=False) 76 | else: 77 | val_set = load_data_distributed(data_root=args.data_root, dataset=args.dataset, phase="val", batch_size=args.test_batch_size, 78 | num_workers=args.workers, test_open=False, shuffle=False) 79 | 80 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.test_batch_size, shuffle=False, num_workers=0, pin_memory=True) 81 | 82 | 83 | np.random.seed(42) 84 | random.seed(42) 85 | torch.manual_seed(args.seed) 86 | 87 | data_list = {} 88 | data_list_num = [] 89 | best_prec1 = 0 90 | 91 | model_dict = {"ImageNet_LT": "models/resnet50_uniform_e90.pth", 92 | "iNaturalist18": "models/iNat18/resnet50_uniform_e200.pth"} 93 | 94 | def main(): 95 | global args, best_prec1 96 | args = parser.parse_args() 97 | 98 | cudnn.benchmark = True 99 | 100 | print(torch.cuda.is_available()) 101 | print(torch.cuda.device_count()) 102 | model = FCModel(2048, args.num_classes) 103 | model = model.cuda() 104 | loading_path = args.loading_path 105 | weights = torch.load(loading_path, map_location=torch.device("cpu")) 106 | weights_c = weights['state_dict']['classifier'] 107 | weights_c = {k: weights_c[k] for k in model.state_dict()} 108 | for k in model.state_dict(): 109 | if k not in weights: 110 | print("Loading Weights Warning.") 111 | 112 | model.load_state_dict(weights_c) 113 | feature_extractor = create_model(stage1_weights=True, dataset=args.dataset, log_dir=model_dict[args.dataset]) 114 | weight_f = weights['state_dict']['feature'] 115 | feature_extractor.load_state_dict(weight_f) 116 | feature_extractor = feature_extractor.cuda() 117 | feature_extractor.eval() 118 | 119 | prec1, preds, gt_labels = validate(val_loader, model, feature_extractor, nn.CrossEntropyLoss()) 120 | 121 | print('Accuracy: ', prec1) 122 | 123 | def validate(val_loader, model, feature_extractor, criterion): 124 | """Perform validation on the validation set""" 125 | batch_time = AverageMeter() 126 | losses = AverageMeter() 127 | top1 = AverageMeter() 128 | 129 | true_labels = [] 130 | preds = [] 131 | 132 | torch.cuda.empty_cache() 133 | model.eval() 134 | end = time.time() 135 | for i, (input, target) in enumerate(val_loader): 136 | input = input.cuda(non_blocking=True) 137 | target = target.cuda(non_blocking=True) 138 | # compute output 139 | with torch.no_grad(): 140 | feature = feature_extractor(input) 141 | output = model(feature) 142 | loss = criterion(output, target) 143 | 144 | output_numpy = output.data.cpu().numpy() 145 | preds_output = list(output_numpy.argmax(axis=1)) 146 | 147 | true_labels += list(target.data.cpu().numpy()) 148 | preds += preds_output 149 | 150 | 151 | # measure accuracy and record loss 152 | prec1 = accuracy(output.data, target, topk=(1,))[0] 153 | losses.update(loss.data.item(), input.size(0)) 154 | top1.update(prec1.item(), input.size(0)) 155 | 156 | # measure elapsed time 157 | batch_time.update(time.time() - end) 158 | end = time.time() 159 | 160 | if i % args.print_freq == 0: 161 | print('Test: [{0}/{1}]\t' 162 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 163 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 164 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 165 | i, len(val_loader), batch_time=batch_time, loss=losses, 166 | top1=top1)) 167 | 168 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 169 | # log to TensorBoard 170 | # import pdb; pdb.set_trace() 171 | 172 | return top1.avg, preds, true_labels 173 | 174 | 175 | 176 | def to_var(x, requires_grad=True): 177 | if torch.cuda.is_available(): 178 | x = x.cuda() 179 | return Variable(x, requires_grad=requires_grad) 180 | 181 | 182 | class AverageMeter(object): 183 | """Computes and stores the average and current value""" 184 | 185 | def __init__(self): 186 | self.reset() 187 | 188 | def reset(self): 189 | self.val = 0 190 | self.avg = 0 191 | self.sum = 0 192 | self.count = 0 193 | 194 | def update(self, val, n=1): 195 | self.val = val 196 | self.sum += val * n 197 | self.count += n 198 | self.avg = self.sum / self.count 199 | 200 | def accuracy(output, target, topk=(1,)): 201 | """Computes the precision@k for the specified values of k""" 202 | maxk = max(topk) 203 | batch_size = target.size(0) 204 | 205 | _, pred = output.topk(maxk, 1, True, True) 206 | pred = pred.t() 207 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 208 | 209 | res = [] 210 | for k in topk: 211 | correct_k = correct[:k].view(-1).float().sum(0) 212 | res.append(correct_k.mul_(100.0 / batch_size)) 213 | return res 214 | 215 | 216 | def save_checkpoint(args, state, is_best, epoch): 217 | filename = 'checkpoint/' + 'train_' + str(args.dataset) + '/' + str(args.lr) + '_' + str(args.batch_size) + '_' + str(args.meta_lr) + 'epoch' + str(epoch) + '_ckpt.pth.tar' 218 | file_root, _ = os.path.split(filename) 219 | if not os.path.exists(file_root): 220 | os.makedirs(file_root) 221 | torch.save(state, filename) 222 | 223 | 224 | if __name__ == '__main__': 225 | main() 226 | -------------------------------------------------------------------------------- /ImageNet_iNat/train.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | import argparse 5 | import random 6 | import copy 7 | import torch 8 | import torchvision 9 | import numpy as np 10 | import torch.nn.functional as F 11 | from torch.autograd import Variable 12 | import torchvision.transforms as transforms 13 | from data_utils import * 14 | # import resnet 15 | from dataloader import load_data_distributed 16 | import shutil 17 | from ResNet import * 18 | import loss 19 | import multiprocessing 20 | import torch.nn.parallel 21 | import torch.nn as nn 22 | from collections import Counter 23 | import time 24 | parser = argparse.ArgumentParser(description='Imbalanced Example') 25 | parser.add_argument('--dataset', default='iNaturalist18', type=str, 26 | help='dataset') 27 | parser.add_argument('--data_root', default='/data1/TL/data/iNaturalist18', type=str) 28 | parser.add_argument('--batch_size', type=int, default=32, metavar='N', 29 | help='input batch size for training (default: 64)') 30 | parser.add_argument('--num_classes', type=int, default=8142) 31 | parser.add_argument('--num_meta', type=int, default=10, 32 | help='The number of meta data for each class.') 33 | parser.add_argument('--test_batch_size', type=int, default=512, metavar='N', 34 | help='input batch size for testing (default: 100)') 35 | parser.add_argument('--epochs', type=int, default=200, metavar='N', 36 | help='number of epochs to train') 37 | parser.add_argument('--lr', '--learning-rate', default=1e-1, type=float, 38 | help='initial learning rate') 39 | parser.add_argument('--workers', default=16, type=int) 40 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 41 | parser.add_argument('--nesterov', default=True, type=bool, help='nesterov momentum') 42 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 43 | help='weight decay (default: 5e-4)') 44 | parser.add_argument('--no-cuda', action='store_true', default=False, 45 | help='disables CUDA training') 46 | parser.add_argument('--split', type=int, default=1000) 47 | parser.add_argument('--seed', type=int, default=42, metavar='S', 48 | help='random seed (default: 42)') 49 | parser.add_argument('--print-freq', '-p', default=1000, type=int, 50 | help='print frequency (default: 10)') 51 | parser.add_argument('--gpu', default=None, type=int) 52 | parser.add_argument('--lam', default=0.25, type=float) 53 | parser.add_argument('--local_rank', default=0, type=int) 54 | parser.add_argument('--meta_lr', default=0.1, type=float) 55 | 56 | args = parser.parse_args() 57 | # print(args) 58 | for arg in vars(args): 59 | print("{}={}".format(arg, getattr(args, arg))) 60 | 61 | 62 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 63 | 64 | kwargs = {'num_workers': 1, 'pin_memory': True} 65 | use_cuda = not args.no_cuda and torch.cuda.is_available() 66 | 67 | cudnn.benchmark = True 68 | cudnn.enabled = True 69 | torch.manual_seed(args.seed) 70 | device = torch.device("cuda" if use_cuda else "cpu") 71 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 72 | 73 | print(f'num_gpus: {num_gpus}') 74 | args.distributed = num_gpus > 1 75 | print("ditributed: {args.distributed}") 76 | if args.distributed: 77 | torch.cuda.set_device(args.local_rank) 78 | #torch.distributed.init_process_group(backend="nccl", init_method="env://") 79 | torch.distributed.init_process_group(backend="nccl") 80 | args.batch_size = int(args.batch_size / num_gpus) 81 | 82 | 83 | ######### ImageNet dataset 84 | splits = ["train", "val", "test"] 85 | if args.dataset == 'ImageNet_LT': 86 | train_set = load_data_distributed(data_root=args.data_root, dataset=args.dataset, phase="train", batch_size=args.batch_size, 87 | num_workers=args.workers, test_open=False, shuffle=False) 88 | val_set = load_data_distributed(data_root=args.data_root, dataset=args.dataset, phase="test", batch_size=args.test_batch_size, 89 | num_workers=args.workers, test_open=False, shuffle=False) 90 | 91 | meta_set = load_data_distributed(data_root=args.data_root, dataset=args.dataset, phase="val", batch_size=args.batch_size, num_workers=args.workers, test_open=False, shuffle=False) 92 | 93 | else: 94 | train_set = load_data_distributed(data_root=args.data_root, dataset=args.dataset, phase="train", batch_size=args.batch_size, 95 | num_workers=args.workers, test_open=False, shuffle=False) 96 | val_set = load_data_distributed(data_root=args.data_root, dataset=args.dataset, phase="val", batch_size=args.test_batch_size, 97 | num_workers=args.workers, test_open=False, shuffle=False) 98 | meta_set = train_set 99 | 100 | if args.dataset == 'iNaturalist17': 101 | meta_set, _ = build_dataset(meta_set, 5, args.num_classes) 102 | elif args.dataset == 'iNaturalist18': 103 | meta_set, _ = build_dataset(meta_set, 2, args.num_classes) 104 | else: 105 | meta_set, _ = build_dataset(meta_set, 10, args.num_classes) 106 | 107 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) 108 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=0, 109 | pin_memory=True, sampler=train_sampler) 110 | val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.test_batch_size, shuffle=False, num_workers=0, pin_memory=True) 111 | meta_sampler = torch.utils.data.distributed.DistributedSampler(meta_set) 112 | meta_loader = torch.utils.data.DataLoader(meta_set, batch_size=args.batch_size, shuffle=(meta_sampler is None), num_workers=0, 113 | pin_memory=True, sampler=meta_sampler) 114 | 115 | 116 | np.random.seed(42) 117 | random.seed(42) 118 | torch.manual_seed(args.seed) 119 | classe_labels = range(args.num_classes) 120 | 121 | data_list = {} 122 | data_list_num = [] 123 | num = Counter(train_loader.dataset.labels) 124 | data_list_num = [0] * args.num_classes 125 | for key in num: 126 | data_list_num[key] = num[key] 127 | 128 | beta = 0.9999 129 | effective_num = 1.0 - np.power(beta, data_list_num) 130 | per_cls_weights = (1.0 - beta) / np.array(effective_num) 131 | per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(data_list_num) 132 | per_cls_weights = torch.FloatTensor(per_cls_weights).cuda() 133 | 134 | model_dict = {"ImageNet_LT": "models/resnet50_uniform_e90.pth", 135 | "iNaturalist18": "models/iNat18/resnet50_uniform_e200.pth"} 136 | 137 | 138 | def main(): 139 | global args 140 | args = parser.parse_args() 141 | 142 | cudnn.benchmark = True 143 | print(torch.cuda.is_available()) 144 | print(torch.cuda.device_count()) 145 | print(f'local_rank: {args.local_rank}') 146 | model = FCModel(2048, args.num_classes) 147 | model = model.cuda() 148 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) 149 | weights = torch.load(model_dict[args.dataset], map_location=torch.device("cpu")) 150 | weights = weights['state_dict_best']['classifier'] 151 | weights = {k: weights['module.' + k] for k in model.module.state_dict()} 152 | for k in model.module.state_dict(): 153 | if k not in weights: 154 | print("Pretrained Weights Warning.") 155 | 156 | model.module.load_state_dict(weights) 157 | feature_extractor = create_model(stage1_weights=True, dataset=args.dataset, log_dir=model_dict[args.dataset]) 158 | feature_extractor = feature_extractor.cuda() 159 | feature_extractor.eval() 160 | 161 | 162 | torch.autograd.set_detect_anomaly(True) 163 | torch.distributed.barrier() 164 | 165 | optimizer_a = torch.optim.SGD(model.module.parameters(), args.lr, 166 | momentum=args.momentum, nesterov=args.nesterov, 167 | weight_decay=args.weight_decay) 168 | 169 | criterion = loss.Loss_meta(2048, args.num_classes) 170 | for epoch in range(args.epochs): 171 | ratio = args.lam * float(epoch) / float(args.epochs) 172 | train_meta(train_loader, model, feature_extractor, optimizer_a, epoch, criterion, ratio) 173 | 174 | if args.local_rank == 0: 175 | save_checkpoint(args, { 176 | 'epoch': epoch + 1, 177 | 'state_dict': {'feature': feature_extractor.state_dict(), 'classifier': model.module.state_dict()}, 178 | 'optimizer' : optimizer_a.state_dict(), 179 | }, False, epoch) 180 | 181 | def train_meta(train_loader, model, feature_extractor, optimizer_a, epoch, criterion, ratio): 182 | """Experimenting how to train stably in stage-2""" 183 | batch_time = AverageMeter() 184 | losses = AverageMeter() 185 | meta_losses = AverageMeter() 186 | top1 = AverageMeter() 187 | meta_top1 = AverageMeter() 188 | model.train() 189 | weights = torch.tensor(per_cls_weights).float() 190 | for i, (input, target) in enumerate(train_loader): 191 | 192 | input_var = input.cuda(non_blocking=True) 193 | target_var = target.cuda(non_blocking=True) 194 | cv = criterion.get_cv() 195 | cv_var = to_var(cv) 196 | 197 | meta_model = FCMeta(2048, args.num_classes) 198 | meta_model.load_state_dict(model.module.state_dict()) 199 | meta_model.cuda() 200 | 201 | with torch.no_grad(): 202 | feat_hat = feature_extractor(input_var) 203 | y_f_hat = meta_model(feat_hat) 204 | cls_loss_meta = criterion(list(meta_model.fc.named_leaves())[0][1], feat_hat, y_f_hat, target_var, ratio, 205 | weights, cv_var, "none") 206 | meta_model.zero_grad() 207 | grads = torch.autograd.grad(cls_loss_meta, (meta_model.params()), create_graph=True) 208 | meta_lr = args.lr 209 | meta_model.fc.update_params(meta_lr, source_params=grads) 210 | 211 | input_val, target_val = next(iter(meta_loader)) 212 | input_val_var = input_val.cuda(non_blocking=True) 213 | target_val_var = target_val.cuda(non_blocking=True) 214 | 215 | with torch.no_grad(): 216 | feature_val = feature_extractor(input_val_var) 217 | y_val = meta_model(feature_val) 218 | cls_meta = F.cross_entropy(y_val, target_val_var) 219 | grad_cv = torch.autograd.grad(cls_meta, cv_var, only_inputs=True)[0] 220 | new_cv = cv - args.meta_lr * grad_cv 221 | 222 | del grad_cv, grads, meta_model 223 | with torch.no_grad(): 224 | features = feature_extractor(input_var) 225 | predicts = model(features) 226 | cls_loss = criterion(list(model.module.fc.parameters())[0], features, predicts, target_var, ratio, weights, new_cv.detach(), "update") 227 | 228 | prec_train = accuracy(predicts.data, target_var.data, topk=(1,))[0] 229 | 230 | losses.update(cls_loss.item(), input.size(0)) 231 | top1.update(prec_train.item(), input.size(0)) 232 | 233 | optimizer_a.zero_grad() 234 | cls_loss.backward() 235 | optimizer_a.step() 236 | 237 | if i % args.print_freq == 0: 238 | print('Epoch: [{0}][{1}/{2}]\t' 239 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 240 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 241 | epoch, i, len(train_loader), 242 | loss=losses,top1=top1)) 243 | 244 | 245 | 246 | def validate(val_loader, model, feature_extractor, criterion, epoch, local_rank, distributed): 247 | """Perform validation on the validation set""" 248 | batch_time = AverageMeter() 249 | losses = AverageMeter() 250 | top1 = AverageMeter() 251 | 252 | # switch to evaluate mode 253 | 254 | true_labels = [] 255 | preds = [] 256 | if distributed: 257 | model = model.module 258 | torch.cuda.empty_cache() 259 | 260 | model.eval() 261 | end = time.time() 262 | for i, (input, target) in enumerate(val_loader): 263 | input = input.cuda(non_blocking=True) 264 | target = target.cuda(non_blocking=True) 265 | # compute output 266 | with torch.no_grad(): 267 | feature = feature_extractor(input) 268 | output = model(feature) 269 | loss = criterion(output, target) 270 | 271 | output_numpy = output.data.cpu().numpy() 272 | preds_output = list(output_numpy.argmax(axis=1)) 273 | 274 | true_labels += list(target.data.cpu().numpy()) 275 | preds += preds_output 276 | 277 | # measure accuracy and record loss 278 | prec1 = accuracy(output.data, target, topk=(1,))[0] 279 | losses.update(loss.data.item(), input.size(0)) 280 | top1.update(prec1.item(), input.size(0)) 281 | 282 | # measure elapsed time 283 | batch_time.update(time.time() - end) 284 | end = time.time() 285 | 286 | if i % args.print_freq == 0 and local_rank==0: 287 | print('Test: [{0}/{1}]\t' 288 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 289 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 290 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 291 | i, len(val_loader), batch_time=batch_time, loss=losses, 292 | top1=top1)) 293 | 294 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 295 | return top1.avg, preds, true_labels 296 | 297 | 298 | def to_var(x, requires_grad=True): 299 | if torch.cuda.is_available(): 300 | x = x.cuda() 301 | return Variable(x, requires_grad=requires_grad) 302 | 303 | 304 | class AverageMeter(object): 305 | """Computes and stores the average and current value""" 306 | 307 | def __init__(self): 308 | self.reset() 309 | 310 | def reset(self): 311 | self.val = 0 312 | self.avg = 0 313 | self.sum = 0 314 | self.count = 0 315 | 316 | def update(self, val, n=1): 317 | self.val = val 318 | self.sum += val * n 319 | self.count += n 320 | self.avg = self.sum / self.count 321 | 322 | 323 | def accuracy(output, target, topk=(1,)): 324 | """Computes the precision@k for the specified values of k""" 325 | maxk = max(topk) 326 | batch_size = target.size(0) 327 | 328 | _, pred = output.topk(maxk, 1, True, True) 329 | pred = pred.t() 330 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 331 | 332 | res = [] 333 | for k in topk: 334 | correct_k = correct[:k].view(-1).float().sum(0) 335 | res.append(correct_k.mul_(100.0 / batch_size)) 336 | return res 337 | 338 | 339 | def save_checkpoint(args, state, is_best, epoch): 340 | filename = 'checkpoint/' + 'train_' + str(args.dataset) + '/' + str(args.lr) + '_' + str(args.batch_size) + '_' + str(args.meta_lr) + 'epoch' + str(epoch) + '_ckpt.pth.tar' 341 | file_root, _ = os.path.split(filename) 342 | if not os.path.exists(file_root): 343 | os.makedirs(file_root) 344 | torch.save(state, filename) 345 | 346 | 347 | if __name__ == '__main__': 348 | main() 349 | -------------------------------------------------------------------------------- /ImageNet_iNat/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | from sklearn.metrics import f1_score 5 | import torch.nn.functional as F 6 | import importlib 7 | 8 | 9 | def source_import(file_path): 10 | """This function imports python module directly from source code using importlib""" 11 | spec = importlib.util.spec_from_file_location('', file_path) 12 | module = importlib.util.module_from_spec(spec) 13 | spec.loader.exec_module(module) 14 | return module 15 | 16 | 17 | def batch_show(inp, title=None): 18 | """Imshow for Tensor.""" 19 | inp = inp.numpy().transpose((1, 2, 0)) 20 | mean = np.array([0.485, 0.456, 0.406]) 21 | std = np.array([0.229, 0.224, 0.225]) 22 | inp = std * inp + mean 23 | inp = np.clip(inp, 0, 1) 24 | plt.figure(figsize=(20,20)) 25 | plt.imshow(inp) 26 | if title is not None: 27 | plt.title(title) 28 | 29 | 30 | def print_write(print_str, log_file): 31 | print(*print_str) 32 | if log_file is None: 33 | return 34 | with open(log_file, 'a') as f: 35 | print(*print_str, file=f) 36 | 37 | 38 | def init_weights(model, weights_path, caffe=False, classifier=False): 39 | """Initialize weights""" 40 | print('Pretrained %s weights path: %s' % ('classifier' if classifier else 'feature model', weights_path)) 41 | weights = torch.load(weights_path, map_location=torch.device("cpu")) 42 | weights_c = weights['state_dict_best']['classifier'] 43 | weights_f = weights['state_dict_best']['feat_model'] 44 | if not classifier: 45 | if caffe: 46 | weights = {k: weights[k] if k in weights else model.state_dict()[k] 47 | for k in model.state_dict()} 48 | else: 49 | weights = weights['state_dict_best']['feat_model'] 50 | weights = {k: weights['module.' + k] for k in model.state_dict()} 51 | else: 52 | weights = weights['state_dict_best']['classifier'] 53 | weights = {k: weights['module.' + k] if 'module.' + k in weights else model.state_dict()[k] 54 | for k in model.state_dict()} 55 | model.load_state_dict(weights) 56 | return model 57 | 58 | 59 | def shot_acc(preds, labels, train_data, many_shot_thr=100, low_shot_thr=20, acc_per_cls=False): 60 | if isinstance(train_data, np.ndarray): 61 | training_labels = np.array(train_data).astype(int) 62 | else: 63 | training_labels = np.array(train_data.dataset.labels).astype(int) 64 | 65 | if isinstance(preds, torch.Tensor): 66 | preds = preds.detach().cpu().numpy() 67 | labels = labels.detach().cpu().numpy() 68 | elif isinstance(preds, np.ndarray): 69 | pass 70 | else: 71 | raise TypeError('Type ({}) of preds not supported'.format(type(preds))) 72 | train_class_count = [] 73 | test_class_count = [] 74 | class_correct = [] 75 | for l in np.unique(labels): 76 | train_class_count.append(len(training_labels[training_labels == l])) 77 | test_class_count.append(len(labels[labels == l])) 78 | class_correct.append((preds[labels == l] == labels[labels == l]).sum()) 79 | 80 | many_shot = [] 81 | median_shot = [] 82 | low_shot = [] 83 | for i in range(len(train_class_count)): 84 | if train_class_count[i] > many_shot_thr: 85 | many_shot.append((class_correct[i] / test_class_count[i])) 86 | elif train_class_count[i] < low_shot_thr: 87 | low_shot.append((class_correct[i] / test_class_count[i])) 88 | else: 89 | median_shot.append((class_correct[i] / test_class_count[i])) 90 | 91 | if len(many_shot) == 0: 92 | many_shot.append(0) 93 | if len(median_shot) == 0: 94 | median_shot.append(0) 95 | if len(low_shot) == 0: 96 | low_shot.append(0) 97 | 98 | if acc_per_cls: 99 | class_accs = [c / cnt for c, cnt in zip(class_correct, test_class_count)] 100 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot), class_accs 101 | else: 102 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot) 103 | 104 | 105 | def weighted_shot_acc(preds, labels, ws, train_data, many_shot_thr=100, low_shot_thr=20): 106 | training_labels = np.array(train_data.dataset.labels).astype(int) 107 | 108 | if isinstance(preds, torch.Tensor): 109 | preds = preds.detach().cpu().numpy() 110 | labels = labels.detach().cpu().numpy() 111 | elif isinstance(preds, np.ndarray): 112 | pass 113 | else: 114 | raise TypeError('Type ({}) of preds not supported'.format(type(preds))) 115 | train_class_count = [] 116 | test_class_count = [] 117 | class_correct = [] 118 | for l in np.unique(labels): 119 | train_class_count.append(len(training_labels[training_labels == l])) 120 | test_class_count.append(ws[labels==l].sum()) 121 | class_correct.append(((preds[labels==l] == labels[labels==l]) * ws[labels==l]).sum()) 122 | 123 | many_shot = [] 124 | median_shot = [] 125 | low_shot = [] 126 | for i in range(len(train_class_count)): 127 | if train_class_count[i] > many_shot_thr: 128 | many_shot.append((class_correct[i] / test_class_count[i])) 129 | elif train_class_count[i] < low_shot_thr: 130 | low_shot.append((class_correct[i] / test_class_count[i])) 131 | else: 132 | median_shot.append((class_correct[i] / test_class_count[i])) 133 | return np.mean(many_shot), np.mean(median_shot), np.mean(low_shot) 134 | 135 | 136 | def F_measure(preds, labels, openset=False, theta=None): 137 | if openset: 138 | # f1 score for openset evaluation 139 | true_pos = 0. 140 | false_pos = 0. 141 | false_neg = 0. 142 | 143 | for i in range(len(labels)): 144 | true_pos += 1 if preds[i] == labels[i] and labels[i] != -1 else 0 145 | false_pos += 1 if preds[i] != labels[i] and labels[i] != -1 and preds[i] != -1 else 0 146 | false_neg += 1 if preds[i] != labels[i] and labels[i] == -1 else 0 147 | 148 | precision = true_pos / (true_pos + false_pos) 149 | recall = true_pos / (true_pos + false_neg) 150 | return 2 * ((precision * recall) / (precision + recall + 1e-12)) 151 | else: 152 | # Regular f1 score 153 | return f1_score(labels.detach().cpu().numpy(), preds.detach().cpu().numpy(), average='macro') 154 | 155 | 156 | def mic_acc_cal(preds, labels): 157 | if isinstance(labels, tuple): 158 | assert len(labels) == 3 159 | targets_a, targets_b, lam = labels 160 | acc_mic_top1 = (lam * preds.eq(targets_a.data).cpu().sum().float() \ 161 | + (1 - lam) * preds.eq(targets_b.data).cpu().sum().float()) / len(preds) 162 | else: 163 | acc_mic_top1 = (preds == labels).sum().item() / len(labels) 164 | return acc_mic_top1 165 | 166 | 167 | def weighted_mic_acc_cal(preds, labels, ws): 168 | acc_mic_top1 = ws[preds == labels].sum() / ws.sum() 169 | return acc_mic_top1 170 | 171 | 172 | def class_count(data): 173 | labels = np.array(data.dataset.labels) 174 | class_data_num = [] 175 | for l in np.unique(labels): 176 | class_data_num.append(len(labels[labels == l])) 177 | return class_data_num 178 | 179 | 180 | def torch2numpy(x): 181 | if isinstance(x, torch.Tensor): 182 | return x.detach().cpu().numpy() 183 | elif isinstance(x, (list, tuple)): 184 | return tuple([torch2numpy(xi) for xi in x]) 185 | else: 186 | return x 187 | 188 | 189 | def logits2score(logits, labels): 190 | scores = F.softmax(logits, dim=1) 191 | score = scores.gather(1, labels.view(-1, 1)) 192 | score = score.squeeze().cpu().numpy() 193 | return score 194 | 195 | 196 | def logits2entropy(logits): 197 | scores = F.softmax(logits, dim=1) 198 | scores = scores.cpu().numpy() + 1e-30 199 | ent = -scores * np.log(scores) 200 | ent = np.sum(ent, 1) 201 | return ent 202 | 203 | 204 | def logits2CE(logits, labels): 205 | scores = F.softmax(logits, dim=1) 206 | score = scores.gather(1, labels.view(-1, 1)) 207 | score = score.squeeze().cpu().numpy() + 1e-30 208 | ce = -np.log(score) 209 | return ce 210 | 211 | 212 | def get_priority(ptype, logits, labels): 213 | if ptype == 'score': 214 | ws = 1 - logits2score(logits, labels) 215 | elif ptype == 'entropy': 216 | ws = logits2entropy(logits) 217 | elif ptype == 'CE': 218 | ws = logits2CE(logits, labels) 219 | 220 | return ws 221 | 222 | 223 | def get_value(oldv, newv): 224 | if newv is not None: 225 | return newv 226 | else: 227 | return oldv 228 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 BIT-DA 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MetaSAug_LDAM_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import random 5 | import copy 6 | import torch 7 | import torchvision 8 | import numpy as np 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | import torchvision.transforms as transforms 12 | from data_utils import * 13 | from resnet import * 14 | import shutil 15 | from loss import * 16 | 17 | parser = argparse.ArgumentParser(description='Imbalanced Example') 18 | parser.add_argument('--dataset', default='cifar100', type=str, 19 | help='dataset (cifar10 or cifar100[default])') 20 | parser.add_argument('--batch-size', type=int, default=100, metavar='N', 21 | help='input batch size for training (default: 64)') 22 | parser.add_argument('--num_classes', type=int, default=100) 23 | parser.add_argument('--num_meta', type=int, default=10, 24 | help='The number of meta data for each class.') 25 | parser.add_argument('--imb_factor', type=float, default=0.005) 26 | parser.add_argument('--test-batch-size', type=int, default=100, metavar='N', 27 | help='input batch size for testing (default: 100)') 28 | parser.add_argument('--epochs', type=int, default=200, metavar='N', 29 | help='number of epochs to train') 30 | parser.add_argument('--lr', '--learning-rate', default=1e-1, type=float, 31 | help='initial learning rate') 32 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 33 | parser.add_argument('--nesterov', default=True, type=bool, help='nesterov momentum') 34 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 35 | help='weight decay (default: 5e-4)') 36 | parser.add_argument('--no-cuda', action='store_true', default=False, 37 | help='disables CUDA training') 38 | parser.add_argument('--split', type=int, default=1000) 39 | parser.add_argument('--seed', type=int, default=42, metavar='S', 40 | help='random seed (default: 42)') 41 | parser.add_argument('--print-freq', '-p', default=100, type=int, 42 | help='print frequency (default: 10)') 43 | parser.add_argument('--lam', default=0.25, type=float, help='[0.25, 0.5, 0.75, 1.0]') 44 | parser.add_argument('--gpu', default=0, type=int) 45 | parser.add_argument('--meta_lr', default=0.1, type=float) 46 | parser.add_argument('--save_name', default='name', type=str) 47 | parser.add_argument('--idx', default='0', type=str) 48 | 49 | 50 | args = parser.parse_args() 51 | for arg in vars(args): 52 | print("{}={}".format(arg, getattr(args, arg))) 53 | 54 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 55 | os.environ["CUDA_VISIBLE_DEVICES"]= str(args.gpu) 56 | kwargs = {'num_workers': 1, 'pin_memory': False} 57 | use_cuda = not args.no_cuda and torch.cuda.is_available() 58 | 59 | torch.manual_seed(args.seed) 60 | device = torch.device("cuda" if use_cuda else "cpu") 61 | 62 | train_data_meta, train_data, test_dataset = build_dataset(args.dataset, args.num_meta) 63 | 64 | print(f'length of meta dataset:{len(train_data_meta)}') 65 | print(f'length of train dataset: {len(train_data)}') 66 | 67 | train_loader = torch.utils.data.DataLoader( 68 | train_data, batch_size=args.batch_size, shuffle=True, **kwargs) 69 | 70 | np.random.seed(42) 71 | random.seed(42) 72 | torch.manual_seed(args.seed) 73 | classe_labels = range(args.num_classes) 74 | 75 | data_list = {} 76 | 77 | 78 | for j in range(args.num_classes): 79 | data_list[j] = [i for i, label in enumerate(train_loader.dataset.targets) if label == j] 80 | 81 | 82 | img_num_list = get_img_num_per_cls(args.dataset, args.imb_factor, args.num_meta*args.num_classes) 83 | print(img_num_list) 84 | print(sum(img_num_list)) 85 | 86 | im_data = {} 87 | idx_to_del = [] 88 | for cls_idx, img_id_list in data_list.items(): 89 | random.shuffle(img_id_list) 90 | img_num = img_num_list[int(cls_idx)] 91 | im_data[cls_idx] = img_id_list[img_num:] 92 | idx_to_del.extend(img_id_list[img_num:]) 93 | 94 | print(len(idx_to_del)) 95 | imbalanced_train_dataset = copy.deepcopy(train_data) 96 | imbalanced_train_dataset.targets = np.delete(train_loader.dataset.targets, idx_to_del, axis=0) 97 | imbalanced_train_dataset.data = np.delete(train_loader.dataset.data, idx_to_del, axis=0) 98 | print(len(imbalanced_train_dataset)) 99 | imbalanced_train_loader = torch.utils.data.DataLoader( 100 | imbalanced_train_dataset, batch_size=args.batch_size, shuffle=True, **kwargs) 101 | 102 | validation_loader = torch.utils.data.DataLoader( 103 | train_data_meta, batch_size=args.batch_size, shuffle=True, **kwargs) 104 | 105 | test_loader = torch.utils.data.DataLoader( 106 | test_dataset, batch_size=args.batch_size, shuffle=False, **kwargs) 107 | 108 | best_prec1 = 0 109 | 110 | beta = 0.9999 111 | effective_num = 1.0 - np.power(beta, img_num_list) 112 | per_cls_weights = (1.0 - beta) / np.array(effective_num) 113 | per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(img_num_list) 114 | per_cls_weights = torch.FloatTensor(per_cls_weights).cuda() 115 | weights = torch.tensor(per_cls_weights).float() 116 | 117 | def main(): 118 | global args, best_prec1 119 | args = parser.parse_args() 120 | 121 | model = build_model() 122 | optimizer_a = torch.optim.SGD(model.params(), args.lr, 123 | momentum=args.momentum, nesterov=args.nesterov, 124 | weight_decay=args.weight_decay) 125 | 126 | cudnn.benchmark = True 127 | 128 | criterion = LDAM_meta(64, args.dataset == "cifar10" and 10 or 100, cls_num_list=img_num_list, 129 | max_m=0.5, s=30) 130 | 131 | for epoch in range(args.epochs): 132 | adjust_learning_rate(optimizer_a, epoch + 1) 133 | 134 | ratio = args.lam * float(epoch) / float(args.epochs) 135 | if epoch < 160: 136 | train(imbalanced_train_loader, model, optimizer_a, epoch) 137 | 138 | else: 139 | train_MetaSAug(imbalanced_train_loader, validation_loader, model, optimizer_a, epoch, criterion, ratio) 140 | 141 | 142 | prec1, preds, gt_labels = validate(test_loader, model, nn.CrossEntropyLoss().cuda(), epoch) 143 | 144 | is_best = prec1 > best_prec1 145 | best_prec1 = max(prec1, best_prec1) 146 | 147 | # save_checkpoint(args, { 148 | # 'epoch': epoch + 1, 149 | # 'state_dict': model.state_dict(), 150 | # 'best_acc1': best_prec1, 151 | # 'optimizer': optimizer_a.state_dict(), 152 | # }, is_best) 153 | 154 | print('Best accuracy: ', best_prec1) 155 | 156 | 157 | def train(train_loader, model, optimizer_a, epoch): 158 | 159 | losses = AverageMeter() 160 | top1 = AverageMeter() 161 | model.train() 162 | 163 | for i, (input, target) in enumerate(train_loader): 164 | input_var = to_var(input, requires_grad=False) 165 | target_var = to_var(target, requires_grad=False) 166 | 167 | _, y_f = model(input_var) 168 | del _ 169 | cost_w = F.cross_entropy(y_f, target_var, reduce=False) 170 | l_f = torch.mean(cost_w) 171 | prec_train = accuracy(y_f.data, target_var.data, topk=(1,))[0] 172 | 173 | 174 | losses.update(l_f.item(), input.size(0)) 175 | top1.update(prec_train.item(), input.size(0)) 176 | 177 | optimizer_a.zero_grad() 178 | l_f.backward() 179 | optimizer_a.step() 180 | 181 | if i % args.print_freq == 0: 182 | print('Epoch: [{0}][{1}/{2}]\t' 183 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 184 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 185 | epoch, i, len(train_loader), 186 | loss=losses,top1=top1)) 187 | 188 | 189 | 190 | def train_MetaSAug(train_loader, validation_loader, model,optimizer_a, epoch, criterion, ratio): 191 | 192 | losses = AverageMeter() 193 | top1 = AverageMeter() 194 | model.train() 195 | 196 | for i, (input, target) in enumerate(train_loader): 197 | 198 | input_var = to_var(input, requires_grad=False) 199 | target_var = to_var(target, requires_grad=False) 200 | 201 | cv = criterion.get_cv() 202 | cv_var = to_var(cv) 203 | 204 | meta_model = ResNet32(args.dataset == 'cifar10' and 10 or 100) 205 | meta_model.load_state_dict(model.state_dict()) 206 | meta_model.cuda() 207 | 208 | feat_hat, y_f_hat = meta_model(input_var) 209 | cls_loss_meta = criterion(meta_model.linear, feat_hat, y_f_hat, target_var, ratio, 210 | weights, cv_var, "none") 211 | meta_model.zero_grad() 212 | 213 | grads = torch.autograd.grad(cls_loss_meta, (meta_model.params()), create_graph=True) 214 | meta_lr = args.lr * ((0.01 ** int(epoch >= 160)) * (0.01 ** int(epoch >= 180))) 215 | meta_model.update_params(meta_lr, source_params=grads) 216 | 217 | input_val, target_val = next(iter(validation_loader)) 218 | input_val_var = to_var(input_val, requires_grad=False) 219 | target_val_var = to_var(target_val, requires_grad=False) 220 | 221 | _, y_val = meta_model(input_val_var) 222 | cls_meta = F.cross_entropy(y_val, target_val_var) 223 | grad_cv = torch.autograd.grad(cls_meta, cv_var, only_inputs=True)[0] 224 | new_cv = cv_var - args.meta_lr * grad_cv 225 | 226 | del grad_cv, grads 227 | 228 | #model.train() 229 | features, predicts = model(input_var) 230 | cls_loss = criterion(model.linear, features, predicts, target_var, ratio, weights, new_cv, "update") 231 | 232 | prec_train = accuracy(predicts.data, target_var.data, topk=(1,))[0] 233 | 234 | 235 | losses.update(cls_loss.item(), input.size(0)) 236 | top1.update(prec_train.item(), input.size(0)) 237 | 238 | optimizer_a.zero_grad() 239 | cls_loss.backward() 240 | optimizer_a.step() 241 | 242 | if i % args.print_freq == 0: 243 | print('Epoch: [{0}][{1}/{2}]\t' 244 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 245 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 246 | epoch, i, len(train_loader), 247 | loss=losses,top1=top1)) 248 | 249 | 250 | def validate(val_loader, model, criterion, epoch): 251 | batch_time = AverageMeter() 252 | losses = AverageMeter() 253 | top1 = AverageMeter() 254 | 255 | model.eval() 256 | 257 | true_labels = [] 258 | preds = [] 259 | 260 | end = time.time() 261 | for i, (input, target) in enumerate(val_loader): 262 | target = target.cuda() 263 | input = input.cuda() 264 | input_var = torch.autograd.Variable(input) 265 | target_var = torch.autograd.Variable(target) 266 | 267 | with torch.no_grad(): 268 | _, output = model(input_var) 269 | 270 | output_numpy = output.data.cpu().numpy() 271 | preds_output = list(output_numpy.argmax(axis=1)) 272 | 273 | true_labels += list(target_var.data.cpu().numpy()) 274 | preds += preds_output 275 | 276 | 277 | prec1 = accuracy(output.data, target, topk=(1,))[0] 278 | top1.update(prec1.item(), input.size(0)) 279 | 280 | batch_time.update(time.time() - end) 281 | end = time.time() 282 | 283 | if i % args.print_freq == 0: 284 | print('Test: [{0}/{1}]\t' 285 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 286 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 287 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 288 | i, len(val_loader), batch_time=batch_time, loss=losses, 289 | top1=top1)) 290 | 291 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 292 | 293 | return top1.avg, preds, true_labels 294 | 295 | 296 | def build_model(): 297 | model = ResNet32(args.dataset == 'cifar10' and 10 or 100) 298 | 299 | if torch.cuda.is_available(): 300 | model.cuda() 301 | torch.backends.cudnn.benchmark = True 302 | 303 | 304 | return model 305 | 306 | def to_var(x, requires_grad=True): 307 | if torch.cuda.is_available(): 308 | x = x.cuda() 309 | return Variable(x, requires_grad=requires_grad) 310 | 311 | 312 | class AverageMeter(object): 313 | 314 | def __init__(self): 315 | self.reset() 316 | 317 | def reset(self): 318 | self.val = 0 319 | self.avg = 0 320 | self.sum = 0 321 | self.count = 0 322 | 323 | def update(self, val, n=1): 324 | self.val = val 325 | self.sum += val * n 326 | self.count += n 327 | self.avg = self.sum / self.count 328 | 329 | 330 | def adjust_learning_rate(optimizer, epoch): 331 | lr = args.lr * ((0.01 ** int(epoch >= 160)) * (0.01 ** int(epoch >= 180))) 332 | for param_group in optimizer.param_groups: 333 | param_group['lr'] = lr 334 | 335 | 336 | def accuracy(output, target, topk=(1,)): 337 | maxk = max(topk) 338 | batch_size = target.size(0) 339 | 340 | _, pred = output.topk(maxk, 1, True, True) 341 | pred = pred.t() 342 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 343 | 344 | res = [] 345 | for k in topk: 346 | correct_k = correct[:k].view(-1).float().sum(0) 347 | res.append(correct_k.mul_(100.0 / batch_size)) 348 | return res 349 | 350 | 351 | def save_checkpoint(args, state, is_best): 352 | path = 'checkpoint/ours/' + args.idx + '/' 353 | if not os.path.exists(path): 354 | os.makedirs(path) 355 | filename = path + args.save_name + '_ckpt.pth.tar' 356 | if is_best: 357 | torch.save(state, filename) 358 | 359 | if __name__ == '__main__': 360 | main() 361 | -------------------------------------------------------------------------------- /MetaSAug_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import random 5 | import copy 6 | import torch 7 | import torchvision 8 | import numpy as np 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | import torch.nn as nn 12 | import torchvision.transforms as transforms 13 | from data_utils import * 14 | from resnet import * 15 | import shutil 16 | import gc 17 | 18 | parser = argparse.ArgumentParser(description='Imbalanced Example') 19 | parser.add_argument('--checkpoint_path', default='path.pth.tar', type=str, 20 | help='the path of checkpoint') 21 | parser.add_argument('--dataset', default='cifar10', type=str) 22 | parser.add_argument('--imb_factor', default='0.1', type=float) 23 | 24 | 25 | args = parser.parse_args() 26 | print('checkpoint_path:', args.checkpoint_path) 27 | 28 | params = args.checkpoint_path.split('_') 29 | dataset = args.dataset 30 | imb_factor = args.imb_factor 31 | 32 | 33 | kwargs = {'num_workers': 4, 'pin_memory': False} 34 | use_cuda = torch.cuda.is_available() 35 | 36 | torch.manual_seed(42) 37 | 38 | print('start loading test data') 39 | train_data_meta, train_data, test_dataset = build_dataset(dataset, 10) 40 | test_loader = torch.utils.data.DataLoader( 41 | test_dataset, batch_size=100, shuffle=False, **kwargs) 42 | 43 | print('load test data successfully') 44 | 45 | best_prec1 = 0 46 | 47 | 48 | def main(): 49 | global args, best_prec1 50 | args = parser.parse_args() 51 | 52 | 53 | model = build_model() 54 | 55 | net_dict = torch.load(args.checkpoint_path) 56 | 57 | model.load_state_dict(net_dict['state_dict']) 58 | 59 | prec1, preds, gt_labels = validate( 60 | test_loader, model, nn.CrossEntropyLoss().cuda(), 0) 61 | print('Test result:\n' 62 | 'Dataset: {0}\t' 63 | 'Imb_factor: {1}\t' 64 | 'Accuracy: {2:.2f} \t' 65 | 'Error: {3:.2f} \n'.format( 66 | dataset, int(1 / imb_factor), prec1,100 - prec1)) 67 | 68 | 69 | 70 | def validate(val_loader, model, criterion, epoch): 71 | batch_time = AverageMeter() 72 | losses = AverageMeter() 73 | top1 = AverageMeter() 74 | 75 | model.eval() 76 | 77 | true_labels = [] 78 | preds = [] 79 | 80 | end = time.time() 81 | for i, (input, target) in enumerate(val_loader): 82 | target = target.cuda() 83 | input = input.cuda() 84 | input_var = torch.autograd.Variable(input) 85 | target_var = torch.autograd.Variable(target) 86 | 87 | with torch.no_grad(): 88 | _, output = model(input_var) 89 | 90 | output_numpy = output.data.cpu().numpy() 91 | preds_output = list(output_numpy.argmax(axis=1)) 92 | 93 | true_labels += list(target_var.data.cpu().numpy()) 94 | preds += preds_output 95 | 96 | prec1 = accuracy(output.data, target, topk=(1,))[0] 97 | top1.update(prec1.item(), input.size(0)) 98 | 99 | batch_time.update(time.time() - end) 100 | end = time.time() 101 | 102 | if i % 100 == 0: 103 | print('Test: [{0}/{1}]\t' 104 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 105 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 106 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 107 | i, len(val_loader), batch_time=batch_time, loss=losses, 108 | top1=top1)) 109 | 110 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 111 | 112 | return top1.avg, preds, true_labels 113 | 114 | 115 | def build_model(): 116 | model = ResNet32(dataset == 'cifar10' and 10 or 100) 117 | 118 | if torch.cuda.is_available(): 119 | model.cuda() 120 | torch.backends.cudnn.benchmark = True 121 | 122 | return model 123 | 124 | 125 | 126 | class AverageMeter(object): 127 | 128 | def __init__(self): 129 | self.reset() 130 | 131 | def reset(self): 132 | self.val = 0 133 | self.avg = 0 134 | self.sum = 0 135 | self.count = 0 136 | 137 | def update(self, val, n=1): 138 | self.val = val 139 | self.sum += val * n 140 | self.count += n 141 | self.avg = self.sum / self.count 142 | 143 | 144 | 145 | 146 | def accuracy(output, target, topk=(1,)): 147 | maxk = max(topk) 148 | batch_size = target.size(0) 149 | 150 | _, pred = output.topk(maxk, 1, True, True) 151 | pred = pred.t() 152 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 153 | 154 | res = [] 155 | for k in topk: 156 | correct_k = correct[:k].view(-1).float().sum(0) 157 | res.append(correct_k.mul_(100.0 / batch_size)) 158 | return res 159 | 160 | 161 | 162 | 163 | if __name__ == '__main__': 164 | main() 165 | 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
28 |
29 |