├── .gitignore ├── README.md ├── isda.py ├── meta.py ├── models ├── mlp.py └── resnet.py ├── train.py └── utils ├── __init__.py ├── autoaugment.py ├── data_utils.py └── dataset.py /.gitignore: -------------------------------------------------------------------------------- 1 | pretrained_models/* 2 | */__pycache__/* 3 | output/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fine-grained Recognition with Learnable Semantic Data Augmentation 2 | 3 | accepted by IEEE Transactions on Image Processing (IEEE TIP) 4 | 5 | Authors: [Yifan Pu](https://github.com/yifanpu001/)\*, [Yizeng Han](https://yizenghan.top/)\*, [Yulin Wang](https://www.wyl.cool/), [Junlan Feng](https://scholar.google.com/citations?user=rBjPtmQAAAAJ&hl=en&oi=ao), Chao Deng, [Gao Huang](http://www.gaohuang.net/)\#. 6 | 7 | *: Equal contribution, #: Corresponding author. 8 | 9 | 10 | ## Get Started 11 | 1. prepare environment 12 | ``` 13 | conda create --name learnable_isda python=3.8 14 | conda activate learnable_isda 15 | pip install torch==2.0.0 torchvision==0.15.1 --index-url https://download.pytorch.org/whl/cu118 16 | pip install scipy pandas matplotlib imageio 17 | ``` 18 | 19 | 2. prepare data 20 | 21 | Download CUB-200-2011 from the [official website](https://www.vision.caltech.edu/datasets/cub_200_2011/) 22 | 23 | 3. prepare pretrained checkpoint 24 | ``` 25 | mkdir pretrained_models 26 | cd pretrained_models 27 | wget https://download.pytorch.org/models/resnet50-0676ba61.pth 28 | cd .. 29 | ``` 30 | 31 | ## Usage 32 | 33 | training 34 | 35 | ``` 36 | CUDA_VISIBLE_DEVICES=0,1,2,3 python train.py \ 37 | --data_root YOUR_DATA_PATH --output_dir_root ./ --output_dir output/ \ 38 | --model_type resnet50 --pretrained_dir ./pretrained_models/resnet50-0676ba61.pth \ 39 | --dataset CUB_200_2011 --train_batch_size 128 --lr 3e-2 --eval_batch_size 64 --workers 1 \ 40 | --meta_lr 1e-3 --meta_net_hidden_size 512 --meta_net_num_layers 1 --lambda_0 10.0 \ 41 | --epochs 100 --warmup_epochs 5; 42 | ``` 43 | 44 | ## Citation 45 | 46 | If you find our work is useful in your research, please consider citing: 47 | 48 | ``` 49 | @article{pu2023fine, 50 | title={Fine-grained recognition with learnable semantic data augmentation}, 51 | author={Pu, Yifan and Han, Yizeng and Wang, Yulin and Feng, Junlan and Deng, Chao and Huang, Gao}, 52 | journal={IEEE Transactions on Image Processing}, 53 | year={2023} 54 | } 55 | ``` 56 | 57 | ## Contact 58 | If you have any questions, please feel free to contact the authors. 59 | 60 | Yifan Pu: pyf20@mails.tsinghua.edu.cn, yifanpu98@126.com. 61 | 62 | Yizeng Han: hanyz18@mails.tsinghua.edu.cn, yizeng38@gmail.com. 63 | -------------------------------------------------------------------------------- /isda.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ISDALoss(nn.Module): 7 | def __init__(self, feature_num, class_num, local_rank): 8 | super(ISDALoss, self).__init__() 9 | self.class_num = class_num 10 | self.cross_entropy = nn.CrossEntropyLoss().cuda(local_rank) 11 | self.local_rank = local_rank 12 | 13 | def isda_aug(self, linear_layer, features, labels, cv_matrix, ratio): 14 | 15 | N = features.size(0) # batch size 16 | C = self.class_num # number of class (200 for CUB-200-2011) 17 | A = features.size(1) # feature dimension (2048 for ResNet50) 18 | 19 | weight_m = list(linear_layer.parameters())[0] # weight of Linear, shape = [200, 2048] = [C, A] 20 | 21 | NxW_ij = weight_m.expand(N, C, A) # shape=[8, 200, 2048] = [N, C, A] (copy weight_m for N times) 22 | 23 | NxW_kj = torch.gather( 24 | NxW_ij, 25 | 1, 26 | labels.view(N, 1, 1).expand(N, C, A) 27 | ) # shape = [8, 200, 2048] = [N, C, A] 28 | 29 | CV_temp = cv_matrix 30 | 31 | sigma2 = ratio \ 32 | * torch.mul( 33 | (weight_m - NxW_kj).pow(2), 34 | CV_temp.view(N, 1, A).expand(N, C, A), 35 | ).sum(2) 36 | 37 | logits = torch.nn.functional.linear(features, weight=linear_layer.weight, bias=linear_layer.bias) 38 | aug_logits = logits + 0.5 * sigma2 39 | 40 | return logits, aug_logits 41 | 42 | def forward(self, linear_layer, features, labels, ratio, cv_matrix): 43 | 44 | logits, aug_logits = self.isda_aug(linear_layer, features, labels, cv_matrix, ratio) 45 | 46 | loss = self.cross_entropy(aug_logits, labels) 47 | 48 | return loss, logits -------------------------------------------------------------------------------- /meta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.sgd import SGD 3 | 4 | 5 | class MetaSGD(SGD): 6 | def __init__(self, net, *args, **kwargs): 7 | super(MetaSGD, self).__init__(*args, **kwargs) 8 | self.net = net 9 | 10 | def set_parameter(self, current_module, name, parameters): 11 | if '.' in name: 12 | name_split = name.split('.') 13 | module_name = name_split[0] 14 | rest_name = '.'.join(name_split[1:]) 15 | for children_name, children in current_module.named_children(): 16 | if module_name == children_name: 17 | self.set_parameter(children, rest_name, parameters) 18 | break 19 | else: 20 | current_module._parameters[name] = parameters 21 | 22 | def meta_step(self, grads): 23 | group = self.param_groups[0] 24 | weight_decay = group['weight_decay'] 25 | momentum = group['momentum'] 26 | dampening = group['dampening'] 27 | nesterov = group['nesterov'] 28 | lr = group['lr'] 29 | 30 | for (name, parameter), grad in zip(self.net.named_parameters(), grads): 31 | 32 | if grad == None: # for the require_grad=False parameters 33 | continue 34 | 35 | parameter.detach_() 36 | if weight_decay != 0: 37 | grad_wd = grad.add(parameter, alpha=weight_decay) 38 | else: 39 | grad_wd = grad 40 | if momentum != 0 and 'momentum_buffer' in self.state[parameter]: 41 | buffer = self.state[parameter]['momentum_buffer'] 42 | grad_b = buffer.mul(momentum).add(grad_wd, alpha=1-dampening) 43 | else: 44 | grad_b = grad_wd 45 | if nesterov: 46 | grad_n = grad_wd.add(grad_b, alpha=momentum) 47 | else: 48 | grad_n = grad_b 49 | self.set_parameter(self.net, name, parameter.add(grad_n, alpha=-lr)) 50 | -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class HiddenLayer(nn.Module): 7 | def __init__(self, input_size, output_size): 8 | super(HiddenLayer, self).__init__() 9 | self.fc = nn.Linear(input_size, output_size) 10 | 11 | def forward(self, x): 12 | return F.relu(self.fc(x)) 13 | 14 | 15 | class MLP_sigmoid(nn.Module): 16 | def __init__(self, input_size=1, hidden_size=100, num_layers=1, output_size=1): 17 | super(MLP_sigmoid, self).__init__() 18 | self.first_hidden_layer = HiddenLayer(input_size, hidden_size) 19 | self.rest_hidden_layers = nn.Sequential(*[HiddenLayer(hidden_size, hidden_size) for _ in range(num_layers - 1)]) 20 | self.output_layer = nn.Linear(hidden_size, output_size) 21 | 22 | def forward(self, x): 23 | x = self.first_hidden_layer(x) 24 | x = self.rest_hidden_layers(x) 25 | x = self.output_layer(x) 26 | return torch.sigmoid(x) -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | try: 4 | from torch.hub import load_state_dict_from_url # noqa: 401 5 | except ImportError: 6 | from torch.utils.model_zoo import load_url as load_state_dict_from_url # noqa: 401 7 | 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 11 | 'wide_resnet50_2', 'wide_resnet101_2'] 12 | 13 | 14 | model_urls = { 15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 17 | "resnet50": 'https://download.pytorch.org/models/resnet50-0676ba61.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 20 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 21 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 22 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 23 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 24 | } 25 | 26 | 27 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 28 | """3x3 convolution with padding""" 29 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 30 | padding=dilation, groups=groups, bias=False, dilation=dilation) 31 | 32 | 33 | def conv1x1(in_planes, out_planes, stride=1): 34 | """1x1 convolution""" 35 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 36 | 37 | 38 | class BasicBlock(nn.Module): 39 | expansion = 1 40 | 41 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 42 | base_width=64, dilation=1, norm_layer=None): 43 | super(BasicBlock, self).__init__() 44 | if norm_layer is None: 45 | norm_layer = nn.BatchNorm2d 46 | if groups != 1 or base_width != 64: 47 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 48 | if dilation > 1: 49 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 50 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 51 | self.conv1 = conv3x3(inplanes, planes, stride) 52 | self.bn1 = norm_layer(planes) 53 | self.relu = nn.ReLU(inplace=True) 54 | self.conv2 = conv3x3(planes, planes) 55 | self.bn2 = norm_layer(planes) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | identity = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | 69 | if self.downsample is not None: 70 | identity = self.downsample(x) 71 | 72 | out += identity 73 | out = self.relu(out) 74 | 75 | return out 76 | 77 | 78 | class Bottleneck(nn.Module): 79 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 80 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 81 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 82 | # This variant is also known as ResNet V1.5 and improves accuracy according to 83 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 84 | 85 | expansion = 4 86 | 87 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 88 | base_width=64, dilation=1, norm_layer=None): 89 | super(Bottleneck, self).__init__() 90 | if norm_layer is None: 91 | norm_layer = nn.BatchNorm2d 92 | width = int(planes * (base_width / 64.)) * groups 93 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 94 | self.conv1 = conv1x1(inplanes, width) 95 | self.bn1 = norm_layer(width) 96 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 97 | self.bn2 = norm_layer(width) 98 | self.conv3 = conv1x1(width, planes * self.expansion) 99 | self.bn3 = norm_layer(planes * self.expansion) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.downsample = downsample 102 | self.stride = stride 103 | 104 | def forward(self, x): 105 | identity = x 106 | 107 | out = self.conv1(x) 108 | out = self.bn1(out) 109 | out = self.relu(out) 110 | 111 | out = self.conv2(out) 112 | out = self.bn2(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv3(out) 116 | out = self.bn3(out) 117 | 118 | if self.downsample is not None: 119 | identity = self.downsample(x) 120 | 121 | out += identity 122 | out = self.relu(out) 123 | 124 | return out 125 | 126 | 127 | class ResNet(nn.Module): 128 | 129 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 130 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 131 | norm_layer=None): 132 | super(ResNet, self).__init__() 133 | self.feature_num = 512 * block.expansion 134 | self.num_classes = num_classes 135 | 136 | if norm_layer is None: 137 | norm_layer = nn.BatchNorm2d 138 | self._norm_layer = norm_layer 139 | 140 | self.inplanes = 64 141 | self.dilation = 1 142 | if replace_stride_with_dilation is None: 143 | # each element in the tuple indicates if we should replace 144 | # the 2x2 stride with a dilated convolution instead 145 | replace_stride_with_dilation = [False, False, False] 146 | if len(replace_stride_with_dilation) != 3: 147 | raise ValueError("replace_stride_with_dilation should be None " 148 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 149 | self.groups = groups 150 | self.base_width = width_per_group 151 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 152 | bias=False) 153 | self.bn1 = norm_layer(self.inplanes) 154 | self.relu = nn.ReLU(inplace=True) 155 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 156 | self.layer1 = self._make_layer(block, 64, layers[0]) 157 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 158 | dilate=replace_stride_with_dilation[0]) 159 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 160 | dilate=replace_stride_with_dilation[1]) 161 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 162 | dilate=replace_stride_with_dilation[2]) 163 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 164 | self.head = nn.Linear(512 * block.expansion, num_classes) 165 | 166 | 167 | 168 | for m in self.modules(): 169 | if isinstance(m, nn.Conv2d): 170 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 171 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 172 | nn.init.constant_(m.weight, 1) 173 | nn.init.constant_(m.bias, 0) 174 | 175 | # Zero-initialize the last BN in each residual branch, 176 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 177 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 178 | if zero_init_residual: 179 | for m in self.modules(): 180 | if isinstance(m, Bottleneck): 181 | nn.init.constant_(m.bn3.weight, 0) 182 | elif isinstance(m, BasicBlock): 183 | nn.init.constant_(m.bn2.weight, 0) 184 | 185 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 186 | norm_layer = self._norm_layer 187 | downsample = None 188 | previous_dilation = self.dilation 189 | if dilate: 190 | self.dilation *= stride 191 | stride = 1 192 | if stride != 1 or self.inplanes != planes * block.expansion: 193 | downsample = nn.Sequential( 194 | conv1x1(self.inplanes, planes * block.expansion, stride), 195 | norm_layer(planes * block.expansion), 196 | ) 197 | 198 | layers = [] 199 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 200 | self.base_width, previous_dilation, norm_layer)) 201 | self.inplanes = planes * block.expansion 202 | for _ in range(1, blocks): 203 | layers.append(block(self.inplanes, planes, groups=self.groups, 204 | base_width=self.base_width, dilation=self.dilation, 205 | norm_layer=norm_layer)) 206 | 207 | return nn.Sequential(*layers) 208 | 209 | def _forward_impl(self, x): 210 | # See note [TorchScript super()] 211 | x = self.conv1(x) 212 | x = self.bn1(x) 213 | x = self.relu(x) 214 | x = self.maxpool(x) 215 | 216 | x = self.layer1(x) 217 | x = self.layer2(x) 218 | x = self.layer3(x) 219 | x = self.layer4(x) 220 | 221 | x = self.avgpool(x) 222 | features = torch.flatten(x, 1) 223 | # x = self.head(features) 224 | # x = torch.nn.functional.linear(features, weight=self.head.weight.detach(), bias=self.head.bias.detach()) 225 | 226 | return features 227 | 228 | def forward(self, x): 229 | return self._forward_impl(x) 230 | 231 | 232 | def _resnet(arch, block, layers, pretrained, progress, pretrained_dir, **kwargs): 233 | model = ResNet(block, layers, **kwargs) 234 | if pretrained: 235 | # state_dict = load_state_dict_from_url(model_urls[arch], 236 | # progress=progress) 237 | state_dict = torch.load(pretrained_dir, map_location='cpu') 238 | state_dict.pop('fc.weight') 239 | state_dict.pop('fc.bias') 240 | model.load_state_dict(state_dict, strict=False) 241 | return model 242 | 243 | 244 | def resnet18(pretrained=False, progress=True, pretrained_dir=None, **kwargs): 245 | r"""ResNet-18 model from 246 | `"Deep Residual Learning for Image Recognition" `_ 247 | 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | progress (bool): If True, displays a progress bar of the download to stderr 251 | """ 252 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, pretrained_dir, 253 | **kwargs) 254 | 255 | 256 | def resnet34(pretrained=False, progress=True, pretrained_dir=None, **kwargs): 257 | r"""ResNet-34 model from 258 | `"Deep Residual Learning for Image Recognition" `_ 259 | 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | progress (bool): If True, displays a progress bar of the download to stderr 263 | """ 264 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, pretrained_dir, 265 | **kwargs) 266 | 267 | 268 | def resnet50(pretrained=False, progress=True, pretrained_dir=None, **kwargs): 269 | r"""ResNet-50 model from 270 | `"Deep Residual Learning for Image Recognition" `_ 271 | 272 | Args: 273 | pretrained (bool): If True, returns a model pre-trained on ImageNet 274 | progress (bool): If True, displays a progress bar of the download to stderr 275 | """ 276 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, pretrained_dir, 277 | **kwargs) 278 | 279 | 280 | def resnet101(pretrained=False, progress=True, pretrained_dir=None, **kwargs): 281 | r"""ResNet-101 model from 282 | `"Deep Residual Learning for Image Recognition" `_ 283 | 284 | Args: 285 | pretrained (bool): If True, returns a model pre-trained on ImageNet 286 | progress (bool): If True, displays a progress bar of the download to stderr 287 | """ 288 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, pretrained_dir, 289 | **kwargs) 290 | 291 | 292 | def resnet152(pretrained=False, progress=True, pretrained_dir=None, **kwargs): 293 | r"""ResNet-152 model from 294 | `"Deep Residual Learning for Image Recognition" `_ 295 | 296 | Args: 297 | pretrained (bool): If True, returns a model pre-trained on ImageNet 298 | progress (bool): If True, displays a progress bar of the download to stderr 299 | """ 300 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, pretrained_dir, 301 | **kwargs) 302 | 303 | 304 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 305 | r"""ResNeXt-50 32x4d model from 306 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 307 | 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | progress (bool): If True, displays a progress bar of the download to stderr 311 | """ 312 | kwargs['groups'] = 32 313 | kwargs['width_per_group'] = 4 314 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 315 | pretrained, progress, **kwargs) 316 | 317 | 318 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 319 | r"""ResNeXt-101 32x8d model from 320 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 321 | 322 | Args: 323 | pretrained (bool): If True, returns a model pre-trained on ImageNet 324 | progress (bool): If True, displays a progress bar of the download to stderr 325 | """ 326 | kwargs['groups'] = 32 327 | kwargs['width_per_group'] = 8 328 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 329 | pretrained, progress, **kwargs) 330 | 331 | 332 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 333 | r"""Wide ResNet-50-2 model from 334 | `"Wide Residual Networks" `_ 335 | 336 | The model is the same as ResNet except for the bottleneck number of channels 337 | which is twice larger in every block. The number of channels in outer 1x1 338 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 339 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 340 | 341 | Args: 342 | pretrained (bool): If True, returns a model pre-trained on ImageNet 343 | progress (bool): If True, displays a progress bar of the download to stderr 344 | """ 345 | kwargs['width_per_group'] = 64 * 2 346 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 347 | pretrained, progress, **kwargs) 348 | 349 | 350 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 351 | r"""Wide ResNet-101-2 model from 352 | `"Wide Residual Networks" `_ 353 | 354 | The model is the same as ResNet except for the bottleneck number of channels 355 | which is twice larger in every block. The number of channels in outer 1x1 356 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 357 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 358 | 359 | Args: 360 | pretrained (bool): If True, returns a model pre-trained on ImageNet 361 | progress (bool): If True, displays a progress bar of the download to stderr 362 | """ 363 | kwargs['width_per_group'] = 64 * 2 364 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 365 | pretrained, progress, **kwargs) 366 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import sys 5 | import math 6 | import time 7 | import copy 8 | import pickle 9 | import socket 10 | import random 11 | import logging 12 | import argparse 13 | import numpy as np 14 | from enum import Enum 15 | 16 | import torch 17 | import torch.distributed as dist 18 | import torch.multiprocessing as mp 19 | import torch.backends.cudnn as cudnn 20 | 21 | import torchvision 22 | 23 | # from models.ViT import VisionTransformer, CONFIGS 24 | from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152 25 | from models.mlp import MLP_sigmoid 26 | from utils.data_utils import get_loader 27 | from isda import ISDALoss 28 | from meta import MetaSGD 29 | 30 | logger = logging.getLogger(__name__) 31 | best_acc1, best_epoch = 0.0, 0 32 | 33 | 34 | """ some tools """ 35 | class Summary(Enum): 36 | NONE = 0 37 | AVERAGE = 1 38 | SUM = 2 39 | COUNT = 3 40 | 41 | 42 | class AverageMeter(object): 43 | """Computes and stores the average and current value""" 44 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): 45 | self.name = name 46 | self.fmt = fmt 47 | self.summary_type = summary_type 48 | self.reset() 49 | 50 | def reset(self): 51 | self.val = 0 52 | self.avg = 0 53 | self.sum = 0 54 | self.count = 0 55 | 56 | def update(self, val, n=1): 57 | self.val = val 58 | self.sum += val * n 59 | self.count += n 60 | self.avg = self.sum / self.count 61 | 62 | def __str__(self): 63 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 64 | return fmtstr.format(**self.__dict__) 65 | 66 | def summary(self): 67 | fmtstr = '' 68 | if self.summary_type is Summary.NONE: 69 | fmtstr = '' 70 | elif self.summary_type is Summary.AVERAGE: 71 | fmtstr = '{name} {avg:.3f}' 72 | elif self.summary_type is Summary.SUM: 73 | fmtstr = '{name} {sum:.3f}' 74 | elif self.summary_type is Summary.COUNT: 75 | fmtstr = '{name} {count:.3f}' 76 | else: 77 | raise ValueError('invalid summary type %r' % self.summary_type) 78 | 79 | return fmtstr.format(**self.__dict__) 80 | 81 | 82 | class ProgressMeter(object): 83 | def __init__(self, num_batches, meters, prefix=""): 84 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 85 | self.meters = meters 86 | self.prefix = prefix 87 | 88 | def display(self, batch): 89 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 90 | entries += [str(meter) for meter in self.meters] 91 | print('\t'.join(entries)) 92 | logger.info('\t'.join(entries)) 93 | 94 | def display_summary(self): 95 | entries = [" *"] 96 | entries += [meter.summary() for meter in self.meters] 97 | print(' '.join(entries)) 98 | logger.info(' '.join(entries)) 99 | 100 | def _get_batch_fmtstr(self, num_batches): 101 | num_digits = len(str(num_batches // 1)) 102 | fmt = '{:' + str(num_digits) + 'd}' 103 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 104 | 105 | 106 | def accuracy(output, target, topk=(1,)): 107 | """Computes the accuracy over the k top predictions for the specified values of k""" 108 | with torch.no_grad(): 109 | maxk = max(topk) 110 | batch_size = target.size(0) 111 | 112 | _, pred = output.topk(maxk, 1, True, True) 113 | pred = pred.t() 114 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 115 | 116 | res = [] 117 | for k in topk: 118 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 119 | res.append(correct_k.mul_(100.0 / batch_size)) 120 | return res 121 | 122 | 123 | def adjust_learning_rate(optimizer, init_lr, epoch_total, warmup_epochs, epoch_cur, num_iter_per_epoch, i_iter): 124 | """ 125 | cosine learning rate with warm-up 126 | """ 127 | if epoch_cur < warmup_epochs: 128 | # T_cur = 1, 2, 3, ..., (T_total - 1) 129 | T_cur = 1 + epoch_cur * num_iter_per_epoch + i_iter 130 | T_total = 1 + warmup_epochs * num_iter_per_epoch 131 | lr = (T_cur / T_total) * init_lr 132 | else: 133 | # T_cur = 0, 1, 2, 3, ..., (T_total - 1) 134 | T_cur = (epoch_cur - warmup_epochs) * num_iter_per_epoch + i_iter 135 | T_total = (epoch_total - warmup_epochs) * num_iter_per_epoch 136 | lr = 0.5 * init_lr * (1 + math.cos(math.pi * T_cur / T_total)) 137 | for param_group in optimizer.param_groups: 138 | param_group['lr'] = lr 139 | return lr 140 | 141 | 142 | def adjust_meta_learning_rate(optimizer, init_lr, epoch_total, warmup_epochs, epoch_cur, num_iter_per_epoch, i_iter): 143 | """ 144 | cosine learning rate with warm-up 145 | """ 146 | if epoch_cur < warmup_epochs: 147 | # T_cur = 1, 2, 3, ..., (T_total - 1) 148 | T_cur = 1 + epoch_cur * num_iter_per_epoch + i_iter 149 | T_total = 1 + warmup_epochs * num_iter_per_epoch 150 | lr = (T_cur / T_total) * init_lr 151 | else: 152 | # T_cur = 0, 1, 2, 3, ..., (T_total - 1) 153 | T_cur = (epoch_cur - warmup_epochs) * num_iter_per_epoch + i_iter 154 | T_total = (epoch_total - warmup_epochs) * num_iter_per_epoch 155 | lr = 0.5 * init_lr * (1 + math.cos(math.pi * T_cur / T_total)) 156 | for param_group in optimizer.param_groups: 157 | param_group['lr'] = lr 158 | return lr 159 | 160 | 161 | def get_lr(optimizer): 162 | for param_group in optimizer.param_groups: 163 | return param_group['lr'] 164 | 165 | 166 | def save_checkpoint(state, is_best, save_dir, filename='checkpoint.pth.tar'): 167 | if not os.path.exists(save_dir): 168 | os.makedirs(save_dir) 169 | torch.save(state, os.path.join(save_dir, filename)) 170 | if is_best: 171 | torch.save(state, os.path.join(save_dir, 'model_best.pth.tar')) 172 | 173 | 174 | def estimated_time(t_start, cur_epoch, start_epoch, total_epoch): 175 | t_curr = time.time() 176 | eta_total = (t_curr - t_start) / (cur_epoch + 1 - start_epoch) * (total_epoch - cur_epoch - 1) 177 | eta_hour = int(eta_total // 3600) 178 | eta_min = int((eta_total - eta_hour * 3600) // 60) 179 | eta_sec = int(eta_total - eta_hour * 3600 - eta_min * 60) 180 | return f'Finished epoch:{cur_epoch:05d}/{total_epoch:05d}; ETA {eta_hour:02d} h {eta_min:02d} m {eta_sec:02d} s' 181 | 182 | 183 | def count_parameters(model): 184 | params = sum(p.numel() for p in model.parameters() if p.requires_grad) 185 | return params/1000000 186 | 187 | 188 | def get_free_port(): 189 | sock = socket.socket() 190 | sock.bind(('', 0)) 191 | free_port = sock.getsockname()[1] 192 | return free_port 193 | 194 | 195 | """ part of main """ 196 | def get_args(): 197 | parser = argparse.ArgumentParser() 198 | # Model Related 199 | parser.add_argument("--model_type", choices=["resnet18", "resnet34", "resnet50", "resnet101", "resnet152"], 200 | default="resnet50", 201 | help="Which variant to use.") 202 | parser.add_argument("--pretrained_dir", type=str, default="./pretrained_models/resnet50-0676ba61.pth", 203 | help="Where to search for pretrained models.") 204 | # Data Related 205 | parser.add_argument("--dataset", choices=["CUB_200_2011", "car", "dog", "nabirds", "Aircraft", "INat2017"], default="CUB_200_2011", 206 | help="Which dataset.") 207 | parser.add_argument('--data_root', type=str, default='/home/data') 208 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 209 | help='number of data loading workers (default: 4)') 210 | parser.add_argument("--img_size", default=448, type=int, 211 | help="Resolution size") 212 | parser.add_argument("--train_batch_size", default=16, type=int, 213 | help="Total batch size for training.") 214 | parser.add_argument("--eval_batch_size", default=8, type=int, 215 | help="Total batch size for eval.") 216 | # Directory Related 217 | parser.add_argument("--output_dir_root", default="./", type=str, 218 | help="output_dir's root") 219 | parser.add_argument("--output_dir", default="output", type=str, 220 | help="The output directory where checkpoints will be written.") 221 | # Optimizer & Learning Schedule 222 | parser.add_argument("--lr", "--learning_rate", default=3e-2, type=float, 223 | help="The initial learning rate.") 224 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 225 | help='momentum') 226 | parser.add_argument("--weight_decay", default=0, type=float, 227 | help="Weight deay if we apply some.") 228 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 229 | help='number of total epochs to run') 230 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 231 | help='manual epoch number (useful on restarts)') 232 | parser.add_argument("--warmup_epochs", default=5, type=int, 233 | help="Step of training to perform learning rate warmup for.") 234 | # For a Specific Experiment 235 | parser.add_argument('--seed', type=int, default=42, 236 | help="random seed for initialization") 237 | parser.add_argument('-p', '--print-freq', default=10, type=int, 238 | metavar='N', help='print frequency (default: 10)') 239 | parser.add_argument('--round', type=int, help="repeat same hyperparameter round") 240 | 241 | # ISDA 242 | parser.add_argument('--lambda_0', type=float, required=True, 243 | help='The hyper-parameter \lambda_0 for ISDA, select from {1, 2.5, 5, 7.5, 10}. ' 244 | 'We adopt 1 for DenseNets and 7.5 for ResNets and ResNeXts, except for using 5 for ResNet-101.') 245 | # Meta 246 | parser.add_argument("--meta_lr", default=1e-3, type=float, help="The initial meta learning rate.") 247 | parser.add_argument('--meta_net_hidden_size', default=512, type=int, required=True) 248 | parser.add_argument('--meta_net_num_layers', default=1, type=int, required=True) 249 | parser.add_argument('--meta_weight_decay', type=float, default=0.0) 250 | args = parser.parse_args() 251 | return args 252 | 253 | 254 | def set_seed(args): 255 | random.seed(args.seed) 256 | np.random.seed(args.seed) 257 | torch.manual_seed(args.seed) 258 | torch.cuda.manual_seed_all(args.seed) 259 | 260 | 261 | def setup_model(args): 262 | 263 | if args.dataset == "CUB_200_2011": 264 | args.num_classes = 200 265 | elif args.dataset == "car": 266 | args.num_classes = 196 267 | elif args.dataset == "nabirds": 268 | args.num_classes = 555 269 | elif args.dataset == "dog": 270 | args.num_classes = 120 271 | elif args.dataset == "Aircraft": 272 | args.num_classes = 100 273 | elif args.dataset == "INat2017": 274 | args.num_classes = 5089 275 | 276 | model = eval(args.model_type)(pretrained=True, num_classes=args.num_classes, pretrained_dir=args.pretrained_dir) 277 | 278 | return args, model 279 | 280 | 281 | def main(): 282 | # Get args 283 | args = get_args() 284 | 285 | # Setup data_root 286 | args.data_root = '{}/{}'.format(args.data_root, args.dataset) 287 | 288 | # Setup save path 289 | args.output_dir = os.path.join( 290 | args.output_dir_root, 291 | args.output_dir, 292 | f'{args.dataset}_{args.model_type}_bs{args.train_batch_size}_lr{args.lr}_wd{args.weight_decay}_epochs{args.epochs}_wmsteps{args.warmup_epochs}_mlr{args.meta_lr}_mhs{args.meta_net_hidden_size}_mlyer{args.meta_net_num_layers}_mdw{args.meta_weight_decay}_lbd{args.lambda_0}_round{args.round}/' 293 | ) 294 | if not os.path.exists(args.output_dir): 295 | os.makedirs(args.output_dir) 296 | 297 | # Set seed 298 | set_seed(args) 299 | 300 | # get free port 301 | args.port = get_free_port() 302 | 303 | # Start Multiprocessing 304 | mp.spawn(main_worker, nprocs=torch.cuda.device_count(), args=(torch.cuda.device_count(), args)) 305 | 306 | 307 | def main_worker(local_rank, ngpus_per_node, args): 308 | global best_acc1, best_epoch 309 | args.local_rank = local_rank 310 | 311 | 312 | # Setup logging 313 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 314 | datefmt='%m/%d/%Y %H:%M:%S', 315 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 316 | filename=os.path.join(args.output_dir, 'screen_output.log')) 317 | 318 | 319 | # Multiprocessing 320 | ip = '127.0.0.1' 321 | port = args.port 322 | hosts = 1 323 | rank = 0 324 | args.ngpus_per_node = ngpus_per_node 325 | args.world_size = hosts * args.ngpus_per_node 326 | args.world_rank = rank * args.ngpus_per_node + args.local_rank 327 | dist.init_process_group(backend='nccl', init_method=f'tcp://{ip}:{port}', world_size=args.world_size, rank=args.world_rank) 328 | args.is_main_proc = (args.world_rank == 0) 329 | 330 | 331 | # Model Setup 332 | args, model = setup_model(args) 333 | meta_net = MLP_sigmoid( 334 | input_size=model.feature_num, 335 | hidden_size=args.meta_net_hidden_size, 336 | num_layers=args.meta_net_num_layers, 337 | output_size=model.feature_num, 338 | ) 339 | 340 | 341 | # DistributedDataParallel 342 | args.train_batch_size = int(args.train_batch_size / args.ngpus_per_node) 343 | args.workers = int((args.workers + args.ngpus_per_node - 1) / args.ngpus_per_node) 344 | torch.cuda.set_device(args.local_rank) 345 | model.cuda(args.local_rank) 346 | meta_net.cuda(args.local_rank) 347 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank]) 348 | meta_net = torch.nn.parallel.DistributedDataParallel(meta_net, device_ids=[args.local_rank]) 349 | 350 | 351 | # Prepare optimizer 352 | criterion_ce = torch.nn.CrossEntropyLoss().cuda(args.local_rank) 353 | criterion_isda = ISDALoss(model.module.feature_num, args.num_classes, args.local_rank).cuda(args.local_rank) 354 | optimizer = torch.optim.SGD(model.parameters(), 355 | lr=args.lr, 356 | momentum=args.momentum, 357 | weight_decay=args.weight_decay) 358 | meta_optimizer = torch.optim.SGD(meta_net.parameters(), lr=args.meta_lr, weight_decay=args.meta_weight_decay) 359 | 360 | 361 | cudnn.benchmark = True 362 | 363 | 364 | # Prepare dataset 365 | train_loader, test_loader, train_sampler = get_loader(args) 366 | 367 | 368 | # Init scores_all.csv 369 | if args.is_main_proc: 370 | if not os.path.exists(args.output_dir + '/scores_all.csv'): 371 | with open(args.output_dir + '/scores_all.csv', "a") as f: 372 | f.write(f'epoch, lr, loss_train, loss_meta, acc1_train, loss_test, acc1_test, acc1_test_best,\n') 373 | 374 | 375 | # Auto Resume 376 | resume_dir = os.path.join(args.output_dir, "save_models", "checkpoint.pth.tar") 377 | if os.path.exists(resume_dir): 378 | logger.info(f'[INFO] resume dir: {resume_dir}') 379 | ckpt = torch.load(resume_dir, map_location='cpu') 380 | args.start_epoch = ckpt['epoch'] 381 | model.module.load_state_dict(ckpt['state_dict']) 382 | optimizer.load_state_dict(ckpt['optimizer']) 383 | meta_net.module.load_state_dict(ckpt['meta_state_dict']) 384 | meta_optimizer.load_state_dict(ckpt['meta_optimizer']) 385 | curr_acc1 = ckpt['curr_acc1'] 386 | best_acc1 = ckpt['best_acc1'] 387 | logger.info(f'[INFO] Auto Resume from {resume_dir}, from finished epoch {args.start_epoch}, with acc_best{best_acc1}, acc_curr {curr_acc1}.') 388 | 389 | 390 | # Start Train 391 | logger.info("***** Running training *****") 392 | 393 | 394 | pseudo_net_init = eval(args.model_type)(pretrained=False, num_classes=args.num_classes) 395 | 396 | 397 | start_time = time.time() 398 | for epoch in range(args.start_epoch, args.epochs): 399 | train_sampler.set_epoch(epoch) 400 | loss_train, loss_meta, acc1_train = train_meta(train_loader, model, meta_net, pseudo_net_init, criterion_isda, criterion_ce, optimizer, meta_optimizer, epoch, args) 401 | 402 | if (epoch % 5 == 0) or (epoch >= args.epochs - 15) or (epoch <= 15): 403 | loss_test, acc1_test = validate(test_loader, model, criterion_ce, args) 404 | 405 | if args.is_main_proc: 406 | 407 | is_best = acc1_test > best_acc1 408 | best_acc1 = max(acc1_test, best_acc1) 409 | if is_best: 410 | best_epoch = epoch 411 | 412 | with open(args.output_dir + '/scores_all.csv', "a") as f: 413 | f.write( 414 | f"{epoch:3d}, {get_lr(optimizer):15.12f}, {loss_train:9.8f}, {loss_meta:9.8f}, {acc1_train:6.3f}, {loss_test:9.8f}, {acc1_test:6.3f}, {best_acc1:6.3f},\n" 415 | ) 416 | 417 | save_checkpoint( 418 | {'epoch': epoch + 1, 419 | 'state_dict': model.module.state_dict(), 420 | 'optimizer' : optimizer.state_dict(), 421 | 'meta_state_dict': meta_net.module.state_dict(), 422 | 'meta_optimizer': meta_optimizer.state_dict(), 423 | 'curr_acc1': acc1_test, 424 | 'best_acc1': best_acc1, 425 | 'best_epoch': best_epoch, 426 | }, is_best, save_dir=os.path.join(args.output_dir, 'save_models') 427 | ) 428 | 429 | if args.is_main_proc: 430 | logger.info(estimated_time(start_time, epoch, args.start_epoch, args.epochs)) 431 | 432 | # record final result 433 | if args.is_main_proc: 434 | with open(args.output_dir + '/scores_final.csv', "a") as f: 435 | f.write(f'epoch, lr, loss_train, loss_meta, acc1_train, loss_test, acc1_test, acc1_test_best,\n') 436 | f.write(f"{epoch:3d}, {get_lr(optimizer):15.12f}, {loss_train:9.8f}, {loss_meta:9.8f}, {acc1_train:6.3f}, {loss_test:9.8f}, {acc1_test:6.3f}, {best_acc1:6.3f},\n") 437 | 438 | logger.info("Best Accuracy: \t%f" % best_acc1) 439 | logger.info("Last Accuracy: \t%f" % acc1_test) 440 | logger.info("Training Complete.") 441 | 442 | 443 | def train_meta(train_loader, model, meta_net, pseudo_net_init, criterion_isda, criterion_ce, optimizer, meta_optimizer, epoch, args): 444 | batch_time = AverageMeter('Time', ':6.3f') 445 | data_time = AverageMeter('Data', ':6.3f') 446 | losses = AverageMeter('Loss', ':.4e') 447 | meta_losses = AverageMeter('MetaLoss', ':.4e') 448 | top1 = AverageMeter('Acc@1', ':6.2f') 449 | top5 = AverageMeter('Acc@5', ':6.2f') 450 | progress = ProgressMeter( 451 | len(train_loader), 452 | [batch_time, data_time, losses, meta_losses, top1, top5], 453 | prefix="Epoch: [{}]".format(epoch)) 454 | 455 | # switch to train mode 456 | model.train() 457 | 458 | pseudo_net_init.cuda(args.local_rank) 459 | 460 | end = time.time() 461 | for i, (images, target) in enumerate(train_loader): 462 | images = images.cuda(args.local_rank, non_blocking=True) 463 | target = target.cuda(args.local_rank, non_blocking=True) 464 | 465 | images_p1, images_p2 = images.chunk(2, dim=0) 466 | target_p1, target_p2 = target.chunk(2, dim=0) 467 | 468 | data_time.update(time.time() - end) 469 | 470 | # adjust learning rate 471 | lr = adjust_learning_rate(optimizer, init_lr=args.lr, 472 | epoch_total=args.epochs, warmup_epochs=args.warmup_epochs, epoch_cur=epoch, 473 | num_iter_per_epoch=len(train_loader), i_iter=i) 474 | meta_lr = adjust_meta_learning_rate(meta_optimizer, init_lr=args.meta_lr, 475 | epoch_total=args.epochs, warmup_epochs=args.warmup_epochs, epoch_cur=epoch, 476 | num_iter_per_epoch=len(train_loader), i_iter=i) 477 | ratio = args.lambda_0 * (epoch / args.epochs) 478 | 479 | ################################################### 480 | ## part 1: images_p1 as train, images_p2 as meta ## 481 | ################################################### 482 | pseudo_net = pickle.loads(pickle.dumps(pseudo_net_init)) 483 | pseudo_net.load_state_dict(model.module.state_dict()) 484 | pseudo_net.train() 485 | 486 | pseudo_outputs_features = pseudo_net(images_p1) 487 | pseudo_cv_matrix = meta_net(pseudo_outputs_features.detach()) 488 | pseudo_loss, pseudo_outputs_logits = criterion_isda(pseudo_net.head, pseudo_outputs_features, target_p1, ratio, pseudo_cv_matrix) 489 | 490 | pseudo_grads = torch.autograd.grad(pseudo_loss, pseudo_net.parameters(), create_graph=True, allow_unused=True) 491 | 492 | pseudo_optimizer = MetaSGD(pseudo_net, pseudo_net.parameters(), lr=lr) 493 | pseudo_optimizer.load_state_dict(optimizer.state_dict()) 494 | pseudo_optimizer.meta_step(pseudo_grads) 495 | 496 | del pseudo_grads 497 | 498 | 499 | meta_outputs_features = pseudo_net(images_p2) 500 | meta_outputs_logits = pseudo_net.head(meta_outputs_features) 501 | meta_loss = criterion_ce(meta_outputs_logits, target_p2) 502 | meta_losses.update(meta_loss.item(), images_p2.size(0)) 503 | 504 | meta_optimizer.zero_grad() 505 | meta_loss.backward() 506 | meta_optimizer.step() 507 | 508 | 509 | outputs_features = model(images_p1) 510 | cv_matrix = meta_net(outputs_features) 511 | loss, outputs_logits = criterion_isda(model.module.head, outputs_features, target_p1, ratio, cv_matrix) 512 | 513 | 514 | # measure accuracy and record loss 515 | acc1, acc5 = accuracy(outputs_logits, target_p1, topk=(1, 5)) 516 | losses.update(loss.item(), images_p1.size(0)) 517 | top1.update(acc1[0].item(), images_p1.size(0)) 518 | top5.update(acc5[0].item(), images_p1.size(0)) 519 | 520 | # compute gradient and do SGD step 521 | optimizer.zero_grad() 522 | loss.backward() 523 | optimizer.step() 524 | 525 | 526 | ################################################### 527 | ## part 2: images_p2 as train, images_p1 as meta ## 528 | ################################################### 529 | pseudo_net = pickle.loads(pickle.dumps(pseudo_net_init)) 530 | pseudo_net.load_state_dict(model.module.state_dict()) 531 | pseudo_net.train() 532 | 533 | pseudo_outputs_features = pseudo_net(images_p2) 534 | pseudo_cv_matrix = meta_net(pseudo_outputs_features.detach()) 535 | pseudo_loss, pseudo_outputs_logits = criterion_isda(pseudo_net.head, pseudo_outputs_features, target_p2, ratio, pseudo_cv_matrix) 536 | 537 | pseudo_grads = torch.autograd.grad(pseudo_loss, pseudo_net.parameters(), create_graph=True, allow_unused=True) 538 | 539 | pseudo_optimizer = MetaSGD(pseudo_net, pseudo_net.parameters(), lr=lr) 540 | pseudo_optimizer.load_state_dict(optimizer.state_dict()) 541 | pseudo_optimizer.meta_step(pseudo_grads) 542 | 543 | del pseudo_grads 544 | 545 | 546 | meta_outputs_features = pseudo_net(images_p1) 547 | meta_outputs_logits = pseudo_net.head(meta_outputs_features) 548 | meta_loss = criterion_ce(meta_outputs_logits, target_p1) 549 | meta_losses.update(meta_loss.item(), images_p1.size(0)) 550 | 551 | meta_optimizer.zero_grad() 552 | meta_loss.backward() 553 | meta_optimizer.step() 554 | 555 | 556 | outputs_features = model(images_p2) 557 | cv_matrix = meta_net(outputs_features) 558 | loss, outputs_logits = criterion_isda(model.module.head, outputs_features, target_p2, ratio, cv_matrix) 559 | 560 | 561 | # measure accuracy and record loss 562 | acc1, acc5 = accuracy(outputs_logits, target_p2, topk=(1, 5)) 563 | losses.update(loss.item(), images_p2.size(0)) 564 | top1.update(acc1[0].item(), images_p2.size(0)) 565 | top5.update(acc5[0].item(), images_p2.size(0)) 566 | 567 | # compute gradient and do SGD step 568 | optimizer.zero_grad() 569 | loss.backward() 570 | optimizer.step() 571 | 572 | ################################################### 573 | ## finish exchange ## 574 | ################################################### 575 | 576 | 577 | # measure elapsed time 578 | batch_time.update(time.time() - end) 579 | end = time.time() 580 | 581 | if ((i % args.print_freq == 0) or (i == len(train_loader) - 1)) and args.is_main_proc: 582 | progress.display(i) 583 | 584 | return losses.avg, meta_losses.avg, top1.avg 585 | 586 | 587 | def validate(val_loader, model, criterion, args): 588 | batch_time = AverageMeter('Time', ':6.3f', Summary.NONE) 589 | losses = AverageMeter('Loss', ':.4e', Summary.NONE) 590 | top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE) 591 | top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE) 592 | progress = ProgressMeter( 593 | len(val_loader), 594 | [batch_time, losses, top1, top5], 595 | prefix='Test: ') 596 | 597 | # switch to evaluate mode 598 | model.eval() 599 | 600 | with torch.no_grad(): 601 | end = time.time() 602 | for i, (images, target) in enumerate(val_loader): 603 | images = images.cuda(args.local_rank, non_blocking=True) 604 | target = target.cuda(args.local_rank, non_blocking=True) 605 | 606 | # compute output 607 | features = model(images) 608 | logits = model.module.head(features) 609 | loss = criterion(logits, target) 610 | 611 | # measure accuracy and record loss 612 | acc1, acc5 = accuracy(logits, target, topk=(1, 5)) 613 | 614 | dist.all_reduce(acc1) 615 | acc1 /= args.world_size 616 | dist.all_reduce(acc5) 617 | acc5 /= args.world_size 618 | dist.all_reduce(loss) 619 | loss /= args.world_size 620 | 621 | losses.update(loss.item(), images.size(0)) 622 | top1.update(acc1[0].item(), images.size(0)) 623 | top5.update(acc5[0].item(), images.size(0)) 624 | 625 | # measure elapsed time 626 | batch_time.update(time.time() - end) 627 | end = time.time() 628 | 629 | if ((i % args.print_freq == 0) or (i == len(val_loader) - 1)) and args.is_main_proc: 630 | progress.display(i) 631 | 632 | if args.is_main_proc: 633 | progress.display_summary() 634 | 635 | return losses.avg, top1.avg 636 | 637 | 638 | if __name__ == '__main__': 639 | main() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LeapLabTHU/LearnableISDA/9473a481d9639efd64ebec2452ce2ec4f5dc2223/utils/__init__.py -------------------------------------------------------------------------------- /utils/autoaugment.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copy from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py 3 | """ 4 | 5 | import random 6 | import numpy as np 7 | from PIL import Image, ImageEnhance, ImageOps 8 | 9 | 10 | __all__ = ['AutoAugImageNetPolicy', 'AutoAugCIFAR10Policy', 'AutoAugSVHNPolicy'] 11 | 12 | 13 | class AutoAugImageNetPolicy(object): 14 | def __init__(self, fillcolor=(128, 128, 128)): 15 | self.policies = [ 16 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 17 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 18 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 19 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 20 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 21 | 22 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 23 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 24 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 25 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 26 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 27 | 28 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 29 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 30 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 31 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 32 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 33 | 34 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 35 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 36 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 37 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 38 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 39 | 40 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 41 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 42 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 43 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor) 44 | ] 45 | 46 | def __call__(self, img): 47 | policy_idx = random.randint(0, len(self.policies) - 1) 48 | return self.policies[policy_idx](img) 49 | 50 | def __repr__(self): 51 | return "AutoAugment ImageNet Policy" 52 | 53 | 54 | class AutoAugCIFAR10Policy(object): 55 | def __init__(self, fillcolor=(128, 128, 128)): 56 | self.policies = [ 57 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 58 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 59 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 60 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 61 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 62 | 63 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 64 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 65 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 66 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 67 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 68 | 69 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 70 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 71 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 72 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 73 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 74 | 75 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 76 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 77 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 78 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 79 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 80 | 81 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 82 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 83 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 84 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 85 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 86 | ] 87 | 88 | def __call__(self, img): 89 | policy_idx = random.randint(0, len(self.policies) - 1) 90 | return self.policies[policy_idx](img) 91 | 92 | def __repr__(self): 93 | return "AutoAugment CIFAR10 Policy" 94 | 95 | 96 | class AutoAugSVHNPolicy(object): 97 | def __init__(self, fillcolor=(128, 128, 128)): 98 | self.policies = [ 99 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 100 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 101 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 102 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 103 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 104 | 105 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 106 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 107 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 108 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 109 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 110 | 111 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 112 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 113 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 114 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 115 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 116 | 117 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 118 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 119 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 120 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 121 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 122 | 123 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 124 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 125 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 126 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 127 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 128 | ] 129 | 130 | def __call__(self, img): 131 | policy_idx = random.randint(0, len(self.policies) - 1) 132 | return self.policies[policy_idx](img) 133 | 134 | def __repr__(self): 135 | return "AutoAugment SVHN Policy" 136 | 137 | 138 | class SubPolicy(object): 139 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 140 | ranges = { 141 | "shearX": np.linspace(0, 0.3, 10), 142 | "shearY": np.linspace(0, 0.3, 10), 143 | "translateX": np.linspace(0, 150 / 331, 10), 144 | "translateY": np.linspace(0, 150 / 331, 10), 145 | "rotate": np.linspace(0, 30, 10), 146 | "color": np.linspace(0.0, 0.9, 10), 147 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 148 | "solarize": np.linspace(256, 0, 10), 149 | "contrast": np.linspace(0.0, 0.9, 10), 150 | "sharpness": np.linspace(0.0, 0.9, 10), 151 | "brightness": np.linspace(0.0, 0.9, 10), 152 | "autocontrast": [0] * 10, 153 | "equalize": [0] * 10, 154 | "invert": [0] * 10 155 | } 156 | 157 | def rotate_with_fill(img, magnitude): 158 | rot = img.convert("RGBA").rotate(magnitude) 159 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 160 | 161 | func = { 162 | "shearX": lambda img, magnitude: img.transform( 163 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 164 | Image.BICUBIC, fillcolor=fillcolor), 165 | "shearY": lambda img, magnitude: img.transform( 166 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 167 | Image.BICUBIC, fillcolor=fillcolor), 168 | "translateX": lambda img, magnitude: img.transform( 169 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 170 | fillcolor=fillcolor), 171 | "translateY": lambda img, magnitude: img.transform( 172 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 173 | fillcolor=fillcolor), 174 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 175 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 176 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 177 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 178 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 179 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 180 | 1 + magnitude * random.choice([-1, 1])), 181 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 182 | 1 + magnitude * random.choice([-1, 1])), 183 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 184 | 1 + magnitude * random.choice([-1, 1])), 185 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 186 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 187 | "invert": lambda img, magnitude: ImageOps.invert(img) 188 | } 189 | 190 | self.p1 = p1 191 | self.operation1 = func[operation1] 192 | self.magnitude1 = ranges[operation1][magnitude_idx1] 193 | self.p2 = p2 194 | self.operation2 = func[operation2] 195 | self.magnitude2 = ranges[operation2][magnitude_idx2] 196 | 197 | def __call__(self, img): 198 | if random.random() < self.p1: 199 | img = self.operation1(img, self.magnitude1) 200 | if random.random() < self.p2: 201 | img = self.operation2(img, self.magnitude2) 202 | return img 203 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from PIL import Image 4 | 5 | import torch 6 | 7 | from torchvision import transforms 8 | from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler 9 | 10 | from .dataset import CUB, CarsDataset, NABirds, dogs, INat2017 11 | from .autoaugment import AutoAugImageNetPolicy 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def get_loader(args): 17 | 18 | if args.dataset == 'CUB_200_2011': 19 | train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), 20 | transforms.RandomCrop((448, 448)), 21 | transforms.RandomHorizontalFlip(), 22 | transforms.ToTensor(), 23 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 24 | test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), 25 | transforms.CenterCrop((448, 448)), 26 | transforms.ToTensor(), 27 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 28 | trainset = CUB(root=args.data_root, is_train=True, transform=train_transform, args=args) 29 | testset = CUB(root=args.data_root, is_train=False, transform = test_transform, args=args) 30 | elif args.dataset == 'car': 31 | trainset = CarsDataset(os.path.join(args.data_root,'devkit/cars_train_annos.mat'), 32 | os.path.join(args.data_root,'cars_train'), 33 | os.path.join(args.data_root,'devkit/cars_meta.mat'), 34 | transform=transforms.Compose([ 35 | transforms.Resize((600, 600), Image.BILINEAR), 36 | transforms.RandomCrop((448, 448)), 37 | transforms.RandomHorizontalFlip(), 38 | AutoAugImageNetPolicy(), 39 | transforms.ToTensor(), 40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 41 | ) 42 | testset = CarsDataset(os.path.join(args.data_root,'cars_test_annos_withlabels.mat'), 43 | os.path.join(args.data_root,'cars_test'), 44 | os.path.join(args.data_root,'devkit/cars_meta.mat'), 45 | # cleaned=os.path.join(data_dir,'cleaned_test.dat'), 46 | transform=transforms.Compose([ 47 | transforms.Resize((600, 600), Image.BILINEAR), 48 | transforms.CenterCrop((448, 448)), 49 | transforms.ToTensor(), 50 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 51 | ) 52 | elif args.dataset == 'dog': 53 | train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), 54 | transforms.RandomCrop((448, 448)), 55 | transforms.RandomHorizontalFlip(), 56 | transforms.ToTensor(), 57 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 58 | test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), 59 | transforms.CenterCrop((448, 448)), 60 | transforms.ToTensor(), 61 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 62 | trainset = dogs(root=args.data_root, 63 | train=True, 64 | cropped=False, 65 | transform=train_transform, 66 | download=False 67 | ) 68 | testset = dogs(root=args.data_root, 69 | train=False, 70 | cropped=False, 71 | transform=test_transform, 72 | download=False 73 | ) 74 | elif args.dataset == 'nabirds': 75 | train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), 76 | transforms.RandomCrop((448, 448)), 77 | transforms.RandomHorizontalFlip(), 78 | transforms.ToTensor(), 79 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 80 | test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), 81 | transforms.CenterCrop((448, 448)), 82 | transforms.ToTensor(), 83 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 84 | trainset = NABirds(root=args.data_root, train=True, transform=train_transform) 85 | testset = NABirds(root=args.data_root, train=False, transform=test_transform) 86 | elif args.dataset == 'INat2017': 87 | train_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR), 88 | transforms.RandomCrop((304, 304)), 89 | transforms.RandomHorizontalFlip(), 90 | AutoAugImageNetPolicy(), 91 | transforms.ToTensor(), 92 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 93 | test_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR), 94 | transforms.CenterCrop((304, 304)), 95 | transforms.ToTensor(), 96 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 97 | trainset = INat2017(args.data_root, 'train', train_transform) 98 | testset = INat2017(args.data_root, 'val', test_transform) 99 | 100 | train_sampler = torch.utils.data.distributed.DistributedSampler(trainset) 101 | test_sampler = torch.utils.data.distributed.DistributedSampler(testset) 102 | 103 | train_loader = DataLoader(trainset, 104 | sampler=train_sampler, 105 | batch_size=args.train_batch_size, 106 | num_workers=args.workers, 107 | drop_last=True, 108 | pin_memory=True) 109 | test_loader = DataLoader(testset, 110 | sampler=test_sampler, 111 | batch_size=args.eval_batch_size, 112 | num_workers=args.workers, 113 | pin_memory=True) if testset is not None else None 114 | 115 | return train_loader, test_loader, train_sampler 116 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import scipy 4 | import imageio 5 | import scipy.misc 6 | import numpy as np 7 | import pandas as pd 8 | from scipy import io 9 | from PIL import Image 10 | from os.path import join 11 | import matplotlib.pyplot as plt 12 | 13 | import torch 14 | from torch.utils.data import Dataset 15 | from torchvision.datasets import VisionDataset 16 | from torchvision.datasets.folder import default_loader 17 | from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg 18 | 19 | 20 | class CUB(): 21 | 22 | def __init__(self, root, is_train=True, data_len=None, transform=None, args=None): 23 | self.root = root 24 | self.is_train = is_train 25 | self.transform = transform 26 | img_txt_file = open(os.path.join(self.root, 'images.txt')) 27 | label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt')) 28 | train_val_file = open(os.path.join(self.root, 'train_test_split.txt')) 29 | img_name_list = [] 30 | for line in img_txt_file: 31 | img_name_list.append(line[:-1].split(' ')[-1]) 32 | label_list = [] 33 | for line in label_txt_file: 34 | label_list.append(int(line[:-1].split(' ')[-1]) - 1) 35 | train_test_list = [] 36 | for line in train_val_file: 37 | train_test_list.append(int(line[:-1].split(' ')[-1])) 38 | train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i] 39 | test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i] 40 | if self.is_train: 41 | self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in 42 | train_file_list[:data_len]] 43 | self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len] 44 | self.train_imgname = [x for x in train_file_list[:data_len]] 45 | if not self.is_train: 46 | self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in 47 | test_file_list[:data_len]] 48 | self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len] 49 | self.test_imgname = [x for x in test_file_list[:data_len]] 50 | 51 | def __getitem__(self, index): 52 | if self.is_train: 53 | img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index] 54 | if len(img.shape) == 2: 55 | img = np.stack([img] * 3, 2) 56 | img = Image.fromarray(img, mode='RGB') 57 | if self.transform is not None: 58 | img = self.transform(img) 59 | else: 60 | img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index] 61 | if len(img.shape) == 2: 62 | img = np.stack([img] * 3, 2) 63 | img = Image.fromarray(img, mode='RGB') 64 | if self.transform is not None: 65 | img = self.transform(img) 66 | 67 | return img, target 68 | 69 | def __len__(self): 70 | if self.is_train: 71 | return len(self.train_label) 72 | else: 73 | return len(self.test_label) 74 | 75 | 76 | class CarsDataset(Dataset): 77 | 78 | def __init__(self, mat_anno, data_dir, car_names, cleaned=None, transform=None): 79 | """ 80 | Args: 81 | mat_anno (string): Path to the MATLAB annotation file. 82 | data_dir (string): Directory with all the images. 83 | transform (callable, optional): Optional transform to be applied 84 | on a sample. 85 | """ 86 | 87 | self.full_data_set = io.loadmat(mat_anno) 88 | self.car_annotations = self.full_data_set['annotations'] 89 | self.car_annotations = self.car_annotations[0] 90 | 91 | if cleaned is not None: 92 | cleaned_annos = [] 93 | print("Cleaning up data set (only take pics with rgb chans)...") 94 | clean_files = np.loadtxt(cleaned, dtype=str) 95 | for c in self.car_annotations: 96 | if c[-1][0] in clean_files: 97 | cleaned_annos.append(c) 98 | self.car_annotations = cleaned_annos 99 | 100 | self.car_names = scipy.io.loadmat(car_names)['class_names'] 101 | self.car_names = np.array(self.car_names[0]) 102 | 103 | self.data_dir = data_dir 104 | self.transform = transform 105 | 106 | def __len__(self): 107 | return len(self.car_annotations) 108 | 109 | def __getitem__(self, idx): 110 | img_name = os.path.join(self.data_dir, self.car_annotations[idx][-1][0]) 111 | image = Image.open(img_name).convert('RGB') 112 | car_class = self.car_annotations[idx][-2][0][0] 113 | car_class = torch.from_numpy(np.array(car_class.astype(np.float32))).long() - 1 114 | assert car_class < 196 115 | 116 | if self.transform: 117 | image = self.transform(image) 118 | 119 | # return image, car_class, img_name 120 | return image, car_class 121 | 122 | def map_class(self, id): 123 | id = np.ravel(id) 124 | ret = self.car_names[id - 1][0][0] 125 | return ret 126 | 127 | def show_batch(self, img_batch, class_batch): 128 | 129 | for i in range(img_batch.shape[0]): 130 | ax = plt.subplot(1, img_batch.shape[0], i + 1) 131 | title_str = self.map_class(int(class_batch[i])) 132 | img = np.transpose(img_batch[i, ...], (1, 2, 0)) 133 | ax.imshow(img) 134 | ax.set_title(title_str.__str__(), {'fontsize': 5}) 135 | plt.tight_layout() 136 | 137 | 138 | def make_dataset(dir, image_ids, targets): 139 | assert(len(image_ids) == len(targets)) 140 | images = [] 141 | dir = os.path.expanduser(dir) 142 | for i in range(len(image_ids)): 143 | item = (os.path.join(dir, 'data', 'images', 144 | '%s.jpg' % image_ids[i]), targets[i]) 145 | images.append(item) 146 | return images 147 | 148 | 149 | def find_classes(classes_file): 150 | # read classes file, separating out image IDs and class names 151 | image_ids = [] 152 | targets = [] 153 | f = open(classes_file, 'r') 154 | for line in f: 155 | split_line = line.split(' ') 156 | image_ids.append(split_line[0]) 157 | targets.append(' '.join(split_line[1:])) 158 | f.close() 159 | 160 | # index class names 161 | classes = np.unique(targets) 162 | class_to_idx = {classes[i]: i for i in range(len(classes))} 163 | targets = [class_to_idx[c] for c in targets] 164 | 165 | return (image_ids, targets, classes, class_to_idx) 166 | 167 | 168 | class dogs(Dataset): 169 | """`Stanford Dogs `_ Dataset. 170 | Args: 171 | root (string): Root directory of dataset where directory 172 | ``omniglot-py`` exists. 173 | cropped (bool, optional): If true, the images will be cropped into the bounding box specified 174 | in the annotations 175 | transform (callable, optional): A function/transform that takes in an PIL image 176 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 177 | target_transform (callable, optional): A function/transform that takes in the 178 | target and transforms it. 179 | download (bool, optional): If true, downloads the dataset tar files from the internet and 180 | puts it in root directory. If the tar files are already downloaded, they are not 181 | downloaded again. 182 | """ 183 | folder = 'dog' 184 | download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs' 185 | 186 | def __init__(self, 187 | root, 188 | train=True, 189 | cropped=False, 190 | transform=None, 191 | target_transform=None, 192 | download=False): 193 | 194 | # self.root = join(os.path.expanduser(root), self.folder) 195 | self.root = root 196 | self.train = train 197 | self.cropped = cropped 198 | self.transform = transform 199 | self.target_transform = target_transform 200 | 201 | if download: 202 | self.download() 203 | 204 | split = self.load_split() 205 | 206 | self.images_folder = join(self.root, 'Images') 207 | self.annotations_folder = join(self.root, 'Annotation') 208 | self._breeds = list_dir(self.images_folder) 209 | 210 | if self.cropped: 211 | self._breed_annotations = [[(annotation, box, idx) 212 | for box in self.get_boxes(join(self.annotations_folder, annotation))] 213 | for annotation, idx in split] 214 | self._flat_breed_annotations = sum(self._breed_annotations, []) 215 | 216 | self._flat_breed_images = [(annotation+'.jpg', idx) for annotation, box, idx in self._flat_breed_annotations] 217 | else: 218 | self._breed_images = [(annotation+'.jpg', idx) for annotation, idx in split] 219 | 220 | self._flat_breed_images = self._breed_images 221 | 222 | self.classes = ["Chihuaha", 223 | "Japanese Spaniel", 224 | "Maltese Dog", 225 | "Pekinese", 226 | "Shih-Tzu", 227 | "Blenheim Spaniel", 228 | "Papillon", 229 | "Toy Terrier", 230 | "Rhodesian Ridgeback", 231 | "Afghan Hound", 232 | "Basset Hound", 233 | "Beagle", 234 | "Bloodhound", 235 | "Bluetick", 236 | "Black-and-tan Coonhound", 237 | "Walker Hound", 238 | "English Foxhound", 239 | "Redbone", 240 | "Borzoi", 241 | "Irish Wolfhound", 242 | "Italian Greyhound", 243 | "Whippet", 244 | "Ibizian Hound", 245 | "Norwegian Elkhound", 246 | "Otterhound", 247 | "Saluki", 248 | "Scottish Deerhound", 249 | "Weimaraner", 250 | "Staffordshire Bullterrier", 251 | "American Staffordshire Terrier", 252 | "Bedlington Terrier", 253 | "Border Terrier", 254 | "Kerry Blue Terrier", 255 | "Irish Terrier", 256 | "Norfolk Terrier", 257 | "Norwich Terrier", 258 | "Yorkshire Terrier", 259 | "Wirehaired Fox Terrier", 260 | "Lakeland Terrier", 261 | "Sealyham Terrier", 262 | "Airedale", 263 | "Cairn", 264 | "Australian Terrier", 265 | "Dandi Dinmont", 266 | "Boston Bull", 267 | "Miniature Schnauzer", 268 | "Giant Schnauzer", 269 | "Standard Schnauzer", 270 | "Scotch Terrier", 271 | "Tibetan Terrier", 272 | "Silky Terrier", 273 | "Soft-coated Wheaten Terrier", 274 | "West Highland White Terrier", 275 | "Lhasa", 276 | "Flat-coated Retriever", 277 | "Curly-coater Retriever", 278 | "Golden Retriever", 279 | "Labrador Retriever", 280 | "Chesapeake Bay Retriever", 281 | "German Short-haired Pointer", 282 | "Vizsla", 283 | "English Setter", 284 | "Irish Setter", 285 | "Gordon Setter", 286 | "Brittany", 287 | "Clumber", 288 | "English Springer Spaniel", 289 | "Welsh Springer Spaniel", 290 | "Cocker Spaniel", 291 | "Sussex Spaniel", 292 | "Irish Water Spaniel", 293 | "Kuvasz", 294 | "Schipperke", 295 | "Groenendael", 296 | "Malinois", 297 | "Briard", 298 | "Kelpie", 299 | "Komondor", 300 | "Old English Sheepdog", 301 | "Shetland Sheepdog", 302 | "Collie", 303 | "Border Collie", 304 | "Bouvier des Flandres", 305 | "Rottweiler", 306 | "German Shepard", 307 | "Doberman", 308 | "Miniature Pinscher", 309 | "Greater Swiss Mountain Dog", 310 | "Bernese Mountain Dog", 311 | "Appenzeller", 312 | "EntleBucher", 313 | "Boxer", 314 | "Bull Mastiff", 315 | "Tibetan Mastiff", 316 | "French Bulldog", 317 | "Great Dane", 318 | "Saint Bernard", 319 | "Eskimo Dog", 320 | "Malamute", 321 | "Siberian Husky", 322 | "Affenpinscher", 323 | "Basenji", 324 | "Pug", 325 | "Leonberg", 326 | "Newfoundland", 327 | "Great Pyrenees", 328 | "Samoyed", 329 | "Pomeranian", 330 | "Chow", 331 | "Keeshond", 332 | "Brabancon Griffon", 333 | "Pembroke", 334 | "Cardigan", 335 | "Toy Poodle", 336 | "Miniature Poodle", 337 | "Standard Poodle", 338 | "Mexican Hairless", 339 | "Dingo", 340 | "Dhole", 341 | "African Hunting Dog"] 342 | 343 | def __len__(self): 344 | return len(self._flat_breed_images) 345 | 346 | def __getitem__(self, index): 347 | """ 348 | Args: 349 | index (int): Index 350 | Returns: 351 | tuple: (image, target) where target is index of the target character class. 352 | """ 353 | image_name, target_class = self._flat_breed_images[index] 354 | image_path = join(self.images_folder, image_name) 355 | image = Image.open(image_path).convert('RGB') 356 | 357 | if self.cropped: 358 | image = image.crop(self._flat_breed_annotations[index][1]) 359 | 360 | if self.transform: 361 | image = self.transform(image) 362 | 363 | if self.target_transform: 364 | target_class = self.target_transform(target_class) 365 | 366 | return image, target_class 367 | 368 | def download(self): 369 | import tarfile 370 | 371 | if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')): 372 | if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120: 373 | print('Files already downloaded and verified') 374 | return 375 | 376 | for filename in ['images', 'annotation', 'lists']: 377 | tar_filename = filename + '.tar' 378 | url = self.download_url_prefix + '/' + tar_filename 379 | download_url(url, self.root, tar_filename, None) 380 | print('Extracting downloaded file: ' + join(self.root, tar_filename)) 381 | with tarfile.open(join(self.root, tar_filename), 'r') as tar_file: 382 | tar_file.extractall(self.root) 383 | os.remove(join(self.root, tar_filename)) 384 | 385 | @staticmethod 386 | def get_boxes(path): 387 | import xml.etree.ElementTree 388 | e = xml.etree.ElementTree.parse(path).getroot() 389 | boxes = [] 390 | for objs in e.iter('object'): 391 | boxes.append([int(objs.find('bndbox').find('xmin').text), 392 | int(objs.find('bndbox').find('ymin').text), 393 | int(objs.find('bndbox').find('xmax').text), 394 | int(objs.find('bndbox').find('ymax').text)]) 395 | return boxes 396 | 397 | def load_split(self): 398 | if self.train: 399 | split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list'] 400 | labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels'] 401 | else: 402 | split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list'] 403 | labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels'] 404 | 405 | split = [item[0][0] for item in split] 406 | labels = [item[0]-1 for item in labels] 407 | return list(zip(split, labels)) 408 | 409 | def stats(self): 410 | counts = {} 411 | for index in range(len(self._flat_breed_images)): 412 | image_name, target_class = self._flat_breed_images[index] 413 | if target_class not in counts.keys(): 414 | counts[target_class] = 1 415 | else: 416 | counts[target_class] += 1 417 | 418 | print("%d samples spanning %d classes (avg %f per class)"%(len(self._flat_breed_images), len(counts.keys()), float(len(self._flat_breed_images))/float(len(counts.keys())))) 419 | 420 | return counts 421 | 422 | 423 | class NABirds(Dataset): 424 | """`NABirds `_ Dataset. 425 | 426 | Args: 427 | root (string): Root directory of the dataset. 428 | train (bool, optional): If True, creates dataset from training set, otherwise 429 | creates from test set. 430 | transform (callable, optional): A function/transform that takes in an PIL image 431 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 432 | target_transform (callable, optional): A function/transform that takes in the 433 | target and transforms it. 434 | download (bool, optional): If true, downloads the dataset from the internet and 435 | puts it in root directory. If dataset is already downloaded, it is not 436 | downloaded again. 437 | """ 438 | base_folder = 'nabirds/images' 439 | 440 | def __init__(self, root, train=True, transform=None): 441 | dataset_path = os.path.join(root, 'nabirds') 442 | self.root = root 443 | self.loader = default_loader 444 | self.train = train 445 | self.transform = transform 446 | 447 | image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'), 448 | sep=' ', names=['img_id', 'filepath']) 449 | image_class_labels = pd.read_csv(os.path.join(dataset_path, 'image_class_labels.txt'), 450 | sep=' ', names=['img_id', 'target']) 451 | # Since the raw labels are non-continuous, map them to new ones 452 | self.label_map = get_continuous_class_map(image_class_labels['target']) 453 | train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'), 454 | sep=' ', names=['img_id', 'is_training_img']) 455 | data = image_paths.merge(image_class_labels, on='img_id') 456 | self.data = data.merge(train_test_split, on='img_id') 457 | # Load in the train / test split 458 | if self.train: 459 | self.data = self.data[self.data.is_training_img == 1] 460 | else: 461 | self.data = self.data[self.data.is_training_img == 0] 462 | 463 | # Load in the class data 464 | self.class_names = load_class_names(dataset_path) 465 | self.class_hierarchy = load_hierarchy(dataset_path) 466 | 467 | def __len__(self): 468 | return len(self.data) 469 | 470 | def __getitem__(self, idx): 471 | sample = self.data.iloc[idx] 472 | path = os.path.join(self.root, self.base_folder, sample.filepath) 473 | target = self.label_map[sample.target] 474 | img = self.loader(path) 475 | 476 | if self.transform is not None: 477 | img = self.transform(img) 478 | return img, target 479 | 480 | 481 | def get_continuous_class_map(class_labels): 482 | label_set = set(class_labels) 483 | return {k: i for i, k in enumerate(label_set)} 484 | 485 | 486 | def load_class_names(dataset_path=''): 487 | names = {} 488 | 489 | with open(os.path.join(dataset_path, 'classes.txt')) as f: 490 | for line in f: 491 | pieces = line.strip().split() 492 | class_id = pieces[0] 493 | names[class_id] = ' '.join(pieces[1:]) 494 | 495 | return names 496 | 497 | 498 | def load_hierarchy(dataset_path=''): 499 | parents = {} 500 | 501 | with open(os.path.join(dataset_path, 'hierarchy.txt')) as f: 502 | for line in f: 503 | pieces = line.strip().split() 504 | child_id, parent_id = pieces 505 | parents[child_id] = parent_id 506 | 507 | return parents 508 | 509 | 510 | class INat2017(VisionDataset): 511 | """`iNaturalist 2017 `_ Dataset. 512 | Args: 513 | root (string): Root directory of the dataset. 514 | split (string, optional): The dataset split, supports ``train``, or ``val``. 515 | transform (callable, optional): A function/transform that takes in an PIL image 516 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 517 | target_transform (callable, optional): A function/transform that takes in the 518 | target and transforms it. 519 | download (bool, optional): If true, downloads the dataset from the internet and 520 | puts it in root directory. If dataset is already downloaded, it is not 521 | downloaded again. 522 | """ 523 | base_folder = 'train_val_images/' 524 | file_list = { 525 | 'imgs': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val_images.tar.gz', 526 | 'train_val_images.tar.gz', 527 | '7c784ea5e424efaec655bd392f87301f'), 528 | 'annos': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val2017.zip', 529 | 'train_val2017.zip', 530 | '444c835f6459867ad69fcb36478786e7') 531 | } 532 | 533 | def __init__(self, root, split='train', transform=None, target_transform=None, download=False): 534 | super(INat2017, self).__init__(root, transform=transform, target_transform=target_transform) 535 | self.loader = default_loader 536 | self.split = verify_str_arg(split, "split", ("train", "val",)) 537 | 538 | if self._check_exists(): 539 | print('Files already downloaded and verified.') 540 | elif download: 541 | if not (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1])) 542 | and os.path.exists(os.path.join(self.root, self.file_list['annos'][1]))): 543 | print('Downloading...') 544 | self._download() 545 | print('Extracting...') 546 | extract_archive(os.path.join(self.root, self.file_list['imgs'][1])) 547 | extract_archive(os.path.join(self.root, self.file_list['annos'][1])) 548 | else: 549 | raise RuntimeError( 550 | 'Dataset not found. You can use download=True to download it.') 551 | anno_filename = split + '2017.json' 552 | with open(os.path.join(self.root, anno_filename), 'r') as fp: 553 | all_annos = json.load(fp) 554 | 555 | self.annos = all_annos['annotations'] 556 | self.images = all_annos['images'] 557 | 558 | def __getitem__(self, index): 559 | path = os.path.join(self.root, self.images[index]['file_name']) 560 | target = self.annos[index]['category_id'] 561 | 562 | image = self.loader(path) 563 | if self.transform is not None: 564 | image = self.transform(image) 565 | if self.target_transform is not None: 566 | target = self.target_transform(target) 567 | 568 | return image, target 569 | 570 | def __len__(self): 571 | return len(self.images) 572 | 573 | def _check_exists(self): 574 | return os.path.exists(os.path.join(self.root, self.base_folder)) 575 | 576 | def _download(self): 577 | for url, filename, md5 in self.file_list.values(): 578 | download_url(url, root=self.root, filename=filename) 579 | if not check_integrity(os.path.join(self.root, filename), md5): 580 | raise RuntimeError("File not found or corrupted.") 581 | --------------------------------------------------------------------------------