├── README.md ├── Sinkhorn_distance_fl.py ├── Sinkhorn_distance.py ├── data_utils.py ├── resnet.py └── OT_train.py /README.md: -------------------------------------------------------------------------------- 1 | # reweight-imbalance-classification-with-OT 2 | 3 | Code for "Learning to Re-weight Examples with Optimal Transport for Imbalanced Classification", in NeurIPS 2022. 4 | 5 | 6 | Requirements: 7 | 8 | Python 3.6 9 | PyTorch 1.7.1 10 | tqdm 4.19.9 11 | torchvision 0.8.2 12 | numpy 1.19.2 13 | 14 | 15 | 16 | Stage1: 17 | 18 | Pretrain the backbone with the imbalanced training set. See paper for more detailes. 19 | 20 | Adjust your file path according to the code. 21 | 22 | Stage2: 23 | 24 | Learn the weight vector by optimizing OT loss and update the recognition model. 25 | Run: OT_train.py 26 | 27 | 28 | Abstract: Imbalanced data pose challenges for deep learning based classification models. One 29 | of the most widely-used approaches for tackling imbalanced data is re-weighting, 30 | where training samples are associated with different weights in the loss function. 31 | Most of existing re-weighting approaches treat the example weights as the learnable 32 | parameter and optimize the weights on the meta set, entailing expensive bilevel 33 | optimization. In this paper, we propose a novel re-weighting method based on 34 | optimal transport (OT) from a distributional point of view. Specifically, we view 35 | the training set as an imbalanced distribution over its samples, which is transported 36 | by OT to a balanced distribution obtained from the meta set. The weights of 37 | the training samples are the probability mass of the imbalanced distribution and 38 | learned by minimizing the OT distance between the two distributions. Compared 39 | with existing methods, our proposed one disengages the dependence of the weight 40 | learning on the concerned classifier at each iteration. Experiments on image, 41 | text and point cloud datasets demonstrate that our proposed re-weighting method 42 | has excellent performance, achieving state-of-the-art results in many cases and 43 | providing a promising tool for addressing the imbalanced classification issue. 44 | 45 | 46 | 47 | @inproceedings{Guo2022reweight, 48 | title={Learning to Re-weight Examples with Optimal Transport for Imbalanced Classification}, 49 | author={Guo, Dandan and Li, Zhuo and Zheng, Meixi and Zhao, He and Zhou, Mingyuan and Zha, Hongyuan}, 50 | booktitle={Proceedings of the Advances in Neural Information Processing Systems (NeurIPS)}, 51 | year={2022} 52 | } 53 | -------------------------------------------------------------------------------- /Sinkhorn_distance_fl.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | 4 | import random 5 | import time 6 | import math 7 | import numpy as np 8 | from sklearn.neighbors import NearestNeighbors 9 | from scipy.spatial import KDTree 10 | from scipy.stats import wasserstein_distance 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.nn import init 15 | import torch.nn.functional as F 16 | from torch.utils.data import Dataset, DataLoader 17 | from torch.autograd.variable import Variable 18 | from torch.utils.data import DataLoader 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | d_cosine = nn.CosineSimilarity(dim=-1, eps=1e-8) 21 | 22 | 23 | class SinkhornDistance(nn.Module): 24 | r""" 25 | Given two empirical measures each with :math:`P_1` locations 26 | :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`, 27 | outputs an approximation of the regularized OT cost for point clouds. 28 | Args: 29 | eps (float): regularization coefficient 30 | max_iter (int): maximum number of Sinkhorn iterations 31 | reduction (string, optional): Specifies the reduction to apply to the output: 32 | 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 33 | 'mean': the sum of the output will be divided by the number of 34 | elements in the output, 'sum': the output will be summed. Default: 'none' 35 | Shape: 36 | - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)` 37 | - Output: :math:`(N)` or :math:`()`, depending on `reduction` 38 | 39 | """ 40 | 41 | def __init__(self, eps, max_iter, reduction='none'): 42 | super(SinkhornDistance, self).__init__() 43 | self.eps = eps 44 | self.max_iter = max_iter 45 | self.reduction = reduction 46 | 47 | def forward(self, x, y, x1, y1, nu): 48 | C1 = self._cost_matrix(x, y, dis='cos') 49 | C2 = self._cost_matrix(x1, y1, dis='euc') 50 | C = 0.5*C1 + 0.5*C2 51 | x_points = x.shape[-2] 52 | y_points = y.shape[-2] 53 | if x.dim() == 2: 54 | batch_size = 1 55 | else: 56 | batch_size = x.shape[0] 57 | 58 | mu = torch.empty(batch_size, x_points, dtype=torch.float, 59 | requires_grad=False).fill_(1.0 / x_points).to(device).squeeze() 60 | 61 | u = torch.zeros_like(mu).to(device) 62 | v = torch.zeros_like(nu).to(device) 63 | actual_nits = 0 64 | thresh = 1e-1 65 | 66 | for i in range(self.max_iter): 67 | u1 = u 68 | u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u 69 | v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v 70 | err = (u - u1).abs().sum(-1).mean() 71 | 72 | actual_nits += 1 73 | if err.item() < thresh: 74 | break 75 | 76 | U, V = u, v 77 | pi = torch.exp(self.M(C, U, V)) 78 | cost = torch.sum(pi * C, dim=(-2, -1)) 79 | 80 | if self.reduction == 'mean': 81 | cost = cost.mean() 82 | elif self.reduction == 'sum': 83 | cost = cost.sum() 84 | return cost 85 | 86 | def M(self, C, u, v): 87 | "Modified cost for logarithmic updates" 88 | "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$" 89 | return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps 90 | 91 | @staticmethod 92 | 93 | def _cost_matrix(x, y, dis, p=2): 94 | "Returns the matrix of $|x_i-y_j|^p$." 95 | x_col = x.unsqueeze(-2) 96 | y_lin = y.unsqueeze(-3) 97 | if dis == 'cos': 98 | C = 1-d_cosine(x_col , y_lin) 99 | elif dis == 'euc': 100 | C = torch.mean((torch.abs(x_col - y_lin)) ** p, -1) 101 | return C 102 | 103 | 104 | 105 | 106 | @staticmethod 107 | def ave(u, u1, tau): 108 | "Barycenter subroutine, used by kinetic acceleration through extrapolation." 109 | return tau * u + (1 - tau) * u1 110 | -------------------------------------------------------------------------------- /Sinkhorn_distance.py: -------------------------------------------------------------------------------- 1 | 2 | import sys 3 | import random 4 | import time 5 | import math 6 | import numpy as np 7 | from sklearn.neighbors import NearestNeighbors 8 | from scipy.spatial import KDTree 9 | from scipy.stats import wasserstein_distance 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import init 14 | import torch.nn.functional as F 15 | from torch.utils.data import Dataset, DataLoader 16 | from torch.autograd.variable import Variable 17 | from torch.utils.data import DataLoader 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | d_cosine = nn.CosineSimilarity(dim=-1, eps=1e-8) 20 | 21 | 22 | class SinkhornDistance(nn.Module): 23 | r""" 24 | Given two empirical measures each with :math:`P_1` locations 25 | :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`, 26 | outputs an approximation of the regularized OT cost for point clouds. 27 | Args: 28 | eps (float): regularization coefficient 29 | max_iter (int): maximum number of Sinkhorn iterations 30 | reduction (string, optional): Specifies the reduction to apply to the output: 31 | 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 32 | 'mean': the sum of the output will be divided by the number of 33 | elements in the output, 'sum': the output will be summed. Default: 'none' 34 | Shape: 35 | - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)` 36 | - Output: :math:`(N)` or :math:`()`, depending on `reduction` 37 | 38 | """ 39 | 40 | def __init__(self, eps, max_iter, dis, reduction='none'): 41 | super(SinkhornDistance, self).__init__() 42 | self.eps = eps 43 | self.max_iter = max_iter 44 | self.reduction = reduction 45 | self.dis = dis 46 | 47 | def forward(self, x, y, nu): 48 | if self.dis == 'cos': 49 | C = self._cost_matrix(x, y, 'cos') 50 | elif self.dis == 'euc': 51 | C = self._cost_matrix(x, y, 'euc') 52 | x_points = x.shape[-2] 53 | y_points = y.shape[-2] 54 | if x.dim() == 2: 55 | batch_size = 1 56 | else: 57 | batch_size = x.shape[0] 58 | 59 | mu = torch.empty(batch_size, x_points, dtype=torch.float, 60 | requires_grad=False).fill_(1.0 / x_points).to(device).squeeze() 61 | 62 | u = torch.zeros_like(mu).to(device) 63 | v = torch.zeros_like(nu).to(device) 64 | 65 | actual_nits = 0 66 | thresh = 1e-1 67 | 68 | for i in range(self.max_iter): 69 | u1 = u 70 | u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u 71 | v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v 72 | err = (u - u1).abs().sum(-1).mean() 73 | 74 | actual_nits += 1 75 | if err.item() < thresh: 76 | break 77 | 78 | U, V = u, v 79 | pi = torch.exp(self.M(C, U, V)) 80 | cost = torch.sum(pi * C, dim=(-2, -1)) 81 | 82 | if self.reduction == 'mean': 83 | cost = cost.mean() 84 | elif self.reduction == 'sum': 85 | cost = cost.sum() 86 | return cost 87 | 88 | def M(self, C, u, v): 89 | "Modified cost for logarithmic updates" 90 | "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$" 91 | return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps 92 | 93 | def _cost_matrix(x, y, dis, p=2): 94 | "Returns the matrix of $|x_i-y_j|^p$." 95 | x_col = x.unsqueeze(-2) 96 | y_lin = y.unsqueeze(-3) 97 | if dis == 'cos': 98 | C = 1-d_cosine(x_col, y_lin) 99 | elif dis == 'euc': 100 | C= torch.mean((torch.abs(x_col - y_lin)) ** p, -1) 101 | 102 | return C 103 | 104 | 105 | 106 | @staticmethod 107 | def ave(u, u1, tau): 108 | "Barycenter subroutine, used by kinetic acceleration through extrapolation." 109 | return tau * u + (1 - tau) * u1 110 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.parallel 5 | import torch.backends.cudnn as cudnn 6 | import torch.optim 7 | import torch.utils.data 8 | import torchvision.transforms as transforms 9 | import torchvision 10 | import numpy as np 11 | import copy 12 | from torch.utils.data import Dataset 13 | 14 | np.random.seed(6) 15 | 16 | def build_dataset(dataset,num_meta): 17 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 18 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 19 | 20 | transform_train = transforms.Compose([ 21 | transforms.ToTensor(), 22 | transforms.Lambda(lambda x: F.pad(x.unsqueeze(0), 23 | (4, 4, 4, 4), mode='reflect').squeeze()), 24 | transforms.ToPILImage(), 25 | transforms.RandomCrop(32), 26 | transforms.RandomHorizontalFlip(), 27 | transforms.ToTensor(), 28 | normalize, 29 | ]) 30 | 31 | transform_test = transforms.Compose([ 32 | transforms.ToTensor(), 33 | normalize 34 | ]) 35 | 36 | if dataset == 'cifar10': 37 | train_dataset = torchvision.datasets.CIFAR10(root='../cifar-10', train=True, download=False, transform=transform_train) 38 | test_dataset = torchvision.datasets.CIFAR10('../cifar-10', train=False, transform=transform_test) 39 | img_num_list = [num_meta] * 10 40 | num_classes = 10 41 | 42 | if dataset == 'cifar100': 43 | train_dataset = torchvision.datasets.CIFAR100(root='../cifar-100', train=True, download=True, transform=transform_train) 44 | test_dataset = torchvision.datasets.CIFAR100('../cifar-100', train=False, transform=transform_test) 45 | img_num_list = [num_meta] * 100 46 | num_classes = 100 47 | 48 | data_list_val = {} 49 | for j in range(num_classes): 50 | data_list_val[j] = [i for i, label in enumerate(train_dataset.targets) if label == j] 51 | 52 | idx_to_meta = [] 53 | idx_to_train = [] 54 | print(img_num_list) 55 | 56 | for cls_idx, img_id_list in data_list_val.items(): 57 | np.random.shuffle(img_id_list) 58 | img_num = img_num_list[int(cls_idx)] 59 | idx_to_meta.extend(img_id_list[:img_num]) 60 | idx_to_train.extend(img_id_list[img_num:]) 61 | train_data = copy.deepcopy(train_dataset) 62 | train_data_meta = copy.deepcopy(train_dataset) 63 | 64 | train_data_meta.data = np.delete(train_dataset.data, idx_to_train,axis=0) 65 | train_data_meta.targets = np.delete(train_dataset.targets, idx_to_train, axis=0) 66 | train_data.data = np.delete(train_dataset.data, idx_to_meta, axis=0) 67 | train_data.targets = np.delete(train_dataset.targets, idx_to_meta, axis=0) 68 | 69 | return train_data_meta, train_data, test_dataset 70 | 71 | def get_img_num_per_cls(dataset, imb_factor=None, num_meta=None): 72 | 73 | if dataset == 'cifar10': 74 | img_max = (50000-num_meta)/10 75 | cls_num = 10 76 | 77 | if dataset == 'cifar100': 78 | img_max = (50000-num_meta)/100 79 | cls_num = 100 80 | 81 | if imb_factor is None: 82 | return [img_max] * cls_num 83 | img_num_per_cls = [] 84 | for cls_idx in range(cls_num): 85 | num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0))) 86 | img_num_per_cls.append(int(num)) 87 | return img_num_per_cls 88 | 89 | class new_dataset(Dataset): 90 | def __init__(self, dataset, train=None): 91 | self.data = dataset.data 92 | self.targets = dataset.targets 93 | normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 94 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) 95 | if train: 96 | self.transform = transforms.Compose([ 97 | transforms.ToTensor(), 98 | transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),(4, 4, 4, 4), mode='reflect').squeeze()), 99 | transforms.ToPILImage(), 100 | transforms.RandomCrop(32), 101 | transforms.RandomHorizontalFlip(), 102 | transforms.ToTensor(), 103 | normalize, 104 | ]) 105 | else: 106 | self.transform = transforms.Compose([ 107 | transforms.ToTensor(), 108 | normalize 109 | ]) 110 | 111 | 112 | def __getitem__(self, index): 113 | img, label = self.data[index, ::], self.targets[index] 114 | img = self.transform(img) 115 | label = torch.LongTensor([np.int64(label)]) 116 | 117 | return img, label, index 118 | 119 | def __len__(self): 120 | return len(self.data) 121 | 122 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch.autograd import Variable 6 | import torch.nn.init as init 7 | 8 | 9 | def to_var(x, requires_grad=True): 10 | if torch.cuda.is_available(): 11 | x = x.cuda() 12 | return Variable(x, requires_grad=requires_grad) 13 | 14 | 15 | class resnet_attention(nn.Module): 16 | def __init__(self, enc_hid_dim=64, dec_hid_dim=100): 17 | super(resnet_attention, self).__init__() 18 | 19 | self.attn = nn.Linear(enc_hid_dim , dec_hid_dim, bias=True) 20 | self.v = nn.Linear(dec_hid_dim, 1, bias=False) 21 | 22 | def forward(self, s): 23 | energy = torch.tanh(self.attn(s)) 24 | attention = self.v(energy) 25 | 26 | return F.softmax(attention, dim=0) 27 | 28 | 29 | class MetaModule(nn.Module): 30 | def params(self): 31 | for name, param in self.named_params(self): 32 | yield param 33 | 34 | def named_leaves(self): 35 | return [] 36 | 37 | def named_submodules(self): 38 | return [] 39 | 40 | def named_params(self, curr_module=None, memo=None, prefix=''): 41 | if memo is None: 42 | memo = set() 43 | 44 | if hasattr(curr_module, 'named_leaves'): 45 | for name, p in curr_module.named_leaves(): 46 | if p is not None and p not in memo: 47 | memo.add(p) 48 | yield prefix + ('.' if prefix else '') + name, p 49 | else: 50 | for name, p in curr_module._parameters.items(): 51 | if p is not None and p not in memo: 52 | memo.add(p) 53 | yield prefix + ('.' if prefix else '') + name, p 54 | 55 | for mname, module in curr_module.named_children(): 56 | submodule_prefix = prefix + ('.' if prefix else '') + mname 57 | for name, p in self.named_params(module, memo, submodule_prefix): 58 | yield name, p 59 | 60 | def update_params(self, lr_inner, first_order=False, source_params=None, detach=False): 61 | if source_params is not None: 62 | for tgt, src in zip(self.named_params(self), source_params): 63 | name_t, param_t = tgt 64 | grad = src 65 | if first_order: 66 | grad = to_var(grad.detach().data) 67 | tmp = param_t - lr_inner * grad 68 | self.set_param(self, name_t, tmp) 69 | else: 70 | 71 | for name, param in self.named_params(self): 72 | if not detach: 73 | grad = param.grad 74 | if first_order: 75 | grad = to_var(grad.detach().data) 76 | tmp = param - lr_inner * grad 77 | self.set_param(self, name, tmp) 78 | else: 79 | param = param.detach_() 80 | self.set_param(self, name, param) 81 | 82 | def set_param(self, curr_mod, name, param): 83 | if '.' in name: 84 | n = name.split('.') 85 | module_name = n[0] 86 | rest = '.'.join(n[1:]) 87 | for name, mod in curr_mod.named_children(): 88 | if module_name == name: 89 | self.set_param(mod, rest, param) 90 | break 91 | else: 92 | setattr(curr_mod, name, param) 93 | 94 | def detach_params(self): 95 | for name, param in self.named_params(self): 96 | self.set_param(self, name, param.detach()) 97 | 98 | def copy(self, other, same_var=False): 99 | for name, param in other.named_params(): 100 | if not same_var: 101 | param = to_var(param.data.clone(), requires_grad=True) 102 | self.set_param(name, param) 103 | 104 | 105 | class MetaLinear(MetaModule): 106 | def __init__(self, *args, **kwargs): 107 | super().__init__() 108 | ignore = nn.Linear(*args, **kwargs) 109 | 110 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 111 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 112 | 113 | def forward(self, x): 114 | return F.linear(x, self.weight, self.bias) 115 | 116 | def named_leaves(self): 117 | return [('weight', self.weight), ('bias', self.bias)] 118 | 119 | class MetaLinear_Norm(MetaModule): 120 | def __init__(self, *args, **kwargs): 121 | super().__init__() 122 | temp = nn.Linear(*args, **kwargs) 123 | temp.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 124 | self.register_buffer('weight', to_var(temp.weight.data.t(), requires_grad=True)) 125 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 126 | 127 | def forward(self, x): 128 | out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 129 | return out 130 | 131 | def named_leaves(self): 132 | return [('weight', self.weight)] 133 | 134 | 135 | class MetaConv2d(MetaModule): 136 | def __init__(self, *args, **kwargs): 137 | super().__init__() 138 | ignore = nn.Conv2d(*args, **kwargs) 139 | 140 | self.in_channels = ignore.in_channels 141 | self.out_channels = ignore.out_channels 142 | self.stride = ignore.stride 143 | self.padding = ignore.padding 144 | self.dilation = ignore.dilation 145 | self.groups = ignore.groups 146 | self.kernel_size = ignore.kernel_size 147 | 148 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 149 | 150 | if ignore.bias is not None: 151 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 152 | else: 153 | self.register_buffer('bias', None) 154 | 155 | def forward(self, x): 156 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 157 | 158 | def named_leaves(self): 159 | return [('weight', self.weight), ('bias', self.bias)] 160 | 161 | 162 | class MetaConvTranspose2d(MetaModule): 163 | def __init__(self, *args, **kwargs): 164 | super().__init__() 165 | ignore = nn.ConvTranspose2d(*args, **kwargs) 166 | 167 | self.stride = ignore.stride 168 | self.padding = ignore.padding 169 | self.dilation = ignore.dilation 170 | self.groups = ignore.groups 171 | 172 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 173 | 174 | if ignore.bias is not None: 175 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 176 | else: 177 | self.register_buffer('bias', None) 178 | 179 | def forward(self, x, output_size=None): 180 | output_padding = self._output_padding(x, output_size) 181 | return F.conv_transpose2d(x, self.weight, self.bias, self.stride, self.padding, 182 | output_padding, self.groups, self.dilation) 183 | 184 | def named_leaves(self): 185 | return [('weight', self.weight), ('bias', self.bias)] 186 | 187 | 188 | class MetaBatchNorm2d(MetaModule): 189 | def __init__(self, *args, **kwargs): 190 | super().__init__() 191 | ignore = nn.BatchNorm2d(*args, **kwargs) 192 | 193 | self.num_features = ignore.num_features 194 | self.eps = ignore.eps 195 | self.momentum = ignore.momentum 196 | self.affine = ignore.affine 197 | self.track_running_stats = ignore.track_running_stats 198 | 199 | if self.affine: 200 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 201 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 202 | 203 | if self.track_running_stats: 204 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 205 | self.register_buffer('running_var', torch.ones(self.num_features)) 206 | else: 207 | self.register_parameter('running_mean', None) 208 | self.register_parameter('running_var', None) 209 | 210 | def forward(self, x): 211 | return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 212 | self.training or not self.track_running_stats, self.momentum, self.eps) 213 | 214 | def named_leaves(self): 215 | return [('weight', self.weight), ('bias', self.bias)] 216 | 217 | 218 | def _weights_init(m): 219 | classname = m.__class__.__name__ 220 | if isinstance(m, MetaLinear) or isinstance(m, MetaConv2d): 221 | init.kaiming_normal(m.weight) 222 | 223 | class LambdaLayer(MetaModule): 224 | def __init__(self, lambd): 225 | super(LambdaLayer, self).__init__() 226 | self.lambd = lambd 227 | 228 | def forward(self, x): 229 | return self.lambd(x) 230 | 231 | 232 | class BasicBlock(MetaModule): 233 | expansion = 1 234 | 235 | def __init__(self, in_planes, planes, stride=1, option='A'): 236 | super(BasicBlock, self).__init__() 237 | self.conv1 = MetaConv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 238 | self.bn1 = MetaBatchNorm2d(planes) 239 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 240 | self.bn2 = MetaBatchNorm2d(planes) 241 | 242 | self.shortcut = nn.Sequential() 243 | if stride != 1 or in_planes != planes: 244 | if option == 'A': 245 | self.shortcut = LambdaLayer(lambda x: 246 | F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0)) 247 | elif option == 'B': 248 | self.shortcut = nn.Sequential( 249 | MetaConv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 250 | MetaBatchNorm2d(self.expansion * planes) 251 | ) 252 | 253 | def forward(self, x): 254 | out = F.relu(self.bn1(self.conv1(x))) 255 | out = self.bn2(self.conv2(out)) 256 | out += self.shortcut(x) 257 | out = F.relu(out) 258 | return out 259 | 260 | 261 | class ResNet32(MetaModule): 262 | def __init__(self, num_classes, block=BasicBlock, num_blocks=[5, 5, 5]): 263 | super(ResNet32, self).__init__() 264 | self.in_planes = 16 265 | 266 | self.conv1 = MetaConv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False) 267 | self.bn1 = MetaBatchNorm2d(16) 268 | self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1) 269 | self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2) 270 | self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2) 271 | self.linear = MetaLinear(64, num_classes) 272 | 273 | self.apply(_weights_init) 274 | 275 | def _make_layer(self, block, planes, num_blocks, stride): 276 | strides = [stride] + [1]*(num_blocks-1) 277 | layers = [] 278 | for stride in strides: 279 | layers.append(block(self.in_planes, planes, stride)) 280 | self.in_planes = planes * block.expansion 281 | 282 | return nn.Sequential(*layers) 283 | 284 | def forward(self, x): 285 | out = F.relu(self.bn1(self.conv1(x))) 286 | out = self.layer1(out) 287 | out = self.layer2(out) 288 | out = self.layer3(out) 289 | out = F.avg_pool2d(out, out.size()[3]) 290 | out = out.view(out.size(0), -1) 291 | y = self.linear(out) 292 | return out, y 293 | -------------------------------------------------------------------------------- /OT_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import random 5 | import copy 6 | import torch 7 | import torchvision 8 | import numpy as np 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | import torchvision.transforms as transforms 12 | from data_utils import * 13 | from resnet import * 14 | import shutil 15 | from Sinkhorn_distance import SinkhornDistance 16 | from Sinkhorn_distance_fl import SinkhornDistance as SinkhornDistance_fl 17 | from torch.utils.data import TensorDataset, DataLoader 18 | 19 | parser = argparse.ArgumentParser(description='Imbalanced Example') 20 | parser.add_argument('--dataset', default='cifar10', type=str, 21 | help='dataset (cifar10[default] or cifar100)') 22 | parser.add_argument('--cost', default='combined', type=str, 23 | help='[combined, label, feature, twoloss]') 24 | parser.add_argument('--meta_set', default='prototype', type=str, 25 | help='[whole, prototype]') 26 | parser.add_argument('--batch-size', type=int, default=16, metavar='N', 27 | help='input batch size for training (default: 16)') 28 | parser.add_argument('--num_classes', type=int, default=10) 29 | parser.add_argument('--num_meta', type=int, default=10, 30 | help='The number of meta data for each class.') 31 | parser.add_argument('--imb_factor', type=float, default=0.005) 32 | parser.add_argument('--epochs', type=int, default=250, metavar='N', 33 | help='number of epochs to train') 34 | parser.add_argument('--lr', '--learning-rate', default=2e-5, type=float, 35 | help='initial learning rate') 36 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 37 | parser.add_argument('--nesterov', default=True, type=bool, help='nesterov momentum') 38 | parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float, 39 | help='weight decay (default: 5e-4)') 40 | parser.add_argument('--no-cuda', action='store_true', default=False, 41 | help='disables CUDA training') 42 | parser.add_argument('--seed', type=int, default=42, metavar='S', 43 | help='random seed (default: 42)') 44 | parser.add_argument('--print-freq', '-p', default=100, type=int, 45 | help='print frequency (default: 10)') 46 | parser.add_argument('--gpu', default=0, type=int) 47 | parser.add_argument('--save_name', default='OT_cifar10_imb0.005', type=str) 48 | parser.add_argument('--idx', default='ours', type=str) 49 | 50 | 51 | args = parser.parse_args() 52 | for arg in vars(args): 53 | print("{}={}".format(arg, getattr(args, arg))) 54 | 55 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 56 | os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) 57 | kwargs = {'num_workers': 0, 'pin_memory': False} 58 | use_cuda = not args.no_cuda and torch.cuda.is_available() 59 | 60 | torch.manual_seed(args.seed) 61 | device = torch.device("cuda" if use_cuda else "cpu") 62 | 63 | train_data_meta, train_data, test_dataset = build_dataset(args.dataset, args.num_meta) 64 | 65 | print(f'length of meta dataset:{len(train_data_meta)}') 66 | print(f'length of train dataset: {len(train_data)}') 67 | 68 | train_loader = torch.utils.data.DataLoader( 69 | train_data, batch_size=args.batch_size, shuffle=True, **kwargs) 70 | 71 | np.random.seed(42) 72 | random.seed(42) 73 | torch.manual_seed(args.seed) 74 | classe_labels = range(args.num_classes) 75 | 76 | data_list = {} 77 | 78 | 79 | for j in range(args.num_classes): 80 | data_list[j] = [i for i, label in enumerate(train_loader.dataset.targets) if label == j] 81 | 82 | 83 | img_num_list = get_img_num_per_cls(args.dataset, args.imb_factor, args.num_meta*args.num_classes) 84 | print(img_num_list) 85 | print(sum(img_num_list)) 86 | 87 | im_data = {} 88 | idx_to_del = [] 89 | for cls_idx, img_id_list in data_list.items(): 90 | random.shuffle(img_id_list) 91 | img_num = img_num_list[int(cls_idx)] 92 | im_data[cls_idx] = img_id_list[img_num:] 93 | idx_to_del.extend(img_id_list[img_num:]) 94 | 95 | print(len(idx_to_del)) 96 | imbalanced_train_dataset = copy.deepcopy(train_data) 97 | imbalanced_train_dataset.targets = np.delete(train_loader.dataset.targets, idx_to_del, axis=0) 98 | imbalanced_train_dataset.data = np.delete(train_loader.dataset.data, idx_to_del, axis=0) 99 | print(len(imbalanced_train_dataset)) 100 | 101 | imbalanced_train_loader = DataLoader(new_dataset(imbalanced_train_dataset, train=True), 102 | batch_size=args.batch_size, shuffle=True, **kwargs) 103 | validation_loader = DataLoader(new_dataset(train_data_meta, train=True), 104 | batch_size=args.num_classes*args.num_meta, shuffle=False, **kwargs) 105 | test_loader = DataLoader(new_dataset(test_dataset, train=False), 106 | batch_size=args.batch_size, shuffle=False, **kwargs) 107 | 108 | best_prec1 = 0 109 | 110 | beta = 0.9999 111 | effective_num = 1.0 - np.power(beta, img_num_list) 112 | per_cls_weights = (1.0 - beta) / np.array(effective_num) 113 | per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(img_num_list) 114 | per_cls_weights = torch.FloatTensor(per_cls_weights).cuda() 115 | weights = torch.tensor(per_cls_weights).float() 116 | weightsbuffer = torch.tensor([per_cls_weights[cls_i] for cls_i in imbalanced_train_dataset.targets]).to('cuda') 117 | 118 | eplisons = 0.1 119 | criterion = SinkhornDistance(eps=eplisons, max_iter=200, reduction=None, dis='cos').to('cuda') 120 | criterion_label = SinkhornDistance(eps=eplisons, max_iter=200, reduction=None, dis='euc').to('cuda') 121 | criterion_fl = SinkhornDistance_fl(eps=eplisons, max_iter=200, reduction=None).to('cuda') 122 | 123 | def main(): 124 | global args, best_prec1 125 | args = parser.parse_args() 126 | 127 | if args.dataset == 'cifar10': 128 | if args.imb_factor == 0.005: 129 | ckpt_path = r'checkpoint/ours/pretrain/..' 130 | 131 | else: 132 | if args.imb_factor == 0.005: 133 | ckpt_path = r'checkpoint/ours/pretrain/..' 134 | 135 | 136 | model = build_model(load_pretrain=True, ckpt_path=ckpt_path) 137 | optimizer_a = torch.optim.SGD([model.linear.weight,model.linear.bias], args.lr, 138 | momentum=args.momentum, nesterov=args.nesterov, 139 | weight_decay=args.weight_decay) 140 | 141 | cudnn.benchmark = True 142 | criterion_classifier = nn.CrossEntropyLoss(reduction='none').cuda() 143 | 144 | for epoch in range(160, args.epochs): 145 | 146 | train_OT(imbalanced_train_loader, validation_loader, weightsbuffer, 147 | model, optimizer_a, epoch, criterion_classifier) 148 | 149 | prec1, preds, gt_labels = validate(test_loader, model) 150 | 151 | is_best = prec1 > best_prec1 152 | best_prec1 = max(prec1, best_prec1) 153 | 154 | if is_best: 155 | weightsbuffer_bycls = [] 156 | for i_cls in range(args.num_classes): 157 | weightsbuffer_bycls.extend(weightsbuffer[imbalanced_train_dataset.targets == i_cls].data.cpu().numpy()) 158 | 159 | save_checkpoint(args, { 160 | 'epoch': epoch + 1, 161 | 'state_dict': model.state_dict(), 162 | 'best_acc1': best_prec1, 163 | 'optimizer': optimizer_a.state_dict(), 164 | 'weights': weightsbuffer_bycls 165 | }, is_best) 166 | 167 | print('Best accuracy: ', best_prec1) 168 | 169 | 170 | def train_OT(train_loader, validation_loader, weightsbuffer, model, optimizer, epoch, criterion_classifier): 171 | losses = AverageMeter() 172 | top1 = AverageMeter() 173 | model.train() 174 | 175 | val_data, val_labels, _ = next(iter(validation_loader)) 176 | val_data = to_var(val_data, requires_grad=False) 177 | val_labels = to_var(val_labels, requires_grad=False).squeeze() 178 | 179 | if args.meta_set == 'whole': 180 | val_data_bycls = val_data 181 | val_labels_bycls = val_labels 182 | elif args.meta_set == 'prototype': 183 | val_data_bycls = torch.zeros([args.num_classes, args.num_meta, 3, 32, 32]).cuda() 184 | for i_cls in range(args.num_classes): 185 | val_data_bycls[i_cls, ::] = val_data[val_labels == i_cls] 186 | val_data_bycls = torch.mean(val_data_bycls, dim=1) 187 | val_labels_bycls = torch.tensor([i_l for i_l in range(args.num_classes)]).cuda() 188 | 189 | val_labels_onehot = to_categorical(val_labels_bycls).cuda() 190 | feature_val, _ = model(val_data_bycls) 191 | 192 | for i, batch in enumerate(train_loader): 193 | 194 | inputs, labels, ids = tuple(t.to('cuda') for t in batch) 195 | labels = labels.squeeze() 196 | labels_onehot = to_categorical(labels).cuda() 197 | 198 | 199 | weights = to_var(weightsbuffer[ids]) 200 | model.eval() 201 | Attoptimizer = torch.optim.SGD([weights], lr=0.01, momentum=0.9, weight_decay=5e-4) 202 | 203 | for ot_epoch in range(1): 204 | feature_train, _ = model(inputs) 205 | probability_train = softmax_normalize(weights) 206 | 207 | if args.cost == 'feature': 208 | OTloss = criterion(feature_val.detach(), feature_train.detach(), probability_train.squeeze()) 209 | elif args.cost == 'label': 210 | OTloss = criterion_label(torch.tensor(val_labels_onehot, dtype=float).cuda(), 211 | torch.tensor(labels_onehot, dtype=float).cuda(), 212 | probability_train.squeeze()) 213 | elif args.cost == 'combined': 214 | OTloss = criterion_fl(feature_val.detach(), feature_train.detach(), 215 | torch.tensor(val_labels_onehot, dtype=float).cuda(), 216 | torch.tensor(labels_onehot, dtype=float).cuda(), 217 | probability_train.squeeze()) 218 | elif args.cost == 'twoloss': 219 | OTloss1 = criterion(feature_val.detach(), feature_train.detach(), probability_train.squeeze()) 220 | OTloss2 = criterion_label(torch.tensor(val_labels_onehot, dtype=float).cuda(), 221 | torch.tensor(labels_onehot, dtype=float).cuda(), 222 | probability_train.squeeze()) 223 | OTloss = OTloss1 + OTloss2 224 | 225 | Attoptimizer.zero_grad() 226 | OTloss.backward() 227 | Attoptimizer.step() 228 | 229 | weightsbuffer[ids] = weights.data 230 | 231 | 232 | model.train() 233 | optimizer.zero_grad() 234 | _, logits = model(inputs) 235 | loss_train = criterion_classifier(logits, labels.long()) 236 | _, logits_val = model(val_data) 237 | loss_val = F.cross_entropy(logits_val, val_labels.long(), reduction='none') 238 | 239 | loss = torch.sum(loss_train * weights.data) + 10*torch.mean(loss_val) 240 | loss.backward() 241 | optimizer.step() 242 | 243 | prec_train = accuracy(logits.data, labels, topk=(1,))[0] 244 | 245 | losses.update(loss.item(), inputs.size(0)) 246 | top1.update(prec_train.item(), inputs.size(0)) 247 | 248 | if i==len(train_loader)-1 or i % args.print_freq == 0: 249 | print('Epoch: [{0}][{1}/{2}]\t' 250 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 251 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 252 | epoch, i, len(train_loader), 253 | loss=losses, top1=top1)) 254 | 255 | 256 | 257 | def validate(val_loader, model): 258 | batch_time = AverageMeter() 259 | losses = AverageMeter() 260 | top1 = AverageMeter() 261 | 262 | model.eval() 263 | 264 | true_labels = [] 265 | preds = [] 266 | 267 | end = time.time() 268 | for i, batch in enumerate(val_loader): 269 | input, target, _ = tuple(t.to('cuda') for t in batch) 270 | target = target.cuda() 271 | input = input.cuda() 272 | input_var = torch.autograd.Variable(input) 273 | target_var = torch.autograd.Variable(target) 274 | 275 | with torch.no_grad(): 276 | _, output = model(input_var) 277 | 278 | output_numpy = output.data.cpu().numpy() 279 | preds_output = list(output_numpy.argmax(axis=1)) 280 | 281 | true_labels += list(target_var.data.cpu().numpy()) 282 | preds += preds_output 283 | 284 | 285 | prec1 = accuracy(output.data, target, topk=(1,))[0] 286 | top1.update(prec1.item(), input.size(0)) 287 | 288 | batch_time.update(time.time() - end) 289 | end = time.time() 290 | 291 | if i==len(val_loader)-1: #i % args.print_freq == 0: 292 | print('Test: [{0}/{1}]\t' 293 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 294 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 295 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( 296 | i, len(val_loader), batch_time=batch_time, loss=losses, 297 | top1=top1)) 298 | 299 | print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1)) 300 | 301 | return top1.avg, preds, true_labels 302 | 303 | 304 | def build_model(load_pretrain, ckpt_path=None): 305 | model = ResNet32(args.dataset == 'cifar10' and 10 or 100) 306 | 307 | if load_pretrain == True: 308 | checkpoint = torch.load(ckpt_path) 309 | model.load_state_dict(checkpoint['model_state_dict']) 310 | 311 | if torch.cuda.is_available(): 312 | model.cuda() 313 | torch.backends.cudnn.benchmark = True 314 | 315 | return model 316 | 317 | 318 | def to_var(x, requires_grad=True): 319 | if torch.cuda.is_available(): 320 | x = x.cuda() 321 | return Variable(x, requires_grad=requires_grad) 322 | 323 | 324 | def linear_normalize(weights): 325 | weights = torch.max(weights, torch.zeros_like(weights)) 326 | if torch.sum(weights) > 1e-8: 327 | return weights / torch.sum(weights) 328 | return torch.zeros_like(weights) 329 | 330 | 331 | def softmax_normalize(weights, temperature=1.): 332 | return nn.functional.softmax(weights / temperature, dim=0) 333 | 334 | 335 | class AverageMeter(object): 336 | 337 | def __init__(self): 338 | self.reset() 339 | 340 | def reset(self): 341 | self.val = 0 342 | self.avg = 0 343 | self.sum = 0 344 | self.count = 0 345 | 346 | def update(self, val, n=1): 347 | self.val = val 348 | self.sum += val * n 349 | self.count += n 350 | self.avg = self.sum / self.count 351 | 352 | 353 | 354 | def accuracy(output, target, topk=(1,)): 355 | maxk = max(topk) 356 | batch_size = target.size(0) 357 | 358 | _, pred = output.topk(maxk, 1, True, True) 359 | pred = pred.t() 360 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 361 | 362 | res = [] 363 | for k in topk: 364 | correct_k = correct[:k].view(-1).float().sum(0) 365 | res.append(correct_k.mul_(100.0 / batch_size)) 366 | return res 367 | 368 | 369 | def save_checkpoint(args, state, is_best): 370 | path = 'checkpoint/ours/' 371 | save_name = args.save_name 372 | if not os.path.exists(path): 373 | os.makedirs(path) 374 | filename = path + save_name + '_ckpt.pth.tar' 375 | if is_best: 376 | torch.save(state, filename) 377 | 378 | def to_categorical(labels): 379 | labels_onehot = torch.zeros([labels.shape[0], args.num_classes]) 380 | for label_epoch in range(labels.shape[0]): 381 | labels_onehot[label_epoch, labels[label_epoch]] = 1 382 | 383 | return labels_onehot 384 | 385 | if __name__ == '__main__': 386 | main() 387 | --------------------------------------------------------------------------------