├── LICENSE ├── README.md ├── assets ├── acc.png └── backbone.png ├── cal_flops.py ├── datasets.py ├── engine.py ├── exp └── b1 │ └── run.sh ├── losses.py ├── main.py ├── model_utils.py ├── samplers.py ├── uninet.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 SenseTime X-Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Pytorch implementation of [UniNet (ECCV 2022)](https://arxiv.org/abs/2207.05420) 2 | 3 | ![tenser](assets/backbone.png) 4 | ![performance](assets/acc.png) 5 | 6 | This repo is the offcial implementation of the paper [UniNet: Unified Architecture Search with Convolution, Transformer, and MLP](https://arxiv.org/abs/2207.05420) 7 | 8 | ``` 9 | @article{UniNet, 10 | author = {Jihao Liu, Xin Huang, Guanglu Song, Yu Liu, Hongsheng Li}, 11 | journal = {arXiv:2207.05420}, 12 | title = {UniNet: Unified Architecture Search with Convolution, Transformer, and MLP}, 13 | year = {2022}, 14 | } 15 | ``` 16 | 17 | ### Update 18 | 20/12/2022 Update pretrained models. 19 | 20 | 25/10/2022 Update the source code. 21 | 22 | #### Environment 23 | The code is tested with ```torch==1.11``` and ```timm==0.5.4```. 24 | 25 | 26 | ### Availble models 27 | |Models | Params (M) | FLOPs (G) | Pretrain Epochs | Top-1 Acc. | ckpt | 28 | | :---: | :---: | :---: | :---: | :---: | :---: | 29 | | UniNet-B1 | 11.5 | 1.1 | 300 | 81.0 | [ckpt](https://drive.google.com/drive/folders/14gp-Vtmtd3MNNlrmYtF5FcUi0rm4CaGi?usp=share_link)| 30 | 31 | ### Run experiments 32 | 33 | Currently, we supporting running experiments with slurm. 34 | You can reproduce the results of UniNet-B1 as follows: 35 | 36 | ```sh exp/b1/run.sh partition 8``` 37 | -------------------------------------------------------------------------------- /assets/acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sense-X/UniNet/81ef33218caca5fb62e85b9cbf85283963f4cc00/assets/acc.png -------------------------------------------------------------------------------- /assets/backbone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sense-X/UniNet/81ef33218caca5fb62e85b9cbf85283963f4cc00/assets/backbone.png -------------------------------------------------------------------------------- /cal_flops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from fvcore.nn import FlopCountAnalysis 4 | from fvcore.nn import flop_count_table 5 | 6 | from timm.models import create_model 7 | from uninet import * 8 | 9 | input_size = 224 10 | 11 | model = create_model('UniNetB1') 12 | model.eval() 13 | 14 | flops = FlopCountAnalysis(model, torch.rand(1, 3, input_size, input_size)) 15 | print(flop_count_table(flops, max_depth=2)) 16 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import numpy as np 5 | from PIL import ImageDraw 6 | 7 | from torchvision import datasets, transforms 8 | from torchvision.datasets.folder import ImageFolder, default_loader 9 | 10 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, \ 11 | IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 12 | from timm.data import create_transform 13 | from timm.data.transforms import str_to_interp_mode 14 | 15 | 16 | class INatDataset(ImageFolder): 17 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 18 | category='name', loader=default_loader): 19 | self.transform = transform 20 | self.loader = loader 21 | self.target_transform = target_transform 22 | self.year = year 23 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 24 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 25 | with open(path_json) as json_file: 26 | data = json.load(json_file) 27 | 28 | with open(os.path.join(root, 'categories.json')) as json_file: 29 | data_catg = json.load(json_file) 30 | 31 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 32 | 33 | with open(path_json_for_targeter) as json_file: 34 | data_for_targeter = json.load(json_file) 35 | 36 | targeter = {} 37 | indexer = 0 38 | for elem in data_for_targeter['annotations']: 39 | king = [] 40 | king.append(data_catg[int(elem['category_id'])][category]) 41 | if king[0] not in targeter.keys(): 42 | targeter[king[0]] = indexer 43 | indexer += 1 44 | self.nb_classes = len(targeter) 45 | 46 | self.samples = [] 47 | for elem in data['images']: 48 | cut = elem['file_name'].split('/') 49 | target_current = int(cut[2]) 50 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 51 | 52 | categors = data_catg[target_current] 53 | target_current_true = targeter[categors[category]] 54 | self.samples.append((path_current, target_current_true)) 55 | 56 | # __getitem__ and __len__ inherited from ImageFolder 57 | 58 | 59 | def build_dataset(is_train, args): 60 | transform = build_transform(is_train, args) 61 | 62 | if args.data_set == 'CIFAR': 63 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 64 | nb_classes = 100 65 | elif args.data_set == 'IMNET': 66 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 67 | dataset = datasets.ImageFolder(root, transform=transform) 68 | nb_classes = 1000 69 | elif args.data_set == 'INAT': 70 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 71 | category=args.inat_category, transform=transform) 72 | nb_classes = dataset.nb_classes 73 | elif args.data_set == 'INAT19': 74 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 75 | category=args.inat_category, transform=transform) 76 | nb_classes = dataset.nb_classes 77 | 78 | return dataset, nb_classes 79 | 80 | 81 | def build_transform(is_train, args): 82 | resize_im = args.input_size > 32 83 | scale = getattr(args, 'scale', None) 84 | imagenet_default_mean_and_std = getattr(args, 'imagenet_default_mean_and_std', True) 85 | if is_train: 86 | # this should always dispatch to transforms_imagenet_train 87 | transform = create_transform( 88 | input_size=args.input_size, 89 | is_training=True, 90 | color_jitter=args.color_jitter, 91 | auto_augment=args.aa, 92 | interpolation=args.train_interpolation, 93 | re_prob=args.reprob, 94 | re_mode=args.remode, 95 | re_count=args.recount, 96 | scale=scale, 97 | mean=IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN, 98 | std=IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 99 | ) 100 | if not resize_im: 101 | # replace RandomResizedCropAndInterpolation with 102 | # RandomCrop 103 | transform.transforms[0] = transforms.RandomCrop( 104 | args.input_size, padding=4) 105 | return transform 106 | 107 | t = [] 108 | test_size = args.input_size 109 | crop = test_size < 320 110 | test_interpolation = str_to_interp_mode(getattr(args, 'test_interpolation', 'bicubic')) 111 | if resize_im: 112 | if crop: 113 | size = int((256 / 224) * test_size) 114 | t.append( 115 | transforms.Resize(size, interpolation=test_interpolation), # to maintain same ratio w.r.t. 224 images 116 | ) 117 | t.append(transforms.CenterCrop(test_size)) 118 | else: 119 | t.append( 120 | transforms.Resize((test_size,test_size), interpolation=test_interpolation), # to maintain same ratio w.r.t. 224 images 121 | ) 122 | 123 | t.append(transforms.ToTensor()) 124 | # t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 125 | if imagenet_default_mean_and_std: 126 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 127 | else: 128 | t.append(transforms.Normalize(IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD)) 129 | return transforms.Compose(t) 130 | 131 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from typing import Iterable, Optional 4 | from contextlib import suppress 5 | 6 | import torch 7 | 8 | from timm.data import Mixup 9 | from timm.utils import accuracy, ModelEma 10 | 11 | from losses import DistillationLoss 12 | import utils 13 | 14 | 15 | def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss, 16 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 17 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 18 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 19 | set_training_mode=True, tb_logger=None, start_idx=0, mimic=False, amp_autocast=suppress): 20 | model.train(set_training_mode) 21 | metric_logger = utils.MetricLogger(delimiter=" ") 22 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 23 | header = 'Epoch: [{}]'.format(epoch) 24 | print_freq = 10 25 | torch.cuda.synchronize() 26 | 27 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 28 | samples = samples.to(device, non_blocking=True) 29 | targets = targets.to(device, non_blocking=True) 30 | 31 | if mixup_fn is not None: 32 | samples, targets = mixup_fn(samples, targets) 33 | 34 | with amp_autocast(): 35 | outputs = model(samples) 36 | if mimic: 37 | loss = criterion(samples, outputs, targets) 38 | else: 39 | loss = criterion(outputs, targets) 40 | 41 | loss_value = loss.item() 42 | 43 | if not math.isfinite(loss_value): 44 | print("Loss is {}, stopping training".format(loss_value), flush=True) 45 | sys.exit(1) 46 | 47 | optimizer.zero_grad() 48 | 49 | # this attribute is added by timm on one optimizer (adahessian) 50 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 51 | loss_scaler(loss, optimizer, clip_grad=max_norm, 52 | parameters=model.parameters(), create_graph=is_second_order) 53 | 54 | torch.cuda.synchronize() 55 | if model_ema is not None: 56 | model_ema.update(model) 57 | 58 | metric_logger.update(loss=loss_value) 59 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 60 | if tb_logger is not None and utils.get_rank() == 0 and start_idx % 50 == 0: 61 | for k, meter in metric_logger.meters.items(): 62 | tb_logger.add_scalar('train/{}_avg'.format(k), meter.global_avg, start_idx) 63 | tb_logger.add_scalar('train/{}_val'.format(k), meter.value, start_idx) 64 | start_idx += 1 65 | # gather the stats from all processes 66 | metric_logger.synchronize_between_processes() 67 | print("Averaged stats:", metric_logger) 68 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 69 | 70 | 71 | @torch.no_grad() 72 | def evaluate(data_loader, model, device, model_ema=None, fp16=True, amp_autocast=suppress): 73 | criterion = torch.nn.CrossEntropyLoss() 74 | 75 | metric_logger = utils.MetricLogger(delimiter=" ") 76 | header = 'Test:' 77 | 78 | # switch to evaluation mode 79 | model.eval() 80 | 81 | for images, target in metric_logger.log_every(data_loader, 10, header): 82 | images = images.to(device, non_blocking=True) 83 | target = target.to(device, non_blocking=True) 84 | 85 | # compute output 86 | with amp_autocast(): 87 | output = model(images) 88 | loss = criterion(output, target) 89 | 90 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 91 | 92 | batch_size = images.shape[0] 93 | metric_logger.update(loss=loss.item()) 94 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 95 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 96 | # gather the stats from all processes 97 | metric_logger.synchronize_between_processes() 98 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 99 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 100 | 101 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 102 | -------------------------------------------------------------------------------- /exp/b1/run.sh: -------------------------------------------------------------------------------- 1 | work_path=$(dirname $0) 2 | filename=$(basename $work_path) 3 | partition=$1 4 | gpus=$2 5 | OMP_NUM_THREADS=1 \ 6 | srun -p ${partition} -n ${gpus} --ntasks-per-node=8 --cpus-per-task=16 --gres=gpu:8 \ 7 | python -u main.py \ 8 | --model UniNetB1 \ 9 | --input-size 224 \ 10 | --batch-size 128 \ 11 | --output_dir ${work_path}/ckpt \ 12 | --epochs 300 \ 13 | --dist-eval \ 14 | --drop-path 0.0 \ 15 | --reprob 0.0 \ 16 | --mixup 0.8 \ 17 | --cutmix 1.0 \ 18 | --num_workers 8 \ 19 | --port 29522 \ 20 | --resume ${work_path}/ckpt/checkpoint.pth 21 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from timm.loss import BinaryCrossEntropy 4 | 5 | def ce_loss(logit_p, logit_q): 6 | p = torch.softmax(logit_p, dim=1) 7 | log_q = torch.log_softmax(logit_q, dim=1) 8 | loss = (-p * log_q).sum(dim=1).mean() 9 | return loss 10 | 11 | class DistillationLoss(torch.nn.Module): 12 | """ 13 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 14 | taking a teacher model prediction and using it as additional supervision. 15 | """ 16 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 17 | distillation_type: str, alpha: float, tau: float, use_ce=False, distill_token=True): 18 | super().__init__() 19 | self.base_criterion = base_criterion 20 | self.teacher_model = teacher_model 21 | assert distillation_type in ['none', 'soft', 'hard'] 22 | self.distillation_type = distillation_type 23 | self.alpha = alpha 24 | self.tau = tau 25 | self.use_ce = use_ce 26 | self.distill_token = distill_token 27 | 28 | def forward(self, inputs, outputs, labels): 29 | """ 30 | Args: 31 | inputs: The original inputs that are feed to the teacher model 32 | outputs: the outputs of the model to be trained. It is expected to be 33 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 34 | in the first position and the distillation predictions as the second output 35 | labels: the labels for the base criterion 36 | """ 37 | outputs_kd = None 38 | if self.distill_token: 39 | if not isinstance(outputs, torch.Tensor): 40 | # assume that the model outputs a tuple of [outputs, outputs_kd] 41 | outputs, outputs_kd = outputs 42 | else: 43 | outputs_kd = outputs 44 | base_loss = self.base_criterion(outputs, labels) 45 | if self.distillation_type == 'none': 46 | return base_loss 47 | 48 | if outputs_kd is None: 49 | raise ValueError("When knowledge distillation is enabled, the model is " 50 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 51 | "class_token and the dist_token") 52 | # don't backprop throught the teacher 53 | with torch.no_grad(): 54 | teacher_outputs = self.teacher_model(inputs) 55 | 56 | if self.distillation_type == 'soft': 57 | if self.use_ce: 58 | # distillation_loss = BinaryCrossEntropy(smoothing=0)(outputs_kd, teacher_outputs) 59 | T = self.tau 60 | distillation_loss = ce_loss(teacher_outputs / T, outputs_kd / T) * T * T 61 | else: 62 | T = self.tau 63 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 64 | # with slight modifications 65 | distillation_loss = F.kl_div( 66 | F.log_softmax(outputs_kd / T, dim=1), 67 | F.log_softmax(teacher_outputs / T, dim=1), 68 | reduction='sum', 69 | log_target=True 70 | ) * (T * T) / outputs_kd.numel() 71 | elif self.distillation_type == 'hard': 72 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 73 | 74 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 75 | return loss 76 | 77 | 78 | class BCELossSmooth(torch.nn.Module): 79 | def __init__(self, base_criterion, smooth=0): 80 | super(BCELossSmooth, self).__init__() 81 | self.base_criterion = base_criterion 82 | self.smooth = smooth 83 | 84 | def forward(self, inputs, outputs, labels): 85 | batch_size, num_classes = outputs.shape 86 | labels = labels.unsqueeze(1) 87 | if self.smooth <= 0.0: 88 | targets = torch.zeros(batch_size, num_classes).cuda().scatter_(1, labels, 1) 89 | loss = self.base_criterion(outputs, targets) 90 | else: 91 | targets = torch.zeros(batch_size, num_classes).cuda().scatter_(1, labels, 1) 92 | targets = (targets + self.smooth).clamp(0, 1) 93 | loss = self.base_criterion(outputs, targets) 94 | return loss -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # ref: https://github.com/facebookresearch/deit 2 | import argparse 3 | import datetime 4 | import logging 5 | import os 6 | import random 7 | import numpy as np 8 | import time 9 | from contextlib import suppress 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 13 | from tensorboardX import SummaryWriter 14 | import json 15 | import shutil 16 | import warnings 17 | warnings.filterwarnings("ignore", category=UserWarning) 18 | 19 | from pathlib import Path 20 | 21 | from timm.data import Mixup 22 | from timm.models import create_model 23 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy 24 | from timm.scheduler import create_scheduler 25 | from timm.optim import create_optimizer 26 | from timm.utils import NativeScaler, get_state_dict, ModelEma, ApexScaler 27 | import timm.models 28 | from fvcore.nn import FlopCountAnalysis 29 | from fvcore.nn import flop_count_table 30 | 31 | from datasets import build_dataset 32 | from engine import train_one_epoch, evaluate 33 | from samplers import RASampler 34 | import utils 35 | import uninet 36 | 37 | try: 38 | from apex import amp 39 | from apex.parallel import DistributedDataParallel as ApexDDP 40 | 41 | has_apex = False 42 | except ImportError: 43 | has_apex = False 44 | 45 | def get_args_parser(): 46 | parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) 47 | parser.add_argument('--batch-size', default=64, type=int) 48 | parser.add_argument('--epochs', default=300, type=int) 49 | 50 | # Model parameters 51 | parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', 52 | help='Name of model to train') 53 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 54 | 55 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 56 | help='Dropout rate (default: 0.)') 57 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 58 | help='Drop path rate (default: 0.1)') 59 | 60 | parser.add_argument('--model-ema', action='store_true') 61 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 62 | parser.set_defaults(model_ema=True) 63 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 64 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 65 | 66 | # Optimizer parameters 67 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 68 | help='Optimizer (default: "fusedadamw"') 69 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 70 | help='Optimizer Epsilon (default: 1e-8)') 71 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 72 | help='Optimizer Betas (default: None, use opt default)') 73 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 74 | help='Clip gradient norm (default: None, no clipping)') 75 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 76 | help='SGD momentum (default: 0.9)') 77 | parser.add_argument('--weight-decay', type=float, default=0.05, 78 | help='weight decay (default: 0.05)') 79 | # Learning rate schedule parameters 80 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 81 | help='LR scheduler (default: "cosine"') 82 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 83 | help='learning rate (default: 5e-4)') 84 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 85 | help='learning rate noise on/off epoch percentages') 86 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 87 | help='learning rate noise limit percent (default: 0.67)') 88 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 89 | help='learning rate noise std-dev (default: 1.0)') 90 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 91 | help='warmup learning rate (default: 1e-6)') 92 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 93 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 94 | 95 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 96 | help='epoch interval to decay LR') 97 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 98 | help='epochs to warmup LR, if scheduler supports') 99 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 100 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 101 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 102 | help='patience epochs for Plateau LR scheduler (default: 10') 103 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 104 | help='LR decay rate (default: 0.1)') 105 | 106 | # Augmentation parameters 107 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 108 | help='Color jitter factor (default: 0.4)') 109 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 110 | help='Use AutoAugment policy. "v0" or "original". " + \ 111 | "(default: rand-m9-mstd0.5-inc1)'), 112 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 113 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 114 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 115 | 116 | parser.add_argument('--repeated-aug', action='store_true') 117 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 118 | parser.set_defaults(repeated_aug=False) 119 | 120 | parser.add_argument('--imagenet_default_mean_and_std', action='store_true') 121 | parser.add_argument('--no_imagenet_default_mean_and_std', action='store_false', dest='imagenet_default_mean_and_std') 122 | parser.set_defaults(imagenet_default_mean_and_std=True) 123 | 124 | # * Random Erase params 125 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 126 | help='Random erase prob (default: 0.25)') 127 | parser.add_argument('--remode', type=str, default='pixel', 128 | help='Random erase mode (default: "pixel")') 129 | parser.add_argument('--recount', type=int, default=1, 130 | help='Random erase count (default: 1)') 131 | parser.add_argument('--resplit', action='store_true', default=False, 132 | help='Do not random erase first (clean) augmentation split') 133 | 134 | parser.add_argument('--use-bce', action='store_true', default=False, 135 | help='use bce loss for mixup or cutmix') 136 | 137 | # * Mixup params 138 | parser.add_argument('--mixup', type=float, default=0.8, 139 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 140 | parser.add_argument('--cutmix', type=float, default=1.0, 141 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 142 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 143 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 144 | parser.add_argument('--mixup-prob', type=float, default=1.0, 145 | help='Probability of performing mixup or cutmix when either/both is enabled') 146 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 147 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 148 | parser.add_argument('--mixup-mode', type=str, default='batch', 149 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 150 | 151 | # * Finetuning params 152 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 153 | parser.add_argument('--ema-finetune', action='store_true', default=False, 154 | help='Enable tracking moving average of model weights') 155 | 156 | # Dataset parameters 157 | parser.add_argument('--data-path', default='/path/to/imagenet/', type=str, 158 | help='dataset path') 159 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], 160 | type=str, help='Image Net dataset path') 161 | parser.add_argument('--inat-category', default='name', 162 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 163 | type=str, help='semantic granularity') 164 | 165 | parser.add_argument('--output_dir', default='', 166 | help='path where to save, empty for no saving') 167 | parser.add_argument('--device', default='cuda', 168 | help='device to use for training / testing') 169 | parser.add_argument('--seed', default=0, type=int) 170 | parser.add_argument('--resume', default='', help='resume from checkpoint') 171 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 172 | help='start epoch') 173 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 174 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 175 | parser.add_argument('--num_workers', default=8, type=int) 176 | parser.add_argument('--pin-mem', action='store_true', 177 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 178 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 179 | help='') 180 | parser.set_defaults(pin_mem=True) 181 | 182 | # distributed training parameters 183 | parser.add_argument('--world_size', default=1, type=int, 184 | help='number of distributed processes') 185 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 186 | parser.add_argument('--port', default=29529, type=int, help='port') 187 | return parser 188 | 189 | 190 | def main(args): 191 | utils.init_distributed_mode(args, verbose=True) 192 | output_dir = Path(args.output_dir) 193 | tb_logger = None 194 | if utils.get_rank() == 0: 195 | tensorboard_path = os.path.join(output_dir, 'events') 196 | tb_logger = SummaryWriter(tensorboard_path) 197 | 198 | utils.init_log(__name__, log_file=os.path.join(output_dir, 'full_log.txt')) 199 | logger = logging.getLogger(__name__) 200 | print = logger.info 201 | 202 | print(args) 203 | 204 | device = torch.device(args.device) 205 | 206 | # fix the seed for reproducibility 207 | seed = args.seed + utils.get_rank() 208 | torch.manual_seed(seed) 209 | np.random.seed(seed) 210 | random.seed(seed) 211 | 212 | cudnn.benchmark = True 213 | 214 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 215 | dataset_val, _ = build_dataset(is_train=False, args=args) 216 | 217 | num_tasks = utils.get_world_size() 218 | global_rank = utils.get_rank() 219 | if args.repeated_aug: 220 | sampler_train = RASampler( 221 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 222 | ) 223 | else: 224 | sampler_train = torch.utils.data.DistributedSampler( 225 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 226 | ) 227 | if args.dist_eval: 228 | if len(dataset_val) % num_tasks != 0: 229 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 230 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 231 | 'equal num of samples per-process.') 232 | sampler_val = torch.utils.data.DistributedSampler( 233 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 234 | else: 235 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 236 | 237 | data_loader_train = torch.utils.data.DataLoader( 238 | dataset_train, sampler=sampler_train, 239 | batch_size=args.batch_size, 240 | num_workers=args.num_workers, 241 | pin_memory=args.pin_mem, 242 | drop_last=True, 243 | persistent_workers=True 244 | ) 245 | 246 | data_loader_val = torch.utils.data.DataLoader( 247 | dataset_val, sampler=sampler_val, 248 | batch_size=int(1.5 * args.batch_size), 249 | num_workers=args.num_workers, 250 | pin_memory=args.pin_mem, 251 | drop_last=False, 252 | persistent_workers=True 253 | ) 254 | 255 | mixup_fn = None 256 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 257 | if mixup_active: 258 | mixup_fn = Mixup( 259 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 260 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 261 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 262 | 263 | print(f"Creating model: {args.model}") 264 | model = create_model( 265 | args.model, 266 | pretrained=False, 267 | num_classes=args.nb_classes, 268 | drop_rate=args.drop, 269 | drop_path_rate=args.drop_path, 270 | drop_block_rate=None, 271 | ) 272 | 273 | if args.finetune: 274 | if args.finetune.startswith('https'): 275 | checkpoint = torch.hub.load_state_dict_from_url( 276 | args.finetune, map_location='cpu', check_hash=True) 277 | else: 278 | checkpoint = torch.load(args.finetune, map_location='cpu') 279 | 280 | if args.ema_finetune: 281 | checkpoint_model = checkpoint['state_dict_ema'] 282 | else: 283 | checkpoint_model = checkpoint['model'] 284 | state_dict = model.state_dict() 285 | # for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: 286 | # if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 287 | # print(f"Removing key {k} from pretrained checkpoint") 288 | # del checkpoint_model[k] 289 | # interpolate position embedding 290 | msg = model.load_state_dict(checkpoint_model, strict=False) 291 | print(msg) 292 | 293 | model.to(device) 294 | 295 | if utils.get_rank() == 0: 296 | model.eval() 297 | flops = FlopCountAnalysis(model, torch.rand(1, 3, args.input_size, args.input_size).to(device)) 298 | if args.rank == 0: 299 | print(flop_count_table(flops)) 300 | model.train() 301 | torch.distributed.barrier() 302 | 303 | model_ema = None 304 | if args.model_ema: 305 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 306 | model_ema = ModelEma( 307 | model, 308 | decay=args.model_ema_decay, 309 | device='cpu' if args.model_ema_force_cpu else 'cuda', 310 | resume=args.resume if os.path.isfile(args.resume) else '' 311 | ) 312 | 313 | model_without_ddp = model 314 | 315 | print(f'batch size {args.batch_size}, world size {utils.get_world_size()}') 316 | print(f'ori lr {args.lr}') 317 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 318 | args.lr = linear_scaled_lr 319 | print(f'scaled lr {args.lr}') 320 | optimizer = create_optimizer(args, model_without_ddp) 321 | 322 | amp_autocast = suppress 323 | if has_apex: 324 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 325 | loss_scaler = ApexScaler() 326 | model = ApexDDP(model, delay_allreduce=True) 327 | print('Using NVIDIA APEX AMP. Training in mixed precision.') 328 | else: 329 | amp_autocast = torch.cuda.amp.autocast 330 | loss_scaler = NativeScaler() 331 | model = NativeDDP(model, device_ids=[args.gpu], find_unused_parameters=False) 332 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 333 | print(f'number of params: {n_parameters}') 334 | 335 | torch.distributed.barrier() 336 | 337 | lr_scheduler, total_epochs = create_scheduler(args, optimizer) 338 | args.epochs = total_epochs 339 | 340 | criterion = LabelSmoothingCrossEntropy() 341 | 342 | if args.mixup > 0. or args.cutmix > 0.: 343 | # smoothing is handled with mixup label transform 344 | criterion = SoftTargetCrossEntropy() 345 | if args.use_bce: 346 | criterion = BinaryCrossEntropy() 347 | elif args.smoothing: 348 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 349 | else: 350 | criterion = torch.nn.CrossEntropyLoss() 351 | 352 | 353 | if args.resume and os.path.isfile(args.resume): 354 | print('>>>>>> resume from {}'.format(args.resume)) 355 | if args.resume.startswith('https'): 356 | checkpoint = torch.hub.load_state_dict_from_url( 357 | args.resume, map_location='cpu', check_hash=True) 358 | else: 359 | checkpoint = torch.load(args.resume, map_location='cpu') 360 | model_without_ddp.load_state_dict(checkpoint['model']) 361 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 362 | optimizer.load_state_dict(checkpoint['optimizer']) 363 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 364 | args.start_epoch = checkpoint['epoch'] + 1 365 | if 'scaler' in checkpoint: 366 | loss_scaler.load_state_dict(checkpoint['scaler']) 367 | 368 | if args.eval: 369 | test_stats = evaluate(data_loader_val, model, device, amp_autocast=amp_autocast) 370 | print(f"results: ") 371 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 372 | test_stats = evaluate(data_loader_val, model_ema.ema, device, amp_autocast=amp_autocast) 373 | print(f"Ema results: ") 374 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 375 | return 376 | 377 | print(f"Start training for {args.epochs} epochs") 378 | start_time = time.time() 379 | max_accuracy = 0.0 380 | start_idx = args.start_epoch * len(data_loader_train) 381 | for epoch in range(args.start_epoch, args.epochs): 382 | if args.distributed: 383 | data_loader_train.sampler.set_epoch(epoch) 384 | 385 | train_stats = train_one_epoch( 386 | model, criterion, data_loader_train, 387 | optimizer, device, epoch, loss_scaler, 388 | args.clip_grad, model_ema, mixup_fn, 389 | # set_training_mode=args.finetune == '', # keep in eval mode during finetuning 390 | tb_logger=tb_logger, start_idx=start_idx, 391 | amp_autocast=amp_autocast 392 | ) 393 | start_idx += len(data_loader_train) 394 | 395 | lr_scheduler.step(epoch) 396 | if args.output_dir: 397 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 398 | for checkpoint_path in checkpoint_paths: 399 | utils.save_on_master({ 400 | 'model': model_without_ddp.state_dict(), 401 | 'optimizer': optimizer.state_dict(), 402 | 'lr_scheduler': lr_scheduler.state_dict(), 403 | 'epoch': epoch, 404 | 'state_dict_ema': get_state_dict(model_ema), 405 | 'scaler': loss_scaler.state_dict(), 406 | 'args': args, 407 | }, checkpoint_path) 408 | 409 | # if epoch % 5 == 0: 410 | test_stats = evaluate(data_loader_val, model, device, amp_autocast=amp_autocast) 411 | print(f"[Epoch {epoch}] Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 412 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 413 | 414 | test_stats_ema = evaluate(data_loader_val, model_ema.ema, device, amp_autocast=amp_autocast) 415 | print( 416 | f"[Epoch {epoch}] [EMA result] Accuracy of the network on the {len(dataset_val)} test images: {test_stats_ema['acc1']:.1f}%") 417 | max_accuracy = max(max_accuracy, test_stats_ema["acc1"]) 418 | print(f'Max accuracy: {max_accuracy:.2f}%') 419 | 420 | # save best ckpt 421 | if (max_accuracy == test_stats["acc1"] or max_accuracy == test_stats_ema["acc1"]) and utils.get_rank() == 0: 422 | checkpoint_path = output_dir / 'checkpoint.pth' 423 | shutil.copy2(checkpoint_path, output_dir / 'ckpt_best.pth') 424 | 425 | if utils.get_rank() == 0: 426 | for k, v in test_stats.items(): 427 | tb_logger.add_scalar('test/{}'.format(k), v, epoch) 428 | for k, v in test_stats_ema.items(): 429 | tb_logger.add_scalar('test_ema/{}'.format(k), v, epoch) 430 | tb_logger.add_scalar('test/max_accuracy', max_accuracy, epoch) 431 | tb_logger.flush() 432 | 433 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 434 | **{f'test_{k}': v for k, v in test_stats.items()}, 435 | **{f'test_ema_{k}': v for k, v in test_stats_ema.items()}, 436 | 'epoch': epoch, 437 | 'n_parameters': n_parameters} 438 | 439 | if args.output_dir and utils.is_main_process(): 440 | with open(os.path.join(output_dir, "log.txt"), 'a') as f: 441 | f.write(json.dumps(log_stats) + "\n") 442 | 443 | total_time = time.time() - start_time 444 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 445 | print('Training time {}'.format(total_time_str)) 446 | 447 | 448 | if __name__ == '__main__': 449 | parser = argparse.ArgumentParser('UniNet training and evaluation script', parents=[get_args_parser()]) 450 | args = parser.parse_args() 451 | if args.output_dir: 452 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 453 | main(args) 454 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def _make_divisible(v, divisor, min_value=None): 7 | if min_value is None: 8 | min_value = divisor 9 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 10 | # Make sure that round down does not go down by more than 10%. 11 | if new_v < 0.9 * v: 12 | new_v += divisor 13 | return new_v 14 | 15 | def to4d(x): 16 | if len(x.shape) == 4: 17 | return x 18 | B, N, C = x.shape 19 | h = int(N ** 0.5) 20 | return x.transpose(1, 2).reshape(B, C, h, h) 21 | 22 | 23 | def to3d(x): 24 | if len(x.shape) == 3: 25 | return x 26 | B, C, h, w = x.shape 27 | N = h * w 28 | return x.reshape(B, C, N).transpose(1, 2) 29 | 30 | 31 | class SqueezeExcitation(nn.Module): 32 | 33 | def __init__(self, input_channels: int, squeeze_factor: int = 4): 34 | super().__init__() 35 | squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) 36 | self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) 39 | 40 | def _scale(self, input, inplace): 41 | scale = F.adaptive_avg_pool2d(input, 1) 42 | scale = self.fc1(scale) 43 | scale = self.relu(scale) 44 | scale = self.fc2(scale) 45 | # return F.hardsigmoid(scale, inplace=inplace) 46 | return hard_sigmoid(scale, inplace=inplace) 47 | 48 | def forward(self, input): 49 | scale = self._scale(input, True) 50 | return scale * input 51 | 52 | 53 | def hard_sigmoid(x, inplace=False): 54 | return F.relu6(x + 3, inplace) / 6 55 | 56 | -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import math 4 | 5 | 6 | class RASampler(torch.utils.data.Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset for distributed, 8 | with repeated augmentation. 9 | It ensures that different each augmented version of a sample will be visible to a 10 | different process (GPU) 11 | Heavily based on torch.utils.data.DistributedSampler 12 | """ 13 | 14 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 15 | if num_replicas is None: 16 | if not dist.is_available(): 17 | raise RuntimeError("Requires distributed package to be available") 18 | num_replicas = dist.get_world_size() 19 | if rank is None: 20 | if not dist.is_available(): 21 | raise RuntimeError("Requires distributed package to be available") 22 | rank = dist.get_rank() 23 | self.dataset = dataset 24 | self.num_replicas = num_replicas 25 | self.rank = rank 26 | self.epoch = 0 27 | self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas)) 28 | self.total_size = self.num_samples * self.num_replicas 29 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 30 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 31 | self.shuffle = shuffle 32 | 33 | def __iter__(self): 34 | # deterministically shuffle based on epoch 35 | g = torch.Generator() 36 | g.manual_seed(self.epoch) 37 | if self.shuffle: 38 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 39 | else: 40 | indices = list(range(len(self.dataset))) 41 | 42 | # add extra samples to make it evenly divisible 43 | indices = [ele for ele in indices for i in range(3)] 44 | indices += indices[:(self.total_size - len(indices))] 45 | assert len(indices) == self.total_size 46 | 47 | # subsample 48 | indices = indices[self.rank:self.total_size:self.num_replicas] 49 | assert len(indices) == self.num_samples 50 | 51 | return iter(indices[:self.num_selected_samples]) 52 | 53 | def __len__(self): 54 | return self.num_selected_samples 55 | 56 | def set_epoch(self, epoch): 57 | self.epoch = epoch 58 | -------------------------------------------------------------------------------- /uninet.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import logging 4 | from functools import partial 5 | from collections import OrderedDict 6 | from copy import deepcopy 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn.init import zeros_ 12 | import torch.utils.checkpoint as checkpoint 13 | 14 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 15 | from timm.models.registry import register_model 16 | from model_utils import _make_divisible, SqueezeExcitation, to3d, to4d 17 | 18 | 19 | class LocalDSM(nn.Module): 20 | def __init__(self, in_features, out_features, stride=1, mlp_ratio=4, head_dim=32, 21 | qkv_bias=True, qk_scale=None, drop=0., drop_path=0., attn_drop=0., seq_l=196): 22 | super().__init__() 23 | h = w = int(seq_l ** 0.5) 24 | new_h = new_w = math.ceil(h / stride) 25 | self.h = h 26 | self.new_h = new_h 27 | self.new_N = self.new_h * self.new_h 28 | if stride == 1: 29 | self.residual = nn.Sequential( 30 | nn.Conv2d(in_features, out_features, kernel_size=1) 31 | ) 32 | else: 33 | self.residual = nn.Sequential( 34 | nn.MaxPool2d(kernel_size=3, stride=stride, padding=1), 35 | nn.Conv2d(in_features, out_features, kernel_size=1) 36 | ) 37 | self.downsample = nn.Sequential( 38 | nn.Conv2d(in_features, in_features, kernel_size=(3, 3), padding=(1, 1), 39 | groups=in_features, stride=stride, bias=True), 40 | nn.BatchNorm2d(in_features), 41 | nn.Conv2d(in_features, out_features, kernel_size=1, bias=True), 42 | ) 43 | 44 | def forward(self, x): 45 | x_shape = x.shape 46 | if len(x_shape) == 3: 47 | x = to4d(x) 48 | return self.downsample(x) + self.residual(x) 49 | 50 | 51 | class LocalGlobalDSM(nn.Module): 52 | def __init__(self, in_features, out_features, stride=1, mlp_ratio=4, head_dim=32, 53 | qkv_bias=True, qk_scale=None, drop=0., drop_path=0., attn_drop=0., seq_l=196): 54 | super().__init__() 55 | out_dim = out_features or in_features 56 | self.num_heads = out_features // head_dim 57 | self.head_dim = head_dim 58 | self.out_features = out_features 59 | self.scale = qk_scale or head_dim ** -0.5 60 | 61 | h = w = int(seq_l ** 0.5) 62 | new_h = new_w = math.ceil(h / stride) 63 | self.h = h 64 | self.new_h = new_h 65 | self.new_N = self.new_h * self.new_h 66 | 67 | if stride == 1: 68 | self.residual = nn.Sequential( 69 | nn.Conv2d(in_features, out_features, kernel_size=1) 70 | ) 71 | else: 72 | self.residual = nn.Sequential( 73 | nn.MaxPool2d(kernel_size=3, stride=stride, padding=1), 74 | nn.Conv2d(in_features, out_features, kernel_size=1) 75 | ) 76 | 77 | self.q = nn.Sequential( 78 | nn.Conv2d(in_features, in_features, kernel_size=(3, 3), padding=(1, 1), 79 | groups=in_features, stride=stride, bias=True), 80 | nn.BatchNorm2d(in_features), 81 | nn.Conv2d(in_features, out_features, kernel_size=1, bias=True), 82 | ) 83 | 84 | self.q_norm = nn.LayerNorm(out_features) 85 | self.kv_norm = nn.LayerNorm(in_features) 86 | self.kv = nn.Linear(in_features, out_features * 2, bias=qkv_bias) 87 | self.proj = nn.Linear(out_features, out_features) 88 | 89 | def forward(self, x): 90 | x_shape = x.shape 91 | if len(x.shape) == 3: 92 | B, N, C = x_shape 93 | else: 94 | B, C, H, W = x_shape 95 | N = H * W 96 | 97 | x = to4d(x) 98 | residual = to3d(self.residual(x)) 99 | q = to3d(self.q(x)) 100 | x = to3d(x) 101 | 102 | q = self.q_norm(q) 103 | x = self.kv_norm(x) 104 | q = q.reshape(B, self.new_N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 105 | kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 106 | k, v = kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple) 107 | 108 | attn = (q @ k.transpose(-2, -1)) * self.scale 109 | attn = attn.softmax(dim=-1) 110 | 111 | x = (attn @ v).transpose(1, 2).reshape(B, self.new_N, self.out_features) 112 | 113 | x = self.proj(x) 114 | return x + residual 115 | 116 | 117 | class DWConvBlock(nn.Module): 118 | def __init__(self, in_features, out_features=None, stride=1, 119 | mlp_ratio=4, use_se=True, drop=0., drop_path=0., 120 | seq_l=196, head_dim=32, init_values=1e-6, **kwargs): 121 | super().__init__() 122 | out_features = out_features or in_features 123 | hidden_features = int(in_features * mlp_ratio) 124 | if in_features != out_features or stride != 1: 125 | self.residual = nn.Sequential( 126 | nn.MaxPool2d(kernel_size=3, stride=stride, padding=1), 127 | nn.Conv2d(in_features, out_features, kernel_size=1) 128 | ) 129 | else: 130 | self.residual = nn.Identity() 131 | 132 | self.b1 = None 133 | if in_features != hidden_features or stride != 1: 134 | layers_b1 = [] 135 | layers_b1.append(nn.BatchNorm2d(in_features)) 136 | layers_b1.append(nn.Conv2d(in_features, hidden_features, kernel_size=(1, 1), 137 | stride=1, padding=(0, 0), bias=False)) 138 | layers_b1.append(nn.BatchNorm2d(hidden_features)) 139 | layers_b1.append(nn.GELU()) 140 | self.b1 = nn.Sequential(*layers_b1) 141 | 142 | layers = [] 143 | layers.append(nn.Conv2d(hidden_features, hidden_features, kernel_size=(3, 3), padding=(1, 1), 144 | groups=hidden_features, stride=stride, bias=False)) 145 | layers.append(nn.BatchNorm2d(hidden_features)) 146 | layers.append(nn.GELU()) 147 | if use_se: 148 | layers.append(SqueezeExcitation(hidden_features)) 149 | 150 | layers.append(nn.Conv2d(hidden_features, out_features, kernel_size=(1, 1), padding=(0, 0))) 151 | layers.append(nn.BatchNorm2d(out_features)) 152 | self.b2 = nn.Sequential(*layers) 153 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 154 | if init_values != -1: 155 | zeros_(self.b2[-1].weight) 156 | 157 | def forward(self, x): 158 | residual = self.residual(x) 159 | if self.b1 is not None: 160 | x = self.b1(x) 161 | x = self.b2(x) 162 | 163 | return residual + self.drop_path(x) 164 | 165 | 166 | class Attention(nn.Module): 167 | def __init__(self, dim, out_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., 168 | seq_l=196, fp32_attn=False): 169 | super().__init__() 170 | out_dim = out_dim or dim 171 | self.num_heads = num_heads 172 | head_dim = dim // num_heads 173 | self.scale = qk_scale or head_dim ** -0.5 174 | 175 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 176 | self.attn_drop = nn.Dropout(attn_drop) 177 | self.proj = nn.Linear(dim, out_dim) 178 | self.proj_drop = nn.Dropout(proj_drop) 179 | self.custom_flops = 2 * seq_l * seq_l * dim 180 | self.fp32_attn = fp32_attn 181 | 182 | def forward(self, x, head=0, mask_type=None): 183 | B, N, C = x.shape 184 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 185 | if self.fp32_attn: 186 | q, k, v = qkv[0].float(), qkv[1].float(), qkv[2].float() 187 | else: 188 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 189 | 190 | attn = (q @ k.transpose(-2, -1)) * self.scale 191 | attn = attn.softmax(dim=-1) 192 | attn = self.attn_drop(attn) 193 | 194 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 195 | if self.fp32_attn: 196 | x = x.to(self.proj.weight.dtype) 197 | x = self.proj(x) 198 | x = self.proj_drop(x) 199 | return x 200 | 201 | 202 | class AttentionBlock(nn.Module): 203 | 204 | def __init__(self, in_features, out_features, stride=1, mlp_ratio=4, head_dim=32, 205 | qkv_bias=True, qk_scale=None, drop=0., drop_path=0., attn_drop=0., seq_l=196, 206 | init_values=1e-6, fp32_attn=False,): 207 | super().__init__() 208 | self.norm1 = nn.LayerNorm(in_features) 209 | self.stride = stride 210 | self.in_features = in_features 211 | self.out_features = out_features 212 | mlp_hidden_dim = int(in_features * mlp_ratio) 213 | num_heads = in_features // head_dim 214 | self.init_values = init_values 215 | self.pos_embed = nn.Conv2d(in_features, in_features, 3, padding=1, groups=in_features) 216 | self.attn = Attention(in_features, out_features, num_heads=num_heads, qkv_bias=qkv_bias, 217 | qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, seq_l=seq_l, 218 | fp32_attn=fp32_attn) 219 | if stride != 1 or in_features != out_features: 220 | self.ds = nn.MaxPool2d(kernel_size=3, stride=stride, padding=1) 221 | self.residual = nn.Sequential( 222 | nn.MaxPool2d(kernel_size=3, stride=stride, padding=1), 223 | nn.Conv2d(in_features, out_features, kernel_size=1) 224 | ) 225 | if init_values != -1: 226 | self.gamma_1 = nn.Parameter(init_values * torch.ones((out_features)), requires_grad=True) 227 | else: 228 | self.norm2 = nn.LayerNorm(in_features) 229 | self.mlp = Mlp(in_features=in_features, hidden_features=mlp_hidden_dim, act_layer=nn.GELU, drop=drop) 230 | if init_values != -1: 231 | self.gamma_1 = nn.Parameter(init_values * torch.ones((out_features)), requires_grad=True) 232 | self.gamma_2 = nn.Parameter(init_values * torch.ones((out_features)), requires_grad=True) 233 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 234 | 235 | 236 | def forward(self, x, head=0, mask_type=None): 237 | x = x + to3d(self.pos_embed(to4d(x))) 238 | if self.stride == 1 and self.in_features == self.out_features: 239 | if self.init_values != -1: 240 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), head=head, mask_type=mask_type)) 241 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 242 | else: 243 | x = x + self.drop_path(self.attn(self.norm1(x), head=head, mask_type=mask_type)) 244 | x = x + self.drop_path(self.mlp(self.norm2(x))) 245 | else: 246 | residual = to3d(self.residual(to4d(x))) 247 | x = self.norm1(x) 248 | x = to3d(self.ds(to4d(x))) 249 | x = self.attn(x) 250 | if self.init_values != -1: 251 | x = residual + self.gamma_1 * x 252 | else: 253 | x = residual + x 254 | return x 255 | 256 | 257 | class VisionTransformer(nn.Module): 258 | def __init__(self, repeats, expansion, channels, strides=[1, 2, 2, 2, 1, 2], num_classes=1000, drop_path_rate=0., 259 | input_size=224, weight_init='', head_dim=32, final_head_dim=1280, final_drop=0.0, init_values=1e-6, 260 | block_ops=[DWConvBlock] * 3 + [AttentionBlock] * 3, checkpoint=0, stem_dim=32, 261 | ds_ops=[LocalDSM] * 3 + [LocalGlobalDSM] * 2, **kwargs): 262 | super().__init__() 263 | self.num_classes = num_classes 264 | self.checkpoint = checkpoint 265 | 266 | # stem_dim = 32 267 | h = w = input_size 268 | self.stem = nn.Sequential( 269 | nn.Conv2d(3, stem_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), 270 | nn.BatchNorm2d(stem_dim), 271 | nn.GELU(), 272 | ) 273 | 274 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(repeats))] # stochastic depth decay rule 275 | dpr.reverse() 276 | 277 | cin = stem_dim 278 | blocks = [] 279 | h = w = math.ceil(h / 2) 280 | seq_l = h * w 281 | for stage in range(len(strides)): 282 | cout = channels[stage] 283 | block_op = block_ops[stage] 284 | # print(f'stage {stage}, cin {cin}, cout {cout}, s {strides[stage]}, e {expansion[stage]} b {block_op}') 285 | 286 | if stage != 0: 287 | blocks.append(ds_ops[stage - 1](cin, cout, stride=strides[stage], seq_l=seq_l, head_dim=head_dim)) 288 | h = w = math.ceil(h / strides[stage]) 289 | seq_l = h * w 290 | cin = cout 291 | 292 | # cin = cout 293 | for i in range(repeats[stage]): 294 | stride = strides[stage] if i == 0 else 1 295 | blocks.append(block_op(cin, cout, stride=1, mlp_ratio=expansion[stage], 296 | drop_path=dpr.pop(), seq_l=seq_l, head_dim=head_dim, 297 | init_values=init_values)) 298 | cin = cout 299 | self.blocks = nn.Sequential(*blocks) 300 | 301 | head_dim = final_head_dim 302 | self.head = nn.Sequential( 303 | nn.Conv2d(cout, head_dim, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=False), 304 | nn.BatchNorm2d(head_dim), 305 | nn.GELU(), 306 | ) 307 | self.final_drop = nn.Dropout(final_drop) if final_drop > 0.0 else nn.Identity() 308 | self.avgpool = nn.AdaptiveAvgPool2d(1) 309 | self.classifier = nn.Linear(head_dim, num_classes) 310 | 311 | head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. 312 | # Weight init 313 | assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') 314 | if weight_init.startswith('jax'): 315 | # leave cls token as zeros to match jax impl 316 | for n, m in self.named_modules(): 317 | _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) 318 | else: 319 | # trunc_normal_(self.cls_token, std=.02) 320 | self.apply(_init_vit_weights) 321 | 322 | def _init_weights(self, m): 323 | # this fn left here for compat with downstream users 324 | _init_vit_weights(m) 325 | 326 | @torch.jit.ignore 327 | def no_weight_decay(self): 328 | return {'pos_embed', 'dist_token'} 329 | 330 | def forward_features(self, x): 331 | 332 | x = self.stem(x) 333 | for i, blk in enumerate(self.blocks): 334 | if isinstance(blk, DWConvBlock): 335 | x = to4d(x) 336 | if isinstance(blk, AttentionBlock): 337 | x = to3d(x) 338 | if i < self.checkpoint and x.requires_grad: 339 | x = checkpoint.checkpoint(blk, x) 340 | else: 341 | x = blk(x) 342 | x = to4d(x) 343 | x = self.head(x) 344 | x = self.avgpool(x) 345 | return torch.flatten(x, 1) 346 | 347 | def forward(self, x): 348 | x = self.forward_features(x) 349 | x = self.final_drop(x) 350 | x = self.classifier(x) 351 | return x 352 | 353 | 354 | def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False): 355 | if isinstance(m, nn.Linear): 356 | if n.startswith('head'): 357 | nn.init.zeros_(m.weight) 358 | nn.init.constant_(m.bias, head_bias) 359 | elif n.startswith('pre_logits'): 360 | lecun_normal_(m.weight) 361 | nn.init.zeros_(m.bias) 362 | else: 363 | if jax_impl: 364 | nn.init.xavier_uniform_(m.weight) 365 | if m.bias is not None: 366 | if 'mlp' in n: 367 | nn.init.normal_(m.bias, std=1e-6) 368 | else: 369 | nn.init.zeros_(m.bias) 370 | else: 371 | trunc_normal_(m.weight, std=.02) 372 | if m.bias is not None: 373 | nn.init.zeros_(m.bias) 374 | elif jax_impl and isinstance(m, nn.Conv2d): 375 | # NOTE conv was left to pytorch default in my original init 376 | lecun_normal_(m.weight) 377 | if m.bias is not None: 378 | nn.init.zeros_(m.bias) 379 | elif isinstance(m, nn.LayerNorm): 380 | nn.init.zeros_(m.bias) 381 | nn.init.ones_(m.weight) 382 | 383 | 384 | @register_model 385 | def UniNetB0(**kwargs): # 11.451M, 0.555G, 160 386 | repeats = [1, 2, 4, 4, 4, 8] 387 | expansion = [1, 4, 6, 3, 2, 5] 388 | channels = [24, 48, 80, 128, 128, 256] 389 | final_drop = 0.0 390 | block_ops = [DWConvBlock] * 4 + [AttentionBlock] * 2 391 | input_size = 160 392 | 393 | print(f'channels {channels}, repeats {repeats}, expansion {expansion}, block_ops {block_ops}') 394 | model_kwargs = dict(repeats=repeats, expansion=expansion, channels=channels, 395 | block_ops=block_ops, final_drop=final_drop, 396 | input_size=input_size, **kwargs) 397 | model = VisionTransformer(**model_kwargs) 398 | return model 399 | 400 | @register_model 401 | def UniNetB1(**kwargs): # 11.451M, 1.118G, 224 402 | repeats = [1, 2, 4, 4, 4, 8] 403 | expansion = [1, 4, 6, 3, 2, 5] 404 | channels = [24, 48, 80, 128, 128, 256] 405 | final_drop = 0.0 406 | block_ops = [DWConvBlock] * 4 + [AttentionBlock] * 2 407 | input_size = 224 408 | 409 | print(f'channels {channels}, repeats {repeats}, expansion {expansion}, block_ops {block_ops}') 410 | model_kwargs = dict(repeats=repeats, expansion=expansion, channels=channels, 411 | block_ops=block_ops, final_drop=final_drop, 412 | input_size=input_size, **kwargs) 413 | model = VisionTransformer(**model_kwargs) 414 | return model 415 | 416 | @register_model 417 | def UniNetB2(**kwargs): # 16.211M, 2.159G, 256 418 | repeats = [2, 3, 6, 6, 6, 12] 419 | expansion = [1, 4, 6, 3, 2, 5] 420 | channels = [24, 48, 80, 128, 128, 256] 421 | final_drop = 0.0 422 | block_ops = [DWConvBlock] * 4 + [AttentionBlock] * 2 423 | input_size = 256 424 | 425 | print(f'channels {channels}, repeats {repeats}, expansion {expansion}, block_ops {block_ops}') 426 | model_kwargs = dict(repeats=repeats, expansion=expansion, channels=channels, 427 | block_ops=block_ops, final_drop=final_drop, 428 | input_size=input_size, **kwargs) 429 | model = VisionTransformer(**model_kwargs) 430 | return model 431 | 432 | @register_model 433 | def UniNetB3(**kwargs): # 24.02M, 4.258G, 288 434 | repeats = [2, 3, 7, 7, 7, 14] 435 | expansion = [1, 4, 6, 3, 2, 5] 436 | channels = [24, 56, 96, 160, 160, 288] 437 | final_drop = 0.0 438 | block_ops = [DWConvBlock] * 4 + [AttentionBlock] * 2 439 | input_size = 288 440 | 441 | print(f'channels {channels}, repeats {repeats}, expansion {expansion}, block_ops {block_ops}') 442 | model_kwargs = dict(repeats=repeats, expansion=expansion, channels=channels, 443 | block_ops=block_ops, final_drop=final_drop, 444 | input_size=input_size, **kwargs) 445 | model = VisionTransformer(**model_kwargs) 446 | return model 447 | 448 | 449 | @register_model 450 | def UniNetB4(**kwargs): # 43.796M, 9.429G, 320 451 | repeats = [2, 4, 9, 9, 9, 18] 452 | expansion = [1, 4, 6, 3, 2, 5] 453 | channels = [32, 64, 112, 192, 192, 352] 454 | final_drop = 0.0 455 | block_ops = [DWConvBlock] * 4 + [AttentionBlock] * 2 456 | input_size = 320 457 | 458 | print(f'channels {channels}, repeats {repeats}, expansion {expansion}, block_ops {block_ops}') 459 | model_kwargs = dict(repeats=repeats, expansion=expansion, channels=channels, 460 | block_ops=block_ops, final_drop=final_drop, 461 | input_size=input_size, **kwargs) 462 | model = VisionTransformer(**model_kwargs) 463 | return model 464 | 465 | @register_model 466 | def UniNetB5(**kwargs): # 72.883M, *, 320 467 | repeats = [3, 5, 10, 10, 10, 20] 468 | expansion = [1, 4, 6, 3, 2, 5] 469 | channels = [32, 64, 112, 224, 224, 448] 470 | final_drop = 0.0 471 | block_ops = [DWConvBlock] * 4 + [AttentionBlock] * 2 472 | input_size = 384 473 | 474 | print(f'channels {channels}, repeats {repeats}, expansion {expansion}, block_ops {block_ops}') 475 | model_kwargs = dict(repeats=repeats, expansion=expansion, channels=channels, 476 | block_ops=block_ops, final_drop=final_drop, 477 | input_size=input_size, **kwargs) 478 | model = VisionTransformer(**model_kwargs) 479 | return model 480 | 481 | @register_model 482 | def UniNetB6(**kwargs): # 117M, *, 320 483 | repeats = [4, 6, 12, 12, 12, 24] 484 | expansion = [1, 4, 6, 3, 2, 5] 485 | channels = [48, 96, 160, 256, 256, 512] 486 | final_drop = 0.0 487 | block_ops = [DWConvBlock] * 4 + [AttentionBlock] * 2 488 | input_size = 448 489 | 490 | print(f'channels {channels}, repeats {repeats}, expansion {expansion}, block_ops {block_ops}') 491 | model_kwargs = dict(repeats=repeats, expansion=expansion, channels=channels, 492 | block_ops=block_ops, final_drop=final_drop, 493 | input_size=input_size, **kwargs) 494 | model = VisionTransformer(**model_kwargs) 495 | return model 496 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc functions, including distributed helpers. 3 | 4 | Mostly copy-paste from torchvision references. 5 | """ 6 | import io 7 | import os 8 | import time 9 | import logging 10 | import random 11 | import subprocess 12 | from collections import defaultdict, deque 13 | import datetime 14 | 15 | import torch 16 | import torch.distributed as dist 17 | 18 | logs = set() 19 | 20 | 21 | def init_log(name, level=logging.INFO, log_file=None): 22 | if (name, level) in logs: 23 | return 24 | logs.add((name, level)) 25 | logger = logging.getLogger(name) 26 | logger.setLevel(level) 27 | ch = logging.StreamHandler() 28 | ch.setLevel(level) 29 | rank = get_rank() 30 | logger.addFilter(lambda record: rank == 0) 31 | format_str = f'%(asctime)s-rk{rank}-%(filename)s#%(lineno)d:%(message)s' 32 | formatter = logging.Formatter(format_str) 33 | print('****** init log ', __name__) 34 | if log_file and rank == 0: 35 | print('[rank {}] log to {}'.format(rank, log_file)) 36 | fileHandler = logging.FileHandler(log_file, 'a') 37 | fileHandler.setFormatter(formatter) 38 | logger.addHandler(fileHandler) 39 | 40 | ch.setFormatter(formatter) 41 | logger.addHandler(ch) 42 | 43 | class SmoothedValue(object): 44 | """Track a series of values and provide access to smoothed values over a 45 | window or the global series average. 46 | """ 47 | 48 | def __init__(self, window_size=20, fmt=None): 49 | if fmt is None: 50 | fmt = "{median:.4f} ({global_avg:.4f})" 51 | self.deque = deque(maxlen=window_size) 52 | self.total = 0.0 53 | self.count = 0 54 | self.fmt = fmt 55 | 56 | def update(self, value, n=1): 57 | self.deque.append(value) 58 | self.count += n 59 | self.total += value * n 60 | 61 | def synchronize_between_processes(self): 62 | """ 63 | Warning: does not synchronize the deque! 64 | """ 65 | if not is_dist_avail_and_initialized(): 66 | return 67 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 68 | dist.barrier() 69 | dist.all_reduce(t) 70 | t = t.tolist() 71 | self.count = int(t[0]) 72 | self.total = t[1] 73 | 74 | @property 75 | def median(self): 76 | d = torch.tensor(list(self.deque)) 77 | return d.median().item() 78 | 79 | @property 80 | def avg(self): 81 | d = torch.tensor(list(self.deque), dtype=torch.float32) 82 | return d.mean().item() 83 | 84 | @property 85 | def global_avg(self): 86 | return self.total / self.count 87 | 88 | @property 89 | def max(self): 90 | return max(self.deque) 91 | 92 | @property 93 | def value(self): 94 | return self.deque[-1] 95 | 96 | def __str__(self): 97 | return self.fmt.format( 98 | median=self.median, 99 | avg=self.avg, 100 | global_avg=self.global_avg, 101 | max=self.max, 102 | value=self.value) 103 | 104 | 105 | class MetricLogger(object): 106 | def __init__(self, delimiter="\t"): 107 | self.meters = defaultdict(SmoothedValue) 108 | self.delimiter = delimiter 109 | 110 | def update(self, **kwargs): 111 | for k, v in kwargs.items(): 112 | if isinstance(v, torch.Tensor): 113 | v = v.item() 114 | assert isinstance(v, (float, int)) 115 | self.meters[k].update(v) 116 | 117 | def __getattr__(self, attr): 118 | if attr in self.meters: 119 | return self.meters[attr] 120 | if attr in self.__dict__: 121 | return self.__dict__[attr] 122 | raise AttributeError("'{}' object has no attribute '{}'".format( 123 | type(self).__name__, attr)) 124 | 125 | def __str__(self): 126 | loss_str = [] 127 | for name, meter in self.meters.items(): 128 | loss_str.append( 129 | "{}: {}".format(name, str(meter)) 130 | ) 131 | return self.delimiter.join(loss_str) 132 | 133 | def synchronize_between_processes(self): 134 | for meter in self.meters.values(): 135 | meter.synchronize_between_processes() 136 | 137 | def add_meter(self, name, meter): 138 | self.meters[name] = meter 139 | 140 | def log_every(self, iterable, print_freq, header=None): 141 | i = 0 142 | if not header: 143 | header = '' 144 | start_time = time.time() 145 | end = time.time() 146 | iter_time = SmoothedValue(fmt='{avg:.4f}') 147 | data_time = SmoothedValue(fmt='{avg:.4f}') 148 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 149 | log_msg = [ 150 | header, 151 | '[{0' + space_fmt + '}/{1}]', 152 | 'eta: {eta}', 153 | '{meters}', 154 | 'time: {time}', 155 | 'data: {data}' 156 | ] 157 | if torch.cuda.is_available(): 158 | log_msg.append('max mem: {memory:.0f}') 159 | log_msg = self.delimiter.join(log_msg) 160 | MB = 1024.0 * 1024.0 161 | for obj in iterable: 162 | data_time.update(time.time() - end) 163 | yield obj 164 | iter_time.update(time.time() - end) 165 | if i % print_freq == 0 or i == len(iterable) - 1: 166 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 167 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 168 | if torch.cuda.is_available(): 169 | print(log_msg.format( 170 | i, len(iterable), eta=eta_string, 171 | meters=str(self), 172 | time=str(iter_time), data=str(data_time), 173 | memory=torch.cuda.max_memory_allocated() / MB)) 174 | else: 175 | print(log_msg.format( 176 | i, len(iterable), eta=eta_string, 177 | meters=str(self), 178 | time=str(iter_time), data=str(data_time))) 179 | i += 1 180 | end = time.time() 181 | total_time = time.time() - start_time 182 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 183 | print('{} Total time: {} ({:.4f} s / it)'.format( 184 | header, total_time_str, total_time / len(iterable))) 185 | 186 | 187 | def _load_checkpoint_for_ema(model_ema, checkpoint): 188 | """ 189 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 190 | """ 191 | mem_file = io.BytesIO() 192 | torch.save(checkpoint, mem_file) 193 | mem_file.seek(0) 194 | model_ema._load_checkpoint(mem_file) 195 | 196 | 197 | def setup_for_distributed(is_master): 198 | """ 199 | This function disables printing when not in master process 200 | """ 201 | import builtins as __builtin__ 202 | builtin_print = __builtin__.print 203 | 204 | def print(*args, **kwargs): 205 | force = kwargs.pop('force', False) 206 | if is_master or force: 207 | builtin_print(*args, **kwargs) 208 | 209 | __builtin__.print = print 210 | 211 | def setup_for_distributed_print(rank): 212 | """ 213 | This function disables printing when not in master process 214 | """ 215 | import builtins as __builtin__ 216 | builtin_print = __builtin__.print 217 | 218 | def print(*args, **kwargs): 219 | builtin_print(f'{rank} ', *args, **kwargs) 220 | 221 | __builtin__.print = print 222 | 223 | def is_dist_avail_and_initialized(): 224 | if not dist.is_available(): 225 | return False 226 | if not dist.is_initialized(): 227 | return False 228 | return True 229 | 230 | 231 | def get_world_size(): 232 | if not is_dist_avail_and_initialized(): 233 | return 1 234 | return dist.get_world_size() 235 | 236 | 237 | def get_rank(): 238 | if not is_dist_avail_and_initialized(): 239 | return 0 240 | return dist.get_rank() 241 | 242 | 243 | def is_main_process(): 244 | return get_rank() == 0 245 | 246 | 247 | def save_on_master(*args, **kwargs): 248 | if is_main_process(): 249 | torch.save(*args, **kwargs) 250 | 251 | 252 | def init_distributed_mode(args, verbose=False): 253 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 254 | args.rank = int(os.environ["RANK"]) 255 | args.world_size = int(os.environ['WORLD_SIZE']) 256 | args.gpu = int(os.environ['LOCAL_RANK']) 257 | os.environ['MASTER_PORT'] = str(getattr(args, 'port', '29529')) 258 | elif 'SLURM_PROCID' in os.environ: 259 | args.rank = int(os.environ['SLURM_PROCID']) 260 | args.gpu = args.rank % torch.cuda.device_count() 261 | args.world_size = int(os.environ['SLURM_NTASKS']) 262 | node_list = os.environ['SLURM_NODELIST'] 263 | num_gpus = torch.cuda.device_count() 264 | addr = subprocess.getoutput( 265 | f'scontrol show hostname {node_list} | head -n1') 266 | os.environ['MASTER_PORT'] = str(getattr(args, 'port', '29529')) 267 | os.environ['MASTER_ADDR'] = addr 268 | os.environ['WORLD_SIZE'] = str(args.world_size) 269 | os.environ['LOCAL_RANK'] = str(args.rank % num_gpus) 270 | os.environ['RANK'] = str(args.rank) 271 | 272 | else: 273 | print('Not using distributed mode') 274 | args.distributed = False 275 | return 276 | 277 | args.distributed = True 278 | 279 | torch.cuda.set_device(args.gpu) 280 | args.dist_backend = 'nccl' 281 | print('| distributed init (rank {}): {}, gpu {}'.format( 282 | args.rank, args.dist_url, args.gpu), flush=True) 283 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 284 | world_size=args.world_size, rank=args.rank) 285 | torch.distributed.barrier() 286 | if verbose: 287 | setup_for_distributed(args.rank == 0) 288 | else: 289 | setup_for_distributed_print(args.rank) 290 | 291 | 292 | def init_distributed_slurm(args): 293 | port = None 294 | args.rank = proc_id = int(os.environ['SLURM_PROCID']) 295 | args.world_size = ntasks = int(os.environ['SLURM_NTASKS']) 296 | node_list = os.environ['SLURM_NODELIST'] 297 | num_gpus = torch.cuda.device_count() 298 | print(node_list, num_gpus) 299 | args.gpu = args.rank % torch.cuda.device_count() 300 | args.distributed = True 301 | 302 | torch.cuda.set_device(args.gpu) 303 | args.dist_backend = 'nccl' 304 | 305 | addr = subprocess.getoutput( 306 | f'scontrol show hostname {node_list} | head -n1') 307 | print(f'addr {addr}') 308 | # specify master port 309 | if port is not None: 310 | os.environ['MASTER_PORT'] = str(port) 311 | elif 'MASTER_PORT' in os.environ: 312 | pass # use MASTER_PORT in the environment variable 313 | else: 314 | # 29500 is torch.distributed default port 315 | # os.environ['MASTER_PORT'] = str(29501) 316 | os.environ['MASTER_PORT'] = str(getattr(args, 'port', 29501)) 317 | # use MASTER_ADDR in the environment variable if it already exists 318 | if 'MASTER_ADDR' not in os.environ: 319 | os.environ['MASTER_ADDR'] = addr 320 | os.environ['WORLD_SIZE'] = str(ntasks) 321 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 322 | os.environ['RANK'] = str(proc_id) 323 | 324 | print('| distributed init (rank {}): {}'.format( 325 | args.rank, args.dist_url), flush=True) 326 | torch.distributed.init_process_group(backend=args.dist_backend) 327 | torch.distributed.barrier() 328 | setup_for_distributed(args.rank == 0) 329 | 330 | --------------------------------------------------------------------------------