├── LICENSE ├── README.md ├── datasets.py ├── engine.py ├── hubconf.py ├── losses.py ├── main.py ├── requirements.txt ├── samplers.py ├── sima.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 UCDvision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple Softmax-free Attention for Vision Transformer (SimA) 2 | 3 | 4 | Official PyTorch implementation and pretrained models for SimA models. ([arXiv](https://arxiv.org/abs/2206.08898)) 5 | 6 | --- 7 | 8 | ``` 9 | @misc{https://doi.org/10.48550/arxiv.2206.08898, 10 | doi = {10.48550/ARXIV.2206.08898}, 11 | url = {https://arxiv.org/abs/2206.08898}, 12 | author = {Koohpayegani, Soroush Abbasi and Pirsiavash, Hamed}, 13 | keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences}, 14 | title = {SimA: Simple Softmax-free Attention for Vision Transformers}, 15 | publisher = {arXiv}, 16 | year = {2022}, 17 | copyright = {arXiv.org perpetual, non-exclusive license} 18 | } 19 | ``` 20 | 21 | # Getting Started 22 | 23 | 24 | You can install the required packages including: [Pytorch](https://pytorch.org/) version 1.7.1, [torchvision](https://pytorch.org/vision/stable/index.html) version 0.8.2 and [Timm](https://github.com/rwightman/pytorch-image-models) version 0.4.8 25 | ``` 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | Download and extract the [ImageNet](https://imagenet.stanford.edu/) dataset. Afterwards, set the ```--data-path``` argument to the corresponding extracted ImageNet path. 30 | 31 | 32 | 33 | 34 | 35 | ### Training 36 | 37 | For training using 8 gpus, use the following command 38 | 39 | ``` 40 | python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --model sima_small_12_p16 --epochs 400 --batch-size 128 --drop-path 0.05 --output_dir [OUTPUT_PATH] --data-path [DATA_PATH] 41 | ``` 42 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Modified from: https://github.com/facebookresearch/deit 5 | """ 6 | import os 7 | import json 8 | 9 | from torchvision import datasets, transforms 10 | from torchvision.datasets.folder import ImageFolder, default_loader 11 | 12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 13 | from timm.data import create_transform 14 | 15 | 16 | class INatDataset(ImageFolder): 17 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 18 | category='name', loader=default_loader): 19 | self.transform = transform 20 | self.loader = loader 21 | self.target_transform = target_transform 22 | self.year = year 23 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 24 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 25 | with open(path_json) as json_file: 26 | data = json.load(json_file) 27 | 28 | with open(os.path.join(root, 'categories.json')) as json_file: 29 | data_catg = json.load(json_file) 30 | 31 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 32 | 33 | with open(path_json_for_targeter) as json_file: 34 | data_for_targeter = json.load(json_file) 35 | 36 | targeter = {} 37 | indexer = 0 38 | for elem in data_for_targeter['annotations']: 39 | king = [] 40 | king.append(data_catg[int(elem['category_id'])][category]) 41 | if king[0] not in targeter.keys(): 42 | targeter[king[0]] = indexer 43 | indexer += 1 44 | self.nb_classes = len(targeter) 45 | 46 | self.samples = [] 47 | for elem in data['images']: 48 | cut = elem['file_name'].split('/') 49 | target_current = int(cut[2]) 50 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 51 | 52 | categors = data_catg[target_current] 53 | target_current_true = targeter[categors[category]] 54 | self.samples.append((path_current, target_current_true)) 55 | 56 | # __getitem__ and __len__ inherited from ImageFolder 57 | 58 | 59 | def build_dataset(is_train, args): 60 | transform = build_transform(is_train, args) 61 | 62 | if args.data_set == 'CIFAR100': 63 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 64 | nb_classes = 100 65 | elif args.data_set == 'CIFAR10': 66 | dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform) 67 | nb_classes = 10 68 | elif args.data_set == 'IMNET': 69 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 70 | dataset = datasets.ImageFolder(root, transform=transform) 71 | nb_classes = 1000 72 | elif args.data_set == 'INAT': 73 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 74 | category=args.inat_category, transform=transform) 75 | nb_classes = dataset.nb_classes 76 | elif args.data_set == 'INAT19': 77 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 78 | category=args.inat_category, transform=transform) 79 | nb_classes = dataset.nb_classes 80 | elif args.data_set == 'CARS': 81 | root = os.path.join(args.data_path, 'train' if is_train else 'test') 82 | dataset = datasets.ImageFolder(root, transform=transform) 83 | nb_classes = 196 84 | elif args.data_set == 'FLOWERS': 85 | root = os.path.join(args.data_path, 'train' if is_train else 'test') 86 | dataset = datasets.ImageFolder(root, transform=transform) 87 | nb_classes = 102 88 | 89 | return dataset, nb_classes 90 | 91 | 92 | def build_transform(is_train, args): 93 | resize_im = args.input_size > 32 94 | if is_train: 95 | # this should always dispatch to transforms_imagenet_train 96 | transform = create_transform( 97 | input_size=args.input_size, 98 | is_training=True, 99 | color_jitter=args.color_jitter, 100 | auto_augment=args.aa, 101 | interpolation=args.train_interpolation, 102 | re_prob=args.reprob, 103 | re_mode=args.remode, 104 | re_count=args.recount, 105 | ) 106 | if not resize_im: 107 | # replace RandomResizedCropAndInterpolation with 108 | # RandomCrop 109 | transform.transforms[0] = transforms.RandomCrop( 110 | args.input_size, padding=4) 111 | 112 | return transform 113 | 114 | t = [] 115 | 116 | if args.full_crop: 117 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 118 | transformations = {} 119 | transformations = transforms.Compose( 120 | [transforms.Resize(args.input_size, interpolation=3), 121 | transforms.CenterCrop(args.input_size), 122 | transforms.ToTensor(), 123 | transforms.Normalize(mean, std)]) 124 | return transformations 125 | 126 | if resize_im: 127 | size = int((256 / 224) * args.input_size) 128 | t.append( 129 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 130 | ) 131 | t.append(transforms.CenterCrop(args.input_size)) 132 | 133 | t.append(transforms.ToTensor()) 134 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 135 | return transforms.Compose(t) 136 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | Modified from: https://github.com/facebookresearch/deit 6 | """ 7 | import math 8 | import sys 9 | from typing import Iterable, Optional 10 | 11 | import torch 12 | 13 | from timm.data import Mixup 14 | from timm.utils import accuracy, ModelEma 15 | 16 | from losses import DistillationLoss 17 | import utils 18 | 19 | 20 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 21 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 22 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 23 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 24 | set_training_mode=True, surgery=None): 25 | model.train(set_training_mode) 26 | 27 | if surgery: 28 | model.module.patch_embed.eval() 29 | 30 | metric_logger = utils.MetricLogger(delimiter=" ") 31 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 32 | header = 'Epoch: [{}]'.format(epoch) 33 | print_freq = 100 34 | 35 | for batch in metric_logger.log_every(data_loader, print_freq, header): 36 | samples, targets = batch[0], batch[1] 37 | 38 | samples = samples.to(device, non_blocking=True) 39 | targets = targets.to(device, non_blocking=True) 40 | 41 | if mixup_fn is not None: 42 | samples, targets = mixup_fn(samples, targets) 43 | 44 | with torch.cuda.amp.autocast(): 45 | outputs = model(samples) 46 | 47 | loss = criterion(samples, outputs, targets) 48 | 49 | loss_value = loss.item() 50 | 51 | if not math.isfinite(loss_value): 52 | print("Loss is {}, stopping training".format(loss_value)) 53 | sys.exit(1) 54 | 55 | optimizer.zero_grad() 56 | 57 | # this attribute is added by timm on one optimizer (adahessian) 58 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 59 | loss_scaler(loss, optimizer, clip_grad=max_norm, 60 | parameters=model.parameters(), create_graph=is_second_order) 61 | 62 | torch.cuda.synchronize() 63 | if model_ema is not None: 64 | model_ema.update(model) 65 | 66 | metric_logger.update(loss=loss_value) 67 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 68 | # gather the stats from all processes 69 | metric_logger.synchronize_between_processes() 70 | print("Averaged stats:", metric_logger) 71 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 72 | 73 | 74 | @torch.no_grad() 75 | def evaluate(data_loader, model, device): 76 | criterion = torch.nn.CrossEntropyLoss() 77 | 78 | metric_logger = utils.MetricLogger(delimiter=" ") 79 | header = 'Test:' 80 | 81 | # switch to evaluation mode 82 | model.eval() 83 | 84 | for batch in metric_logger.log_every(data_loader, 100, header): 85 | images, target = batch[0], batch[1] 86 | 87 | images = images.to(device, non_blocking=True) 88 | target = target.to(device, non_blocking=True) 89 | 90 | # compute output 91 | with torch.cuda.amp.autocast(): 92 | output = model(images) 93 | 94 | loss = criterion(output, target) 95 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 96 | 97 | batch_size = images.shape[0] 98 | metric_logger.update(loss=loss.item()) 99 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 100 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 101 | # gather the stats from all processes 102 | metric_logger.synchronize_between_processes() 103 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 104 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 105 | 106 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 107 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | from sima import * 4 | 5 | dependencies = ["torch", "torchvision", "timm"] 6 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | Modified from: https://github.com/facebookresearch/deit 6 | """ 7 | import torch 8 | from torch.nn import functional as F 9 | import torch.nn as nn 10 | 11 | 12 | class DistillationLoss(torch.nn.Module): 13 | """ 14 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 15 | taking a teacher model prediction and using it as additional supervision. 16 | """ 17 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 18 | distillation_type: str, alpha: float, tau: float): 19 | super().__init__() 20 | self.base_criterion = base_criterion 21 | self.teacher_model = teacher_model 22 | assert distillation_type in ['none', 'soft', 'hard'] 23 | self.distillation_type = distillation_type 24 | self.alpha = alpha 25 | self.tau = tau 26 | 27 | def forward(self, inputs, outputs, labels): 28 | """ 29 | Args: 30 | inputs: The original inputs that are feed to the teacher model 31 | outputs: the outputs of the model to be trained. It is expected to be 32 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 33 | in the first position and the distillation predictions as the second output 34 | labels: the labels for the base criterion 35 | """ 36 | outputs_kd = None 37 | if not isinstance(outputs, torch.Tensor): 38 | # assume that the model outputs a tuple of [outputs, outputs_kd] 39 | outputs, outputs_kd = outputs 40 | 41 | if isinstance(outputs, tuple): 42 | base_loss = self.base_criterion(outputs, labels, inputs) 43 | else: 44 | base_loss = self.base_criterion(outputs, labels) 45 | 46 | if self.distillation_type == 'none': 47 | return base_loss 48 | 49 | if outputs_kd is None: 50 | raise ValueError("When knowledge distillation is enabled, the model is " 51 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 52 | "class_token and the dist_token") 53 | # don't backprop throught the teacher 54 | with torch.no_grad(): 55 | teacher_outputs = self.teacher_model(inputs) 56 | 57 | if self.distillation_type == 'soft': 58 | T = self.tau 59 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 60 | # with slight modifications 61 | distillation_loss = F.kl_div( 62 | F.log_softmax(outputs_kd / T, dim=1), 63 | F.log_softmax(teacher_outputs / T, dim=1), 64 | reduction='sum', 65 | log_target=True 66 | ) * (T * T) / outputs_kd.numel() 67 | elif self.distillation_type == 'hard': 68 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 69 | 70 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 71 | return loss 72 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | The main training/evaluation loop 5 | Modified from: https://github.com/facebookresearch/deit 6 | """ 7 | import argparse 8 | import datetime 9 | import numpy as np 10 | import time 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | import json 14 | import os 15 | 16 | from pathlib import Path 17 | 18 | from timm.data import Mixup 19 | from timm.models import create_model 20 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 21 | from timm.scheduler import create_scheduler 22 | from timm.optim import create_optimizer 23 | from timm.utils import NativeScaler, get_state_dict, ModelEma 24 | 25 | from datasets import build_dataset 26 | from engine import train_one_epoch, evaluate 27 | from losses import DistillationLoss 28 | from samplers import RASampler 29 | import utils 30 | 31 | import sima 32 | 33 | 34 | def get_args_parser(): 35 | parser = argparse.ArgumentParser('SimA training and evaluation script', add_help=False) 36 | parser.add_argument('--batch-size', default=64, type=int) 37 | parser.add_argument('--epochs', default=400, type=int) 38 | 39 | # Model parameters 40 | parser.add_argument('--model', default='sima_s_12', type=str, metavar='MODEL', 41 | help='Name of model to train') 42 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 43 | 44 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 45 | help='Dropout rate (default: 0.)') 46 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 47 | help='Drop path rate (default: 0.1)') 48 | 49 | parser.add_argument('--model-ema', action='store_true') 50 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 51 | parser.set_defaults(model_ema=True) 52 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 53 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 54 | 55 | # Optimizer parameters 56 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 57 | help='Optimizer (default: "adamw"') 58 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 59 | help='Optimizer Epsilon (default: 1e-8)') 60 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 61 | help='Optimizer Betas (default: None, use opt default)') 62 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 63 | help='Clip gradient norm (default: None, no clipping)') 64 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 65 | help='SGD momentum (default: 0.9)') 66 | parser.add_argument('--weight-decay', type=float, default=0.05, 67 | help='weight decay (default: 0.05)') 68 | 69 | # Learning rate schedule parameters 70 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 71 | help='LR scheduler (default: "cosine"') 72 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 73 | help='learning rate (default: 5e-4)') 74 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 75 | help='learning rate noise on/off epoch percentages') 76 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 77 | help='learning rate noise limit percent (default: 0.67)') 78 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 79 | help='learning rate noise std-dev (default: 1.0)') 80 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 81 | help='warmup learning rate (default: 1e-6)') 82 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 83 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 84 | 85 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 86 | help='epoch interval to decay LR') 87 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 88 | help='epochs to warmup LR, if scheduler supports') 89 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 90 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 91 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 92 | help='patience epochs for Plateau LR scheduler (default: 10') 93 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 94 | help='LR decay rate (default: 0.1)') 95 | 96 | # Augmentation parameters 97 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 98 | help='Color jitter factor (default: 0.4)') 99 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 100 | help='Use AutoAugment policy. "v0" or "original". " + \ 101 | "(default: rand-m9-mstd0.5-inc1)'), 102 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 103 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 104 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 105 | 106 | parser.add_argument('--repeated-aug', action='store_true') 107 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 108 | parser.set_defaults(repeated_aug=True) 109 | 110 | # * Random Erase params 111 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 112 | help='Random erase prob (default: 0.25)') 113 | parser.add_argument('--remode', type=str, default='pixel', 114 | help='Random erase mode (default: "pixel")') 115 | parser.add_argument('--recount', type=int, default=1, 116 | help='Random erase count (default: 1)') 117 | parser.add_argument('--resplit', action='store_true', default=False, 118 | help='Do not random erase first (clean) augmentation split') 119 | 120 | # * Mixup params 121 | parser.add_argument('--mixup', type=float, default=0.8, 122 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 123 | parser.add_argument('--cutmix', type=float, default=1.0, 124 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 125 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 126 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 127 | parser.add_argument('--mixup-prob', type=float, default=1.0, 128 | help='Probability of performing mixup or cutmix when either/both is enabled') 129 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 130 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 131 | parser.add_argument('--mixup-mode', type=str, default='batch', 132 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 133 | 134 | # Distillation parameters 135 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', 136 | help='Name of teacher model to train (default: "regnety_160"') 137 | parser.add_argument('--teacher-path', type=str, default='') 138 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") 139 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") 140 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="") 141 | 142 | 143 | # Dataset parameters 144 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, 145 | help='dataset path') 146 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR10', 'CIFAR100', 'IMNET', 147 | 'INAT', 'INAT19', 'CARS', 'FLOWERS', 148 | 'IMNET22k'], 149 | type=str, help='Image Net dataset path') 150 | parser.add_argument('--inat-category', default='name', 151 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 152 | type=str, help='semantic granularity') 153 | 154 | parser.add_argument('--output_dir', default='', 155 | help='path where to save, empty for no saving') 156 | parser.add_argument('--device', default='cuda', 157 | help='device to use for training / testing') 158 | parser.add_argument('--seed', default=0, type=int) 159 | parser.add_argument('--resume', default='', help='resume from checkpoint') 160 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 161 | help='start epoch') 162 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 163 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 164 | parser.add_argument('--num_workers', default=10, type=int) 165 | parser.add_argument('--pin-mem', action='store_true', 166 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 167 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 168 | help='') 169 | parser.set_defaults(pin_mem=True) 170 | 171 | # distributed training parameters 172 | parser.add_argument('--world_size', default=1, type=int, 173 | help='number of distributed processes') 174 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 175 | parser.add_argument('--test-freq', default=1, type=int, help='Number of epochs between \ 176 | validation runs.') 177 | 178 | parser.add_argument('--full_crop', action='store_true', help='use crop_ratio=1.0 instead of the\ 179 | default 0.875 (Used by CaiT).') 180 | parser.add_argument("--pretrained", default=None, type=str, help='Path to pre-trained checkpoint') 181 | parser.add_argument('--surgery', default=None, type=str, help='Path to checkpoint to copy the \ 182 | patch projection from. \ 183 | Can improve stability for very \ 184 | large models.') 185 | 186 | return parser 187 | 188 | 189 | def main(args): 190 | utils.init_distributed_mode(args) 191 | 192 | print(args) 193 | 194 | device = torch.device(args.device) 195 | 196 | # fix the seed for reproducibility 197 | seed = args.seed + utils.get_rank() 198 | torch.manual_seed(seed) 199 | np.random.seed(seed) 200 | 201 | cudnn.benchmark = True 202 | 203 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 204 | dataset_val, _ = build_dataset(is_train=False, args=args) 205 | 206 | if True: # args.distributed: 207 | num_tasks = utils.get_world_size() 208 | global_rank = utils.get_rank() 209 | if args.repeated_aug: 210 | sampler_train = RASampler( 211 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 212 | ) 213 | else: 214 | sampler_train = torch.utils.data.DistributedSampler( 215 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 216 | ) 217 | if args.dist_eval: 218 | if len(dataset_val) % num_tasks != 0: 219 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 220 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 221 | 'equal num of samples per-process.') 222 | sampler_val = torch.utils.data.DistributedSampler( 223 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 224 | else: 225 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 226 | else: 227 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 228 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 229 | 230 | data_loader_train = torch.utils.data.DataLoader( 231 | dataset_train, sampler=sampler_train, 232 | batch_size=args.batch_size, 233 | num_workers=args.num_workers, 234 | pin_memory=args.pin_mem, 235 | drop_last=True, 236 | ) 237 | 238 | data_loader_val = torch.utils.data.DataLoader( 239 | dataset_val, sampler=sampler_val, 240 | batch_size=int(1.5 * args.batch_size), 241 | num_workers=args.num_workers, 242 | pin_memory=args.pin_mem, 243 | drop_last=False 244 | ) 245 | 246 | mixup_fn = None 247 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 248 | if mixup_active: 249 | mixup_fn = Mixup( 250 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 251 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 252 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 253 | 254 | print(f"Creating model: {args.model}") 255 | 256 | model = create_model( 257 | args.model, 258 | pretrained=False, 259 | num_classes=args.nb_classes, 260 | drop_rate=args.drop, 261 | drop_path_rate=args.drop_path, 262 | drop_block_rate=None 263 | ) 264 | 265 | if args.pretrained: 266 | if args.pretrained.startswith('https'): 267 | checkpoint = torch.hub.load_state_dict_from_url( 268 | args.pretrained, map_location='cpu', check_hash=True) 269 | else: 270 | checkpoint = torch.load(args.pretrained, map_location='cpu') 271 | 272 | checkpoint_model = checkpoint['model'] 273 | state_dict = model.state_dict() 274 | for k in ['head.weight', 'head.bias']: 275 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 276 | print(f"Removing key {k} from pretrained checkpoint") 277 | del checkpoint_model[k] 278 | 279 | model.load_state_dict(checkpoint_model, strict=True) 280 | 281 | model.to(device) 282 | 283 | if args.surgery: 284 | checkpoint = torch.load(args.surgery, map_location='cpu') 285 | checkpoint_model = checkpoint['model'] 286 | patch_embed_weights = {key.replace("patch_embed.", ""): value for key, 287 | value in checkpoint['model'].items() if 'patch_embed' in key} 288 | 289 | model.patch_embed.load_state_dict(patch_embed_weights) 290 | for p in model.patch_embed.parameters(): 291 | p.requires_grad = False 292 | 293 | model_ema = None 294 | if args.model_ema: 295 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 296 | model_ema = ModelEma( 297 | model, 298 | decay=args.model_ema_decay, 299 | device='cpu' if args.model_ema_force_cpu else '', 300 | resume='') 301 | 302 | model_without_ddp = model 303 | if args.distributed: 304 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 305 | model_without_ddp = model.module 306 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 307 | print('number of params:', n_parameters) 308 | 309 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 310 | args.lr = linear_scaled_lr 311 | optimizer = create_optimizer(args, model_without_ddp) 312 | loss_scaler = NativeScaler() 313 | 314 | lr_scheduler, _ = create_scheduler(args, optimizer) 315 | 316 | criterion = LabelSmoothingCrossEntropy() 317 | 318 | if args.mixup > 0.: 319 | # smoothing is handled with mixup label transform 320 | criterion = SoftTargetCrossEntropy() 321 | elif args.smoothing: 322 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 323 | else: 324 | criterion = torch.nn.CrossEntropyLoss() 325 | 326 | teacher_model = None 327 | if args.distillation_type != 'none': 328 | assert args.teacher_path, 'need to specify teacher-path when using distillation' 329 | print(f"Creating teacher model: {args.teacher_model}") 330 | teacher_model = create_model( 331 | args.teacher_model, 332 | pretrained=False, 333 | num_classes=args.nb_classes, 334 | global_pool='avg', 335 | ) 336 | if args.teacher_path.startswith('https'): 337 | checkpoint = torch.hub.load_state_dict_from_url( 338 | args.teacher_path, map_location='cpu', check_hash=True) 339 | else: 340 | checkpoint = torch.load(args.teacher_path, map_location='cpu') 341 | 342 | teacher_model.load_state_dict(checkpoint['model']) 343 | 344 | teacher_model.to(device) 345 | teacher_model.eval() 346 | 347 | # wrap the criterion in our custom DistillationLoss, which 348 | # just dispatches to the original criterion if args.distillation_type is 'none' 349 | criterion = DistillationLoss( 350 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau 351 | ) 352 | 353 | output_dir = Path(args.output_dir) 354 | if not os.path.exists(output_dir): 355 | os.mkdir(output_dir) 356 | 357 | resume_path = os.path.join(output_dir, 'checkpoint.pth') 358 | if args.resume and os.path.exists(resume_path): 359 | if args.resume.startswith('https'): 360 | checkpoint = torch.hub.load_state_dict_from_url( 361 | args.resume, map_location='cpu', check_hash=True) 362 | else: 363 | print("Loading from checkpoint ...") 364 | checkpoint = torch.load(resume_path, map_location='cpu') 365 | model_without_ddp.load_state_dict(checkpoint['model']) 366 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 367 | optimizer.load_state_dict(checkpoint['optimizer']) 368 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 369 | args.start_epoch = checkpoint['epoch'] + 1 370 | if args.model_ema: 371 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 372 | if 'scaler' in checkpoint: 373 | loss_scaler.load_state_dict(checkpoint['scaler']) 374 | 375 | if args.eval: 376 | test_stats = evaluate(data_loader_val, model, device) 377 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 378 | return 379 | 380 | print(f"Start training for {args.epochs} epochs") 381 | start_time = time.time() 382 | max_accuracy = 0.0 383 | for epoch in range(args.start_epoch, args.epochs): 384 | if args.distributed: 385 | data_loader_train.sampler.set_epoch(epoch) 386 | 387 | train_stats = train_one_epoch( 388 | model, criterion, data_loader_train, 389 | optimizer, device, epoch, loss_scaler, 390 | args.clip_grad, model_ema, mixup_fn, 391 | surgery=args.surgery 392 | ) 393 | 394 | lr_scheduler.step(epoch) 395 | if args.output_dir: 396 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 397 | for checkpoint_path in checkpoint_paths: 398 | utils.save_on_master({ 399 | 'model': model_without_ddp.state_dict(), 400 | 'optimizer': optimizer.state_dict(), 401 | 'lr_scheduler': lr_scheduler.state_dict(), 402 | 'epoch': epoch, 403 | 'model_ema': get_state_dict(model_ema), 404 | 'scaler': loss_scaler.state_dict(), 405 | 'args': args, 406 | }, checkpoint_path) 407 | 408 | if (epoch % args.test_freq == 0) or (epoch == args.epochs - 1): 409 | test_stats = evaluate(data_loader_val, model, device) 410 | 411 | if test_stats["acc1"] >= max_accuracy: 412 | utils.save_on_master({ 413 | 'model': model_without_ddp.state_dict(), 414 | 'optimizer': optimizer.state_dict(), 415 | 'lr_scheduler': lr_scheduler.state_dict(), 416 | 'epoch': epoch, 417 | 'model_ema': get_state_dict(model_ema), 418 | 'args': args, 419 | }, os.path.join(output_dir, 'best_model.pth')) 420 | 421 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 422 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 423 | print(f'Max accuracy: {max_accuracy:.2f}%') 424 | 425 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 426 | **{f'test_{k}': v for k, v in test_stats.items()}, 427 | 'epoch': epoch, 428 | 'n_parameters': n_parameters} 429 | 430 | if args.output_dir and utils.is_main_process(): 431 | with (output_dir / "log.txt").open("a") as f: 432 | f.write(json.dumps(log_stats) + "\n") 433 | 434 | total_time = time.time() - start_time 435 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 436 | print('Training time {}'.format(total_time_str)) 437 | 438 | 439 | if __name__ == '__main__': 440 | parser = argparse.ArgumentParser('XCiT training and evaluation script', parents=[get_args_parser()]) 441 | args = parser.parse_args() 442 | if args.output_dir: 443 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 444 | main(args) 445 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.0 2 | torchvision==0.9.0 3 | timm==0.4.9 4 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Copied from: https://github.com/facebookresearch/deit 5 | """ 6 | import torch 7 | import torch.distributed as dist 8 | import math 9 | 10 | 11 | class RASampler(torch.utils.data.Sampler): 12 | """Sampler that restricts data loading to a subset of the dataset for distributed, 13 | with repeated augmentation. 14 | It ensures that different each augmented version of a sample will be visible to a 15 | different process (GPU) 16 | Heavily based on torch.utils.data.DistributedSampler 17 | """ 18 | 19 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 20 | if num_replicas is None: 21 | if not dist.is_available(): 22 | raise RuntimeError("Requires distributed package to be available") 23 | num_replicas = dist.get_world_size() 24 | if rank is None: 25 | if not dist.is_available(): 26 | raise RuntimeError("Requires distributed package to be available") 27 | rank = dist.get_rank() 28 | self.dataset = dataset 29 | self.num_replicas = num_replicas 30 | self.rank = rank 31 | self.epoch = 0 32 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 33 | self.total_size = self.num_samples * self.num_replicas 34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 36 | self.shuffle = shuffle 37 | 38 | def __iter__(self): 39 | # deterministically shuffle based on epoch 40 | g = torch.Generator() 41 | g.manual_seed(self.epoch) 42 | if self.shuffle: 43 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 44 | else: 45 | indices = list(range(len(self.dataset))) 46 | 47 | # add extra samples to make it evenly divisible 48 | indices = [ele for ele in indices for i in range(3)] 49 | indices += indices[:(self.total_size - len(indices))] 50 | assert len(indices) == self.total_size 51 | 52 | # subsample 53 | indices = indices[self.rank:self.total_size:self.num_replicas] 54 | assert len(indices) == self.num_samples 55 | 56 | return iter(indices[:self.num_selected_samples]) 57 | 58 | def __len__(self): 59 | return self.num_selected_samples 60 | 61 | def set_epoch(self, epoch): 62 | self.epoch = epoch 63 | -------------------------------------------------------------------------------- /sima.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Implementation of SimA (Simple Softmax-free Attention) 4 | Based on timm, DeiT and XCiT code bases 5 | https://github.com/rwightman/pytorch-image-models/tree/master/timm 6 | https://github.com/facebookresearch/deit/ 7 | https://github.com/facebookresearch/xcit/ 8 | """ 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | from functools import partial 14 | 15 | from timm.models.vision_transformer import _cfg, Mlp 16 | from timm.models.registry import register_model 17 | from timm.models.layers import DropPath, trunc_normal_, to_2tuple 18 | import torch.nn.functional as F 19 | 20 | 21 | class PositionalEncodingFourier(nn.Module): 22 | """ 23 | Positional encoding relying on a fourier kernel matching the one used in the 24 | "Attention is all of Need" paper. The implementation builds on DeTR code 25 | https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py 26 | """ 27 | 28 | def __init__(self, hidden_dim=32, dim=768, temperature=10000): 29 | super().__init__() 30 | self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) 31 | self.scale = 2 * math.pi 32 | self.temperature = temperature 33 | self.hidden_dim = hidden_dim 34 | self.dim = dim 35 | 36 | def forward(self, B, H, W): 37 | mask = torch.zeros(B, H, W).bool().to(self.token_projection.weight.device) 38 | not_mask = ~mask 39 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 40 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 41 | eps = 1e-6 42 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 43 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 44 | 45 | dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=mask.device) 46 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim) 47 | 48 | pos_x = x_embed[:, :, :, None] / dim_t 49 | pos_y = y_embed[:, :, :, None] / dim_t 50 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), 51 | pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 52 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), 53 | pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 54 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 55 | pos = self.token_projection(pos) 56 | return pos 57 | 58 | 59 | def conv3x3(in_planes, out_planes, stride=1): 60 | """3x3 convolution with padding""" 61 | return torch.nn.Sequential( 62 | nn.Conv2d( 63 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 64 | ), 65 | nn.SyncBatchNorm(out_planes) 66 | ) 67 | 68 | 69 | class ConvPatchEmbed(nn.Module): 70 | """ Image to Patch Embedding using multiple convolutional layers 71 | """ 72 | 73 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 74 | super().__init__() 75 | img_size = to_2tuple(img_size) 76 | patch_size = to_2tuple(patch_size) 77 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 78 | self.img_size = img_size 79 | self.patch_size = patch_size 80 | self.num_patches = num_patches 81 | 82 | if patch_size[0] == 16: 83 | self.proj = torch.nn.Sequential( 84 | conv3x3(3, embed_dim // 8, 2), 85 | nn.GELU(), 86 | conv3x3(embed_dim // 8, embed_dim // 4, 2), 87 | nn.GELU(), 88 | conv3x3(embed_dim // 4, embed_dim // 2, 2), 89 | nn.GELU(), 90 | conv3x3(embed_dim // 2, embed_dim, 2), 91 | ) 92 | elif patch_size[0] == 8: 93 | self.proj = torch.nn.Sequential( 94 | conv3x3(3, embed_dim // 4, 2), 95 | nn.GELU(), 96 | conv3x3(embed_dim // 4, embed_dim // 2, 2), 97 | nn.GELU(), 98 | conv3x3(embed_dim // 2, embed_dim, 2), 99 | ) 100 | else: 101 | raise("For convolutional projection, patch size has to be in [8, 16]") 102 | 103 | def forward(self, x, padding_size=None): 104 | B, C, H, W = x.shape 105 | x = self.proj(x) 106 | Hp, Wp = x.shape[2], x.shape[3] 107 | x = x.flatten(2).transpose(1, 2) 108 | 109 | return x, (Hp, Wp) 110 | 111 | 112 | class LPI(nn.Module): 113 | """ 114 | Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows 115 | to augment the implicit communcation performed by the block diagonal scatter attention. 116 | Implemented using 2 layers of separable 3x3 convolutions with GeLU and BatchNorm2d 117 | """ 118 | 119 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, 120 | drop=0., kernel_size=3): 121 | super().__init__() 122 | out_features = out_features or in_features 123 | 124 | padding = kernel_size // 2 125 | 126 | self.conv1 = torch.nn.Conv2d(in_features, out_features, kernel_size=kernel_size, 127 | padding=padding, groups=out_features) 128 | self.act = act_layer() 129 | self.bn = nn.SyncBatchNorm(in_features) 130 | self.conv2 = torch.nn.Conv2d(in_features, out_features, kernel_size=kernel_size, 131 | padding=padding, groups=out_features) 132 | 133 | def forward(self, x, H, W): 134 | B, N, C = x.shape 135 | x = x.permute(0, 2, 1).reshape(B, C, H, W) 136 | x = self.conv1(x) 137 | x = self.act(x) 138 | x = self.bn(x) 139 | x = self.conv2(x) 140 | x = x.reshape(B, C, N).permute(0, 2, 1) 141 | 142 | return x 143 | 144 | 145 | class ClassAttention(nn.Module): 146 | """Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239 147 | """ 148 | 149 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 150 | super().__init__() 151 | self.num_heads = num_heads 152 | head_dim = dim // num_heads 153 | self.scale = qk_scale or head_dim ** -0.5 154 | 155 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 156 | self.attn_drop = nn.Dropout(attn_drop) 157 | self.proj = nn.Linear(dim, dim) 158 | self.proj_drop = nn.Dropout(proj_drop) 159 | 160 | def forward(self, x): 161 | B, N, C = x.shape 162 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 163 | qkv = qkv.permute(2, 0, 3, 1, 4) 164 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 165 | 166 | k = F.normalize(k, p=1.0, dim=-2, eps=1e-6) 167 | q = F.normalize(q, p=1.0, dim=-2, eps=1e-6) 168 | qc = q[:, :, 0:1] # CLS token 169 | attn_cls = (qc * k).sum(dim=-1) 170 | 171 | attn_cls = self.attn_drop(attn_cls) 172 | 173 | cls_tkn = (attn_cls.unsqueeze(2) @ v).transpose(1, 2).reshape(B, 1, C) 174 | cls_tkn = self.proj(cls_tkn) 175 | x = torch.cat([self.proj_drop(cls_tkn), x[:, 1:]], dim=1) 176 | return x 177 | 178 | 179 | class ClassAttentionBlock(nn.Module): 180 | """Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239 181 | """ 182 | 183 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., 184 | attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, eta=None, 185 | tokens_norm=False): 186 | super().__init__() 187 | self.norm1 = norm_layer(dim) 188 | 189 | self.attn = ClassAttention( 190 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, 191 | proj_drop=drop 192 | ) 193 | 194 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 195 | self.norm2 = norm_layer(dim) 196 | mlp_hidden_dim = int(dim * mlp_ratio) 197 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, 198 | drop=drop) 199 | 200 | if eta is not None: # LayerScale Initialization (no layerscale when None) 201 | self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) 202 | self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) 203 | else: 204 | self.gamma1, self.gamma2 = 1.0, 1.0 205 | 206 | # FIXME: A hack for models pre-trained with layernorm over all the tokens not just the CLS 207 | self.tokens_norm = tokens_norm 208 | 209 | def forward(self, x, H, W, mask=None): 210 | x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) 211 | if self.tokens_norm: 212 | x = self.norm2(x) 213 | else: 214 | x[:, 0:1] = self.norm2(x[:, 0:1]) 215 | 216 | x_res = x 217 | cls_token = x[:, 0:1] 218 | cls_token = self.gamma2 * self.mlp(cls_token) 219 | x = torch.cat([cls_token, x[:, 1:]], dim=1) 220 | x = x_res + self.drop_path(x) 221 | return x 222 | 223 | 224 | class SimA(nn.Module): 225 | """ SimA attention block 226 | """ 227 | 228 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 229 | super().__init__() 230 | self.num_heads = num_heads 231 | 232 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 233 | self.attn_drop = nn.Dropout(attn_drop) 234 | self.proj = nn.Linear(dim, dim) 235 | self.proj_drop = nn.Dropout(proj_drop) 236 | 237 | def forward(self, x): 238 | B, N, C = x.shape 239 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) 240 | qkv = qkv.permute(2, 0, 3, 1, 4) 241 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 242 | 243 | k = F.normalize(k, p=1.0, dim=-2) 244 | q = F.normalize(q, p=1.0, dim=-2) 245 | if N < (C//self.num_heads) : 246 | x = ((q @ k.transpose(-2, -1)) @ v).transpose(1, 2).reshape(B, N, C) 247 | else: 248 | x = (q @ (k.transpose(-2, -1) @ v)).transpose(1, 2).reshape(B, N, C) 249 | 250 | 251 | x = self.proj(x) 252 | x = self.proj_drop(x) 253 | return x 254 | 255 | @torch.jit.ignore 256 | def no_weight_decay(self): 257 | return {} 258 | 259 | 260 | class SimABlock(nn.Module): 261 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., 262 | attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, 263 | num_tokens=196, eta=None): 264 | super().__init__() 265 | self.norm1 = norm_layer(dim) 266 | self.attn = SimA( 267 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, 268 | proj_drop=drop 269 | ) 270 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 271 | self.norm2 = norm_layer(dim) 272 | 273 | mlp_hidden_dim = int(dim * mlp_ratio) 274 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, 275 | drop=drop) 276 | 277 | self.norm3 = norm_layer(dim) 278 | self.local_mp = LPI(in_features=dim, act_layer=act_layer) 279 | 280 | self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) 281 | self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) 282 | self.gamma3 = nn.Parameter(eta * torch.ones(dim), requires_grad=True) 283 | 284 | def forward(self, x, H, W): 285 | x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) 286 | x = x + self.drop_path(self.gamma3 * self.local_mp(self.norm3(x), H, W)) 287 | x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) 288 | return x 289 | 290 | 291 | class SimAVisionTransformer(nn.Module): 292 | """ 293 | Based on timm, DeiT and XCiT code bases 294 | https://github.com/rwightman/pytorch-image-models/tree/master/timm 295 | https://github.com/facebookresearch/deit/ 296 | https://github.com/facebookresearch/xcit/ 297 | """ 298 | 299 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, 300 | depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, 301 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 302 | cls_attn_layers=2, use_pos=True, patch_proj='linear', eta=None, tokens_norm=False): 303 | """ 304 | Args: 305 | img_size (int, tuple): input image size 306 | patch_size (int, tuple): patch size 307 | in_chans (int): number of input channels 308 | num_classes (int): number of classes for classification head 309 | embed_dim (int): embedding dimension 310 | depth (int): depth of transformer 311 | num_heads (int): number of attention heads 312 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 313 | qkv_bias (bool): enable bias for qkv if True 314 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 315 | drop_rate (float): dropout rate 316 | attn_drop_rate (float): attention dropout rate 317 | drop_path_rate (float): stochastic depth rate 318 | norm_layer: (nn.Module): normalization layer 319 | cls_attn_layers: (int) Depth of Class attention layers 320 | use_pos: (bool) whether to use positional encoding 321 | eta: (float) layerscale initialization value 322 | tokens_norm: (bool) Whether to normalize all tokens or just the cls_token in the CA 323 | """ 324 | super().__init__() 325 | self.num_classes = num_classes 326 | self.num_features = self.embed_dim = embed_dim 327 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 328 | 329 | self.patch_embed = ConvPatchEmbed(img_size=img_size, embed_dim=embed_dim, 330 | patch_size=patch_size) 331 | 332 | num_patches = self.patch_embed.num_patches 333 | 334 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 335 | self.pos_drop = nn.Dropout(p=drop_rate) 336 | 337 | dpr = [drop_path_rate for i in range(depth)] 338 | self.blocks = nn.ModuleList([ 339 | SimABlock( 340 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, 341 | qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], 342 | norm_layer=norm_layer, num_tokens=num_patches, eta=eta) 343 | for i in range(depth)]) 344 | 345 | self.cls_attn_blocks = nn.ModuleList([ 346 | ClassAttentionBlock( 347 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, 348 | qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, norm_layer=norm_layer, 349 | eta=eta, tokens_norm=tokens_norm) 350 | for i in range(cls_attn_layers)]) 351 | self.norm = norm_layer(embed_dim) 352 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 353 | 354 | self.pos_embeder = PositionalEncodingFourier(dim=embed_dim) 355 | self.use_pos = use_pos 356 | 357 | # Classifier head 358 | trunc_normal_(self.cls_token, std=.02) 359 | self.apply(self._init_weights) 360 | 361 | def _init_weights(self, m): 362 | if isinstance(m, nn.Linear): 363 | trunc_normal_(m.weight, std=.02) 364 | if isinstance(m, nn.Linear) and m.bias is not None: 365 | nn.init.constant_(m.bias, 0) 366 | elif isinstance(m, nn.LayerNorm): 367 | nn.init.constant_(m.bias, 0) 368 | nn.init.constant_(m.weight, 1.0) 369 | 370 | @torch.jit.ignore 371 | def no_weight_decay(self): 372 | return {'pos_embed', 'cls_token', 'dist_token'} 373 | 374 | def forward_features(self, x): 375 | B, C, H, W = x.shape 376 | 377 | x, (Hp, Wp) = self.patch_embed(x) 378 | 379 | if self.use_pos: 380 | pos_encoding = self.pos_embeder(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1) 381 | x = x + pos_encoding 382 | 383 | x = self.pos_drop(x) 384 | 385 | for blk in self.blocks: 386 | x = blk(x, Hp, Wp) 387 | 388 | cls_tokens = self.cls_token.expand(B, -1, -1) 389 | x = torch.cat((cls_tokens, x), dim=1) 390 | 391 | for blk in self.cls_attn_blocks: 392 | x = blk(x, Hp, Wp) 393 | 394 | x = self.norm(x)[:, 0] 395 | return x 396 | 397 | def forward(self, x): 398 | x = self.forward_features(x) 399 | x = self.head(x) 400 | 401 | if self.training: 402 | return x, x 403 | else: 404 | return x 405 | 406 | 407 | # Patch size 16x16 models 408 | @register_model 409 | def sima_nano_12_p16(pretrained=False, **kwargs): 410 | model = SimAVisionTransformer( 411 | patch_size=16, embed_dim=128, depth=12, num_heads=4, mlp_ratio=4, qkv_bias=True, 412 | norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1.0, tokens_norm=False, **kwargs) 413 | model.default_cfg = _cfg() 414 | return model 415 | 416 | 417 | @register_model 418 | def sima_tiny_12_p16(pretrained=False, **kwargs): 419 | model = SimAVisionTransformer( 420 | patch_size=16, embed_dim=192, depth=12, num_heads=4, mlp_ratio=4, qkv_bias=True, 421 | norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1.0, tokens_norm=True, **kwargs) 422 | model.default_cfg = _cfg() 423 | return model 424 | 425 | 426 | @register_model 427 | def sima_small_12_p16(pretrained=False, **kwargs): 428 | model = SimAVisionTransformer( 429 | patch_size=16, embed_dim=384, depth=12, num_heads=8, mlp_ratio=4, qkv_bias=True, 430 | norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1.0, tokens_norm=True, **kwargs) 431 | model.default_cfg = _cfg() 432 | return model 433 | 434 | 435 | @register_model 436 | def sima_tiny_24_p16(pretrained=False, **kwargs): 437 | model = SimAVisionTransformer( 438 | patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True, 439 | norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1e-5, tokens_norm=True, **kwargs) 440 | model.default_cfg = _cfg() 441 | return model 442 | 443 | 444 | @register_model 445 | def sima_small_24_p16(pretrained=False, **kwargs): 446 | model = SimAVisionTransformer( 447 | patch_size=16, embed_dim=384, depth=24, num_heads=8, mlp_ratio=4, qkv_bias=True, 448 | norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1e-5, tokens_norm=True, **kwargs) 449 | model.default_cfg = _cfg() 450 | return model 451 | 452 | 453 | @register_model 454 | def sima_medium_24_p16(pretrained=False, **kwargs): 455 | model = SimAVisionTransformer( 456 | patch_size=16, embed_dim=512, depth=24, num_heads=8, mlp_ratio=4, qkv_bias=True, 457 | norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1e-5, tokens_norm=True, **kwargs) 458 | model.default_cfg = _cfg() 459 | return model 460 | 461 | 462 | @register_model 463 | def sima_large_24_p16(pretrained=False, **kwargs): 464 | model = SimAVisionTransformer( 465 | patch_size=16, embed_dim=768, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 466 | norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1e-5, tokens_norm=True, **kwargs) 467 | model.default_cfg = _cfg() 468 | return model 469 | 470 | 471 | # Patch size 8x8 models 472 | @register_model 473 | def sima_nano_12_p8(pretrained=False, **kwargs): 474 | model = SimAVisionTransformer( 475 | patch_size=8, embed_dim=128, depth=12, num_heads=4, mlp_ratio=4, qkv_bias=True, 476 | norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1.0, tokens_norm=False, **kwargs) 477 | model.default_cfg = _cfg() 478 | return model 479 | 480 | 481 | @register_model 482 | def sima_tiny_12_p8(pretrained=False, **kwargs): 483 | model = SimAVisionTransformer( 484 | patch_size=8, embed_dim=192, depth=12, num_heads=4, mlp_ratio=4, qkv_bias=True, 485 | norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1.0, tokens_norm=True, **kwargs) 486 | model.default_cfg = _cfg() 487 | return model 488 | 489 | 490 | @register_model 491 | def sima_small_12_p8(pretrained=False, **kwargs): 492 | model = SimAVisionTransformer( 493 | patch_size=8, embed_dim=384, depth=12, num_heads=8, mlp_ratio=4, qkv_bias=True, 494 | norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1.0, tokens_norm=True, **kwargs) 495 | model.default_cfg = _cfg() 496 | return model 497 | 498 | 499 | @register_model 500 | def sima_tiny_24_p8(pretrained=False, **kwargs): 501 | model = SimAVisionTransformer( 502 | patch_size=8, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True, 503 | norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1e-5, tokens_norm=True, **kwargs) 504 | model.default_cfg = _cfg() 505 | return model 506 | 507 | 508 | @register_model 509 | def sima_small_24_p8(pretrained=False, **kwargs): 510 | model = SimAVisionTransformer( 511 | patch_size=8, embed_dim=384, depth=24, num_heads=8, mlp_ratio=4, qkv_bias=True, 512 | norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1e-5, tokens_norm=True, **kwargs) 513 | model.default_cfg = _cfg() 514 | return model 515 | 516 | 517 | @register_model 518 | def sima_medium_24_p8(pretrained=False, **kwargs): 519 | model = SimAVisionTransformer( 520 | patch_size=8, embed_dim=512, depth=24, num_heads=8, mlp_ratio=4, qkv_bias=True, 521 | norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1e-5, tokens_norm=True, **kwargs) 522 | model.default_cfg = _cfg() 523 | return model 524 | 525 | 526 | @register_model 527 | def sima_large_24_p8(pretrained=False, **kwargs): 528 | model = SimAVisionTransformer( 529 | patch_size=8, embed_dim=768, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 530 | norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1e-5, tokens_norm=True, **kwargs) 531 | model.default_cfg = _cfg() 532 | return model 533 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | class SmoothedValue(object): 19 | """Track a series of values and provide access to smoothed values over a 20 | window or the global series average. 21 | """ 22 | 23 | def __init__(self, window_size=20, fmt=None): 24 | if fmt is None: 25 | fmt = "{median:.4f} ({global_avg:.4f})" 26 | self.deque = deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | self.fmt = fmt 30 | 31 | def update(self, value, n=1): 32 | self.deque.append(value) 33 | self.count += n 34 | self.total += value * n 35 | 36 | def synchronize_between_processes(self): 37 | """ 38 | Warning: does not synchronize the deque! 39 | """ 40 | if not is_dist_avail_and_initialized(): 41 | return 42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | if isinstance(v, torch.Tensor): 88 | v = v.item() 89 | assert isinstance(v, (float, int)) 90 | self.meters[k].update(v) 91 | 92 | def __getattr__(self, attr): 93 | if attr in self.meters: 94 | return self.meters[attr] 95 | if attr in self.__dict__: 96 | return self.__dict__[attr] 97 | raise AttributeError("'{}' object has no attribute '{}'".format( 98 | type(self).__name__, attr)) 99 | 100 | def __str__(self): 101 | loss_str = [] 102 | for name, meter in self.meters.items(): 103 | loss_str.append( 104 | "{}: {}".format(name, str(meter)) 105 | ) 106 | return self.delimiter.join(loss_str) 107 | 108 | def synchronize_between_processes(self): 109 | for meter in self.meters.values(): 110 | meter.synchronize_between_processes() 111 | 112 | def add_meter(self, name, meter): 113 | self.meters[name] = meter 114 | 115 | def log_every(self, iterable, print_freq, header=None): 116 | i = 0 117 | if not header: 118 | header = '' 119 | start_time = time.time() 120 | end = time.time() 121 | iter_time = SmoothedValue(fmt='{avg:.4f}') 122 | data_time = SmoothedValue(fmt='{avg:.4f}') 123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 124 | log_msg = [ 125 | header, 126 | '[{0' + space_fmt + '}/{1}]', 127 | 'eta: {eta}', 128 | '{meters}', 129 | 'time: {time}', 130 | 'data: {data}' 131 | ] 132 | if torch.cuda.is_available(): 133 | log_msg.append('max mem: {memory:.0f}') 134 | log_msg = self.delimiter.join(log_msg) 135 | MB = 1024.0 * 1024.0 136 | for obj in iterable: 137 | data_time.update(time.time() - end) 138 | yield obj 139 | iter_time.update(time.time() - end) 140 | if i % print_freq == 0 or i == len(iterable) - 1: 141 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 143 | if torch.cuda.is_available(): 144 | print(log_msg.format( 145 | i, len(iterable), eta=eta_string, 146 | meters=str(self), 147 | time=str(iter_time), data=str(data_time), 148 | memory=torch.cuda.max_memory_allocated() / MB)) 149 | else: 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time))) 154 | i += 1 155 | end = time.time() 156 | total_time = time.time() - start_time 157 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 158 | print('{} Total time: {} ({:.4f} s / it)'.format( 159 | header, total_time_str, total_time / len(iterable))) 160 | 161 | 162 | def _load_checkpoint_for_ema(model_ema, checkpoint): 163 | """ 164 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 165 | """ 166 | mem_file = io.BytesIO() 167 | torch.save(checkpoint, mem_file) 168 | mem_file.seek(0) 169 | model_ema._load_checkpoint(mem_file) 170 | 171 | 172 | def setup_for_distributed(is_master): 173 | """ 174 | This function disables printing when not in master process 175 | """ 176 | import builtins as __builtin__ 177 | builtin_print = __builtin__.print 178 | 179 | def print(*args, **kwargs): 180 | force = kwargs.pop('force', False) 181 | if is_master or force: 182 | builtin_print(*args, **kwargs) 183 | 184 | __builtin__.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def _find_free_port(): 217 | import socket 218 | 219 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 220 | # Binding to port 0 will cause the OS to find an available port for us 221 | sock.bind(("", 0)) 222 | port = sock.getsockname()[1] 223 | sock.close() 224 | # NOTE: there is still a chance the port could be taken by other processes. 225 | return port 226 | 227 | 228 | def init_distributed_mode(args): 229 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 230 | args.rank = int(os.environ["RANK"]) 231 | args.world_size = int(os.environ['WORLD_SIZE']) 232 | args.gpu = int(os.environ['LOCAL_RANK']) 233 | elif 'SLURM_PROCID' in os.environ: 234 | args.rank = int(os.environ['SLURM_PROCID']) 235 | args.gpu = args.rank % torch.cuda.device_count() 236 | else: 237 | print('Not using distributed mode') 238 | args.distributed = False 239 | return 240 | 241 | args.distributed = True 242 | 243 | torch.cuda.set_device(args.gpu) 244 | args.dist_backend = 'nccl' 245 | print('| distributed init (rank {}): {}'.format( 246 | args.rank, args.dist_url), flush=True) 247 | 248 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 249 | world_size=args.world_size, rank=args.rank) 250 | torch.distributed.barrier() 251 | setup_for_distributed(args.rank == 0) 252 | --------------------------------------------------------------------------------