├── LICENSE ├── README.md ├── classifier.py ├── images └── fig1.jpg ├── main.py ├── utils ├── __init__.py ├── analysis │ ├── __init__.py │ ├── a_distance.py │ └── tsne.py ├── data.py ├── logger.py ├── meter.py ├── metric │ └── __init__.py └── scheduler.py └── vision ├── __init__.py ├── datasets ├── __init__.py ├── _util.py ├── cub.py ├── cub200.py ├── domainnet.py ├── imagelist.py ├── office31.py └── officehome.py ├── models ├── __init__.py ├── digits.py ├── ibn.py └── resnet.py └── transforms └── __init__.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Vision and Learning Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [A Broad Study of Pre-training for Domain Generalization and Adaptation (ECCV 2022)](https://arxiv.org/pdf/2203.11819.pdf) 2 | [Donghyun Kim](http://cs-people.bu.edu/donhk/), [Kaihong Wang](https://cs-people.bu.edu/kaiwkh/), [Stan Sclaroff](https://www.cs.bu.edu/fac/sclaroff/), and [Kate Saenko](http://ai.bu.edu/ksaenko.html). 3 | #### [[Project Page]]() [[Paper]](https://arxiv.org/pdf/2203.11819.pdf) 4 | ![Overview](images/fig1.jpg) 5 | 6 | 7 | 8 | ## Introduction 9 | 10 | While domain transfer methods (e.g., domain adaptation, domain generalization) have been 11 | proposed to learn transferable representations across domains, they are 12 | typically applied to ResNet backbones pre-trained on ImageNet. Thus, 13 | existing works pay little attention to the effects of pre-training on domain 14 | transfer tasks. In this paper, we provide a broad study and in-depth analysis of pre-training for domain adaptation and generalization, namely: 15 | network architectures, size, pre-training loss, and datasets. This repository contains PyTorch implementation of the single domain generalization experiments, which can be used as a baselin for domain transfer tasks including domain generalization and adpatation. This implementation is based on [Transfer Learning Library](https://github.com/thuml/Transfer-Learning-Library). 16 | 17 | **Bibtex** 18 | ``` 19 | @InProceedings{kim2022unified, 20 | title={A Broad Study of Pre-training for Domain Generalization and Adaptation}, 21 | author={Kim, Donghyun and Wang, Kaihong and Sclaroff, Stan and Saenko, Kate}, 22 | booktitle = {The European Conference on Computer Vision (ECCV)}, 23 | year = {2022} 24 | } 25 | ``` 26 | 27 | 28 | 29 | # Usage 30 | 31 | Our code is based on the implmentation of [Transfer Learning Library](https://github.com/thuml/Transfer-Learning-Library/tree/master/examples/domain_generalization/image_classification) and [timm](https://github.com/rwightman/pytorch-image-models/). Data will be downloaded automatically except WILD. WILD and timm can be installed using pip. 32 | ``` 33 | pip install wilds 34 | ``` 35 | ``` 36 | pip install timm 37 | ``` 38 | 39 | # Baseline training example 40 | ### ConvNeXt 41 | ``` 42 | python main.py data_scc/office-home -d OfficeHome -s Rw -t Cl Ar Pr -a convnext_base_in22ft1k --seed 0 --log logs/baseline/ 43 | 44 | python main.py data_scc/domainnet -d DomainNet -s r -t i p q c s -a convnext_base_in22ft1k --seed 0 --log logs/baseline_domainnet/ 45 | ``` 46 | ### Swin Transformers 47 | ``` 48 | python main.py data_scc/office-home -d OfficeHome -s Rw -t Cl Ar Pr -a swin_base_patch4_window7_224 --seed 0 --log logs/baseline/ 49 | 50 | python main.py data_scc/domainnet -d DomainNet -s r -t c i p q s -a swin_base_patch4_window7_224 --seed 0 --log logs/baseline_domainnet/ 51 | ``` -------------------------------------------------------------------------------- /classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Baixu Chen 3 | @contact: cbx_99_hasta@outlook.com 4 | """ 5 | from typing import Optional, Tuple 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class Classifier(nn.Module): 11 | """A generic Classifier class for domain adaptation. 12 | 13 | Args: 14 | backbone (torch.nn.Module): Any backbone to extract 2-d features from data 15 | num_classes (int): Number of classes 16 | bottleneck (torch.nn.Module, optional): Any bottleneck layer. Use no bottleneck by default 17 | bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: -1 18 | head (torch.nn.Module, optional): Any classifier head. Use :class:`torch.nn.Linear` by default 19 | finetune (bool): Whether finetune the classifier or train from scratch. Default: True 20 | 21 | .. note:: 22 | Different classifiers are used in different domain adaptation algorithms to achieve better accuracy 23 | respectively, and we provide a suggested `Classifier` for different algorithms. 24 | Remember they are not the core of algorithms. You can implement your own `Classifier` and combine it with 25 | the domain adaptation algorithm in this algorithm library. 26 | 27 | .. note:: 28 | The learning rate of this classifier is set 10 times to that of the feature extractor for better accuracy 29 | by default. If you have other optimization strategies, please over-ride :meth:`~Classifier.get_parameters`. 30 | 31 | Inputs: 32 | - x (tensor): input data fed to `backbone` 33 | 34 | Outputs: 35 | - predictions: classifier's predictions 36 | - features: features after `bottleneck` layer and before `head` layer 37 | 38 | Shape: 39 | - Inputs: (minibatch, *) where * means, any number of additional dimensions 40 | - predictions: (minibatch, `num_classes`) 41 | - features: (minibatch, `features_dim`) 42 | 43 | """ 44 | 45 | def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: Optional[nn.Module] = None, 46 | bottleneck_dim: Optional[int] = -1, head: Optional[nn.Module] = None, finetune=True, pool_layer=None): 47 | super(Classifier, self).__init__() 48 | self.backbone = backbone 49 | self.num_classes = num_classes 50 | if pool_layer is None: 51 | self.pool_layer = nn.Sequential( 52 | nn.AdaptiveAvgPool2d(output_size=(1, 1)), 53 | nn.Flatten() 54 | ) 55 | else: 56 | self.pool_layer = pool_layer 57 | if bottleneck is None: 58 | self.bottleneck = nn.Identity() 59 | self._features_dim = backbone.out_features 60 | else: 61 | self.bottleneck = bottleneck 62 | assert bottleneck_dim > 0 63 | self._features_dim = bottleneck_dim 64 | 65 | if head is None: 66 | self.head = nn.Linear(self._features_dim, num_classes) 67 | else: 68 | self.head = head 69 | self.finetune = finetune 70 | 71 | @property 72 | def features_dim(self) -> int: 73 | """The dimension of features before the final `head` layer""" 74 | return self._features_dim 75 | 76 | def forward(self, x: torch.Tensor): 77 | """""" 78 | f = self.backbone(x) 79 | if len(f.shape) > 2: 80 | f = self.pool_layer(f) 81 | 82 | f = self.bottleneck(f) 83 | predictions = self.head(f) 84 | if self.training: 85 | return predictions, f 86 | else: 87 | return predictions 88 | 89 | def get_parameters(self, base_lr=1.0): 90 | """A parameter list which decides optimization hyper-parameters, 91 | such as the relative learning rate of each layer 92 | """ 93 | params = [ 94 | {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, 95 | {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr}, 96 | {"params": self.head.parameters(), "lr": 1.0 * base_lr}, 97 | ] 98 | 99 | return params 100 | 101 | def get_head_parameters(self, base_lr=1.0): 102 | """A parameter list which decides optimization hyper-parameters, 103 | such as the relative learning rate of each layer 104 | """ 105 | params = [ 106 | # {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, 107 | # {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr}, 108 | {"params": self.head.parameters(), "lr": 1.0 * base_lr}, 109 | ] 110 | 111 | return params 112 | 113 | 114 | 115 | class ImageClassifier(Classifier): 116 | """ImageClassifier specific for reproducing results of `DomainBed `_. 117 | You are free to freeze all `BatchNorm2d` layers and insert one additional `Dropout` layer, this can achieve better 118 | results for some datasets like PACS but may be worse for others. 119 | 120 | Args: 121 | backbone (torch.nn.Module): Any backbone to extract features from data 122 | num_classes (int): Number of classes 123 | freeze_bn (bool, optional): whether to freeze all `BatchNorm2d` layers. Default: False 124 | dropout_p (float, optional): dropout ratio for additional `Dropout` layer, this layer is only used when `freeze_bn` is True. Default: 0.1 125 | """ 126 | 127 | def __init__(self, backbone: nn.Module, num_classes: int, freeze_bn: Optional[bool] = False, 128 | dropout_p: Optional[float] = 0.1, **kwargs): 129 | super(ImageClassifier, self).__init__(backbone, num_classes, **kwargs) 130 | self.freeze_bn = freeze_bn 131 | if freeze_bn: 132 | self.feature_dropout = nn.Dropout(p=dropout_p) 133 | 134 | def forward(self, x: torch.Tensor): 135 | f = self.backbone(x) 136 | if len(f.shape) > 2: 137 | f = self.pool_layer(f) 138 | 139 | f = self.bottleneck(f) 140 | if self.freeze_bn: 141 | f = self.feature_dropout(f) 142 | predictions = self.head(f) 143 | if self.training: 144 | return predictions, f 145 | else: 146 | return predictions 147 | 148 | def train(self, mode=True): 149 | super(ImageClassifier, self).train(mode) 150 | if self.freeze_bn: 151 | for m in self.modules(): 152 | if isinstance(m, nn.BatchNorm2d): 153 | m.eval() 154 | -------------------------------------------------------------------------------- /images/fig1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VisionLearningGroup/Benchmark_Domain_Transfer/b631decfa33cf5ec3c279487e2f97b048058faf4/images/fig1.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import time 4 | import warnings 5 | import sys 6 | import argparse 7 | import shutil 8 | import os.path as osp 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.backends.cudnn as cudnn 13 | from torch.optim import SGD, Adam 14 | from torch.optim.lr_scheduler import CosineAnnealingLR 15 | from torch.utils.data import DataLoader 16 | import torch.nn.functional as F 17 | 18 | sys.path.insert(0, './') 19 | print(sys.path) 20 | from classifier import ImageClassifier as Classifier 21 | from utils.data import ForeverDataIterator 22 | from utils.metric import accuracy 23 | from utils.meter import AverageMeter, ProgressMeter 24 | from utils.logger import CompleteLogger 25 | from utils.analysis import tsne, a_distance 26 | import utils 27 | 28 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | cudnn.benchmark = True 30 | 31 | def main(args: argparse.Namespace): 32 | 33 | log_name = args.log + args.data + '_' + args.arch + '_src_' + '_'.join(args.sources) 34 | 35 | logger = CompleteLogger(log_name, args.phase) 36 | logger.write(' '.join(f'{k}={v}' for k, v in vars(args).items())) 37 | 38 | print(args) 39 | 40 | if args.data == 'DomainNet': 41 | test_iter = 4 42 | else: 43 | test_iter = 2 44 | 45 | logger.write('gpu count: {}'.format(torch.cuda.device_count())) 46 | if torch.cuda.device_count() >= 1: 47 | logger.write('gpu name: {}'.format(torch.cuda.get_device_name(0))) 48 | 49 | # Data loading code 50 | train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, 51 | random_color_jitter=True, random_gray_scale=True) 52 | val_transform = utils.get_val_transform(args.val_resizing) 53 | logger.write("train_transform: {}".format(train_transform)) 54 | logger.write("val_transform: {}".format(val_transform)) 55 | 56 | train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, 57 | split='train', download=True, transform=train_transform, 58 | seed=args.seed) 59 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, 60 | shuffle=True, num_workers=args.workers, drop_last=True) 61 | val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val', 62 | download=True, transform=val_transform, seed=args.seed) 63 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 64 | test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test', 65 | download=True, transform=val_transform, seed=args.seed) 66 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 67 | 68 | logger.write("train_dataset_size: {}".format(len(train_dataset))) 69 | logger.write('val_dataset_size: {}'.format(len(val_dataset))) 70 | logger.write("test_dataset_size: {}".format(len(test_dataset))) 71 | train_iter = ForeverDataIterator(train_loader) 72 | 73 | # create model 74 | logger.write("=> using pre-trained model '{}'".format(args.arch)) 75 | 76 | test_val_acc1 = 0. 77 | global_best_val_acc1 = 0. 78 | 79 | for opt in ['sgd', 'adam']: 80 | 81 | if opt == 'sgd': 82 | lr_list = [1e-2, 1e-3, 1e-3] 83 | if opt == 'adam': 84 | lr_list = [1e-3, 1e-4, 1e-5] 85 | 86 | for lr in lr_list: 87 | args.lr = lr 88 | 89 | for seed in [0]: 90 | args.seed = seed 91 | if args.seed is not None: 92 | random.seed(args.seed) 93 | torch.manual_seed(args.seed) 94 | 95 | logger.write('opt:{} lr:{} seed:{}'.format(opt, lr, seed)) 96 | model_name = args.arch 97 | 98 | if model_name == 'dino_vits16': 99 | backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16') 100 | backbone.out_features = 384 101 | elif model_name == 'dino_vitb16': 102 | backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') 103 | backbone.out_features = 768 104 | elif model_name == 'dino_resnet50': 105 | backbone = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50') 106 | backbone.out_features = 2048 107 | 108 | else: 109 | backbone = utils.get_model(args.arch) 110 | 111 | 112 | pool_layer = nn.Identity() if args.no_pool else None 113 | classifier = Classifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p, 114 | finetune=args.finetune, pool_layer=pool_layer).to(device) 115 | 116 | # define optimizer and lr scheduler 117 | if opt == 'sgd': 118 | optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd, 119 | nesterov=True) 120 | else: 121 | optimizer = Adam(classifier.get_parameters(base_lr=args.lr), args.lr) 122 | 123 | lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch) 124 | 125 | # resume from the best checkpoint 126 | if args.phase != 'train': 127 | checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') 128 | classifier.load_state_dict(checkpoint) 129 | 130 | 131 | # analysis the model 132 | if args.phase == 'analysis': 133 | # extract features from both domains 134 | feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) 135 | source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100) 136 | target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100) 137 | print(len(source_feature), len(target_feature)) 138 | # plot t-SNE 139 | tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') 140 | tsne.visualize(source_feature, target_feature, tSNE_filename) 141 | logger.write("Saving t-SNE to", tSNE_filename) 142 | # calculate A-distance, which is a measure for distribution discrepancy 143 | A_distance = a_distance.calculate(source_feature, target_feature, device) 144 | logger.write("A-distance =", A_distance) 145 | return 146 | 147 | if args.phase == 'test': 148 | acc1 = utils.validate_each_domain(test_loader, classifier, args, device, logger) 149 | return 150 | 151 | 152 | best_val_acc1 = 0. 153 | best_test_acc1 = 0. 154 | for epoch in range(args.epochs): 155 | print(lr_scheduler.get_lr()) 156 | # train for one epoch 157 | train(train_iter, classifier, optimizer, lr_scheduler, epoch, args) 158 | 159 | # evaluate on validation set 160 | print("Evaluate on validation set...") 161 | 162 | if epoch % test_iter == 1: 163 | acc1 = utils.validate(val_loader, classifier, args, device) 164 | acc1_list, test_acc = utils.validate_domains(test_loader, classifier, args, device, logger) 165 | 166 | logger.write("Evaluate on test set...") 167 | if acc1 > best_val_acc1: 168 | best_val_acc1 = max(acc1, best_val_acc1) 169 | 170 | 171 | if acc1 > global_best_val_acc1: 172 | global_best_val_acc1 = acc1 173 | test_val_acc1 = test_acc 174 | test_val_acc1_list = acc1_list 175 | torch.save({"state_dict": classifier.state_dict(), "lr":args.lr, "opt":opt}, logger.get_checkpoint_path('global_best')) 176 | 177 | best_test_acc1 = max(best_test_acc1, test_acc) 178 | 179 | 180 | 181 | max_acc = test_val_acc1 182 | 183 | logger.write("{}".format(args.arch)) 184 | logger.write('{:.2f}'.format(max_acc)) 185 | 186 | str_list = "" 187 | for target, acc in zip(args.targets, test_val_acc1_list): 188 | str_list += "{}: {} ".format(target, '{:.2f}'.format(acc)) 189 | 190 | logger.write('{:.2f}'.format(str_list)) 191 | logger.write('Source: {} acc: {:.2f}'.format(args.sources, global_best_val_acc1)) 192 | 193 | logger.close() 194 | 195 | 196 | 197 | def train(train_iter: ForeverDataIterator, model: Classifier, optimizer, 198 | lr_scheduler: CosineAnnealingLR, epoch: int, args: argparse.Namespace): 199 | batch_time = AverageMeter('Time', ':4.2f') 200 | data_time = AverageMeter('Data', ':3.1f') 201 | losses = AverageMeter('Loss', ':3.2f') 202 | cls_accs = AverageMeter('Cls Acc', ':3.1f') 203 | 204 | progress = ProgressMeter( 205 | args.iters_per_epoch, 206 | [batch_time, data_time, losses, cls_accs], 207 | prefix="Epoch: [{}]".format(epoch)) 208 | 209 | # switch to train mode 210 | model.train() 211 | 212 | end = time.time() 213 | for i in range(args.iters_per_epoch): 214 | x, labels, _ = next(train_iter) 215 | x = x.to(device) 216 | labels = labels.to(device) 217 | 218 | # measure data loading time 219 | data_time.update(time.time() - end) 220 | 221 | # compute output 222 | y, _ = model(x) 223 | 224 | loss = F.cross_entropy(y, labels) 225 | 226 | cls_acc = accuracy(y, labels)[0] 227 | losses.update(loss.item(), x.size(0)) 228 | cls_accs.update(cls_acc.item(), x.size(0)) 229 | 230 | # compute gradient and do SGD step 231 | optimizer.zero_grad() 232 | loss.backward() 233 | optimizer.step() 234 | lr_scheduler.step() 235 | 236 | # measure elapsed time 237 | batch_time.update(time.time() - end) 238 | end = time.time() 239 | 240 | if i % args.print_freq == 0: 241 | progress.display(i) 242 | 243 | 244 | if __name__ == '__main__': 245 | parser = argparse.ArgumentParser(description='Baseline for Domain Generalization') 246 | # dataset parameters 247 | parser.add_argument('root', metavar='DIR', 248 | help='root path of dataset') 249 | parser.add_argument('-d', '--data', metavar='DATA', default='PACS') 250 | # help='dataset: ' + ' | '.join(utils.get_dataset_names()) + 251 | # ' (default: PACS)') 252 | parser.add_argument('-s', '--sources', nargs='+', default=None, 253 | help='source domain(s)') 254 | parser.add_argument('-t', '--targets', nargs='+', default=None, 255 | help='target domain(s)') 256 | parser.add_argument('--train-resizing', type=str, default='default') 257 | parser.add_argument('--val-resizing', type=str, default='default') 258 | # model parameters 259 | parser.add_argument('-a', '--arch', metavar='ARCH', default='deit_small_patch16_') 260 | # choices=utils.get_model_names(), 261 | # help='backbone architecture: ' + 262 | # ' | '.join(utils.get_model_names()) + 263 | # ' (default: resnet50)') 264 | # deit_small_patch16_224 265 | parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') 266 | parser.add_argument('--finetune', default=True, action='store_true', help='whether use 10x smaller lr for backbone') 267 | parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers') 268 | parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True') 269 | # training parameters 270 | parser.add_argument('-b', '--batch-size', default=64, type=int, 271 | metavar='N', 272 | help='mini-batch size (default: 36)') 273 | parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, 274 | metavar='LR', help='initial learning rate', dest='lr') 275 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 276 | help='momentum') 277 | parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, 278 | metavar='W', help='weight decay (default: 5e-4)') 279 | parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', 280 | help='number of data loading workers (default: 4)') 281 | parser.add_argument('--epochs', default=20, type=int, metavar='N', 282 | help='number of total epochs to run') 283 | parser.add_argument('-i', '--iters-per-epoch', default=250, type=int, 284 | help='Number of iterations per epoch') 285 | parser.add_argument('-p', '--print-freq', default=100, type=int, 286 | metavar='N', help='print frequency (default: 100)') 287 | parser.add_argument('--seed', default=0, type=int, 288 | help='seed for initializing training. ') 289 | parser.add_argument("--log", type=str, default='baseline', 290 | help="Where to save logs, checkpoints and debugging images.") 291 | parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], 292 | help="When phase is 'test', only test the model." 293 | "When phase is 'analysis', only analysis the model.") 294 | args = parser.parse_args() 295 | main(args) 296 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import CompleteLogger 2 | from .meter import * 3 | from .data import ForeverDataIterator 4 | 5 | __all__ = ['metric', 'analysis', 'meter', 'data', 'logger'] 6 | 7 | import sys 8 | import time 9 | import timm 10 | import tqdm 11 | import torch 12 | import torch.nn as nn 13 | import torchvision.transforms as T 14 | import torch.nn.functional as F 15 | import numpy as np 16 | from torch.utils.data.dataset import Subset, ConcatDataset 17 | import wilds 18 | from torch.utils.data import DataLoader 19 | import vision.datasets as datasets 20 | import vision.models as models 21 | from vision.transforms import ResizeImage 22 | from utils.metric import accuracy 23 | from utils.meter import AverageMeter, ProgressMeter 24 | 25 | 26 | def get_model_names(): 27 | return sorted( 28 | name for name in models.__dict__ 29 | if name.islower() and not name.startswith("__") 30 | and callable(models.__dict__[name]) 31 | ) + timm.list_models() 32 | 33 | 34 | def get_model(model_name): 35 | if model_name in models.__dict__: 36 | # load models from common.vision.models 37 | backbone = models.__dict__[model_name](pretrained=True) 38 | else: 39 | # else: 40 | # load models from pytorch-image-models 41 | backbone = timm.create_model(model_name, pretrained=True) 42 | if 'resnetv2' in model_name: 43 | backbone.head.in_features = backbone.num_features 44 | try: 45 | backbone.out_features = backbone.get_classifier().in_features 46 | backbone.reset_classifier(0, '') 47 | except: 48 | backbone.out_features = backbone.head.in_features 49 | backbone.head = nn.Identity() 50 | return backbone 51 | 52 | 53 | def get_dataset_names(): 54 | return sorted( 55 | name for name in datasets.__dict__ 56 | if not name.startswith("__") and callable(datasets.__dict__[name]) 57 | ) + wilds.supported_datasets 58 | 59 | 60 | class ConcatDatasetWithDomainLabel(ConcatDataset): 61 | """ConcatDataset with domain label""" 62 | 63 | def __init__(self, *args, **kwargs): 64 | super(ConcatDatasetWithDomainLabel, self).__init__(*args, **kwargs) 65 | self.index_to_domain_id = {} 66 | domain_id = 0 67 | start = 0 68 | for end in self.cumulative_sizes: 69 | for idx in range(start, end): 70 | self.index_to_domain_id[idx] = domain_id 71 | start = end 72 | domain_id += 1 73 | 74 | def __getitem__(self, index): 75 | img, target = super(ConcatDatasetWithDomainLabel, self).__getitem__(index) 76 | domain_id = self.index_to_domain_id[index] 77 | return img, target, domain_id 78 | 79 | 80 | def convert_from_wilds_dataset(dataset_name, wild_dataset): 81 | metadata_array = wild_dataset.metadata_array 82 | sample_idxes_per_domain = {} 83 | for idx, metadata in enumerate(metadata_array): 84 | if dataset_name == 'iwildcam': 85 | # In iwildcam dataset, domain id is specified by location 86 | domain = metadata[0].item() 87 | elif dataset_name == 'camelyon17': 88 | # In camelyon17 dataset, domain id is specified by hospital 89 | domain = metadata[0].item() 90 | elif dataset_name == 'fmow': 91 | # In fmow dataset, domain id is specified by (region, year) tuple 92 | domain = (metadata[0].item(), metadata[1].item()) 93 | 94 | if domain not in sample_idxes_per_domain: 95 | sample_idxes_per_domain[domain] = [] 96 | sample_idxes_per_domain[domain].append(idx) 97 | 98 | class Dataset: 99 | def __init__(self): 100 | self.dataset = wild_dataset 101 | 102 | def __getitem__(self, idx): 103 | x, y, metadata = self.dataset[idx] 104 | return x, y 105 | 106 | def __len__(self): 107 | return len(self.dataset) 108 | 109 | dataset = Dataset() 110 | concat_dataset = ConcatDatasetWithDomainLabel( 111 | [Subset(dataset, sample_idxes_per_domain[domain]) for domain in sample_idxes_per_domain]) 112 | return concat_dataset 113 | 114 | 115 | def get_dataset(dataset_name, root, task_list, split='train', download=True, transform=None, seed=0, split_ratio=0.8): 116 | assert split in ['train', 'val', 'test'] 117 | if dataset_name in datasets.__dict__: 118 | # load datasets from common.vision.datasets 119 | # currently only PACS, OfficeHome and DomainNet are supported 120 | supported_dataset = ['PACS', 'OfficeHome', 'DomainNet', 'CUB'] 121 | assert dataset_name in supported_dataset 122 | 123 | dataset = datasets.__dict__[dataset_name] 124 | 125 | train_split_list = [] 126 | val_split_list = [] 127 | test_split_list = [] 128 | # we follow DomainBed and split each dataset randomly into two parts, with 80% samples and 20% samples 129 | # respectively, the former (larger) will be used as training set, and the latter will be used as validation set. 130 | split_ratio = split_ratio 131 | num_classes = 0 132 | 133 | # under domain generalization setting, we use all samples in target domain as test set 134 | for task in task_list: 135 | if dataset_name == 'PACS': 136 | all_split = dataset(root=root, task=task, split='all', download=download, transform=transform) 137 | num_classes = all_split.num_classes 138 | elif dataset_name == 'OfficeHome': 139 | all_split = dataset(root=root, task=task, download=download, transform=transform) 140 | num_classes = all_split.num_classes 141 | elif dataset_name == 'CUB': 142 | all_split = dataset(root=root, task=task, download=download, transform=transform) 143 | num_classes = all_split.num_classes 144 | elif dataset_name == 'DomainNet': 145 | train_split = dataset(root=root, task=task, split='train', download=download, transform=transform) 146 | test_split = dataset(root=root, task=task, split='test', download=download, transform=transform) 147 | num_classes = train_split.num_classes 148 | all_split = ConcatDataset([train_split, test_split]) 149 | 150 | train_split, val_split = split_dataset(all_split, int(len(all_split) * split_ratio), seed) 151 | 152 | train_split_list.append(train_split) 153 | val_split_list.append(val_split) 154 | test_split_list.append(all_split) 155 | 156 | train_dataset = ConcatDatasetWithDomainLabel(train_split_list) 157 | val_dataset = ConcatDatasetWithDomainLabel(val_split_list) 158 | test_dataset = ConcatDatasetWithDomainLabel(test_split_list) 159 | 160 | dataset_dict = { 161 | 'train': train_dataset, 162 | 'val': val_dataset, 163 | 'test': test_dataset 164 | } 165 | return dataset_dict[split], num_classes 166 | else: 167 | # load datasets from wilds 168 | # currently only iwildcam, camelyon17 and fmow are supported 169 | supported_dataset = ['iwildcam', 'camelyon17', 'fmow'] 170 | assert dataset_name in supported_dataset 171 | 172 | dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True) 173 | num_classes = dataset.n_classes 174 | return convert_from_wilds_dataset(dataset_name, 175 | dataset.get_subset(split=split, transform=transform)), num_classes 176 | 177 | 178 | 179 | def get_dataset_class_list(dataset_name, root, task_list, split='train', download=True, transform=None, seed=0, split_ratio=0.8): 180 | assert split in ['train', 'val', 'test'] 181 | if dataset_name in datasets.__dict__: 182 | # load datasets from common.vision.datasets 183 | # currently only PACS, OfficeHome and DomainNet are supported 184 | supported_dataset = ['PACS', 'OfficeHome', 'DomainNet', 'CUB'] 185 | assert dataset_name in supported_dataset 186 | 187 | dataset = datasets.__dict__[dataset_name] 188 | 189 | train_split_list = [] 190 | val_split_list = [] 191 | test_split_list = [] 192 | # we follow DomainBed and split each dataset randomly into two parts, with 80% samples and 20% samples 193 | # respectively, the former (larger) will be used as training set, and the latter will be used as validation set. 194 | split_ratio = split_ratio 195 | num_classes = 0 196 | 197 | # under domain generalization setting, we use all samples in target domain as test set 198 | for task in task_list: 199 | if dataset_name == 'PACS': 200 | all_split = dataset(root=root, task=task, split='all', download=download, transform=transform) 201 | num_classes = all_split.num_classes 202 | class_list = all_split.CLASSES 203 | elif dataset_name == 'OfficeHome': 204 | all_split = dataset(root=root, task=task, download=download, transform=transform) 205 | num_classes = all_split.num_classes 206 | class_list = all_split.CLASSES 207 | elif dataset_name == 'CUB': 208 | all_split = dataset(root=root, task=task, download=download, transform=transform) 209 | num_classes = all_split.num_classes 210 | class_list = all_split.CLASSES 211 | elif dataset_name == 'DomainNet': 212 | train_split = dataset(root=root, task=task, split='train', download=download, transform=transform) 213 | test_split = dataset(root=root, task=task, split='test', download=download, transform=transform) 214 | num_classes = train_split.num_classes 215 | class_list = train_split.CLASSES 216 | all_split = ConcatDataset([train_split, test_split]) 217 | 218 | 219 | train_split, val_split = split_dataset(all_split, int(len(all_split) * split_ratio), seed) 220 | 221 | train_split_list.append(train_split) 222 | val_split_list.append(val_split) 223 | test_split_list.append(all_split) 224 | 225 | train_dataset = ConcatDatasetWithDomainLabel(train_split_list) 226 | val_dataset = ConcatDatasetWithDomainLabel(val_split_list) 227 | test_dataset = ConcatDatasetWithDomainLabel(test_split_list) 228 | 229 | dataset_dict = { 230 | 'train': train_dataset, 231 | 'val': val_dataset, 232 | 'test': test_dataset 233 | } 234 | return dataset_dict[split], num_classes, class_list 235 | else: 236 | # load datasets from wilds 237 | # currently only iwildcam, camelyon17 and fmow are supported 238 | supported_dataset = ['iwildcam', 'camelyon17', 'fmow'] 239 | assert dataset_name in supported_dataset 240 | 241 | dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True) 242 | num_classes = dataset.n_classes 243 | # class_list = dataset.CLASSES 244 | class_list = None 245 | return convert_from_wilds_dataset(dataset_name, 246 | dataset.get_subset(split=split, transform=transform)), num_classes, class_list 247 | 248 | 249 | 250 | def get_dataset_class(dataset_name, root, task_list, split='train', download=True, transform=None, seed=0, split_ratio=0.8): 251 | assert split in ['train', 'val', 'test'] 252 | if dataset_name in datasets.__dict__: 253 | # load datasets from common.vision.datasets 254 | # currently only PACS, OfficeHome and DomainNet are supported 255 | supported_dataset = ['PACS', 'OfficeHome', 'DomainNet', 'CUB'] 256 | assert dataset_name in supported_dataset 257 | 258 | dataset = datasets.__dict__[dataset_name] 259 | 260 | train_split_list = [] 261 | val_split_list = [] 262 | test_split_list = [] 263 | # we follow DomainBed and split each dataset randomly into two parts, with 80% samples and 20% samples 264 | # respectively, the former (larger) will be used as training set, and the latter will be used as validation set. 265 | split_ratio = split_ratio 266 | num_classes = 0 267 | 268 | # under domain generalization setting, we use all samples in target domain as test set 269 | for task in task_list: 270 | if dataset_name == 'PACS': 271 | all_split = dataset(root=root, task=task, split='all', download=download, transform=transform) 272 | num_classes = all_split.num_classes 273 | elif dataset_name == 'OfficeHome': 274 | all_split = dataset(root=root, task=task, download=download, transform=transform) 275 | num_classes = all_split.num_classes 276 | elif dataset_name == 'CUB': 277 | all_split = dataset(root=root, task=task, download=download, transform=transform) 278 | num_classes = all_split.num_classes 279 | elif dataset_name == 'DomainNet': 280 | train_split = dataset(root=root, task=task, split='train', download=download, transform=transform) 281 | test_split = dataset(root=root, task=task, split='test', download=download, transform=transform) 282 | num_classes = train_split.num_classes 283 | all_split = ConcatDataset([train_split, test_split]) 284 | 285 | train_split, val_split = split_dataset_class(all_split, int(len(all_split) * split_ratio), seed) 286 | 287 | train_split_list.append(train_split) 288 | val_split_list.append(val_split) 289 | test_split_list.append(all_split) 290 | 291 | train_dataset = ConcatDatasetWithDomainLabel(train_split_list) 292 | val_dataset = ConcatDatasetWithDomainLabel(val_split_list) 293 | test_dataset = ConcatDatasetWithDomainLabel(test_split_list) 294 | 295 | dataset_dict = { 296 | 'train': train_dataset, 297 | 'val': val_dataset, 298 | 'test': test_dataset 299 | } 300 | return dataset_dict[split], num_classes 301 | else: 302 | # load datasets from wilds 303 | # currently only iwildcam, camelyon17 and fmow are supported 304 | supported_dataset = ['iwildcam', 'camelyon17', 'fmow'] 305 | assert dataset_name in supported_dataset 306 | 307 | dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True) 308 | num_classes = dataset.n_classes 309 | return convert_from_wilds_dataset(dataset_name, 310 | dataset.get_subset(split=split, transform=transform)), num_classes 311 | 312 | 313 | 314 | def split_dataset(dataset, n, seed=0): 315 | """ 316 | Return a pair of datasets corresponding to a random split of the given 317 | dataset, with n data points in the first dataset and the rest in the last, 318 | using the given random seed 319 | """ 320 | assert (n <= len(dataset)) 321 | idxes = list(range(len(dataset))) 322 | np.random.RandomState(seed).shuffle(idxes) 323 | subset_1 = idxes[:n] 324 | subset_2 = idxes[n:] 325 | return Subset(dataset, subset_1), Subset(dataset, subset_2) 326 | 327 | 328 | def split_dataset_class(dataset, n, seed=0): 329 | """ 330 | Return a pair of datasets corresponding to a random split of the given 331 | dataset, with n data points in the first dataset and the rest in the last, 332 | using the given random seed 333 | """ 334 | assert (n <= len(dataset)) 335 | idxes = list(range(len(dataset))) 336 | np.random.RandomState(seed).shuffle(idxes) 337 | subset_1 = idxes[:n] 338 | subset_2 = idxes[n:] 339 | return Subset(dataset, subset_1), Subset(dataset, subset_2) 340 | 341 | 342 | 343 | def validate(val_loader, model, args, device) -> float: 344 | batch_time = AverageMeter('Time', ':6.3f') 345 | losses = AverageMeter('Loss', ':.4e') 346 | top1 = AverageMeter('Acc@1', ':6.2f') 347 | progress = ProgressMeter( 348 | len(val_loader), 349 | [batch_time, losses, top1], 350 | prefix='Test: ') 351 | 352 | # switch to evaluate mode 353 | model.eval() 354 | 355 | with torch.no_grad(): 356 | end = time.time() 357 | for i, (images, target, domain_label) in enumerate(val_loader): 358 | 359 | images = images.to(device) 360 | target = target.to(device) 361 | 362 | # compute output 363 | output = model(images) 364 | loss = F.cross_entropy(output, target) 365 | 366 | # measure accuracy and record loss 367 | acc1 = accuracy(output, target)[0] 368 | losses.update(loss.item(), images.size(0)) 369 | top1.update(acc1.item(), images.size(0)) 370 | 371 | # measure elapsed time 372 | batch_time.update(time.time() - end) 373 | end = time.time() 374 | 375 | if i % args.print_freq == 0: 376 | progress.display(i) 377 | 378 | print(' * Acc@1 {top1.avg:.3f} '.format(top1=top1)) 379 | 380 | return top1.avg 381 | 382 | 383 | 384 | def validate_domains(val_loader, model, args, device, logger) -> float: 385 | batch_time = AverageMeter('Time', ':6.3f') 386 | losses = AverageMeter('Loss', ':.4e') 387 | 388 | top1 = AverageMeter('Acc@1', ':6.2f') 389 | 390 | # switch to evaluate mode 391 | model.eval() 392 | 393 | num_domains = len(args.targets) 394 | 395 | top1_domains = [AverageMeter('Acc@1', ':6.2f') for i in range(num_domains)] 396 | acc_list = [] 397 | progress = ProgressMeter( 398 | len(val_loader), 399 | [batch_time, losses, top1], 400 | prefix='Test: ') 401 | 402 | with torch.no_grad(): 403 | end = time.time() 404 | for i, (images, target, domain_label) in enumerate(val_loader): 405 | 406 | images = images.to(device) 407 | target = target.to(device) 408 | 409 | # compute output 410 | output = model(images) 411 | loss = F.cross_entropy(output, target) 412 | 413 | # measure accuracy and record loss 414 | acc1 = accuracy(output, target)[0] 415 | losses.update(loss.item(), images.size(0)) 416 | top1.update(acc1.item(), target.size(0)) 417 | for d in range(num_domains): 418 | select = (domain_label == d) 419 | if select.sum() > 0: 420 | acc1_domains = accuracy(output[select], target[select])[0] 421 | top1_domains[d].update(acc1_domains.item(), target[select].size(0)) 422 | 423 | # measure elapsed time 424 | batch_time.update(time.time() - end) 425 | end = time.time() 426 | 427 | if i % args.print_freq == 0: 428 | progress.display(i) 429 | 430 | print(' * Acc@1 {top1.avg:.3f} '.format(top1=top1)) 431 | 432 | for d in range(num_domains): 433 | logger.write("Domain: {domain}, Acc: {top1.avg:.3f}".format(domain=args.targets[d], top1=top1_domains[d])) 434 | acc_list.append(top1_domains[d].avg) 435 | 436 | return acc_list, top1.avg 437 | 438 | 439 | 440 | 441 | 442 | 443 | def validate_each_domain(val_loader, model, args, device, logger) -> float: 444 | 445 | 446 | # switch to evaluate mode 447 | model.eval() 448 | 449 | # test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) 450 | 451 | datasets = val_loader.dataset.datasets 452 | acc_list = [] 453 | top1_all = AverageMeter('Acc@1', ':6.2f') 454 | with torch.no_grad(): 455 | end = time.time() 456 | 457 | for dataset in datasets: 458 | 459 | batch_time = AverageMeter('Time', ':6.3f') 460 | losses = AverageMeter('Loss', ':.4e') 461 | top1 = AverageMeter('Acc@1', ':6.2f') 462 | 463 | 464 | domain_val_loader = DataLoader(dataset, batch_size=72, shuffle=False, num_workers=4) 465 | progress = ProgressMeter( 466 | len(domain_val_loader), 467 | [batch_time, losses, top1], 468 | prefix='Test: ') 469 | 470 | for i, (images, target) in enumerate(domain_val_loader): 471 | images = images.to(device) 472 | target = target.to(device) 473 | 474 | # compute output 475 | output = model(images) 476 | loss = F.cross_entropy(output, target) 477 | 478 | # measure accuracy and record loss 479 | acc1 = accuracy(output, target)[0] 480 | losses.update(loss.item(), images.size(0)) 481 | top1.update(acc1.item(), images.size(0)) 482 | top1_all.update(acc1.item(), images.size(0)) 483 | 484 | # measure elapsed time 485 | batch_time.update(time.time() - end) 486 | end = time.time() 487 | 488 | if i % args.print_freq == 0: 489 | progress.display(i) 490 | 491 | print(' * Acc@1 {top1.avg:.3f} '.format(top1=top1)) 492 | logger.write("Domain: {domain}, Acc: {top1.avg:.3f}".format(domain=dataset.data_list_file, top1=top1)) 493 | acc_list.append(top1.avg) 494 | 495 | return acc_list, top1_all.avg 496 | 497 | 498 | 499 | def get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=True, 500 | random_gray_scale=True): 501 | """ 502 | resizing mode: 503 | - default: random resized crop with scale factor(0.7, 1.0) and size 224; 504 | - cen.crop: take the center crop of 224; 505 | - res.|cen.crop: resize the image to 256 and take the center crop of size 224; 506 | - res: resize the image to 224; 507 | - res2x: resize the image to 448; 508 | - res.|crop: resize the image to 256 and take a random crop of size 224; 509 | - res.sma|crop: resize the image keeping its aspect ratio such that the 510 | smaller side is 256, then take a random crop of size 224; 511 | – inc.crop: “inception crop” from (Szegedy et al., 2015); 512 | – cif.crop: resize the image to 224, zero-pad it by 28 on each side, then take a random crop of size 224. 513 | """ 514 | if resizing == 'default': 515 | transform = T.RandomResizedCrop(224, scale=(0.7, 1.0)) 516 | elif resizing == 'default_256': 517 | transform = T.RandomResizedCrop(256, scale=(0.7, 1.0)) 518 | elif resizing == 'cen.crop': 519 | transform = T.CenterCrop(224) 520 | elif resizing == 'res.|cen.crop': 521 | transform = T.Compose([ 522 | ResizeImage(256), 523 | T.CenterCrop(224) 524 | ]) 525 | elif resizing == 'res': 526 | transform = ResizeImage(224) 527 | elif 'res2x' in resizing: 528 | transform = ResizeImage(448) 529 | elif resizing == 'res.|crop': 530 | transform = T.Compose([ 531 | T.Resize((256, 256)), 532 | T.RandomCrop(224) 533 | ]) 534 | elif resizing == "res.sma|crop": 535 | transform = T.Compose([ 536 | T.Resize(256), 537 | T.RandomCrop(224) 538 | ]) 539 | elif resizing == 'inc.crop': 540 | transform = T.RandomResizedCrop(224) 541 | elif resizing == 'cif.crop': 542 | transform = T.Compose([ 543 | T.Resize((224, 224)), 544 | T.Pad(28), 545 | T.RandomCrop(224), 546 | ]) 547 | else: 548 | raise NotImplementedError(resizing) 549 | transforms = [transform] 550 | if random_horizontal_flip: 551 | transforms.append(T.RandomHorizontalFlip()) 552 | if random_color_jitter: 553 | transforms.append(T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3)) 554 | if random_gray_scale: 555 | transforms.append(T.RandomGrayscale()) 556 | transforms.extend([ 557 | T.ToTensor(), 558 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 559 | ]) 560 | return T.Compose(transforms) 561 | 562 | 563 | def get_val_transform(resizing='default'): 564 | """ 565 | resizing mode: 566 | - default: resize the image to 224; 567 | - res2x: resize the image to 448; 568 | - res.|cen.crop: resize the image to 256 and take the center crop of size 224; 569 | """ 570 | if resizing == 'default': 571 | transform = ResizeImage(224) 572 | elif resizing == 'default_256': 573 | transform = ResizeImage(256) 574 | elif 'res2x' in resizing: 575 | transform = ResizeImage(448) 576 | elif resizing == 'res.|cen.crop': 577 | transform = T.Compose([ 578 | ResizeImage(256), 579 | T.CenterCrop(224), 580 | ]) 581 | else: 582 | raise NotImplementedError(resizing) 583 | return T.Compose([ 584 | transform, 585 | T.ToTensor(), 586 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 587 | ]) 588 | 589 | 590 | def collect_feature(data_loader, feature_extractor: nn.Module, device: torch.device, 591 | max_num_features=None, return_label=False) -> torch.Tensor: 592 | """ 593 | Fetch data from `data_loader`, and then use `feature_extractor` to collect features. This function is 594 | specific for domain generalization because each element in data_loader is a tuple 595 | (images, labels, domain_labels). 596 | 597 | Args: 598 | data_loader (torch.utils.data.DataLoader): Data loader. 599 | feature_extractor (torch.nn.Module): A feature extractor. 600 | device (torch.device) 601 | max_num_features (int): The max number of features to return 602 | 603 | Returns: 604 | Features in shape (min(len(data_loader), max_num_features * mini-batch size), :math:`|\mathcal{F}|`). 605 | """ 606 | feature_extractor.eval() 607 | all_features = [] 608 | labels = [] 609 | with torch.no_grad(): 610 | for i, (images, target, domain_labels) in enumerate(tqdm.tqdm(data_loader)): 611 | if max_num_features is not None and i >= max_num_features: 612 | break 613 | images = images.to(device) 614 | feature = feature_extractor(images).cpu() 615 | all_features.append(feature) 616 | labels.append(target) 617 | 618 | if return_label == True: 619 | return torch.cat(all_features, dim=0), torch.cat(labels, dim=0) 620 | else: 621 | return torch.cat(all_features, dim=0) 622 | 623 | 624 | import math 625 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 626 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 627 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 628 | def norm_cdf(x): 629 | # Computes standard normal cumulative distribution function 630 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 631 | 632 | if (mean < a - 2 * std) or (mean > b + 2 * std): 633 | print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 634 | "The distribution of values may be incorrect.", 635 | stacklevel=2) 636 | 637 | with torch.no_grad(): 638 | # Values are generated by using a truncated uniform distribution and 639 | # then using the inverse CDF for the normal distribution. 640 | # Get upper and lower cdf values 641 | l = norm_cdf((a - mean) / std) 642 | u = norm_cdf((b - mean) / std) 643 | 644 | # Uniformly fill tensor with values from [l, u], then translate to 645 | # [2l-1, 2u-1]. 646 | tensor.uniform_(2 * l - 1, 2 * u - 1) 647 | 648 | # Use inverse cdf transform for normal distribution to get truncated 649 | # standard normal 650 | tensor.erfinv_() 651 | 652 | # Transform to proper mean, std 653 | tensor.mul_(std * math.sqrt(2.)) 654 | tensor.add_(mean) 655 | 656 | # Clamp to ensure it's in the proper range 657 | tensor.clamp_(min=a, max=b) 658 | return tensor 659 | 660 | 661 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 662 | # type: (Tensor, float, float, float, float) -> Tensor 663 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 664 | 665 | -------------------------------------------------------------------------------- /utils/analysis/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch.utils.data import DataLoader 4 | import torch.nn as nn 5 | import tqdm 6 | 7 | 8 | def collect_feature(data_loader: DataLoader, feature_extractor: nn.Module, 9 | device: torch.device, max_num_features=None) -> torch.Tensor: 10 | """ 11 | Fetch data from `data_loader`, and then use `feature_extractor` to collect features 12 | 13 | Args: 14 | data_loader (torch.utils.data.DataLoader): Data loader. 15 | feature_extractor (torch.nn.Module): A feature extractor. 16 | device (torch.device) 17 | max_num_features (int): The max number of features to return 18 | 19 | Returns: 20 | Features in shape (min(len(data_loader), max_num_features * mini-batch size), :math:`|\mathcal{F}|`). 21 | """ 22 | feature_extractor.eval() 23 | all_features = [] 24 | with torch.no_grad(): 25 | for i, (images, target) in enumerate(tqdm.tqdm(data_loader)): 26 | if max_num_features is not None and i >= max_num_features: 27 | break 28 | images = images.to(device) 29 | feature = feature_extractor(images).cpu() 30 | all_features.append(feature) 31 | return torch.cat(all_features, dim=0) 32 | -------------------------------------------------------------------------------- /utils/analysis/a_distance.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | from torch.utils.data import TensorDataset 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | from torch.optim import SGD 11 | from ..meter import AverageMeter 12 | from ..metric import binary_accuracy 13 | 14 | 15 | class ANet(nn.Module): 16 | def __init__(self, in_feature): 17 | super(ANet, self).__init__() 18 | self.layer = nn.Linear(in_feature, 1) 19 | self.sigmoid = nn.Sigmoid() 20 | 21 | def forward(self, x): 22 | x = self.layer(x) 23 | x = self.sigmoid(x) 24 | return x 25 | 26 | 27 | def calculate(source_feature: torch.Tensor, target_feature: torch.Tensor, 28 | device, progress=True, training_epochs=10): 29 | """ 30 | Calculate the :math:`\mathcal{A}`-distance, which is a measure for distribution discrepancy. 31 | 32 | The definition is :math:`dist_\mathcal{A} = 2 (1-2\epsilon)`, where :math:`\epsilon` is the 33 | test error of a classifier trained to discriminate the source from the target. 34 | 35 | Args: 36 | source_feature (tensor): features from source domain in shape :math:`(minibatch, F)` 37 | target_feature (tensor): features from target domain in shape :math:`(minibatch, F)` 38 | device (torch.device) 39 | progress (bool): if True, displays a the progress of training A-Net 40 | training_epochs (int): the number of epochs when training the classifier 41 | 42 | Returns: 43 | :math:`\mathcal{A}`-distance 44 | """ 45 | source_label = torch.ones((source_feature.shape[0], 1)) 46 | target_label = torch.zeros((target_feature.shape[0], 1)) 47 | feature = torch.cat([source_feature, target_feature], dim=0) 48 | label = torch.cat([source_label, target_label], dim=0) 49 | 50 | dataset = TensorDataset(feature, label) 51 | length = len(dataset) 52 | train_size = int(0.8 * length) 53 | val_size = length - train_size 54 | train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size]) 55 | train_loader = DataLoader(train_set, batch_size=2, shuffle=True) 56 | val_loader = DataLoader(val_set, batch_size=8, shuffle=False) 57 | 58 | anet = ANet(feature.shape[1]).to(device) 59 | optimizer = SGD(anet.parameters(), lr=0.01) 60 | a_distance = 2.0 61 | for epoch in range(training_epochs): 62 | anet.train() 63 | for (x, label) in train_loader: 64 | x = x.to(device) 65 | label = label.to(device) 66 | anet.zero_grad() 67 | y = anet(x) 68 | loss = F.binary_cross_entropy(y, label) 69 | loss.backward() 70 | optimizer.step() 71 | 72 | anet.eval() 73 | meter = AverageMeter("accuracy", ":4.2f") 74 | with torch.no_grad(): 75 | for (x, label) in val_loader: 76 | x = x.to(device) 77 | label = label.to(device) 78 | y = anet(x) 79 | acc = binary_accuracy(y, label) 80 | meter.update(acc, x.shape[0]) 81 | error = 1 - meter.avg / 100 82 | a_distance = 2 * (1 - 2 * error) 83 | if progress: 84 | print("epoch {} accuracy: {} A-dist: {}".format(epoch, meter.avg, a_distance)) 85 | 86 | return a_distance 87 | 88 | -------------------------------------------------------------------------------- /utils/analysis/tsne.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import torch 6 | import matplotlib 7 | 8 | matplotlib.use('Agg') 9 | from sklearn.manifold import TSNE 10 | import numpy as np 11 | import matplotlib.pyplot as plt 12 | import matplotlib.colors as col 13 | 14 | 15 | def visualize(source_feature: torch.Tensor, target_feature: torch.Tensor, 16 | filename: str, source_color='r', target_color='b'): 17 | """ 18 | Visualize features from different domains using t-SNE. 19 | 20 | Args: 21 | source_feature (tensor): features from source domain in shape :math:`(minibatch, F)` 22 | target_feature (tensor): features from target domain in shape :math:`(minibatch, F)` 23 | filename (str): the file name to save t-SNE 24 | source_color (str): the color of the source features. Default: 'r' 25 | target_color (str): the color of the target features. Default: 'b' 26 | 27 | """ 28 | source_feature = source_feature.numpy() 29 | target_feature = target_feature.numpy() 30 | features = np.concatenate([source_feature, target_feature], axis=0) 31 | 32 | # map features to 2-d using TSNE 33 | X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features) 34 | 35 | # domain labels, 1 represents source while 0 represents target 36 | domains = np.concatenate((np.ones(len(source_feature)), np.zeros(len(target_feature)))) 37 | 38 | # visualize using matplotlib 39 | fig, ax = plt.subplots(figsize=(10, 10)) 40 | ax.spines['top'].set_visible(False) 41 | ax.spines['right'].set_visible(False) 42 | ax.spines['bottom'].set_visible(False) 43 | ax.spines['left'].set_visible(False) 44 | plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=domains, cmap=col.ListedColormap([target_color, source_color]), s=20) 45 | 46 | plt.tight_layout() 47 | plt.xticks([]) 48 | plt.yticks([]) 49 | plt.savefig(filename) 50 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import random 3 | import numpy as np 4 | 5 | import torch 6 | from torch.utils.data import Sampler 7 | from torch.utils.data import DataLoader, Dataset 8 | from typing import TypeVar, Iterable, Dict, List 9 | 10 | T_co = TypeVar('T_co', covariant=True) 11 | T = TypeVar('T') 12 | 13 | 14 | def send_to_device(tensor, device): 15 | """ 16 | Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device. 17 | 18 | Args: 19 | tensor (nested list/tuple/dictionary of :obj:`torch.Tensor`): 20 | The data to send to a given device. 21 | device (:obj:`torch.device`): 22 | The device to send the data to 23 | 24 | Returns: 25 | The same data structure as :obj:`tensor` with all tensors sent to the proper device. 26 | """ 27 | if isinstance(tensor, (list, tuple)): 28 | return type(tensor)(send_to_device(t, device) for t in tensor) 29 | elif isinstance(tensor, dict): 30 | return type(tensor)({k: send_to_device(v, device) for k, v in tensor.items()}) 31 | elif not hasattr(tensor, "to"): 32 | return tensor 33 | return tensor.to(device) 34 | 35 | 36 | class ForeverDataIterator: 37 | r"""A data iterator that will never stop producing data""" 38 | 39 | def __init__(self, data_loader: DataLoader, device=None): 40 | self.data_loader = data_loader 41 | self.iter = iter(self.data_loader) 42 | self.device = device 43 | 44 | def __next__(self): 45 | try: 46 | data = next(self.iter) 47 | if self.device is not None: 48 | data = send_to_device(data, self.device) 49 | except StopIteration: 50 | self.iter = iter(self.data_loader) 51 | data = next(self.iter) 52 | if self.device is not None: 53 | data = send_to_device(data, self.device) 54 | return data 55 | 56 | def __len__(self): 57 | return len(self.data_loader) 58 | 59 | 60 | class RandomMultipleGallerySampler(Sampler): 61 | r"""Sampler from `In defense of the Triplet Loss for Person Re-Identification 62 | (ICCV 2017) `_. Assume there are :math:`N` identities in the dataset, this 63 | implementation simply samples :math:`K` images for every identity to form an iter of size :math:`N\times K`. During 64 | training, we will call ``__iter__`` method of pytorch dataloader once we reach a ``StopIteration``, this guarantees 65 | every image in the dataset will eventually be selected and we are not wasting any training data. 66 | 67 | Args: 68 | dataset(list): each element of this list is a tuple (image_path, person_id, camera_id) 69 | num_instances(int, optional): number of images to sample for every identity (:math:`K` here) 70 | """ 71 | 72 | def __init__(self, dataset, num_instances=4): 73 | super(RandomMultipleGallerySampler, self).__init__(dataset) 74 | self.dataset = dataset 75 | self.num_instances = num_instances 76 | 77 | self.idx_to_pid = {} 78 | self.cid_list_per_pid = {} 79 | self.idx_list_per_pid = {} 80 | 81 | for idx, (_, pid, cid) in enumerate(dataset): 82 | if pid not in self.cid_list_per_pid: 83 | self.cid_list_per_pid[pid] = [] 84 | self.idx_list_per_pid[pid] = [] 85 | 86 | self.idx_to_pid[idx] = pid 87 | self.cid_list_per_pid[pid].append(cid) 88 | self.idx_list_per_pid[pid].append(idx) 89 | 90 | self.pid_list = list(self.idx_list_per_pid.keys()) 91 | self.num_samples = len(self.pid_list) 92 | 93 | def __len__(self): 94 | return self.num_samples * self.num_instances 95 | 96 | def __iter__(self): 97 | def select_idxes(element_list, target_element): 98 | assert isinstance(element_list, list) 99 | return [i for i, element in enumerate(element_list) if element != target_element] 100 | 101 | pid_idxes = torch.randperm(len(self.pid_list)).tolist() 102 | final_idxes = [] 103 | 104 | for perm_id in pid_idxes: 105 | i = random.choice(self.idx_list_per_pid[self.pid_list[perm_id]]) 106 | _, _, cid = self.dataset[i] 107 | 108 | final_idxes.append(i) 109 | 110 | pid_i = self.idx_to_pid[i] 111 | cid_list = self.cid_list_per_pid[pid_i] 112 | idx_list = self.idx_list_per_pid[pid_i] 113 | selected_cid_list = select_idxes(cid_list, cid) 114 | 115 | if selected_cid_list: 116 | if len(selected_cid_list) >= self.num_instances: 117 | cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=False) 118 | else: 119 | cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=True) 120 | for cid_idx in cid_idxes: 121 | final_idxes.append(idx_list[cid_idx]) 122 | else: 123 | selected_idxes = select_idxes(idx_list, i) 124 | if not selected_idxes: 125 | continue 126 | if len(selected_idxes) >= self.num_instances: 127 | pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=False) 128 | else: 129 | pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=True) 130 | 131 | for pid_idx in pid_idxes: 132 | final_idxes.append(idx_list[pid_idx]) 133 | 134 | return iter(final_idxes) 135 | 136 | 137 | class CombineDataset(Dataset[T_co]): 138 | r"""Dataset as a combination of multiple datasets. 139 | The element of each dataset must be a list, and the i-th element of the combined dataset 140 | is a list splicing of the i-th element of each sub dataset. 141 | The length of the combined dataset is the minimum of the lengths of all sub datasets. 142 | 143 | Arguments: 144 | datasets (sequence): List of datasets to be concatenated 145 | """ 146 | 147 | def __init__(self, datasets: Iterable[Dataset]) -> None: 148 | super(CombineDataset, self).__init__() 149 | # Cannot verify that datasets is Sized 150 | assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore 151 | self.datasets = list(datasets) 152 | 153 | def __len__(self): 154 | return min([len(d) for d in self.datasets]) 155 | 156 | def __getitem__(self, idx): 157 | return list(itertools.chain(*[d[idx] for d in self.datasets])) 158 | 159 | 160 | def concatenate(tensors): 161 | """concatenate multiple batches into one batch. 162 | ``tensors`` can be :class:`torch.Tensor`, List or Dict, but they must be the same data format. 163 | """ 164 | if isinstance(tensors[0], torch.Tensor): 165 | return torch.cat(tensors, dim=0) 166 | elif isinstance(tensors[0], List): 167 | ret = [] 168 | for i in range(len(tensors[0])): 169 | ret.append(concatenate([t[i] for t in tensors])) 170 | return ret 171 | elif isinstance(tensors[0], Dict): 172 | ret = dict() 173 | for k in tensors[0].keys(): 174 | ret[k] = concatenate([t[k] for t in tensors]) 175 | return ret 176 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import time 5 | 6 | class TextLogger(object): 7 | """Writes stream output to external text file. 8 | 9 | Args: 10 | filename (str): the file to write stream output 11 | stream: the stream to read from. Default: sys.stdout 12 | """ 13 | def __init__(self, filename, stream=sys.stdout): 14 | self.terminal = stream 15 | self.log = open(filename, 'a') 16 | 17 | def write(self, message): 18 | self.terminal.write(message) 19 | self.log.write(message) 20 | self.flush() 21 | 22 | def flush(self): 23 | self.terminal.flush() 24 | self.log.flush() 25 | 26 | def close(self): 27 | self.terminal.close() 28 | self.log.close() 29 | 30 | 31 | class CompleteLogger: 32 | """ 33 | A useful logger that 34 | 35 | - writes outputs to files and displays them on the console at the same time. 36 | - manages the directory of checkpoints and debugging images. 37 | 38 | Args: 39 | root (str): the root directory of logger 40 | phase (str): the phase of training. 41 | 42 | """ 43 | 44 | def __init__(self, root, phase='train'): 45 | self.root = root 46 | self.phase = phase 47 | self.visualize_directory = os.path.join(self.root.replace('logs', 'output_viz'), "visualize") 48 | 49 | self.epoch = 0 50 | 51 | os.makedirs(self.root, exist_ok=True) 52 | os.makedirs(self.visualize_directory, exist_ok=True) 53 | 54 | # redirect std out 55 | now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time())) 56 | hname = os.uname() 57 | if 'kruskal' in hname.nodename: 58 | self.checkpoint_directory = os.path.join(self.root.replace('logs', 'checkpoints_yoda'), "checkpoints") 59 | else: 60 | self.checkpoint_directory = os.path.join(self.root.replace('logs', 'checkpoints'), "checkpoints") 61 | os.makedirs(self.checkpoint_directory, exist_ok=True) 62 | log_filename = os.path.join(self.root, "{}-{}.txt".format(phase, now)) 63 | 64 | if os.path.exists(log_filename): 65 | os.remove(log_filename) 66 | self.logger = TextLogger(log_filename) 67 | # sys.stdout = self.logger 68 | # sys.stderr = self.logger 69 | if phase != 'train': 70 | self.set_epoch(phase) 71 | 72 | print(self.visualize_directory) 73 | print(self.checkpoint_directory) 74 | print(log_filename) 75 | 76 | self.write(self.visualize_directory) 77 | self.write(self.checkpoint_directory) 78 | self.write(log_filename) 79 | 80 | 81 | def set_epoch(self, epoch): 82 | """Set the epoch number. Please use it during training.""" 83 | os.makedirs(os.path.join(self.visualize_directory, str(epoch)), exist_ok=True) 84 | self.epoch = epoch 85 | 86 | 87 | def _get_phase_or_epoch(self): 88 | if self.phase == 'train': 89 | return str(self.epoch) 90 | else: 91 | return self.phase 92 | 93 | def get_image_path(self, filename: str): 94 | """ 95 | Get the full image path for a specific filename 96 | """ 97 | return os.path.join(self.visualize_directory, self._get_phase_or_epoch(), filename) 98 | 99 | def get_checkpoint_path(self, name=None): 100 | """ 101 | Get the full checkpoint path. 102 | 103 | Args: 104 | name (optional): the filename (without file extension) to save checkpoint. 105 | If None, when the phase is ``train``, checkpoint will be saved to ``{epoch}.pth``. 106 | Otherwise, will be saved to ``{phase}.pth``. 107 | 108 | """ 109 | if name is None: 110 | name = self._get_phase_or_epoch() 111 | name = str(name) 112 | return os.path.join(self.checkpoint_directory, name + ".pth") 113 | 114 | def write(self, str): 115 | self.logger.write(str + '\n') 116 | 117 | 118 | def close(self): 119 | self.logger.close() 120 | 121 | -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Optional, List 3 | 4 | 5 | class AverageMeter(object): 6 | r"""Computes and stores the average and current value. 7 | 8 | Examples:: 9 | 10 | >>> # Initialize a meter to record loss 11 | >>> losses = AverageMeter() 12 | >>> # Update meter after every minibatch update 13 | >>> losses.update(loss_value, batch_size) 14 | """ 15 | def __init__(self, name: str, fmt: Optional[str] = ':f'): 16 | self.name = name 17 | self.fmt = fmt 18 | self.reset() 19 | 20 | def reset(self): 21 | self.val = 0 22 | self.avg = 0 23 | self.sum = 0 24 | self.count = 0 25 | 26 | def update(self, val, n=1): 27 | self.val = val 28 | self.sum += val * n 29 | self.count += n 30 | if self.count > 0: 31 | self.avg = self.sum / self.count 32 | 33 | def __str__(self): 34 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 35 | return fmtstr.format(**self.__dict__) 36 | 37 | 38 | class AverageMeterDict(object): 39 | def __init__(self, names: List, fmt: Optional[str] = ':f'): 40 | self.dict = { 41 | name: AverageMeter(name, fmt) for name in names 42 | } 43 | 44 | def reset(self): 45 | for meter in self.dict.values(): 46 | meter.reset() 47 | 48 | def update(self, accuracies, n=1): 49 | for name, acc in accuracies.items(): 50 | self.dict[name].update(acc, n) 51 | 52 | def average(self): 53 | return { 54 | name: meter.avg for name, meter in self.dict.items() 55 | } 56 | 57 | def __getitem__(self, item): 58 | return self.dict[item] 59 | 60 | 61 | class Meter(object): 62 | """Computes and stores the current value.""" 63 | def __init__(self, name: str, fmt: Optional[str] = ':f'): 64 | self.name = name 65 | self.fmt = fmt 66 | self.reset() 67 | 68 | def reset(self): 69 | self.val = 0 70 | 71 | def update(self, val): 72 | self.val = val 73 | 74 | def __str__(self): 75 | fmtstr = '{name} {val' + self.fmt + '}' 76 | return fmtstr.format(**self.__dict__) 77 | 78 | 79 | class ProgressMeter(object): 80 | def __init__(self, num_batches, meters, prefix=""): 81 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 82 | self.meters = meters 83 | self.prefix = prefix 84 | 85 | def display(self, batch): 86 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 87 | entries += [str(meter) for meter in self.meters] 88 | print('\t'.join(entries)) 89 | 90 | def _get_batch_fmtstr(self, num_batches): 91 | num_digits = len(str(num_batches // 1)) 92 | fmt = '{:' + str(num_digits) + 'd}' 93 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 94 | 95 | 96 | -------------------------------------------------------------------------------- /utils/metric/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import prettytable 3 | 4 | __all__ = ['keypoint_detection'] 5 | 6 | def binary_accuracy(output: torch.Tensor, target: torch.Tensor) -> float: 7 | """Computes the accuracy for binary classification""" 8 | with torch.no_grad(): 9 | batch_size = target.size(0) 10 | pred = (output >= 0.5).float().t().view(-1) 11 | correct = pred.eq(target.view(-1)).float().sum() 12 | correct.mul_(100. / batch_size) 13 | return correct 14 | 15 | 16 | def accuracy(output, target, topk=(1,)): 17 | r""" 18 | Computes the accuracy over the k top predictions for the specified values of k 19 | 20 | Args: 21 | output (tensor): Classification outputs, :math:`(N, C)` where `C = number of classes` 22 | target (tensor): :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1` 23 | topk (sequence[int]): A list of top-N number. 24 | 25 | Returns: 26 | Top-N accuracies (N :math:`\in` topK). 27 | """ 28 | with torch.no_grad(): 29 | maxk = max(topk) 30 | batch_size = target.size(0) 31 | 32 | _, pred = output.topk(maxk, 1, True, True) 33 | pred = pred.t() 34 | correct = pred.eq(target[None]) 35 | 36 | res = [] 37 | for k in topk: 38 | correct_k = correct[:k].flatten().sum(dtype=torch.float32) 39 | res.append(correct_k * (100.0 / batch_size)) 40 | return res 41 | 42 | 43 | class ConfusionMatrix(object): 44 | def __init__(self, num_classes): 45 | self.num_classes = num_classes 46 | self.mat = None 47 | 48 | def update(self, target, output): 49 | """ 50 | Update confusion matrix. 51 | 52 | Args: 53 | target: ground truth 54 | output: predictions of models 55 | 56 | Shape: 57 | - target: :math:`(minibatch, C)` where C means the number of classes. 58 | - output: :math:`(minibatch, C)` where C means the number of classes. 59 | """ 60 | n = self.num_classes 61 | if self.mat is None: 62 | self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device) 63 | with torch.no_grad(): 64 | k = (target >= 0) & (target < n) 65 | inds = n * target[k].to(torch.int64) + output[k] 66 | self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) 67 | 68 | def reset(self): 69 | self.mat.zero_() 70 | 71 | def compute(self): 72 | """compute global accuracy, per-class accuracy and per-class IoU""" 73 | h = self.mat.float() 74 | acc_global = torch.diag(h).sum() / h.sum() 75 | acc = torch.diag(h) / h.sum(1) 76 | iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) 77 | return acc_global, acc, iu 78 | 79 | # def reduce_from_all_processes(self): 80 | # if not torch.distributed.is_available(): 81 | # return 82 | # if not torch.distributed.is_initialized(): 83 | # return 84 | # torch.distributed.barrier() 85 | # torch.distributed.all_reduce(self.mat) 86 | 87 | def __str__(self): 88 | acc_global, acc, iu = self.compute() 89 | return ( 90 | 'global correct: {:.1f}\n' 91 | 'average row correct: {}\n' 92 | 'IoU: {}\n' 93 | 'mean IoU: {:.1f}').format( 94 | acc_global.item() * 100, 95 | ['{:.1f}'.format(i) for i in (acc * 100).tolist()], 96 | ['{:.1f}'.format(i) for i in (iu * 100).tolist()], 97 | iu.mean().item() * 100) 98 | 99 | def format(self, classes: list): 100 | """Get the accuracy and IoU for each class in the table format""" 101 | acc_global, acc, iu = self.compute() 102 | 103 | table = prettytable.PrettyTable(["class", "acc", "iou"]) 104 | for i, class_name, per_acc, per_iu in zip(range(len(classes)), classes, (acc * 100).tolist(), (iu * 100).tolist()): 105 | table.add_row([class_name, per_acc, per_iu]) 106 | 107 | return 'global correct: {:.1f}\nmean correct:{:.1f}\nmean IoU: {:.1f}\n{}'.format( 108 | acc_global.item() * 100, acc.mean().item() * 100, iu.mean().item() * 100, table.get_string()) 109 | 110 | -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/yxgeee/MMT 3 | @author: Baixu Chen 4 | @contact: cbx_99_hasta@outlook.com 5 | """ 6 | import torch 7 | from bisect import bisect_right 8 | 9 | 10 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 11 | r"""Starts with a warm-up phase, then decays the learning rate of each parameter group by gamma once the 12 | number of epoch reaches one of the milestones. When last_epoch=-1, sets initial lr as lr. 13 | 14 | Args: 15 | optimizer (Optimizer): Wrapped optimizer. 16 | milestones (list): List of epoch indices. Must be increasing. 17 | gamma (float): Multiplicative factor of learning rate decay. 18 | Default: 0.1. 19 | warmup_factor (float): a float number :math:`k` between 0 and 1, the start learning rate of warmup phase 20 | will be set to :math:`k*initial\_lr` 21 | warmup_steps (int): number of warm-up steps. 22 | warmup_method (str): "constant" denotes a constant learning rate during warm-up phase and "linear" denotes a 23 | linear-increasing learning rate during warm-up phase. 24 | last_epoch (int): The index of last epoch. Default: -1. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | optimizer, 30 | milestones, 31 | gamma=0.1, 32 | warmup_factor=1.0 / 3, 33 | warmup_steps=500, 34 | warmup_method="linear", 35 | last_epoch=-1, 36 | ): 37 | if not list(milestones) == sorted(milestones): 38 | raise ValueError( 39 | "Milestones should be a list of" " increasing integers. Got {}", 40 | milestones, 41 | ) 42 | 43 | if warmup_method not in ("constant", "linear"): 44 | raise ValueError( 45 | "Only 'constant' or 'linear' warmup_method accepted" 46 | "got {}".format(warmup_method) 47 | ) 48 | self.milestones = milestones 49 | self.gamma = gamma 50 | self.warmup_factor = warmup_factor 51 | self.warmup_steps = warmup_steps 52 | self.warmup_method = warmup_method 53 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 54 | 55 | def get_lr(self): 56 | warmup_factor = 1 57 | if self.last_epoch < self.warmup_steps: 58 | if self.warmup_method == "constant": 59 | warmup_factor = self.warmup_factor 60 | elif self.warmup_method == "linear": 61 | alpha = float(self.last_epoch) / float(self.warmup_steps) 62 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 63 | return [ 64 | base_lr 65 | * warmup_factor 66 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 67 | for base_lr in self.base_lrs 68 | ] 69 | -------------------------------------------------------------------------------- /vision/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['datasets', 'models', 'transforms'] 2 | -------------------------------------------------------------------------------- /vision/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .imagelist import ImageList 2 | from .office31 import Office31 3 | from .officehome import OfficeHome 4 | from .domainnet import DomainNet 5 | from .cub200 import CUB200 6 | from .cub import CUB 7 | 8 | __all__ = ['ImageList', 'Office31', 'OfficeHome', "DomainNet", "cub200", "CUB"] 9 | -------------------------------------------------------------------------------- /vision/datasets/_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import os 6 | from typing import List 7 | from torchvision.datasets.utils import download_and_extract_archive 8 | 9 | 10 | def download(root: str, file_name: str, archive_name: str, url_link: str): 11 | """ 12 | Download file from internet url link. 13 | 14 | Args: 15 | root (str) The directory to put downloaded files. 16 | file_name: (str) The name of the unzipped file. 17 | archive_name: (str) The name of archive(zipped file) downloaded. 18 | url_link: (str) The url link to download data. 19 | 20 | .. note:: 21 | If `file_name` already exists under path `root`, then it is not downloaded again. 22 | Else `archive_name` will be downloaded from `url_link` and extracted to `file_name`. 23 | """ 24 | if not os.path.exists(os.path.join(root, file_name)): 25 | print("Downloading {}".format(file_name)) 26 | # if os.path.exists(os.path.join(root, archive_name)): 27 | # os.remove(os.path.join(root, archive_name)) 28 | try: 29 | download_and_extract_archive(url_link, download_root=root, filename=archive_name, remove_finished=False) 30 | except Exception: 31 | print("Fail to download {} from url link {}".format(archive_name, url_link)) 32 | print('Please check you internet connection.' 33 | "Simply trying again may be fine.") 34 | exit(0) 35 | 36 | 37 | def check_exits(root: str, file_name: str): 38 | """Check whether `file_name` exists under directory `root`. """ 39 | if not os.path.exists(os.path.join(root, file_name)): 40 | print("Dataset directory {} not found under {}".format(file_name, root)) 41 | exit(-1) 42 | 43 | 44 | def read_list_from_file(file_name: str) -> List[str]: 45 | """Read data from file and convert each line into an element in the list""" 46 | result = [] 47 | with open(file_name, "r") as f: 48 | for line in f.readlines(): 49 | result.append(line.strip()) 50 | return result 51 | -------------------------------------------------------------------------------- /vision/datasets/cub.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import os 6 | from typing import Optional 7 | from .imagelist import ImageList 8 | from ._util import download as download_data, check_exits 9 | 10 | 11 | class CUB(ImageList): 12 | """`CUB `_ Dataset. 13 | 14 | Args: 15 | root (str): Root directory of dataset 16 | task (str): The task (domain) to create dataset. Choices include ``'Ar'``: Art, \ 17 | ``'Cl'``: Clipart, ``'Pr'``: Product and ``'Rw'``: Real_World. 18 | download (bool, optional): If true, downloads the dataset from the internet and puts it \ 19 | in root directory. If dataset is already downloaded, it is not downloaded again. 20 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 21 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 22 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 23 | 24 | .. note:: In `root`, there will exist following files after downloading. 25 | :: 26 | Art/ 27 | Alarm_Clock/*.jpg 28 | ... 29 | Clipart/ 30 | Product/ 31 | Real_World/ 32 | image_list/ 33 | Art.txt 34 | Clipart.txt 35 | Product.txt 36 | Real_World.txt 37 | """ 38 | download_list = [ 39 | ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/ca3a3b6a8d554905b4cd/?dl=1"), 40 | ("Art", "Art.tgz", "https://cloud.tsinghua.edu.cn/f/4691878067d04755beab/?dl=1"), 41 | ("Clipart", "Clipart.tgz", "https://cloud.tsinghua.edu.cn/f/0d41e7da4558408ea5aa/?dl=1"), 42 | ("Product", "Product.tgz", "https://cloud.tsinghua.edu.cn/f/76186deacd7c4fa0a679/?dl=1"), 43 | ("Real_World", "Real_World.tgz", "https://cloud.tsinghua.edu.cn/f/dee961894cc64b1da1d7/?dl=1") 44 | ] 45 | image_list = { 46 | "Rw": "image_list/CUB_200_2011.txt", 47 | "Pr": "image_list/CUB_painting.txt", 48 | # "Pr": "image_list/Product.txt", 49 | # "Rw": "image_list/Real_World.txt", 50 | } 51 | CLASSES = ['001.Black_footed_Albatross', '002.Laysan_Albatross', '003.Sooty_Albatross', '004.Groove_billed_Ani', 52 | '005.Crested_Auklet', '006.Least_Auklet', '007.Parakeet_Auklet', '008.Rhinoceros_Auklet', 53 | '009.Brewer_Blackbird', '010.Red_winged_Blackbird', '011.Rusty_Blackbird', '012.Yellow_headed_Blackbird', 54 | '013.Bobolink', '014.Indigo_Bunting', '015.Lazuli_Bunting', '016.Painted_Bunting', '017.Cardinal', 55 | '018.Spotted_Catbird', '019.Gray_Catbird', '020.Yellow_breasted_Chat', '021.Eastern_Towhee', 56 | '022.Chuck_will_Widow', '023.Brandt_Cormorant', '024.Red_faced_Cormorant', '025.Pelagic_Cormorant', 57 | '026.Bronzed_Cowbird', '027.Shiny_Cowbird', '028.Brown_Creeper', '029.American_Crow', '030.Fish_Crow', 58 | '031.Black_billed_Cuckoo', '032.Mangrove_Cuckoo', '033.Yellow_billed_Cuckoo', 59 | '034.Gray_crowned_Rosy_Finch', '035.Purple_Finch', '036.Northern_Flicker', '037.Acadian_Flycatcher', 60 | '038.Great_Crested_Flycatcher', '039.Least_Flycatcher', '040.Olive_sided_Flycatcher', 61 | '041.Scissor_tailed_Flycatcher', '042.Vermilion_Flycatcher', '043.Yellow_bellied_Flycatcher', 62 | '044.Frigatebird', '045.Northern_Fulmar', '046.Gadwall', '047.American_Goldfinch', 63 | '048.European_Goldfinch', '049.Boat_tailed_Grackle', '050.Eared_Grebe', 64 | '051.Horned_Grebe', '052.Pied_billed_Grebe', '053.Western_Grebe', '054.Blue_Grosbeak', 65 | '055.Evening_Grosbeak', '056.Pine_Grosbeak', '057.Rose_breasted_Grosbeak', '058.Pigeon_Guillemot', 66 | '059.California_Gull', '060.Glaucous_winged_Gull', '061.Heermann_Gull', '062.Herring_Gull', 67 | '063.Ivory_Gull', '064.Ring_billed_Gull', '065.Slaty_backed_Gull', '066.Western_Gull', 68 | '067.Anna_Hummingbird', '068.Ruby_throated_Hummingbird', '069.Rufous_Hummingbird', '070.Green_Violetear', 69 | '071.Long_tailed_Jaeger', '072.Pomarine_Jaeger', '073.Blue_Jay', '074.Florida_Jay', '075.Green_Jay', 70 | '076.Dark_eyed_Junco', '077.Tropical_Kingbird', '078.Gray_Kingbird', '079.Belted_Kingfisher', 71 | '080.Green_Kingfisher', '081.Pied_Kingfisher', '082.Ringed_Kingfisher', '083.White_breasted_Kingfisher', 72 | '084.Red_legged_Kittiwake', '085.Horned_Lark', '086.Pacific_Loon', '087.Mallard', 73 | '088.Western_Meadowlark', '089.Hooded_Merganser', '090.Red_breasted_Merganser', '091.Mockingbird', 74 | '092.Nighthawk', '093.Clark_Nutcracker', '094.White_breasted_Nuthatch', '095.Baltimore_Oriole', 75 | '096.Hooded_Oriole', '097.Orchard_Oriole', '098.Scott_Oriole', '099.Ovenbird', '100.Brown_Pelican', 76 | '101.White_Pelican', '102.Western_Wood_Pewee', '103.Sayornis', '104.American_Pipit', 77 | '105.Whip_poor_Will', '106.Horned_Puffin', '107.Common_Raven', '108.White_necked_Raven', 78 | '109.American_Redstart', '110.Geococcyx', '111.Loggerhead_Shrike', '112.Great_Grey_Shrike', 79 | '113.Baird_Sparrow', '114.Black_throated_Sparrow', '115.Brewer_Sparrow', '116.Chipping_Sparrow', 80 | '117.Clay_colored_Sparrow', '118.House_Sparrow', '119.Field_Sparrow', '120.Fox_Sparrow', 81 | '121.Grasshopper_Sparrow', '122.Harris_Sparrow', '123.Henslow_Sparrow', '124.Le_Conte_Sparrow', 82 | '125.Lincoln_Sparrow', '126.Nelson_Sharp_tailed_Sparrow', '127.Savannah_Sparrow', '128.Seaside_Sparrow', 83 | '129.Song_Sparrow', '130.Tree_Sparrow', '131.Vesper_Sparrow', '132.White_crowned_Sparrow', 84 | '133.White_throated_Sparrow', '134.Cape_Glossy_Starling', '135.Bank_Swallow', '136.Barn_Swallow', 85 | '137.Cliff_Swallow', '138.Tree_Swallow', '139.Scarlet_Tanager', '140.Summer_Tanager', '141.Artic_Tern', 86 | '142.Black_Tern', '143.Caspian_Tern', '144.Common_Tern', '145.Elegant_Tern', '146.Forsters_Tern', 87 | '147.Least_Tern', '148.Green_tailed_Towhee', '149.Brown_Thrasher', '150.Sage_Thrasher', 88 | '151.Black_capped_Vireo', '152.Blue_headed_Vireo', '153.Philadelphia_Vireo', '154.Red_eyed_Vireo', 89 | '155.Warbling_Vireo', '156.White_eyed_Vireo', '157.Yellow_throated_Vireo', '158.Bay_breasted_Warbler', 90 | '159.Black_and_white_Warbler', '160.Black_throated_Blue_Warbler', '161.Blue_winged_Warbler', 91 | '162.Canada_Warbler', '163.Cape_May_Warbler', '164.Cerulean_Warbler', '165.Chestnut_sided_Warbler', 92 | '166.Golden_winged_Warbler', '167.Hooded_Warbler', '168.Kentucky_Warbler', '169.Magnolia_Warbler', 93 | '170.Mourning_Warbler', '171.Myrtle_Warbler', '172.Nashville_Warbler', '173.Orange_crowned_Warbler', 94 | '174.Palm_Warbler', '175.Pine_Warbler', '176.Prairie_Warbler', '177.Prothonotary_Warbler', 95 | '178.Swainson_Warbler', '179.Tennessee_Warbler', '180.Wilson_Warbler', '181.Worm_eating_Warbler', 96 | '182.Yellow_Warbler', '183.Northern_Waterthrush', '184.Louisiana_Waterthrush', '185.Bohemian_Waxwing', 97 | '186.Cedar_Waxwing', '187.American_Three_toed_Woodpecker', '188.Pileated_Woodpecker', 98 | '189.Red_bellied_Woodpecker', '190.Red_cockaded_Woodpecker', '191.Red_headed_Woodpecker', 99 | '192.Downy_Woodpecker', '193.Bewick_Wren', '194.Cactus_Wren', '195.Carolina_Wren', '196.House_Wren', 100 | '197.Marsh_Wren', '198.Rock_Wren', '199.Winter_Wren', '200.Common_Yellowthroat'] 101 | 102 | def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs): 103 | assert task in self.image_list 104 | data_list_file = os.path.join(root, self.image_list[task]) 105 | 106 | # if download: 107 | # list(map(lambda args: download_data(root, *args), self.download_list)) 108 | # else: 109 | # list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) 110 | 111 | super(CUB, self).__init__(root, CUB.CLASSES, data_list_file=data_list_file, **kwargs) 112 | 113 | 114 | def parse_data_file(self, file_name): 115 | """Parse file to data list 116 | 117 | Args: 118 | file_name (str): The path of data file 119 | return (list): List of (image path, class_index) tuples 120 | """ 121 | hname = os.uname() 122 | with open(file_name, "r") as f: 123 | data_list = [] 124 | for line in f.readlines(): 125 | line = line.replace('masaito/cub/', 'donhk/dataset/data/') 126 | line = line.replace('/images/', '/') 127 | if 'kruskal' not in hname.nodename and 'yoda' not in hname.nodename and 'goat' not in hname.nodename: 128 | line = line.replace('/research/', '//net/ivcfs4/mnt/data/') 129 | split_line = line.split() 130 | target = split_line[-1] 131 | path = ' '.join(split_line[:-1]) 132 | if not os.path.isabs(path): 133 | path = os.path.join(self.root, path) 134 | target = int(target) 135 | data_list.append((path, target)) 136 | return data_list 137 | 138 | @classmethod 139 | def domains(cls): 140 | return list(cls.image_list.keys()) -------------------------------------------------------------------------------- /vision/datasets/cub200.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Yifei Ji 3 | @contact: jiyf990330@163.com 4 | """ 5 | import os 6 | from typing import Optional 7 | from .imagelist import ImageList 8 | from ._util import download as download_data, check_exits 9 | 10 | 11 | class CUB200(ImageList): 12 | """`Caltech-UCSD Birds-200-2011 `_ \ 13 | is a dataset for fine-grained visual recognition with 11,788 images in 200 bird species. \ 14 | It is an extended version of the CUB-200 dataset, roughly doubling the number of images. 15 | 16 | Args: 17 | root (str): Root directory of dataset 18 | split (str, optional): The dataset split, supports ``train``, or ``test``. 19 | sample_rate (int): The sampling rates to sample random ``training`` images for each category. 20 | Choices include 100, 50, 30, 15. Default: 100. 21 | download (bool, optional): If true, downloads the dataset from the internet and puts it \ 22 | in root directory. If dataset is already downloaded, it is not downloaded again. 23 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 24 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 25 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 26 | 27 | .. note:: In `root`, there will exist following files after downloading. 28 | :: 29 | train/ 30 | test/ 31 | image_list/ 32 | train_100.txt 33 | train_50.txt 34 | train_30.txt 35 | train_15.txt 36 | test.txt 37 | """ 38 | download_list = [ 39 | ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/fa398a321b094a24a347/?dl=1"), 40 | ("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/521ba92bafc04ee69c20/?dl=1"), 41 | ("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/cc7ef72081e64bc7a218/?dl=1"), 42 | ] 43 | image_list = { 44 | "train": "image_list/train_100.txt", 45 | "train100": "image_list/train_100.txt", 46 | "train50": "image_list/train_50.txt", 47 | "train30": "image_list/train_30.txt", 48 | "train15": "image_list/train_15.txt", 49 | "test": "image_list/test.txt", 50 | "test100": "image_list/test.txt", 51 | } 52 | CLASSES = ['001.Black_footed_Albatross', '002.Laysan_Albatross', '003.Sooty_Albatross', '004.Groove_billed_Ani', '005.Crested_Auklet', '006.Least_Auklet', '007.Parakeet_Auklet', '008.Rhinoceros_Auklet', '009.Brewer_Blackbird', '010.Red_winged_Blackbird', '011.Rusty_Blackbird', '012.Yellow_headed_Blackbird', '013.Bobolink', '014.Indigo_Bunting', '015.Lazuli_Bunting', '016.Painted_Bunting', '017.Cardinal', '018.Spotted_Catbird', '019.Gray_Catbird', '020.Yellow_breasted_Chat', '021.Eastern_Towhee', '022.Chuck_will_Widow', '023.Brandt_Cormorant', '024.Red_faced_Cormorant', '025.Pelagic_Cormorant', '026.Bronzed_Cowbird', '027.Shiny_Cowbird', '028.Brown_Creeper', '029.American_Crow', '030.Fish_Crow', '031.Black_billed_Cuckoo', '032.Mangrove_Cuckoo', '033.Yellow_billed_Cuckoo', '034.Gray_crowned_Rosy_Finch', '035.Purple_Finch', '036.Northern_Flicker', '037.Acadian_Flycatcher', '038.Great_Crested_Flycatcher', '039.Least_Flycatcher', '040.Olive_sided_Flycatcher', '041.Scissor_tailed_Flycatcher', '042.Vermilion_Flycatcher', '043.Yellow_bellied_Flycatcher', '044.Frigatebird', '045.Northern_Fulmar', '046.Gadwall', '047.American_Goldfinch', '048.European_Goldfinch', '049.Boat_tailed_Grackle', '050.Eared_Grebe', 53 | '051.Horned_Grebe', '052.Pied_billed_Grebe', '053.Western_Grebe', '054.Blue_Grosbeak', '055.Evening_Grosbeak', '056.Pine_Grosbeak', '057.Rose_breasted_Grosbeak', '058.Pigeon_Guillemot', '059.California_Gull', '060.Glaucous_winged_Gull', '061.Heermann_Gull', '062.Herring_Gull', '063.Ivory_Gull', '064.Ring_billed_Gull', '065.Slaty_backed_Gull', '066.Western_Gull', '067.Anna_Hummingbird', '068.Ruby_throated_Hummingbird', '069.Rufous_Hummingbird', '070.Green_Violetear', '071.Long_tailed_Jaeger', '072.Pomarine_Jaeger', '073.Blue_Jay', '074.Florida_Jay', '075.Green_Jay', '076.Dark_eyed_Junco', '077.Tropical_Kingbird', '078.Gray_Kingbird', '079.Belted_Kingfisher', '080.Green_Kingfisher', '081.Pied_Kingfisher', '082.Ringed_Kingfisher', '083.White_breasted_Kingfisher', '084.Red_legged_Kittiwake', '085.Horned_Lark', '086.Pacific_Loon', '087.Mallard', '088.Western_Meadowlark', '089.Hooded_Merganser', '090.Red_breasted_Merganser', '091.Mockingbird', '092.Nighthawk', '093.Clark_Nutcracker', '094.White_breasted_Nuthatch', '095.Baltimore_Oriole', '096.Hooded_Oriole', '097.Orchard_Oriole', '098.Scott_Oriole', '099.Ovenbird', '100.Brown_Pelican', 54 | '101.White_Pelican', '102.Western_Wood_Pewee', '103.Sayornis', '104.American_Pipit', '105.Whip_poor_Will', '106.Horned_Puffin', '107.Common_Raven', '108.White_necked_Raven', '109.American_Redstart', '110.Geococcyx', '111.Loggerhead_Shrike', '112.Great_Grey_Shrike', '113.Baird_Sparrow', '114.Black_throated_Sparrow', '115.Brewer_Sparrow', '116.Chipping_Sparrow', '117.Clay_colored_Sparrow', '118.House_Sparrow', '119.Field_Sparrow', '120.Fox_Sparrow', '121.Grasshopper_Sparrow', '122.Harris_Sparrow', '123.Henslow_Sparrow', '124.Le_Conte_Sparrow', '125.Lincoln_Sparrow', '126.Nelson_Sharp_tailed_Sparrow', '127.Savannah_Sparrow', '128.Seaside_Sparrow', '129.Song_Sparrow', '130.Tree_Sparrow', '131.Vesper_Sparrow', '132.White_crowned_Sparrow', '133.White_throated_Sparrow', '134.Cape_Glossy_Starling', '135.Bank_Swallow', '136.Barn_Swallow', '137.Cliff_Swallow', '138.Tree_Swallow', '139.Scarlet_Tanager', '140.Summer_Tanager', '141.Artic_Tern', '142.Black_Tern', '143.Caspian_Tern', '144.Common_Tern', '145.Elegant_Tern', '146.Forsters_Tern', '147.Least_Tern', '148.Green_tailed_Towhee', '149.Brown_Thrasher', '150.Sage_Thrasher', 55 | '151.Black_capped_Vireo', '152.Blue_headed_Vireo', '153.Philadelphia_Vireo', '154.Red_eyed_Vireo', '155.Warbling_Vireo', '156.White_eyed_Vireo', '157.Yellow_throated_Vireo', '158.Bay_breasted_Warbler', '159.Black_and_white_Warbler', '160.Black_throated_Blue_Warbler', '161.Blue_winged_Warbler', '162.Canada_Warbler', '163.Cape_May_Warbler', '164.Cerulean_Warbler', '165.Chestnut_sided_Warbler', '166.Golden_winged_Warbler', '167.Hooded_Warbler', '168.Kentucky_Warbler', '169.Magnolia_Warbler', '170.Mourning_Warbler', '171.Myrtle_Warbler', '172.Nashville_Warbler', '173.Orange_crowned_Warbler', '174.Palm_Warbler', '175.Pine_Warbler', '176.Prairie_Warbler', '177.Prothonotary_Warbler', '178.Swainson_Warbler', '179.Tennessee_Warbler', '180.Wilson_Warbler', '181.Worm_eating_Warbler', '182.Yellow_Warbler', '183.Northern_Waterthrush', '184.Louisiana_Waterthrush', '185.Bohemian_Waxwing', '186.Cedar_Waxwing', '187.American_Three_toed_Woodpecker', '188.Pileated_Woodpecker', '189.Red_bellied_Woodpecker', '190.Red_cockaded_Woodpecker', '191.Red_headed_Woodpecker', '192.Downy_Woodpecker', '193.Bewick_Wren', '194.Cactus_Wren', '195.Carolina_Wren', '196.House_Wren', '197.Marsh_Wren', '198.Rock_Wren', '199.Winter_Wren', '200.Common_Yellowthroat'] 56 | 57 | def __init__(self, root: str, split: str, sample_rate: Optional[int] =100, download: Optional[bool] = False, **kwargs): 58 | 59 | if split == 'train': 60 | list_name = 'train' + str(sample_rate) 61 | assert list_name in self.image_list 62 | data_list_file = os.path.join(root, self.image_list[list_name]) 63 | else: 64 | data_list_file = os.path.join(root, self.image_list['test']) 65 | 66 | if download: 67 | list(map(lambda args: download_data(root, *args), self.download_list)) 68 | else: 69 | list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) 70 | 71 | super(CUB200, self).__init__(root, CUB200.CLASSES, data_list_file=data_list_file, **kwargs) 72 | -------------------------------------------------------------------------------- /vision/datasets/domainnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import os 6 | from typing import Optional 7 | from .imagelist import ImageList 8 | from ._util import download as download_data, check_exits 9 | 10 | 11 | class DomainNet(ImageList): 12 | """`DomainNet `_ (cleaned version, recommended) 13 | 14 | See `Moment Matching for Multi-Source Domain Adaptation `_ for details. 15 | 16 | Args: 17 | root (str): Root directory of dataset 18 | task (str): The task (domain) to create dataset. Choices include ``'c'``:clipart, \ 19 | ``'i'``: infograph, ``'p'``: painting, ``'q'``: quickdraw, ``'r'``: real, ``'s'``: sketch 20 | split (str, optional): The dataset split, supports ``train``, or ``test``. 21 | download (bool, optional): If true, downloads the dataset from the internet and puts it \ 22 | in root directory. If dataset is already downloaded, it is not downloaded again. 23 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 24 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 25 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 26 | 27 | .. note:: In `root`, there will exist following files after downloading. 28 | :: 29 | clipart/ 30 | infograph/ 31 | painting/ 32 | quickdraw/ 33 | real/ 34 | sketch/ 35 | image_list/ 36 | clipart.txt 37 | ... 38 | """ 39 | download_list = [ 40 | ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/90ecb35bbd374e5e8c41/?dl=1"), 41 | ("clipart", "clipart.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip"), 42 | ("infograph", "infograph.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip"), 43 | ("painting", "painting.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip"), 44 | ("quickdraw", "quickdraw.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip"), 45 | ("real", "real.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip"), 46 | ("sketch", "sketch.zip", "http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip"), 47 | ] 48 | image_list = { 49 | "c": "clipart", 50 | "i": "infograph", 51 | "p": "painting", 52 | "q": "quickdraw", 53 | "r": "real", 54 | "s": "sketch", 55 | } 56 | CLASSES = ['aircraft_carrier', 'airplane', 'alarm_clock', 'ambulance', 'angel', 'animal_migration', 'ant', 'anvil', 57 | 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball', 'baseball_bat', 58 | 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench', 59 | 'bicycle', 'binoculars', 'bird', 'birthday_cake', 'blackberry', 'blueberry', 'book', 'boomerang', 60 | 'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 61 | 'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 62 | 'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling_fan', 63 | 'cello', 'cell_phone', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud', 64 | 'coffee_cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile', 65 | 'crown', 'cruise_ship', 'cup', 'diamond', 'dishwasher', 'diving_board', 'dog', 'dolphin', 'donut', 66 | 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant', 67 | 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire_hydrant', 68 | 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip_flops', 'floor_lamp', 'flower', 69 | 'flying_saucer', 'foot', 'fork', 'frog', 'frying_pan', 'garden', 'garden_hose', 'giraffe', 'goatee', 70 | 'golf_club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones', 71 | 'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey_puck', 'hockey_stick', 'horse', 'hospital', 72 | 'hot_air_balloon', 'hot_dog', 'hot_tub', 'hourglass', 'house', 'house_plant', 'hurricane', 'ice_cream', 73 | 'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'knife', 'ladder', 'lantern', 'laptop', 'leaf', 74 | 'leg', 'light_bulb', 'lighter', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster', 75 | 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave', 76 | 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom', 77 | 'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paintbrush', 78 | 'paint_can', 'palm_tree', 'panda', 'pants', 'paper_clip', 'parachute', 'parrot', 'passport', 'peanut', 79 | 'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup_truck', 'picture_frame', 'pig', 'pillow', 80 | 'pineapple', 'pizza', 'pliers', 'police_car', 'pond', 'pool', 'popsicle', 'postcard', 'potato', 81 | 'power_outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote_control', 82 | 'rhinoceros', 'rifle', 'river', 'roller_coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw', 83 | 'saxophone', 'school_bus', 'scissors', 'scorpion', 'screwdriver', 'sea_turtle', 'see_saw', 'shark', 84 | 'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping_bag', 85 | 'smiley_face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer_ball', 'sock', 'speedboat', 86 | 'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 87 | 'stethoscope', 'stitches', 'stop_sign', 'stove', 'strawberry', 'streetlight', 'string_bean', 'submarine', 88 | 'suitcase', 'sun', 'swan', 'sweater', 'swing_set', 'sword', 'syringe', 'table', 'teapot', 'teddy-bear', 89 | 'telephone', 'television', 'tennis_racquet', 'tent', 'The_Eiffel_Tower', 'The_Great_Wall_of_China', 90 | 'The_Mona_Lisa', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 91 | 'tractor', 'traffic_light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 't-shirt', 92 | 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing_machine', 'watermelon', 'waterslide', 93 | 'whale', 'wheel', 'windmill', 'wine_bottle', 'wine_glass', 'wristwatch', 'yoga', 'zebra', 'zigzag'] 94 | 95 | def __init__(self, root: str, task: str, split: Optional[str] = 'train', download: Optional[float] = False, **kwargs): 96 | assert task in self.image_list 97 | assert split in ['train', 'test'] 98 | data_list_file = os.path.join(root, "image_list", "{}_{}.txt".format(self.image_list[task], split)) 99 | print("loading {}".format(data_list_file)) 100 | 101 | if download: 102 | list(map(lambda args: download_data(root, *args), self.download_list)) 103 | else: 104 | list(map(lambda args: check_exits(root, args[0]), self.download_list)) 105 | 106 | super(DomainNet, self).__init__(root, DomainNet.CLASSES, data_list_file=data_list_file, **kwargs) 107 | 108 | @classmethod 109 | def domains(cls): 110 | return list(cls.image_list.keys()) 111 | -------------------------------------------------------------------------------- /vision/datasets/imagelist.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import os 6 | from typing import Optional, Callable, Tuple, Any, List 7 | import torchvision.datasets as datasets 8 | from torchvision.datasets.folder import default_loader 9 | 10 | 11 | class ImageList(datasets.VisionDataset): 12 | """A generic Dataset class for image classification 13 | 14 | Args: 15 | root (str): Root directory of dataset 16 | classes (list[str]): The names of all the classes 17 | data_list_file (str): File to read the image list from. 18 | transform (callable, optional): A function/transform that takes in an PIL image \ 19 | and returns a transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 20 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 21 | 22 | .. note:: In `data_list_file`, each line has 2 values in the following format. 23 | :: 24 | source_dir/dog_xxx.png 0 25 | source_dir/cat_123.png 1 26 | target_dir/dog_xxy.png 0 27 | target_dir/cat_nsdf3.png 1 28 | 29 | The first value is the relative path of an image, and the second value is the label of the corresponding image. 30 | If your data_list_file has different formats, please over-ride :meth:`~ImageList.parse_data_file`. 31 | """ 32 | 33 | def __init__(self, root: str, classes: List[str], data_list_file: str, 34 | transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): 35 | super().__init__(root, transform=transform, target_transform=target_transform) 36 | self.samples = self.parse_data_file(data_list_file) 37 | self.classes = classes 38 | self.class_to_idx = {cls: idx 39 | for idx, cls in enumerate(self.classes)} 40 | self.loader = default_loader 41 | self.data_list_file = data_list_file 42 | 43 | def __getitem__(self, index: int) -> Tuple[Any, int]: 44 | """ 45 | Args: 46 | index (int): Index 47 | return (tuple): (image, target) where target is index of the target class. 48 | """ 49 | path, target = self.samples[index] 50 | img = self.loader(path) 51 | if self.transform is not None: 52 | img = self.transform(img) 53 | if self.target_transform is not None and target is not None: 54 | target = self.target_transform(target) 55 | return img, target 56 | 57 | def __len__(self) -> int: 58 | return len(self.samples) 59 | 60 | def parse_data_file(self, file_name: str) -> List[Tuple[str, int]]: 61 | """Parse file to data list 62 | 63 | Args: 64 | file_name (str): The path of data file 65 | return (list): List of (image path, class_index) tuples 66 | """ 67 | with open(file_name, "r") as f: 68 | data_list = [] 69 | for line in f.readlines(): 70 | split_line = line.split() 71 | target = split_line[-1] 72 | path = ' '.join(split_line[:-1]) 73 | if not os.path.isabs(path): 74 | path = os.path.join(self.root, path) 75 | target = int(target) 76 | data_list.append((path, target)) 77 | return data_list 78 | 79 | @property 80 | def num_classes(self) -> int: 81 | """Number of classes""" 82 | return len(self.classes) 83 | 84 | @classmethod 85 | def domains(cls): 86 | """All possible domain in this dataset""" 87 | raise NotImplemented -------------------------------------------------------------------------------- /vision/datasets/office31.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | from typing import Optional 6 | import os 7 | from .imagelist import ImageList 8 | from ._util import download as download_data, check_exits 9 | 10 | 11 | class Office31(ImageList): 12 | """Office31 Dataset. 13 | 14 | Args: 15 | root (str): Root directory of dataset 16 | task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \ 17 | ``'D'``: dslr and ``'W'``: webcam. 18 | download (bool, optional): If true, downloads the dataset from the internet and puts it \ 19 | in root directory. If dataset is already downloaded, it is not downloaded again. 20 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 21 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 22 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 23 | 24 | .. note:: In `root`, there will exist following files after downloading. 25 | :: 26 | amazon/ 27 | images/ 28 | backpack/ 29 | *.jpg 30 | ... 31 | dslr/ 32 | webcam/ 33 | image_list/ 34 | amazon.txt 35 | dslr.txt 36 | webcam.txt 37 | """ 38 | download_list = [ 39 | ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/d9bca681c71249f19da2/?dl=1"), 40 | ("amazon", "amazon.tgz", "https://cloud.tsinghua.edu.cn/f/edc8d1bba1c740dc821c/?dl=1"), 41 | ("dslr", "dslr.tgz", "https://cloud.tsinghua.edu.cn/f/ca6df562b7e64850ad7f/?dl=1"), 42 | ("webcam", "webcam.tgz", "https://cloud.tsinghua.edu.cn/f/82b24ed2e08f4a3c8888/?dl=1"), 43 | ] 44 | image_list = { 45 | "A": "image_list/amazon.txt", 46 | "D": "image_list/dslr.txt", 47 | "W": "image_list/webcam.txt" 48 | } 49 | CLASSES = ['back_pack', 'bike', 'bike_helmet', 'bookcase', 'bottle', 'calculator', 'desk_chair', 'desk_lamp', 50 | 'desktop_computer', 'file_cabinet', 'headphones', 'keyboard', 'laptop_computer', 'letter_tray', 51 | 'mobile_phone', 'monitor', 'mouse', 'mug', 'paper_notebook', 'pen', 'phone', 'printer', 'projector', 52 | 'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can'] 53 | 54 | def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs): 55 | assert task in self.image_list 56 | data_list_file = os.path.join(root, self.image_list[task]) 57 | 58 | if download: 59 | list(map(lambda args: download_data(root, *args), self.download_list)) 60 | else: 61 | list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) 62 | 63 | super(Office31, self).__init__(root, Office31.CLASSES, data_list_file=data_list_file, **kwargs) 64 | 65 | @classmethod 66 | def domains(cls): 67 | return list(cls.image_list.keys()) -------------------------------------------------------------------------------- /vision/datasets/officehome.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import os 6 | from typing import Optional 7 | from .imagelist import ImageList 8 | from ._util import download as download_data, check_exits 9 | 10 | 11 | class OfficeHome(ImageList): 12 | """`OfficeHome `_ Dataset. 13 | 14 | Args: 15 | root (str): Root directory of dataset 16 | task (str): The task (domain) to create dataset. Choices include ``'Ar'``: Art, \ 17 | ``'Cl'``: Clipart, ``'Pr'``: Product and ``'Rw'``: Real_World. 18 | download (bool, optional): If true, downloads the dataset from the internet and puts it \ 19 | in root directory. If dataset is already downloaded, it is not downloaded again. 20 | transform (callable, optional): A function/transform that takes in an PIL image and returns a \ 21 | transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. 22 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 23 | 24 | .. note:: In `root`, there will exist following files after downloading. 25 | :: 26 | Art/ 27 | Alarm_Clock/*.jpg 28 | ... 29 | Clipart/ 30 | Product/ 31 | Real_World/ 32 | image_list/ 33 | Art.txt 34 | Clipart.txt 35 | Product.txt 36 | Real_World.txt 37 | """ 38 | download_list = [ 39 | ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/ca3a3b6a8d554905b4cd/?dl=1"), 40 | ("Art", "Art.tgz", "https://cloud.tsinghua.edu.cn/f/4691878067d04755beab/?dl=1"), 41 | ("Clipart", "Clipart.tgz", "https://cloud.tsinghua.edu.cn/f/0d41e7da4558408ea5aa/?dl=1"), 42 | ("Product", "Product.tgz", "https://cloud.tsinghua.edu.cn/f/76186deacd7c4fa0a679/?dl=1"), 43 | ("Real_World", "Real_World.tgz", "https://cloud.tsinghua.edu.cn/f/dee961894cc64b1da1d7/?dl=1") 44 | ] 45 | image_list = { 46 | "Ar": "image_list/Art.txt", 47 | "Cl": "image_list/Clipart.txt", 48 | "Pr": "image_list/Product.txt", 49 | "Rw": "image_list/Real_World.txt", 50 | } 51 | CLASSES = ['Drill', 'Exit_Sign', 'Bottle', 'Glasses', 'Computer', 'File_Cabinet', 'Shelf', 'Toys', 'Sink', 52 | 'Laptop', 'Kettle', 'Folder', 'Keyboard', 'Flipflops', 'Pencil', 'Bed', 'Hammer', 'ToothBrush', 'Couch', 53 | 'Bike', 'Postit_Notes', 'Mug', 'Webcam', 'Desk_Lamp', 'Telephone', 'Helmet', 'Mouse', 'Pen', 'Monitor', 54 | 'Mop', 'Sneakers', 'Notebook', 'Backpack', 'Alarm_Clock', 'Push_Pin', 'Paper_Clip', 'Batteries', 'Radio', 55 | 'Fan', 'Ruler', 'Pan', 'Screwdriver', 'Trash_Can', 'Printer', 'Speaker', 'Eraser', 'Bucket', 'Chair', 56 | 'Calendar', 'Calculator', 'Flowers', 'Lamp_Shade', 'Spoon', 'Candles', 'Clipboards', 'Scissors', 'TV', 57 | 'Curtains', 'Fork', 'Soda', 'Table', 'Knives', 'Oven', 'Refrigerator', 'Marker'] 58 | 59 | def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs): 60 | assert task in self.image_list 61 | data_list_file = os.path.join(root, self.image_list[task]) 62 | 63 | if download: 64 | list(map(lambda args: download_data(root, *args), self.download_list)) 65 | else: 66 | list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) 67 | 68 | super(OfficeHome, self).__init__(root, OfficeHome.CLASSES, data_list_file=data_list_file, **kwargs) 69 | 70 | @classmethod 71 | def domains(cls): 72 | return list(cls.image_list.keys()) -------------------------------------------------------------------------------- /vision/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .ibn import * 3 | from .digits import * 4 | 5 | __all__ = ['resnet', 'digits', 'ibn'] 6 | -------------------------------------------------------------------------------- /vision/models/digits.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Junguang Jiang 3 | @contact: JiangJunguang1123@outlook.com 4 | """ 5 | import torch.nn as nn 6 | 7 | 8 | class LeNet(nn.Sequential): 9 | def __init__(self, num_classes=10): 10 | super(LeNet, self).__init__( 11 | nn.Conv2d(1, 20, kernel_size=5), 12 | nn.MaxPool2d(2), 13 | nn.ReLU(), 14 | nn.Conv2d(20, 50, kernel_size=5), 15 | nn.Dropout2d(p=0.5), 16 | nn.MaxPool2d(2), 17 | nn.ReLU(), 18 | nn.Flatten(start_dim=1), 19 | nn.Linear(50 * 4 * 4, 500), 20 | nn.ReLU(), 21 | nn.Dropout(p=0.5), 22 | ) 23 | self.num_classes = num_classes 24 | self.out_features = 500 25 | 26 | def copy_head(self): 27 | return nn.Linear(500, self.num_classes) 28 | 29 | 30 | class DTN(nn.Sequential): 31 | def __init__(self, num_classes=10): 32 | super(DTN, self).__init__( 33 | nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2), 34 | nn.BatchNorm2d(64), 35 | nn.Dropout2d(0.1), 36 | nn.ReLU(), 37 | nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), 38 | nn.BatchNorm2d(128), 39 | nn.Dropout2d(0.3), 40 | nn.ReLU(), 41 | nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), 42 | nn.BatchNorm2d(256), 43 | nn.Dropout2d(0.5), 44 | nn.ReLU(), 45 | nn.Flatten(start_dim=1), 46 | nn.Linear(256 * 4 * 4, 512), 47 | nn.BatchNorm1d(512), 48 | nn.ReLU(), 49 | nn.Dropout(), 50 | ) 51 | self.num_classes = num_classes 52 | self.out_features = 512 53 | 54 | def copy_head(self): 55 | return nn.Linear(512, self.num_classes) 56 | 57 | 58 | 59 | def lenet(pretrained=False, **kwargs): 60 | """LeNet model from 61 | `"Gradient-based learning applied to document recognition" `_ 62 | 63 | Args: 64 | num_classes (int): number of classes. Default: 10 65 | 66 | .. note:: 67 | The input image size must be 28 x 28. 68 | 69 | """ 70 | return LeNet(**kwargs) 71 | 72 | 73 | def dtn(pretrained=False, **kwargs): 74 | """ DTN model 75 | 76 | Args: 77 | num_classes (int): number of classes. Default: 10 78 | 79 | .. note:: 80 | The input image size must be 32 x 32. 81 | 82 | """ 83 | return DTN(**kwargs) -------------------------------------------------------------------------------- /vision/models/ibn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/XingangPan/IBN-Net 3 | @author: Baixu Chen 4 | @contact: cbx_99_hasta@outlook.com 5 | """ 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | 10 | __all__ = ['resnet18_ibn_a', 'resnet18_ibn_b', 'resnet34_ibn_a', 'resnet34_ibn_b', 'resnet50_ibn_a', 'resnet50_ibn_b', 11 | 'resnet101_ibn_a', 'resnet101_ibn_b'] 12 | 13 | model_urls = { 14 | 'resnet18_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth', 15 | 'resnet34_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth', 16 | 'resnet50_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth', 17 | 'resnet101_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth', 18 | 'resnet18_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_b-bc2f3c11.pth', 19 | 'resnet34_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_b-04134c37.pth', 20 | 'resnet50_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_b-9ca61e85.pth', 21 | 'resnet101_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_b-c55f6dba.pth', 22 | } 23 | 24 | 25 | class IBN(nn.Module): 26 | r"""Instance-Batch Normalization layer from 27 | `Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net (ECCV 2018) 28 | `_. 29 | 30 | Given input feature map :math:`f\_input` of dimension :math:`(C,H,W)`, we first split :math:`f\_input` into 31 | two parts along `channel` dimension. They are denoted as :math:`f_1` of dimension :math:`(C_1,H,W)` and 32 | :math:`f_2` of dimension :math:`(C_2,H,W)`, where :math:`C_1+C_2=C`. Then we pass :math:`f_1` and :math:`f_2` 33 | through IN and BN layer, respectively, to get :math:`IN(f_1)` and :math:`BN(f_2)`. Last, we concat them along 34 | `channel` dimension to create :math:`f\_output=concat(IN(f_1), BN(f_2))`. 35 | 36 | Args: 37 | planes (int): Number of channels for the input tensor 38 | ratio (float): Ratio of instance normalization in the IBN layer 39 | """ 40 | 41 | def __init__(self, planes, ratio=0.5): 42 | super(IBN, self).__init__() 43 | self.half = int(planes * ratio) 44 | self.IN = nn.InstanceNorm2d(self.half, affine=True) 45 | self.BN = nn.BatchNorm2d(planes - self.half) 46 | 47 | def forward(self, x): 48 | split = torch.split(x, self.half, 1) 49 | out1 = self.IN(split[0].contiguous()) 50 | out2 = self.BN(split[1].contiguous()) 51 | out = torch.cat((out1, out2), 1) 52 | return out 53 | 54 | 55 | class BasicBlock_IBN(nn.Module): 56 | expansion = 1 57 | 58 | def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None): 59 | super(BasicBlock_IBN, self).__init__() 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 61 | padding=1, bias=False) 62 | if ibn == 'a': 63 | self.bn1 = IBN(planes) 64 | else: 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.relu = nn.ReLU(inplace=True) 67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) 68 | self.bn2 = nn.BatchNorm2d(planes) 69 | self.IN = nn.InstanceNorm2d(planes, affine=True) if ibn == 'b' else None 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | 83 | if self.downsample is not None: 84 | residual = self.downsample(x) 85 | 86 | out += residual 87 | if self.IN is not None: 88 | out = self.IN(out) 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class Bottleneck_IBN(nn.Module): 95 | expansion = 4 96 | 97 | def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None): 98 | super(Bottleneck_IBN, self).__init__() 99 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 100 | if ibn == 'a': 101 | self.bn1 = IBN(planes) 102 | else: 103 | self.bn1 = nn.BatchNorm2d(planes) 104 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 105 | padding=1, bias=False) 106 | self.bn2 = nn.BatchNorm2d(planes) 107 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 108 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 109 | self.IN = nn.InstanceNorm2d(planes * 4, affine=True) if ibn == 'b' else None 110 | self.relu = nn.ReLU(inplace=True) 111 | self.downsample = downsample 112 | self.stride = stride 113 | 114 | def forward(self, x): 115 | residual = x 116 | 117 | out = self.conv1(x) 118 | out = self.bn1(out) 119 | out = self.relu(out) 120 | 121 | out = self.conv2(out) 122 | out = self.bn2(out) 123 | out = self.relu(out) 124 | 125 | out = self.conv3(out) 126 | out = self.bn3(out) 127 | 128 | if self.downsample is not None: 129 | residual = self.downsample(x) 130 | 131 | out += residual 132 | if self.IN is not None: 133 | out = self.IN(out) 134 | out = self.relu(out) 135 | 136 | return out 137 | 138 | 139 | class ResNet_IBN(nn.Module): 140 | r""" 141 | ResNets-IBN without fully connected layer 142 | """ 143 | 144 | def __init__(self, block, layers, ibn_cfg=('a', 'a', 'a', None)): 145 | self.inplanes = 64 146 | super(ResNet_IBN, self).__init__() 147 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 148 | bias=False) 149 | if ibn_cfg[0] == 'b': 150 | self.bn1 = nn.InstanceNorm2d(64, affine=True) 151 | else: 152 | self.bn1 = nn.BatchNorm2d(64) 153 | self.relu = nn.ReLU(inplace=True) 154 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 155 | self.layer1 = self._make_layer(block, 64, layers[0], ibn=ibn_cfg[0]) 156 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, ibn=ibn_cfg[1]) 157 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, ibn=ibn_cfg[2]) 158 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, ibn=ibn_cfg[3]) 159 | self._out_features = 512 * block.expansion 160 | 161 | for m in self.modules(): 162 | if isinstance(m, nn.Conv2d): 163 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 164 | m.weight.data.normal_(0, math.sqrt(2. / n)) 165 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): 166 | m.weight.data.fill_(1) 167 | m.bias.data.zero_() 168 | 169 | def _make_layer(self, block, planes, blocks, stride=1, ibn=None): 170 | downsample = None 171 | if stride != 1 or self.inplanes != planes * block.expansion: 172 | downsample = nn.Sequential( 173 | nn.Conv2d(self.inplanes, planes * block.expansion, 174 | kernel_size=1, stride=stride, bias=False), 175 | nn.BatchNorm2d(planes * block.expansion), 176 | ) 177 | 178 | layers = [] 179 | layers.append(block(self.inplanes, planes, 180 | None if ibn == 'b' else ibn, 181 | stride, downsample)) 182 | self.inplanes = planes * block.expansion 183 | for i in range(1, blocks): 184 | layers.append(block(self.inplanes, planes, 185 | None if (ibn == 'b' and i < blocks - 1) else ibn)) 186 | 187 | return nn.Sequential(*layers) 188 | 189 | def forward(self, x): 190 | """""" 191 | x = self.conv1(x) 192 | x = self.bn1(x) 193 | x = self.relu(x) 194 | x = self.maxpool(x) 195 | 196 | x = self.layer1(x) 197 | x = self.layer2(x) 198 | x = self.layer3(x) 199 | x = self.layer4(x) 200 | 201 | return x 202 | 203 | @property 204 | def out_features(self) -> int: 205 | """The dimension of output features""" 206 | return self._out_features 207 | 208 | 209 | def resnet18_ibn_a(pretrained=False): 210 | """Constructs a ResNet-18-IBN-a model. 211 | 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = ResNet_IBN(block=BasicBlock_IBN, 216 | layers=[2, 2, 2, 2], 217 | ibn_cfg=('a', 'a', 'a', None)) 218 | if pretrained: 219 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet18_ibn_a']), strict=False) 220 | return model 221 | 222 | 223 | def resnet34_ibn_a(pretrained=False): 224 | """Constructs a ResNet-34-IBN-a model. 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | """ 229 | model = ResNet_IBN(block=BasicBlock_IBN, 230 | layers=[3, 4, 6, 3], 231 | ibn_cfg=('a', 'a', 'a', None)) 232 | if pretrained: 233 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet34_ibn_a']), strict=False) 234 | return model 235 | 236 | 237 | def resnet50_ibn_a(pretrained=False): 238 | """Constructs a ResNet-50-IBN-a model. 239 | 240 | Args: 241 | pretrained (bool): If True, returns a model pre-trained on ImageNet 242 | """ 243 | model = ResNet_IBN(block=Bottleneck_IBN, 244 | layers=[3, 4, 6, 3], 245 | ibn_cfg=('a', 'a', 'a', None)) 246 | if pretrained: 247 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet50_ibn_a']), strict=False) 248 | return model 249 | 250 | 251 | def resnet101_ibn_a(pretrained=False): 252 | """Constructs a ResNet-101-IBN-a model. 253 | 254 | Args: 255 | pretrained (bool): If True, returns a model pre-trained on ImageNet 256 | """ 257 | model = ResNet_IBN(block=Bottleneck_IBN, 258 | layers=[3, 4, 23, 3], 259 | ibn_cfg=('a', 'a', 'a', None)) 260 | if pretrained: 261 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet101_ibn_a']), strict=False) 262 | return model 263 | 264 | 265 | def resnet18_ibn_b(pretrained=False): 266 | """Constructs a ResNet-18-IBN-b model. 267 | 268 | Args: 269 | pretrained (bool): If True, returns a model pre-trained on ImageNet 270 | """ 271 | model = ResNet_IBN(block=BasicBlock_IBN, 272 | layers=[2, 2, 2, 2], 273 | ibn_cfg=('b', 'b', None, None)) 274 | if pretrained: 275 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet18_ibn_b']), strict=False) 276 | return model 277 | 278 | 279 | def resnet34_ibn_b(pretrained=False): 280 | """Constructs a ResNet-34-IBN-b model. 281 | 282 | Args: 283 | pretrained (bool): If True, returns a model pre-trained on ImageNet 284 | """ 285 | model = ResNet_IBN(block=BasicBlock_IBN, 286 | layers=[3, 4, 6, 3], 287 | ibn_cfg=('b', 'b', None, None)) 288 | if pretrained: 289 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet34_ibn_b']), strict=False) 290 | return model 291 | 292 | 293 | def resnet50_ibn_b(pretrained=False): 294 | """Constructs a ResNet-50-IBN-b model. 295 | 296 | Args: 297 | pretrained (bool): If True, returns a model pre-trained on ImageNet 298 | """ 299 | model = ResNet_IBN(block=Bottleneck_IBN, 300 | layers=[3, 4, 6, 3], 301 | ibn_cfg=('b', 'b', None, None)) 302 | if pretrained: 303 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet50_ibn_b']), strict=False) 304 | return model 305 | 306 | 307 | def resnet101_ibn_b(pretrained=False): 308 | """Constructs a ResNet-101-IBN-b model. 309 | 310 | Args: 311 | pretrained (bool): If True, returns a model pre-trained on ImageNet 312 | """ 313 | model = ResNet_IBN(block=Bottleneck_IBN, 314 | layers=[3, 4, 23, 3], 315 | ibn_cfg=('b', 'b', None, None)) 316 | if pretrained: 317 | model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet101_ibn_b']), strict=False) 318 | return model 319 | -------------------------------------------------------------------------------- /vision/models/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified based on torchvision.models.resnet. 3 | @author: Junguang Jiang 4 | @contact: JiangJunguang1123@outlook.com 5 | """ 6 | 7 | import torch.nn as nn 8 | from torchvision import models 9 | from torchvision.models.utils import load_state_dict_from_url 10 | from torchvision.models.resnet import BasicBlock, Bottleneck, model_urls 11 | import copy 12 | 13 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 14 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 15 | 'wide_resnet50_2', 'wide_resnet101_2'] 16 | 17 | 18 | class ResNet(models.ResNet): 19 | """ResNets without fully connected layer""" 20 | 21 | def __init__(self, *args, **kwargs): 22 | super(ResNet, self).__init__(*args, **kwargs) 23 | self._out_features = self.fc.in_features 24 | 25 | def forward(self, x): 26 | """""" 27 | x = self.conv1(x) 28 | x = self.bn1(x) 29 | x = self.relu(x) 30 | x = self.maxpool(x) 31 | 32 | x = self.layer1(x) 33 | x = self.layer2(x) 34 | x = self.layer3(x) 35 | x = self.layer4(x) 36 | 37 | # x = self.avgpool(x) 38 | # x = torch.flatten(x, 1) 39 | # x = x.view(-1, self._out_features) 40 | return x 41 | 42 | @property 43 | def out_features(self) -> int: 44 | """The dimension of output features""" 45 | return self._out_features 46 | 47 | def copy_head(self) -> nn.Module: 48 | """Copy the origin fully connected layer""" 49 | return copy.deepcopy(self.fc) 50 | 51 | 52 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 53 | model = ResNet(block, layers, **kwargs) 54 | if pretrained: 55 | model_dict = model.state_dict() 56 | pretrained_dict = load_state_dict_from_url(model_urls[arch], 57 | progress=progress) 58 | # remove keys from pretrained dict that doesn't appear in model dict 59 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 60 | model.load_state_dict(pretrained_dict, strict=False) 61 | return model 62 | 63 | 64 | def resnet18(pretrained=False, progress=True, **kwargs): 65 | r"""ResNet-18 model from 66 | `"Deep Residual Learning for Image Recognition" `_ 67 | 68 | Args: 69 | pretrained (bool): If True, returns a model pre-trained on ImageNet 70 | progress (bool): If True, displays a progress bar of the download to stderr 71 | """ 72 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 73 | **kwargs) 74 | 75 | 76 | def resnet34(pretrained=False, progress=True, **kwargs): 77 | r"""ResNet-34 model from 78 | `"Deep Residual Learning for Image Recognition" `_ 79 | 80 | Args: 81 | pretrained (bool): If True, returns a model pre-trained on ImageNet 82 | progress (bool): If True, displays a progress bar of the download to stderr 83 | """ 84 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 85 | **kwargs) 86 | 87 | 88 | def resnet50(pretrained=False, progress=True, **kwargs): 89 | r"""ResNet-50 model from 90 | `"Deep Residual Learning for Image Recognition" `_ 91 | 92 | Args: 93 | pretrained (bool): If True, returns a model pre-trained on ImageNet 94 | progress (bool): If True, displays a progress bar of the download to stderr 95 | """ 96 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 97 | **kwargs) 98 | 99 | 100 | def resnet101(pretrained=False, progress=True, **kwargs): 101 | r"""ResNet-101 model from 102 | `"Deep Residual Learning for Image Recognition" `_ 103 | 104 | Args: 105 | pretrained (bool): If True, returns a model pre-trained on ImageNet 106 | progress (bool): If True, displays a progress bar of the download to stderr 107 | """ 108 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 109 | **kwargs) 110 | 111 | 112 | def resnet152(pretrained=False, progress=True, **kwargs): 113 | r"""ResNet-152 model from 114 | `"Deep Residual Learning for Image Recognition" `_ 115 | 116 | Args: 117 | pretrained (bool): If True, returns a model pre-trained on ImageNet 118 | progress (bool): If True, displays a progress bar of the download to stderr 119 | """ 120 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 121 | **kwargs) 122 | 123 | 124 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 125 | r"""ResNeXt-50 32x4d model from 126 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 127 | 128 | Args: 129 | pretrained (bool): If True, returns a model pre-trained on ImageNet 130 | progress (bool): If True, displays a progress bar of the download to stderr 131 | """ 132 | kwargs['groups'] = 32 133 | kwargs['width_per_group'] = 4 134 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 135 | pretrained, progress, **kwargs) 136 | 137 | 138 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 139 | r"""ResNeXt-101 32x8d model from 140 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 141 | 142 | Args: 143 | pretrained (bool): If True, returns a model pre-trained on ImageNet 144 | progress (bool): If True, displays a progress bar of the download to stderr 145 | """ 146 | kwargs['groups'] = 32 147 | kwargs['width_per_group'] = 8 148 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 149 | pretrained, progress, **kwargs) 150 | 151 | 152 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 153 | r"""Wide ResNet-50-2 model from 154 | `"Wide Residual Networks" `_ 155 | 156 | The model is the same as ResNet except for the bottleneck number of channels 157 | which is twice larger in every block. The number of channels in outer 1x1 158 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 159 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 160 | 161 | Args: 162 | pretrained (bool): If True, returns a model pre-trained on ImageNet 163 | progress (bool): If True, displays a progress bar of the download to stderr 164 | """ 165 | kwargs['width_per_group'] = 64 * 2 166 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 167 | pretrained, progress, **kwargs) 168 | 169 | 170 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 171 | r"""Wide ResNet-101-2 model from 172 | `"Wide Residual Networks" `_ 173 | 174 | The model is the same as ResNet except for the bottleneck number of channels 175 | which is twice larger in every block. The number of channels in outer 1x1 176 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 177 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 178 | 179 | Args: 180 | pretrained (bool): If True, returns a model pre-trained on ImageNet 181 | progress (bool): If True, displays a progress bar of the download to stderr 182 | """ 183 | kwargs['width_per_group'] = 64 * 2 184 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 185 | pretrained, progress, **kwargs) 186 | -------------------------------------------------------------------------------- /vision/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | from torchvision.transforms import Normalize 7 | 8 | 9 | class ResizeImage(object): 10 | """Resize the input PIL Image to the given size. 11 | 12 | Args: 13 | size (sequence or int): Desired output size. If size is a sequence like 14 | (h, w), output size will be matched to this. If size is an int, 15 | output size will be (size, size) 16 | """ 17 | 18 | def __init__(self, size): 19 | if isinstance(size, int): 20 | self.size = (int(size), int(size)) 21 | else: 22 | self.size = size 23 | 24 | def __call__(self, img): 25 | th, tw = self.size 26 | return img.resize((th, tw)) 27 | 28 | def __repr__(self): 29 | return self.__class__.__name__ + '(size={0})'.format(self.size) 30 | 31 | 32 | class MultipleApply: 33 | """Apply a list of transformations to an image and get multiple transformed images. 34 | 35 | Args: 36 | transforms (list or tuple): list of transformations 37 | 38 | Example: 39 | 40 | >>> transform1 = T.Compose([ 41 | ... ResizeImage(256), 42 | ... T.RandomCrop(224) 43 | ... ]) 44 | >>> transform2 = T.Compose([ 45 | ... ResizeImage(256), 46 | ... T.RandomCrop(224), 47 | ... ]) 48 | >>> multiply_transform = MultipleApply([transform1, transform2]) 49 | """ 50 | 51 | def __init__(self, transforms): 52 | self.transforms = transforms 53 | 54 | def __call__(self, image): 55 | return [t(image) for t in self.transforms] 56 | 57 | def __repr__(self): 58 | format_string = self.__class__.__name__ + '(' 59 | for t in self.transforms: 60 | format_string += '\n' 61 | format_string += ' {0}'.format(t) 62 | format_string += '\n)' 63 | return format_string 64 | 65 | 66 | class Denormalize(Normalize): 67 | """DeNormalize a tensor image with mean and standard deviation. 68 | Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` 69 | channels, this transform will denormalize each channel of the input 70 | ``torch.*Tensor`` i.e., 71 | ``output[channel] = input[channel] * std[channel] + mean[channel]`` 72 | 73 | .. note:: 74 | This transform acts out of place, i.e., it does not mutate the input tensor. 75 | 76 | Args: 77 | mean (sequence): Sequence of means for each channel. 78 | std (sequence): Sequence of standard deviations for each channel. 79 | 80 | """ 81 | 82 | def __init__(self, mean, std): 83 | mean = np.array(mean) 84 | std = np.array(std) 85 | super().__init__((-mean / std).tolist(), (1 / std).tolist()) 86 | 87 | 88 | class NormalizeAndTranspose: 89 | """ 90 | First, normalize a tensor image with mean and standard deviation. 91 | Then, convert the shape (H x W x C) to shape (C x H x W). 92 | """ 93 | 94 | def __init__(self, mean=(104.00698793, 116.66876762, 122.67891434)): 95 | self.mean = np.array(mean, dtype=np.float32) 96 | 97 | def __call__(self, image): 98 | if isinstance(image, Image.Image): 99 | image = np.asarray(image, np.float32) 100 | # change to BGR 101 | image = image[:, :, ::-1] 102 | # normalize 103 | image -= self.mean 104 | image = image.transpose((2, 0, 1)).copy() 105 | elif isinstance(image, torch.Tensor): 106 | # change to BGR 107 | image = image[:, :, [2, 1, 0]] 108 | # normalize 109 | image -= torch.from_numpy(self.mean).to(image.device) 110 | image = image.permute((2, 0, 1)) 111 | else: 112 | raise NotImplementedError(type(image)) 113 | return image 114 | 115 | 116 | class DeNormalizeAndTranspose: 117 | """ 118 | First, convert a tensor image from the shape (C x H x W ) to shape (H x W x C). 119 | Then, denormalize it with mean and standard deviation. 120 | """ 121 | 122 | def __init__(self, mean=(104.00698793, 116.66876762, 122.67891434)): 123 | self.mean = np.array(mean, dtype=np.float32) 124 | 125 | def __call__(self, image): 126 | image = image.transpose((1, 2, 0)) 127 | # denormalize 128 | image += self.mean 129 | # change to RGB 130 | image = image[:, :, ::-1] 131 | return image 132 | 133 | 134 | class RandomErasing(object): 135 | """Random erasing augmentation from `Random Erasing Data Augmentation (CVPR 2017) 136 | `_. This augmentation randomly selects a rectangle region in an image 137 | and erases its pixels. 138 | 139 | Args: 140 | probability (float): The probability that the Random Erasing operation will be performed. 141 | sl (float): Minimum proportion of erased area against input image. 142 | sh (float): Maximum proportion of erased area against input image. 143 | r1 (float): Minimum aspect ratio of erased area. 144 | mean (sequence): Value to fill the erased area. 145 | """ 146 | 147 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 148 | self.probability = probability 149 | self.mean = mean 150 | self.sl = sl 151 | self.sh = sh 152 | self.r1 = r1 153 | 154 | def __call__(self, img): 155 | 156 | if random.uniform(0, 1) >= self.probability: 157 | return img 158 | 159 | for attempt in range(100): 160 | area = img.size()[1] * img.size()[2] 161 | 162 | target_area = random.uniform(self.sl, self.sh) * area 163 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 164 | 165 | h = int(round(math.sqrt(target_area * aspect_ratio))) 166 | w = int(round(math.sqrt(target_area / aspect_ratio))) 167 | 168 | if w < img.size()[2] and h < img.size()[1]: 169 | x1 = random.randint(0, img.size()[1] - h) 170 | y1 = random.randint(0, img.size()[2] - w) 171 | if img.size()[0] == 3: 172 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 173 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 174 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 175 | else: 176 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 177 | return img 178 | 179 | return img 180 | 181 | def __repr__(self): 182 | return self.__class__.__name__ + '(p={})'.format(self.probability) 183 | --------------------------------------------------------------------------------