├── .gitattributes
├── LICENSE
├── README.md
├── augment.py
├── datasets.py
├── engine_scala.py
├── eval.sh
├── fig
├── dino.png
├── gran_bound.png
├── hybrid.png
├── intro.png
├── meta.png
├── neu.png
├── slim.png
├── snap.png
└── transfer.png
├── losses.py
├── main_scala.py
├── models_scala.py
├── run.sh
├── samplers.py
├── scheduler.py
└── utils.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 BeSpontaneous
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Slicing Vision Transformer for Flexible Inference (NeurIPS 2024)
2 |
3 |
8 |
9 | Primary contact: [Yitian Zhang](mailto:markcheung9248@gmail.com)
10 |
11 |
12 |

13 |
14 |
15 |
16 | ## TL,DR
17 | - `Background`: ViTs are the same architecture but only differ in embedding dimensions, a large ViT can be transformed to represent small models by uniformly slicing the weight matrix at each layer, e.g., ViT-B (r=0.5) equals ViT-S.
18 | - `Target`: Broad slicing bound to ensure the diversity of sub-networks; Fine-grained slicing granularity to ensure the number of sub-networks; Uniform slicing to align with the inherent design of ViT to vary from widths.
19 | - `Contribution`:
20 | - (1) Detailed analysis of the slimmable ability between different architectures
21 | - (2) Propose Scala to learn slimmable representation for flexible inference
22 |
23 |
24 | ## Requirements
25 | - python 3.7
26 | - pytorch 1.8.1
27 | - torchvision 0.9.1
28 | - timm 0.3.2
29 |
30 |
31 | ## Datasets
32 | Please follow the instruction of [DeiT](https://github.com/facebookresearch/deit/blob/main/README_deit.md#data-preparation) to prepare the ImageNet-1K dataset.
33 |
34 |
35 | ## Pretrained Models
36 | Here we provide the pretrained Scala building on top of DeiT-S which are trained on ImageNet-1K for 100 epochs:
37 |
38 | | Model | Acc1. ($r=0.25$) | Acc1. ($r=0.50$) | Acc1. ($r=0.75$) | Acc1. ($r=1.00$) |
39 | | ---- | ---- | ---- | ---- | ---- |
40 | | Separate Training | 45.8% | 65.1% | 70.7% | 75.0% |
41 | | [Scala-S (X=25)](https://drive.google.com/file/d/1-xQFweDA3MUTslDyfs5zqvdRhuRtSN99/view?usp=drive_link) | 58.4% | 67.8% | 73.1% | 76.2% |
42 | | [Scala-S (X=13)](https://drive.google.com/file/d/1D2KZ5_1VAKB8_NTCH35Xu8IdsIT5i3hB/view?usp=drive_link) | 58.7% | 68.3% | 73.3% | 76.1% |
43 | | [Scala-S (X=7)](https://drive.google.com/file/d/1DtA21C6VL4Qe8joHXl8yaaLrQ7mbHnEZ/view?usp=drive_link) | 59.8% | 70.3% | 74.2% | 76.5% |
44 | | [Scala-S (X=4)](https://drive.google.com/file/d/1ZBzFeaMYubr4lBajiyO4QYpzDAG7bp6i/view?usp=drive_link) | 59.8% | 72.0% | 75.6% | 76.7% |
45 |
46 | We also provide Scala building on top of DeiT-B which are trained on ImageNet-1K for 300 epochs:
47 |
48 | | Model | Acc1. ($r=0.25$) | Acc1. ($r=0.50$) | Acc1. ($r=0.75$) | Acc1. ($r=1.00$) |
49 | | ---- | ---- | ---- | ---- | ---- |
50 | | Separate Training | 72.2% | 79.9% | 81.0% | 81.8% |
51 | | [Scala-B (X=13)](https://drive.google.com/file/d/1g58ace9cfFUoooqP6n1Xy0mWqntSGhxE/view?usp=drive_link) | 75.3% | 79.3% | 81.2% | 82.0% |
52 | | [Scala-B (X=7)](https://drive.google.com/file/d/1LIgPj8TAzmrFvJcQS_QmIyUy8CTDDNeF/view?usp=drive_link) | 75.3% | 79.7% | 81.4% | 82.0% |
53 | | [Scala-B (X=4)](https://drive.google.com/file/d/1Usy-LevoYqAXdggUT-jvRiWY41Bw93hf/view?usp=drive_link) | 75.6% | 80.9% | 81.9% | 82.2% |
54 |
55 |
56 | ## Results
57 |
58 | - Slicing Granularity and Bound
59 |
60 |

61 |
62 |
63 | - Application on Hybrid and Lightweight structures
64 |
65 |

66 |
67 |
68 | - Slimmable Ability across Architectures
69 |
70 |

71 |
72 |
73 | - Transferability
74 | - **Whether the slimmable representation can be transferred to downstream tasks?** We first pre-train on ImageNet-1K for 300 epochs and then conduct linear probing on video recognition dataset UCF101. We make the classification head slimmable as well to fit the features with various dimensions and the results imply the great transferability of the slimmable representation.
75 |
76 |

77 |
78 |
79 | - **Whether the generalization ability can be maintained in the slimmable representation?** When leveraging the vision foundation model DINOv2 as the teacher network, we follow prior work [Proteus](https://github.com/BeSpontaneous/Proteus-pytorch) and remove all the Cross-Entropy losses during training to alleviate the dataset bias issue and inherit the strong generalization ability of the teacher network. The results are shown in the table and the delivered Scala-B with great generalization ability can be downloaded from the [link](https://drive.google.com/file/d/1KPJK_rucC8ovQPe2TDDeKt0HBflQ0mrq/view?usp=drive_link).
80 |
81 |
82 |
83 |

84 |
85 |
86 |
87 | ## Training Scala on ImageNet-1K
88 | 1. Specify the directory of datasets with `IMAGENET_LOCATION` in `run.sh`.
89 | 2. Specify the smallest slicing bound $s$ `smallest_ratio`, the largest slicing bound $l$ `largest_ratio` and slicing granularity $\epsilon$ `granularity` to determine $X$ (number of subnets).
90 |
91 | | $s$ | $l$ | $\epsilon$ | $X$ |
92 | | ---- | ---- | ---- | ---- |
93 | | 0.25 | 1.0 | 0.03125 | 25 |
94 | | 0.25 | 1.0 | 0.0625 | 13 |
95 | | 0.25 | 1.0 | 0.125 | 7 |
96 | | 0.25 | 1.0 | 0.25 | 4 |
97 |
98 | 3. Run `bash run.sh`.
99 |
100 |
101 |
102 | ## Flexible inference at different width ratios
103 | 1. Specify the directory of datasets with `IMAGENET_LOCATION` and the pretrained model with `MODEL_PATH` in `eval.sh`.
104 | 2. Specify the width ratio with `eval_ratio` in `eval.sh`.
105 | 3. Run `bash eval.sh`.
--------------------------------------------------------------------------------
/augment.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | """
5 | 3Augment implementation
6 | Data-augmentation (DA) based on dino DA (https://github.com/facebookresearch/dino)
7 | and timm DA(https://github.com/rwightman/pytorch-image-models)
8 | """
9 | import torch
10 | from torchvision import transforms
11 |
12 | from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor
13 |
14 | import numpy as np
15 | from torchvision import datasets, transforms
16 | import random
17 |
18 |
19 |
20 | from PIL import ImageFilter, ImageOps
21 | import torchvision.transforms.functional as TF
22 |
23 |
24 | class GaussianBlur(object):
25 | """
26 | Apply Gaussian Blur to the PIL image.
27 | """
28 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.):
29 | self.prob = p
30 | self.radius_min = radius_min
31 | self.radius_max = radius_max
32 |
33 | def __call__(self, img):
34 | do_it = random.random() <= self.prob
35 | if not do_it:
36 | return img
37 |
38 | img = img.filter(
39 | ImageFilter.GaussianBlur(
40 | radius=random.uniform(self.radius_min, self.radius_max)
41 | )
42 | )
43 | return img
44 |
45 | class Solarization(object):
46 | """
47 | Apply Solarization to the PIL image.
48 | """
49 | def __init__(self, p=0.2):
50 | self.p = p
51 |
52 | def __call__(self, img):
53 | if random.random() < self.p:
54 | return ImageOps.solarize(img)
55 | else:
56 | return img
57 |
58 | class gray_scale(object):
59 | """
60 | Apply Solarization to the PIL image.
61 | """
62 | def __init__(self, p=0.2):
63 | self.p = p
64 | self.transf = transforms.Grayscale(3)
65 |
66 | def __call__(self, img):
67 | if random.random() < self.p:
68 | return self.transf(img)
69 | else:
70 | return img
71 |
72 |
73 |
74 | class horizontal_flip(object):
75 | """
76 | Apply Solarization to the PIL image.
77 | """
78 | def __init__(self, p=0.2,activate_pred=False):
79 | self.p = p
80 | self.transf = transforms.RandomHorizontalFlip(p=1.0)
81 |
82 | def __call__(self, img):
83 | if random.random() < self.p:
84 | return self.transf(img)
85 | else:
86 | return img
87 |
88 |
89 |
90 | def new_data_aug_generator(args = None):
91 | img_size = args.input_size
92 | remove_random_resized_crop = args.src
93 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
94 | primary_tfl = []
95 | scale=(0.08, 1.0)
96 | interpolation='bicubic'
97 | if remove_random_resized_crop:
98 | primary_tfl = [
99 | transforms.Resize(img_size, interpolation=3),
100 | transforms.RandomCrop(img_size, padding=4,padding_mode='reflect'),
101 | transforms.RandomHorizontalFlip()
102 | ]
103 | else:
104 | primary_tfl = [
105 | RandomResizedCropAndInterpolation(
106 | img_size, scale=scale, interpolation=interpolation),
107 | transforms.RandomHorizontalFlip()
108 | ]
109 |
110 |
111 | secondary_tfl = [transforms.RandomChoice([gray_scale(p=1.0),
112 | Solarization(p=1.0),
113 | GaussianBlur(p=1.0)])]
114 |
115 | if args.color_jitter is not None and not args.color_jitter==0:
116 | secondary_tfl.append(transforms.ColorJitter(args.color_jitter, args.color_jitter, args.color_jitter))
117 | final_tfl = [
118 | transforms.ToTensor(),
119 | transforms.Normalize(
120 | mean=torch.tensor(mean),
121 | std=torch.tensor(std))
122 | ]
123 | return transforms.Compose(primary_tfl+secondary_tfl+final_tfl)
124 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | import os
4 | import json
5 |
6 | from torchvision import datasets, transforms
7 | from torchvision.datasets.folder import ImageFolder, default_loader
8 |
9 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
10 | from timm.data import create_transform
11 |
12 |
13 | class INatDataset(ImageFolder):
14 | def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
15 | category='name', loader=default_loader):
16 | self.transform = transform
17 | self.loader = loader
18 | self.target_transform = target_transform
19 | self.year = year
20 | # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
21 | path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
22 | with open(path_json) as json_file:
23 | data = json.load(json_file)
24 |
25 | with open(os.path.join(root, 'categories.json')) as json_file:
26 | data_catg = json.load(json_file)
27 |
28 | path_json_for_targeter = os.path.join(root, f"train{year}.json")
29 |
30 | with open(path_json_for_targeter) as json_file:
31 | data_for_targeter = json.load(json_file)
32 |
33 | targeter = {}
34 | indexer = 0
35 | for elem in data_for_targeter['annotations']:
36 | king = []
37 | king.append(data_catg[int(elem['category_id'])][category])
38 | if king[0] not in targeter.keys():
39 | targeter[king[0]] = indexer
40 | indexer += 1
41 | self.nb_classes = len(targeter)
42 |
43 | self.samples = []
44 | for elem in data['images']:
45 | cut = elem['file_name'].split('/')
46 | target_current = int(cut[2])
47 | path_current = os.path.join(root, cut[0], cut[2], cut[3])
48 |
49 | categors = data_catg[target_current]
50 | target_current_true = targeter[categors[category]]
51 | self.samples.append((path_current, target_current_true))
52 |
53 | # __getitem__ and __len__ inherited from ImageFolder
54 |
55 |
56 | def build_dataset(is_train, args):
57 | transform = build_transform(is_train, args)
58 |
59 | if args.data_set == 'CIFAR':
60 | dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
61 | nb_classes = 100
62 | elif args.data_set == 'IMNET':
63 | root = os.path.join(args.data_path, 'train' if is_train else 'val')
64 | dataset = datasets.ImageFolder(root, transform=transform)
65 | nb_classes = 1000
66 | elif args.data_set == 'INAT':
67 | dataset = INatDataset(args.data_path, train=is_train, year=2018,
68 | category=args.inat_category, transform=transform)
69 | nb_classes = dataset.nb_classes
70 | elif args.data_set == 'INAT19':
71 | dataset = INatDataset(args.data_path, train=is_train, year=2019,
72 | category=args.inat_category, transform=transform)
73 | nb_classes = dataset.nb_classes
74 |
75 | return dataset, nb_classes
76 |
77 |
78 | def build_transform(is_train, args):
79 | resize_im = args.input_size > 32
80 | if is_train:
81 | # this should always dispatch to transforms_imagenet_train
82 | transform = create_transform(
83 | input_size=args.input_size,
84 | is_training=True,
85 | color_jitter=args.color_jitter,
86 | auto_augment=args.aa,
87 | interpolation=args.train_interpolation,
88 | re_prob=args.reprob,
89 | re_mode=args.remode,
90 | re_count=args.recount,
91 | )
92 | if not resize_im:
93 | # replace RandomResizedCropAndInterpolation with
94 | # RandomCrop
95 | transform.transforms[0] = transforms.RandomCrop(
96 | args.input_size, padding=4)
97 | return transform
98 |
99 | t = []
100 | if resize_im:
101 | size = int(args.input_size / args.eval_crop_ratio)
102 | t.append(
103 | transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
104 | )
105 | t.append(transforms.CenterCrop(args.input_size))
106 |
107 | t.append(transforms.ToTensor())
108 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
109 | return transforms.Compose(t)
110 |
--------------------------------------------------------------------------------
/engine_scala.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Train and eval functions used in main.py
5 | """
6 | import math
7 | import sys
8 | from typing import Iterable, Optional
9 | import torch.nn as nn
10 | import torch
11 | from torch.nn import functional as F
12 | from timm.data import Mixup
13 | from timm.utils import accuracy, ModelEma
14 | import random
15 | from losses import DistillationLoss
16 | import utils
17 |
18 |
19 | def train_one_epoch(full_warm_epoch: int, distillation_type: str, ce_coefficient: float,
20 | largest_ratio: float, smallest_ratio: float, granularity: float,
21 | ce_type: str, distill_type: str, transfer_type: str, token_type: str,
22 | model: torch.nn.Module, criterion: DistillationLoss,
23 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
24 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
25 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
26 | set_training_mode=True, args = None):
27 | model.train(set_training_mode)
28 | metric_logger = utils.MetricLogger(delimiter=" ")
29 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
30 | header = 'Epoch: [{}]'.format(epoch)
31 | print_freq = 10
32 |
33 | if args.cosub:
34 | criterion = torch.nn.BCEWithLogitsLoss()
35 |
36 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
37 | samples = samples.to(device, non_blocking=True)
38 | targets = targets.to(device, non_blocking=True)
39 |
40 | if mixup_fn is not None:
41 | samples, targets = mixup_fn(samples, targets)
42 |
43 | if args.cosub:
44 | samples = torch.cat((samples,samples),dim=0)
45 |
46 | if args.bce_loss:
47 | targets = targets.gt(0.0).type(targets.dtype)
48 |
49 | optimizer.zero_grad()
50 |
51 | with torch.cuda.amp.autocast():
52 | if epoch < full_warm_epoch:
53 | outputs = model(samples, largest_ratio)
54 | if not args.cosub:
55 | loss_full = criterion(samples, outputs, targets)
56 | else:
57 | loss_full = 0.25 * criterion(outputs[0], targets)
58 | loss_full = loss_full + 0.25 * criterion(outputs[1], targets)
59 | loss_full = loss_full + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid())
60 | loss_full = loss_full + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid())
61 |
62 | loss_scaler.scale(loss_full).backward()
63 | loss_scaler.step(optimizer)
64 | loss_scaler.update()
65 |
66 | loss = loss_full
67 | loss_value = loss.item()
68 | loss_full_value = loss_full.item()
69 | loss_3q_value = 0.0
70 | loss_2q_value = 0.0
71 | loss_1q_value = 0.0
72 | loss_3q_ce_value = 0.0
73 | loss_2q_ce_value = 0.0
74 | loss_1q_ce_value = 0.0
75 | else:
76 | ############## full model ##############
77 | outputs = model(samples, largest_ratio)
78 | if not args.cosub:
79 | loss_full = criterion(samples, outputs, targets)
80 | else:
81 | loss_full = 0.25 * criterion(outputs[0], targets)
82 | loss_full = loss_full + 0.25 * criterion(outputs[1], targets)
83 | loss_full = loss_full + 0.25 * criterion(outputs[0], outputs[1].detach().sigmoid())
84 | loss_full = loss_full + 0.25 * criterion(outputs[1], outputs[0].detach().sigmoid())
85 |
86 | loss_scaler.scale(loss_full).backward()
87 |
88 | middle_ratio = (largest_ratio + smallest_ratio) / 2
89 | width_3q = float(random.randint(middle_ratio//granularity, largest_ratio//granularity-1) * granularity)
90 | width_2q = float(random.randint(smallest_ratio//granularity+1, middle_ratio//granularity-1) * granularity)
91 |
92 | if token_type == 'dist_token':
93 | token = int(1)
94 | elif token_type == 'cls_token':
95 | token = int(0)
96 |
97 | if ce_type == 'one_hot':
98 | ce_targets = targets
99 | elif ce_type == 'teacher':
100 | ce_targets = outputs[0].detach().argmax(dim=1)
101 |
102 | ce_loss = torch.nn.CrossEntropyLoss()
103 |
104 | ############## 3q model ##############
105 | output_3q = model(samples, width_3q)
106 | if distillation_type == 'none':
107 | if distill_type == 'hard':
108 | loss_3q = F.cross_entropy(output_3q, outputs.detach().argmax(dim=1))
109 | else:
110 | loss_3q = nn.KLDivLoss(reduction='batchmean')(nn.LogSoftmax(dim=1)(output_3q), nn.Softmax(dim=1)(outputs.detach()))
111 | loss_3q_ce = ce_coefficient * ce_loss(output_3q, ce_targets)
112 | loss_3q = loss_3q + loss_3q_ce
113 | else:
114 | if distill_type == 'hard':
115 | loss_3q = F.cross_entropy(output_3q[token], outputs[token].detach().argmax(dim=1))
116 | else:
117 | loss_3q = nn.KLDivLoss(reduction='batchmean')(nn.LogSoftmax(dim=1)(output_3q[token]), nn.Softmax(dim=1)(outputs[token].detach()))
118 | loss_3q_ce = ce_coefficient * ce_loss(output_3q[0], ce_targets)
119 | loss_3q = loss_3q + loss_3q_ce
120 |
121 | loss_scaler.scale(loss_3q).backward()
122 |
123 |
124 | ############## 2q model ##############
125 | if transfer_type == 'US':
126 | teacher_2q = outputs
127 | else:
128 | teacher_2q = output_3q
129 | output_2q = model(samples, width_2q)
130 | if distillation_type == 'none':
131 | if distill_type == 'hard':
132 | loss_2q = F.cross_entropy(output_2q, teacher_2q.detach().argmax(dim=1))
133 | else:
134 | loss_2q = nn.KLDivLoss(reduction='batchmean')(nn.LogSoftmax(dim=1)(output_2q), nn.Softmax(dim=1)(teacher_2q.detach()))
135 | loss_2q_ce = ce_coefficient * ce_loss(output_2q, ce_targets)
136 | loss_2q = loss_2q + loss_2q_ce
137 | else:
138 | if distill_type == 'hard':
139 | loss_2q = F.cross_entropy(output_2q[token], teacher_2q[token].detach().argmax(dim=1))
140 | else:
141 | loss_2q = nn.KLDivLoss(reduction='batchmean')(nn.LogSoftmax(dim=1)(output_2q[token]), nn.Softmax(dim=1)(teacher_2q[token].detach()))
142 | loss_2q_ce = ce_coefficient * ce_loss(output_2q[0], ce_targets)
143 | loss_2q = loss_2q + loss_2q_ce
144 |
145 | loss_scaler.scale(loss_2q).backward()
146 |
147 |
148 | ############## 1q model ##############
149 | if transfer_type == 'US':
150 | teacher_1q = outputs
151 | else:
152 | teacher_1q = output_2q
153 | output_1q = model(samples, smallest_ratio)
154 | if distillation_type == 'none':
155 | if distill_type == 'hard':
156 | loss_1q = F.cross_entropy(output_1q, teacher_1q.detach().argmax(dim=1))
157 | else:
158 | loss_1q = nn.KLDivLoss(reduction='batchmean')(nn.LogSoftmax(dim=1)(output_1q), nn.Softmax(dim=1)(teacher_1q.detach()))
159 | loss_1q_ce = ce_coefficient * ce_loss(output_1q, ce_targets)
160 | loss_1q = loss_1q + loss_1q_ce
161 | else:
162 | if distill_type == 'hard':
163 | loss_1q = F.cross_entropy(output_1q[token], teacher_1q[token].detach().argmax(dim=1))
164 | else:
165 | loss_1q = nn.KLDivLoss(reduction='batchmean')(nn.LogSoftmax(dim=1)(output_1q[token]), nn.Softmax(dim=1)(teacher_1q[token].detach()))
166 | loss_1q_ce = ce_coefficient * ce_loss(output_1q[0], ce_targets)
167 | loss_1q = loss_1q + loss_1q_ce
168 |
169 | loss_scaler.scale(loss_1q).backward()
170 | loss_scaler.step(optimizer)
171 | loss_scaler.update()
172 |
173 | loss = loss_full + loss_3q + loss_2q + loss_1q
174 | loss_value = loss.item()
175 | loss_full_value = loss_full.item()
176 | loss_3q_value = loss_3q.item()
177 | loss_2q_value = loss_2q.item()
178 | loss_1q_value = loss_1q.item()
179 | loss_3q_ce_value = loss_3q_ce.item()
180 | loss_2q_ce_value = loss_2q_ce.item()
181 | loss_1q_ce_value = loss_1q_ce.item()
182 |
183 | if not math.isfinite(loss_value):
184 | print("Loss is {}, stopping training".format(loss_value))
185 | sys.exit(1)
186 |
187 | # # this attribute is added by timm on one optimizer (adahessian)
188 | # is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
189 | # loss_scaler(loss, optimizer, clip_grad=max_norm,
190 | # parameters=model.parameters(), create_graph=is_second_order)
191 |
192 | torch.cuda.synchronize()
193 | if model_ema is not None:
194 | model_ema.update(model)
195 |
196 | metric_logger.update(loss=loss_value)
197 | metric_logger.update(loss_full=loss_full_value)
198 | metric_logger.update(loss_3q=loss_3q_value)
199 | metric_logger.update(loss_3q_ce=loss_3q_ce_value)
200 | metric_logger.update(loss_2q=loss_2q_value)
201 | metric_logger.update(loss_2q_ce=loss_2q_ce_value)
202 | metric_logger.update(loss_1q=loss_1q_value)
203 | metric_logger.update(loss_1q_ce=loss_1q_ce_value)
204 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
205 |
206 | # gather the stats from all processes
207 | metric_logger.synchronize_between_processes()
208 | print("Averaged stats:", metric_logger)
209 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
210 |
211 |
212 | @torch.no_grad()
213 | def evaluate(data_loader, model, device, scale_ratio):
214 | criterion = torch.nn.CrossEntropyLoss()
215 |
216 | metric_logger = utils.MetricLogger(delimiter=" ")
217 | header = 'Test:'
218 |
219 | # switch to evaluation mode
220 | model.eval()
221 |
222 | for images, target in metric_logger.log_every(data_loader, 10, header):
223 | images = images.to(device, non_blocking=True)
224 | target = target.to(device, non_blocking=True)
225 |
226 | # compute output
227 | with torch.cuda.amp.autocast():
228 | output = model(images, scale_ratio)
229 | loss = criterion(output, target)
230 |
231 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
232 |
233 | batch_size = images.shape[0]
234 | metric_logger.update(loss=loss.item())
235 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
236 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
237 | # gather the stats from all processes
238 | metric_logger.synchronize_between_processes()
239 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
240 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
241 |
242 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
243 |
--------------------------------------------------------------------------------
/eval.sh:
--------------------------------------------------------------------------------
1 | ### eval scala on deit_s at width ratio of 0.5
2 |
3 | python -m torch.distributed.launch --nproc_per_node=4 --use_env main_scala.py --eval \
4 | --data-path IMAGENET_LOCATION \
5 | --model deit_small_distilled_patch16_224_scala \
6 | --batch-size 256 --eval_ratio 0.5 \
7 | --resume MODEL_PATH;
--------------------------------------------------------------------------------
/fig/dino.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/dino.png
--------------------------------------------------------------------------------
/fig/gran_bound.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/gran_bound.png
--------------------------------------------------------------------------------
/fig/hybrid.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/hybrid.png
--------------------------------------------------------------------------------
/fig/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/intro.png
--------------------------------------------------------------------------------
/fig/meta.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/meta.png
--------------------------------------------------------------------------------
/fig/neu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/neu.png
--------------------------------------------------------------------------------
/fig/slim.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/slim.png
--------------------------------------------------------------------------------
/fig/snap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/snap.png
--------------------------------------------------------------------------------
/fig/transfer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeSpontaneous/Scala-pytorch/3fe74cdc3c13f229cc8cdb9c78749795a55f6aef/fig/transfer.png
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Implements the knowledge distillation loss
5 | """
6 | import torch
7 | from torch.nn import functional as F
8 |
9 |
10 | class DistillationLoss(torch.nn.Module):
11 | """
12 | This module wraps a standard criterion and adds an extra knowledge distillation loss by
13 | taking a teacher model prediction and using it as additional supervision.
14 | """
15 | def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
16 | distillation_type: str, alpha: float, tau: float):
17 | super().__init__()
18 | self.base_criterion = base_criterion
19 | self.teacher_model = teacher_model
20 | assert distillation_type in ['none', 'soft', 'hard']
21 | self.distillation_type = distillation_type
22 | self.alpha = alpha
23 | self.tau = tau
24 |
25 | def forward(self, inputs, outputs, labels):
26 | """
27 | Args:
28 | inputs: The original inputs that are feed to the teacher model
29 | outputs: the outputs of the model to be trained. It is expected to be
30 | either a Tensor, or a Tuple[Tensor, Tensor], with the original output
31 | in the first position and the distillation predictions as the second output
32 | labels: the labels for the base criterion
33 | """
34 | outputs_kd = None
35 | if not isinstance(outputs, torch.Tensor):
36 | # assume that the model outputs a tuple of [outputs, outputs_kd]
37 | outputs, outputs_kd = outputs
38 | base_loss = self.base_criterion(outputs, labels)
39 | if self.distillation_type == 'none':
40 | return base_loss
41 |
42 | if outputs_kd is None:
43 | raise ValueError("When knowledge distillation is enabled, the model is "
44 | "expected to return a Tuple[Tensor, Tensor] with the output of the "
45 | "class_token and the dist_token")
46 | # don't backprop throught the teacher
47 | with torch.no_grad():
48 | teacher_outputs = self.teacher_model(inputs)
49 |
50 | if self.distillation_type == 'soft':
51 | T = self.tau
52 | # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
53 | # with slight modifications
54 | distillation_loss = F.kl_div(
55 | F.log_softmax(outputs_kd / T, dim=1),
56 | #We provide the teacher's targets in log probability because we use log_target=True
57 | #(as recommended in pytorch https://github.com/pytorch/pytorch/blob/9324181d0ac7b4f7949a574dbc3e8be30abe7041/torch/nn/functional.py#L2719)
58 | #but it is possible to give just the probabilities and set log_target=False. In our experiments we tried both.
59 | F.log_softmax(teacher_outputs / T, dim=1),
60 | reduction='sum',
61 | log_target=True
62 | ) * (T * T) / outputs_kd.numel()
63 | #We divide by outputs_kd.numel() to have the legacy PyTorch behavior.
64 | #But we also experiments output_kd.size(0)
65 | #see issue 61(https://github.com/facebookresearch/deit/issues/61) for more details
66 | elif self.distillation_type == 'hard':
67 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))
68 |
69 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
70 | return loss
71 |
--------------------------------------------------------------------------------
/main_scala.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | import argparse
4 | import datetime
5 | import numpy as np
6 | import time
7 | import torch
8 | import torch.backends.cudnn as cudnn
9 | import json
10 |
11 | from pathlib import Path
12 |
13 | from timm.data import Mixup
14 | from timm.models import create_model
15 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
16 | from timm.optim import create_optimizer
17 | from timm.utils import get_state_dict, ModelEma
18 | import random
19 | from datasets import build_dataset
20 | from engine_scala import train_one_epoch, evaluate
21 | from losses import DistillationLoss
22 | from samplers import RASampler
23 | from scheduler import create_scheduler
24 | from augment import new_data_aug_generator
25 |
26 | import models_scala
27 | import utils
28 |
29 |
30 | def get_args_parser():
31 | parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
32 | parser.add_argument('--batch-size', default=64, type=int)
33 | parser.add_argument('--epochs', default=300, type=int)
34 | parser.add_argument('--bce-loss', action='store_true')
35 | parser.add_argument('--unscale-lr', action='store_true')
36 |
37 | parser.add_argument('--full_warm_epoch', type=int, default=10)
38 | parser.add_argument('--smallest_ratio', type=float, default=0.25)
39 | parser.add_argument('--largest_ratio', type=float, default=1.0)
40 | parser.add_argument('--granularity', type=float, default=0.03125)
41 | parser.add_argument('--ce_coefficient', type=float, default=1.0)
42 | parser.add_argument('--eval_ratio', type=float, default=1.0)
43 | parser.add_argument('--ce_type', default='one_hot', type=str)
44 | parser.add_argument('--distill_type', default='soft', type=str)
45 | parser.add_argument('--transfer_type', default='progressive', type=str)
46 | parser.add_argument('--token_type', default='cls_token', type=str)
47 |
48 | # Model parameters
49 | parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
50 | help='Name of model to train')
51 | parser.add_argument('--input-size', default=224, type=int, help='images input size')
52 |
53 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
54 | help='Dropout rate (default: 0.)')
55 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
56 | help='Drop path rate (default: 0.1)')
57 |
58 | parser.add_argument('--model-ema', action='store_true')
59 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
60 | parser.set_defaults(model_ema=True)
61 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
62 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
63 |
64 | # Optimizer parameters
65 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
66 | help='Optimizer (default: "adamw"')
67 | parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
68 | help='Optimizer Epsilon (default: 1e-8)')
69 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
70 | help='Optimizer Betas (default: None, use opt default)')
71 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
72 | help='Clip gradient norm (default: None, no clipping)')
73 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
74 | help='SGD momentum (default: 0.9)')
75 | parser.add_argument('--weight-decay', type=float, default=0.05,
76 | help='weight decay (default: 0.05)')
77 | # Learning rate schedule parameters
78 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
79 | help='LR scheduler (default: "cosine"')
80 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
81 | help='learning rate (default: 5e-4)')
82 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
83 | help='learning rate noise on/off epoch percentages')
84 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
85 | help='learning rate noise limit percent (default: 0.67)')
86 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
87 | help='learning rate noise std-dev (default: 1.0)')
88 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
89 | help='warmup learning rate (default: 1e-6)')
90 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
91 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
92 |
93 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
94 | help='epoch interval to decay LR')
95 | parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
96 | help='epochs to warmup LR, if scheduler supports')
97 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
98 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
99 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
100 | help='patience epochs for Plateau LR scheduler (default: 10')
101 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
102 | help='LR decay rate (default: 0.1)')
103 |
104 | # Augmentation parameters
105 | parser.add_argument('--color-jitter', type=float, default=0.3, metavar='PCT',
106 | help='Color jitter factor (default: 0.3)')
107 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
108 | help='Use AutoAugment policy. "v0" or "original". " + \
109 | "(default: rand-m9-mstd0.5-inc1)'),
110 | parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
111 | parser.add_argument('--train-interpolation', type=str, default='bicubic',
112 | help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
113 |
114 | parser.add_argument('--repeated-aug', action='store_true')
115 | parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
116 | parser.set_defaults(repeated_aug=True)
117 |
118 | parser.add_argument('--train-mode', action='store_true')
119 | parser.add_argument('--no-train-mode', action='store_false', dest='train_mode')
120 | parser.set_defaults(train_mode=True)
121 |
122 | parser.add_argument('--ThreeAugment', action='store_true') #3augment
123 |
124 | parser.add_argument('--src', action='store_true') #simple random crop
125 |
126 | # * Random Erase params
127 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
128 | help='Random erase prob (default: 0.25)')
129 | parser.add_argument('--remode', type=str, default='pixel',
130 | help='Random erase mode (default: "pixel")')
131 | parser.add_argument('--recount', type=int, default=1,
132 | help='Random erase count (default: 1)')
133 | parser.add_argument('--resplit', action='store_true', default=False,
134 | help='Do not random erase first (clean) augmentation split')
135 |
136 | # * Mixup params
137 | parser.add_argument('--mixup', type=float, default=0.8,
138 | help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
139 | parser.add_argument('--cutmix', type=float, default=1.0,
140 | help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
141 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
142 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
143 | parser.add_argument('--mixup-prob', type=float, default=1.0,
144 | help='Probability of performing mixup or cutmix when either/both is enabled')
145 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
146 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
147 | parser.add_argument('--mixup-mode', type=str, default='batch',
148 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
149 |
150 | # Distillation parameters
151 | parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
152 | help='Name of teacher model to train (default: "regnety_160"')
153 | parser.add_argument('--teacher-path', type=str, default='')
154 | parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
155 | parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
156 | parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
157 |
158 | # * Cosub params
159 | parser.add_argument('--cosub', action='store_true')
160 |
161 | # * Finetuning params
162 | parser.add_argument('--finetune', default='', help='finetune from checkpoint')
163 | parser.add_argument('--attn-only', action='store_true')
164 |
165 | # Dataset parameters
166 | parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', type=str,
167 | help='dataset path')
168 | parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
169 | type=str, help='Image Net dataset path')
170 | parser.add_argument('--inat-category', default='name',
171 | choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
172 | type=str, help='semantic granularity')
173 |
174 | parser.add_argument('--output_dir', default='',
175 | help='path where to save, empty for no saving')
176 | parser.add_argument('--device', default='cuda',
177 | help='device to use for training / testing')
178 | parser.add_argument('--seed', default=0, type=int)
179 | parser.add_argument('--resume', default='', help='resume from checkpoint')
180 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
181 | help='start epoch')
182 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
183 | parser.add_argument('--eval-crop-ratio', default=0.875, type=float, help="Crop ratio for evaluation")
184 | parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
185 | parser.add_argument('--num_workers', default=10, type=int)
186 | parser.add_argument('--pin-mem', action='store_true',
187 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
188 | parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
189 | help='')
190 | parser.set_defaults(pin_mem=True)
191 |
192 | # distributed training parameters
193 | parser.add_argument('--distributed', action='store_true', default=False, help='Enabling distributed training')
194 | parser.add_argument('--world_size', default=1, type=int,
195 | help='number of distributed processes')
196 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
197 | return parser
198 |
199 |
200 | def main(args):
201 | utils.init_distributed_mode(args)
202 |
203 | print(args)
204 |
205 | if args.distillation_type != 'none' and args.finetune and not args.eval:
206 | raise NotImplementedError("Finetuning with distillation not yet supported")
207 |
208 | device = torch.device(args.device)
209 |
210 | # fix the seed for reproducibility
211 | seed = args.seed + utils.get_rank()
212 | torch.manual_seed(seed)
213 | np.random.seed(seed)
214 | random.seed(seed)
215 |
216 | cudnn.benchmark = True
217 |
218 | dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
219 | dataset_val, _ = build_dataset(is_train=False, args=args)
220 |
221 | if args.distributed:
222 | num_tasks = utils.get_world_size()
223 | global_rank = utils.get_rank()
224 | if args.repeated_aug:
225 | sampler_train = RASampler(
226 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
227 | )
228 | else:
229 | sampler_train = torch.utils.data.DistributedSampler(
230 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
231 | )
232 | if args.dist_eval:
233 | if len(dataset_val) % num_tasks != 0:
234 | print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
235 | 'This will slightly alter validation results as extra duplicate entries are added to achieve '
236 | 'equal num of samples per-process.')
237 | sampler_val = torch.utils.data.DistributedSampler(
238 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
239 | else:
240 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
241 | else:
242 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
243 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
244 |
245 | data_loader_train = torch.utils.data.DataLoader(
246 | dataset_train, sampler=sampler_train,
247 | batch_size=args.batch_size,
248 | num_workers=args.num_workers,
249 | pin_memory=args.pin_mem,
250 | drop_last=True,
251 | )
252 | if args.ThreeAugment:
253 | data_loader_train.dataset.transform = new_data_aug_generator(args)
254 |
255 | data_loader_val = torch.utils.data.DataLoader(
256 | dataset_val, sampler=sampler_val,
257 | batch_size=int(1.5 * args.batch_size),
258 | num_workers=args.num_workers,
259 | pin_memory=args.pin_mem,
260 | drop_last=False
261 | )
262 |
263 | mixup_fn = None
264 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
265 | if mixup_active:
266 | mixup_fn = Mixup(
267 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
268 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
269 | label_smoothing=args.smoothing, num_classes=args.nb_classes)
270 |
271 | print(f"Creating model: {args.model}")
272 | model = create_model(
273 | args.model,
274 | pretrained=False,
275 | num_classes=args.nb_classes,
276 | drop_rate=args.drop,
277 | drop_path_rate=args.drop_path,
278 | drop_block_rate=None,
279 | img_size=args.input_size,
280 | smallest_ratio=args.smallest_ratio,
281 | largest_ratio=args.largest_ratio,
282 | )
283 |
284 |
285 | if args.finetune:
286 | if args.finetune.startswith('https'):
287 | checkpoint = torch.hub.load_state_dict_from_url(
288 | args.finetune, map_location='cpu', check_hash=True)
289 | else:
290 | checkpoint = torch.load(args.finetune, map_location='cpu')
291 |
292 | checkpoint_model = checkpoint['model']
293 | state_dict = model.state_dict()
294 | for k in ['head.weight', 'head.bias', 'head_dist.weight', 'head_dist.bias']:
295 | if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
296 | print(f"Removing key {k} from pretrained checkpoint")
297 | del checkpoint_model[k]
298 |
299 | # interpolate position embedding
300 | pos_embed_checkpoint = checkpoint_model['pos_embed']
301 | embedding_size = pos_embed_checkpoint.shape[-1]
302 | num_patches = model.patch_embed.num_patches
303 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
304 | # height (== width) for the checkpoint position embedding
305 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
306 | # height (== width) for the new position embedding
307 | new_size = int(num_patches ** 0.5)
308 | # class_token and dist_token are kept unchanged
309 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
310 | # only the position tokens are interpolated
311 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
312 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
313 | pos_tokens = torch.nn.functional.interpolate(
314 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
315 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
316 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
317 | checkpoint_model['pos_embed'] = new_pos_embed
318 |
319 | model.load_state_dict(checkpoint_model, strict=False)
320 |
321 | if args.attn_only:
322 | for name_p,p in model.named_parameters():
323 | if '.attn.' in name_p:
324 | p.requires_grad = True
325 | else:
326 | p.requires_grad = False
327 | try:
328 | model.head.weight.requires_grad = True
329 | model.head.bias.requires_grad = True
330 | except:
331 | model.fc.weight.requires_grad = True
332 | model.fc.bias.requires_grad = True
333 | try:
334 | model.pos_embed.requires_grad = True
335 | except:
336 | print('no position encoding')
337 | try:
338 | for p in model.patch_embed.parameters():
339 | p.requires_grad = False
340 | except:
341 | print('no patch embed')
342 |
343 | model.to(device)
344 |
345 | model_ema = None
346 | if args.model_ema:
347 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
348 | model_ema = ModelEma(
349 | model,
350 | decay=args.model_ema_decay,
351 | device='cpu' if args.model_ema_force_cpu else '',
352 | resume='')
353 |
354 | model_without_ddp = model
355 | if args.distributed:
356 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
357 | model_without_ddp = model.module
358 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
359 | print('number of params:', n_parameters)
360 | if not args.unscale_lr:
361 | linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
362 | args.lr = linear_scaled_lr
363 | optimizer = create_optimizer(args, model_without_ddp)
364 | # loss_scaler = NativeScaler()
365 | loss_scaler = torch.cuda.amp.GradScaler()
366 |
367 | lr_scheduler, _ = create_scheduler(args, optimizer)
368 |
369 | criterion = LabelSmoothingCrossEntropy()
370 |
371 | if mixup_active:
372 | # smoothing is handled with mixup label transform
373 | criterion = SoftTargetCrossEntropy()
374 | elif args.smoothing:
375 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
376 | else:
377 | criterion = torch.nn.CrossEntropyLoss()
378 |
379 | if args.bce_loss:
380 | criterion = torch.nn.BCEWithLogitsLoss()
381 |
382 | teacher_model = None
383 | if args.distillation_type != 'none':
384 | assert args.teacher_path, 'need to specify teacher-path when using distillation'
385 | print(f"Creating teacher model: {args.teacher_model}")
386 | teacher_model = create_model(
387 | args.teacher_model,
388 | pretrained=False,
389 | num_classes=args.nb_classes,
390 | # global_pool='avg',
391 | )
392 | if args.teacher_path.startswith('https'):
393 | checkpoint = torch.hub.load_state_dict_from_url(
394 | args.teacher_path, map_location='cpu', check_hash=True)
395 | else:
396 | checkpoint = torch.load(args.teacher_path, map_location='cpu')
397 | teacher_model.load_state_dict(checkpoint['model'])
398 | teacher_model.to(device)
399 | teacher_model.eval()
400 |
401 | # wrap the criterion in our custom DistillationLoss, which
402 | # just dispatches to the original criterion if args.distillation_type is 'none'
403 | criterion = DistillationLoss(
404 | criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
405 | )
406 |
407 | output_dir = Path(args.output_dir)
408 | if args.resume:
409 | if args.resume.startswith('https'):
410 | checkpoint = torch.hub.load_state_dict_from_url(
411 | args.resume, map_location='cpu', check_hash=True)
412 | else:
413 | checkpoint = torch.load(args.resume, map_location='cpu')
414 | model_without_ddp.load_state_dict(checkpoint['model'])
415 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
416 | optimizer.load_state_dict(checkpoint['optimizer'])
417 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
418 | args.start_epoch = checkpoint['epoch'] + 1
419 | if args.model_ema:
420 | utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
421 | if 'scaler' in checkpoint:
422 | loss_scaler.load_state_dict(checkpoint['scaler'])
423 | lr_scheduler.step(args.start_epoch)
424 | if args.eval:
425 | test_stats = evaluate(data_loader_val, model, device, args.eval_ratio)
426 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
427 | return
428 |
429 | print(f"Start training for {args.epochs} epochs")
430 | start_time = time.time()
431 | max_accuracy = 0.0
432 | for epoch in range(args.start_epoch, args.epochs):
433 | if args.distributed:
434 | data_loader_train.sampler.set_epoch(epoch)
435 |
436 | train_stats = train_one_epoch(
437 | args.full_warm_epoch, args.distillation_type, args.ce_coefficient,
438 | args.largest_ratio, args.smallest_ratio, args.granularity,
439 | args.ce_type, args.distill_type, args.transfer_type, args.token_type,
440 | model, criterion, data_loader_train,
441 | optimizer, device, epoch, loss_scaler,
442 | args.clip_grad, model_ema, mixup_fn,
443 | set_training_mode=args.train_mode, # keep in eval mode for deit finetuning / train mode for training and deit III finetuning
444 | args = args,
445 | )
446 |
447 | lr_scheduler.step(epoch)
448 | if args.output_dir:
449 | checkpoint_paths = [output_dir / 'checkpoint.pth']
450 | for checkpoint_path in checkpoint_paths:
451 | utils.save_on_master({
452 | 'model': model_without_ddp.state_dict(),
453 | 'optimizer': optimizer.state_dict(),
454 | 'lr_scheduler': lr_scheduler.state_dict(),
455 | 'epoch': epoch,
456 | 'model_ema': get_state_dict(model_ema),
457 | 'scaler': loss_scaler.state_dict(),
458 | 'args': args,
459 | }, checkpoint_path)
460 |
461 |
462 | test_stats = evaluate(data_loader_val, model, device, 1.0)
463 | print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
464 |
465 | if max_accuracy < test_stats["acc1"]:
466 | max_accuracy = test_stats["acc1"]
467 | if args.output_dir:
468 | checkpoint_paths = [output_dir / 'best_checkpoint.pth']
469 | for checkpoint_path in checkpoint_paths:
470 | utils.save_on_master({
471 | 'model': model_without_ddp.state_dict(),
472 | 'optimizer': optimizer.state_dict(),
473 | 'lr_scheduler': lr_scheduler.state_dict(),
474 | 'epoch': epoch,
475 | 'model_ema': get_state_dict(model_ema),
476 | 'scaler': loss_scaler.state_dict(),
477 | 'args': args,
478 | }, checkpoint_path)
479 |
480 | print(f'Max accuracy: {max_accuracy:.2f}%')
481 |
482 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
483 | **{f'test_{k}': v for k, v in test_stats.items()},
484 | 'epoch': epoch,
485 | 'n_parameters': n_parameters}
486 |
487 |
488 | if args.output_dir and utils.is_main_process():
489 | with (output_dir / "log.txt").open("a") as f:
490 | f.write(json.dumps(log_stats) + "\n")
491 |
492 | if epoch == args.epochs - 1:
493 | range_num = int((1-args.smallest_ratio) / args.granularity)
494 | for i in range(range_num):
495 | scale_ratio = args.smallest_ratio + i * args.granularity
496 | test_stats_end = evaluate(data_loader_val, model, device, scale_ratio)
497 | print(f"Width ratio: {scale_ratio}, Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
498 |
499 | log_stats_end = {'Width ratio': scale_ratio,
500 | 'Acc1': test_stats_end['acc1']}
501 |
502 | if args.output_dir and utils.is_main_process():
503 | with (output_dir / "log.txt").open("a") as f:
504 | f.write(json.dumps(log_stats_end) + "\n")
505 |
506 | total_time = time.time() - start_time
507 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
508 | print('Training time {}'.format(total_time_str))
509 |
510 |
511 | if __name__ == '__main__':
512 | parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
513 | args = parser.parse_args()
514 | if args.output_dir:
515 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
516 | main(args)
517 |
--------------------------------------------------------------------------------
/models_scala.py:
--------------------------------------------------------------------------------
1 | import math
2 | import logging
3 | from functools import partial
4 | from collections import OrderedDict
5 | from typing import Callable, List, Optional, Tuple, Union
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.utils.checkpoint
11 | from enum import Enum
12 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
13 | from timm.models.helpers import build_model_with_cfg, named_apply, checkpoint_seq
14 | from timm.models.layers import DropPath, trunc_normal_, lecun_normal_, to_2tuple
15 | from timm.models.registry import register_model
16 |
17 | _logger = logging.getLogger(__name__)
18 |
19 |
20 | def _assert(condition: bool, message: str):
21 | assert condition, message
22 |
23 |
24 | class Format(str, Enum):
25 | NCHW = 'NCHW'
26 | NHWC = 'NHWC'
27 | NCL = 'NCL'
28 | NLC = 'NLC'
29 |
30 |
31 | def _cfg(url='', **kwargs):
32 | return {
33 | 'url': url,
34 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
35 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
36 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
37 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
38 | **kwargs
39 | }
40 |
41 |
42 | class Attention(nn.Module):
43 | def __init__(self, smallest_ratio, largest_ratio, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
44 | super().__init__()
45 | assert dim % num_heads == 0, 'dim should be divisible by num_heads'
46 |
47 | self.smallest_ratio = smallest_ratio
48 | self.largest_ratio = largest_ratio
49 |
50 | self.num_heads = num_heads
51 | head_dim = dim // num_heads
52 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
53 | self.scale = head_dim ** -0.5
54 |
55 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
56 | self.attn_drop = nn.Dropout(attn_drop)
57 | self.proj = nn.Linear(dim, dim)
58 | self.proj_drop = nn.Dropout(proj_drop)
59 |
60 | self.dim = dim
61 |
62 | def forward(self, x, ratio):
63 | dim_channels = int(ratio * self.dim)
64 |
65 | B, N, C = x.shape
66 |
67 | if ratio == self.smallest_ratio or ratio == self.largest_ratio:
68 | weight_qkv = self.qkv.weight[:dim_channels*3, :dim_channels]
69 | bias_qkv = self.qkv.bias[:dim_channels*3]
70 | else:
71 | weight_qkv = self.qkv.weight[-dim_channels*3:, -dim_channels:]
72 | bias_qkv = self.qkv.bias[-dim_channels*3:]
73 | qkv = F.linear(input=x, weight=weight_qkv, bias=bias_qkv)
74 | qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
75 | # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
76 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
77 |
78 | attn = (q @ k.transpose(-2, -1)) * self.scale
79 | attn = attn.softmax(dim=-1)
80 | attn = self.attn_drop(attn)
81 |
82 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
83 |
84 | if ratio == self.smallest_ratio or ratio == self.largest_ratio:
85 | weight_proj = self.proj.weight[:dim_channels, :dim_channels]
86 | bias_proj = self.proj.bias[:dim_channels]
87 | else:
88 | weight_proj = self.proj.weight[-dim_channels:, -dim_channels:]
89 | bias_proj = self.proj.bias[-dim_channels:]
90 | x = F.linear(input=x, weight=weight_proj, bias=bias_proj)
91 | x = self.proj_drop(x)
92 | return x
93 |
94 |
95 |
96 | class LayerScale(nn.Module):
97 | def __init__(self, smallest_ratio, largest_ratio, dim, init_values=1e-5, inplace=False):
98 | super().__init__()
99 |
100 | self.smallest_ratio = smallest_ratio
101 | self.largest_ratio = largest_ratio
102 |
103 | self.inplace = inplace
104 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
105 |
106 | self.dim = dim
107 |
108 | def forward(self, x, ratio):
109 | dim_channels = int(ratio * self.dim)
110 |
111 | if ratio == self.smallest_ratio or ratio == self.largest_ratio:
112 | gamma = self.gamma.data[:dim_channels]
113 | else:
114 | gamma = self.gamma.data[-dim_channels:]
115 | if self.inplace:
116 | x = x.mul_(gamma)
117 | else:
118 | x * gamma
119 |
120 | return x
121 |
122 |
123 |
124 | class Mlp(nn.Module):
125 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks
126 | """
127 | def __init__(
128 | self,
129 | smallest_ratio,
130 | largest_ratio,
131 | in_features,
132 | hidden_features=None,
133 | out_features=None,
134 | act_layer=nn.GELU,
135 | norm_layer=None,
136 | bias=True,
137 | drop=0.,
138 | use_conv=False,
139 | ):
140 | super().__init__()
141 | out_features = out_features or in_features
142 | hidden_features = hidden_features or in_features
143 | bias = to_2tuple(bias)
144 | drop_probs = to_2tuple(drop)
145 | linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
146 |
147 | self.smallest_ratio = smallest_ratio
148 | self.largest_ratio = largest_ratio
149 |
150 | self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
151 | self.act = act_layer()
152 | self.drop1 = nn.Dropout(drop_probs[0])
153 | self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
154 | self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
155 | self.drop2 = nn.Dropout(drop_probs[1])
156 |
157 | self.in_features = in_features
158 | self.hidden_features = hidden_features
159 | self.out_features = out_features
160 |
161 | def forward(self, x, ratio):
162 | in_channels = int(ratio * self.in_features)
163 | hidden_channels = int(ratio * self.hidden_features)
164 | out_channels = int(ratio * self.out_features)
165 |
166 | if ratio == self.smallest_ratio or ratio == self.largest_ratio:
167 | weight1 = self.fc1.weight[:hidden_channels, :in_channels]
168 | bias1 = self.fc1.bias[:hidden_channels]
169 | else:
170 | weight1 = self.fc1.weight[-hidden_channels:, -in_channels:]
171 | bias1 = self.fc1.bias[-hidden_channels:]
172 | x = F.linear(input=x, weight=weight1, bias=bias1)
173 | x = self.act(x)
174 | x = self.drop1(x)
175 |
176 | if ratio == self.smallest_ratio or ratio == self.largest_ratio:
177 | weight2 = self.fc2.weight[:out_channels, :hidden_channels]
178 | bias2 = self.fc2.bias[:out_channels]
179 | else:
180 | weight2 = self.fc2.weight[-out_channels:, -hidden_channels:]
181 | bias2 = self.fc2.bias[-out_channels:]
182 | x = F.linear(input=x, weight=weight2, bias=bias2)
183 | x = self.drop2(x)
184 |
185 | return x
186 |
187 |
188 | class Block(nn.Module):
189 |
190 | def __init__(
191 | self, smallest_ratio, largest_ratio, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
192 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
193 | super().__init__()
194 | self.smallest_ratio = smallest_ratio
195 | self.largest_ratio = largest_ratio
196 | self.norm1 = norm_layer(dim)
197 | self.attn = Attention(smallest_ratio, largest_ratio, dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
198 | self.ls1 = LayerScale(smallest_ratio, largest_ratio, dim, init_values=init_values) if init_values else nn.Identity()
199 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
200 | self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
201 |
202 | self.norm2 = norm_layer(dim)
203 | self.mlp = Mlp(smallest_ratio, largest_ratio, in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
204 | self.ls2 = LayerScale(smallest_ratio, largest_ratio, dim, init_values=init_values) if init_values else nn.Identity()
205 | self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
206 |
207 | self.dim = dim
208 | self.init_values = init_values
209 |
210 | def forward(self, x, ratio):
211 | dim_channels = int(ratio * self.dim)
212 |
213 | residual = x
214 |
215 | if ratio == self.smallest_ratio or ratio == self.largest_ratio:
216 | weight_norm1 = self.norm1.weight[:dim_channels]
217 | bias_norm1 = self.norm1.bias[:dim_channels]
218 | else:
219 | weight_norm1 = self.norm1.weight[-dim_channels:]
220 | bias_norm1 = self.norm1.bias[-dim_channels:]
221 | x = F.layer_norm(x, [dim_channels], weight_norm1, bias_norm1)
222 |
223 | x = self.attn(x, ratio)
224 | if self.init_values:
225 | x = self.ls1(x, ratio)
226 | else:
227 | x = self.ls1(x)
228 | x = residual + self.drop_path1(x)
229 |
230 | residual = x
231 |
232 | if ratio == self.smallest_ratio or ratio == self.largest_ratio:
233 | weight_norm2 = self.norm2.weight[:dim_channels]
234 | bias_norm2 = self.norm2.bias[:dim_channels]
235 | else:
236 | weight_norm2 = self.norm2.weight[-dim_channels:]
237 | bias_norm2 = self.norm2.bias[-dim_channels:]
238 | x = F.layer_norm(x, [dim_channels], weight_norm2, bias_norm2)
239 |
240 | x = self.mlp(x, ratio)
241 | if self.init_values:
242 | x = self.ls2(x, ratio)
243 | else:
244 | x = self.ls2(x)
245 | x = residual + self.drop_path2(x)
246 | return x
247 |
248 |
249 |
250 |
251 | class PatchEmbed(nn.Module):
252 | """ 2D Image to Patch Embedding
253 | """
254 | output_fmt: Format
255 |
256 | def __init__(
257 | self,
258 | smallest_ratio: float = 0.25,
259 | largest_ratio: float = 1.0,
260 | img_size: Optional[int] = 224,
261 | patch_size: int = 16,
262 | in_chans: int = 3,
263 | embed_dim: int = 768,
264 | norm_layer: Optional[Callable] = None,
265 | flatten: bool = True,
266 | output_fmt: Optional[str] = None,
267 | bias: bool = True,
268 | strict_img_size: bool = True,
269 | ):
270 | super().__init__()
271 | self.smallest_ratio = smallest_ratio
272 | self.largest_ratio = largest_ratio
273 | self.stride = patch_size
274 | self.patch_size = to_2tuple(patch_size)
275 | if img_size is not None:
276 | self.img_size = to_2tuple(img_size)
277 | self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
278 | self.num_patches = self.grid_size[0] * self.grid_size[1]
279 | else:
280 | self.img_size = None
281 | self.grid_size = None
282 | self.num_patches = None
283 |
284 | if output_fmt is not None:
285 | self.flatten = False
286 | self.output_fmt = Format(output_fmt)
287 | else:
288 | # flatten spatial dim and transpose to channels last, kept for bwd compat
289 | self.flatten = flatten
290 | self.output_fmt = Format.NCHW
291 | self.strict_img_size = strict_img_size
292 |
293 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
294 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
295 |
296 | self.norm_layer = norm_layer
297 | self.embed_dim = embed_dim
298 |
299 | def forward(self, x, ratio):
300 | embed_dim_channels = int(ratio*self.embed_dim)
301 |
302 | B, C, H, W = x.shape
303 | if self.img_size is not None:
304 | if self.strict_img_size:
305 | _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
306 | _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
307 | else:
308 | _assert(
309 | H % self.patch_size[0] == 0,
310 | f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
311 | )
312 | _assert(
313 | W % self.patch_size[1] == 0,
314 | f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
315 | )
316 |
317 | if ratio == self.smallest_ratio or ratio == self.largest_ratio:
318 | weight_proj = self.proj.weight[:embed_dim_channels, :, :, :]
319 | bias_proj = self.proj.bias[:embed_dim_channels]
320 | else:
321 | weight_proj = self.proj.weight[-embed_dim_channels:, :, :, :]
322 | bias_proj = self.proj.bias[-embed_dim_channels:]
323 | x = F.conv2d(x, weight_proj, bias_proj, self.stride)
324 |
325 | if self.flatten:
326 | x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
327 | elif self.output_fmt != Format.NCHW:
328 | x = nchw_to(x, self.output_fmt)
329 |
330 | if self.norm_layer:
331 | if ratio == self.smallest_ratio or ratio == self.largest_ratio:
332 | weight_norm = self.norm.weight[:embed_dim_channels]
333 | bias_norm = self.norm.bias[:embed_dim_channels]
334 | else:
335 | weight_norm = self.norm.weight[-embed_dim_channels:]
336 | bias_norm = self.norm.bias[-embed_dim_channels:]
337 | x = F.layer_norm(x, [embed_dim_channels], weight_norm, bias_norm)
338 | else:
339 | x = self.norm(x)
340 |
341 | return x
342 |
343 |
344 | class VisionTransformer_scala(nn.Module):
345 | """ Vision Transformer
346 |
347 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
348 | - https://arxiv.org/abs/2010.11929
349 | """
350 |
351 | def __init__(
352 | self, img_size=224, smallest_ratio=0.25, largest_ratio=1.0, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
353 | embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
354 | class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
355 | weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
356 | """
357 | Args:
358 | img_size (int, tuple): input image size
359 | patch_size (int, tuple): patch size
360 | in_chans (int): number of input channels
361 | num_classes (int): number of classes for classification head
362 | global_pool (str): type of global pooling for final sequence (default: 'token')
363 | embed_dim (int): embedding dimension
364 | depth (int): depth of transformer
365 | num_heads (int): number of attention heads
366 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
367 | qkv_bias (bool): enable bias for qkv if True
368 | init_values: (float): layer-scale init values
369 | class_token (bool): use class token
370 | fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
371 | drop_rate (float): dropout rate
372 | attn_drop_rate (float): attention dropout rate
373 | drop_path_rate (float): stochastic depth rate
374 | weight_init (str): weight init scheme
375 | embed_layer (nn.Module): patch embedding layer
376 | norm_layer: (nn.Module): normalization layer
377 | act_layer: (nn.Module): MLP activation layer
378 | """
379 | super().__init__()
380 | assert global_pool in ('', 'avg', 'token')
381 | assert class_token or global_pool != 'token'
382 | use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
383 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
384 | act_layer = act_layer or nn.GELU
385 |
386 | self.num_classes = num_classes
387 | self.smallest_ratio = smallest_ratio
388 | self.largest_ratio = largest_ratio
389 | self.global_pool = global_pool
390 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
391 | self.num_prefix_tokens = 1 if class_token else 0
392 | self.no_embed_class = no_embed_class
393 | self.grad_checkpointing = False
394 |
395 | self.patch_embed = embed_layer(
396 | smallest_ratio=smallest_ratio, largest_ratio=largest_ratio, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
397 | num_patches = self.patch_embed.num_patches
398 |
399 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
400 | embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
401 | self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
402 | self.pos_drop = nn.Dropout(p=drop_rate)
403 |
404 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
405 | self.blocks = nn.Sequential(*[
406 | block_fn(
407 | smallest_ratio, largest_ratio, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values,
408 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
409 | for i in range(depth)])
410 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
411 |
412 | # Classifier Head
413 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
414 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
415 |
416 | self.use_fc_norm = use_fc_norm
417 |
418 | if weight_init != 'skip':
419 | self.init_weights(weight_init)
420 |
421 | def init_weights(self, mode=''):
422 | assert mode in ('jax', 'jax_nlhb', 'moco', '')
423 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
424 | trunc_normal_(self.pos_embed, std=.02)
425 | if self.cls_token is not None:
426 | nn.init.normal_(self.cls_token, std=1e-6)
427 | named_apply(get_init_weights_vit(mode, head_bias), self)
428 |
429 | def _init_weights(self, m):
430 | # this fn left here for compat with downstream users
431 | init_weights_vit_timm(m)
432 |
433 | @torch.jit.ignore()
434 | def load_pretrained(self, checkpoint_path, prefix=''):
435 | _load_weights(self, checkpoint_path, prefix)
436 |
437 | @torch.jit.ignore
438 | def no_weight_decay(self):
439 | return {'pos_embed', 'cls_token', 'dist_token'}
440 |
441 | @torch.jit.ignore
442 | def group_matcher(self, coarse=False):
443 | return dict(
444 | stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
445 | blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
446 | )
447 |
448 | @torch.jit.ignore
449 | def set_grad_checkpointing(self, enable=True):
450 | self.grad_checkpointing = enable
451 |
452 | @torch.jit.ignore
453 | def get_classifier(self):
454 | return self.head
455 |
456 | def reset_classifier(self, num_classes: int, global_pool=None):
457 | self.num_classes = num_classes
458 | if global_pool is not None:
459 | assert global_pool in ('', 'avg', 'token')
460 | self.global_pool = global_pool
461 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
462 |
463 | def _pos_embed(self, x, ratio):
464 | if self.no_embed_class:
465 | # deit-3, updated JAX (big vision)
466 | # position embedding does not overlap with class token, add then concat
467 | if ratio == self.smallest_ratio:
468 | pos_embed = self.pos_embed[:,:,:int(ratio*self.num_features)]
469 | x = x + pos_embed
470 | if self.cls_token is not None:
471 | cls_token = self.cls_token[:,:,:int(ratio*self.num_features)]
472 | x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1)
473 | elif ratio == self.largest_ratio:
474 | x = x + self.pos_embed
475 | if self.cls_token is not None:
476 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
477 | else:
478 | pos_embed = self.pos_embed[:,:,-int(ratio*self.num_features):]
479 | x = x + pos_embed
480 | if self.cls_token is not None:
481 | cls_token = self.cls_token[:,:,-int(ratio*self.num_features):]
482 | x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1)
483 | else:
484 | # original timm, JAX, and deit vit impl
485 | # pos_embed has entry for class token, concat then add
486 | if ratio == self.smallest_ratio:
487 | if self.cls_token is not None:
488 | cls_token = self.cls_token[:,:,:int(ratio*self.num_features)]
489 | x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1)
490 | pos_embed = self.pos_embed[:,:,:int(ratio*self.num_features)]
491 | x = x + pos_embed
492 | elif ratio == self.largest_ratio:
493 | if self.cls_token is not None:
494 | x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
495 | x = x + self.pos_embed
496 | else:
497 | if self.cls_token is not None:
498 | cls_token = self.cls_token[:,:,-int(ratio*self.num_features):]
499 | x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1)
500 | pos_embed = self.pos_embed[:,:,-int(ratio*self.num_features):]
501 | x = x + pos_embed
502 |
503 | return self.pos_drop(x)
504 |
505 | def forward_features(self, x, ratio):
506 | x = self.patch_embed(x, ratio)
507 | x = self._pos_embed(x, ratio)
508 | for blk in self.blocks:
509 | if self.grad_checkpointing and not torch.jit.is_scripting():
510 | x = checkpoint_seq(blk, x, ratio)
511 | else:
512 | x = blk(x, ratio)
513 |
514 | if self.use_fc_norm:
515 | x = self.norm(x)
516 | else:
517 | if ratio == self.smallest_ratio or ratio == self.largest_ratio:
518 | weight_norm = self.norm.weight[:int(ratio*self.embed_dim)]
519 | bias_norm = self.norm.bias[:int(ratio*self.embed_dim)]
520 | else:
521 | weight_norm = self.norm.weight[-int(ratio*self.embed_dim):]
522 | bias_norm = self.norm.bias[-int(ratio*self.embed_dim):]
523 | x = F.layer_norm(x, [int(ratio*self.embed_dim)], weight_norm, bias_norm)
524 | return x
525 |
526 | def forward_head(self, x, ratio, pre_logits: bool = False):
527 | if self.global_pool:
528 | x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
529 |
530 | if self.use_fc_norm:
531 | if ratio == self.smallest_ratio or ratio == self.largest_ratio:
532 | weight_fc_norm = self.fc_norm.weight[:int(ratio*self.embed_dim)]
533 | bias_fc_norm = self.fc_norm.bias[:int(ratio*self.embed_dim)]
534 | else:
535 | weight_fc_norm = self.fc_norm.weight[-int(ratio*self.embed_dim):]
536 | bias_fc_norm = self.fc_norm.bias[-int(ratio*self.embed_dim):]
537 | x = F.layer_norm(x, [int(ratio*self.embed_dim)], weight_fc_norm, bias_fc_norm)
538 | else:
539 | x = self.fc_norm(x)
540 |
541 | if pre_logits:
542 | x = x
543 | else:
544 | if ratio == self.smallest_ratio or ratio == self.largest_ratio:
545 | weight_head = self.head.weight[:, :int(ratio*self.embed_dim)]
546 | bias_head = self.head.bias[:]
547 | else:
548 | weight_head = self.head.weight[:, -int(ratio*self.embed_dim):]
549 | bias_head = self.head.bias[:]
550 | x = F.linear(input=x, weight=weight_head, bias=bias_head)
551 |
552 | return x
553 |
554 | def forward(self, x, ratio):
555 | x = self.forward_features(x, ratio)
556 | x = self.forward_head(x, ratio)
557 | return x
558 |
559 |
560 | def init_weights_vit_timm(module: nn.Module, name: str = ''):
561 | """ ViT weight initialization, original timm impl (for reproducibility) """
562 | if isinstance(module, nn.Linear):
563 | trunc_normal_(module.weight, std=.02)
564 | if module.bias is not None:
565 | nn.init.zeros_(module.bias)
566 | elif hasattr(module, 'init_weights'):
567 | module.init_weights()
568 |
569 |
570 | def get_init_weights_vit(mode='jax', head_bias: float = 0.):
571 | if 'jax' in mode:
572 | return partial(init_weights_vit_jax, head_bias=head_bias)
573 | elif 'moco' in mode:
574 | return init_weights_vit_moco
575 | else:
576 | return init_weights_vit_timm
577 |
578 |
579 |
580 |
581 | class DistilledVisionTransformer_scala(VisionTransformer_scala):
582 | def __init__(self, *args, **kwargs):
583 | super().__init__(*args, **kwargs)
584 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
585 | num_patches = self.patch_embed.num_patches
586 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
587 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
588 |
589 | trunc_normal_(self.dist_token, std=.02)
590 | trunc_normal_(self.pos_embed, std=.02)
591 | self.head_dist.apply(self._init_weights)
592 |
593 | def forward_features(self, x, ratio):
594 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
595 | # with slight modifications to add the dist_token
596 | B = x.shape[0]
597 | x = self.patch_embed(x, ratio)
598 |
599 | if ratio == 0.25 or ratio == 1.0:
600 | cls_token = self.cls_token[:,:,:int(ratio*self.num_features)]
601 | dist_token = self.dist_token[:,:,:int(ratio*self.num_features)]
602 | x = torch.cat((cls_token.expand(B, -1, -1), dist_token.expand(B, -1, -1), x), dim=1)
603 | pos_embed = self.pos_embed[:,:,:int(ratio*self.num_features)]
604 | x = x + pos_embed
605 | else:
606 | cls_token = self.cls_token[:,:,-int(ratio*self.num_features):]
607 | dist_token = self.dist_token[:,:,-int(ratio*self.num_features):]
608 | x = torch.cat((cls_token.expand(B, -1, -1), dist_token.expand(B, -1, -1), x), dim=1)
609 | pos_embed = self.pos_embed[:,:,-int(ratio*self.num_features):]
610 | x = x + pos_embed
611 |
612 | x = self.pos_drop(x)
613 |
614 | for blk in self.blocks:
615 | x = blk(x, ratio)
616 |
617 | if ratio == 0.25 or ratio == 1.0:
618 | weight_norm = self.norm.weight[:int(ratio*self.embed_dim)]
619 | bias_norm = self.norm.bias[:int(ratio*self.embed_dim)]
620 | x = F.layer_norm(x, [int(ratio*self.embed_dim)], weight_norm, bias_norm)
621 | else:
622 | weight_norm = self.norm.weight[-int(ratio*self.embed_dim):]
623 | bias_norm = self.norm.bias[-int(ratio*self.embed_dim):]
624 | x = F.layer_norm(x, [int(ratio*self.embed_dim)], weight_norm, bias_norm)
625 |
626 | return x[:, 0], x[:, 1]
627 |
628 | def forward(self, x, ratio):
629 | x, x_dist = self.forward_features(x, ratio)
630 |
631 | if ratio == 0.25 or ratio == 1.0:
632 | weight_head = self.head.weight[:, :int(ratio*self.embed_dim)]
633 | bias_head = self.head.bias[:]
634 | x = F.linear(input=x, weight=weight_head, bias=bias_head)
635 | else:
636 | weight_head = self.head.weight[:, -int(ratio*self.embed_dim):]
637 | bias_head = self.head.bias[:]
638 | x = F.linear(input=x, weight=weight_head, bias=bias_head)
639 |
640 | if ratio == 0.25 or ratio == 1.0:
641 | weight_head_dist = self.head_dist.weight[:, :int(ratio*self.embed_dim)]
642 | bias_head_dist = self.head_dist.bias[:]
643 | x_dist = F.linear(input=x_dist, weight=weight_head_dist, bias=bias_head_dist)
644 | else:
645 | weight_head_dist = self.head_dist.weight[:, -int(ratio*self.embed_dim):]
646 | bias_head_dist = self.head_dist.bias[:]
647 | x_dist = F.linear(input=x_dist, weight=weight_head_dist, bias=bias_head_dist)
648 | if self.training:
649 | return x, x_dist
650 | else:
651 | # during inference, return the average of both classifier predictions
652 | return (x + x_dist) / 2
653 |
654 |
655 |
656 | @register_model
657 | def deit_so_tiny_patch16_224_scala(pretrained=False, pretrained_cfg=None, **kwargs):
658 | model = VisionTransformer_scala(
659 | patch_size=16, embed_dim=96, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
660 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
661 | model.default_cfg = _cfg()
662 | return model
663 |
664 |
665 | @register_model
666 | def deit_tiny_patch16_224_scala(pretrained=False, pretrained_cfg=None, **kwargs):
667 | model = VisionTransformer_scala(
668 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
669 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
670 | model.default_cfg = _cfg()
671 | return model
672 |
673 |
674 | @register_model
675 | def deit_small_patch16_224_scala(pretrained=False, pretrained_cfg=None, **kwargs):
676 | model = VisionTransformer_scala(
677 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
678 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
679 | model.default_cfg = _cfg()
680 | return model
681 |
682 |
683 | @register_model
684 | def deit_base_patch16_224_scala(pretrained=False, pretrained_cfg=None, **kwargs):
685 | model = VisionTransformer_scala(
686 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
687 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
688 | model.default_cfg = _cfg()
689 | return model
690 |
691 |
692 | @register_model
693 | def deit_tiny_distilled_patch16_224_scala(pretrained=False, pretrained_cfg=None, **kwargs):
694 | model = DistilledVisionTransformer_scala(
695 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
696 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
697 | model.default_cfg = _cfg()
698 | return model
699 |
700 |
701 | @register_model
702 | def deit_small_distilled_patch16_224_scala(pretrained=False, pretrained_cfg=None, **kwargs):
703 | model = DistilledVisionTransformer_scala(
704 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
705 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
706 | model.default_cfg = _cfg()
707 | return model
708 |
709 |
710 | @register_model
711 | def deit_base_distilled_patch16_224_scala(pretrained=False, pretrained_cfg=None, **kwargs):
712 | model = DistilledVisionTransformer_scala(
713 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
714 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
715 | model.default_cfg = _cfg()
716 | return model
717 |
718 |
719 | @register_model
720 | def deit_base_patch16_384_scala(pretrained=False, pretrained_cfg=None, **kwargs):
721 | model = VisionTransformer_scala(
722 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
723 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
724 | model.default_cfg = _cfg()
725 | return model
726 |
727 |
728 | @register_model
729 | def deit_base_distilled_patch16_384_scala(pretrained=False, pretrained_cfg=None, **kwargs):
730 | model = DistilledVisionTransformer_scala(
731 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
732 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
733 | model.default_cfg = _cfg()
734 | return model
735 |
736 |
737 |
738 | if __name__ == "__main__":
739 | import argparse
740 | from fvcore.nn import FlopCountAnalysis, parameter_count_table
741 | parser = argparse.ArgumentParser(description='PyTorch resnet Training')
742 | args = parser.parse_args()
743 |
744 | args.num_classes = 1000
745 | with torch.no_grad():
746 | model = deit_small_distilled_patch16_224_scala()
747 |
748 | # for name, param in model.named_parameters():
749 | # print(name)
750 |
751 | tensor = (torch.rand(1, 3, 224, 224), 0.5)
752 | flops = FlopCountAnalysis(model, tensor)
753 | print("FLOPs: ", flops.total()/1e9)
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | ### scala on deit_s for 100 epoch
2 |
3 | python -m torch.distributed.launch --nproc_per_node=4 --use_env main_scala.py \
4 | --batch-size 256 --epochs 100 --data-path IMAGENET_LOCATION \
5 | --aa rand-m1-mstd0.5-inc1 --no-repeated-aug --lr 2e-3 --warmup-epochs 3 \
6 | --teacher-model deit_small_patch16_224 --distillation-type hard \
7 | --teacher-path https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth \
8 | --model deit_small_distilled_patch16_224_scala \
9 | --smallest_ratio 0.25 --largest_ratio 1.0 --granularity 0.0625 --distill_type soft \
10 | --transfer_type progressive --token_type dist_token --ce_coefficient 1.0 \
11 | --full_warm_epoch 10 --output_dir log/deit_small_distilled_scala;
12 |
13 |
14 |
15 | ### scala on deit_b for 300 epoch
16 |
17 | python -m torch.distributed.launch --nproc_per_node=4 --use_env main_scala.py \
18 | --batch-size 128 --epochs 300 --data-path IMAGENET_LOCATION \
19 | --aa rand-m4-mstd0.5-inc1 --no-repeated-aug --lr 1e-3 --warmup-epochs 5 \
20 | --teacher-model deit_base_patch16_224 --distillation-type hard \
21 | --teacher-path https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth \
22 | --model deit_base_distilled_patch16_224_scala \
23 | --smallest_ratio 0.25 --largest_ratio 1.0 --discrete_ratio 0.0625 --distill_type soft \
24 | --transfer_type progressive --token_type dist_token --ce_coefficient 1.0 \
25 | --full_warm_epoch 10 --output_dir log/deit_base_distilled_scala;
--------------------------------------------------------------------------------
/samplers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | import torch
4 | import torch.distributed as dist
5 | import math
6 |
7 |
8 | class RASampler(torch.utils.data.Sampler):
9 | """Sampler that restricts data loading to a subset of the dataset for distributed,
10 | with repeated augmentation.
11 | It ensures that different each augmented version of a sample will be visible to a
12 | different process (GPU)
13 | Heavily based on torch.utils.data.DistributedSampler
14 | """
15 |
16 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, num_repeats: int = 3):
17 | if num_replicas is None:
18 | if not dist.is_available():
19 | raise RuntimeError("Requires distributed package to be available")
20 | num_replicas = dist.get_world_size()
21 | if rank is None:
22 | if not dist.is_available():
23 | raise RuntimeError("Requires distributed package to be available")
24 | rank = dist.get_rank()
25 | if num_repeats < 1:
26 | raise ValueError("num_repeats should be greater than 0")
27 | self.dataset = dataset
28 | self.num_replicas = num_replicas
29 | self.rank = rank
30 | self.num_repeats = num_repeats
31 | self.epoch = 0
32 | self.num_samples = int(math.ceil(len(self.dataset) * self.num_repeats / self.num_replicas))
33 | self.total_size = self.num_samples * self.num_replicas
34 | # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
35 | self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
36 | self.shuffle = shuffle
37 |
38 | def __iter__(self):
39 | if self.shuffle:
40 | # deterministically shuffle based on epoch
41 | g = torch.Generator()
42 | g.manual_seed(self.epoch)
43 | indices = torch.randperm(len(self.dataset), generator=g)
44 | else:
45 | indices = torch.arange(start=0, end=len(self.dataset))
46 |
47 | # add extra samples to make it evenly divisible
48 | indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0).tolist()
49 | padding_size: int = self.total_size - len(indices)
50 | if padding_size > 0:
51 | indices += indices[:padding_size]
52 | assert len(indices) == self.total_size
53 |
54 | # subsample
55 | indices = indices[self.rank:self.total_size:self.num_replicas]
56 | assert len(indices) == self.num_samples
57 |
58 | return iter(indices[:self.num_selected_samples])
59 |
60 | def __len__(self):
61 | return self.num_selected_samples
62 |
63 | def set_epoch(self, epoch):
64 | self.epoch = epoch
65 |
--------------------------------------------------------------------------------
/scheduler.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from abc import ABC
3 | from typing import Any, Dict, Optional
4 | import torch
5 | from typing import List, Union
6 | from torch.optim import Optimizer
7 | import logging
8 | import math
9 | import numpy as np
10 | import torch
11 |
12 |
13 | _logger = logging.getLogger(__name__)
14 |
15 |
16 | def scheduler_kwargs(cfg):
17 | """ cfg/argparse to kwargs helper
18 | Convert scheduler args in argparse args or cfg (.dot) like object to keyword args.
19 | """
20 | eval_metric = getattr(cfg, 'eval_metric', 'top1')
21 | plateau_mode = 'min' if 'loss' in eval_metric else 'max'
22 | kwargs = dict(
23 | sched=cfg.sched,
24 | num_epochs=getattr(cfg, 'epochs', 100),
25 | decay_epochs=getattr(cfg, 'decay_epochs', 30),
26 | decay_milestones=getattr(cfg, 'decay_milestones', [30, 60]),
27 | warmup_epochs=getattr(cfg, 'warmup_epochs', 5),
28 | warm_epoch=getattr(cfg, 'warm_epoch', 10),
29 | warm_factor=getattr(cfg, 'warm_factor', 2.0),
30 | cooldown_epochs=getattr(cfg, 'cooldown_epochs', 0),
31 | patience_epochs=getattr(cfg, 'patience_epochs', 10),
32 | decay_rate=getattr(cfg, 'decay_rate', 0.1),
33 | min_lr=getattr(cfg, 'min_lr', 0.),
34 | warmup_lr=getattr(cfg, 'warmup_lr', 1e-5),
35 | warmup_prefix=getattr(cfg, 'warmup_prefix', False),
36 | noise=getattr(cfg, 'lr_noise', None),
37 | noise_pct=getattr(cfg, 'lr_noise_pct', 0.67),
38 | noise_std=getattr(cfg, 'lr_noise_std', 1.),
39 | noise_seed=getattr(cfg, 'seed', 42),
40 | cycle_mul=getattr(cfg, 'lr_cycle_mul', 1.),
41 | cycle_decay=getattr(cfg, 'lr_cycle_decay', 0.1),
42 | cycle_limit=getattr(cfg, 'lr_cycle_limit', 1),
43 | k_decay=getattr(cfg, 'lr_k_decay', 1.0),
44 | plateau_mode=plateau_mode,
45 | step_on_epochs=not getattr(cfg, 'sched_on_updates', False),
46 | )
47 | return kwargs
48 |
49 |
50 | def create_scheduler(
51 | args,
52 | optimizer: Optimizer,
53 | updates_per_epoch: int = 0,
54 | ):
55 | return create_scheduler_v2(
56 | optimizer=optimizer,
57 | **scheduler_kwargs(args),
58 | updates_per_epoch=updates_per_epoch,
59 | )
60 |
61 |
62 | def create_scheduler_v2(
63 | optimizer: Optimizer,
64 | sched: str = 'cosine',
65 | num_epochs: int = 300,
66 | decay_epochs: int = 90,
67 | decay_milestones: List[int] = (90, 180, 270),
68 | cooldown_epochs: int = 0,
69 | patience_epochs: int = 10,
70 | decay_rate: float = 0.1,
71 | min_lr: float = 0,
72 | warmup_lr: float = 1e-5,
73 | warm_epoch: int = 0,
74 | warm_factor: float = 1.0,
75 | warmup_epochs: int = 0,
76 | warmup_prefix: bool = False,
77 | noise: Union[float, List[float]] = None,
78 | noise_pct: float = 0.67,
79 | noise_std: float = 1.,
80 | noise_seed: int = 42,
81 | cycle_mul: float = 1.,
82 | cycle_decay: float = 0.1,
83 | cycle_limit: int = 1,
84 | k_decay: float = 1.0,
85 | plateau_mode: str = 'max',
86 | step_on_epochs: bool = True,
87 | updates_per_epoch: int = 0,
88 | ):
89 | t_initial = num_epochs
90 | warmup_t = warmup_epochs
91 | decay_t = decay_epochs
92 | cooldown_t = cooldown_epochs
93 |
94 | if not step_on_epochs:
95 | assert updates_per_epoch > 0, 'updates_per_epoch must be set to number of dataloader batches'
96 | t_initial = t_initial * updates_per_epoch
97 | warmup_t = warmup_t * updates_per_epoch
98 | decay_t = decay_t * updates_per_epoch
99 | decay_milestones = [d * updates_per_epoch for d in decay_milestones]
100 | cooldown_t = cooldown_t * updates_per_epoch
101 |
102 | # warmup args
103 | warmup_args = dict(
104 | warmup_lr_init=warmup_lr,
105 | warmup_t=warmup_t,
106 | warm_epoch = warm_epoch,
107 | warm_factor = warm_factor,
108 | warmup_prefix=warmup_prefix,
109 | )
110 |
111 | # setup noise args for supporting schedulers
112 | if noise is not None:
113 | if isinstance(noise, (list, tuple)):
114 | noise_range = [n * t_initial for n in noise]
115 | if len(noise_range) == 1:
116 | noise_range = noise_range[0]
117 | else:
118 | noise_range = noise * t_initial
119 | else:
120 | noise_range = None
121 | noise_args = dict(
122 | noise_range_t=noise_range,
123 | noise_pct=noise_pct,
124 | noise_std=noise_std,
125 | noise_seed=noise_seed,
126 | )
127 |
128 | # setup cycle args for supporting schedulers
129 | cycle_args = dict(
130 | cycle_mul=cycle_mul,
131 | cycle_decay=cycle_decay,
132 | cycle_limit=cycle_limit,
133 | )
134 |
135 | lr_scheduler = CosineLRScheduler(
136 | optimizer,
137 | t_initial=t_initial,
138 | lr_min=min_lr,
139 | t_in_epochs=step_on_epochs,
140 | **cycle_args,
141 | **warmup_args,
142 | **noise_args,
143 | k_decay=k_decay,
144 | )
145 |
146 |
147 | if hasattr(lr_scheduler, 'get_cycle_length'):
148 | # for cycle based schedulers (cosine, tanh, poly) recalculate total epochs w/ cycles & cooldown
149 | t_with_cycles_and_cooldown = lr_scheduler.get_cycle_length() + cooldown_t
150 | if step_on_epochs:
151 | num_epochs = t_with_cycles_and_cooldown
152 | else:
153 | num_epochs = t_with_cycles_and_cooldown // updates_per_epoch
154 |
155 | return lr_scheduler, num_epochs
156 |
157 |
158 |
159 |
160 | class Scheduler(ABC):
161 | """ Parameter Scheduler Base Class
162 | A scheduler base class that can be used to schedule any optimizer parameter groups.
163 |
164 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called
165 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value
166 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value
167 |
168 | The schedulers built on this should try to remain as stateless as possible (for simplicity).
169 |
170 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch'
171 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training
172 | code and explicitly passed in to the schedulers on the corresponding step or step_update call.
173 |
174 | Based on ideas from:
175 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler
176 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers
177 | """
178 |
179 | def __init__(
180 | self,
181 | optimizer: torch.optim.Optimizer,
182 | param_group_field: str,
183 | t_in_epochs: bool = True,
184 | noise_range_t=None,
185 | noise_type='normal',
186 | noise_pct=0.67,
187 | noise_std=1.0,
188 | noise_seed=None,
189 | initialize: bool = True,
190 | ) -> None:
191 | self.optimizer = optimizer
192 | self.param_group_field = param_group_field
193 | self._initial_param_group_field = f"initial_{param_group_field}"
194 | if initialize:
195 | for i, group in enumerate(self.optimizer.param_groups):
196 | if param_group_field not in group:
197 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]")
198 | group.setdefault(self._initial_param_group_field, group[param_group_field])
199 | else:
200 | for i, group in enumerate(self.optimizer.param_groups):
201 | if self._initial_param_group_field not in group:
202 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
203 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
204 | self.metric = None # any point to having this for all?
205 | self.t_in_epochs = t_in_epochs
206 | self.noise_range_t = noise_range_t
207 | self.noise_pct = noise_pct
208 | self.noise_type = noise_type
209 | self.noise_std = noise_std
210 | self.noise_seed = noise_seed if noise_seed is not None else 42
211 | self.update_groups(self.base_values)
212 |
213 | def state_dict(self) -> Dict[str, Any]:
214 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
215 |
216 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
217 | self.__dict__.update(state_dict)
218 |
219 | @abc.abstractmethod
220 | def _get_lr(self, t: int) -> float:
221 | pass
222 |
223 | def _get_values(self, t: int, on_epoch: bool = True) -> Optional[float]:
224 | proceed = (on_epoch and self.t_in_epochs) or (not on_epoch and not self.t_in_epochs)
225 | if not proceed:
226 | return None
227 | return self._get_lr(t)
228 |
229 | def step(self, epoch: int, metric: float = None) -> None:
230 | self.metric = metric
231 | values = self._get_values(epoch, on_epoch=True)
232 | if values is not None:
233 | values = self._add_noise(values, epoch)
234 | self.update_groups(values)
235 |
236 | def step_update(self, num_updates: int, metric: float = None):
237 | self.metric = metric
238 | values = self._get_values(num_updates, on_epoch=False)
239 | if values is not None:
240 | values = self._add_noise(values, num_updates)
241 | self.update_groups(values)
242 |
243 | def update_groups(self, values):
244 | if not isinstance(values, (list, tuple)):
245 | values = [values] * len(self.optimizer.param_groups)
246 | for param_group, value in zip(self.optimizer.param_groups, values):
247 | if 'lr_scale' in param_group:
248 | param_group[self.param_group_field] = value * param_group['lr_scale']
249 | else:
250 | param_group[self.param_group_field] = value
251 |
252 | def _add_noise(self, lrs, t):
253 | if self._is_apply_noise(t):
254 | noise = self._calculate_noise(t)
255 | lrs = [v + v * noise for v in lrs]
256 | return lrs
257 |
258 | def _is_apply_noise(self, t) -> bool:
259 | """Return True if scheduler in noise range."""
260 | apply_noise = False
261 | if self.noise_range_t is not None:
262 | if isinstance(self.noise_range_t, (list, tuple)):
263 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
264 | else:
265 | apply_noise = t >= self.noise_range_t
266 | return apply_noise
267 |
268 | def _calculate_noise(self, t) -> float:
269 | g = torch.Generator()
270 | g.manual_seed(self.noise_seed + t)
271 | if self.noise_type == 'normal':
272 | while True:
273 | # resample if noise out of percent limit, brute force but shouldn't spin much
274 | noise = torch.randn(1, generator=g).item()
275 | if abs(noise) < self.noise_pct:
276 | return noise
277 | else:
278 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
279 | return noise
280 |
281 |
282 |
283 |
284 | class CosineLRScheduler(Scheduler):
285 | """
286 | Cosine decay with restarts.
287 | This is described in the paper https://arxiv.org/abs/1608.03983.
288 |
289 | Inspiration from
290 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py
291 |
292 | k-decay option based on `k-decay: A New Method For Learning Rate Schedule` - https://arxiv.org/abs/2004.05909
293 | """
294 |
295 | def __init__(
296 | self,
297 | optimizer: torch.optim.Optimizer,
298 | t_initial: int,
299 | lr_min: float = 0.,
300 | cycle_mul: float = 1.,
301 | cycle_decay: float = 1.,
302 | cycle_limit: int = 1,
303 | warmup_t=0,
304 | warm_epoch=0,
305 | warm_factor=1.0,
306 | warmup_lr_init=0,
307 | warmup_prefix=False,
308 | t_in_epochs=True,
309 | noise_range_t=None,
310 | noise_pct=0.67,
311 | noise_std=1.0,
312 | noise_seed=42,
313 | k_decay=1.0,
314 | initialize=True,
315 | ) -> None:
316 | super().__init__(
317 | optimizer,
318 | param_group_field="lr",
319 | t_in_epochs=t_in_epochs,
320 | noise_range_t=noise_range_t,
321 | noise_pct=noise_pct,
322 | noise_std=noise_std,
323 | noise_seed=noise_seed,
324 | initialize=initialize,
325 | )
326 |
327 | assert t_initial > 0
328 | assert lr_min >= 0
329 | if t_initial == 1 and cycle_mul == 1 and cycle_decay == 1:
330 | _logger.warning(
331 | "Cosine annealing scheduler will have no effect on the learning "
332 | "rate since t_initial = t_mul = eta_mul = 1.")
333 | self.t_initial = t_initial
334 | self.lr_min = lr_min
335 | self.cycle_mul = cycle_mul
336 | self.cycle_decay = cycle_decay
337 | self.cycle_limit = cycle_limit
338 | self.warmup_t = warmup_t
339 | self.warm_epoch = warm_epoch
340 | self.warm_factor = warm_factor
341 | self.warmup_lr_init = warmup_lr_init
342 | self.warmup_prefix = warmup_prefix
343 | self.k_decay = k_decay
344 | if self.warmup_t:
345 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
346 | super().update_groups(self.warmup_lr_init)
347 | else:
348 | self.warmup_steps = [1 for _ in self.base_values]
349 |
350 | def _get_lr(self, t):
351 | if t < self.warmup_t:
352 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
353 | else:
354 | if self.warmup_prefix:
355 | t = t - self.warmup_t
356 |
357 | if self.cycle_mul != 1:
358 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.cycle_mul), self.cycle_mul))
359 | t_i = self.cycle_mul ** i * self.t_initial
360 | t_curr = t - (1 - self.cycle_mul ** i) / (1 - self.cycle_mul) * self.t_initial
361 | else:
362 | i = t // self.t_initial
363 | t_i = self.t_initial
364 | t_curr = t - (self.t_initial * i)
365 |
366 | gamma = self.cycle_decay ** i
367 | lr_max_values = [v * gamma for v in self.base_values]
368 | k = self.k_decay
369 |
370 | if i < self.cycle_limit:
371 | lrs = [
372 | self.lr_min + 0.5 * (lr_max - self.lr_min) * (1 + math.cos(math.pi * t_curr ** k / t_i ** k))
373 | for lr_max in lr_max_values
374 | ]
375 | else:
376 | lrs = [self.lr_min for _ in self.base_values]
377 |
378 | if t < self.warm_epoch:
379 | lrs = [lr / self.warm_factor for lr in lrs]
380 |
381 | return lrs
382 |
383 | def get_cycle_length(self, cycles=0):
384 | cycles = max(1, cycles or self.cycle_limit)
385 | if self.cycle_mul == 1.0:
386 | return self.t_initial * cycles
387 | else:
388 | return int(math.floor(-self.t_initial * (self.cycle_mul ** cycles - 1) / (1 - self.cycle_mul)))
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Misc functions, including distributed helpers.
5 |
6 | Mostly copy-paste from torchvision references.
7 | """
8 | import io
9 | import os
10 | import time
11 | from collections import defaultdict, deque
12 | import datetime
13 |
14 | import torch
15 | import torch.distributed as dist
16 |
17 |
18 | class SmoothedValue(object):
19 | """Track a series of values and provide access to smoothed values over a
20 | window or the global series average.
21 | """
22 |
23 | def __init__(self, window_size=20, fmt=None):
24 | if fmt is None:
25 | fmt = "{median:.4f} ({global_avg:.4f})"
26 | self.deque = deque(maxlen=window_size)
27 | self.total = 0.0
28 | self.count = 0
29 | self.fmt = fmt
30 |
31 | def update(self, value, n=1):
32 | self.deque.append(value)
33 | self.count += n
34 | self.total += value * n
35 |
36 | def synchronize_between_processes(self):
37 | """
38 | Warning: does not synchronize the deque!
39 | """
40 | if not is_dist_avail_and_initialized():
41 | return
42 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
43 | dist.barrier()
44 | dist.all_reduce(t)
45 | t = t.tolist()
46 | self.count = int(t[0])
47 | self.total = t[1]
48 |
49 | @property
50 | def median(self):
51 | d = torch.tensor(list(self.deque))
52 | return d.median().item()
53 |
54 | @property
55 | def avg(self):
56 | d = torch.tensor(list(self.deque), dtype=torch.float32)
57 | return d.mean().item()
58 |
59 | @property
60 | def global_avg(self):
61 | return self.total / self.count
62 |
63 | @property
64 | def max(self):
65 | return max(self.deque)
66 |
67 | @property
68 | def value(self):
69 | return self.deque[-1]
70 |
71 | def __str__(self):
72 | return self.fmt.format(
73 | median=self.median,
74 | avg=self.avg,
75 | global_avg=self.global_avg,
76 | max=self.max,
77 | value=self.value)
78 |
79 |
80 | class MetricLogger(object):
81 | def __init__(self, delimiter="\t"):
82 | self.meters = defaultdict(SmoothedValue)
83 | self.delimiter = delimiter
84 |
85 | def update(self, **kwargs):
86 | for k, v in kwargs.items():
87 | if isinstance(v, torch.Tensor):
88 | v = v.item()
89 | assert isinstance(v, (float, int))
90 | self.meters[k].update(v)
91 |
92 | def __getattr__(self, attr):
93 | if attr in self.meters:
94 | return self.meters[attr]
95 | if attr in self.__dict__:
96 | return self.__dict__[attr]
97 | raise AttributeError("'{}' object has no attribute '{}'".format(
98 | type(self).__name__, attr))
99 |
100 | def __str__(self):
101 | loss_str = []
102 | for name, meter in self.meters.items():
103 | loss_str.append(
104 | "{}: {}".format(name, str(meter))
105 | )
106 | return self.delimiter.join(loss_str)
107 |
108 | def synchronize_between_processes(self):
109 | for meter in self.meters.values():
110 | meter.synchronize_between_processes()
111 |
112 | def add_meter(self, name, meter):
113 | self.meters[name] = meter
114 |
115 | def log_every(self, iterable, print_freq, header=None):
116 | i = 0
117 | if not header:
118 | header = ''
119 | start_time = time.time()
120 | end = time.time()
121 | iter_time = SmoothedValue(fmt='{avg:.4f}')
122 | data_time = SmoothedValue(fmt='{avg:.4f}')
123 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
124 | log_msg = [
125 | header,
126 | '[{0' + space_fmt + '}/{1}]',
127 | 'eta: {eta}',
128 | '{meters}',
129 | 'time: {time}',
130 | 'data: {data}'
131 | ]
132 | if torch.cuda.is_available():
133 | log_msg.append('max mem: {memory:.0f}')
134 | log_msg = self.delimiter.join(log_msg)
135 | MB = 1024.0 * 1024.0
136 | for obj in iterable:
137 | data_time.update(time.time() - end)
138 | yield obj
139 | iter_time.update(time.time() - end)
140 | if i % print_freq == 0 or i == len(iterable) - 1:
141 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
142 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
143 | if torch.cuda.is_available():
144 | print(log_msg.format(
145 | i, len(iterable), eta=eta_string,
146 | meters=str(self),
147 | time=str(iter_time), data=str(data_time),
148 | memory=torch.cuda.max_memory_allocated() / MB))
149 | else:
150 | print(log_msg.format(
151 | i, len(iterable), eta=eta_string,
152 | meters=str(self),
153 | time=str(iter_time), data=str(data_time)))
154 | i += 1
155 | end = time.time()
156 | total_time = time.time() - start_time
157 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
158 | print('{} Total time: {} ({:.4f} s / it)'.format(
159 | header, total_time_str, total_time / len(iterable)))
160 |
161 |
162 | def _load_checkpoint_for_ema(model_ema, checkpoint):
163 | """
164 | Workaround for ModelEma._load_checkpoint to accept an already-loaded object
165 | """
166 | mem_file = io.BytesIO()
167 | torch.save({'state_dict_ema':checkpoint}, mem_file)
168 | mem_file.seek(0)
169 | model_ema._load_checkpoint(mem_file)
170 |
171 |
172 | def setup_for_distributed(is_master):
173 | """
174 | This function disables printing when not in master process
175 | """
176 | import builtins as __builtin__
177 | builtin_print = __builtin__.print
178 |
179 | def print(*args, **kwargs):
180 | force = kwargs.pop('force', False)
181 | if is_master or force:
182 | builtin_print(*args, **kwargs)
183 |
184 | __builtin__.print = print
185 |
186 |
187 | def is_dist_avail_and_initialized():
188 | if not dist.is_available():
189 | return False
190 | if not dist.is_initialized():
191 | return False
192 | return True
193 |
194 |
195 | def get_world_size():
196 | if not is_dist_avail_and_initialized():
197 | return 1
198 | return dist.get_world_size()
199 |
200 |
201 | def get_rank():
202 | if not is_dist_avail_and_initialized():
203 | return 0
204 | return dist.get_rank()
205 |
206 |
207 | def is_main_process():
208 | return get_rank() == 0
209 |
210 |
211 | def save_on_master(*args, **kwargs):
212 | if is_main_process():
213 | torch.save(*args, **kwargs)
214 |
215 |
216 | def init_distributed_mode(args):
217 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
218 | args.rank = int(os.environ["RANK"])
219 | args.world_size = int(os.environ['WORLD_SIZE'])
220 | args.gpu = int(os.environ['LOCAL_RANK'])
221 | elif 'SLURM_PROCID' in os.environ:
222 | args.rank = int(os.environ['SLURM_PROCID'])
223 | args.gpu = args.rank % torch.cuda.device_count()
224 | else:
225 | print('Not using distributed mode')
226 | args.distributed = False
227 | return
228 |
229 | args.distributed = True
230 |
231 | torch.cuda.set_device(args.gpu)
232 | args.dist_backend = 'nccl'
233 | print('| distributed init (rank {}): {}'.format(
234 | args.rank, args.dist_url), flush=True)
235 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
236 | world_size=args.world_size, rank=args.rank)
237 | torch.distributed.barrier()
238 | setup_for_distributed(args.rank == 0)
239 |
--------------------------------------------------------------------------------