├── swin_v2.PNG ├── run_train.sh ├── run_eval.sh ├── transforms.py ├── configs └── swinv2_base_patch4_window7_224.yaml ├── stat_define.py ├── droppath.py ├── utils.py ├── losses.py ├── README.md ├── random_erasing.py ├── config.py ├── port_weights ├── load_pytorch_weights_384.py ├── load_pytorch_weights.py └── load_pytorch_weights_large_384.py ├── modification.md ├── datasets.py ├── mixup.py ├── auto_augment.py ├── LICENSE ├── main_single_gpu.py ├── main_multi_gpu.py └── swin_transformer.py /swin_v2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nku-shengzheliu/PaddlePaddle-Swin-Transformer-V2/HEAD/swin_v2.PNG -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 \ 2 | python main_single_gpu.py \ 3 | -cfg='./configs/swin_tiny_patch4_window7_224.yaml' \ 4 | -dataset='imagenet2012' \ 5 | -batch_size=4 \ 6 | -data_path='/dataset/imagenet' \ 7 | -------------------------------------------------------------------------------- /run_eval.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 \ 2 | python main_single_gpu.py \ 3 | -cfg='./configs/swin_base_patch4_window7_224.yaml' \ 4 | -dataset='imagenet2012' \ 5 | -batch_size=32 \ 6 | -data_path='/dataset/imagenet' \ 7 | -eval \ 8 | -pretrained='./swin_base_patch4_window7_224' \ 9 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import paddle 3 | import paddle.nn 4 | import paddle.vision.transforms as T 5 | 6 | 7 | class RandomHorizontalFlip(): 8 | def __init__(self, p=0.5): 9 | self.p = p 10 | 11 | def __call__(self, image): 12 | if random.random() < self.p: 13 | return T.hflip(image) 14 | return image 15 | -------------------------------------------------------------------------------- /configs/swinv2_base_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMAGE_SIZE: 224 3 | CROP_PCT: 0.90 4 | MODEL: 5 | TYPE: swin 6 | NAME: swin_base_patch4_window7_224 7 | DROP_PATH: 0.5 8 | TRANS: 9 | EMBED_DIM: 128 10 | STAGE_DEPTHS: [2, 2, 18, 2] 11 | NUM_HEADS: [4, 8, 16, 32] 12 | WINDOW_SIZE: 7 13 | PATCH_SIZE: 4 14 | EXTRA_NORM: False 15 | 16 | -------------------------------------------------------------------------------- /stat_define.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import paddle 4 | from config import get_config 5 | from swin_transformer import build_swin as build_model 6 | 7 | def count_gelu(layer, input, output): 8 | activation_flops = 8 9 | x = input[0] 10 | num = x.numel() 11 | layer.total_ops += num * activation_flops 12 | 13 | 14 | def count_softmax(layer, input, output): 15 | softmax_flops = 5 # max/substract, exp, sum, divide 16 | x = input[0] 17 | num = x.numel() 18 | layer.total_ops += num * softmax_flops 19 | 20 | 21 | def count_layernorm(layer, input, output): 22 | layer_norm_flops = 5 # get mean (sum), get variance (square and sum), scale(multiply) 23 | x = input[0] 24 | num = x.numel() 25 | layer.total_ops += num * layer_norm_flops 26 | 27 | 28 | cfg = './configs/swin_tiny_patch4_window7_224.yaml' 29 | input_size = (1, 3, 224, 224) 30 | config = get_config(cfg) 31 | model = build_model(config) 32 | 33 | custom_ops = {paddle.nn.GELU: count_gelu, 34 | paddle.nn.LayerNorm: count_layernorm, 35 | paddle.nn.Softmax: count_softmax, 36 | } 37 | print(os.path.basename(cfg)) 38 | paddle.flops(model, 39 | input_size=input_size, 40 | custom_ops=custom_ops, 41 | print_detail=False) 42 | 43 | 44 | #for cfg in glob.glob('./configs/*.yaml'): 45 | # #cfg = './configs/swin_base_patch4_window7_224.yaml' 46 | # input_size = (1, 3, int(cfg[-8:-5]), int(cfg[-8:-5])) 47 | # config = get_config(cfg) 48 | # model = build_model(config) 49 | # 50 | # 51 | # custom_ops = {paddle.nn.GELU: count_gelu, 52 | # paddle.nn.LayerNorm: count_layernorm, 53 | # paddle.nn.Softmax: count_softmax, 54 | # } 55 | # print(os.path.basename(cfg)) 56 | # paddle.flops(model, 57 | # input_size=input_size, 58 | # custom_ops=custom_ops, 59 | # print_detail=False) 60 | # print('-----------') 61 | -------------------------------------------------------------------------------- /droppath.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Droppath, reimplement from https://github.com/yueatsprograms/Stochastic_Depth 17 | """ 18 | 19 | import numpy as np 20 | import paddle 21 | import paddle.nn as nn 22 | 23 | 24 | class DropPath(nn.Layer): 25 | """DropPath class""" 26 | def __init__(self, drop_prob=None): 27 | super(DropPath, self).__init__() 28 | self.drop_prob = drop_prob 29 | 30 | def drop_path(self, inputs): 31 | """drop path op 32 | Args: 33 | input: tensor with arbitrary shape 34 | drop_prob: float number of drop path probability, default: 0.0 35 | training: bool, if current mode is training, default: False 36 | Returns: 37 | output: output tensor after drop path 38 | """ 39 | # if prob is 0 or eval mode, return original input 40 | if self.drop_prob == 0. or not self.training: 41 | return inputs 42 | keep_prob = 1 - self.drop_prob 43 | keep_prob = paddle.to_tensor(keep_prob, dtype='float32') 44 | shape = (inputs.shape[0], ) + (1, ) * (inputs.ndim - 1) # shape=(N, 1, 1, 1) 45 | random_tensor = keep_prob + paddle.rand(shape, dtype=inputs.dtype) 46 | random_tensor = random_tensor.floor() # mask 47 | output = inputs.divide(keep_prob) * random_tensor # divide is to keep same output expectation 48 | return output 49 | 50 | def forward(self, inputs): 51 | return self.drop_path(inputs) 52 | 53 | 54 | #def main(): 55 | # tmp = paddle.to_tensor(np.random.rand(8, 16, 8, 8), dtype='float32') 56 | # dp = DropPath(0.5) 57 | # out = dp(tmp) 58 | # print(out) 59 | # 60 | #if __name__ == "__main__": 61 | # main() 62 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """utils for ViT 16 | 17 | Contains AverageMeter for monitoring, get_exclude_from_decay_fn for training 18 | and WarmupCosineScheduler for training 19 | 20 | """ 21 | 22 | import math 23 | from paddle.optimizer.lr import LRScheduler 24 | 25 | 26 | class AverageMeter(): 27 | """ Meter for monitoring losses""" 28 | def __init__(self): 29 | self.avg = 0 30 | self.sum = 0 31 | self.cnt = 0 32 | self.reset() 33 | 34 | def reset(self): 35 | """reset all values to zeros""" 36 | self.avg = 0 37 | self.sum = 0 38 | self.cnt = 0 39 | 40 | def update(self, val, n=1): 41 | """update avg by val and n, where val is the avg of n values""" 42 | self.sum += val * n 43 | self.cnt += n 44 | self.avg = self.sum / self.cnt 45 | 46 | 47 | 48 | def get_exclude_from_weight_decay_fn(exclude_list=[]): 49 | """ Set params with no weight decay during the training 50 | 51 | For certain params, e.g., positional encoding in ViT, weight decay 52 | may not needed during the learning, this method is used to find 53 | these params. 54 | 55 | Args: 56 | exclude_list: a list of params names which need to exclude 57 | from weight decay. 58 | Returns: 59 | exclude_from_weight_decay_fn: a function returns True if param 60 | will be excluded from weight decay 61 | """ 62 | if len(exclude_list) == 0: 63 | exclude_from_weight_decay_fn = None 64 | else: 65 | def exclude_fn(param): 66 | for name in exclude_list: 67 | if param.endswith(name): 68 | return False 69 | return True 70 | exclude_from_weight_decay_fn = exclude_fn 71 | return exclude_from_weight_decay_fn 72 | 73 | 74 | class WarmupCosineScheduler(LRScheduler): 75 | """Warmup Cosine Scheduler 76 | 77 | First apply linear warmup, then apply cosine decay schedule. 78 | Linearly increase learning rate from "warmup_start_lr" to "start_lr" over "warmup_epochs" 79 | Cosinely decrease learning rate from "start_lr" to "end_lr" over remaining 80 | "total_epochs - warmup_epochs" 81 | 82 | Attributes: 83 | learning_rate: the starting learning rate (without warmup), not used here! 84 | warmup_start_lr: warmup starting learning rate 85 | start_lr: the starting learning rate (without warmup) 86 | end_lr: the ending learning rate after whole loop 87 | warmup_epochs: # of epochs for warmup 88 | total_epochs: # of total epochs (include warmup) 89 | """ 90 | def __init__(self, 91 | learning_rate, 92 | warmup_start_lr, 93 | start_lr, 94 | end_lr, 95 | warmup_epochs, 96 | total_epochs, 97 | cycles=0.5, 98 | last_epoch=-1, 99 | verbose=False): 100 | """init WarmupCosineScheduler """ 101 | self.warmup_epochs = warmup_epochs 102 | self.total_epochs = total_epochs 103 | self.warmup_start_lr = warmup_start_lr 104 | self.start_lr = start_lr 105 | self.end_lr = end_lr 106 | self.cycles = cycles 107 | super(WarmupCosineScheduler, self).__init__(learning_rate, last_epoch, verbose) 108 | 109 | def get_lr(self): 110 | """ return lr value """ 111 | if self.last_epoch < self.warmup_epochs: 112 | val = (self.start_lr - self.warmup_start_lr) * float( 113 | self.last_epoch)/float(self.warmup_epochs) + self.warmup_start_lr 114 | return val 115 | 116 | progress = float(self.last_epoch - self.warmup_epochs) / float( 117 | max(1, self.total_epochs - self.warmup_epochs)) 118 | val = max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 119 | val = max(0.0, val * (self.start_lr - self.end_lr) + self.end_lr) 120 | return val 121 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ Implement Loss functions """ 16 | import paddle 17 | import paddle.nn as nn 18 | import paddle.nn.functional as F 19 | 20 | 21 | class LabelSmoothingCrossEntropyLoss(nn.Layer): 22 | """ cross entropy loss for label smoothing 23 | Args: 24 | smoothing: float, smoothing rate 25 | x: tensor, predictions (before softmax) with shape [N, num_classes] 26 | target: tensor, target label with shape [N] 27 | Return: 28 | loss: float, cross entropy loss value 29 | """ 30 | def __init__(self, smoothing=0.1): 31 | super().__init__() 32 | assert 0 <= smoothing < 1.0 33 | self.smoothing = smoothing 34 | self.confidence = 1 - smoothing 35 | 36 | def forward(self, x, target): 37 | log_probs = F.log_softmax(x) # [N, num_classes] 38 | # target_index is used to get prob for each of the N samples 39 | target_index = paddle.zeros([x.shape[0], 2], dtype='int64') # [N, 2] 40 | target_index[:, 0] = paddle.arange(x.shape[0]) 41 | target_index[:, 1] = target 42 | 43 | nll_loss = -log_probs.gather_nd(index=target_index) # index: [N] 44 | smooth_loss = -log_probs.mean(axis=-1) 45 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 46 | return loss.mean() 47 | 48 | 49 | class SoftTargetCrossEntropyLoss(nn.Layer): 50 | """ cross entropy loss for soft target 51 | Args: 52 | x: tensor, predictions (before softmax) with shape [N, num_classes] 53 | target: tensor, soft target with shape [N, num_classes] 54 | Returns: 55 | loss: float, the mean loss value 56 | """ 57 | def __init__(self): 58 | super().__init__() 59 | 60 | def forward(self, x, target): 61 | loss = paddle.sum(-target * F.log_softmax(x, axis=-1), axis=-1) 62 | return loss.mean() 63 | 64 | 65 | class DistillationLoss(nn.Layer): 66 | """Distillation loss function 67 | This layer includes the orginal loss (criterion) and a extra 68 | distillation loss (criterion), which computes the loss with 69 | different type options, between current model and 70 | a teacher model as its supervision. 71 | 72 | Args: 73 | base_criterion: nn.Layer, the original criterion 74 | teacher_model: nn.Layer, the teacher model as supervision 75 | distillation_type: str, one of ['none', 'soft', 'hard'] 76 | alpha: float, ratio of base loss (* (1-alpha)) 77 | and distillation loss( * alpha) 78 | tao: float, temperature in distillation 79 | """ 80 | def __init__(self, 81 | base_criterion, 82 | teacher_model, 83 | distillation_type, 84 | alpha, 85 | tau): 86 | super().__init__() 87 | assert distillation_type in ['none', 'soft', 'hard'] 88 | self.base_criterion = base_criterion 89 | self.teacher_model = teacher_model 90 | self.type = distillation_type 91 | self.alpha = alpha 92 | self.tau = tau 93 | 94 | def forward(self, inputs, outputs, targets): 95 | """ 96 | Args: 97 | inputs: tensor, the orginal model inputs 98 | outputs: tensor, the outputs of the model 99 | outputds_kd: tensor, the distillation outputs of the model, 100 | this is usually obtained by a separate branch 101 | in the last layer of the model 102 | targets: tensor, the labels for the base criterion 103 | """ 104 | outputs, outputs_kd = outputs[0], outputs[1] 105 | base_loss = self.base_criterion(outputs, targets) 106 | if self.type == 'none': 107 | return base_loss 108 | 109 | with paddle.no_grad(): 110 | teacher_outputs = self.teacher_model(inputs) 111 | 112 | if self.type == 'soft': 113 | distillation_loss = F.kl_div( 114 | F.log_softmax(outputs_kd / self.tau, axis=1), 115 | F.log_softmax(teacher_outputs / self.tau, axis=1), 116 | reduction='sum') * (self.tau * self.tau) / outputs_kd.numel() 117 | elif self.type == 'hard': 118 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(axis=1)) 119 | 120 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha 121 | return loss 122 | 123 | 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Swin Transformer V2: Scaling Up Capacity and Resolution, [arxiv](https://arxiv.org/pdf/2111.09883) 2 | 3 | PaddlePaddle training/validation code and pretrained models for **Swin Transformer V2**. 4 | 5 | The official pytorch implementation is [here](https://github.com/microsoft/Swin-Transformer). 6 | 7 | This implementation is developed by [PaddleViT](https://github.com/BR-IDL/PaddleViT.git). 8 | 9 |

10 | drawing 11 |

Comparison of the WindowAttention module between Swin Transformer V1 and Swin Transformer V2

12 |

13 | 14 | ## Update 15 | 16 | * Update (2021-11-27): Complete the modification of WindowAttention module according to the original paper 17 | - [x] post-norm configuration 18 | - [x] scaled cosine attention 19 | - [x] log-spaced continuous relative position bias 20 | 21 | ## Code modification explanation 22 | 23 | The code modification explanation is [here](./modification.md) 24 | 25 | ## Models trained from scratch using PaddleViT 26 | 27 | | Model | Acc@1 | Acc@5 | #Params | FLOPs | Image Size | Crop_pct | Interpolation | Link | 28 | |-------------------------------|-------|-------|---------|--------|------------|----------|---------------|--------------| 29 | | swin_b_224 | | | 88.9M | 15.3G | 224 | 0.9 | Log-CPB | coming soon | 30 | 31 | > *The results are evaluated on ImageNet2012 validation set. 32 | 33 | 34 | ## Requirements 35 | - Python>=3.6 36 | - yaml>=0.2.5 37 | - [PaddlePaddle](https://www.paddlepaddle.org.cn/documentation/docs/en/install/index_en.html)>=2.1.0 38 | - [yacs](https://github.com/rbgirshick/yacs)>=0.1.8 39 | 40 | ## Data 41 | ImageNet2012 dataset is used in the following folder structure: 42 | ``` 43 | │imagenet/ 44 | ├──train/ 45 | │ ├── n01440764 46 | │ │ ├── n01440764_10026.JPEG 47 | │ │ ├── n01440764_10027.JPEG 48 | │ │ ├── ...... 49 | │ ├── ...... 50 | ├──val/ 51 | │ ├── ILSVRC2012_val_00000293.JPEG 52 | │ ├── ILSVRC2012_val_00002138.JPEG 53 | │ ├── ...... 54 | ``` 55 | 56 | ## Usage 57 | To use the model with pretrained weights, download the `.pdparam` weight file and change related file paths in the following python scripts. The model config files are located in `./configs/`. 58 | 59 | For example, assume the downloaded weight file is stored in `./swin_base_patch4_window7_224.pdparams`, to use the `swin_base_patch4_window7_224` model in python: 60 | ```python 61 | from config import get_config 62 | from swin import build_swin as build_model 63 | # config files in ./configs/ 64 | config = get_config('./configs/swinv2_base_patch4_window7_224.yaml') 65 | # build model 66 | model = build_model(config) 67 | # load pretrained weights, .pdparams is NOT needed 68 | model_state_dict = paddle.load('./swinv2_base_patch4_window7_224') 69 | model.set_dict(model_state_dict) 70 | ``` 71 | 72 | ## Evaluation 73 | To evaluate Swin Transformer model performance on ImageNet2012 with a single GPU, run the following script using command line: 74 | ```shell 75 | sh run_eval.sh 76 | ``` 77 | or 78 | ```shell 79 | CUDA_VISIBLE_DEVICES=0 \ 80 | python main_single_gpu.py \ 81 | -cfg='./configs/swinv2_base_patch4_window7_224.yaml' \ 82 | -dataset='imagenet2012' \ 83 | -batch_size=16 \ 84 | -data_path='/dataset/imagenet' \ 85 | -eval \ 86 | -pretrained='./swinv2_base_patch4_window7_224' 87 | ``` 88 | 89 |
90 | 91 | Run evaluation using multi-GPUs: 92 | 93 | 94 | 95 | ```shell 96 | sh run_eval_multi.sh 97 | ``` 98 | or 99 | ```shell 100 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 101 | python main_multi_gpu.py \ 102 | -cfg='./configs/swinv2_base_patch4_window7_224.yaml' \ 103 | -dataset='imagenet2012' \ 104 | -batch_size=16 \ 105 | -data_path='/dataset/imagenet' \ 106 | -eval \ 107 | -pretrained='./swinv2_base_patch4_window7_224' 108 | ``` 109 | 110 |
111 | 112 | 113 | ## Training 114 | To train the Swin Transformer model on ImageNet2012 with single GPU, run the following script using command line: 115 | ```shell 116 | sh run_train.sh 117 | ``` 118 | or 119 | ```shell 120 | CUDA_VISIBLE_DEVICES=0 \ 121 | python main_singel_gpu.py \ 122 | -cfg='./configs/swinv2_base_patch4_window7_224.yaml' \ 123 | -dataset='imagenet2012' \ 124 | -batch_size=32 \ 125 | -data_path='/dataset/imagenet' \ 126 | ``` 127 | 128 |
129 | 130 | 131 | Run training using multi-GPUs: 132 | 133 | 134 | 135 | ```shell 136 | sh run_train_multi.sh 137 | ``` 138 | or 139 | ```shell 140 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 141 | python main_multi_gpu.py \ 142 | -cfg='./configs/swinv2_base_patch4_window7_224.yaml' \ 143 | -dataset='imagenet2012' \ 144 | -batch_size=16 \ 145 | -data_path='/dataset/imagenet' \ 146 | ``` 147 | 148 |
149 | 150 | ## Reference 151 | ``` 152 | @article{liu2021swin, 153 | title={Swin Transformer V2: Scaling Up Capacity and Resolution}, 154 | author={Liu, Ze and Hu, Han and Lin, Yutong and Yao, Zhuliang and Xie, Zhenda and Wei, Yixuan and Ning, Jia and Cao, Yue and Zhang, Zheng and Dong, Li and others}, 155 | journal={arXiv preprint arXiv:2111.09883}, 156 | year={2021} 157 | } 158 | ``` 159 | 160 | -------------------------------------------------------------------------------- /random_erasing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Random Erasing for image tensor""" 16 | 17 | import random 18 | import math 19 | import paddle 20 | 21 | 22 | def _get_pixels(per_pixel, rand_color, patch_size, dtype="float32"): 23 | if per_pixel: 24 | return paddle.normal(shape=patch_size).astype(dtype) 25 | elif rand_color: 26 | return paddle.normal(shape=(patch_size[0], 1, 1)).astype(dtype) 27 | else: 28 | return paddle.zeros((patch_size[0], 1, 1)).astype(dtype) 29 | 30 | 31 | class RandomErasing(object): 32 | """ 33 | Args: 34 | prob: probability of performing random erasing 35 | min_area: Minimum percentage of erased area wrt input image area 36 | max_area: Maximum percentage of erased area wrt input image area 37 | min_aspect: Minimum aspect ratio of earsed area 38 | max_aspect: Maximum aspect ratio of earsed area 39 | mode: pixel color mode, in ['const', 'rand', 'pixel'] 40 | 'const' - erase block is constant valued 0 for all channels 41 | 'rand' - erase block is valued random color (same per-channel) 42 | 'pixel' - erase block is vauled random color per pixel 43 | min_count: Minimum # of ereasing blocks per image. 44 | max_count: Maximum # of ereasing blocks per image. Area per box is scaled by count 45 | per-image count is randomly chosen between min_count to max_count 46 | """ 47 | def __init__(self, prob=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, 48 | mode='const', min_count=1, max_count=None, num_splits=0): 49 | self.prob = prob 50 | self.min_area = min_area 51 | self.max_area = max_area 52 | max_aspect = max_aspect or 1 / min_aspect 53 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 54 | self.min_count = min_count 55 | self.max_count = max_count or min_count 56 | self.num_splits = num_splits 57 | mode = mode.lower() 58 | self.rand_color = False 59 | self.per_pixel = False 60 | if mode == "rand": 61 | self.rand_color = True 62 | elif mode == "pixel": 63 | self.per_pixel = True 64 | else: 65 | assert not mode or mode == "const" 66 | 67 | def _erase(self, img, chan, img_h, img_w, dtype): 68 | if random.random() > self.prob: 69 | return 70 | area = img_h * img_w 71 | count = self.min_count if self.min_count == self.max_count else \ 72 | random.randint(self.min_count, self.max_count) 73 | for _ in range(count): 74 | for attempt in range(10): 75 | target_area = random.uniform(self.min_area, self.max_area) * area / count 76 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 77 | h = int(round(math.sqrt(target_area * aspect_ratio))) 78 | w = int(round(math.sqrt(target_area / aspect_ratio))) 79 | if w < img_w and h < img_h: 80 | top = random.randint(0, img_h - h) 81 | left = random.randint(0, img_w - w) 82 | img[:, top:top+h, left:left+w] = _get_pixels( 83 | self.per_pixel, self.rand_color, (chan, h, w), 84 | dtype=dtype) 85 | break 86 | 87 | def __call__(self, input): 88 | if len(input.shape) == 3: 89 | self._erase(input, *input.shape, input.dtype) 90 | else: 91 | batch_size, chan, img_h, img_w = input.shape 92 | batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 93 | for i in range(batch_start, batch_size): 94 | self._erase(input[i], chan, img_h, img_w, input.dtype) 95 | return input 96 | 97 | 98 | 99 | #def main(): 100 | # re = RandomErasing(prob=1.0, min_area=0.2, max_area=0.6, mode='rand') 101 | # #re = RandomErasing(prob=1.0, min_area=0.2, max_area=0.6, mode='const') 102 | # #re = RandomErasing(prob=1.0, min_area=0.2, max_area=0.6, mode='pixel') 103 | # import PIL.Image as Image 104 | # import numpy as np 105 | # paddle.set_device('cpu') 106 | # img = paddle.to_tensor(np.asarray(Image.open('./lenna.png'))).astype('float32') 107 | # img = img / 255.0 108 | # img = paddle.transpose(img, [2, 0, 1]) 109 | # new_img = re(img) 110 | # new_img = new_img * 255.0 111 | # new_img = paddle.transpose(new_img, [1, 2, 0]) 112 | # new_img = new_img.cpu().numpy() 113 | # new_img = Image.fromarray(new_img.astype('uint8')) 114 | # new_img.save('./res.png') 115 | # 116 | # 117 | # 118 | #if __name__ == "__main__": 119 | # main() 120 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration 16 | 17 | Configuration for data, model archtecture, and training, etc. 18 | Config can be set by .yaml file or by argparser(limited usage) 19 | 20 | 21 | """ 22 | 23 | import os 24 | from yacs.config import CfgNode as CN 25 | import yaml 26 | 27 | _C = CN() 28 | _C.BASE = [''] 29 | 30 | # data settings 31 | _C.DATA = CN() 32 | _C.DATA.BATCH_SIZE = 8 #1024 batch_size for single GPU 33 | _C.DATA.BATCH_SIZE_EVAL = 8 #1024 batch_size for single GPU 34 | _C.DATA.DATA_PATH = '/dataset/imagenet/' # path to dataset 35 | _C.DATA.DATASET = 'imagenet2012' # dataset name 36 | _C.DATA.IMAGE_SIZE = 224 # input image size 37 | _C.DATA.CROP_PCT = 0.9 # input image scale ratio, scale is applied before centercrop in eval mode 38 | _C.DATA.NUM_WORKERS = 8 # number of data loading threads 39 | 40 | # model settings 41 | _C.MODEL = CN() 42 | _C.MODEL.TYPE = 'Swin' 43 | _C.MODEL.NAME = 'Swin' 44 | _C.MODEL.RESUME = None 45 | _C.MODEL.PRETRAINED = None 46 | _C.MODEL.NUM_CLASSES = 1000 47 | _C.MODEL.DROPOUT = 0.0 48 | _C.MODEL.ATTENTION_DROPOUT = 0.0 49 | _C.MODEL.DROP_PATH = 0.1 50 | 51 | # transformer settings 52 | _C.MODEL.TRANS = CN() 53 | _C.MODEL.TRANS.PATCH_SIZE = 4 # image_size = patch_size x window_size x num_windows 54 | _C.MODEL.TRANS.WINDOW_SIZE = 7 55 | _C.MODEL.TRANS.IN_CHANNELS = 3 56 | _C.MODEL.TRANS.EMBED_DIM = 96 # same as HIDDEN_SIZE in ViT 57 | _C.MODEL.TRANS.STAGE_DEPTHS = [2, 2, 6, 2] 58 | _C.MODEL.TRANS.NUM_HEADS = [3, 6, 12, 24] 59 | _C.MODEL.TRANS.MLP_RATIO = 4. 60 | _C.MODEL.TRANS.QKV_BIAS = True 61 | _C.MODEL.TRANS.QK_SCALE = None 62 | _C.MODEL.TRANS.APE = False # absolute positional embeddings 63 | _C.MODEL.TRANS.PATCH_NORM = True 64 | _C.MODEL.TRANS.EXTRA_NORM = False 65 | 66 | # training settings 67 | _C.TRAIN = CN() 68 | _C.TRAIN.LAST_EPOCH = 0 69 | _C.TRAIN.NUM_EPOCHS = 300 70 | _C.TRAIN.WARMUP_EPOCHS = 20 71 | _C.TRAIN.WEIGHT_DECAY = 0.05 72 | _C.TRAIN.BASE_LR = 5e-4 73 | _C.TRAIN.WARMUP_START_LR = 5e-7 74 | _C.TRAIN.END_LR = 5e-6 75 | _C.TRAIN.GRAD_CLIP = 5.0 76 | _C.TRAIN.ACCUM_ITER = 1 77 | 78 | _C.TRAIN.LR_SCHEDULER = CN() 79 | _C.TRAIN.LR_SCHEDULER.NAME = 'warmupcosine' 80 | _C.TRAIN.LR_SCHEDULER.MILESTONES = "30, 60, 90" # only used in StepLRScheduler 81 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 # only used in StepLRScheduler 82 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 # only used in StepLRScheduler 83 | 84 | _C.TRAIN.OPTIMIZER = CN() 85 | _C.TRAIN.OPTIMIZER.NAME = 'AdamW' 86 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 87 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) # for adamW 88 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 89 | 90 | # train augmentation 91 | _C.TRAIN.MIXUP_ALPHA = 0.8 92 | _C.TRAIN.CUTMIX_ALPHA = 1.0 93 | _C.TRAIN.CUTMIX_MINMAX = None 94 | _C.TRAIN.MIXUP_PROB = 1.0 95 | _C.TRAIN.MIXUP_SWITCH_PROB = 0.5 96 | _C.TRAIN.MIXUP_MODE = 'batch' 97 | 98 | _C.TRAIN.SMOOTHING = 0.1 99 | _C.TRAIN.COLOR_JITTER = 0.4 100 | _C.TRAIN.AUTO_AUGMENT = True #'rand-m9-mstd0.5-inc1' 101 | 102 | _C.TRAIN.RANDOM_ERASE_PROB = 0.25 103 | _C.TRAIN.RANDOM_ERASE_MODE = 'pixel' 104 | _C.TRAIN.RANDOM_ERASE_COUNT = 1 105 | _C.TRAIN.RANDOM_ERASE_SPLIT = False 106 | 107 | # augmentation 108 | _C.AUG = CN() 109 | _C.AUG.COLOR_JITTER = 0.4 # color jitter factor 110 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 111 | _C.AUG.RE_PROB = 0.25 # random earse prob 112 | _C.AUG.RE_MODE = 'pixel' # random earse mode 113 | _C.AUG.RE_COUNT = 1 # random earse count 114 | _C.AUG.MIXUP = 0.8 # mixup alpha, enabled if >0 115 | _C.AUG.CUTMIX = 1.0 # cutmix alpha, enabled if >0 116 | _C.AUG.CUTMIX_MINMAX = None # cutmix min/max ratio, overrides alpha 117 | _C.AUG.MIXUP_PROB = 1.0 # prob of mixup or cutmix when either/both is enabled 118 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 # prob of switching cutmix when both mixup and cutmix enabled 119 | _C.AUG.MIXUP_MODE = 'batch' #how to apply mixup/curmix params, per 'batch', 'pair', or 'elem' 120 | 121 | # misc 122 | _C.SAVE = "./output" 123 | _C.TAG = "default" 124 | _C.SAVE_FREQ = 1 # freq to save chpt 125 | _C.REPORT_FREQ = 50 # freq to logging info 126 | _C.VALIDATE_FREQ = 10 # freq to do validation 127 | _C.SEED = 42 128 | _C.EVAL = False # run evaluation only 129 | _C.AMP = False 130 | _C.LOCAL_RANK = 0 131 | _C.NGPUS = -1 132 | 133 | 134 | def _update_config_from_file(config, cfg_file): 135 | config.defrost() 136 | with open(cfg_file, 'r') as infile: 137 | yaml_cfg = yaml.load(infile, Loader=yaml.FullLoader) 138 | for cfg in yaml_cfg.setdefault('BASE', ['']): 139 | if cfg: 140 | _update_config_from_file( 141 | config, os.path.join(os.path.dirname(cfg_file), cfg) 142 | ) 143 | print('merging config from {}'.format(cfg_file)) 144 | config.merge_from_file(cfg_file) 145 | config.freeze() 146 | 147 | def update_config(config, args): 148 | """Update config by ArgumentParser 149 | Args: 150 | args: ArgumentParser contains options 151 | Return: 152 | config: updated config 153 | """ 154 | if args.cfg: 155 | _update_config_from_file(config, args.cfg) 156 | config.defrost() 157 | if args.dataset: 158 | config.DATA.DATASET = args.dataset 159 | if args.batch_size: 160 | config.DATA.BATCH_SIZE = args.batch_size 161 | if args.image_size: 162 | config.DATA.IMAGE_SIZE = args.image_size 163 | if args.data_path: 164 | config.DATA.DATA_PATH = args.data_path 165 | if args.ngpus: 166 | config.NGPUS = args.ngpus 167 | if args.eval: 168 | config.EVAL = True 169 | config.DATA.BATCH_SIZE_EVAL = args.batch_size 170 | if args.pretrained: 171 | config.MODEL.PRETRAINED = args.pretrained 172 | if args.resume: 173 | config.MODEL.RESUME = args.resume 174 | if args.last_epoch: 175 | config.TRAIN.LAST_EPOCH = args.last_epoch 176 | if args.amp: # only during training 177 | if config.EVAL is True: 178 | config.AMP = False 179 | 180 | #config.freeze() 181 | return config 182 | 183 | 184 | def get_config(cfg_file=None): 185 | """Return a clone of config or load from yaml file""" 186 | config = _C.clone() 187 | if cfg_file: 188 | _update_config_from_file(config, cfg_file) 189 | return config 190 | -------------------------------------------------------------------------------- /port_weights/load_pytorch_weights_384.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import numpy as np 17 | import paddle 18 | import torch 19 | import timm 20 | from swin_transformer import * 21 | from config import * 22 | 23 | 24 | config = get_config('./configs/swin_base_patch4_window12_384.yaml') 25 | print(config) 26 | 27 | 28 | def print_model_named_params(model): 29 | print('----------------------------------') 30 | for name, param in model.named_parameters(): 31 | print(name, param.shape) 32 | print('----------------------------------') 33 | 34 | 35 | def print_model_named_buffers(model): 36 | print('----------------------------------') 37 | for name, param in model.named_buffers(): 38 | print(name, param.shape) 39 | print('----------------------------------') 40 | 41 | 42 | def torch_to_paddle_mapping(): 43 | mapping = [ 44 | ('patch_embed.proj', 'patch_embedding.patch_embed'), 45 | ('patch_embed.norm', 'patch_embedding.norm'), 46 | ] 47 | 48 | # torch 'layers' to paddle 'stages' 49 | depths = config.MODEL.TRANS.STAGE_DEPTHS 50 | num_stages = len(depths) 51 | for stage_idx in range(num_stages): 52 | pp_s_prefix = f'stages.{stage_idx}.blocks' 53 | th_s_prefix = f'layers.{stage_idx}.blocks' 54 | for block_idx in range(depths[stage_idx]): 55 | th_b_prefix = f'{th_s_prefix}.{block_idx}' 56 | pp_b_prefix = f'{pp_s_prefix}.{block_idx}' 57 | layer_mapping = [ 58 | (f'{th_b_prefix}.norm1', f'{pp_b_prefix}.norm1'), 59 | (f'{th_b_prefix}.attn.relative_position_bias_table', f'{pp_b_prefix}.attn.relative_position_bias_table'), 60 | (f'{th_b_prefix}.attn.qkv', f'{pp_b_prefix}.attn.qkv'), 61 | (f'{th_b_prefix}.attn.proj', f'{pp_b_prefix}.attn.proj'), 62 | (f'{th_b_prefix}.norm2', f'{pp_b_prefix}.norm2'), 63 | (f'{th_b_prefix}.mlp.fc1', f'{pp_b_prefix}.mlp.fc1'), 64 | (f'{th_b_prefix}.mlp.fc2', f'{pp_b_prefix}.mlp.fc2'), 65 | ] 66 | mapping.extend(layer_mapping) 67 | # stage downsample: last stage does not have downsample ops 68 | if stage_idx < num_stages - 1: 69 | mapping.extend([ 70 | (f'layers.{stage_idx}.downsample.reduction.weight', f'stages.{stage_idx}.downsample.reduction.weight'), 71 | (f'layers.{stage_idx}.downsample.norm', f'stages.{stage_idx}.downsample.norm')]) 72 | 73 | mapping.extend([ 74 | ('norm', 'norm'), 75 | ('head', 'fc')]) 76 | return mapping 77 | 78 | 79 | 80 | def convert(torch_model, paddle_model): 81 | def _set_value(th_name, pd_name, no_transpose=False): 82 | th_shape = th_params[th_name].shape 83 | pd_shape = tuple(pd_params[pd_name].shape) # paddle shape default type is list 84 | #assert th_shape == pd_shape, f'{th_shape} != {pd_shape}' 85 | print(f'set {th_name} {th_shape} to {pd_name} {pd_shape}') 86 | value = th_params[th_name].data.numpy() 87 | if len(value.shape) == 2: 88 | if not no_transpose: 89 | value = value.transpose((1, 0)) 90 | pd_params[pd_name].set_value(value) 91 | 92 | # 1. get paddle and torch model parameters 93 | pd_params = {} 94 | th_params = {} 95 | for name, param in paddle_model.named_parameters(): 96 | pd_params[name] = param 97 | for name, param in torch_model.named_parameters(): 98 | th_params[name] = param 99 | 100 | for name, param in paddle_model.named_buffers(): 101 | pd_params[name] = param 102 | for name, param in torch_model.named_buffers(): 103 | th_params[name] = param 104 | 105 | # 2. get name mapping pairs 106 | mapping = torch_to_paddle_mapping() 107 | # 3. set torch param values to paddle params: may needs transpose on weights 108 | for th_name, pd_name in mapping: 109 | if th_name in th_params.keys(): # nn.Parameters 110 | if th_name.endswith('relative_position_bias_table'): 111 | _set_value(th_name, pd_name, no_transpose=True) 112 | else: 113 | _set_value(th_name, pd_name) 114 | else: # weight & bias 115 | th_name_w = f'{th_name}.weight' 116 | pd_name_w = f'{pd_name}.weight' 117 | _set_value(th_name_w, pd_name_w) 118 | 119 | if f'{th_name}.bias' in th_params.keys(): 120 | th_name_b = f'{th_name}.bias' 121 | pd_name_b = f'{pd_name}.bias' 122 | _set_value(th_name_b, pd_name_b) 123 | 124 | return paddle_model 125 | 126 | 127 | def main(): 128 | 129 | paddle.set_device('cpu') 130 | paddle_model = build_swin(config) 131 | paddle_model.eval() 132 | 133 | print_model_named_params(paddle_model) 134 | print_model_named_buffers(paddle_model) 135 | 136 | print('+++++++++++++++++++++++++++++++++++') 137 | device = torch.device('cpu') 138 | torch_model = timm.create_model('swin_base_patch4_window12_384', pretrained=True) 139 | torch_model = torch_model.to(device) 140 | torch_model.eval() 141 | 142 | print_model_named_params(torch_model) 143 | print_model_named_buffers(torch_model) 144 | 145 | # convert weights 146 | paddle_model = convert(torch_model, paddle_model) 147 | 148 | # check correctness 149 | x = np.random.randn(2, 3, 384, 384).astype('float32') 150 | x_paddle = paddle.to_tensor(x) 151 | x_torch = torch.Tensor(x).to(device) 152 | 153 | out_torch = torch_model(x_torch) 154 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||') 155 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||') 156 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||') 157 | out_paddle = paddle_model(x_paddle) 158 | 159 | out_torch = out_torch.data.cpu().numpy() 160 | out_paddle = out_paddle.cpu().numpy() 161 | 162 | print(out_torch.shape, out_paddle.shape) 163 | print(out_torch[0, 0:20]) 164 | print(out_paddle[0, 0:20]) 165 | assert np.allclose(out_torch, out_paddle, atol = 1e-4) 166 | 167 | # save weights for paddle model 168 | model_path = os.path.join('./swin_base_patch4_window12_384.pdparams') 169 | paddle.save(paddle_model.state_dict(), model_path) 170 | 171 | 172 | 173 | #tmp = np.random.randn(1, 56, 128, 128).astype('float32') 174 | #xp = paddle.to_tensor(tmp) 175 | #xt = torch.Tensor(tmp).to(device) 176 | #xps = paddle.roll(xp, shifts=(-3, -3), axis=(1,2)) 177 | #xts = torch.roll(xt,shifts=(-3, -3), dims=(1,2)) 178 | #xps = xps.cpu().numpy() 179 | #xts = xts.data.cpu().numpy() 180 | #assert np.allclose(xps, xts, atol=1e-4) 181 | 182 | if __name__ == "__main__": 183 | main() 184 | -------------------------------------------------------------------------------- /port_weights/load_pytorch_weights.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import numpy as np 17 | import paddle 18 | import torch 19 | import timm 20 | from swin_transformer import * 21 | from config import * 22 | 23 | 24 | config = get_config('./configs/swin_base_patch4_window7_224.yaml') 25 | print(config) 26 | 27 | 28 | def print_model_named_params(model): 29 | print('----------------------------------') 30 | for name, param in model.named_parameters(): 31 | print(name, param.shape) 32 | print('----------------------------------') 33 | 34 | 35 | def print_model_named_buffers(model): 36 | print('----------------------------------') 37 | for name, param in model.named_buffers(): 38 | print(name, param.shape) 39 | print('----------------------------------') 40 | 41 | 42 | def torch_to_paddle_mapping(): 43 | mapping = [ 44 | ('patch_embed.proj', 'patch_embedding.patch_embed'), 45 | ('patch_embed.norm', 'patch_embedding.norm'), 46 | ] 47 | 48 | # torch 'layers' to paddle 'stages' 49 | depths = config.MODEL.TRANS.STAGE_DEPTHS 50 | num_stages = len(depths) 51 | for stage_idx in range(num_stages): 52 | pp_s_prefix = f'stages.{stage_idx}.blocks' 53 | th_s_prefix = f'layers.{stage_idx}.blocks' 54 | for block_idx in range(depths[stage_idx]): 55 | th_b_prefix = f'{th_s_prefix}.{block_idx}' 56 | pp_b_prefix = f'{pp_s_prefix}.{block_idx}' 57 | layer_mapping = [ 58 | (f'{th_b_prefix}.norm1', f'{pp_b_prefix}.norm1'), 59 | (f'{th_b_prefix}.attn.relative_position_bias_table', f'{pp_b_prefix}.attn.relative_position_bias_table'), 60 | (f'{th_b_prefix}.attn.qkv', f'{pp_b_prefix}.attn.qkv'), 61 | (f'{th_b_prefix}.attn.proj', f'{pp_b_prefix}.attn.proj'), 62 | (f'{th_b_prefix}.norm2', f'{pp_b_prefix}.norm2'), 63 | (f'{th_b_prefix}.mlp.fc1', f'{pp_b_prefix}.mlp.fc1'), 64 | (f'{th_b_prefix}.mlp.fc2', f'{pp_b_prefix}.mlp.fc2'), 65 | ] 66 | mapping.extend(layer_mapping) 67 | # stage downsample: last stage does not have downsample ops 68 | if stage_idx < num_stages - 1: 69 | mapping.extend([ 70 | (f'layers.{stage_idx}.downsample.reduction.weight', f'stages.{stage_idx}.downsample.reduction.weight'), 71 | (f'layers.{stage_idx}.downsample.norm', f'stages.{stage_idx}.downsample.norm')]) 72 | 73 | mapping.extend([ 74 | ('norm', 'norm'), 75 | ('head', 'fc')]) 76 | return mapping 77 | 78 | 79 | 80 | def convert(torch_model, paddle_model): 81 | def _set_value(th_name, pd_name, no_transpose=False): 82 | th_shape = th_params[th_name].shape 83 | pd_shape = tuple(pd_params[pd_name].shape) # paddle shape default type is list 84 | #assert th_shape == pd_shape, f'{th_shape} != {pd_shape}' 85 | print(f'set {th_name} {th_shape} to {pd_name} {pd_shape}') 86 | value = th_params[th_name].data.numpy() 87 | if len(value.shape) == 2: 88 | if not no_transpose: 89 | value = value.transpose((1, 0)) 90 | pd_params[pd_name].set_value(value) 91 | 92 | # 1. get paddle and torch model parameters 93 | pd_params = {} 94 | th_params = {} 95 | for name, param in paddle_model.named_parameters(): 96 | pd_params[name] = param 97 | for name, param in torch_model.named_parameters(): 98 | th_params[name] = param 99 | 100 | for name, param in paddle_model.named_buffers(): 101 | pd_params[name] = param 102 | for name, param in torch_model.named_buffers(): 103 | th_params[name] = param 104 | 105 | # 2. get name mapping pairs 106 | mapping = torch_to_paddle_mapping() 107 | # 3. set torch param values to paddle params: may needs transpose on weights 108 | for th_name, pd_name in mapping: 109 | if th_name in th_params.keys(): # nn.Parameters 110 | if th_name.endswith('relative_position_bias_table'): 111 | _set_value(th_name, pd_name, no_transpose=True) 112 | else: 113 | _set_value(th_name, pd_name) 114 | else: # weight & bias 115 | th_name_w = f'{th_name}.weight' 116 | pd_name_w = f'{pd_name}.weight' 117 | _set_value(th_name_w, pd_name_w) 118 | 119 | if f'{th_name}.bias' in th_params.keys(): 120 | th_name_b = f'{th_name}.bias' 121 | pd_name_b = f'{pd_name}.bias' 122 | _set_value(th_name_b, pd_name_b) 123 | 124 | return paddle_model 125 | 126 | 127 | 128 | 129 | 130 | def main(): 131 | 132 | paddle.set_device('cpu') 133 | paddle_model = build_swin(config) 134 | paddle_model.eval() 135 | 136 | print_model_named_params(paddle_model) 137 | print_model_named_buffers(paddle_model) 138 | 139 | print('+++++++++++++++++++++++++++++++++++') 140 | device = torch.device('cpu') 141 | torch_model = timm.create_model('swin_base_patch4_window7_224', pretrained=True) 142 | torch_model = torch_model.to(device) 143 | torch_model.eval() 144 | print_model_named_params(torch_model) 145 | print_model_named_buffers(torch_model) 146 | 147 | # convert weights 148 | paddle_model = convert(torch_model, paddle_model) 149 | 150 | # check correctness 151 | x = np.random.randn(2, 3, 224, 224).astype('float32') 152 | x_paddle = paddle.to_tensor(x) 153 | x_torch = torch.Tensor(x).to(device) 154 | 155 | out_torch = torch_model(x_torch) 156 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||') 157 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||') 158 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||') 159 | out_paddle = paddle_model(x_paddle) 160 | 161 | out_torch = out_torch.data.cpu().numpy() 162 | out_paddle = out_paddle.cpu().numpy() 163 | 164 | print(out_torch.shape, out_paddle.shape) 165 | print(out_torch[0, 0:20]) 166 | print(out_paddle[0, 0:20]) 167 | assert np.allclose(out_torch, out_paddle, atol = 1e-4) 168 | 169 | # save weights for paddle model 170 | model_path = os.path.join('./swin_base_patch4_window7_224.pdparams') 171 | paddle.save(paddle_model.state_dict(), model_path) 172 | 173 | 174 | 175 | #tmp = np.random.randn(1, 56, 128, 128).astype('float32') 176 | #xp = paddle.to_tensor(tmp) 177 | #xt = torch.Tensor(tmp).to(device) 178 | #xps = paddle.roll(xp, shifts=(-3, -3), axis=(1,2)) 179 | #xts = torch.roll(xt,shifts=(-3, -3), dims=(1,2)) 180 | #xps = xps.cpu().numpy() 181 | #xts = xts.data.cpu().numpy() 182 | #assert np.allclose(xps, xts, atol=1e-4) 183 | 184 | if __name__ == "__main__": 185 | main() 186 | -------------------------------------------------------------------------------- /port_weights/load_pytorch_weights_large_384.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import numpy as np 17 | import paddle 18 | import torch 19 | import timm 20 | from swin_transformer import * 21 | from config import * 22 | 23 | 24 | config = get_config('./configs/swin_large_patch4_window12_384.yaml') 25 | print(config) 26 | 27 | 28 | def print_model_named_params(model): 29 | print('----------------------------------') 30 | for name, param in model.named_parameters(): 31 | print(name, param.shape) 32 | print('----------------------------------') 33 | 34 | 35 | def print_model_named_buffers(model): 36 | print('----------------------------------') 37 | for name, param in model.named_buffers(): 38 | print(name, param.shape) 39 | print('----------------------------------') 40 | 41 | 42 | def torch_to_paddle_mapping(): 43 | mapping = [ 44 | ('patch_embed.proj', 'patch_embedding.patch_embed'), 45 | ('patch_embed.norm', 'patch_embedding.norm'), 46 | ] 47 | 48 | # torch 'layers' to paddle 'stages' 49 | depths = config.MODEL.TRANS.STAGE_DEPTHS 50 | num_stages = len(depths) 51 | for stage_idx in range(num_stages): 52 | pp_s_prefix = f'stages.{stage_idx}.blocks' 53 | th_s_prefix = f'layers.{stage_idx}.blocks' 54 | for block_idx in range(depths[stage_idx]): 55 | th_b_prefix = f'{th_s_prefix}.{block_idx}' 56 | pp_b_prefix = f'{pp_s_prefix}.{block_idx}' 57 | layer_mapping = [ 58 | (f'{th_b_prefix}.norm1', f'{pp_b_prefix}.norm1'), 59 | (f'{th_b_prefix}.attn.relative_position_bias_table', f'{pp_b_prefix}.attn.relative_position_bias_table'), 60 | (f'{th_b_prefix}.attn.qkv', f'{pp_b_prefix}.attn.qkv'), 61 | (f'{th_b_prefix}.attn.proj', f'{pp_b_prefix}.attn.proj'), 62 | (f'{th_b_prefix}.norm2', f'{pp_b_prefix}.norm2'), 63 | (f'{th_b_prefix}.mlp.fc1', f'{pp_b_prefix}.mlp.fc1'), 64 | (f'{th_b_prefix}.mlp.fc2', f'{pp_b_prefix}.mlp.fc2'), 65 | ] 66 | mapping.extend(layer_mapping) 67 | # stage downsample: last stage does not have downsample ops 68 | if stage_idx < num_stages - 1: 69 | mapping.extend([ 70 | (f'layers.{stage_idx}.downsample.reduction.weight', f'stages.{stage_idx}.downsample.reduction.weight'), 71 | (f'layers.{stage_idx}.downsample.norm', f'stages.{stage_idx}.downsample.norm')]) 72 | 73 | mapping.extend([ 74 | ('norm', 'norm'), 75 | ('head', 'fc')]) 76 | return mapping 77 | 78 | 79 | 80 | def convert(torch_model, paddle_model): 81 | def _set_value(th_name, pd_name, no_transpose=False): 82 | th_shape = th_params[th_name].shape 83 | pd_shape = tuple(pd_params[pd_name].shape) # paddle shape default type is list 84 | #assert th_shape == pd_shape, f'{th_shape} != {pd_shape}' 85 | print(f'set {th_name} {th_shape} to {pd_name} {pd_shape}') 86 | value = th_params[th_name].data.numpy() 87 | if len(value.shape) == 2: 88 | if not no_transpose: 89 | value = value.transpose((1, 0)) 90 | pd_params[pd_name].set_value(value) 91 | 92 | # 1. get paddle and torch model parameters 93 | pd_params = {} 94 | th_params = {} 95 | for name, param in paddle_model.named_parameters(): 96 | pd_params[name] = param 97 | for name, param in torch_model.named_parameters(): 98 | th_params[name] = param 99 | 100 | for name, param in paddle_model.named_buffers(): 101 | pd_params[name] = param 102 | for name, param in torch_model.named_buffers(): 103 | th_params[name] = param 104 | 105 | # 2. get name mapping pairs 106 | mapping = torch_to_paddle_mapping() 107 | # 3. set torch param values to paddle params: may needs transpose on weights 108 | for th_name, pd_name in mapping: 109 | if th_name in th_params.keys(): # nn.Parameters 110 | if th_name.endswith('relative_position_bias_table'): 111 | _set_value(th_name, pd_name, no_transpose=True) 112 | else: 113 | _set_value(th_name, pd_name) 114 | else: # weight & bias 115 | th_name_w = f'{th_name}.weight' 116 | pd_name_w = f'{pd_name}.weight' 117 | _set_value(th_name_w, pd_name_w) 118 | 119 | if f'{th_name}.bias' in th_params.keys(): 120 | th_name_b = f'{th_name}.bias' 121 | pd_name_b = f'{pd_name}.bias' 122 | _set_value(th_name_b, pd_name_b) 123 | 124 | return paddle_model 125 | 126 | 127 | 128 | 129 | 130 | def main(): 131 | 132 | paddle.set_device('cpu') 133 | paddle_model = build_swin(config) 134 | paddle_model.eval() 135 | 136 | print_model_named_params(paddle_model) 137 | print_model_named_buffers(paddle_model) 138 | 139 | print('+++++++++++++++++++++++++++++++++++') 140 | device = torch.device('cpu') 141 | torch_model = timm.create_model('swin_large_patch4_window12_384', pretrained=True) 142 | torch_model = torch_model.to(device) 143 | torch_model.eval() 144 | print_model_named_params(torch_model) 145 | print_model_named_buffers(torch_model) 146 | 147 | # convert weights 148 | paddle_model = convert(torch_model, paddle_model) 149 | 150 | # check correctness 151 | x = np.random.randn(2, 3, 384, 384).astype('float32') 152 | x_paddle = paddle.to_tensor(x) 153 | x_torch = torch.Tensor(x).to(device) 154 | 155 | out_torch = torch_model(x_torch) 156 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||') 157 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||') 158 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||') 159 | out_paddle = paddle_model(x_paddle) 160 | 161 | out_torch = out_torch.data.cpu().numpy() 162 | out_paddle = out_paddle.cpu().numpy() 163 | 164 | print(out_torch.shape, out_paddle.shape) 165 | print(out_torch[0, 0:20]) 166 | print(out_paddle[0, 0:20]) 167 | assert np.allclose(out_torch, out_paddle, atol = 1e-4) 168 | 169 | # save weights for paddle model 170 | model_path = os.path.join('./swin_large_patch4_window12_384.pdparams') 171 | paddle.save(paddle_model.state_dict(), model_path) 172 | 173 | 174 | 175 | #tmp = np.random.randn(1, 56, 128, 128).astype('float32') 176 | #xp = paddle.to_tensor(tmp) 177 | #xt = torch.Tensor(tmp).to(device) 178 | #xps = paddle.roll(xp, shifts=(-3, -3), axis=(1,2)) 179 | #xts = torch.roll(xt,shifts=(-3, -3), dims=(1,2)) 180 | #xps = xps.cpu().numpy() 181 | #xts = xts.data.cpu().numpy() 182 | #assert np.allclose(xps, xts, atol=1e-4) 183 | 184 | if __name__ == "__main__": 185 | main() 186 | -------------------------------------------------------------------------------- /modification.md: -------------------------------------------------------------------------------- 1 | # 代码改动说明 2 | 3 | 写在开头:我个人水平有限,对于Swin Transformer的代码理解可能没有很透彻,在修改过程中有错误的话欢迎大家及时指正!也希望能借这个机会多多交流~~ 4 | 5 | ## Model Architecture 6 | 7 | Swin Transformer V2 相比于V1版本提出的三个改动集中在`swin_transformer.py`的`WindowAttention`模块,分别为: 8 | 9 | * 将pre-norm更改为post-norm 10 | * 将点乘attention计算方式更改为cosine attention,并添加用于scaled的参数$\tau$ 11 | * 使用continuous relative position bias替代原本直接学习relative position bias的方式,并将线性的相对坐标更改为log-spaced coordinates 12 | 13 | ### 1. Post-norm 14 | 15 | 直接修改`swin_transformer.py`的`SwinTransformerBlock`中的代码顺序,向后移动`self.norm1(x)`和`self.norm2(x)`到attention以及mlp操作后,shortcut操作之前,例如: 16 | 17 | ```python 18 | # x = self.norm2(x) # Swin-T v1, pre-norm 19 | x = self.mlp(x) # [bs,H*W,C] 20 | x = self.norm2(x) # Swin-T v2, post-norm 21 | if self.drop_path is not None: 22 | x = h + self.drop_path(x) 23 | else: 24 | x = h + x 25 | ``` 26 | 27 | 注意代码中额外添加了`self.norm3`,对应原文的: 28 | 29 | > For SwinV2-H and SwinV2-G, we further introduce a layer normalization unit on the main branch every 6 layers. 30 | 31 | 对于大模型,每隔6个`SwinTransformerBlock`就做一次额外的layer norm。可以通过设置**config**里的`EXTRA_NORM`参数开启。 32 | 33 | ## 2. Attention计算方式 34 | 35 | ### 2.1 Dot product attention 36 | 37 | 原始的swin transformer self-attention计算方式: 38 | $$ 39 | \text { Attention }(Q, K, V)=\operatorname{SoftMax}\left(Q K^{T} / \sqrt{d}+B\right) V 40 | $$ 41 | Softmax内前面的点乘attention计算对应`WindowAttention`模块如下代码: 42 | 43 | ```python 44 | qkv = self.qkv(x).chunk(3, axis=-1) 45 | q, k, v = map(self.transpose_multihead, qkv) 46 | q = q * self.scale # i.e., sqrt(d) 47 | attn = paddle.matmul(q, k, transpose_y=True) 48 | ``` 49 | 50 | ### 2.2 Scaled cosine attention 51 | 52 | V2提出的scaled cosine attention计算方式: 53 | $$ 54 | \operatorname{Sim}\left(\mathbf{q}_{i}, \mathbf{k}_{j}\right)=\cos \left(\mathbf{q}_{i}, \mathbf{k}_{j}\right) / \tau+B_{i j} 55 | $$ 56 | 其中$\tau$每个layer的每个head都不同,是可学习参数,且限定最小取值为0.01。 57 | 58 | 代码更改如下: 59 | 60 | 首先在`__init__`中定义$\tau$: 61 | 62 | ```python 63 | # Swin-T v2, Scaled cosine attention 64 | self.tau = paddle.create_parameter( 65 | shape = [num_heads, window_size[0]*window_size[1], window_size[0]*window_size[1]], 66 | dtype='float32', 67 | default_initializer=paddle.nn.initializer.Constant(1)) 68 | ``` 69 | 70 | 然后在`forward`中: 71 | 72 | ```python 73 | qkv = self.qkv(x).chunk(3, axis=-1) # {list:3} 74 | q, k, v = map(self.transpose_multihead, qkv) # [bs*num_window=1*64,4,49,32] -> [bs*num_window=1*16,8,49,32]-> [bs*num_window=1*4,16,49,32]->[bs*num_window=1*1,32,49,32] 75 | 76 | # Swin-T v2, Scaled cosine attention, Eq.(2) 77 | qk = paddle.matmul(q, k, transpose_y=True) # [bs*num_window=1*64,num_heads=4,49,49] -> [bs*num_window=1*16,num_heads=8,49,49] -> [bs*num_window=1*4,num_heads=16,49,49] -> [bs*num_window=1*1,num_heads=32,49,49] 78 | q2 = paddle.multiply(q, q).sum(-1).sqrt().unsqueeze(3) 79 | k2 = paddle.multiply(k, k).sum(-1).sqrt().unsqueeze(3) 80 | attn = qk/paddle.clip(paddle.matmul(q2, k2, transpose_y=True), min=1e-6) 81 | attn = attn/paddle.clip(self.tau.unsqueeze(0), min=0.01) 82 | ``` 83 | 84 | ## 3.Log-Spaced CPB策略 85 | 86 | ## 3.1 Continuous relative position bias 87 | 88 | 作者在将训练好的模型迁移到更高分辨率以及更大尺度的window size时,发现直接使用双三次插值的方式去扩充relative position bias会导致性能下降很多,如文章的Tabel1第一行所示。因此V2版本使用了**连续相对位置偏差**的方式,这里我认为连续(continuous)指的是利用一个小网络(比如两层全连接中间带一个ReLu)学习每个相对位置坐标对应的bias,利用小网络的泛化性去适应更大尺寸的window size(这里理解的不是很透彻,还需要再研究一下)。 89 | 90 | * 原始模型的代码: 91 | 92 | 首先在`WindowAttention`的`__init__`方法中定义relative_position_bias_table ,并根据当前block对应的window size计算relative_position_index: 93 | 94 | ```python 95 | self.relative_position_bias_table = paddle.create_parameter( 96 | shape=[(2 * window_size[0] -1) * (2 * window_size[1] - 1), num_heads], 97 | dtype='float32', 98 | default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02)) 99 | 100 | # relative position index for each token inside window 101 | coords_h = paddle.arange(self.window_size[0]) 102 | coords_w = paddle.arange(self.window_size[1]) 103 | coords = paddle.stack(paddle.meshgrid([coords_h, coords_w])) # [2, window_h, window_w] 104 | coords_flatten = paddle.flatten(coords, 1) # [2, window_h * window_w] 105 | # 2, window_h * window_w, window_h * window_h 106 | relative_coords = coords_flatten.unsqueeze(2) - coords_flatten.unsqueeze(1) 107 | # winwod_h*window_w, window_h*window_w, 2 108 | relative_coords = relative_coords.transpose([1, 2, 0]) 109 | relative_coords[:, :, 0] += self.window_size[0] - 1 110 | relative_coords[:, :, 1] += self.window_size[1] - 1 111 | relative_coords[:, :, 0] *= 2* self.window_size[1] - 1 112 | # [window_size * window_size, window_size*window_size] 113 | relative_position_index = relative_coords.sum(-1) 114 | self.register_buffer("relative_position_index", relative_position_index) 115 | ``` 116 | 117 | 在`forward`过程中,使用如下方式调用: 118 | 119 | ```python 120 | def get_relative_pos_bias_from_pos_index(self): 121 | table = self.relative_position_bias_table # N x num_heads 122 | # index is a tensor 123 | index = self.relative_position_index.reshape([-1]) # window_h*window_w * window_h*window_w 124 | # NOTE: paddle does NOT support indexing Tensor by a Tensor 125 | relative_position_bias = paddle.index_select(x=table, index=index) 126 | return relative_position_bias 127 | def forward(......): 128 | ...... 129 | relative_position_bias = relative_position_bias.transpose([2, 0, 1]) 130 | attn = attn + relative_position_bias.unsqueeze(0) 131 | ...... 132 | ``` 133 | 134 | * V2对应代码: 135 | 136 | `__init__`中: 137 | 138 | ```python 139 | ## Swin-T v2, small meta network, Eq.(3) 140 | self.cpb = Mlp_Relu(in_features=2, # delta x, delta y 141 | hidden_features=512, # TODO: hidden dims 142 | out_features=self.num_heads, 143 | dropout=dropout) 144 | ``` 145 | 146 | 还需解决的点在于中间隐藏层维度取多少,这里我设置了512。相对坐标的index计算过程在下面一节会说。 147 | 148 | `forward`中: 149 | 150 | ```python 151 | def get_continuous_relative_position_bias(self): 152 | # The continuous position bias approach adopts a small meta network on the relative coordinates 153 | continuous_relative_position_bias = self.cpb(self.log_relative_position_index) 154 | return continuous_relative_position_bias 155 | def forward(......): 156 | ...... 157 | ## Swin-T v2 158 | relative_position_bias = self.get_continuous_relative_position_bias() 159 | relative_position_bias = relative_position_bias.reshape( 160 | [self.window_size[0] * self.window_size[1], 161 | self.window_size[0] * self.window_size[1], 162 | -1]) 163 | 164 | # nH, window_h*window_w, window_h*window_w 165 | relative_position_bias = relative_position_bias.transpose([2, 0, 1]) 166 | attn = attn + relative_position_bias.unsqueeze(0) 167 | ...... 168 | ``` 169 | 170 | ### 3.2 Log-spaced coordinates 171 | 172 | 此外,作者提到: 173 | 174 | > When transferred across largely varied window sizes, there will be a large portion of relative coordinate range requiring extrapolation. 175 | 176 | 原先的线性编码计算patch之间的相对位置偏差会导致模型在迁移到更大尺寸的window size时,插值的变化范围也会间隔较大。因此提出: 177 | 178 | >we propose to use the log-spaced coordinates instead of the original linear-spaced ones 179 | 180 | log-spaced coordinates文章中对应公式4: 181 | $$ 182 | \begin{aligned} 183 | &\widehat{\Delta x}=\operatorname{sign}(x) \cdot \log (1+|\Delta x|) \\ 184 | &\widehat{\Delta y}=\operatorname{sign}(y) \cdot \log (1+|\Delta y|) 185 | \end{aligned} 186 | $$ 187 | 但是我感觉$\operatorname{sign}(·)$里面应该是$\Delta x$和$\Delta y$,对应的修改后代码: 188 | 189 | ```python 190 | # relative position index for each token inside window 191 | coords_h = paddle.arange(self.window_size[0]) 192 | coords_w = paddle.arange(self.window_size[1]) 193 | coords = paddle.stack(paddle.meshgrid([coords_h, coords_w])) # [2, window_h, window_w] 194 | coords_flatten = paddle.flatten(coords, 1) # [2, window_h * window_w] 195 | # 2, window_h * window_w, window_h * window_h 196 | relative_coords = coords_flatten.unsqueeze(2) - coords_flatten.unsqueeze(1) 197 | # winwod_h*window_w, window_h*window_w, 2 198 | relative_coords = relative_coords.transpose([1, 2, 0]) 199 | 200 | ## Swin-T v2, log-spaced coordinates, Eq.(4) 201 | log_relative_position_index = paddle.multiply(relative_coords.cast(dtype='float32').sign(), 202 | paddle.log((relative_coords.cast(dtype='float32').abs()+1))) 203 | self.register_buffer("log_relative_position_index", log_relative_position_index) 204 | ``` 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Dataset related classes and methods for ViT training and validation 17 | Cifar10, Cifar100 and ImageNet2012 are supported 18 | """ 19 | 20 | import os 21 | import math 22 | from paddle.io import Dataset 23 | from paddle.io import DataLoader 24 | from paddle.io import DistributedBatchSampler 25 | from paddle.vision import transforms 26 | from paddle.vision import datasets 27 | from paddle.vision import image_load 28 | from auto_augment import auto_augment_policy_original 29 | from auto_augment import AutoAugment 30 | from transforms import RandomHorizontalFlip 31 | from random_erasing import RandomErasing 32 | 33 | class ImageNet2012Dataset(Dataset): 34 | """Build ImageNet2012 dataset 35 | 36 | This class gets train/val imagenet datasets, which loads transfomed data and labels. 37 | 38 | Attributes: 39 | file_folder: path where imagenet images are stored 40 | transform: preprocessing ops to apply on image 41 | img_path_list: list of full path of images in whole dataset 42 | label_list: list of labels of whole dataset 43 | """ 44 | 45 | def __init__(self, file_folder, mode="train", transform=None): 46 | """Init ImageNet2012 Dataset with dataset file path, mode(train/val), and transform""" 47 | super(ImageNet2012Dataset, self).__init__() 48 | assert mode in ["train", "val"] 49 | self.file_folder = file_folder 50 | self.transform = transform 51 | self.img_path_list = [] 52 | self.label_list = [] 53 | self.mode = mode 54 | 55 | if mode == "train": 56 | self.list_file = os.path.join(self.file_folder, "Annotations", "CLS-LOC", "train.txt") 57 | else: 58 | self.list_file = os.path.join(self.file_folder, "Annotations", "CLS-LOC", "val.txt") 59 | 60 | with open(self.list_file, 'r') as infile: 61 | for line in infile: 62 | img_path = line.strip().split()[0] 63 | img_label = int(line.strip().split()[1]) 64 | self.img_path_list.append(os.path.join(self.file_folder, "Data", "CLS-LOC", self.mode, img_path)) 65 | self.label_list.append(img_label) 66 | print(f'----- Imagenet2012 image {mode} list len = {len(self.label_list)}') 67 | 68 | def __len__(self): 69 | return len(self.label_list) 70 | 71 | def __getitem__(self, index): 72 | data = image_load(self.img_path_list[index]).convert('RGB') 73 | data = self.transform(data) 74 | label = self.label_list[index] 75 | 76 | return data, label 77 | 78 | 79 | def get_train_transforms(config): 80 | """ Get training transforms 81 | 82 | For training, a RandomResizedCrop is applied, then normalization is applied with 83 | [0.5, 0.5, 0.5] mean and std. The input pixel values must be rescaled to [0, 1.] 84 | Outputs is converted to tensor 85 | 86 | Args: 87 | config: configs contains IMAGE_SIZE, see config.py for details 88 | Returns: 89 | transforms_train: training transforms 90 | """ 91 | 92 | aug_op_list = [] 93 | # STEP1: random crop and resize 94 | aug_op_list.append( 95 | transforms.RandomResizedCrop((config.DATA.IMAGE_SIZE, config.DATA.IMAGE_SIZE), 96 | scale=(0.05, 1.0), interpolation='bicubic')) 97 | # STEP2: auto_augment or color jitter 98 | if config.TRAIN.AUTO_AUGMENT: 99 | policy = auto_augment_policy_original() 100 | auto_augment = AutoAugment(policy) 101 | aug_op_list.append(auto_augment) 102 | else: 103 | jitter = (float(config.TRAIN.COLOR_JITTER), ) * 3 104 | aug_op_list.append(transforms.ColorJitter(*jitter)) 105 | # STEP3: other ops 106 | aug_op_list.append(transforms.ToTensor()) 107 | aug_op_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) 108 | # STEP4: random erasing 109 | if config.TRAIN.RANDOM_ERASE_PROB > 0.: 110 | random_erasing = RandomErasing(prob=config.TRAIN.RANDOM_ERASE_PROB, 111 | mode=config.TRAIN.RANDOM_ERASE_MODE, 112 | max_count=config.TRAIN.RANDOM_ERASE_COUNT, 113 | num_splits=config.TRAIN.RANDOM_ERASE_SPLIT) 114 | aug_op_list.append(random_erasing) 115 | # Final: compose transforms and return 116 | transforms_train = transforms.Compose(aug_op_list) 117 | return transforms_train 118 | 119 | 120 | def get_val_transforms(config): 121 | """ Get training transforms 122 | 123 | For validation, image is first Resize then CenterCrop to image_size. 124 | Then normalization is applied with [0.5, 0.5, 0.5] mean and std. 125 | The input pixel values must be rescaled to [0, 1.] 126 | Outputs is converted to tensor 127 | 128 | Args: 129 | config: configs contains IMAGE_SIZE, see config.py for details 130 | Returns: 131 | transforms_train: training transforms 132 | """ 133 | 134 | scale_size = int(math.floor(config.DATA.IMAGE_SIZE / config.DATA.CROP_PCT)) 135 | transforms_val = transforms.Compose([ 136 | transforms.Resize(scale_size, interpolation='bicubic'), 137 | transforms.CenterCrop((config.DATA.IMAGE_SIZE, config.DATA.IMAGE_SIZE)), 138 | transforms.ToTensor(), 139 | #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 140 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 141 | ]) 142 | return transforms_val 143 | 144 | 145 | def get_dataset(config, mode='train'): 146 | """ Get dataset from config and mode (train/val) 147 | 148 | Returns the related dataset object according to configs and mode(train/val) 149 | 150 | Args: 151 | config: configs contains dataset related settings. see config.py for details 152 | Returns: 153 | dataset: dataset object 154 | """ 155 | 156 | assert mode in ['train', 'val'] 157 | if config.DATA.DATASET == "cifar10": 158 | if mode == 'train': 159 | dataset = datasets.Cifar10(mode=mode, transform=get_train_transforms(config)) 160 | else: 161 | dataset = datasets.Cifar10(mode=mode, transform=get_val_transforms(config)) 162 | elif config.DATA.DATASET == "cifar100": 163 | if mode == 'train': 164 | dataset = datasets.Cifar100(mode=mode, transform=get_train_transforms(config)) 165 | else: 166 | dataset = datasets.Cifar100(mode=mode, transform=get_val_transforms(config)) 167 | elif config.DATA.DATASET == "imagenet2012": 168 | if mode == 'train': 169 | dataset = ImageNet2012Dataset(config.DATA.DATA_PATH, 170 | mode=mode, 171 | transform=get_train_transforms(config)) 172 | else: 173 | dataset = ImageNet2012Dataset(config.DATA.DATA_PATH, 174 | mode=mode, 175 | transform=get_val_transforms(config)) 176 | else: 177 | raise NotImplementedError( 178 | "[{config.DATA.DATASET}] Only cifar10, cifar100, imagenet2012 are supported now") 179 | return dataset 180 | 181 | 182 | def get_dataloader(config, dataset, mode='train', multi_process=False): 183 | """Get dataloader with config, dataset, mode as input, allows multiGPU settings. 184 | 185 | Multi-GPU loader is implements as distributedBatchSampler. 186 | 187 | Args: 188 | config: see config.py for details 189 | dataset: paddle.io.dataset object 190 | mode: train/val 191 | multi_process: if True, use DistributedBatchSampler to support multi-processing 192 | Returns: 193 | dataloader: paddle.io.DataLoader object. 194 | """ 195 | 196 | if mode == 'train': 197 | batch_size = config.DATA.BATCH_SIZE 198 | else: 199 | batch_size = config.DATA.BATCH_SIZE_EVAL 200 | 201 | if multi_process is True: 202 | sampler = DistributedBatchSampler(dataset, 203 | batch_size=batch_size, 204 | shuffle=(mode == 'train')) 205 | dataloader = DataLoader(dataset, 206 | batch_sampler=sampler, 207 | num_workers=config.DATA.NUM_WORKERS) 208 | else: 209 | dataloader = DataLoader(dataset, 210 | batch_size=batch_size, 211 | num_workers=config.DATA.NUM_WORKERS, 212 | shuffle=(mode == 'train')) 213 | return dataloader 214 | -------------------------------------------------------------------------------- /mixup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """mixup and cutmix for batch data""" 16 | import numpy as np 17 | import paddle 18 | 19 | 20 | def rand_bbox(image_shape, lam, count=None): 21 | """ CutMix bbox by lam value 22 | Generate 1 random bbox by value lam. lam is the cut size rate. 23 | The cut_size is computed by sqrt(1-lam) * image_size. 24 | 25 | Args: 26 | image_shape: tuple/list, image height and width 27 | lam: float, cutmix lambda value 28 | count: int, number of bbox to generate 29 | """ 30 | image_h, image_w = image_shape[-2:] 31 | cut_rate = np.sqrt(1. - lam) 32 | cut_h = int(cut_rate * image_h) 33 | cut_w = int(cut_rate * image_w) 34 | 35 | # get random bbox center 36 | cy = np.random.randint(0, image_h, size=count) 37 | cx = np.random.randint(0, image_w, size=count) 38 | 39 | # get bbox coords 40 | bbox_x1 = np.clip(cx - cut_w // 2, 0, image_w) 41 | bbox_y1 = np.clip(cy - cut_h // 2, 0, image_h) 42 | bbox_x2 = np.clip(cx + cut_w // 2, 0, image_w) 43 | bbox_y2 = np.clip(cy + cut_h // 2, 0, image_h) 44 | 45 | # NOTE: in paddle, tensor indexing e.g., a[x1:x2], 46 | # if x1 == x2, paddle will raise ValueErros, 47 | # while in pytorch, it will return [] tensor 48 | return bbox_x1, bbox_y1, bbox_x2, bbox_y2 49 | 50 | 51 | def rand_bbox_minmax(image_shape, minmax, count=None): 52 | """ CutMix bbox by min and max value 53 | Generate 1 random bbox by min and max percentage values. 54 | Minmax is a tuple/list of min and max percentage vlaues 55 | applied to the image width and height. 56 | 57 | Args: 58 | image_shape: tuple/list, image height and width 59 | minmax: tuple/list, min and max percentage values of image size 60 | count: int, number of bbox to generate 61 | """ 62 | assert len(minmax) == 2 63 | image_h, image_w = image_shape[-2:] 64 | min_ratio = minmax[0] 65 | max_ratio = minmax[1] 66 | cut_h = np.random.randint(int(image_h * min_ratio), int(image_h * max_ratio), size=count) 67 | cut_w = np.random.randint(int(image_w * min_ratio), int(image_w * max_ratio), size=count) 68 | 69 | bbox_x1 = np.random.randint(0, image_w - cut_w, size=count) 70 | bbox_y1 = np.random.randint(0, image_h - cut_h, size=count) 71 | bbox_x2 = bbox_x1 + cut_w 72 | bbox_y2 = bbox_y1 + cut_h 73 | 74 | return bbox_x1, bbox_y1, bbox_x2, bbox_y2 75 | 76 | 77 | def cutmix_generate_bbox_adjust_lam(image_shape, lam, minmax=None, correct_lam=True, count=None): 78 | """Generate bbox and apply correction for lambda 79 | If the mimmax is None, apply the standard cutmix by lam value, 80 | If the minmax is set, apply the cutmix by min and max percentage values. 81 | 82 | Args: 83 | image_shape: tuple/list, image height and width 84 | lam: float, cutmix lambda value 85 | minmax: tuple/list, min and max percentage values of image size 86 | correct_lam: bool, if True, correct the lam value by the generated bbox 87 | count: int, number of bbox to generate 88 | """ 89 | if minmax is not None: 90 | bbox_x1, bbox_y1, bbox_x2, bbox_y2 = rand_bbox_minmax(image_shape, minmax, count) 91 | else: 92 | bbox_x1, bbox_y1, bbox_x2, bbox_y2 = rand_bbox(image_shape, lam, count) 93 | 94 | if correct_lam or minmax is not None: 95 | image_h, image_w = image_shape[-2:] 96 | bbox_area = (bbox_y2 - bbox_y1) * (bbox_x2 - bbox_x1) 97 | lam = 1. - bbox_area / float(image_h * image_w) 98 | return (bbox_x1, bbox_y1, bbox_x2, bbox_y2), lam 99 | 100 | 101 | def one_hot(x, num_classes, on_value=1., off_value=0.): 102 | """ Generate one-hot vector for label smoothing 103 | Args: 104 | x: tensor, contains label/class indices 105 | num_classes: int, num of classes (len of the one-hot vector) 106 | on_value: float, the vector value at label index, default=1. 107 | off_value: float, the vector value at non-label indices, default=0. 108 | Returns: 109 | one_hot: tensor, tensor with on value at label index and off value 110 | at non-label indices. 111 | """ 112 | x = x.reshape_([-1, 1]) 113 | x_smoothed = paddle.full((x.shape[0], num_classes), fill_value=off_value) 114 | for i in range(x.shape[0]): 115 | x_smoothed[i, x[i]] = on_value 116 | return x_smoothed 117 | 118 | 119 | def mixup_one_hot(label, num_classes, lam=1., smoothing=0.): 120 | """ mixup and label smoothing in batch 121 | label smoothing is firstly applied, then 122 | mixup is applied by mixing the bacth and its flip, 123 | with a mixup rate. 124 | 125 | Args: 126 | label: tensor, label tensor with shape [N], contains the class indices 127 | num_classes: int, num of all classes 128 | lam: float, mixup rate, default=1.0 129 | smoothing: float, label smoothing rate 130 | """ 131 | off_value = smoothing / num_classes 132 | on_value = 1. - smoothing + off_value 133 | y1 = one_hot(label, num_classes, on_value, off_value) 134 | y2 = one_hot(label.flip(axis=[0]), num_classes, on_value, off_value) 135 | return y2 * (1 - lam) + y1 * lam 136 | 137 | 138 | class Mixup: 139 | """Mixup class 140 | Args: 141 | mixup_alpha: float, mixup alpha for beta distribution, default=1.0, 142 | cutmix_alpha: float, cutmix alpha for beta distribution, default=0.0, 143 | cutmix_minmax: list/tuple, min and max value for cutmix ratio, default=None, 144 | prob: float, if random prob < prob, do not use mixup, default=1.0, 145 | switch_prob: float, prob of switching mixup and cutmix, default=0.5, 146 | mode: string, mixup up, now only 'batch' is supported, default='batch', 147 | correct_lam: bool, if True, apply correction of lam, default=True, 148 | label_smoothing: float, label smoothing rate, default=0.1, 149 | num_classes: int, num of classes, default=1000 150 | """ 151 | def __init__(self, 152 | mixup_alpha=1.0, 153 | cutmix_alpha=0.0, 154 | cutmix_minmax=None, 155 | prob=1.0, 156 | switch_prob=0.5, 157 | mode='batch', 158 | correct_lam=True, 159 | label_smoothing=0.1, 160 | num_classes=1000): 161 | self.mixup_alpha = mixup_alpha 162 | self.cutmix_alpha = cutmix_alpha 163 | self.cutmix_minmax = cutmix_minmax 164 | if cutmix_minmax is not None: 165 | assert len(cutmix_minmax) == 2 166 | self.cutmix_alpha = 1.0 167 | self.mix_prob = prob 168 | self.switch_prob = switch_prob 169 | self.label_smoothing = label_smoothing 170 | self.num_classes = num_classes 171 | self.mode = mode 172 | self.correct_lam = correct_lam 173 | assert mode == 'batch', 'Now only batch mode is supported!' 174 | 175 | def __call__(self, x, target): 176 | assert x.shape[0] % 2 == 0, "Batch size should be even" 177 | lam = self._mix_batch(x) 178 | target = mixup_one_hot(target, self.num_classes, lam, self.label_smoothing) 179 | return x, target 180 | 181 | def get_params(self): 182 | """Decide to use cutmix or regular mixup by sampling and 183 | sample lambda for mixup 184 | """ 185 | lam = 1. 186 | use_cutmix = False 187 | use_mixup = np.random.rand() < self.mix_prob 188 | if use_mixup: 189 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: 190 | use_cutmix = np.random.rand() < self.switch_prob 191 | alpha = self.cutmix_alpha if use_cutmix else self.mixup_alpha 192 | lam_mix = np.random.beta(alpha, alpha) 193 | elif self.mixup_alpha == 0. and self.cutmix_alpha > 0.: 194 | use_cutmix=True 195 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) 196 | elif self.mixup_alpha > 0. and self.cutmix_alpha == 0.: 197 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) 198 | else: 199 | raise ValueError('mixup_alpha and cutmix_alpha cannot be all 0') 200 | lam = float(lam_mix) 201 | return lam, use_cutmix 202 | 203 | def _mix_batch(self, x): 204 | """mixup/cutmix by adding batch data and its flipped version""" 205 | lam, use_cutmix = self.get_params() 206 | if lam == 1.: 207 | return lam 208 | if use_cutmix: 209 | (bbox_x1, bbox_y1, bbox_x2, bbox_y2), lam = cutmix_generate_bbox_adjust_lam( 210 | x.shape, 211 | lam, 212 | minmax=self.cutmix_minmax, 213 | correct_lam=self.correct_lam) 214 | 215 | # NOTE: in paddle, tensor indexing e.g., a[x1:x2], 216 | # if x1 == x2, paddle will raise ValueErros, 217 | # but in pytorch, it will return [] tensor without errors 218 | if int(bbox_x1) != int(bbox_x2) and int(bbox_y1) != int(bbox_y2): 219 | x[:, :, int(bbox_x1): int(bbox_x2), int(bbox_y1): int(bbox_y2)] = x.flip(axis=[0])[ 220 | :, :, int(bbox_x1): int(bbox_x2), int(bbox_y1): int(bbox_y2)] 221 | else: 222 | x_flipped = x.flip(axis=[0]) 223 | x_flipped = x_flipped * (1 - lam) 224 | x.set_value(x * (lam) + x_flipped) 225 | return lam 226 | -------------------------------------------------------------------------------- /auto_augment.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | 14 | """Auto Augmentation""" 15 | 16 | import random 17 | import numpy as np 18 | from PIL import Image, ImageEnhance, ImageOps 19 | 20 | 21 | def auto_augment_policy_original(): 22 | """ImageNet auto augment policy""" 23 | policy = [ 24 | [('Posterize', 0.4, 8), ('Rotate', 0.6, 9)], 25 | [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], 26 | [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], 27 | [('Posterize', 0.6, 7), ('Posterize', 0.6, 6)], 28 | [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], 29 | [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], 30 | [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], 31 | [('Posterize', 0.8, 5), ('Equalize', 1.0, 2)], 32 | [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], 33 | [('Equalize', 0.6, 8), ('Posterize', 0.4, 6)], 34 | [('Rotate', 0.8, 8), ('Color', 0.4, 0)], 35 | [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], 36 | [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], 37 | [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], 38 | [('Color', 0.6, 4), ('Contrast', 1.0, 8)], 39 | [('Rotate', 0.8, 8), ('Color', 1.0, 2)], 40 | [('Color', 0.8, 8), ('Solarize', 0.8, 7)], 41 | [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], 42 | [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], 43 | [('Color', 0.4, 0), ('Equalize', 0.6, 3)], 44 | [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], 45 | [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], 46 | [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], 47 | [('Color', 0.6, 4), ('Contrast', 1.0, 8)], 48 | [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], 49 | ] 50 | policy = [[SubPolicy(*args) for args in subpolicy] for subpolicy in policy] 51 | return policy 52 | 53 | 54 | class AutoAugment(): 55 | """Auto Augment 56 | Randomly choose a tuple of augment ops from a list of policy 57 | Then apply the tuple of augment ops to input image 58 | """ 59 | def __init__(self, policy): 60 | self.policy = policy 61 | 62 | def __call__(self, image, policy_idx=None): 63 | if policy_idx is None: 64 | policy_idx = random.randint(0, len(self.policy)-1) 65 | 66 | sub_policy = self.policy[policy_idx] 67 | for op in sub_policy: 68 | image = op(image) 69 | return image 70 | 71 | 72 | class SubPolicy: 73 | """Subpolicy 74 | Read augment name and magnitude, apply augment with probability 75 | Args: 76 | op_name: str, augment operation name 77 | prob: float, if prob > random prob, apply augment 78 | magnitude_idx: int, index of magnitude in preset magnitude ranges 79 | """ 80 | def __init__(self, op_name, prob, magnitude_idx): 81 | # ranges of operations' magnitude 82 | ranges = { 83 | 'ShearX': np.linspace(0, 0.3, 10), # [-0.3, 0.3] (by random negative) 84 | 'ShearY': np.linspace(0, 0.3, 10), # [-0.3, 0.3] (by random negative) 85 | 'TranslateX': np.linspace(0, 150 / 331, 10), #[-0.45, 0.45] (by random negative) 86 | 'TranslateY': np.linspace(0, 150 / 331, 10), #[-0.45, 0.45] (by random negative) 87 | 'Rotate': np.linspace(0, 30, 10), #[-30, 30] (by random negative) 88 | 'Color': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative) 89 | 'Posterize': np.round(np.linspace(8, 4, 10), 0).astype(np.int), #[0, 4] 90 | 'Solarize': np.linspace(256, 0, 10), #[0, 256] 91 | 'Contrast': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative) 92 | 'Sharpness': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative) 93 | 'Brightness': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative) 94 | 'AutoContrast': [0] * 10, # no range 95 | 'Equalize': [0] * 10, # no range 96 | 'Invert': [0] * 10, # no range 97 | } 98 | 99 | # augmentation operations 100 | # Lambda is not pickleable for DDP 101 | #image_ops = { 102 | # 'ShearX': lambda image, magnitude: shear_x(image, magnitude), 103 | # 'ShearY': lambda image, magnitude: shear_y(image, magnitude), 104 | # 'TranslateX': lambda image, magnitude: translate_x(image, magnitude), 105 | # 'TranslateY': lambda image, magnitude: translate_y(image, magnitude), 106 | # 'Rotate': lambda image, magnitude: rotate(image, magnitude), 107 | # 'AutoContrast': lambda image, magnitude: auto_contrast(image, magnitude), 108 | # 'Invert': lambda image, magnitude: invert(image, magnitude), 109 | # 'Equalize': lambda image, magnitude: equalize(image, magnitude), 110 | # 'Solarize': lambda image, magnitude: solarize(image, magnitude), 111 | # 'Posterize': lambda image, magnitude: posterize(image, magnitude), 112 | # 'Contrast': lambda image, magnitude: contrast(image, magnitude), 113 | # 'Color': lambda image, magnitude: color(image, magnitude), 114 | # 'Brightness': lambda image, magnitude: brightness(image, magnitude), 115 | # 'Sharpness': lambda image, magnitude: sharpness(image, magnitude), 116 | #} 117 | image_ops = { 118 | 'ShearX': shear_x, 119 | 'ShearY': shear_y, 120 | 'TranslateX': translate_x_relative, 121 | 'TranslateY': translate_y_relative, 122 | 'Rotate': rotate, 123 | 'AutoContrast': auto_contrast, 124 | 'Invert': invert, 125 | 'Equalize': equalize, 126 | 'Solarize': solarize, 127 | 'Posterize': posterize, 128 | 'Contrast': contrast, 129 | 'Color': color, 130 | 'Brightness': brightness, 131 | 'Sharpness': sharpness, 132 | } 133 | 134 | self.prob = prob 135 | self.magnitude = ranges[op_name][magnitude_idx] 136 | self.op = image_ops[op_name] 137 | 138 | def __call__(self, image): 139 | if self.prob > random.random(): 140 | image = self.op(image, self.magnitude) 141 | return image 142 | 143 | 144 | # PIL Image transforms 145 | # https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.transform 146 | def shear_x(image, magnitude, fillcolor=(128, 128, 128)): 147 | factor = magnitude * random.choice([-1, 1]) # random negative 148 | return image.transform(image.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), fillcolor=fillcolor) 149 | 150 | 151 | def shear_y(image, magnitude, fillcolor=(128, 128, 128)): 152 | factor = magnitude * random.choice([-1, 1]) # random negative 153 | return image.transform(image.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), fillcolor=fillcolor) 154 | 155 | 156 | def translate_x_relative(image, magnitude, fillcolor=(128, 128, 128)): 157 | pixels = magnitude * image.size[0] 158 | pixels = pixels * random.choice([-1, 1]) # random negative 159 | return image.transform(image.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), fillcolor=fillcolor) 160 | 161 | 162 | def translate_y_relative(image, magnitude, fillcolor=(128, 128, 128)): 163 | pixels = magnitude * image.size[0] 164 | pixels = pixels * random.choice([-1, 1]) # random negative 165 | return image.transform(image.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), fillcolor=fillcolor) 166 | 167 | 168 | def translate_x_absolute(image, magnitude, fillcolor=(128, 128, 128)): 169 | magnitude = magnitude * random.choice([-1, 1]) # random negative 170 | return image.transform(image.size, Image.AFFINE, (1, 0, magnitude, 0, 1, 0), fillcolor=fillcolor) 171 | 172 | 173 | def translate_y_absolute(image, magnitude, fillcolor=(128, 128, 128)): 174 | magnitude = magnitude * random.choice([-1, 1]) # random negative 175 | return image.transform(image.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude), fillcolor=fillcolor) 176 | 177 | 178 | def rotate(image, magnitude): 179 | rot = image.convert("RGBA").rotate(magnitude) 180 | return Image.composite(rot, 181 | Image.new('RGBA', rot.size, (128, ) * 4), 182 | rot).convert(image.mode) 183 | 184 | 185 | def auto_contrast(image, magnitude=None): 186 | return ImageOps.autocontrast(image) 187 | 188 | 189 | def invert(image, magnitude=None): 190 | return ImageOps.invert(image) 191 | 192 | 193 | def equalize(image, magnitude=None): 194 | return ImageOps.equalize(image) 195 | 196 | 197 | def solarize(image, magnitude): 198 | return ImageOps.solarize(image, magnitude) 199 | 200 | 201 | def posterize(image, magnitude): 202 | return ImageOps.posterize(image, magnitude) 203 | 204 | 205 | def contrast(image, magnitude): 206 | magnitude = magnitude * random.choice([-1, 1]) # random negative 207 | return ImageEnhance.Contrast(image).enhance(1 + magnitude) 208 | 209 | 210 | def color(image, magnitude): 211 | magnitude = magnitude * random.choice([-1, 1]) # random negative 212 | return ImageEnhance.Color(image).enhance(1 + magnitude) 213 | 214 | 215 | def brightness(image, magnitude): 216 | magnitude = magnitude * random.choice([-1, 1]) # random negative 217 | return ImageEnhance.Brightness(image).enhance(1 + magnitude) 218 | 219 | 220 | def sharpness(image, magnitude): 221 | magnitude = magnitude * random.choice([-1, 1]) # random negative 222 | return ImageEnhance.Sharpness(image).enhance(1 + magnitude) 223 | 224 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /main_single_gpu.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Swin training/validation using single GPU """ 17 | 18 | import sys 19 | import os 20 | import time 21 | import logging 22 | import argparse 23 | import random 24 | import numpy as np 25 | import warnings 26 | warnings.filterwarnings('ignore') 27 | 28 | import paddle 29 | import paddle.nn as nn 30 | import paddle.nn.functional as F 31 | from datasets import get_dataloader 32 | from datasets import get_dataset 33 | 34 | from utils import AverageMeter 35 | from utils import WarmupCosineScheduler 36 | from utils import get_exclude_from_weight_decay_fn 37 | from config import get_config 38 | from config import update_config 39 | from mixup import Mixup 40 | from losses import LabelSmoothingCrossEntropyLoss 41 | from losses import SoftTargetCrossEntropyLoss 42 | from losses import DistillationLoss 43 | from swin_transformer import build_swin as build_model 44 | 45 | 46 | 47 | def get_arguments(): 48 | """return argumeents, this will overwrite the config after loading yaml file""" 49 | parser = argparse.ArgumentParser('Swin') 50 | parser.add_argument('-cfg', type=str, default='/home/ubuntu13/lsz/code/S-T-V2/PaddleViT/image_classification/SwinTransformerV2/configs/swinv2_base_patch4_window7_224.yaml') 51 | parser.add_argument('-dataset', type=str, default='imagenet2012') 52 | parser.add_argument('-batch_size', type=int, default=48) 53 | parser.add_argument('-image_size', type=int, default=None) 54 | parser.add_argument('-data_path', type=str, default='/home/ubuntu13/lsz/dataset/ILSVRC') 55 | parser.add_argument('-ngpus', type=int, default=None) 56 | parser.add_argument('-pretrained', type=str, default=None) 57 | parser.add_argument('-resume', type=str, default=None) 58 | parser.add_argument('-last_epoch', type=int, default=None) 59 | parser.add_argument('-eval', action='store_true') 60 | parser.add_argument('-amp', action='store_true', default=True) 61 | arguments = parser.parse_args() 62 | return arguments 63 | 64 | 65 | def get_logger(filename, logger_name=None): 66 | """set logging file and format 67 | Args: 68 | filename: str, full path of the logger file to write 69 | logger_name: str, the logger name, e.g., 'master_logger', 'local_logger' 70 | Return: 71 | logger: python logger 72 | """ 73 | log_format = "%(asctime)s %(message)s" 74 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 75 | format=log_format, datefmt="%m%d %I:%M:%S %p") 76 | # different name is needed when creating multiple logger in one process 77 | logger = logging.getLogger(logger_name) 78 | fh = logging.FileHandler(os.path.join(filename)) 79 | fh.setFormatter(logging.Formatter(log_format)) 80 | logger.addHandler(fh) 81 | return logger 82 | 83 | 84 | def train(dataloader, 85 | model, 86 | criterion, 87 | optimizer, 88 | epoch, 89 | total_epochs, 90 | total_batch, 91 | debug_steps=100, 92 | accum_iter=1, 93 | mixup_fn=None, 94 | amp=False, 95 | logger=None): 96 | """Training for one epoch 97 | Args: 98 | dataloader: paddle.io.DataLoader, dataloader instance 99 | model: nn.Layer, a ViT model 100 | criterion: nn.criterion 101 | epoch: int, current epoch 102 | total_epochs: int, total num of epochs 103 | total_batch: int, total num of batches for one epoch 104 | debug_steps: int, num of iters to log info, default: 100 105 | accum_iter: int, num of iters for accumulating gradients, default: 1 106 | mixup_fn: Mixup, mixup instance, default: None 107 | amp: bool, if True, use mix precision training, default: False 108 | logger: logger for logging, default: None 109 | Returns: 110 | train_loss_meter.avg: float, average loss on current process/gpu 111 | train_acc_meter.avg: float, average top1 accuracy on current process/gpu 112 | train_time: float, training time 113 | """ 114 | model.train() 115 | train_loss_meter = AverageMeter() 116 | train_acc_meter = AverageMeter() 117 | 118 | if amp is True: 119 | scaler = paddle.amp.GradScaler(init_loss_scaling=1024) 120 | time_st = time.time() 121 | 122 | for batch_id, data in enumerate(dataloader): 123 | image = data[0] 124 | label = data[1] 125 | label_orig = label.clone() 126 | 127 | if mixup_fn is not None: 128 | image, label = mixup_fn(image, label_orig) 129 | 130 | if amp is True: # mixed precision training 131 | with paddle.amp.auto_cast(): 132 | output = model(image) 133 | loss = criterion(image, output, label) 134 | scaled = scaler.scale(loss) 135 | scaled.backward() 136 | if ((batch_id +1) % accum_iter == 0) or (batch_id + 1 == len(dataloader)): 137 | scaler.minimize(optimizer, scaled) 138 | optimizer.clear_grad() 139 | else: # full precision training 140 | output = model(image) 141 | loss = criterion(output, label) 142 | #NOTE: division may be needed depending on the loss function 143 | # Here no division is needed: 144 | # default 'reduction' param in nn.CrossEntropyLoss is set to 'mean' 145 | #loss = loss / accum_iter 146 | loss.backward() 147 | 148 | if ((batch_id +1) % accum_iter == 0) or (batch_id + 1 == len(dataloader)): 149 | optimizer.step() 150 | optimizer.clear_grad() 151 | 152 | pred = F.softmax(output) 153 | if mixup_fn: 154 | acc = paddle.metric.accuracy(pred, label_orig) 155 | else: 156 | acc = paddle.metric.accuracy(pred, label_orig.unsqueeze(1)) 157 | 158 | batch_size = image.shape[0] 159 | train_loss_meter.update(loss.numpy()[0], batch_size) 160 | train_acc_meter.update(acc.numpy()[0], batch_size) 161 | 162 | if logger and batch_id % debug_steps == 0: 163 | logger.info( 164 | f"Epoch[{epoch:03d}/{total_epochs:03d}], " + 165 | f"Step[{batch_id:04d}/{total_batch:04d}], " + 166 | f"Avg Loss: {train_loss_meter.avg:.4f}, " + 167 | f"Avg Acc: {train_acc_meter.avg:.4f}") 168 | 169 | train_time = time.time() - time_st 170 | return train_loss_meter.avg, train_acc_meter.avg, train_time 171 | 172 | 173 | def validate(dataloader, model, criterion, total_batch, debug_steps=100, logger=None): 174 | """Validation for whole dataset 175 | Args: 176 | dataloader: paddle.io.DataLoader, dataloader instance 177 | model: nn.Layer, a ViT model 178 | criterion: nn.criterion 179 | total_batch: int, total num of batches for one epoch 180 | debug_steps: int, num of iters to log info, default: 100 181 | logger: logger for logging, default: None 182 | Returns: 183 | val_loss_meter.avg: float, average loss on current process/gpu 184 | val_acc1_meter.avg: float, average top1 accuracy on current process/gpu 185 | val_acc5_meter.avg: float, average top5 accuracy on current process/gpu 186 | val_time: float, valitaion time 187 | """ 188 | model.eval() 189 | val_loss_meter = AverageMeter() 190 | val_acc1_meter = AverageMeter() 191 | val_acc5_meter = AverageMeter() 192 | time_st = time.time() 193 | 194 | with paddle.no_grad(): 195 | for batch_id, data in enumerate(dataloader): 196 | image = data[0] 197 | label = data[1] 198 | 199 | output = model(image) 200 | loss = criterion(output, label) 201 | 202 | pred = F.softmax(output) 203 | acc1 = paddle.metric.accuracy(pred, label.unsqueeze(1)) 204 | acc5 = paddle.metric.accuracy(pred, label.unsqueeze(1), k=5) 205 | 206 | batch_size = image.shape[0] 207 | val_loss_meter.update(loss.numpy()[0], batch_size) 208 | val_acc1_meter.update(acc1.numpy()[0], batch_size) 209 | val_acc5_meter.update(acc5.numpy()[0], batch_size) 210 | 211 | if logger and batch_id % debug_steps == 0: 212 | logger.info( 213 | f"Val Step[{batch_id:04d}/{total_batch:04d}], " + 214 | f"Avg Loss: {val_loss_meter.avg:.4f}, " + 215 | f"Avg Acc@1: {val_acc1_meter.avg:.4f}, " + 216 | f"Avg Acc@5: {val_acc5_meter.avg:.4f}") 217 | 218 | val_time = time.time() - time_st 219 | return val_loss_meter.avg, val_acc1_meter.avg, val_acc5_meter.avg, val_time 220 | 221 | 222 | def main(): 223 | # STEP 0: Preparation 224 | # config is updated by: (1) config.py, (2) yaml file, (3) arguments 225 | arguments = get_arguments() 226 | config = get_config() 227 | config = update_config(config, arguments) 228 | # set output folder 229 | if not config.EVAL: 230 | config.SAVE = '{}/train-{}'.format(config.SAVE, time.strftime('%Y%m%d-%H-%M-%S')) 231 | else: 232 | config.SAVE = '{}/eval-{}'.format(config.SAVE, time.strftime('%Y%m%d-%H-%M-%S')) 233 | if not os.path.exists(config.SAVE): 234 | os.makedirs(config.SAVE, exist_ok=True) 235 | last_epoch = config.TRAIN.LAST_EPOCH 236 | seed = config.SEED 237 | paddle.seed(seed) 238 | np.random.seed(seed) 239 | random.seed(seed) 240 | logger = get_logger(filename=os.path.join(config.SAVE, 'log.txt')) 241 | logger.info(f'\n{config}') 242 | 243 | # STEP 1: Create model 244 | model = build_model(config) 245 | 246 | # STEP 2: Create train and val dataloader 247 | dataset_train = get_dataset(config, mode='train') 248 | dataset_val = get_dataset(config, mode='val') 249 | dataloader_train = get_dataloader(config, dataset_train, 'train', False) 250 | dataloader_val = get_dataloader(config, dataset_val, 'val', False) 251 | 252 | # STEP 3: Define Mixup function 253 | mixup_fn = None 254 | if config.TRAIN.MIXUP_PROB > 0 or config.TRAIN.CUTMIX_ALPHA > 0 or config.TRAIN.CUTMIX_MINMAX is not None: 255 | mixup_fn = Mixup(mixup_alpha=config.TRAIN.MIXUP_ALPHA, 256 | cutmix_alpha=config.TRAIN.CUTMIX_ALPHA, 257 | cutmix_minmax=config.TRAIN.CUTMIX_MINMAX, 258 | prob=config.TRAIN.MIXUP_PROB, 259 | switch_prob=config.TRAIN.MIXUP_SWITCH_PROB, 260 | mode=config.TRAIN.MIXUP_MODE, 261 | label_smoothing=config.TRAIN.SMOOTHING) 262 | 263 | # STEP 4: Define criterion 264 | if config.TRAIN.MIXUP_PROB > 0.: 265 | criterion = SoftTargetCrossEntropyLoss() 266 | elif config.TRAIN.SMOOTHING: 267 | criterion = LabelSmoothingCrossEntropyLoss() 268 | else: 269 | criterion = nn.CrossEntropyLoss() 270 | # only use cross entropy for val 271 | criterion_val = nn.CrossEntropyLoss() 272 | 273 | # STEP 5: Define optimizer and lr_scheduler 274 | # set lr according to batch size and world size (hacked from official code) 275 | linear_scaled_lr = (config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE) / 512.0 276 | linear_scaled_warmup_start_lr = (config.TRAIN.WARMUP_START_LR * config.DATA.BATCH_SIZE) / 512.0 277 | linear_scaled_end_lr = (config.TRAIN.END_LR * config.DATA.BATCH_SIZE) / 512.0 278 | 279 | if config.TRAIN.ACCUM_ITER > 1: 280 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUM_ITER 281 | linear_scaled_warmup_start_lr = linear_scaled_warmup_start_lr * config.TRAIN.ACCUM_ITER 282 | linear_scaled_end_lr = linear_scaled_end_lr * config.TRAIN.ACCUM_ITER 283 | 284 | config.TRAIN.BASE_LR = linear_scaled_lr 285 | config.TRAIN.WARMUP_START_LR = linear_scaled_warmup_start_lr 286 | config.TRAIN.END_LR = linear_scaled_end_lr 287 | 288 | scheduler = None 289 | if config.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": 290 | scheduler = WarmupCosineScheduler(learning_rate=config.TRAIN.BASE_LR, 291 | warmup_start_lr=config.TRAIN.WARMUP_START_LR, 292 | start_lr=config.TRAIN.BASE_LR, 293 | end_lr=config.TRAIN.END_LR, 294 | warmup_epochs=config.TRAIN.WARMUP_EPOCHS, 295 | total_epochs=config.TRAIN.NUM_EPOCHS, 296 | last_epoch=config.TRAIN.LAST_EPOCH, 297 | ) 298 | elif config.TRAIN.LR_SCHEDULER.NAME == "cosine": 299 | scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=config.TRAIN.BASE_LR, 300 | T_max=config.TRAIN.NUM_EPOCHS, 301 | last_epoch=last_epoch) 302 | elif config.scheduler == "multi-step": 303 | milestones = [int(v.strip()) for v in config.TRAIN.LR_SCHEDULER.MILESTONES.split(",")] 304 | scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=config.TRAIN.BASE_LR, 305 | milestones=milestones, 306 | gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 307 | last_epoch=last_epoch) 308 | else: 309 | logger.fatal(f"Unsupported Scheduler: {config.TRAIN.LR_SCHEDULER}.") 310 | raise NotImplementedError(f"Unsupported Scheduler: {config.TRAIN.LR_SCHEDULER}.") 311 | 312 | if config.TRAIN.OPTIMIZER.NAME == "SGD": 313 | if config.TRAIN.GRAD_CLIP: 314 | clip = paddle.nn.ClipGradByGlobalNorm(config.TRAIN.GRAD_CLIP) 315 | else: 316 | clip = None 317 | optimizer = paddle.optimizer.Momentum( 318 | parameters=model.parameters(), 319 | learning_rate=scheduler if scheduler is not None else config.TRAIN.BASE_LR, 320 | weight_decay=config.TRAIN.WEIGHT_DECAY, 321 | momentum=config.TRAIN.OPTIMIZER.MOMENTUM, 322 | grad_clip=clip) 323 | elif config.TRAIN.OPTIMIZER.NAME == "AdamW": 324 | if config.TRAIN.GRAD_CLIP: 325 | clip = paddle.nn.ClipGradByGlobalNorm(config.TRAIN.GRAD_CLIP) 326 | else: 327 | clip = None 328 | optimizer = paddle.optimizer.AdamW( 329 | parameters=model.parameters(), 330 | learning_rate=scheduler if scheduler is not None else config.TRAIN.BASE_LR, 331 | beta1=config.TRAIN.OPTIMIZER.BETAS[0], 332 | beta2=config.TRAIN.OPTIMIZER.BETAS[1], 333 | weight_decay=config.TRAIN.WEIGHT_DECAY, 334 | epsilon=config.TRAIN.OPTIMIZER.EPS, 335 | grad_clip=clip, 336 | apply_decay_param_fun=get_exclude_from_weight_decay_fn([ 337 | 'absolute_pos_embed', 'relative_position_bias_table']), 338 | ) 339 | else: 340 | logger.fatal(f"Unsupported Optimizer: {config.TRAIN.OPTIMIZER.NAME}.") 341 | raise NotImplementedError(f"Unsupported Optimizer: {config.TRAIN.OPTIMIZER.NAME}.") 342 | 343 | # STEP 6: Load pretrained model or load resume model and optimizer states 344 | if config.MODEL.PRETRAINED: 345 | if (config.MODEL.PRETRAINED).endswith('.pdparams'): 346 | raise ValueError(f'{config.MODEL.PRETRAINED} should not contain .pdparams') 347 | assert os.path.isfile(config.MODEL.PRETRAINED + '.pdparams') is True 348 | model_state = paddle.load(config.MODEL.PRETRAINED+'.pdparams') 349 | model.set_dict(model_state) 350 | logger.info(f"----- Pretrained: Load model state from {config.MODEL.PRETRAINED}") 351 | 352 | if config.MODEL.RESUME: 353 | assert os.path.isfile(config.MODEL.RESUME+'.pdparams') is True 354 | assert os.path.isfile(config.MODEL.RESUME+'.pdopt') is True 355 | model_state = paddle.load(config.MODEL.RESUME+'.pdparams') 356 | model.set_dict(model_state) 357 | opt_state = paddle.load(config.MODEL.RESUME+'.pdopt') 358 | optimizer.set_state_dict(opt_state) 359 | logger.info( 360 | f"----- Resume: Load model and optmizer from {config.MODEL.RESUME}") 361 | 362 | # STEP 7: Validation (eval mode) 363 | if config.EVAL: 364 | logger.info('----- Start Validating') 365 | val_loss, val_acc1, val_acc5, val_time = validate( 366 | dataloader=dataloader_val, 367 | model=model, 368 | criterion=criterion_val, 369 | total_batch=len(dataloader_val), 370 | debug_steps=config.REPORT_FREQ, 371 | logger=logger) 372 | logger.info(f"Validation Loss: {val_loss:.4f}, " + 373 | f"Validation Acc@1: {val_acc1:.4f}, " + 374 | f"Validation Acc@5: {val_acc5:.4f}, " + 375 | f"time: {val_time:.2f}") 376 | return 377 | 378 | # STEP 8: Start training and validation (train mode) 379 | logger.info(f"Start training from epoch {last_epoch+1}.") 380 | for epoch in range(last_epoch+1, config.TRAIN.NUM_EPOCHS+1): 381 | # train 382 | logger.info(f"Now training epoch {epoch}. LR={optimizer.get_lr():.6f}") 383 | train_loss, train_acc, train_time = train(dataloader=dataloader_train, 384 | model=model, 385 | criterion=criterion, 386 | optimizer=optimizer, 387 | epoch=epoch, 388 | total_epochs=config.TRAIN.NUM_EPOCHS, 389 | total_batch=len(dataloader_train), 390 | debug_steps=config.REPORT_FREQ, 391 | accum_iter=config.TRAIN.ACCUM_ITER, 392 | mixup_fn=mixup_fn, 393 | amp=config.AMP, 394 | logger=logger) 395 | scheduler.step() 396 | logger.info(f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " + 397 | f"Train Loss: {train_loss:.4f}, " + 398 | f"Train Acc: {train_acc:.4f}, " + 399 | f"time: {train_time:.2f}") 400 | # validation 401 | if epoch % config.VALIDATE_FREQ == 0 or epoch == config.TRAIN.NUM_EPOCHS: 402 | logger.info(f'----- Validation after Epoch: {epoch}') 403 | val_loss, val_acc1, val_acc5, val_time = validate( 404 | dataloader=dataloader_val, 405 | model=model, 406 | criterion=criterion_val, 407 | total_batch=len(dataloader_val), 408 | debug_steps=config.REPORT_FREQ, 409 | logger=logger) 410 | logger.info(f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " + 411 | f"Validation Loss: {val_loss:.4f}, " + 412 | f"Validation Acc@1: {val_acc1:.4f}, " + 413 | f"Validation Acc@5: {val_acc5:.4f}, " + 414 | f"time: {val_time:.2f}") 415 | # model save 416 | if epoch % config.SAVE_FREQ == 0 or epoch == config.TRAIN.NUM_EPOCHS: 417 | model_path = os.path.join( 418 | config.SAVE, f"{config.MODEL.TYPE}-Epoch-{epoch}-Loss-{train_loss}") 419 | paddle.save(model.state_dict(), model_path + '.pdparams') 420 | paddle.save(optimizer.state_dict(), model_path + '.pdopt') 421 | logger.info(f"----- Save model: {model_path}.pdparams") 422 | logger.info(f"----- Save optim: {model_path}.pdopt") 423 | 424 | 425 | if __name__ == "__main__": 426 | main() 427 | -------------------------------------------------------------------------------- /main_multi_gpu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Swin training/validation using multiple GPU """ 16 | 17 | import sys 18 | import os 19 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1' 20 | import time 21 | import logging 22 | import argparse 23 | import random 24 | import numpy as np 25 | import paddle 26 | import paddle.nn as nn 27 | import paddle.nn.functional as F 28 | import paddle.distributed as dist 29 | from datasets import get_dataloader 30 | from datasets import get_dataset 31 | from utils import AverageMeter 32 | from utils import WarmupCosineScheduler 33 | from utils import get_exclude_from_weight_decay_fn 34 | from config import get_config 35 | from config import update_config 36 | from mixup import Mixup 37 | from losses import LabelSmoothingCrossEntropyLoss 38 | from losses import SoftTargetCrossEntropyLoss 39 | from losses import DistillationLoss 40 | from swin_transformer import build_swin as build_model 41 | 42 | 43 | def get_arguments(): 44 | """return argumeents, this will overwrite the config after loading yaml file""" 45 | parser = argparse.ArgumentParser('Swin') 46 | parser.add_argument('-cfg', type=str, default='/home/ubuntu13/lsz/code/S-T-V2/PaddleViT/image_classification/SwinTransformerV2/configs/swinv2_base_patch4_window7_224.yaml') 47 | parser.add_argument('-dataset', type=str, default='imagenet2012') 48 | parser.add_argument('-batch_size', type=int, default=100) 49 | parser.add_argument('-image_size', type=int, default=None) 50 | parser.add_argument('-data_path', type=str, default='/home/ubuntu13/lsz/dataset/ILSVRC') 51 | parser.add_argument('-ngpus', type=int, default=None) 52 | parser.add_argument('-pretrained', type=str, default=None) 53 | parser.add_argument('-resume', type=str, default=None) 54 | parser.add_argument('-last_epoch', type=int, default=None) 55 | parser.add_argument('-eval', action='store_true') 56 | parser.add_argument('-amp', action='store_true', default=True) 57 | arguments = parser.parse_args() 58 | return arguments 59 | 60 | 61 | def get_logger(filename, logger_name=None): 62 | """set logging file and format 63 | Args: 64 | filename: str, full path of the logger file to write 65 | logger_name: str, the logger name, e.g., 'master_logger', 'local_logger' 66 | Return: 67 | logger: python logger 68 | """ 69 | log_format = "%(asctime)s %(message)s" 70 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 71 | format=log_format, datefmt="%m%d %I:%M:%S %p") 72 | # different name is needed when creating multiple logger in one process 73 | logger = logging.getLogger(logger_name) 74 | fh = logging.FileHandler(os.path.join(filename)) 75 | fh.setFormatter(logging.Formatter(log_format)) 76 | logger.addHandler(fh) 77 | return logger 78 | 79 | 80 | def train(dataloader, 81 | model, 82 | criterion, 83 | optimizer, 84 | epoch, 85 | total_epochs, 86 | total_batch, 87 | debug_steps=100, 88 | accum_iter=1, 89 | mixup_fn=None, 90 | amp=False, 91 | local_logger=None, 92 | master_logger=None): 93 | """Training for one epoch 94 | Args: 95 | dataloader: paddle.io.DataLoader, dataloader instance 96 | model: nn.Layer, a ViT model 97 | criterion: nn.criterion 98 | epoch: int, current epoch 99 | total_epochs: int, total num of epochs 100 | total_batch: int, total num of batches for one epoch 101 | debug_steps: int, num of iters to log info, default: 100 102 | accum_iter: int, num of iters for accumulating gradients, default: 1 103 | mixup_fn: Mixup, mixup instance, default: None 104 | amp: bool, if True, use mix precision training, default: False 105 | local_logger: logger for local process/gpu, default: None 106 | master_logger: logger for main process, default: None 107 | Returns: 108 | train_loss_meter.avg: float, average loss on current process/gpu 109 | train_acc_meter.avg: float, average top1 accuracy on current process/gpu 110 | master_train_loss_meter.avg: float, average loss on all processes/gpus 111 | master_train_acc_meter.avg: float, average top1 accuracy on all processes/gpus 112 | train_time: float, training time 113 | """ 114 | model.train() 115 | train_loss_meter = AverageMeter() 116 | train_acc_meter = AverageMeter() 117 | master_train_loss_meter = AverageMeter() 118 | master_train_acc_meter = AverageMeter() 119 | 120 | if amp is True: 121 | scaler = paddle.amp.GradScaler(init_loss_scaling=1024) 122 | time_st = time.time() 123 | 124 | for batch_id, data in enumerate(dataloader): 125 | image = data[0] 126 | label = data[1] 127 | label_orig = label.clone() 128 | 129 | if mixup_fn is not None: 130 | image, label = mixup_fn(image, label_orig) 131 | 132 | if amp is True: # mixed precision training 133 | with paddle.amp.auto_cast(): 134 | output = model(image) 135 | loss = criterion(image, output, label) 136 | scaled = scaler.scale(loss) 137 | scaled.backward() 138 | if ((batch_id +1) % accum_iter == 0) or (batch_id + 1 == len(dataloader)): 139 | scaler.minimize(optimizer, scaled) 140 | optimizer.clear_grad() 141 | else: # full precision training 142 | output = model(image) 143 | loss = criterion(output, label) 144 | #NOTE: division may be needed depending on the loss function 145 | # Here no division is needed: 146 | # default 'reduction' param in nn.CrossEntropyLoss is set to 'mean' 147 | #loss = loss / accum_iter 148 | loss.backward() 149 | 150 | if ((batch_id +1) % accum_iter == 0) or (batch_id + 1 == len(dataloader)): 151 | optimizer.step() 152 | optimizer.clear_grad() 153 | 154 | pred = F.softmax(output) 155 | if mixup_fn: 156 | acc = paddle.metric.accuracy(pred, label_orig) 157 | else: 158 | acc = paddle.metric.accuracy(pred, label_orig.unsqueeze(1)) 159 | 160 | batch_size = paddle.to_tensor(image.shape[0]) 161 | 162 | # sync from other gpus for overall loss and acc 163 | master_loss = loss.clone() 164 | master_acc = acc.clone() 165 | master_batch_size = batch_size.clone() 166 | dist.all_reduce(master_loss) 167 | dist.all_reduce(master_acc) 168 | dist.all_reduce(master_batch_size) 169 | master_loss = master_loss / dist.get_world_size() 170 | master_acc = master_acc / dist.get_world_size() 171 | master_train_loss_meter.update(master_loss.numpy()[0], master_batch_size.numpy()[0]) 172 | master_train_acc_meter.update(master_acc.numpy()[0], master_batch_size.numpy()[0]) 173 | 174 | train_loss_meter.update(loss.numpy()[0], batch_size.numpy()[0]) 175 | train_acc_meter.update(acc.numpy()[0], batch_size.numpy()[0]) 176 | 177 | if batch_id % debug_steps == 0: 178 | if local_logger: 179 | local_logger.info( 180 | f"Epoch[{epoch:03d}/{total_epochs:03d}], " + 181 | f"Step[{batch_id:04d}/{total_batch:04d}], " + 182 | f"Avg Loss: {train_loss_meter.avg:.4f}, " + 183 | f"Avg Acc: {train_acc_meter.avg:.4f}") 184 | if master_logger and dist.get_rank() == 0: 185 | master_logger.info( 186 | f"Epoch[{epoch:03d}/{total_epochs:03d}], " + 187 | f"Step[{batch_id:04d}/{total_batch:04d}], " + 188 | f"Avg Loss: {master_train_loss_meter.avg:.4f}, " + 189 | f"Avg Acc: {master_train_acc_meter.avg:.4f}") 190 | 191 | train_time = time.time() - time_st 192 | return (train_loss_meter.avg, 193 | train_acc_meter.avg, 194 | master_train_loss_meter.avg, 195 | master_train_acc_meter.avg, 196 | train_time) 197 | 198 | 199 | def validate(dataloader, 200 | model, 201 | criterion, 202 | total_batch, 203 | debug_steps=100, 204 | local_logger=None, 205 | master_logger=None): 206 | """Validation for whole dataset 207 | Args: 208 | dataloader: paddle.io.DataLoader, dataloader instance 209 | model: nn.Layer, a ViT model 210 | criterion: nn.criterion 211 | total_epoch: int, total num of epoch, for logging 212 | debug_steps: int, num of iters to log info, default: 100 213 | local_logger: logger for local process/gpu, default: None 214 | master_logger: logger for main process, default: None 215 | Returns: 216 | val_loss_meter.avg: float, average loss on current process/gpu 217 | val_acc1_meter.avg: float, average top1 accuracy on current process/gpu 218 | val_acc5_meter.avg: float, average top5 accuracy on current process/gpu 219 | master_val_loss_meter.avg: float, average loss on all processes/gpus 220 | master_val_acc1_meter.avg: float, average top1 accuracy on all processes/gpus 221 | master_val_acc5_meter.avg: float, average top5 accuracy on all processes/gpus 222 | val_time: float, validation time 223 | """ 224 | model.eval() 225 | val_loss_meter = AverageMeter() 226 | val_acc1_meter = AverageMeter() 227 | val_acc5_meter = AverageMeter() 228 | master_val_loss_meter = AverageMeter() 229 | master_val_acc1_meter = AverageMeter() 230 | master_val_acc5_meter = AverageMeter() 231 | time_st = time.time() 232 | 233 | with paddle.no_grad(): 234 | for batch_id, data in enumerate(dataloader): 235 | image = data[0] 236 | label = data[1] 237 | 238 | output = model(image) 239 | loss = criterion(output, label) 240 | 241 | pred = F.softmax(output) 242 | acc1 = paddle.metric.accuracy(pred, label.unsqueeze(1)) 243 | acc5 = paddle.metric.accuracy(pred, label.unsqueeze(1), k=5) 244 | 245 | batch_size = paddle.to_tensor(image.shape[0]) 246 | 247 | master_loss = loss.clone() 248 | master_acc1 = acc1.clone() 249 | master_acc5 = acc5.clone() 250 | master_batch_size = batch_size.clone() 251 | 252 | dist.all_reduce(master_loss) 253 | dist.all_reduce(master_acc1) 254 | dist.all_reduce(master_acc5) 255 | dist.all_reduce(master_batch_size) 256 | master_loss = master_loss / dist.get_world_size() 257 | master_acc1 = master_acc1 / dist.get_world_size() 258 | master_acc5 = master_acc5 / dist.get_world_size() 259 | 260 | master_val_loss_meter.update(master_loss.numpy()[0], master_batch_size.numpy()[0]) 261 | master_val_acc1_meter.update(master_acc1.numpy()[0], master_batch_size.numpy()[0]) 262 | master_val_acc5_meter.update(master_acc5.numpy()[0], master_batch_size.numpy()[0]) 263 | 264 | val_loss_meter.update(loss.numpy()[0], batch_size.numpy()[0]) 265 | val_acc1_meter.update(acc1.numpy()[0], batch_size.numpy()[0]) 266 | val_acc5_meter.update(acc5.numpy()[0], batch_size.numpy()[0]) 267 | 268 | if batch_id % debug_steps == 0: 269 | if local_logger: 270 | local_logger.info( 271 | f"Val Step[{batch_id:04d}/{total_batch:04d}], " + 272 | f"Avg Loss: {val_loss_meter.avg:.4f}, " + 273 | f"Avg Acc@1: {val_acc1_meter.avg:.4f}, " + 274 | f"Avg Acc@5: {val_acc5_meter.avg:.4f}") 275 | if master_logger and dist.get_rank() == 0: 276 | master_logger.info( 277 | f"Val Step[{batch_id:04d}/{total_batch:04d}], " + 278 | f"Avg Loss: {master_val_loss_meter.avg:.4f}, " + 279 | f"Avg Acc@1: {master_val_acc1_meter.avg:.4f}, " + 280 | f"Avg Acc@5: {master_val_acc5_meter.avg:.4f}") 281 | val_time = time.time() - time_st 282 | return (val_loss_meter.avg, 283 | val_acc1_meter.avg, 284 | val_acc5_meter.avg, 285 | master_val_loss_meter.avg, 286 | master_val_acc1_meter.avg, 287 | master_val_acc5_meter.avg, 288 | val_time) 289 | 290 | 291 | def main_worker(*args): 292 | # STEP 0: Preparation 293 | config = args[0] 294 | dist.init_parallel_env() 295 | last_epoch = config.TRAIN.LAST_EPOCH 296 | world_size = dist.get_world_size() 297 | local_rank = dist.get_rank() 298 | seed = config.SEED + local_rank 299 | paddle.seed(seed) 300 | np.random.seed(seed) 301 | random.seed(seed) 302 | # logger for each process/gpu 303 | local_logger = get_logger( 304 | filename=os.path.join(config.SAVE, 'log_{}.txt'.format(local_rank)), 305 | logger_name='local_logger') 306 | # overall logger 307 | if local_rank == 0: 308 | master_logger = get_logger( 309 | filename=os.path.join(config.SAVE, 'log.txt'), 310 | logger_name='master_logger') 311 | master_logger.info(f'\n{config}') 312 | else: 313 | master_logger = None 314 | local_logger.info(f'----- world_size = {world_size}, local_rank = {local_rank}') 315 | if local_rank == 0: 316 | master_logger.info(f'----- world_size = {world_size}, local_rank = {local_rank}') 317 | 318 | # STEP 1: Create model 319 | model = build_model(config) 320 | model = paddle.DataParallel(model) 321 | 322 | # STEP 2: Create train and val dataloader 323 | dataset_train, dataset_val = args[1], args[2] 324 | dataloader_train = get_dataloader(config, dataset_train, 'train', True) 325 | dataloader_val = get_dataloader(config, dataset_val, 'test', True) 326 | total_batch_train = len(dataloader_train) 327 | total_batch_val = len(dataloader_val) 328 | local_logger.info(f'----- Total # of train batch (single gpu): {total_batch_train}') 329 | local_logger.info(f'----- Total # of val batch (single gpu): {total_batch_val}') 330 | if local_rank == 0: 331 | master_logger.info(f'----- Total # of train batch (single gpu): {total_batch_train}') 332 | master_logger.info(f'----- Total # of val batch (single gpu): {total_batch_val}') 333 | 334 | # STEP 3: Define Mixup function 335 | mixup_fn = None 336 | if config.TRAIN.MIXUP_PROB > 0 or config.TRAIN.CUTMIX_ALPHA > 0 or config.TRAIN.CUTMIX_MINMAX is not None: 337 | mixup_fn = Mixup(mixup_alpha=config.TRAIN.MIXUP_ALPHA, 338 | cutmix_alpha=config.TRAIN.CUTMIX_ALPHA, 339 | cutmix_minmax=config.TRAIN.CUTMIX_MINMAX, 340 | prob=config.TRAIN.MIXUP_PROB, 341 | switch_prob=config.TRAIN.MIXUP_SWITCH_PROB, 342 | mode=config.TRAIN.MIXUP_MODE, 343 | label_smoothing=config.TRAIN.SMOOTHING) 344 | 345 | # STEP 4: Define criterion 346 | if config.TRAIN.MIXUP_PROB > 0.: 347 | criterion = SoftTargetCrossEntropyLoss() 348 | elif config.TRAIN.SMOOTHING: 349 | criterion = LabelSmoothingCrossEntropyLoss() 350 | else: 351 | criterion = nn.CrossEntropyLoss() 352 | # only use cross entropy for val 353 | criterion_val = nn.CrossEntropyLoss() 354 | 355 | # STEP 5: Define optimizer and lr_scheduler 356 | # set lr according to batch size and world size (hacked from official code) 357 | linear_scaled_lr = (config.TRAIN.BASE_LR * 358 | config.DATA.BATCH_SIZE * dist.get_world_size()) / 512.0 359 | linear_scaled_warmup_start_lr = (config.TRAIN.WARMUP_START_LR * 360 | config.DATA.BATCH_SIZE * dist.get_world_size()) / 512.0 361 | linear_scaled_end_lr = (config.TRAIN.END_LR * 362 | config.DATA.BATCH_SIZE * dist.get_world_size()) / 512.0 363 | 364 | if config.TRAIN.ACCUM_ITER > 1: 365 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUM_ITER 366 | linear_scaled_warmup_start_lr = linear_scaled_warmup_start_lr * config.TRAIN.ACCUM_ITER 367 | linear_scaled_end_lr = linear_scaled_end_lr * config.TRAIN.ACCUM_ITER 368 | 369 | config.TRAIN.BASE_LR = linear_scaled_lr 370 | config.TRAIN.WARMUP_START_LR = linear_scaled_warmup_start_lr 371 | config.TRAIN.END_LR = linear_scaled_end_lr 372 | 373 | scheduler = None 374 | if config.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": 375 | scheduler = WarmupCosineScheduler(learning_rate=config.TRAIN.BASE_LR, 376 | warmup_start_lr=config.TRAIN.WARMUP_START_LR, 377 | start_lr=config.TRAIN.BASE_LR, 378 | end_lr=config.TRAIN.END_LR, 379 | warmup_epochs=config.TRAIN.WARMUP_EPOCHS, 380 | total_epochs=config.TRAIN.NUM_EPOCHS, 381 | last_epoch=config.TRAIN.LAST_EPOCH, 382 | ) 383 | elif config.TRAIN.LR_SCHEDULER.NAME == "cosine": 384 | scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=config.TRAIN.BASE_LR, 385 | T_max=config.TRAIN.NUM_EPOCHS, 386 | last_epoch=last_epoch) 387 | elif config.scheduler == "multi-step": 388 | milestones = [int(v.strip()) for v in config.TRAIN.LR_SCHEDULER.MILESTONES.split(",")] 389 | scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=config.TRAIN.BASE_LR, 390 | milestones=milestones, 391 | gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 392 | last_epoch=last_epoch) 393 | else: 394 | local_logger.fatal(f"Unsupported Scheduler: {config.TRAIN.LR_SCHEDULER}.") 395 | if local_rank == 0: 396 | master_logger.fatal(f"Unsupported Scheduler: {config.TRAIN.LR_SCHEDULER}.") 397 | raise NotImplementedError(f"Unsupported Scheduler: {config.TRAIN.LR_SCHEDULER}.") 398 | 399 | if config.TRAIN.OPTIMIZER.NAME == "SGD": 400 | if config.TRAIN.GRAD_CLIP: 401 | clip = paddle.nn.ClipGradByGlobalNorm(config.TRAIN.GRAD_CLIP) 402 | else: 403 | clip = None 404 | optimizer = paddle.optimizer.Momentum( 405 | parameters=model.parameters(), 406 | learning_rate=scheduler if scheduler is not None else config.TRAIN.BASE_LR, 407 | weight_decay=config.TRAIN.WEIGHT_DECAY, 408 | momentum=config.TRAIN.OPTIMIZER.MOMENTUM, 409 | grad_clip=clip) 410 | elif config.TRAIN.OPTIMIZER.NAME == "AdamW": 411 | if config.TRAIN.GRAD_CLIP: 412 | clip = paddle.nn.ClipGradByGlobalNorm(config.TRAIN.GRAD_CLIP) 413 | else: 414 | clip = None 415 | optimizer = paddle.optimizer.AdamW( 416 | parameters=model.parameters(), 417 | learning_rate=scheduler if scheduler is not None else config.TRAIN.BASE_LR, 418 | beta1=config.TRAIN.OPTIMIZER.BETAS[0], 419 | beta2=config.TRAIN.OPTIMIZER.BETAS[1], 420 | weight_decay=config.TRAIN.WEIGHT_DECAY, 421 | epsilon=config.TRAIN.OPTIMIZER.EPS, 422 | grad_clip=clip, 423 | apply_decay_param_fun=get_exclude_from_weight_decay_fn([ 424 | 'absolute_pos_embed', 'relative_position_bias_table']), 425 | ) 426 | else: 427 | local_logger.fatal(f"Unsupported Optimizer: {config.TRAIN.OPTIMIZER.NAME}.") 428 | if local_rank == 0: 429 | master_logger.fatal(f"Unsupported Optimizer: {config.TRAIN.OPTIMIZER.NAME}.") 430 | raise NotImplementedError(f"Unsupported Optimizer: {config.TRAIN.OPTIMIZER.NAME}.") 431 | 432 | # STEP 6: Load pretrained model / load resumt model and optimizer states 433 | if config.MODEL.PRETRAINED: 434 | if (config.MODEL.PRETRAINED).endswith('.pdparams'): 435 | raise ValueError(f'{config.MODEL.PRETRAINED} should not contain .pdparams') 436 | assert os.path.isfile(config.MODEL.PRETRAINED + '.pdparams') is True 437 | model_state = paddle.load(config.MODEL.PRETRAINED+'.pdparams') 438 | model.set_dict(model_state) 439 | local_logger.info(f"----- Pretrained: Load model state from {config.MODEL.PRETRAINED}") 440 | if local_rank == 0: 441 | master_logger.info( 442 | f"----- Pretrained: Load model state from {config.MODEL.PRETRAINED}") 443 | 444 | if config.MODEL.RESUME: 445 | assert os.path.isfile(config.MODEL.RESUME+'.pdparams') is True 446 | assert os.path.isfile(config.MODEL.RESUME+'.pdopt') is True 447 | model_state = paddle.load(config.MODEL.RESUME+'.pdparams') 448 | model.set_dict(model_state) 449 | opt_state = paddle.load(config.MODEL.RESUME+'.pdopt') 450 | optimizer.set_state_dict(opt_state) 451 | local_logger.info( 452 | f"----- Resume Training: Load model and optmizer from {config.MODEL.RESUME}") 453 | if local_rank == 0: 454 | master_logger.info( 455 | f"----- Resume Training: Load model and optmizer from {config.MODEL.RESUME}") 456 | 457 | # STEP 7: Validation (eval mode) 458 | if config.EVAL: 459 | local_logger.info('----- Start Validating') 460 | if local_rank == 0: 461 | master_logger.info('----- Start Validating') 462 | val_loss, val_acc1, val_acc5, avg_loss, avg_acc1, avg_acc5, val_time = validate( 463 | dataloader=dataloader_val, 464 | model=model, 465 | criterion=criterion_val, 466 | total_batch=total_batch_val, 467 | debug_steps=config.REPORT_FREQ, 468 | local_logger=local_logger, 469 | master_logger=master_logger) 470 | local_logger.info(f"Validation Loss: {val_loss:.4f}, " + 471 | f"Validation Acc@1: {val_acc1:.4f}, " + 472 | f"Validation Acc@5: {val_acc5:.4f}, " + 473 | f"time: {val_time:.2f}") 474 | if local_rank == 0: 475 | master_logger.info(f"Validation Loss: {avg_loss:.4f}, " + 476 | f"Validation Acc@1: {avg_acc1:.4f}, " + 477 | f"Validation Acc@5: {avg_acc5:.4f}, " + 478 | f"time: {val_time:.2f}") 479 | return 480 | 481 | # STEP 8: Start training and validation (train mode) 482 | local_logger.info(f"Start training from epoch {last_epoch+1}.") 483 | if local_rank == 0: 484 | master_logger.info(f"Start training from epoch {last_epoch+1}.") 485 | for epoch in range(last_epoch+1, config.TRAIN.NUM_EPOCHS+1): 486 | # train 487 | local_logger.info(f"Now training epoch {epoch}. LR={optimizer.get_lr():.6f}") 488 | if local_rank == 0: 489 | master_logger.info(f"Now training epoch {epoch}. LR={optimizer.get_lr():.6f}") 490 | train_loss, train_acc, avg_loss, avg_acc, train_time = train( 491 | dataloader=dataloader_train, 492 | model=model, 493 | criterion=criterion, 494 | optimizer=optimizer, 495 | epoch=epoch, 496 | total_epochs=config.TRAIN.NUM_EPOCHS, 497 | total_batch=total_batch_train, 498 | debug_steps=config.REPORT_FREQ, 499 | accum_iter=config.TRAIN.ACCUM_ITER, 500 | mixup_fn=mixup_fn, 501 | amp=config.AMP, 502 | local_logger=local_logger, 503 | master_logger=master_logger) 504 | 505 | scheduler.step() 506 | 507 | local_logger.info(f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " + 508 | f"Train Loss: {train_loss:.4f}, " + 509 | f"Train Acc: {train_acc:.4f}, " + 510 | f"time: {train_time:.2f}") 511 | if local_rank == 0: 512 | master_logger.info(f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " + 513 | f"Train Loss: {avg_loss:.4f}, " + 514 | f"Train Acc: {avg_acc:.4f}, " + 515 | f"time: {train_time:.2f}") 516 | 517 | # validation 518 | if epoch % config.VALIDATE_FREQ == 0 or epoch == config.TRAIN.NUM_EPOCHS: 519 | local_logger.info(f'----- Validation after Epoch: {epoch}') 520 | if local_rank == 0: 521 | master_logger.info(f'----- Validation after Epoch: {epoch}') 522 | val_loss, val_acc1, val_acc5, avg_loss, avg_acc1, avg_acc5, val_time = validate( 523 | dataloader=dataloader_val, 524 | model=model, 525 | criterion=criterion_val, 526 | total_batch=total_batch_val, 527 | debug_steps=config.REPORT_FREQ, 528 | local_logger=local_logger, 529 | master_logger=master_logger) 530 | local_logger.info(f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " + 531 | f"Validation Loss: {val_loss:.4f}, " + 532 | f"Validation Acc@1: {val_acc1:.4f}, " + 533 | f"Validation Acc@5: {val_acc5:.4f}, " + 534 | f"time: {val_time:.2f}") 535 | if local_rank == 0: 536 | master_logger.info(f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " + 537 | f"Validation Loss: {avg_loss:.4f}, " + 538 | f"Validation Acc@1: {avg_acc1:.4f}, " + 539 | f"Validation Acc@5: {avg_acc5:.4f}, " + 540 | f"time: {val_time:.2f}") 541 | # model save 542 | if local_rank == 0: 543 | if epoch % config.SAVE_FREQ == 0 or epoch == config.TRAIN.NUM_EPOCHS: 544 | model_path = os.path.join( 545 | config.SAVE, f"{config.MODEL.TYPE}-Epoch-{epoch}-Loss-{train_loss}") 546 | paddle.save(model.state_dict(), model_path + '.pdparams') 547 | paddle.save(optimizer.state_dict(), model_path + '.pdopt') 548 | master_logger.info(f"----- Save model: {model_path}.pdparams") 549 | master_logger.info(f"----- Save optim: {model_path}.pdopt") 550 | 551 | 552 | def main(): 553 | # config is updated by: (1) config.py, (2) yaml file, (3) arguments 554 | arguments = get_arguments() 555 | config = get_config() 556 | config = update_config(config, arguments) 557 | 558 | # set output folder 559 | if not config.EVAL: 560 | config.SAVE = '{}/train-{}'.format(config.SAVE, time.strftime('%Y%m%d-%H-%M-%S')) 561 | else: 562 | config.SAVE = '{}/eval-{}'.format(config.SAVE, time.strftime('%Y%m%d-%H-%M-%S')) 563 | 564 | if not os.path.exists(config.SAVE): 565 | os.makedirs(config.SAVE, exist_ok=True) 566 | 567 | # get dataset and start DDP 568 | dataset_train = get_dataset(config, mode='train') 569 | dataset_val = get_dataset(config, mode='val') 570 | config.NGPUS = len(paddle.static.cuda_places()) if config.NGPUS == -1 else config.NGPUS 571 | dist.spawn(main_worker, args=(config, dataset_train, dataset_val, ), nprocs=config.NGPUS) 572 | 573 | 574 | if __name__ == "__main__": 575 | main() 576 | -------------------------------------------------------------------------------- /swin_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Implement Transformer Class for Swin Transformer V2 17 | """ 18 | 19 | from types import TracebackType 20 | import paddle 21 | from paddle.framework import dtype 22 | import paddle.nn as nn 23 | from droppath import DropPath 24 | 25 | 26 | class Identity(nn.Layer): 27 | """ Identity layer 28 | 29 | The output of this layer is the input without any change. 30 | Use this layer to avoid if condition in some forward methods 31 | 32 | """ 33 | def __init__(self): 34 | super(Identity, self).__init__() 35 | def forward(self, x): 36 | return x 37 | 38 | 39 | class PatchEmbedding(nn.Layer): 40 | """Patch Embeddings 41 | 42 | Apply patch embeddings on input images. Embeddings is implemented using a Conv2D op. 43 | 44 | Attributes: 45 | image_size: int, input image size, default: 224 46 | patch_size: int, size of patch, default: 4 47 | in_channels: int, input image channels, default: 3 48 | embed_dim: int, embedding dimension, default: 96 49 | """ 50 | 51 | def __init__(self, image_size=224, patch_size=4, in_channels=3, embed_dim=96): 52 | super().__init__() 53 | image_size = (image_size, image_size) # TODO: add to_2tuple 54 | patch_size = (patch_size, patch_size) 55 | patches_resolution = [image_size[0]//patch_size[0], image_size[1]//patch_size[1]] 56 | self.image_size = image_size 57 | self.patch_size = patch_size 58 | self.patches_resolution = patches_resolution 59 | self.num_patches = patches_resolution[0] * patches_resolution[1] 60 | self.in_channels = in_channels 61 | self.embed_dim = embed_dim 62 | self.patch_embed = nn.Conv2D(in_channels=in_channels, 63 | out_channels=embed_dim, 64 | kernel_size=patch_size, 65 | stride=patch_size) 66 | 67 | w_attr, b_attr = self._init_weights_layernorm() 68 | self.norm = nn.LayerNorm(embed_dim, 69 | weight_attr=w_attr, 70 | bias_attr=b_attr) 71 | 72 | def _init_weights_layernorm(self): 73 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1)) 74 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0)) 75 | return weight_attr, bias_attr 76 | 77 | def forward(self, x): 78 | x = self.patch_embed(x) # [batch, embed_dim, h, w] h,w = patch_resolution 79 | x = x.flatten(start_axis=2, stop_axis=-1) # [batch, embed_dim, h*w] h*w = num_patches 80 | x = x.transpose([0, 2, 1]) # [batch, h*w, embed_dim] 81 | x = self.norm(x) # [batch, num_patches, embed_dim] 82 | return x 83 | 84 | 85 | class PatchMerging(nn.Layer): 86 | """ Patch Merging class 87 | 88 | Merge multiple patch into one path and keep the out dim. 89 | Spefically, merge adjacent 2x2 patches(dim=C) into 1 patch. 90 | The concat dim 4*C is rescaled to 2*C 91 | 92 | Attributes: 93 | input_resolution: tuple of ints, the size of input 94 | dim: dimension of single patch 95 | reduction: nn.Linear which maps 4C to 2C dim 96 | norm: nn.LayerNorm, applied after linear layer. 97 | """ 98 | 99 | def __init__(self, input_resolution, dim): 100 | super(PatchMerging, self).__init__() 101 | self.input_resolution = input_resolution 102 | self.dim = dim 103 | w_attr_1, b_attr_1 = self._init_weights() 104 | self.reduction = nn.Linear(4 * dim, 105 | 2 * dim, 106 | weight_attr=w_attr_1, 107 | bias_attr=False) 108 | 109 | w_attr_2, b_attr_2 = self._init_weights_layernorm() 110 | self.norm = nn.LayerNorm(4*dim, 111 | weight_attr=w_attr_2, 112 | bias_attr=b_attr_2) 113 | 114 | def _init_weights_layernorm(self): 115 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1)) 116 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0)) 117 | return weight_attr, bias_attr 118 | 119 | def _init_weights(self): 120 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02)) 121 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0)) 122 | return weight_attr, bias_attr 123 | 124 | def forward(self, x): 125 | h, w = self.input_resolution 126 | b, _, c = x.shape 127 | x = x.reshape([b, h, w, c]) 128 | 129 | x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C] 130 | x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C] 131 | x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C] 132 | x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C] 133 | x = paddle.concat([x0, x1, x2, x3], -1) #[B, H/2, W/2, 4*C] 134 | x = x.reshape([b, -1, 4*c]) # [B, H/2*W/2, 4*C] 135 | 136 | x = self.norm(x) 137 | x = self.reduction(x) 138 | 139 | return x 140 | 141 | 142 | class Mlp(nn.Layer): 143 | """ MLP module 144 | 145 | Impl using nn.Linear and activation is GELU, dropout is applied. 146 | Ops: fc -> act -> dropout -> fc -> dropout 147 | 148 | Attributes: 149 | fc1: nn.Linear 150 | fc2: nn.Linear 151 | act: GELU 152 | dropout1: dropout after fc1 153 | dropout2: dropout after fc2 154 | """ 155 | 156 | def __init__(self, in_features, hidden_features, dropout): 157 | super(Mlp, self).__init__() 158 | w_attr_1, b_attr_1 = self._init_weights() 159 | self.fc1 = nn.Linear(in_features, 160 | hidden_features, 161 | weight_attr=w_attr_1, 162 | bias_attr=b_attr_1) 163 | 164 | w_attr_2, b_attr_2 = self._init_weights() 165 | self.fc2 = nn.Linear(hidden_features, 166 | in_features, 167 | weight_attr=w_attr_2, 168 | bias_attr=b_attr_2) 169 | self.act = nn.GELU() 170 | self.dropout = nn.Dropout(dropout) 171 | 172 | def _init_weights(self): 173 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02)) 174 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0)) 175 | return weight_attr, bias_attr 176 | 177 | def forward(self, x): 178 | x = self.fc1(x) 179 | x = self.act(x) 180 | x = self.dropout(x) 181 | x = self.fc2(x) 182 | x = self.dropout(x) 183 | return x 184 | 185 | class Mlp_Relu(nn.Layer): 186 | """ MLP module 187 | 188 | Impl using nn.Linear and activation is GELU, dropout is applied. 189 | Ops: fc -> act -> dropout -> fc -> dropout 190 | 191 | Attributes: 192 | fc1: nn.Linear 193 | fc2: nn.Linear 194 | act: RELU 195 | dropout1: dropout after fc1 196 | dropout2: dropout after fc2 197 | """ 198 | 199 | def __init__(self, in_features, hidden_features, out_features, dropout): 200 | super(Mlp_Relu, self).__init__() 201 | w_attr_1, b_attr_1 = self._init_weights() 202 | self.fc1 = nn.Linear(in_features, 203 | hidden_features, 204 | weight_attr=w_attr_1, 205 | bias_attr=b_attr_1) 206 | 207 | w_attr_2, b_attr_2 = self._init_weights() 208 | self.fc2 = nn.Linear(hidden_features, 209 | out_features, 210 | weight_attr=w_attr_2, 211 | bias_attr=b_attr_2) 212 | self.act = nn.ReLU() 213 | self.dropout = nn.Dropout(dropout) 214 | 215 | def _init_weights(self): 216 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02)) 217 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0)) 218 | return weight_attr, bias_attr 219 | 220 | def forward(self, x): 221 | x = self.fc1(x) 222 | x = self.act(x) 223 | x = self.dropout(x) 224 | x = self.fc2(x) 225 | x = self.dropout(x) 226 | return x 227 | 228 | 229 | class WindowAttention(nn.Layer): 230 | """Window based multihead attention, with relative position bias. 231 | 232 | Both shifted window and non-shifted window are supported. 233 | 234 | Attributes: 235 | dim: int, input dimension (channels) 236 | window_size: int, height and width of the window 237 | num_heads: int, number of attention heads 238 | qkv_bias: bool, if True, enable learnable bias to q,k,v, default: True 239 | qk_scale: float, override default qk scale head_dim**-0.5 if set, default: None 240 | attention_dropout: float, dropout of attention 241 | dropout: float, dropout for output 242 | """ 243 | 244 | def __init__(self, 245 | dim, 246 | window_size, 247 | num_heads, 248 | qkv_bias=True, 249 | qk_scale=None, 250 | attention_dropout=0., 251 | dropout=0.): 252 | super(WindowAttention, self).__init__() 253 | self.window_size = window_size 254 | self.num_heads = num_heads 255 | self.dim = dim 256 | self.dim_head = dim // num_heads 257 | self.scale = qk_scale or self.dim_head ** -0.5 258 | 259 | self.relative_position_bias_table = paddle.create_parameter( 260 | shape=[(2 * window_size[0] -1) * (2 * window_size[1] - 1), num_heads], 261 | dtype='float32', 262 | default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02)) 263 | 264 | # relative position index for each token inside window 265 | coords_h = paddle.arange(self.window_size[0]) 266 | coords_w = paddle.arange(self.window_size[1]) 267 | coords = paddle.stack(paddle.meshgrid([coords_h, coords_w])) # [2, window_h, window_w] 268 | coords_flatten = paddle.flatten(coords, 1) # [2, window_h * window_w] 269 | # 2, window_h * window_w, window_h * window_h 270 | relative_coords = coords_flatten.unsqueeze(2) - coords_flatten.unsqueeze(1) 271 | # winwod_h*window_w, window_h*window_w, 2 272 | relative_coords = relative_coords.transpose([1, 2, 0]) 273 | 274 | ## Swin-T v1 275 | # relative_coords[:, :, 0] += self.window_size[0] - 1 276 | # relative_coords[:, :, 1] += self.window_size[1] - 1 277 | # relative_coords[:, :, 0] *= 2* self.window_size[1] - 1 278 | # relative_position_index = relative_coords.sum(-1) # [window_size * window_size, window_size*window_size] 279 | # self.register_buffer("relative_position_index", relative_position_index) 280 | 281 | ## Swin-T v2, log-spaced coordinates, Eq.(4) 282 | log_relative_position_index = paddle.multiply(relative_coords.cast(dtype='float32').sign(), 283 | paddle.log((relative_coords.cast(dtype='float32').abs()+1))) 284 | self.register_buffer("log_relative_position_index", log_relative_position_index) 285 | ## Swin-T v2, small meta network, Eq.(3) 286 | self.cpb = Mlp_Relu(in_features=2, # delta x, delta y 287 | hidden_features=512, # TODO: hidden dims 288 | out_features=self.num_heads, 289 | dropout=dropout) 290 | 291 | w_attr_1, b_attr_1 = self._init_weights() 292 | self.qkv = nn.Linear(dim, 293 | dim * 3, 294 | weight_attr=w_attr_1, 295 | bias_attr=b_attr_1 if qkv_bias else False) 296 | 297 | self.attn_dropout = nn.Dropout(attention_dropout) 298 | 299 | w_attr_2, b_attr_2 = self._init_weights() 300 | self.proj = nn.Linear(dim, 301 | dim, 302 | weight_attr=w_attr_2, 303 | bias_attr=b_attr_2) 304 | self.proj_dropout = nn.Dropout(dropout) 305 | self.softmax = nn.Softmax(axis=-1) 306 | 307 | # Swin-T v2, Scaled cosine attention 308 | self.tau = paddle.create_parameter( 309 | shape = [num_heads, window_size[0]*window_size[1], window_size[0]*window_size[1]], 310 | dtype='float32', 311 | default_initializer=paddle.nn.initializer.Constant(1)) 312 | 313 | def _init_weights(self): 314 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02)) 315 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0)) 316 | return weight_attr, bias_attr 317 | 318 | def transpose_multihead(self, x): 319 | new_shape = x.shape[:-1] + [self.num_heads, self.dim_head] 320 | x = x.reshape(new_shape) 321 | x = x.transpose([0, 2, 1, 3]) 322 | return x 323 | 324 | def get_relative_pos_bias_from_pos_index(self): 325 | # relative_position_bias_table is a ParamBase object 326 | # https://github.com/PaddlePaddle/Paddle/blob/067f558c59b34dd6d8626aad73e9943cf7f5960f/python/paddle/fluid/framework.py#L5727 327 | table = self.relative_position_bias_table # N x num_heads 328 | # index is a tensor 329 | index = self.relative_position_index.reshape([-1]) # window_h*window_w * window_h*window_w 330 | # NOTE: paddle does NOT support indexing Tensor by a Tensor 331 | relative_position_bias = paddle.index_select(x=table, index=index) 332 | return relative_position_bias 333 | 334 | def get_continuous_relative_position_bias(self): 335 | # The continuous position bias approach adopts a small meta network on the relative coordinates 336 | continuous_relative_position_bias = self.cpb(self.log_relative_position_index) 337 | return continuous_relative_position_bias 338 | 339 | def forward(self, x, mask=None): 340 | qkv = self.qkv(x).chunk(3, axis=-1) # {list:3} 341 | q, k, v = map(self.transpose_multihead, qkv) # [bs*num_window=1*64,4,49,32] -> [bs*num_window=1*16,8,49,32]-> [bs*num_window=1*4,16,49,32]->[bs*num_window=1*1,32,49,32] 342 | 343 | # Swin-T v2, Scaled cosine attention 344 | qk = paddle.matmul(q, k, transpose_y=True) # [bs*num_window=1*64,num_heads=4,49,49] -> [bs*num_window=1*16,num_heads=8,49,49] -> [bs*num_window=1*4,num_heads=16,49,49] -> [bs*num_window=1*1,num_heads=32,49,49] 345 | q2 = paddle.multiply(q, q).sum(-1).sqrt().unsqueeze(3) 346 | k2 = paddle.multiply(k, k).sum(-1).sqrt().unsqueeze(3) 347 | attn = qk/paddle.clip(paddle.matmul(q2, k2, transpose_y=True), min=1e-6) 348 | attn = attn/paddle.clip(self.tau.unsqueeze(0), min=0.01) 349 | 350 | ## Swin-T v1 351 | # relative_position_bias = self.get_relative_pos_bias_from_pos_index() #[2401,num_heads=4]->[2401,8]->[2401,16]->[2401,32] 352 | ## Swin-T v2 353 | relative_position_bias = self.get_continuous_relative_position_bias() 354 | relative_position_bias = relative_position_bias.reshape( 355 | [self.window_size[0] * self.window_size[1], 356 | self.window_size[0] * self.window_size[1], 357 | -1]) # [49,49,num_heads=4]->[49,49,8]->[49,49,16]->[49,49,32] 358 | 359 | # nH, window_h*window_w, window_h*window_w 360 | relative_position_bias = relative_position_bias.transpose([2, 0, 1]) # [bs*num_window=1*64,49,49]->[1*16,49,49]->[1*4,49,49]->[1*1,49,49] 361 | attn = attn + relative_position_bias.unsqueeze(0) 362 | 363 | if mask is not None: 364 | nW = mask.shape[0] 365 | attn = attn.reshape( 366 | [x.shape[0] // nW, nW, self.num_heads, x.shape[1], x.shape[1]]) 367 | attn += mask.unsqueeze(1).unsqueeze(0) 368 | attn = attn.reshape([-1, self.num_heads, x.shape[1], x.shape[1]]) 369 | attn = self.softmax(attn) 370 | else: 371 | attn = self.softmax(attn) 372 | 373 | attn = self.attn_dropout(attn) # [bs*num_window=1*64,4,49,49]->[1*16,8,49,49]->[1*4,16,49,49]->[1*1,32,49,49] 374 | 375 | z = paddle.matmul(attn, v) # [bs*num_window=1*64,4,49,32]->[1*16,8,49,32]->[1*4,16,49,32]->[1*1,32,49,32] 376 | z = z.transpose([0, 2, 1, 3]) 377 | new_shape = z.shape[:-2] + [self.dim] 378 | z = z.reshape(new_shape) 379 | z = self.proj(z) 380 | z = self.proj_dropout(z) # [512,49,96]->[128,49,192]->[32,49,384]->[8,49,768] 381 | 382 | return z 383 | 384 | 385 | def windows_partition(x, window_size): 386 | """ partite windows into window_size x window_size 387 | Args: 388 | x: Tensor, shape=[b, h, w, c] 389 | window_size: int, window size 390 | Returns: 391 | x: Tensor, shape=[num_windows*b, window_size, window_size, c] 392 | """ 393 | 394 | B, H, W, C = x.shape 395 | x = x.reshape([B, H//window_size, window_size, W//window_size, window_size, C]) # [bs,num_window,window_size,num_window,window_size,C] 396 | x = x.transpose([0, 1, 3, 2, 4, 5]) # [bs,num_window,num_window,window_size,window_Size,C] 397 | x = x.reshape([-1, window_size, window_size, C]) #(bs*num_windows,window_size, window_size, C) 398 | 399 | return x 400 | 401 | 402 | def windows_reverse(windows, window_size, H, W): 403 | """ Window reverse 404 | Args: 405 | windows: (n_windows * B, window_size, window_size, C) 406 | window_size: (int) window size 407 | H: (int) height of image 408 | W: (int) width of image 409 | 410 | Returns: 411 | x: (B, H, W, C) 412 | """ 413 | 414 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 415 | x = windows.reshape([B, H // window_size, W // window_size, window_size, window_size, -1]) # [bs,num_window,num_window,window_size,window_Size,C] 416 | x = x.transpose([0, 1, 3, 2, 4, 5]) # [bs,num_window,window_size,num_window,window_size,C] 417 | x = x.reshape([B, H, W, -1]) #(bs,num_windows*window_size, num_windows*window_size, C) 418 | return x 419 | 420 | 421 | class SwinTransformerBlock(nn.Layer): 422 | """Swin transformer block 423 | 424 | Contains window multi head self attention, droppath, mlp, norm and residual. 425 | 426 | Attributes: 427 | dim: int, input dimension (channels) 428 | input_resolution: int, input resoultion 429 | num_heads: int, number of attention heads 430 | window_size: int, window size, default: 7 431 | shift_size: int, shift size for SW-MSA, default: 0 432 | mlp_ratio: float, ratio of mlp hidden dim and input embedding dim, default: 4. 433 | qkv_bias: bool, if True, enable learnable bias to q,k,v, default: True 434 | qk_scale: float, override default qk scale head_dim**-0.5 if set, default: None 435 | dropout: float, dropout for output, default: 0. 436 | attention_dropout: float, dropout of attention, default: 0. 437 | droppath: float, drop path rate, default: 0. 438 | """ 439 | 440 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 441 | mlp_ratio=4., qkv_bias=True, qk_scale=None, dropout=0., extra_norm=False, 442 | attention_dropout=0., droppath=0.): 443 | super(SwinTransformerBlock, self).__init__() 444 | self.dim = dim 445 | self.extra_norm = extra_norm # Swin-T v2, introduce a LN unit on the main branch every 6 layers 446 | self.input_resolution = input_resolution 447 | self.num_heads = num_heads 448 | self.window_size = window_size 449 | self.shift_size = shift_size 450 | self.mlp_ratio = mlp_ratio 451 | if min(self.input_resolution) <= self.window_size: 452 | self.shift_size = 0 453 | self.window_size = min(self.input_resolution) 454 | 455 | w_attr_1, b_attr_1 = self._init_weights_layernorm() 456 | self.norm1 = nn.LayerNorm(dim, 457 | weight_attr=w_attr_1, 458 | bias_attr=b_attr_1) 459 | 460 | self.attn = WindowAttention(dim, 461 | window_size=(self.window_size, self.window_size), 462 | num_heads=num_heads, 463 | qkv_bias=qkv_bias, 464 | qk_scale=qk_scale, 465 | attention_dropout=attention_dropout, 466 | dropout=dropout) 467 | self.drop_path = DropPath(droppath) if droppath > 0. else None 468 | 469 | w_attr_2, b_attr_2 = self._init_weights_layernorm() 470 | self.norm2 = nn.LayerNorm(dim, 471 | weight_attr=w_attr_2, 472 | bias_attr=b_attr_2) 473 | 474 | self.mlp = Mlp(in_features=dim, 475 | hidden_features=int(dim*mlp_ratio), 476 | dropout=dropout) 477 | if extra_norm: 478 | # Swin-T v2, introduce a LN unit on the main branch every 6 layers 479 | w_attr_3, b_attr_3 = self._init_weights_layernorm() 480 | self.norm3 = nn.LayerNorm(dim, 481 | weight_attr=w_attr_3, 482 | bias_attr=b_attr_3) 483 | 484 | if self.shift_size > 0: 485 | H, W = self.input_resolution 486 | img_mask = paddle.zeros((1, H, W, 1)) 487 | h_slices = (slice(0, -self.window_size), 488 | slice(-self.window_size, -self.shift_size), 489 | slice(-self.shift_size, None)) 490 | w_slices = (slice(0, -self.window_size), 491 | slice(-self.window_size, -self.shift_size), 492 | slice(-self.shift_size, None)) 493 | cnt = 0 494 | for h in h_slices: 495 | for w in w_slices: 496 | img_mask[:, h, w, :] = cnt 497 | cnt += 1 498 | 499 | mask_windows = windows_partition(img_mask, self.window_size) 500 | mask_windows = mask_windows.reshape((-1, self.window_size * self.window_size)) 501 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 502 | attn_mask = paddle.where(attn_mask != 0, 503 | paddle.ones_like(attn_mask) * float(-100.0), 504 | attn_mask) 505 | attn_mask = paddle.where(attn_mask == 0, 506 | paddle.zeros_like(attn_mask), 507 | attn_mask) 508 | else: 509 | attn_mask = None 510 | 511 | self.register_buffer("attn_mask", attn_mask) 512 | 513 | def _init_weights_layernorm(self): 514 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1)) 515 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0)) 516 | return weight_attr, bias_attr 517 | 518 | def forward(self, x): 519 | H, W = self.input_resolution 520 | B, L, C = x.shape 521 | h = x 522 | # x = self.norm1(x) # Swin-T v1, pre-norm 523 | 524 | new_shape = [B, H, W, C] 525 | x = x.reshape(new_shape) # [bs,H,W,C] 526 | 527 | if self.shift_size > 0: 528 | shifted_x = paddle.roll(x, 529 | shifts=(-self.shift_size, -self.shift_size), 530 | axis=(1, 2)) # [bs,H,W,C] 531 | else: 532 | shifted_x = x 533 | 534 | x_windows = windows_partition(shifted_x, self.window_size) # [bs*num_windows,7,7,C] 535 | x_windows = x_windows.reshape([-1, self.window_size * self.window_size, C]) # [bs*num_windows,7*7,C] 536 | 537 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # [bs*num_windows,7*7,C] 538 | attn_windows = attn_windows.reshape([-1, self.window_size, self.window_size, C]) # [bs*num_windows,7,7,C] 539 | 540 | shifted_x = windows_reverse(attn_windows, self.window_size, H, W) # [bs,H,W,C] 541 | 542 | # reverse cyclic shift 543 | if self.shift_size > 0: 544 | x = paddle.roll(shifted_x, 545 | shifts=(self.shift_size, self.shift_size), 546 | axis=(1, 2)) 547 | else: 548 | x = shifted_x 549 | 550 | x = x.reshape([B, H*W, C]) # [bs,H*W,C] 551 | x = self.norm1(x) # Swin-T v2, post-norm 552 | 553 | if self.drop_path is not None: 554 | x = h + self.drop_path(x) 555 | else: 556 | x = h + x 557 | h = x # [bs,H*W,C] 558 | # x = self.norm2(x) # Swin-T v1, pre-norm 559 | x = self.mlp(x) # [bs,H*W,C] 560 | x = self.norm2(x) # Swin-T v2, post-norm 561 | if self.drop_path is not None: 562 | x = h + self.drop_path(x) 563 | else: 564 | x = h + x 565 | 566 | if self.extra_norm: # Swin-T v2 567 | x = self.norm3(x) 568 | 569 | return x 570 | 571 | 572 | class SwinTransformerStage(nn.Layer): 573 | """Stage layers for swin transformer 574 | 575 | Stage layers contains a number of Transformer blocks and an optional 576 | patch merging layer, patch merging is not applied after last stage 577 | 578 | Attributes: 579 | dim: int, embedding dimension 580 | input_resolution: tuple, input resoliution 581 | depth: list, num of blocks in each stage 582 | blocks: nn.LayerList, contains SwinTransformerBlocks for one stage 583 | downsample: PatchMerging, patch merging layer, none if last stage 584 | """ 585 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 586 | mlp_ratio=4., qkv_bias=True, qk_scale=None, dropout=0., 587 | attention_dropout=0., droppath=0., downsample=None, sum_depth=None): 588 | super(SwinTransformerStage, self).__init__() 589 | self.dim = dim 590 | self.input_resolution = input_resolution 591 | self.depth = depth 592 | 593 | self.blocks = nn.LayerList() 594 | for i in range(depth): 595 | self.blocks.append( 596 | SwinTransformerBlock( 597 | dim=dim, input_resolution=input_resolution, 598 | num_heads=num_heads, window_size=window_size, 599 | shift_size=0 if (i % 2 == 0) else window_size // 2, 600 | mlp_ratio=mlp_ratio, 601 | extra_norm = sum_depth!=None and (i+sum_depth+1)%6==0, # Swin-T v2 602 | qkv_bias=qkv_bias, qk_scale=qk_scale, 603 | dropout=dropout, attention_dropout=attention_dropout, 604 | droppath=droppath[i] if isinstance(droppath, list) else droppath)) 605 | 606 | if downsample is not None: 607 | self.downsample = downsample(input_resolution, dim=dim) 608 | else: 609 | self.downsample = None 610 | 611 | def forward(self, x): 612 | for block in self.blocks: 613 | x = block(x) # [bs,56*56,96] -> [bs,28*28,96*2] -> [bs,14*14,96*4] -> [bs,7*7,96*8] 614 | if self.downsample is not None: 615 | x = self.downsample(x) # [bs,28*28,96*2] -> [bs,14*14,96*4] -> [bs,7*7,96*8] 616 | 617 | return x 618 | 619 | 620 | class SwinTransformer(nn.Layer): 621 | """SwinTransformer class 622 | 623 | Attributes: 624 | num_classes: int, num of image classes 625 | num_stages: int, num of stages contains patch merging and Swin blocks 626 | depths: list of int, num of Swin blocks in each stage 627 | num_heads: int, num of heads in attention module 628 | embed_dim: int, output dimension of patch embedding 629 | num_features: int, output dimension of whole network before classifier 630 | mlp_ratio: float, hidden dimension of mlp layer is mlp_ratio * mlp input dim 631 | qkv_bias: bool, if True, set qkv layers have bias enabled 632 | qk_scale: float, scale factor for qk. 633 | ape: bool, if True, set to use absolute positional embeddings 634 | window_size: int, size of patch window for inputs 635 | dropout: float, dropout rate for linear layer 636 | dropout_attn: float, dropout rate for attention 637 | patch_embedding: PatchEmbedding, patch embedding instance 638 | patch_resolution: tuple, number of patches in row and column 639 | position_dropout: nn.Dropout, dropout op for position embedding 640 | stages: SwinTransformerStage, stage instances. 641 | norm: nn.LayerNorm, norm layer applied after transformer 642 | avgpool: nn.AveragePool2D, pooling layer before classifer 643 | fc: nn.Linear, classifier op. 644 | """ 645 | def __init__(self, 646 | image_size=224, 647 | patch_size=4, 648 | in_channels=3, 649 | num_classes=1000, 650 | embed_dim=96, 651 | depths=[2, 2, 6, 2], 652 | num_heads=[3, 6, 12, 24], 653 | window_size=7, 654 | mlp_ratio=4., 655 | qkv_bias=True, 656 | qk_scale=None, 657 | dropout=0., 658 | attention_dropout=0., 659 | droppath=0., 660 | ape=False, 661 | extra_norm=False): 662 | super(SwinTransformer, self).__init__() 663 | 664 | self.num_classes = num_classes 665 | self.num_stages = len(depths) 666 | self.embed_dim = embed_dim 667 | self.num_features = int(self.embed_dim * 2 ** (self.num_stages - 1)) 668 | self.mlp_ratio = mlp_ratio 669 | self.ape = ape 670 | 671 | self.patch_embedding = PatchEmbedding(image_size=image_size, 672 | patch_size=patch_size, 673 | in_channels=in_channels, 674 | embed_dim=embed_dim) 675 | num_patches = self.patch_embedding.num_patches 676 | self.patches_resolution = self.patch_embedding.patches_resolution 677 | 678 | 679 | if self.ape: 680 | self.absolute_positional_embedding = paddle.nn.ParameterList([ 681 | paddle.create_parameter( 682 | shape=[1, num_patches, self.embed_dim], dtype='float32', 683 | default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02))]) 684 | 685 | self.position_dropout = nn.Dropout(dropout) 686 | 687 | depth_decay = [x.item() for x in paddle.linspace(0, droppath, sum(depths))] 688 | 689 | self.stages = nn.LayerList() 690 | for stage_idx in range(self.num_stages): 691 | stage = SwinTransformerStage( 692 | dim=int(self.embed_dim * 2 ** stage_idx), 693 | input_resolution=( 694 | self.patches_resolution[0] // (2 ** stage_idx), 695 | self.patches_resolution[1] // (2 ** stage_idx)), 696 | depth=depths[stage_idx], 697 | sum_depth=sum(depths[:stage_idx]) if extra_norm else None, # Swin-T v2 698 | num_heads=num_heads[stage_idx], 699 | window_size=window_size, 700 | mlp_ratio=mlp_ratio, 701 | qkv_bias=qkv_bias, 702 | qk_scale=qk_scale, 703 | dropout=dropout, 704 | attention_dropout=attention_dropout, 705 | droppath=depth_decay[ 706 | sum(depths[:stage_idx]):sum(depths[:stage_idx+1])], 707 | downsample=PatchMerging if ( 708 | stage_idx < self.num_stages-1) else None, 709 | ) 710 | self.stages.append(stage) 711 | 712 | w_attr_1, b_attr_1 = self._init_weights_layernorm() 713 | self.norm = nn.LayerNorm(self.num_features, 714 | weight_attr=w_attr_1, 715 | bias_attr=b_attr_1) 716 | 717 | self.avgpool = nn.AdaptiveAvgPool1D(1) 718 | w_attr_2, b_attr_2 = self._init_weights() 719 | self.fc = nn.Linear(self.num_features, 720 | self.num_classes, 721 | weight_attr=w_attr_2, 722 | bias_attr=b_attr_2) 723 | 724 | def _init_weights_layernorm(self): 725 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1)) 726 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0)) 727 | return weight_attr, bias_attr 728 | 729 | def _init_weights(self): 730 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02)) 731 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0)) 732 | return weight_attr, bias_attr 733 | 734 | def forward_features(self, x): 735 | x = self.patch_embedding(x) # [bs,H*W,96] 736 | if self.ape: 737 | x = x + self.absolute_positional_embedding 738 | x = self.position_dropout(x) # [bs,H*W,96] 739 | 740 | for stage in self.stages: 741 | x = stage(x) # [bs,784,192],[bs,196,384],[bs,49,768],[bs,49,768] 742 | 743 | x = self.norm(x) # [bs,49,768] 744 | x = x.transpose([0, 2, 1]) 745 | x = self.avgpool(x) # [bs,768,1] 746 | x = x.flatten(1) # [bs,768] 747 | return x 748 | 749 | def forward(self, x): 750 | x = self.forward_features(x) # [bs,768] 751 | x = self.fc(x) # [bs,1000] 752 | return x 753 | 754 | 755 | def build_swin(config): 756 | model = SwinTransformer( 757 | image_size=config.DATA.IMAGE_SIZE, 758 | patch_size=config.MODEL.TRANS.PATCH_SIZE, 759 | in_channels=config.MODEL.TRANS.IN_CHANNELS, 760 | embed_dim=config.MODEL.TRANS.EMBED_DIM, 761 | num_classes=config.MODEL.NUM_CLASSES, 762 | depths=config.MODEL.TRANS.STAGE_DEPTHS, 763 | num_heads=config.MODEL.TRANS.NUM_HEADS, 764 | mlp_ratio=config.MODEL.TRANS.MLP_RATIO, 765 | qkv_bias=config.MODEL.TRANS.QKV_BIAS, 766 | qk_scale=config.MODEL.TRANS.QK_SCALE, 767 | ape=config.MODEL.TRANS.APE, 768 | window_size=config.MODEL.TRANS.WINDOW_SIZE, 769 | dropout=config.MODEL.DROPOUT, 770 | attention_dropout=config.MODEL.ATTENTION_DROPOUT, 771 | droppath=config.MODEL.DROP_PATH, 772 | extra_norm=config.MODEL.TRANS.EXTRA_NORM) 773 | return model 774 | 775 | if __name__ == '__main__': 776 | from main_single_gpu import get_arguments 777 | from config import get_config 778 | from config import update_config 779 | arguments = get_arguments() 780 | config = get_config() 781 | config = update_config(config, arguments) 782 | 783 | model = build_swin(config) 784 | image = paddle.randn([1, 3, 224, 224]) 785 | output = model(image) 786 | print(output.shape) 787 | --------------------------------------------------------------------------------