├── .gitattributes ├── LICENSE ├── README.md ├── augment.py ├── datasets.py ├── engine_scala.py ├── eval.sh ├── fig ├── dino.png ├── gran_bound.png ├── hybrid.png ├── intro.png ├── meta.png ├── neu.png ├── slim.png ├── snap.png └── transfer.png ├── losses.py ├── main_scala.py ├── models_scala.py ├── run.sh ├── samplers.py ├── scheduler.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 BeSpontaneous 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Slicing Vision Transformer for Flexible Inference (NeurIPS 2024) 2 | 3 |
4 | 5 | 6 | 7 |
8 | 9 | Primary contact: [Yitian Zhang](mailto:markcheung9248@gmail.com) 10 | 11 |
12 | 13 |
14 | 15 | 16 | ## TL,DR 17 | - `Background`: ViTs are the same architecture but only differ in embedding dimensions, a large ViT can be transformed to represent small models by uniformly slicing the weight matrix at each layer, e.g., ViT-B (r=0.5) equals ViT-S. 18 | - `Target`: Broad slicing bound to ensure the diversity of sub-networks; Fine-grained slicing granularity to ensure the number of sub-networks; Uniform slicing to align with the inherent design of ViT to vary from widths. 19 | - `Contribution`: 20 | - (1) Detailed analysis of the slimmable ability between different architectures 21 | - (2) Propose Scala to learn slimmable representation for flexible inference 22 | 23 | 24 | ## Requirements 25 | - python 3.7 26 | - pytorch 1.8.1 27 | - torchvision 0.9.1 28 | - timm 0.3.2 29 | 30 | 31 | ## Datasets 32 | Please follow the instruction of [DeiT](https://github.com/facebookresearch/deit/blob/main/README_deit.md#data-preparation) to prepare the ImageNet-1K dataset. 33 | 34 | 35 | ## Pretrained Models 36 | Here we provide the pretrained Scala building on top of DeiT-S which are trained on ImageNet-1K for 100 epochs: 37 | 38 | | Model | Acc1. ($r=0.25$) | Acc1. ($r=0.50$) | Acc1. ($r=0.75$) | Acc1. ($r=1.00$) | 39 | | ---- | ---- | ---- | ---- | ---- | 40 | | Separate Training | 45.8% | 65.1% | 70.7% | 75.0% | 41 | | [Scala-S (X=25)](https://drive.google.com/file/d/1-xQFweDA3MUTslDyfs5zqvdRhuRtSN99/view?usp=drive_link) | 58.4% | 67.8% | 73.1% | 76.2% | 42 | | [Scala-S (X=13)](https://drive.google.com/file/d/1D2KZ5_1VAKB8_NTCH35Xu8IdsIT5i3hB/view?usp=drive_link) | 58.7% | 68.3% | 73.3% | 76.1% | 43 | | [Scala-S (X=7)](https://drive.google.com/file/d/1DtA21C6VL4Qe8joHXl8yaaLrQ7mbHnEZ/view?usp=drive_link) | 59.8% | 70.3% | 74.2% | 76.5% | 44 | | [Scala-S (X=4)](https://drive.google.com/file/d/1ZBzFeaMYubr4lBajiyO4QYpzDAG7bp6i/view?usp=drive_link) | 59.8% | 72.0% | 75.6% | 76.7% | 45 | 46 | We also provide Scala building on top of DeiT-B which are trained on ImageNet-1K for 300 epochs: 47 | 48 | | Model | Acc1. ($r=0.25$) | Acc1. ($r=0.50$) | Acc1. ($r=0.75$) | Acc1. ($r=1.00$) | 49 | | ---- | ---- | ---- | ---- | ---- | 50 | | Separate Training | 72.2% | 79.9% | 81.0% | 81.8% | 51 | | [Scala-B (X=13)](https://drive.google.com/file/d/1g58ace9cfFUoooqP6n1Xy0mWqntSGhxE/view?usp=drive_link) | 75.3% | 79.3% | 81.2% | 82.0% | 52 | | [Scala-B (X=7)](https://drive.google.com/file/d/1LIgPj8TAzmrFvJcQS_QmIyUy8CTDDNeF/view?usp=drive_link) | 75.3% | 79.7% | 81.4% | 82.0% | 53 | | [Scala-B (X=4)](https://drive.google.com/file/d/1Usy-LevoYqAXdggUT-jvRiWY41Bw93hf/view?usp=drive_link) | 75.6% | 80.9% | 81.9% | 82.2% | 54 | 55 | 56 | ## Results 57 | 58 | - Slicing Granularity and Bound 59 |
60 | 61 |
62 | 63 | - Application on Hybrid and Lightweight structures 64 |
65 | 66 |
67 | 68 | - Slimmable Ability across Architectures 69 |
70 | 71 |
72 | 73 | - Transferability 74 | - **Whether the slimmable representation can be transferred to downstream tasks?** We first pre-train on ImageNet-1K for 300 epochs and then conduct linear probing on video recognition dataset UCF101. We make the classification head slimmable as well to fit the features with various dimensions and the results imply the great transferability of the slimmable representation. 75 |
76 | 77 |
78 | 79 | - **Whether the generalization ability can be maintained in the slimmable representation?** When leveraging the vision foundation model DINOv2 as the teacher network, we follow prior work [Proteus](https://github.com/BeSpontaneous/Proteus-pytorch) and remove all the Cross-Entropy losses during training to alleviate the dataset bias issue and inherit the strong generalization ability of the teacher network. The results are shown in the table and the delivered Scala-B with great generalization ability can be downloaded from the [link](https://drive.google.com/file/d/1KPJK_rucC8ovQPe2TDDeKt0HBflQ0mrq/view?usp=drive_link). 80 | 81 | 82 |
83 | 84 |
85 | 86 | 87 | ## Training Scala on ImageNet-1K 88 | 1. Specify the directory of datasets with `IMAGENET_LOCATION` in `run.sh`. 89 | 2. Specify the smallest slicing bound $s$ `smallest_ratio`, the largest slicing bound $l$ `largest_ratio` and slicing granularity $\epsilon$ `granularity` to determine $X$ (number of subnets). 90 | 91 | | $s$ | $l$ | $\epsilon$ | $X$ | 92 | | ---- | ---- | ---- | ---- | 93 | | 0.25 | 1.0 | 0.03125 | 25 | 94 | | 0.25 | 1.0 | 0.0625 | 13 | 95 | | 0.25 | 1.0 | 0.125 | 7 | 96 | | 0.25 | 1.0 | 0.25 | 4 | 97 | 98 | 3. Run `bash run.sh`. 99 | 100 | 101 | 102 | ## Flexible inference at different width ratios 103 | 1. Specify the directory of datasets with `IMAGENET_LOCATION` and the pretrained model with `MODEL_PATH` in `eval.sh`. 104 | 2. Specify the width ratio with `eval_ratio` in `eval.sh`. 105 | 3. Run `bash eval.sh`. -------------------------------------------------------------------------------- /augment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | """ 5 | 3Augment implementation 6 | Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino) 7 | and timm DA(https://github.com/rwightman/pytorch-image-models) 8 | """ 9 | import torch 10 | from torchvision import transforms 11 | 12 | from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor 13 | 14 | import numpy as np 15 | from torchvision import datasets, transforms 16 | import random 17 | 18 | 19 | 20 | from PIL import ImageFilter, ImageOps 21 | import torchvision.transforms.functional as TF 22 | 23 | 24 | class GaussianBlur(object): 25 | """ 26 | Apply Gaussian Blur to the PIL image. 27 | """ 28 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): 29 | self.prob = p 30 | self.radius_min = radius_min 31 | self.radius_max = radius_max 32 | 33 | def __call__(self, img): 34 | do_it = random.random() <= self.prob 35 | if not do_it: 36 | return img 37 | 38 | img = img.filter( 39 | ImageFilter.GaussianBlur( 40 | radius=random.uniform(self.radius_min, self.radius_max) 41 | ) 42 | ) 43 | return img 44 | 45 | class Solarization(object): 46 | """ 47 | Apply Solarization to the PIL image. 48 | """ 49 | def __init__(self, p=0.2): 50 | self.p = p 51 | 52 | def __call__(self, img): 53 | if random.random() < self.p: 54 | return ImageOps.solarize(img) 55 | else: 56 | return img 57 | 58 | class gray_scale(object): 59 | """ 60 | Apply Solarization to the PIL image. 61 | """ 62 | def __init__(self, p=0.2): 63 | self.p = p 64 | self.transf = transforms.Grayscale(3) 65 | 66 | def __call__(self, img): 67 | if random.random() < self.p: 68 | return self.transf(img) 69 | else: 70 | return img 71 | 72 | 73 | 74 | class horizontal_flip(object): 75 | """ 76 | Apply Solarization to the PIL image. 77 | """ 78 | def __init__(self, p=0.2,activate_pred=False): 79 | self.p = p 80 | self.transf = transforms.RandomHorizontalFlip(p=1.0) 81 | 82 | def __call__(self, img): 83 | if random.random() < self.p: 84 | return self.transf(img) 85 | else: 86 | return img 87 | 88 | 89 | 90 | def new_data_aug_generator(args = None): 91 | img_size = args.input_size 92 | remove_random_resized_crop = args.src 93 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 94 | primary_tfl = [] 95 | scale=(0.08, 1.0) 96 | interpolation='bicubic' 97 | if remove_random_resized_crop: 98 | primary_tfl = [ 99 | transforms.Resize(img_size, interpolation=3), 100 | transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'), 101 | transforms.RandomHorizontalFlip() 102 | ] 103 | else: 104 | primary_tfl = [ 105 | RandomResizedCropAndInterpolation( 106 | img_size, scale=scale, interpolation=interpolation), 107 | transforms.RandomHorizontalFlip() 108 | ] 109 | 110 | 111 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0), 112 | Solarization(p=1.0), 113 | GaussianBlur(p=1.0)])] 114 | 115 | if args.color_jitter is not None and not args.color_jitter==0: 116 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter)) 117 | final_tfl = [ 118 | transforms.ToTensor(), 119 | transforms.Normalize( 120 | mean=torch.tensor(mean), 121 | std=torch.tensor(std)) 122 | ] 123 | return transforms.Compose(primary_tfl+secondary_tfl+final_tfl) 124 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import os 4 | import json 5 | 6 | from torchvision import datasets, transforms 7 | from torchvision.datasets.folder import ImageFolder, default_loader 8 | 9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 10 | from timm.data import create_transform 11 | 12 | 13 | class INatDataset(ImageFolder): 14 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, 15 | category='name', loader=default_loader): 16 | self.transform = transform 17 | self.loader = loader 18 | self.target_transform = target_transform 19 | self.year = year 20 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] 21 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') 22 | with open(path_json) as json_file: 23 | data = json.load(json_file) 24 | 25 | with open(os.path.join(root, 'categories.json')) as json_file: 26 | data_catg = json.load(json_file) 27 | 28 | path_json_for_targeter = os.path.join(root, f"train{year}.json") 29 | 30 | with open(path_json_for_targeter) as json_file: 31 | data_for_targeter = json.load(json_file) 32 | 33 | targeter = {} 34 | indexer = 0 35 | for elem in data_for_targeter['annotations']: 36 | king = [] 37 | king.append(data_catg[int(elem['category_id'])][category]) 38 | if king[0] not in targeter.keys(): 39 | targeter[king[0]] = indexer 40 | indexer += 1 41 | self.nb_classes = len(targeter) 42 | 43 | self.samples = [] 44 | for elem in data['images']: 45 | cut = elem['file_name'].split('/') 46 | target_current = int(cut[2]) 47 | path_current = os.path.join(root, cut[0], cut[2], cut[3]) 48 | 49 | categors = data_catg[target_current] 50 | target_current_true = targeter[categors[category]] 51 | self.samples.append((path_current, target_current_true)) 52 | 53 | # __getitem__ and __len__ inherited from ImageFolder 54 | 55 | 56 | def build_dataset(is_train, args): 57 | transform = build_transform(is_train, args) 58 | 59 | if args.data_set == 'CIFAR': 60 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) 61 | nb_classes = 100 62 | elif args.data_set == 'IMNET': 63 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 64 | dataset = datasets.ImageFolder(root, transform=transform) 65 | nb_classes = 1000 66 | elif args.data_set == 'INAT': 67 | dataset = INatDataset(args.data_path, train=is_train, year=2018, 68 | category=args.inat_category, transform=transform) 69 | nb_classes = dataset.nb_classes 70 | elif args.data_set == 'INAT19': 71 | dataset = INatDataset(args.data_path, train=is_train, year=2019, 72 | category=args.inat_category, transform=transform) 73 | nb_classes = dataset.nb_classes 74 | 75 | return dataset, nb_classes 76 | 77 | 78 | def build_transform(is_train, args): 79 | resize_im = args.input_size > 32 80 | if is_train: 81 | # this should always dispatch to transforms_imagenet_train 82 | transform = create_transform( 83 | input_size=args.input_size, 84 | is_training=True, 85 | color_jitter=args.color_jitter, 86 | auto_augment=args.aa, 87 | interpolation=args.train_interpolation, 88 | re_prob=args.reprob, 89 | re_mode=args.remode, 90 | re_count=args.recount, 91 | ) 92 | if not resize_im: 93 | # replace RandomResizedCropAndInterpolation with 94 | # RandomCrop 95 | transform.transforms[0] = transforms.RandomCrop( 96 | args.input_size, padding=4) 97 | return transform 98 | 99 | t = [] 100 | if resize_im: 101 | size = int(args.input_size / args.eval_crop_ratio) 102 | t.append( 103 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images 104 | ) 105 | t.append(transforms.CenterCrop(args.input_size)) 106 | 107 | t.append(transforms.ToTensor()) 108 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 109 | return transforms.Compose(t) 110 | -------------------------------------------------------------------------------- /engine_scala.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable, Optional 9 | import torch.nn as nn 10 | import torch 11 | from torch.nn import functional as F 12 | from timm.data import Mixup 13 | from timm.utils import accuracy, ModelEma 14 | import random 15 | from losses import DistillationLoss 16 | import utils 17 | 18 | 19 | def train_one_epoch(full_warm_epoch: int, distillation_type: str, ce_coefficient: float, 20 | largest_ratio: float, smallest_ratio: float, granularity: float, 21 | ce_type: str, distill_type: str, transfer_type: str, token_type: str, 22 | model: torch.nn.Module, criterion: DistillationLoss, 23 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 24 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 25 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, 26 | set_training_mode=True, args = None): 27 | model.train(set_training_mode) 28 | metric_logger = utils.MetricLogger(delimiter=" ") 29 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 30 | header = 'Epoch: [{}]'.format(epoch) 31 | print_freq = 10 32 | 33 | if args.cosub: 34 | criterion = torch.nn.BCEWithLogitsLoss() 35 | 36 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 37 | samples = samples.to(device, non_blocking=True) 38 | targets = targets.to(device, non_blocking=True) 39 | 40 | if mixup_fn is not None: 41 | samples, targets = mixup_fn(samples, targets) 42 | 43 | if args.cosub: 44 | samples = torch.cat((samples,samples),dim=0) 45 | 46 | if args.bce_loss: 47 | targets = targets.gt(0.0).type(targets.dtype) 48 | 49 | optimizer.zero_grad() 50 | 51 | with torch.cuda.amp.autocast(): 52 | if epoch < full_warm_epoch: 53 | outputs = model(samples, largest_ratio) 54 | if not args.cosub: 55 | loss_full = criterion(samples, outputs, targets) 56 | else: 57 | loss_full = 0.25 * criterion(outputs[0], targets) 58 | loss_full = loss_full + 0.25 * criterion(outputs[1], targets) 59 | loss_full = loss_full + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid()) 60 | loss_full = loss_full + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid()) 61 | 62 | loss_scaler.scale(loss_full).backward() 63 | loss_scaler.step(optimizer) 64 | loss_scaler.update() 65 | 66 | loss = loss_full 67 | loss_value = loss.item() 68 | loss_full_value = loss_full.item() 69 | loss_3q_value = 0.0 70 | loss_2q_value = 0.0 71 | loss_1q_value = 0.0 72 | loss_3q_ce_value = 0.0 73 | loss_2q_ce_value = 0.0 74 | loss_1q_ce_value = 0.0 75 | else: 76 | ############## full model ############## 77 | outputs = model(samples, largest_ratio) 78 | if not args.cosub: 79 | loss_full = criterion(samples, outputs, targets) 80 | else: 81 | loss_full = 0.25 * criterion(outputs[0], targets) 82 | loss_full = loss_full + 0.25 * criterion(outputs[1], targets) 83 | loss_full = loss_full + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid()) 84 | loss_full = loss_full + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid()) 85 | 86 | loss_scaler.scale(loss_full).backward() 87 | 88 | middle_ratio = (largest_ratio + smallest_ratio) / 2 89 | width_3q = float(random.randint(middle_ratio//granularity, largest_ratio//granularity-1) * granularity) 90 | width_2q = float(random.randint(smallest_ratio//granularity+1, middle_ratio//granularity-1) * granularity) 91 | 92 | if token_type == 'dist_token': 93 | token = int(1) 94 | elif token_type == 'cls_token': 95 | token = int(0) 96 | 97 | if ce_type == 'one_hot': 98 | ce_targets = targets 99 | elif ce_type == 'teacher': 100 | ce_targets = outputs[0].detach().argmax(dim=1) 101 | 102 | ce_loss = torch.nn.CrossEntropyLoss() 103 | 104 | ############## 3q model ############## 105 | output_3q = model(samples, width_3q) 106 | if distillation_type == 'none': 107 | if distill_type == 'hard': 108 | loss_3q = F.cross_entropy(output_3q, outputs.detach().argmax(dim=1)) 109 | else: 110 | loss_3q = nn.KLDivLoss(reduction='batchmean')(nn.LogSoftmax(dim=1)(output_3q), nn.Softmax(dim=1)(outputs.detach())) 111 | loss_3q_ce = ce_coefficient * ce_loss(output_3q, ce_targets) 112 | loss_3q = loss_3q + loss_3q_ce 113 | else: 114 | if distill_type == 'hard': 115 | loss_3q = F.cross_entropy(output_3q[token], outputs[token].detach().argmax(dim=1)) 116 | else: 117 | loss_3q = nn.KLDivLoss(reduction='batchmean')(nn.LogSoftmax(dim=1)(output_3q[token]), nn.Softmax(dim=1)(outputs[token].detach())) 118 | loss_3q_ce = ce_coefficient * ce_loss(output_3q[0], ce_targets) 119 | loss_3q = loss_3q + loss_3q_ce 120 | 121 | loss_scaler.scale(loss_3q).backward() 122 | 123 | 124 | ############## 2q model ############## 125 | if transfer_type == 'US': 126 | teacher_2q = outputs 127 | else: 128 | teacher_2q = output_3q 129 | output_2q = model(samples, width_2q) 130 | if distillation_type == 'none': 131 | if distill_type == 'hard': 132 | loss_2q = F.cross_entropy(output_2q, teacher_2q.detach().argmax(dim=1)) 133 | else: 134 | loss_2q = nn.KLDivLoss(reduction='batchmean')(nn.LogSoftmax(dim=1)(output_2q), nn.Softmax(dim=1)(teacher_2q.detach())) 135 | loss_2q_ce = ce_coefficient * ce_loss(output_2q, ce_targets) 136 | loss_2q = loss_2q + loss_2q_ce 137 | else: 138 | if distill_type == 'hard': 139 | loss_2q = F.cross_entropy(output_2q[token], teacher_2q[token].detach().argmax(dim=1)) 140 | else: 141 | loss_2q = nn.KLDivLoss(reduction='batchmean')(nn.LogSoftmax(dim=1)(output_2q[token]), nn.Softmax(dim=1)(teacher_2q[token].detach())) 142 | loss_2q_ce = ce_coefficient * ce_loss(output_2q[0], ce_targets) 143 | loss_2q = loss_2q + loss_2q_ce 144 | 145 | loss_scaler.scale(loss_2q).backward() 146 | 147 | 148 | ############## 1q model ############## 149 | if transfer_type == 'US': 150 | teacher_1q = outputs 151 | else: 152 | teacher_1q = output_2q 153 | output_1q = model(samples, smallest_ratio) 154 | if distillation_type == 'none': 155 | if distill_type == 'hard': 156 | loss_1q = F.cross_entropy(output_1q, teacher_1q.detach().argmax(dim=1)) 157 | else: 158 | loss_1q = nn.KLDivLoss(reduction='batchmean')(nn.LogSoftmax(dim=1)(output_1q), nn.Softmax(dim=1)(teacher_1q.detach())) 159 | loss_1q_ce = ce_coefficient * ce_loss(output_1q, ce_targets) 160 | loss_1q = loss_1q + loss_1q_ce 161 | else: 162 | if distill_type == 'hard': 163 | loss_1q = F.cross_entropy(output_1q[token], teacher_1q[token].detach().argmax(dim=1)) 164 | else: 165 | loss_1q = nn.KLDivLoss(reduction='batchmean')(nn.LogSoftmax(dim=1)(output_1q[token]), nn.Softmax(dim=1)(teacher_1q[token].detach())) 166 | loss_1q_ce = ce_coefficient * ce_loss(output_1q[0], ce_targets) 167 | loss_1q = loss_1q + loss_1q_ce 168 | 169 | loss_scaler.scale(loss_1q).backward() 170 | loss_scaler.step(optimizer) 171 | loss_scaler.update() 172 | 173 | loss = loss_full + loss_3q + loss_2q + loss_1q 174 | loss_value = loss.item() 175 | loss_full_value = loss_full.item() 176 | loss_3q_value = loss_3q.item() 177 | loss_2q_value = loss_2q.item() 178 | loss_1q_value = loss_1q.item() 179 | loss_3q_ce_value = loss_3q_ce.item() 180 | loss_2q_ce_value = loss_2q_ce.item() 181 | loss_1q_ce_value = loss_1q_ce.item() 182 | 183 | if not math.isfinite(loss_value): 184 | print("Loss is {}, stopping training".format(loss_value)) 185 | sys.exit(1) 186 | 187 | # # this attribute is added by timm on one optimizer (adahessian) 188 | # is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 189 | # loss_scaler(loss, optimizer, clip_grad=max_norm, 190 | # parameters=model.parameters(), create_graph=is_second_order) 191 | 192 | torch.cuda.synchronize() 193 | if model_ema is not None: 194 | model_ema.update(model) 195 | 196 | metric_logger.update(loss=loss_value) 197 | metric_logger.update(loss_full=loss_full_value) 198 | metric_logger.update(loss_3q=loss_3q_value) 199 | metric_logger.update(loss_3q_ce=loss_3q_ce_value) 200 | metric_logger.update(loss_2q=loss_2q_value) 201 | metric_logger.update(loss_2q_ce=loss_2q_ce_value) 202 | metric_logger.update(loss_1q=loss_1q_value) 203 | metric_logger.update(loss_1q_ce=loss_1q_ce_value) 204 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 205 | 206 | # gather the stats from all processes 207 | metric_logger.synchronize_between_processes() 208 | print("Averaged stats:", metric_logger) 209 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 210 | 211 | 212 | @torch.no_grad() 213 | def evaluate(data_loader, model, device, scale_ratio): 214 | criterion = torch.nn.CrossEntropyLoss() 215 | 216 | metric_logger = utils.MetricLogger(delimiter=" ") 217 | header = 'Test:' 218 | 219 | # switch to evaluation mode 220 | model.eval() 221 | 222 | for images, target in metric_logger.log_every(data_loader, 10, header): 223 | images = images.to(device, non_blocking=True) 224 | target = target.to(device, non_blocking=True) 225 | 226 | # compute output 227 | with torch.cuda.amp.autocast(): 228 | output = model(images, scale_ratio) 229 | loss = criterion(output, target) 230 | 231 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 232 | 233 | batch_size = images.shape[0] 234 | metric_logger.update(loss=loss.item()) 235 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 236 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 237 | # gather the stats from all processes 238 | metric_logger.synchronize_between_processes() 239 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 240 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 241 | 242 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 243 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | ### eval scala on deit_s at width ratio of 0.5 2 | 3 | python -m torch.distributed.launch --nproc_per_node=4 --use_env main_scala.py --eval \ 4 | --data-path IMAGENET_LOCATION \ 5 | --model deit_small_distilled_patch16_224_scala \ 6 | --batch-size 256 --eval_ratio 0.5 \ 7 | --resume MODEL_PATH; -------------------------------------------------------------------------------- /fig/dino.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/dino.png -------------------------------------------------------------------------------- /fig/gran_bound.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/gran_bound.png -------------------------------------------------------------------------------- /fig/hybrid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/hybrid.png -------------------------------------------------------------------------------- /fig/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/intro.png -------------------------------------------------------------------------------- /fig/meta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/meta.png -------------------------------------------------------------------------------- /fig/neu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/neu.png -------------------------------------------------------------------------------- /fig/slim.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/slim.png -------------------------------------------------------------------------------- /fig/snap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/snap.png -------------------------------------------------------------------------------- /fig/transfer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/transfer.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Implements the knowledge distillation loss 5 | """ 6 | import torch 7 | from torch.nn import functional as F 8 | 9 | 10 | class DistillationLoss(torch.nn.Module): 11 | """ 12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by 13 | taking a teacher model prediction and using it as additional supervision. 14 | """ 15 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module, 16 | distillation_type: str, alpha: float, tau: float): 17 | super().__init__() 18 | self.base_criterion = base_criterion 19 | self.teacher_model = teacher_model 20 | assert distillation_type in ['none', 'soft', 'hard'] 21 | self.distillation_type = distillation_type 22 | self.alpha = alpha 23 | self.tau = tau 24 | 25 | def forward(self, inputs, outputs, labels): 26 | """ 27 | Args: 28 | inputs: The original inputs that are feed to the teacher model 29 | outputs: the outputs of the model to be trained. It is expected to be 30 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output 31 | in the first position and the distillation predictions as the second output 32 | labels: the labels for the base criterion 33 | """ 34 | outputs_kd = None 35 | if not isinstance(outputs, torch.Tensor): 36 | # assume that the model outputs a tuple of [outputs, outputs_kd] 37 | outputs, outputs_kd = outputs 38 | base_loss = self.base_criterion(outputs, labels) 39 | if self.distillation_type == 'none': 40 | return base_loss 41 | 42 | if outputs_kd is None: 43 | raise ValueError("When knowledge distillation is enabled, the model is " 44 | "expected to return a Tuple[Tensor, Tensor] with the output of the " 45 | "class_token and the dist_token") 46 | # don't backprop throught the teacher 47 | with torch.no_grad(): 48 | teacher_outputs = self.teacher_model(inputs) 49 | 50 | if self.distillation_type == 'soft': 51 | T = self.tau 52 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100 53 | # with slight modifications 54 | distillation_loss = F.kl_div( 55 | F.log_softmax(outputs_kd / T, dim=1), 56 | #We provide the teacher's targets in log probability because we use log_target=True 57 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719) 58 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both. 59 | F.log_softmax(teacher_outputs / T, dim=1), 60 | reduction='sum', 61 | log_target=True 62 | ) * (T * T) / outputs_kd.numel() 63 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior. 64 | #But we also experiments output_kd.size(0) 65 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details 66 | elif self.distillation_type == 'hard': 67 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1)) 68 | 69 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 70 | return loss 71 | -------------------------------------------------------------------------------- /main_scala.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import argparse 4 | import datetime 5 | import numpy as np 6 | import time 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import json 10 | 11 | from pathlib import Path 12 | 13 | from timm.data import Mixup 14 | from timm.models import create_model 15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 16 | from timm.optim import create_optimizer 17 | from timm.utils import get_state_dict, ModelEma 18 | import random 19 | from datasets import build_dataset 20 | from engine_scala import train_one_epoch, evaluate 21 | from losses import DistillationLoss 22 | from samplers import RASampler 23 | from scheduler import create_scheduler 24 | from augment import new_data_aug_generator 25 | 26 | import models_scala 27 | import utils 28 | 29 | 30 | def get_args_parser(): 31 | parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False) 32 | parser.add_argument('--batch-size', default=64, type=int) 33 | parser.add_argument('--epochs', default=300, type=int) 34 | parser.add_argument('--bce-loss', action='store_true') 35 | parser.add_argument('--unscale-lr', action='store_true') 36 | 37 | parser.add_argument('--full_warm_epoch', type=int, default=10) 38 | parser.add_argument('--smallest_ratio', type=float, default=0.25) 39 | parser.add_argument('--largest_ratio', type=float, default=1.0) 40 | parser.add_argument('--granularity', type=float, default=0.03125) 41 | parser.add_argument('--ce_coefficient', type=float, default=1.0) 42 | parser.add_argument('--eval_ratio', type=float, default=1.0) 43 | parser.add_argument('--ce_type', default='one_hot', type=str) 44 | parser.add_argument('--distill_type', default='soft', type=str) 45 | parser.add_argument('--transfer_type', default='progressive', type=str) 46 | parser.add_argument('--token_type', default='cls_token', type=str) 47 | 48 | # Model parameters 49 | parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL', 50 | help='Name of model to train') 51 | parser.add_argument('--input-size', default=224, type=int, help='images input size') 52 | 53 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 54 | help='Dropout rate (default: 0.)') 55 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 56 | help='Drop path rate (default: 0.1)') 57 | 58 | parser.add_argument('--model-ema', action='store_true') 59 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema') 60 | parser.set_defaults(model_ema=True) 61 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='') 62 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='') 63 | 64 | # Optimizer parameters 65 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 66 | help='Optimizer (default: "adamw"') 67 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', 68 | help='Optimizer Epsilon (default: 1e-8)') 69 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 70 | help='Optimizer Betas (default: None, use opt default)') 71 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 72 | help='Clip gradient norm (default: None, no clipping)') 73 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 74 | help='SGD momentum (default: 0.9)') 75 | parser.add_argument('--weight-decay', type=float, default=0.05, 76 | help='weight decay (default: 0.05)') 77 | # Learning rate schedule parameters 78 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 79 | help='LR scheduler (default: "cosine"') 80 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 81 | help='learning rate (default: 5e-4)') 82 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 83 | help='learning rate noise on/off epoch percentages') 84 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 85 | help='learning rate noise limit percent (default: 0.67)') 86 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 87 | help='learning rate noise std-dev (default: 1.0)') 88 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 89 | help='warmup learning rate (default: 1e-6)') 90 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 91 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 92 | 93 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 94 | help='epoch interval to decay LR') 95 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', 96 | help='epochs to warmup LR, if scheduler supports') 97 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 98 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 99 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 100 | help='patience epochs for Plateau LR scheduler (default: 10') 101 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 102 | help='LR decay rate (default: 0.1)') 103 | 104 | # Augmentation parameters 105 | parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT', 106 | help='Color jitter factor (default: 0.3)') 107 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 108 | help='Use AutoAugment policy. "v0" or "original". " + \ 109 | "(default: rand-m9-mstd0.5-inc1)'), 110 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') 111 | parser.add_argument('--train-interpolation', type=str, default='bicubic', 112 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")') 113 | 114 | parser.add_argument('--repeated-aug', action='store_true') 115 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug') 116 | parser.set_defaults(repeated_aug=True) 117 | 118 | parser.add_argument('--train-mode', action='store_true') 119 | parser.add_argument('--no-train-mode', action='store_false', dest='train_mode') 120 | parser.set_defaults(train_mode=True) 121 | 122 | parser.add_argument('--ThreeAugment', action='store_true') #3augment 123 | 124 | parser.add_argument('--src', action='store_true') #simple random crop 125 | 126 | # * Random Erase params 127 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 128 | help='Random erase prob (default: 0.25)') 129 | parser.add_argument('--remode', type=str, default='pixel', 130 | help='Random erase mode (default: "pixel")') 131 | parser.add_argument('--recount', type=int, default=1, 132 | help='Random erase count (default: 1)') 133 | parser.add_argument('--resplit', action='store_true', default=False, 134 | help='Do not random erase first (clean) augmentation split') 135 | 136 | # * Mixup params 137 | parser.add_argument('--mixup', type=float, default=0.8, 138 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)') 139 | parser.add_argument('--cutmix', type=float, default=1.0, 140 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') 141 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 142 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 143 | parser.add_argument('--mixup-prob', type=float, default=1.0, 144 | help='Probability of performing mixup or cutmix when either/both is enabled') 145 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 146 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 147 | parser.add_argument('--mixup-mode', type=str, default='batch', 148 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 149 | 150 | # Distillation parameters 151 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL', 152 | help='Name of teacher model to train (default: "regnety_160"') 153 | parser.add_argument('--teacher-path', type=str, default='') 154 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="") 155 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="") 156 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="") 157 | 158 | # * Cosub params 159 | parser.add_argument('--cosub', action='store_true') 160 | 161 | # * Finetuning params 162 | parser.add_argument('--finetune', default='', help='finetune from checkpoint') 163 | parser.add_argument('--attn-only', action='store_true') 164 | 165 | # Dataset parameters 166 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str, 167 | help='dataset path') 168 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'], 169 | type=str, help='Image Net dataset path') 170 | parser.add_argument('--inat-category', default='name', 171 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'], 172 | type=str, help='semantic granularity') 173 | 174 | parser.add_argument('--output_dir', default='', 175 | help='path where to save, empty for no saving') 176 | parser.add_argument('--device', default='cuda', 177 | help='device to use for training / testing') 178 | parser.add_argument('--seed', default=0, type=int) 179 | parser.add_argument('--resume', default='', help='resume from checkpoint') 180 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 181 | help='start epoch') 182 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 183 | parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation") 184 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation') 185 | parser.add_argument('--num_workers', default=10, type=int) 186 | parser.add_argument('--pin-mem', action='store_true', 187 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 188 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem', 189 | help='') 190 | parser.set_defaults(pin_mem=True) 191 | 192 | # distributed training parameters 193 | parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training') 194 | parser.add_argument('--world_size', default=1, type=int, 195 | help='number of distributed processes') 196 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 197 | return parser 198 | 199 | 200 | def main(args): 201 | utils.init_distributed_mode(args) 202 | 203 | print(args) 204 | 205 | if args.distillation_type != 'none' and args.finetune and not args.eval: 206 | raise NotImplementedError("Finetuning with distillation not yet supported") 207 | 208 | device = torch.device(args.device) 209 | 210 | # fix the seed for reproducibility 211 | seed = args.seed + utils.get_rank() 212 | torch.manual_seed(seed) 213 | np.random.seed(seed) 214 | random.seed(seed) 215 | 216 | cudnn.benchmark = True 217 | 218 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args) 219 | dataset_val, _ = build_dataset(is_train=False, args=args) 220 | 221 | if args.distributed: 222 | num_tasks = utils.get_world_size() 223 | global_rank = utils.get_rank() 224 | if args.repeated_aug: 225 | sampler_train = RASampler( 226 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 227 | ) 228 | else: 229 | sampler_train = torch.utils.data.DistributedSampler( 230 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 231 | ) 232 | if args.dist_eval: 233 | if len(dataset_val) % num_tasks != 0: 234 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' 235 | 'This will slightly alter validation results as extra duplicate entries are added to achieve ' 236 | 'equal num of samples per-process.') 237 | sampler_val = torch.utils.data.DistributedSampler( 238 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False) 239 | else: 240 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 241 | else: 242 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 243 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 244 | 245 | data_loader_train = torch.utils.data.DataLoader( 246 | dataset_train, sampler=sampler_train, 247 | batch_size=args.batch_size, 248 | num_workers=args.num_workers, 249 | pin_memory=args.pin_mem, 250 | drop_last=True, 251 | ) 252 | if args.ThreeAugment: 253 | data_loader_train.dataset.transform = new_data_aug_generator(args) 254 | 255 | data_loader_val = torch.utils.data.DataLoader( 256 | dataset_val, sampler=sampler_val, 257 | batch_size=int(1.5 * args.batch_size), 258 | num_workers=args.num_workers, 259 | pin_memory=args.pin_mem, 260 | drop_last=False 261 | ) 262 | 263 | mixup_fn = None 264 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 265 | if mixup_active: 266 | mixup_fn = Mixup( 267 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 268 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 269 | label_smoothing=args.smoothing, num_classes=args.nb_classes) 270 | 271 | print(f"Creating model: {args.model}") 272 | model = create_model( 273 | args.model, 274 | pretrained=False, 275 | num_classes=args.nb_classes, 276 | drop_rate=args.drop, 277 | drop_path_rate=args.drop_path, 278 | drop_block_rate=None, 279 | img_size=args.input_size, 280 | smallest_ratio=args.smallest_ratio, 281 | largest_ratio=args.largest_ratio, 282 | ) 283 | 284 | 285 | if args.finetune: 286 | if args.finetune.startswith('https'): 287 | checkpoint = torch.hub.load_state_dict_from_url( 288 | args.finetune, map_location='cpu', check_hash=True) 289 | else: 290 | checkpoint = torch.load(args.finetune, map_location='cpu') 291 | 292 | checkpoint_model = checkpoint['model'] 293 | state_dict = model.state_dict() 294 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']: 295 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: 296 | print(f"Removing key {k} from pretrained checkpoint") 297 | del checkpoint_model[k] 298 | 299 | # interpolate position embedding 300 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 301 | embedding_size = pos_embed_checkpoint.shape[-1] 302 | num_patches = model.patch_embed.num_patches 303 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 304 | # height (== width) for the checkpoint position embedding 305 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 306 | # height (== width) for the new position embedding 307 | new_size = int(num_patches ** 0.5) 308 | # class_token and dist_token are kept unchanged 309 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 310 | # only the position tokens are interpolated 311 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 312 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 313 | pos_tokens = torch.nn.functional.interpolate( 314 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 315 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 316 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 317 | checkpoint_model['pos_embed'] = new_pos_embed 318 | 319 | model.load_state_dict(checkpoint_model, strict=False) 320 | 321 | if args.attn_only: 322 | for name_p,p in model.named_parameters(): 323 | if '.attn.' in name_p: 324 | p.requires_grad = True 325 | else: 326 | p.requires_grad = False 327 | try: 328 | model.head.weight.requires_grad = True 329 | model.head.bias.requires_grad = True 330 | except: 331 | model.fc.weight.requires_grad = True 332 | model.fc.bias.requires_grad = True 333 | try: 334 | model.pos_embed.requires_grad = True 335 | except: 336 | print('no position encoding') 337 | try: 338 | for p in model.patch_embed.parameters(): 339 | p.requires_grad = False 340 | except: 341 | print('no patch embed') 342 | 343 | model.to(device) 344 | 345 | model_ema = None 346 | if args.model_ema: 347 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 348 | model_ema = ModelEma( 349 | model, 350 | decay=args.model_ema_decay, 351 | device='cpu' if args.model_ema_force_cpu else '', 352 | resume='') 353 | 354 | model_without_ddp = model 355 | if args.distributed: 356 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 357 | model_without_ddp = model.module 358 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 359 | print('number of params:', n_parameters) 360 | if not args.unscale_lr: 361 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 362 | args.lr = linear_scaled_lr 363 | optimizer = create_optimizer(args, model_without_ddp) 364 | # loss_scaler = NativeScaler() 365 | loss_scaler = torch.cuda.amp.GradScaler() 366 | 367 | lr_scheduler, _ = create_scheduler(args, optimizer) 368 | 369 | criterion = LabelSmoothingCrossEntropy() 370 | 371 | if mixup_active: 372 | # smoothing is handled with mixup label transform 373 | criterion = SoftTargetCrossEntropy() 374 | elif args.smoothing: 375 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 376 | else: 377 | criterion = torch.nn.CrossEntropyLoss() 378 | 379 | if args.bce_loss: 380 | criterion = torch.nn.BCEWithLogitsLoss() 381 | 382 | teacher_model = None 383 | if args.distillation_type != 'none': 384 | assert args.teacher_path, 'need to specify teacher-path when using distillation' 385 | print(f"Creating teacher model: {args.teacher_model}") 386 | teacher_model = create_model( 387 | args.teacher_model, 388 | pretrained=False, 389 | num_classes=args.nb_classes, 390 | # global_pool='avg', 391 | ) 392 | if args.teacher_path.startswith('https'): 393 | checkpoint = torch.hub.load_state_dict_from_url( 394 | args.teacher_path, map_location='cpu', check_hash=True) 395 | else: 396 | checkpoint = torch.load(args.teacher_path, map_location='cpu') 397 | teacher_model.load_state_dict(checkpoint['model']) 398 | teacher_model.to(device) 399 | teacher_model.eval() 400 | 401 | # wrap the criterion in our custom DistillationLoss, which 402 | # just dispatches to the original criterion if args.distillation_type is 'none' 403 | criterion = DistillationLoss( 404 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau 405 | ) 406 | 407 | output_dir = Path(args.output_dir) 408 | if args.resume: 409 | if args.resume.startswith('https'): 410 | checkpoint = torch.hub.load_state_dict_from_url( 411 | args.resume, map_location='cpu', check_hash=True) 412 | else: 413 | checkpoint = torch.load(args.resume, map_location='cpu') 414 | model_without_ddp.load_state_dict(checkpoint['model']) 415 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 416 | optimizer.load_state_dict(checkpoint['optimizer']) 417 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 418 | args.start_epoch = checkpoint['epoch'] + 1 419 | if args.model_ema: 420 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema']) 421 | if 'scaler' in checkpoint: 422 | loss_scaler.load_state_dict(checkpoint['scaler']) 423 | lr_scheduler.step(args.start_epoch) 424 | if args.eval: 425 | test_stats = evaluate(data_loader_val, model, device, args.eval_ratio) 426 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 427 | return 428 | 429 | print(f"Start training for {args.epochs} epochs") 430 | start_time = time.time() 431 | max_accuracy = 0.0 432 | for epoch in range(args.start_epoch, args.epochs): 433 | if args.distributed: 434 | data_loader_train.sampler.set_epoch(epoch) 435 | 436 | train_stats = train_one_epoch( 437 | args.full_warm_epoch, args.distillation_type, args.ce_coefficient, 438 | args.largest_ratio, args.smallest_ratio, args.granularity, 439 | args.ce_type, args.distill_type, args.transfer_type, args.token_type, 440 | model, criterion, data_loader_train, 441 | optimizer, device, epoch, loss_scaler, 442 | args.clip_grad, model_ema, mixup_fn, 443 | set_training_mode=args.train_mode, # keep in eval mode for deit finetuning / train mode for training and deit III finetuning 444 | args = args, 445 | ) 446 | 447 | lr_scheduler.step(epoch) 448 | if args.output_dir: 449 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 450 | for checkpoint_path in checkpoint_paths: 451 | utils.save_on_master({ 452 | 'model': model_without_ddp.state_dict(), 453 | 'optimizer': optimizer.state_dict(), 454 | 'lr_scheduler': lr_scheduler.state_dict(), 455 | 'epoch': epoch, 456 | 'model_ema': get_state_dict(model_ema), 457 | 'scaler': loss_scaler.state_dict(), 458 | 'args': args, 459 | }, checkpoint_path) 460 | 461 | 462 | test_stats = evaluate(data_loader_val, model, device, 1.0) 463 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 464 | 465 | if max_accuracy < test_stats["acc1"]: 466 | max_accuracy = test_stats["acc1"] 467 | if args.output_dir: 468 | checkpoint_paths = [output_dir / 'best_checkpoint.pth'] 469 | for checkpoint_path in checkpoint_paths: 470 | utils.save_on_master({ 471 | 'model': model_without_ddp.state_dict(), 472 | 'optimizer': optimizer.state_dict(), 473 | 'lr_scheduler': lr_scheduler.state_dict(), 474 | 'epoch': epoch, 475 | 'model_ema': get_state_dict(model_ema), 476 | 'scaler': loss_scaler.state_dict(), 477 | 'args': args, 478 | }, checkpoint_path) 479 | 480 | print(f'Max accuracy: {max_accuracy:.2f}%') 481 | 482 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 483 | **{f'test_{k}': v for k, v in test_stats.items()}, 484 | 'epoch': epoch, 485 | 'n_parameters': n_parameters} 486 | 487 | 488 | if args.output_dir and utils.is_main_process(): 489 | with (output_dir / "log.txt").open("a") as f: 490 | f.write(json.dumps(log_stats) + "\n") 491 | 492 | if epoch == args.epochs - 1: 493 | range_num = int((1-args.smallest_ratio) / args.granularity) 494 | for i in range(range_num): 495 | scale_ratio = args.smallest_ratio + i * args.granularity 496 | test_stats_end = evaluate(data_loader_val, model, device, scale_ratio) 497 | print(f"Width ratio: {scale_ratio}, Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%") 498 | 499 | log_stats_end = {'Width ratio': scale_ratio, 500 | 'Acc1': test_stats_end['acc1']} 501 | 502 | if args.output_dir and utils.is_main_process(): 503 | with (output_dir / "log.txt").open("a") as f: 504 | f.write(json.dumps(log_stats_end) + "\n") 505 | 506 | total_time = time.time() - start_time 507 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 508 | print('Training time {}'.format(total_time_str)) 509 | 510 | 511 | if __name__ == '__main__': 512 | parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()]) 513 | args = parser.parse_args() 514 | if args.output_dir: 515 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 516 | main(args) 517 | -------------------------------------------------------------------------------- /models_scala.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | from functools import partial 4 | from collections import OrderedDict 5 | from typing import Callable, List, Optional, Tuple, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint 11 | from enum import Enum 12 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 13 | from timm.models.helpers import build_model_with_cfg, named_apply, checkpoint_seq 14 | from timm.models.layers import DropPath, trunc_normal_, lecun_normal_, to_2tuple 15 | from timm.models.registry import register_model 16 | 17 | _logger = logging.getLogger(__name__) 18 | 19 | 20 | def _assert(condition: bool, message: str): 21 | assert condition, message 22 | 23 | 24 | class Format(str, Enum): 25 | NCHW = 'NCHW' 26 | NHWC = 'NHWC' 27 | NCL = 'NCL' 28 | NLC = 'NLC' 29 | 30 | 31 | def _cfg(url='', **kwargs): 32 | return { 33 | 'url': url, 34 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 35 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 36 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 37 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 38 | **kwargs 39 | } 40 | 41 | 42 | class Attention(nn.Module): 43 | def __init__(self, smallest_ratio, largest_ratio, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): 44 | super().__init__() 45 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 46 | 47 | self.smallest_ratio = smallest_ratio 48 | self.largest_ratio = largest_ratio 49 | 50 | self.num_heads = num_heads 51 | head_dim = dim // num_heads 52 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 53 | self.scale = head_dim ** -0.5 54 | 55 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 56 | self.attn_drop = nn.Dropout(attn_drop) 57 | self.proj = nn.Linear(dim, dim) 58 | self.proj_drop = nn.Dropout(proj_drop) 59 | 60 | self.dim = dim 61 | 62 | def forward(self, x, ratio): 63 | dim_channels = int(ratio * self.dim) 64 | 65 | B, N, C = x.shape 66 | 67 | if ratio == self.smallest_ratio or ratio == self.largest_ratio: 68 | weight_qkv = self.qkv.weight[:dim_channels*3, :dim_channels] 69 | bias_qkv = self.qkv.bias[:dim_channels*3] 70 | else: 71 | weight_qkv = self.qkv.weight[-dim_channels*3:, -dim_channels:] 72 | bias_qkv = self.qkv.bias[-dim_channels*3:] 73 | qkv = F.linear(input=x, weight=weight_qkv, bias=bias_qkv) 74 | qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 75 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 76 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 77 | 78 | attn = (q @ k.transpose(-2, -1)) * self.scale 79 | attn = attn.softmax(dim=-1) 80 | attn = self.attn_drop(attn) 81 | 82 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 83 | 84 | if ratio == self.smallest_ratio or ratio == self.largest_ratio: 85 | weight_proj = self.proj.weight[:dim_channels, :dim_channels] 86 | bias_proj = self.proj.bias[:dim_channels] 87 | else: 88 | weight_proj = self.proj.weight[-dim_channels:, -dim_channels:] 89 | bias_proj = self.proj.bias[-dim_channels:] 90 | x = F.linear(input=x, weight=weight_proj, bias=bias_proj) 91 | x = self.proj_drop(x) 92 | return x 93 | 94 | 95 | 96 | class LayerScale(nn.Module): 97 | def __init__(self, smallest_ratio, largest_ratio, dim, init_values=1e-5, inplace=False): 98 | super().__init__() 99 | 100 | self.smallest_ratio = smallest_ratio 101 | self.largest_ratio = largest_ratio 102 | 103 | self.inplace = inplace 104 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 105 | 106 | self.dim = dim 107 | 108 | def forward(self, x, ratio): 109 | dim_channels = int(ratio * self.dim) 110 | 111 | if ratio == self.smallest_ratio or ratio == self.largest_ratio: 112 | gamma = self.gamma.data[:dim_channels] 113 | else: 114 | gamma = self.gamma.data[-dim_channels:] 115 | if self.inplace: 116 | x = x.mul_(gamma) 117 | else: 118 | x * gamma 119 | 120 | return x 121 | 122 | 123 | 124 | class Mlp(nn.Module): 125 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 126 | """ 127 | def __init__( 128 | self, 129 | smallest_ratio, 130 | largest_ratio, 131 | in_features, 132 | hidden_features=None, 133 | out_features=None, 134 | act_layer=nn.GELU, 135 | norm_layer=None, 136 | bias=True, 137 | drop=0., 138 | use_conv=False, 139 | ): 140 | super().__init__() 141 | out_features = out_features or in_features 142 | hidden_features = hidden_features or in_features 143 | bias = to_2tuple(bias) 144 | drop_probs = to_2tuple(drop) 145 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear 146 | 147 | self.smallest_ratio = smallest_ratio 148 | self.largest_ratio = largest_ratio 149 | 150 | self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) 151 | self.act = act_layer() 152 | self.drop1 = nn.Dropout(drop_probs[0]) 153 | self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() 154 | self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) 155 | self.drop2 = nn.Dropout(drop_probs[1]) 156 | 157 | self.in_features = in_features 158 | self.hidden_features = hidden_features 159 | self.out_features = out_features 160 | 161 | def forward(self, x, ratio): 162 | in_channels = int(ratio * self.in_features) 163 | hidden_channels = int(ratio * self.hidden_features) 164 | out_channels = int(ratio * self.out_features) 165 | 166 | if ratio == self.smallest_ratio or ratio == self.largest_ratio: 167 | weight1 = self.fc1.weight[:hidden_channels, :in_channels] 168 | bias1 = self.fc1.bias[:hidden_channels] 169 | else: 170 | weight1 = self.fc1.weight[-hidden_channels:, -in_channels:] 171 | bias1 = self.fc1.bias[-hidden_channels:] 172 | x = F.linear(input=x, weight=weight1, bias=bias1) 173 | x = self.act(x) 174 | x = self.drop1(x) 175 | 176 | if ratio == self.smallest_ratio or ratio == self.largest_ratio: 177 | weight2 = self.fc2.weight[:out_channels, :hidden_channels] 178 | bias2 = self.fc2.bias[:out_channels] 179 | else: 180 | weight2 = self.fc2.weight[-out_channels:, -hidden_channels:] 181 | bias2 = self.fc2.bias[-out_channels:] 182 | x = F.linear(input=x, weight=weight2, bias=bias2) 183 | x = self.drop2(x) 184 | 185 | return x 186 | 187 | 188 | class Block(nn.Module): 189 | 190 | def __init__( 191 | self, smallest_ratio, largest_ratio, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, 192 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 193 | super().__init__() 194 | self.smallest_ratio = smallest_ratio 195 | self.largest_ratio = largest_ratio 196 | self.norm1 = norm_layer(dim) 197 | self.attn = Attention(smallest_ratio, largest_ratio, dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 198 | self.ls1 = LayerScale(smallest_ratio, largest_ratio, dim, init_values=init_values) if init_values else nn.Identity() 199 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 200 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 201 | 202 | self.norm2 = norm_layer(dim) 203 | self.mlp = Mlp(smallest_ratio, largest_ratio, in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) 204 | self.ls2 = LayerScale(smallest_ratio, largest_ratio, dim, init_values=init_values) if init_values else nn.Identity() 205 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() 206 | 207 | self.dim = dim 208 | self.init_values = init_values 209 | 210 | def forward(self, x, ratio): 211 | dim_channels = int(ratio * self.dim) 212 | 213 | residual = x 214 | 215 | if ratio == self.smallest_ratio or ratio == self.largest_ratio: 216 | weight_norm1 = self.norm1.weight[:dim_channels] 217 | bias_norm1 = self.norm1.bias[:dim_channels] 218 | else: 219 | weight_norm1 = self.norm1.weight[-dim_channels:] 220 | bias_norm1 = self.norm1.bias[-dim_channels:] 221 | x = F.layer_norm(x, [dim_channels], weight_norm1, bias_norm1) 222 | 223 | x = self.attn(x, ratio) 224 | if self.init_values: 225 | x = self.ls1(x, ratio) 226 | else: 227 | x = self.ls1(x) 228 | x = residual + self.drop_path1(x) 229 | 230 | residual = x 231 | 232 | if ratio == self.smallest_ratio or ratio == self.largest_ratio: 233 | weight_norm2 = self.norm2.weight[:dim_channels] 234 | bias_norm2 = self.norm2.bias[:dim_channels] 235 | else: 236 | weight_norm2 = self.norm2.weight[-dim_channels:] 237 | bias_norm2 = self.norm2.bias[-dim_channels:] 238 | x = F.layer_norm(x, [dim_channels], weight_norm2, bias_norm2) 239 | 240 | x = self.mlp(x, ratio) 241 | if self.init_values: 242 | x = self.ls2(x, ratio) 243 | else: 244 | x = self.ls2(x) 245 | x = residual + self.drop_path2(x) 246 | return x 247 | 248 | 249 | 250 | 251 | class PatchEmbed(nn.Module): 252 | """ 2D Image to Patch Embedding 253 | """ 254 | output_fmt: Format 255 | 256 | def __init__( 257 | self, 258 | smallest_ratio: float = 0.25, 259 | largest_ratio: float = 1.0, 260 | img_size: Optional[int] = 224, 261 | patch_size: int = 16, 262 | in_chans: int = 3, 263 | embed_dim: int = 768, 264 | norm_layer: Optional[Callable] = None, 265 | flatten: bool = True, 266 | output_fmt: Optional[str] = None, 267 | bias: bool = True, 268 | strict_img_size: bool = True, 269 | ): 270 | super().__init__() 271 | self.smallest_ratio = smallest_ratio 272 | self.largest_ratio = largest_ratio 273 | self.stride = patch_size 274 | self.patch_size = to_2tuple(patch_size) 275 | if img_size is not None: 276 | self.img_size = to_2tuple(img_size) 277 | self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) 278 | self.num_patches = self.grid_size[0] * self.grid_size[1] 279 | else: 280 | self.img_size = None 281 | self.grid_size = None 282 | self.num_patches = None 283 | 284 | if output_fmt is not None: 285 | self.flatten = False 286 | self.output_fmt = Format(output_fmt) 287 | else: 288 | # flatten spatial dim and transpose to channels last, kept for bwd compat 289 | self.flatten = flatten 290 | self.output_fmt = Format.NCHW 291 | self.strict_img_size = strict_img_size 292 | 293 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) 294 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 295 | 296 | self.norm_layer = norm_layer 297 | self.embed_dim = embed_dim 298 | 299 | def forward(self, x, ratio): 300 | embed_dim_channels = int(ratio*self.embed_dim) 301 | 302 | B, C, H, W = x.shape 303 | if self.img_size is not None: 304 | if self.strict_img_size: 305 | _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).") 306 | _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).") 307 | else: 308 | _assert( 309 | H % self.patch_size[0] == 0, 310 | f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})." 311 | ) 312 | _assert( 313 | W % self.patch_size[1] == 0, 314 | f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." 315 | ) 316 | 317 | if ratio == self.smallest_ratio or ratio == self.largest_ratio: 318 | weight_proj = self.proj.weight[:embed_dim_channels, :, :, :] 319 | bias_proj = self.proj.bias[:embed_dim_channels] 320 | else: 321 | weight_proj = self.proj.weight[-embed_dim_channels:, :, :, :] 322 | bias_proj = self.proj.bias[-embed_dim_channels:] 323 | x = F.conv2d(x, weight_proj, bias_proj, self.stride) 324 | 325 | if self.flatten: 326 | x = x.flatten(2).transpose(1, 2) # NCHW -> NLC 327 | elif self.output_fmt != Format.NCHW: 328 | x = nchw_to(x, self.output_fmt) 329 | 330 | if self.norm_layer: 331 | if ratio == self.smallest_ratio or ratio == self.largest_ratio: 332 | weight_norm = self.norm.weight[:embed_dim_channels] 333 | bias_norm = self.norm.bias[:embed_dim_channels] 334 | else: 335 | weight_norm = self.norm.weight[-embed_dim_channels:] 336 | bias_norm = self.norm.bias[-embed_dim_channels:] 337 | x = F.layer_norm(x, [embed_dim_channels], weight_norm, bias_norm) 338 | else: 339 | x = self.norm(x) 340 | 341 | return x 342 | 343 | 344 | class VisionTransformer_scala(nn.Module): 345 | """ Vision Transformer 346 | 347 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 348 | - https://arxiv.org/abs/2010.11929 349 | """ 350 | 351 | def __init__( 352 | self, img_size=224, smallest_ratio=0.25, largest_ratio=1.0, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', 353 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, 354 | class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 355 | weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): 356 | """ 357 | Args: 358 | img_size (int, tuple): input image size 359 | patch_size (int, tuple): patch size 360 | in_chans (int): number of input channels 361 | num_classes (int): number of classes for classification head 362 | global_pool (str): type of global pooling for final sequence (default: 'token') 363 | embed_dim (int): embedding dimension 364 | depth (int): depth of transformer 365 | num_heads (int): number of attention heads 366 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 367 | qkv_bias (bool): enable bias for qkv if True 368 | init_values: (float): layer-scale init values 369 | class_token (bool): use class token 370 | fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) 371 | drop_rate (float): dropout rate 372 | attn_drop_rate (float): attention dropout rate 373 | drop_path_rate (float): stochastic depth rate 374 | weight_init (str): weight init scheme 375 | embed_layer (nn.Module): patch embedding layer 376 | norm_layer: (nn.Module): normalization layer 377 | act_layer: (nn.Module): MLP activation layer 378 | """ 379 | super().__init__() 380 | assert global_pool in ('', 'avg', 'token') 381 | assert class_token or global_pool != 'token' 382 | use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm 383 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 384 | act_layer = act_layer or nn.GELU 385 | 386 | self.num_classes = num_classes 387 | self.smallest_ratio = smallest_ratio 388 | self.largest_ratio = largest_ratio 389 | self.global_pool = global_pool 390 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 391 | self.num_prefix_tokens = 1 if class_token else 0 392 | self.no_embed_class = no_embed_class 393 | self.grad_checkpointing = False 394 | 395 | self.patch_embed = embed_layer( 396 | smallest_ratio=smallest_ratio, largest_ratio=largest_ratio, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 397 | num_patches = self.patch_embed.num_patches 398 | 399 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None 400 | embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens 401 | self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) 402 | self.pos_drop = nn.Dropout(p=drop_rate) 403 | 404 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 405 | self.blocks = nn.Sequential(*[ 406 | block_fn( 407 | smallest_ratio, largest_ratio, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, 408 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) 409 | for i in range(depth)]) 410 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() 411 | 412 | # Classifier Head 413 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() 414 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 415 | 416 | self.use_fc_norm = use_fc_norm 417 | 418 | if weight_init != 'skip': 419 | self.init_weights(weight_init) 420 | 421 | def init_weights(self, mode=''): 422 | assert mode in ('jax', 'jax_nlhb', 'moco', '') 423 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 424 | trunc_normal_(self.pos_embed, std=.02) 425 | if self.cls_token is not None: 426 | nn.init.normal_(self.cls_token, std=1e-6) 427 | named_apply(get_init_weights_vit(mode, head_bias), self) 428 | 429 | def _init_weights(self, m): 430 | # this fn left here for compat with downstream users 431 | init_weights_vit_timm(m) 432 | 433 | @torch.jit.ignore() 434 | def load_pretrained(self, checkpoint_path, prefix=''): 435 | _load_weights(self, checkpoint_path, prefix) 436 | 437 | @torch.jit.ignore 438 | def no_weight_decay(self): 439 | return {'pos_embed', 'cls_token', 'dist_token'} 440 | 441 | @torch.jit.ignore 442 | def group_matcher(self, coarse=False): 443 | return dict( 444 | stem=r'^cls_token|pos_embed|patch_embed', # stem and embed 445 | blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] 446 | ) 447 | 448 | @torch.jit.ignore 449 | def set_grad_checkpointing(self, enable=True): 450 | self.grad_checkpointing = enable 451 | 452 | @torch.jit.ignore 453 | def get_classifier(self): 454 | return self.head 455 | 456 | def reset_classifier(self, num_classes: int, global_pool=None): 457 | self.num_classes = num_classes 458 | if global_pool is not None: 459 | assert global_pool in ('', 'avg', 'token') 460 | self.global_pool = global_pool 461 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 462 | 463 | def _pos_embed(self, x, ratio): 464 | if self.no_embed_class: 465 | # deit-3, updated JAX (big vision) 466 | # position embedding does not overlap with class token, add then concat 467 | if ratio == self.smallest_ratio: 468 | pos_embed = self.pos_embed[:,:,:int(ratio*self.num_features)] 469 | x = x + pos_embed 470 | if self.cls_token is not None: 471 | cls_token = self.cls_token[:,:,:int(ratio*self.num_features)] 472 | x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1) 473 | elif ratio == self.largest_ratio: 474 | x = x + self.pos_embed 475 | if self.cls_token is not None: 476 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 477 | else: 478 | pos_embed = self.pos_embed[:,:,-int(ratio*self.num_features):] 479 | x = x + pos_embed 480 | if self.cls_token is not None: 481 | cls_token = self.cls_token[:,:,-int(ratio*self.num_features):] 482 | x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1) 483 | else: 484 | # original timm, JAX, and deit vit impl 485 | # pos_embed has entry for class token, concat then add 486 | if ratio == self.smallest_ratio: 487 | if self.cls_token is not None: 488 | cls_token = self.cls_token[:,:,:int(ratio*self.num_features)] 489 | x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1) 490 | pos_embed = self.pos_embed[:,:,:int(ratio*self.num_features)] 491 | x = x + pos_embed 492 | elif ratio == self.largest_ratio: 493 | if self.cls_token is not None: 494 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) 495 | x = x + self.pos_embed 496 | else: 497 | if self.cls_token is not None: 498 | cls_token = self.cls_token[:,:,-int(ratio*self.num_features):] 499 | x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1) 500 | pos_embed = self.pos_embed[:,:,-int(ratio*self.num_features):] 501 | x = x + pos_embed 502 | 503 | return self.pos_drop(x) 504 | 505 | def forward_features(self, x, ratio): 506 | x = self.patch_embed(x, ratio) 507 | x = self._pos_embed(x, ratio) 508 | for blk in self.blocks: 509 | if self.grad_checkpointing and not torch.jit.is_scripting(): 510 | x = checkpoint_seq(blk, x, ratio) 511 | else: 512 | x = blk(x, ratio) 513 | 514 | if self.use_fc_norm: 515 | x = self.norm(x) 516 | else: 517 | if ratio == self.smallest_ratio or ratio == self.largest_ratio: 518 | weight_norm = self.norm.weight[:int(ratio*self.embed_dim)] 519 | bias_norm = self.norm.bias[:int(ratio*self.embed_dim)] 520 | else: 521 | weight_norm = self.norm.weight[-int(ratio*self.embed_dim):] 522 | bias_norm = self.norm.bias[-int(ratio*self.embed_dim):] 523 | x = F.layer_norm(x, [int(ratio*self.embed_dim)], weight_norm, bias_norm) 524 | return x 525 | 526 | def forward_head(self, x, ratio, pre_logits: bool = False): 527 | if self.global_pool: 528 | x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] 529 | 530 | if self.use_fc_norm: 531 | if ratio == self.smallest_ratio or ratio == self.largest_ratio: 532 | weight_fc_norm = self.fc_norm.weight[:int(ratio*self.embed_dim)] 533 | bias_fc_norm = self.fc_norm.bias[:int(ratio*self.embed_dim)] 534 | else: 535 | weight_fc_norm = self.fc_norm.weight[-int(ratio*self.embed_dim):] 536 | bias_fc_norm = self.fc_norm.bias[-int(ratio*self.embed_dim):] 537 | x = F.layer_norm(x, [int(ratio*self.embed_dim)], weight_fc_norm, bias_fc_norm) 538 | else: 539 | x = self.fc_norm(x) 540 | 541 | if pre_logits: 542 | x = x 543 | else: 544 | if ratio == self.smallest_ratio or ratio == self.largest_ratio: 545 | weight_head = self.head.weight[:, :int(ratio*self.embed_dim)] 546 | bias_head = self.head.bias[:] 547 | else: 548 | weight_head = self.head.weight[:, -int(ratio*self.embed_dim):] 549 | bias_head = self.head.bias[:] 550 | x = F.linear(input=x, weight=weight_head, bias=bias_head) 551 | 552 | return x 553 | 554 | def forward(self, x, ratio): 555 | x = self.forward_features(x, ratio) 556 | x = self.forward_head(x, ratio) 557 | return x 558 | 559 | 560 | def init_weights_vit_timm(module: nn.Module, name: str = ''): 561 | """ ViT weight initialization, original timm impl (for reproducibility) """ 562 | if isinstance(module, nn.Linear): 563 | trunc_normal_(module.weight, std=.02) 564 | if module.bias is not None: 565 | nn.init.zeros_(module.bias) 566 | elif hasattr(module, 'init_weights'): 567 | module.init_weights() 568 | 569 | 570 | def get_init_weights_vit(mode='jax', head_bias: float = 0.): 571 | if 'jax' in mode: 572 | return partial(init_weights_vit_jax, head_bias=head_bias) 573 | elif 'moco' in mode: 574 | return init_weights_vit_moco 575 | else: 576 | return init_weights_vit_timm 577 | 578 | 579 | 580 | 581 | class DistilledVisionTransformer_scala(VisionTransformer_scala): 582 | def __init__(self, *args, **kwargs): 583 | super().__init__(*args, **kwargs) 584 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 585 | num_patches = self.patch_embed.num_patches 586 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 587 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 588 | 589 | trunc_normal_(self.dist_token, std=.02) 590 | trunc_normal_(self.pos_embed, std=.02) 591 | self.head_dist.apply(self._init_weights) 592 | 593 | def forward_features(self, x, ratio): 594 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 595 | # with slight modifications to add the dist_token 596 | B = x.shape[0] 597 | x = self.patch_embed(x, ratio) 598 | 599 | if ratio == 0.25 or ratio == 1.0: 600 | cls_token = self.cls_token[:,:,:int(ratio*self.num_features)] 601 | dist_token = self.dist_token[:,:,:int(ratio*self.num_features)] 602 | x = torch.cat((cls_token.expand(B, -1, -1), dist_token.expand(B, -1, -1), x), dim=1) 603 | pos_embed = self.pos_embed[:,:,:int(ratio*self.num_features)] 604 | x = x + pos_embed 605 | else: 606 | cls_token = self.cls_token[:,:,-int(ratio*self.num_features):] 607 | dist_token = self.dist_token[:,:,-int(ratio*self.num_features):] 608 | x = torch.cat((cls_token.expand(B, -1, -1), dist_token.expand(B, -1, -1), x), dim=1) 609 | pos_embed = self.pos_embed[:,:,-int(ratio*self.num_features):] 610 | x = x + pos_embed 611 | 612 | x = self.pos_drop(x) 613 | 614 | for blk in self.blocks: 615 | x = blk(x, ratio) 616 | 617 | if ratio == 0.25 or ratio == 1.0: 618 | weight_norm = self.norm.weight[:int(ratio*self.embed_dim)] 619 | bias_norm = self.norm.bias[:int(ratio*self.embed_dim)] 620 | x = F.layer_norm(x, [int(ratio*self.embed_dim)], weight_norm, bias_norm) 621 | else: 622 | weight_norm = self.norm.weight[-int(ratio*self.embed_dim):] 623 | bias_norm = self.norm.bias[-int(ratio*self.embed_dim):] 624 | x = F.layer_norm(x, [int(ratio*self.embed_dim)], weight_norm, bias_norm) 625 | 626 | return x[:, 0], x[:, 1] 627 | 628 | def forward(self, x, ratio): 629 | x, x_dist = self.forward_features(x, ratio) 630 | 631 | if ratio == 0.25 or ratio == 1.0: 632 | weight_head = self.head.weight[:, :int(ratio*self.embed_dim)] 633 | bias_head = self.head.bias[:] 634 | x = F.linear(input=x, weight=weight_head, bias=bias_head) 635 | else: 636 | weight_head = self.head.weight[:, -int(ratio*self.embed_dim):] 637 | bias_head = self.head.bias[:] 638 | x = F.linear(input=x, weight=weight_head, bias=bias_head) 639 | 640 | if ratio == 0.25 or ratio == 1.0: 641 | weight_head_dist = self.head_dist.weight[:, :int(ratio*self.embed_dim)] 642 | bias_head_dist = self.head_dist.bias[:] 643 | x_dist = F.linear(input=x_dist, weight=weight_head_dist, bias=bias_head_dist) 644 | else: 645 | weight_head_dist = self.head_dist.weight[:, -int(ratio*self.embed_dim):] 646 | bias_head_dist = self.head_dist.bias[:] 647 | x_dist = F.linear(input=x_dist, weight=weight_head_dist, bias=bias_head_dist) 648 | if self.training: 649 | return x, x_dist 650 | else: 651 | # during inference, return the average of both classifier predictions 652 | return (x + x_dist) / 2 653 | 654 | 655 | 656 | @register_model 657 | def deit_so_tiny_patch16_224_scala(pretrained=False, pretrained_cfg=None, **kwargs): 658 | model = VisionTransformer_scala( 659 | patch_size=16, embed_dim=96, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 660 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 661 | model.default_cfg = _cfg() 662 | return model 663 | 664 | 665 | @register_model 666 | def deit_tiny_patch16_224_scala(pretrained=False, pretrained_cfg=None, **kwargs): 667 | model = VisionTransformer_scala( 668 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 669 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 670 | model.default_cfg = _cfg() 671 | return model 672 | 673 | 674 | @register_model 675 | def deit_small_patch16_224_scala(pretrained=False, pretrained_cfg=None, **kwargs): 676 | model = VisionTransformer_scala( 677 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 678 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 679 | model.default_cfg = _cfg() 680 | return model 681 | 682 | 683 | @register_model 684 | def deit_base_patch16_224_scala(pretrained=False, pretrained_cfg=None, **kwargs): 685 | model = VisionTransformer_scala( 686 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 687 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 688 | model.default_cfg = _cfg() 689 | return model 690 | 691 | 692 | @register_model 693 | def deit_tiny_distilled_patch16_224_scala(pretrained=False, pretrained_cfg=None, **kwargs): 694 | model = DistilledVisionTransformer_scala( 695 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 696 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 697 | model.default_cfg = _cfg() 698 | return model 699 | 700 | 701 | @register_model 702 | def deit_small_distilled_patch16_224_scala(pretrained=False, pretrained_cfg=None, **kwargs): 703 | model = DistilledVisionTransformer_scala( 704 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 705 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 706 | model.default_cfg = _cfg() 707 | return model 708 | 709 | 710 | @register_model 711 | def deit_base_distilled_patch16_224_scala(pretrained=False, pretrained_cfg=None, **kwargs): 712 | model = DistilledVisionTransformer_scala( 713 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 714 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 715 | model.default_cfg = _cfg() 716 | return model 717 | 718 | 719 | @register_model 720 | def deit_base_patch16_384_scala(pretrained=False, pretrained_cfg=None, **kwargs): 721 | model = VisionTransformer_scala( 722 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 723 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 724 | model.default_cfg = _cfg() 725 | return model 726 | 727 | 728 | @register_model 729 | def deit_base_distilled_patch16_384_scala(pretrained=False, pretrained_cfg=None, **kwargs): 730 | model = DistilledVisionTransformer_scala( 731 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 732 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 733 | model.default_cfg = _cfg() 734 | return model 735 | 736 | 737 | 738 | if __name__ == "__main__": 739 | import argparse 740 | from fvcore.nn import FlopCountAnalysis, parameter_count_table 741 | parser = argparse.ArgumentParser(description='PyTorch resnet Training') 742 | args = parser.parse_args() 743 | 744 | args.num_classes = 1000 745 | with torch.no_grad(): 746 | model = deit_small_distilled_patch16_224_scala() 747 | 748 | # for name, param in model.named_parameters(): 749 | # print(name) 750 | 751 | tensor = (torch.rand(1, 3, 224, 224), 0.5) 752 | flops = FlopCountAnalysis(model, tensor) 753 | print("FLOPs: ", flops.total()/1e9) -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | ### scala on deit_s for 100 epoch 2 | 3 | python -m torch.distributed.launch --nproc_per_node=4 --use_env main_scala.py \ 4 | --batch-size 256 --epochs 100 --data-path IMAGENET_LOCATION \ 5 | --aa rand-m1-mstd0.5-inc1 --no-repeated-aug --lr 2e-3 --warmup-epochs 3 \ 6 | --teacher-model deit_small_patch16_224 --distillation-type hard \ 7 | --teacher-path https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth \ 8 | --model deit_small_distilled_patch16_224_scala \ 9 | --smallest_ratio 0.25 --largest_ratio 1.0 --granularity 0.0625 --distill_type soft \ 10 | --transfer_type progressive --token_type dist_token --ce_coefficient 1.0 \ 11 | --full_warm_epoch 10 --output_dir log/deit_small_distilled_scala; 12 | 13 | 14 | 15 | ### scala on deit_b for 300 epoch 16 | 17 | python -m torch.distributed.launch --nproc_per_node=4 --use_env main_scala.py \ 18 | --batch-size 128 --epochs 300 --data-path IMAGENET_LOCATION \ 19 | --aa rand-m4-mstd0.5-inc1 --no-repeated-aug --lr 1e-3 --warmup-epochs 5 \ 20 | --teacher-model deit_base_patch16_224 --distillation-type hard \ 21 | --teacher-path https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth \ 22 | --model deit_base_distilled_patch16_224_scala \ 23 | --smallest_ratio 0.25 --largest_ratio 1.0 --discrete_ratio 0.0625 --distill_type soft \ 24 | --transfer_type progressive --token_type dist_token --ce_coefficient 1.0 \ 25 | --full_warm_epoch 10 --output_dir log/deit_base_distilled_scala; -------------------------------------------------------------------------------- /samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.distributed as dist 5 | import math 6 | 7 | 8 | class RASampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset for distributed, 10 | with repeated augmentation. 11 | It ensures that different each augmented version of a sample will be visible to a 12 | different process (GPU) 13 | Heavily based on torch.utils.data.DistributedSampler 14 | """ 15 | 16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3): 17 | if num_replicas is None: 18 | if not dist.is_available(): 19 | raise RuntimeError("Requires distributed package to be available") 20 | num_replicas = dist.get_world_size() 21 | if rank is None: 22 | if not dist.is_available(): 23 | raise RuntimeError("Requires distributed package to be available") 24 | rank = dist.get_rank() 25 | if num_repeats < 1: 26 | raise ValueError("num_repeats should be greater than 0") 27 | self.dataset = dataset 28 | self.num_replicas = num_replicas 29 | self.rank = rank 30 | self.num_repeats = num_repeats 31 | self.epoch = 0 32 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas)) 33 | self.total_size = self.num_samples * self.num_replicas 34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas)) 35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas)) 36 | self.shuffle = shuffle 37 | 38 | def __iter__(self): 39 | if self.shuffle: 40 | # deterministically shuffle based on epoch 41 | g = torch.Generator() 42 | g.manual_seed(self.epoch) 43 | indices = torch.randperm(len(self.dataset), generator=g) 44 | else: 45 | indices = torch.arange(start=0, end=len(self.dataset)) 46 | 47 | # add extra samples to make it evenly divisible 48 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist() 49 | padding_size: int = self.total_size - len(indices) 50 | if padding_size > 0: 51 | indices += indices[:padding_size] 52 | assert len(indices) == self.total_size 53 | 54 | # subsample 55 | indices = indices[self.rank:self.total_size:self.num_replicas] 56 | assert len(indices) == self.num_samples 57 | 58 | return iter(indices[:self.num_selected_samples]) 59 | 60 | def __len__(self): 61 | return self.num_selected_samples 62 | 63 | def set_epoch(self, epoch): 64 | self.epoch = epoch 65 | -------------------------------------------------------------------------------- /scheduler.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from abc import ABC 3 | from typing import Any, Dict, Optional 4 | import torch 5 | from typing import List, Union 6 | from torch.optim import Optimizer 7 | import logging 8 | import math 9 | import numpy as np 10 | import torch 11 | 12 | 13 | _logger = logging.getLogger(__name__) 14 | 15 | 16 | def scheduler_kwargs(cfg): 17 | """ cfg/argparse to kwargs helper 18 | Convert scheduler args in argparse args or cfg (.dot) like object to keyword args. 19 | """ 20 | eval_metric = getattr(cfg, 'eval_metric', 'top1') 21 | plateau_mode = 'min' if 'loss' in eval_metric else 'max' 22 | kwargs = dict( 23 | sched=cfg.sched, 24 | num_epochs=getattr(cfg, 'epochs', 100), 25 | decay_epochs=getattr(cfg, 'decay_epochs', 30), 26 | decay_milestones=getattr(cfg, 'decay_milestones', [30, 60]), 27 | warmup_epochs=getattr(cfg, 'warmup_epochs', 5), 28 | warm_epoch=getattr(cfg, 'warm_epoch', 10), 29 | warm_factor=getattr(cfg, 'warm_factor', 2.0), 30 | cooldown_epochs=getattr(cfg, 'cooldown_epochs', 0), 31 | patience_epochs=getattr(cfg, 'patience_epochs', 10), 32 | decay_rate=getattr(cfg, 'decay_rate', 0.1), 33 | min_lr=getattr(cfg, 'min_lr', 0.), 34 | warmup_lr=getattr(cfg, 'warmup_lr', 1e-5), 35 | warmup_prefix=getattr(cfg, 'warmup_prefix', False), 36 | noise=getattr(cfg, 'lr_noise', None), 37 | noise_pct=getattr(cfg, 'lr_noise_pct', 0.67), 38 | noise_std=getattr(cfg, 'lr_noise_std', 1.), 39 | noise_seed=getattr(cfg, 'seed', 42), 40 | cycle_mul=getattr(cfg, 'lr_cycle_mul', 1.), 41 | cycle_decay=getattr(cfg, 'lr_cycle_decay', 0.1), 42 | cycle_limit=getattr(cfg, 'lr_cycle_limit', 1), 43 | k_decay=getattr(cfg, 'lr_k_decay', 1.0), 44 | plateau_mode=plateau_mode, 45 | step_on_epochs=not getattr(cfg, 'sched_on_updates', False), 46 | ) 47 | return kwargs 48 | 49 | 50 | def create_scheduler( 51 | args, 52 | optimizer: Optimizer, 53 | updates_per_epoch: int = 0, 54 | ): 55 | return create_scheduler_v2( 56 | optimizer=optimizer, 57 | **scheduler_kwargs(args), 58 | updates_per_epoch=updates_per_epoch, 59 | ) 60 | 61 | 62 | def create_scheduler_v2( 63 | optimizer: Optimizer, 64 | sched: str = 'cosine', 65 | num_epochs: int = 300, 66 | decay_epochs: int = 90, 67 | decay_milestones: List[int] = (90, 180, 270), 68 | cooldown_epochs: int = 0, 69 | patience_epochs: int = 10, 70 | decay_rate: float = 0.1, 71 | min_lr: float = 0, 72 | warmup_lr: float = 1e-5, 73 | warm_epoch: int = 0, 74 | warm_factor: float = 1.0, 75 | warmup_epochs: int = 0, 76 | warmup_prefix: bool = False, 77 | noise: Union[float, List[float]] = None, 78 | noise_pct: float = 0.67, 79 | noise_std: float = 1., 80 | noise_seed: int = 42, 81 | cycle_mul: float = 1., 82 | cycle_decay: float = 0.1, 83 | cycle_limit: int = 1, 84 | k_decay: float = 1.0, 85 | plateau_mode: str = 'max', 86 | step_on_epochs: bool = True, 87 | updates_per_epoch: int = 0, 88 | ): 89 | t_initial = num_epochs 90 | warmup_t = warmup_epochs 91 | decay_t = decay_epochs 92 | cooldown_t = cooldown_epochs 93 | 94 | if not step_on_epochs: 95 | assert updates_per_epoch > 0, 'updates_per_epoch must be set to number of dataloader batches' 96 | t_initial = t_initial * updates_per_epoch 97 | warmup_t = warmup_t * updates_per_epoch 98 | decay_t = decay_t * updates_per_epoch 99 | decay_milestones = [d * updates_per_epoch for d in decay_milestones] 100 | cooldown_t = cooldown_t * updates_per_epoch 101 | 102 | # warmup args 103 | warmup_args = dict( 104 | warmup_lr_init=warmup_lr, 105 | warmup_t=warmup_t, 106 | warm_epoch = warm_epoch, 107 | warm_factor = warm_factor, 108 | warmup_prefix=warmup_prefix, 109 | ) 110 | 111 | # setup noise args for supporting schedulers 112 | if noise is not None: 113 | if isinstance(noise, (list, tuple)): 114 | noise_range = [n * t_initial for n in noise] 115 | if len(noise_range) == 1: 116 | noise_range = noise_range[0] 117 | else: 118 | noise_range = noise * t_initial 119 | else: 120 | noise_range = None 121 | noise_args = dict( 122 | noise_range_t=noise_range, 123 | noise_pct=noise_pct, 124 | noise_std=noise_std, 125 | noise_seed=noise_seed, 126 | ) 127 | 128 | # setup cycle args for supporting schedulers 129 | cycle_args = dict( 130 | cycle_mul=cycle_mul, 131 | cycle_decay=cycle_decay, 132 | cycle_limit=cycle_limit, 133 | ) 134 | 135 | lr_scheduler = CosineLRScheduler( 136 | optimizer, 137 | t_initial=t_initial, 138 | lr_min=min_lr, 139 | t_in_epochs=step_on_epochs, 140 | **cycle_args, 141 | **warmup_args, 142 | **noise_args, 143 | k_decay=k_decay, 144 | ) 145 | 146 | 147 | if hasattr(lr_scheduler, 'get_cycle_length'): 148 | # for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown 149 | t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t 150 | if step_on_epochs: 151 | num_epochs = t_with_cycles_and_cooldown 152 | else: 153 | num_epochs = t_with_cycles_and_cooldown // updates_per_epoch 154 | 155 | return lr_scheduler, num_epochs 156 | 157 | 158 | 159 | 160 | class Scheduler(ABC): 161 | """ Parameter Scheduler Base Class 162 | A scheduler base class that can be used to schedule any optimizer parameter groups. 163 | 164 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 165 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 166 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 167 | 168 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 169 | 170 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 171 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 172 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 173 | 174 | Based on ideas from: 175 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 176 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 177 | """ 178 | 179 | def __init__( 180 | self, 181 | optimizer: torch.optim.Optimizer, 182 | param_group_field: str, 183 | t_in_epochs: bool = True, 184 | noise_range_t=None, 185 | noise_type='normal', 186 | noise_pct=0.67, 187 | noise_std=1.0, 188 | noise_seed=None, 189 | initialize: bool = True, 190 | ) -> None: 191 | self.optimizer = optimizer 192 | self.param_group_field = param_group_field 193 | self._initial_param_group_field = f"initial_{param_group_field}" 194 | if initialize: 195 | for i, group in enumerate(self.optimizer.param_groups): 196 | if param_group_field not in group: 197 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 198 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 199 | else: 200 | for i, group in enumerate(self.optimizer.param_groups): 201 | if self._initial_param_group_field not in group: 202 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 203 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 204 | self.metric = None # any point to having this for all? 205 | self.t_in_epochs = t_in_epochs 206 | self.noise_range_t = noise_range_t 207 | self.noise_pct = noise_pct 208 | self.noise_type = noise_type 209 | self.noise_std = noise_std 210 | self.noise_seed = noise_seed if noise_seed is not None else 42 211 | self.update_groups(self.base_values) 212 | 213 | def state_dict(self) -> Dict[str, Any]: 214 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 215 | 216 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 217 | self.__dict__.update(state_dict) 218 | 219 | @abc.abstractmethod 220 | def _get_lr(self, t: int) -> float: 221 | pass 222 | 223 | def _get_values(self, t: int, on_epoch: bool = True) -> Optional[float]: 224 | proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs) 225 | if not proceed: 226 | return None 227 | return self._get_lr(t) 228 | 229 | def step(self, epoch: int, metric: float = None) -> None: 230 | self.metric = metric 231 | values = self._get_values(epoch, on_epoch=True) 232 | if values is not None: 233 | values = self._add_noise(values, epoch) 234 | self.update_groups(values) 235 | 236 | def step_update(self, num_updates: int, metric: float = None): 237 | self.metric = metric 238 | values = self._get_values(num_updates, on_epoch=False) 239 | if values is not None: 240 | values = self._add_noise(values, num_updates) 241 | self.update_groups(values) 242 | 243 | def update_groups(self, values): 244 | if not isinstance(values, (list, tuple)): 245 | values = [values] * len(self.optimizer.param_groups) 246 | for param_group, value in zip(self.optimizer.param_groups, values): 247 | if 'lr_scale' in param_group: 248 | param_group[self.param_group_field] = value * param_group['lr_scale'] 249 | else: 250 | param_group[self.param_group_field] = value 251 | 252 | def _add_noise(self, lrs, t): 253 | if self._is_apply_noise(t): 254 | noise = self._calculate_noise(t) 255 | lrs = [v + v * noise for v in lrs] 256 | return lrs 257 | 258 | def _is_apply_noise(self, t) -> bool: 259 | """Return True if scheduler in noise range.""" 260 | apply_noise = False 261 | if self.noise_range_t is not None: 262 | if isinstance(self.noise_range_t, (list, tuple)): 263 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 264 | else: 265 | apply_noise = t >= self.noise_range_t 266 | return apply_noise 267 | 268 | def _calculate_noise(self, t) -> float: 269 | g = torch.Generator() 270 | g.manual_seed(self.noise_seed + t) 271 | if self.noise_type == 'normal': 272 | while True: 273 | # resample if noise out of percent limit, brute force but shouldn't spin much 274 | noise = torch.randn(1, generator=g).item() 275 | if abs(noise) < self.noise_pct: 276 | return noise 277 | else: 278 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 279 | return noise 280 | 281 | 282 | 283 | 284 | class CosineLRScheduler(Scheduler): 285 | """ 286 | Cosine decay with restarts. 287 | This is described in the paper https://arxiv.org/abs/1608.03983. 288 | 289 | Inspiration from 290 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 291 | 292 | k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909 293 | """ 294 | 295 | def __init__( 296 | self, 297 | optimizer: torch.optim.Optimizer, 298 | t_initial: int, 299 | lr_min: float = 0., 300 | cycle_mul: float = 1., 301 | cycle_decay: float = 1., 302 | cycle_limit: int = 1, 303 | warmup_t=0, 304 | warm_epoch=0, 305 | warm_factor=1.0, 306 | warmup_lr_init=0, 307 | warmup_prefix=False, 308 | t_in_epochs=True, 309 | noise_range_t=None, 310 | noise_pct=0.67, 311 | noise_std=1.0, 312 | noise_seed=42, 313 | k_decay=1.0, 314 | initialize=True, 315 | ) -> None: 316 | super().__init__( 317 | optimizer, 318 | param_group_field="lr", 319 | t_in_epochs=t_in_epochs, 320 | noise_range_t=noise_range_t, 321 | noise_pct=noise_pct, 322 | noise_std=noise_std, 323 | noise_seed=noise_seed, 324 | initialize=initialize, 325 | ) 326 | 327 | assert t_initial > 0 328 | assert lr_min >= 0 329 | if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1: 330 | _logger.warning( 331 | "Cosine annealing scheduler will have no effect on the learning " 332 | "rate since t_initial = t_mul = eta_mul = 1.") 333 | self.t_initial = t_initial 334 | self.lr_min = lr_min 335 | self.cycle_mul = cycle_mul 336 | self.cycle_decay = cycle_decay 337 | self.cycle_limit = cycle_limit 338 | self.warmup_t = warmup_t 339 | self.warm_epoch = warm_epoch 340 | self.warm_factor = warm_factor 341 | self.warmup_lr_init = warmup_lr_init 342 | self.warmup_prefix = warmup_prefix 343 | self.k_decay = k_decay 344 | if self.warmup_t: 345 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 346 | super().update_groups(self.warmup_lr_init) 347 | else: 348 | self.warmup_steps = [1 for _ in self.base_values] 349 | 350 | def _get_lr(self, t): 351 | if t < self.warmup_t: 352 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 353 | else: 354 | if self.warmup_prefix: 355 | t = t - self.warmup_t 356 | 357 | if self.cycle_mul != 1: 358 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul)) 359 | t_i = self.cycle_mul ** i * self.t_initial 360 | t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial 361 | else: 362 | i = t // self.t_initial 363 | t_i = self.t_initial 364 | t_curr = t - (self.t_initial * i) 365 | 366 | gamma = self.cycle_decay ** i 367 | lr_max_values = [v * gamma for v in self.base_values] 368 | k = self.k_decay 369 | 370 | if i < self.cycle_limit: 371 | lrs = [ 372 | self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k)) 373 | for lr_max in lr_max_values 374 | ] 375 | else: 376 | lrs = [self.lr_min for _ in self.base_values] 377 | 378 | if t < self.warm_epoch: 379 | lrs = [lr / self.warm_factor for lr in lrs] 380 | 381 | return lrs 382 | 383 | def get_cycle_length(self, cycles=0): 384 | cycles = max(1, cycles or self.cycle_limit) 385 | if self.cycle_mul == 1.0: 386 | return self.t_initial * cycles 387 | else: 388 | return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul))) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Misc functions, including distributed helpers. 5 | 6 | Mostly copy-paste from torchvision references. 7 | """ 8 | import io 9 | import os 10 | import time 11 | from collections import defaultdict, deque 12 | import datetime 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | class SmoothedValue(object): 19 | """Track a series of values and provide access to smoothed values over a 20 | window or the global series average. 21 | """ 22 | 23 | def __init__(self, window_size=20, fmt=None): 24 | if fmt is None: 25 | fmt = "{median:.4f} ({global_avg:.4f})" 26 | self.deque = deque(maxlen=window_size) 27 | self.total = 0.0 28 | self.count = 0 29 | self.fmt = fmt 30 | 31 | def update(self, value, n=1): 32 | self.deque.append(value) 33 | self.count += n 34 | self.total += value * n 35 | 36 | def synchronize_between_processes(self): 37 | """ 38 | Warning: does not synchronize the deque! 39 | """ 40 | if not is_dist_avail_and_initialized(): 41 | return 42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 43 | dist.barrier() 44 | dist.all_reduce(t) 45 | t = t.tolist() 46 | self.count = int(t[0]) 47 | self.total = t[1] 48 | 49 | @property 50 | def median(self): 51 | d = torch.tensor(list(self.deque)) 52 | return d.median().item() 53 | 54 | @property 55 | def avg(self): 56 | d = torch.tensor(list(self.deque), dtype=torch.float32) 57 | return d.mean().item() 58 | 59 | @property 60 | def global_avg(self): 61 | return self.total / self.count 62 | 63 | @property 64 | def max(self): 65 | return max(self.deque) 66 | 67 | @property 68 | def value(self): 69 | return self.deque[-1] 70 | 71 | def __str__(self): 72 | return self.fmt.format( 73 | median=self.median, 74 | avg=self.avg, 75 | global_avg=self.global_avg, 76 | max=self.max, 77 | value=self.value) 78 | 79 | 80 | class MetricLogger(object): 81 | def __init__(self, delimiter="\t"): 82 | self.meters = defaultdict(SmoothedValue) 83 | self.delimiter = delimiter 84 | 85 | def update(self, **kwargs): 86 | for k, v in kwargs.items(): 87 | if isinstance(v, torch.Tensor): 88 | v = v.item() 89 | assert isinstance(v, (float, int)) 90 | self.meters[k].update(v) 91 | 92 | def __getattr__(self, attr): 93 | if attr in self.meters: 94 | return self.meters[attr] 95 | if attr in self.__dict__: 96 | return self.__dict__[attr] 97 | raise AttributeError("'{}' object has no attribute '{}'".format( 98 | type(self).__name__, attr)) 99 | 100 | def __str__(self): 101 | loss_str = [] 102 | for name, meter in self.meters.items(): 103 | loss_str.append( 104 | "{}: {}".format(name, str(meter)) 105 | ) 106 | return self.delimiter.join(loss_str) 107 | 108 | def synchronize_between_processes(self): 109 | for meter in self.meters.values(): 110 | meter.synchronize_between_processes() 111 | 112 | def add_meter(self, name, meter): 113 | self.meters[name] = meter 114 | 115 | def log_every(self, iterable, print_freq, header=None): 116 | i = 0 117 | if not header: 118 | header = '' 119 | start_time = time.time() 120 | end = time.time() 121 | iter_time = SmoothedValue(fmt='{avg:.4f}') 122 | data_time = SmoothedValue(fmt='{avg:.4f}') 123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 124 | log_msg = [ 125 | header, 126 | '[{0' + space_fmt + '}/{1}]', 127 | 'eta: {eta}', 128 | '{meters}', 129 | 'time: {time}', 130 | 'data: {data}' 131 | ] 132 | if torch.cuda.is_available(): 133 | log_msg.append('max mem: {memory:.0f}') 134 | log_msg = self.delimiter.join(log_msg) 135 | MB = 1024.0 * 1024.0 136 | for obj in iterable: 137 | data_time.update(time.time() - end) 138 | yield obj 139 | iter_time.update(time.time() - end) 140 | if i % print_freq == 0 or i == len(iterable) - 1: 141 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 143 | if torch.cuda.is_available(): 144 | print(log_msg.format( 145 | i, len(iterable), eta=eta_string, 146 | meters=str(self), 147 | time=str(iter_time), data=str(data_time), 148 | memory=torch.cuda.max_memory_allocated() / MB)) 149 | else: 150 | print(log_msg.format( 151 | i, len(iterable), eta=eta_string, 152 | meters=str(self), 153 | time=str(iter_time), data=str(data_time))) 154 | i += 1 155 | end = time.time() 156 | total_time = time.time() - start_time 157 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 158 | print('{} Total time: {} ({:.4f} s / it)'.format( 159 | header, total_time_str, total_time / len(iterable))) 160 | 161 | 162 | def _load_checkpoint_for_ema(model_ema, checkpoint): 163 | """ 164 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object 165 | """ 166 | mem_file = io.BytesIO() 167 | torch.save({'state_dict_ema':checkpoint}, mem_file) 168 | mem_file.seek(0) 169 | model_ema._load_checkpoint(mem_file) 170 | 171 | 172 | def setup_for_distributed(is_master): 173 | """ 174 | This function disables printing when not in master process 175 | """ 176 | import builtins as __builtin__ 177 | builtin_print = __builtin__.print 178 | 179 | def print(*args, **kwargs): 180 | force = kwargs.pop('force', False) 181 | if is_master or force: 182 | builtin_print(*args, **kwargs) 183 | 184 | __builtin__.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | 216 | def init_distributed_mode(args): 217 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 218 | args.rank = int(os.environ["RANK"]) 219 | args.world_size = int(os.environ['WORLD_SIZE']) 220 | args.gpu = int(os.environ['LOCAL_RANK']) 221 | elif 'SLURM_PROCID' in os.environ: 222 | args.rank = int(os.environ['SLURM_PROCID']) 223 | args.gpu = args.rank % torch.cuda.device_count() 224 | else: 225 | print('Not using distributed mode') 226 | args.distributed = False 227 | return 228 | 229 | args.distributed = True 230 | 231 | torch.cuda.set_device(args.gpu) 232 | args.dist_backend = 'nccl' 233 | print('| distributed init (rank {}): {}'.format( 234 | args.rank, args.dist_url), flush=True) 235 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 236 | world_size=args.world_size, rank=args.rank) 237 | torch.distributed.barrier() 238 | setup_for_distributed(args.rank == 0) 239 | --------------------------------------------------------------------------------