├── swin_v2.PNG
├── run_train.sh
├── run_eval.sh
├── transforms.py
├── configs
└── swinv2_base_patch4_window7_224.yaml
├── stat_define.py
├── droppath.py
├── utils.py
├── losses.py
├── README.md
├── random_erasing.py
├── config.py
├── port_weights
├── load_pytorch_weights_384.py
├── load_pytorch_weights.py
└── load_pytorch_weights_large_384.py
├── modification.md
├── datasets.py
├── mixup.py
├── auto_augment.py
├── LICENSE
├── main_single_gpu.py
├── main_multi_gpu.py
└── swin_transformer.py
/swin_v2.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nku-shengzheliu/PaddlePaddle-Swin-Transformer-V2/HEAD/swin_v2.PNG
--------------------------------------------------------------------------------
/run_train.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 \
2 | python main_single_gpu.py \
3 | -cfg='./configs/swin_tiny_patch4_window7_224.yaml' \
4 | -dataset='imagenet2012' \
5 | -batch_size=4 \
6 | -data_path='/dataset/imagenet' \
7 |
--------------------------------------------------------------------------------
/run_eval.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 \
2 | python main_single_gpu.py \
3 | -cfg='./configs/swin_base_patch4_window7_224.yaml' \
4 | -dataset='imagenet2012' \
5 | -batch_size=32 \
6 | -data_path='/dataset/imagenet' \
7 | -eval \
8 | -pretrained='./swin_base_patch4_window7_224' \
9 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | import random
2 | import paddle
3 | import paddle.nn
4 | import paddle.vision.transforms as T
5 |
6 |
7 | class RandomHorizontalFlip():
8 | def __init__(self, p=0.5):
9 | self.p = p
10 |
11 | def __call__(self, image):
12 | if random.random() < self.p:
13 | return T.hflip(image)
14 | return image
15 |
--------------------------------------------------------------------------------
/configs/swinv2_base_patch4_window7_224.yaml:
--------------------------------------------------------------------------------
1 | DATA:
2 | IMAGE_SIZE: 224
3 | CROP_PCT: 0.90
4 | MODEL:
5 | TYPE: swin
6 | NAME: swin_base_patch4_window7_224
7 | DROP_PATH: 0.5
8 | TRANS:
9 | EMBED_DIM: 128
10 | STAGE_DEPTHS: [2, 2, 18, 2]
11 | NUM_HEADS: [4, 8, 16, 32]
12 | WINDOW_SIZE: 7
13 | PATCH_SIZE: 4
14 | EXTRA_NORM: False
15 |
16 |
--------------------------------------------------------------------------------
/stat_define.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import paddle
4 | from config import get_config
5 | from swin_transformer import build_swin as build_model
6 |
7 | def count_gelu(layer, input, output):
8 | activation_flops = 8
9 | x = input[0]
10 | num = x.numel()
11 | layer.total_ops += num * activation_flops
12 |
13 |
14 | def count_softmax(layer, input, output):
15 | softmax_flops = 5 # max/substract, exp, sum, divide
16 | x = input[0]
17 | num = x.numel()
18 | layer.total_ops += num * softmax_flops
19 |
20 |
21 | def count_layernorm(layer, input, output):
22 | layer_norm_flops = 5 # get mean (sum), get variance (square and sum), scale(multiply)
23 | x = input[0]
24 | num = x.numel()
25 | layer.total_ops += num * layer_norm_flops
26 |
27 |
28 | cfg = './configs/swin_tiny_patch4_window7_224.yaml'
29 | input_size = (1, 3, 224, 224)
30 | config = get_config(cfg)
31 | model = build_model(config)
32 |
33 | custom_ops = {paddle.nn.GELU: count_gelu,
34 | paddle.nn.LayerNorm: count_layernorm,
35 | paddle.nn.Softmax: count_softmax,
36 | }
37 | print(os.path.basename(cfg))
38 | paddle.flops(model,
39 | input_size=input_size,
40 | custom_ops=custom_ops,
41 | print_detail=False)
42 |
43 |
44 | #for cfg in glob.glob('./configs/*.yaml'):
45 | # #cfg = './configs/swin_base_patch4_window7_224.yaml'
46 | # input_size = (1, 3, int(cfg[-8:-5]), int(cfg[-8:-5]))
47 | # config = get_config(cfg)
48 | # model = build_model(config)
49 | #
50 | #
51 | # custom_ops = {paddle.nn.GELU: count_gelu,
52 | # paddle.nn.LayerNorm: count_layernorm,
53 | # paddle.nn.Softmax: count_softmax,
54 | # }
55 | # print(os.path.basename(cfg))
56 | # paddle.flops(model,
57 | # input_size=input_size,
58 | # custom_ops=custom_ops,
59 | # print_detail=False)
60 | # print('-----------')
61 |
--------------------------------------------------------------------------------
/droppath.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Droppath, reimplement from https://github.com/yueatsprograms/Stochastic_Depth
17 | """
18 |
19 | import numpy as np
20 | import paddle
21 | import paddle.nn as nn
22 |
23 |
24 | class DropPath(nn.Layer):
25 | """DropPath class"""
26 | def __init__(self, drop_prob=None):
27 | super(DropPath, self).__init__()
28 | self.drop_prob = drop_prob
29 |
30 | def drop_path(self, inputs):
31 | """drop path op
32 | Args:
33 | input: tensor with arbitrary shape
34 | drop_prob: float number of drop path probability, default: 0.0
35 | training: bool, if current mode is training, default: False
36 | Returns:
37 | output: output tensor after drop path
38 | """
39 | # if prob is 0 or eval mode, return original input
40 | if self.drop_prob == 0. or not self.training:
41 | return inputs
42 | keep_prob = 1 - self.drop_prob
43 | keep_prob = paddle.to_tensor(keep_prob, dtype='float32')
44 | shape = (inputs.shape[0], ) + (1, ) * (inputs.ndim - 1) # shape=(N, 1, 1, 1)
45 | random_tensor = keep_prob + paddle.rand(shape, dtype=inputs.dtype)
46 | random_tensor = random_tensor.floor() # mask
47 | output = inputs.divide(keep_prob) * random_tensor # divide is to keep same output expectation
48 | return output
49 |
50 | def forward(self, inputs):
51 | return self.drop_path(inputs)
52 |
53 |
54 | #def main():
55 | # tmp = paddle.to_tensor(np.random.rand(8, 16, 8, 8), dtype='float32')
56 | # dp = DropPath(0.5)
57 | # out = dp(tmp)
58 | # print(out)
59 | #
60 | #if __name__ == "__main__":
61 | # main()
62 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """utils for ViT
16 |
17 | Contains AverageMeter for monitoring, get_exclude_from_decay_fn for training
18 | and WarmupCosineScheduler for training
19 |
20 | """
21 |
22 | import math
23 | from paddle.optimizer.lr import LRScheduler
24 |
25 |
26 | class AverageMeter():
27 | """ Meter for monitoring losses"""
28 | def __init__(self):
29 | self.avg = 0
30 | self.sum = 0
31 | self.cnt = 0
32 | self.reset()
33 |
34 | def reset(self):
35 | """reset all values to zeros"""
36 | self.avg = 0
37 | self.sum = 0
38 | self.cnt = 0
39 |
40 | def update(self, val, n=1):
41 | """update avg by val and n, where val is the avg of n values"""
42 | self.sum += val * n
43 | self.cnt += n
44 | self.avg = self.sum / self.cnt
45 |
46 |
47 |
48 | def get_exclude_from_weight_decay_fn(exclude_list=[]):
49 | """ Set params with no weight decay during the training
50 |
51 | For certain params, e.g., positional encoding in ViT, weight decay
52 | may not needed during the learning, this method is used to find
53 | these params.
54 |
55 | Args:
56 | exclude_list: a list of params names which need to exclude
57 | from weight decay.
58 | Returns:
59 | exclude_from_weight_decay_fn: a function returns True if param
60 | will be excluded from weight decay
61 | """
62 | if len(exclude_list) == 0:
63 | exclude_from_weight_decay_fn = None
64 | else:
65 | def exclude_fn(param):
66 | for name in exclude_list:
67 | if param.endswith(name):
68 | return False
69 | return True
70 | exclude_from_weight_decay_fn = exclude_fn
71 | return exclude_from_weight_decay_fn
72 |
73 |
74 | class WarmupCosineScheduler(LRScheduler):
75 | """Warmup Cosine Scheduler
76 |
77 | First apply linear warmup, then apply cosine decay schedule.
78 | Linearly increase learning rate from "warmup_start_lr" to "start_lr" over "warmup_epochs"
79 | Cosinely decrease learning rate from "start_lr" to "end_lr" over remaining
80 | "total_epochs - warmup_epochs"
81 |
82 | Attributes:
83 | learning_rate: the starting learning rate (without warmup), not used here!
84 | warmup_start_lr: warmup starting learning rate
85 | start_lr: the starting learning rate (without warmup)
86 | end_lr: the ending learning rate after whole loop
87 | warmup_epochs: # of epochs for warmup
88 | total_epochs: # of total epochs (include warmup)
89 | """
90 | def __init__(self,
91 | learning_rate,
92 | warmup_start_lr,
93 | start_lr,
94 | end_lr,
95 | warmup_epochs,
96 | total_epochs,
97 | cycles=0.5,
98 | last_epoch=-1,
99 | verbose=False):
100 | """init WarmupCosineScheduler """
101 | self.warmup_epochs = warmup_epochs
102 | self.total_epochs = total_epochs
103 | self.warmup_start_lr = warmup_start_lr
104 | self.start_lr = start_lr
105 | self.end_lr = end_lr
106 | self.cycles = cycles
107 | super(WarmupCosineScheduler, self).__init__(learning_rate, last_epoch, verbose)
108 |
109 | def get_lr(self):
110 | """ return lr value """
111 | if self.last_epoch < self.warmup_epochs:
112 | val = (self.start_lr - self.warmup_start_lr) * float(
113 | self.last_epoch)/float(self.warmup_epochs) + self.warmup_start_lr
114 | return val
115 |
116 | progress = float(self.last_epoch - self.warmup_epochs) / float(
117 | max(1, self.total_epochs - self.warmup_epochs))
118 | val = max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
119 | val = max(0.0, val * (self.start_lr - self.end_lr) + self.end_lr)
120 | return val
121 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """ Implement Loss functions """
16 | import paddle
17 | import paddle.nn as nn
18 | import paddle.nn.functional as F
19 |
20 |
21 | class LabelSmoothingCrossEntropyLoss(nn.Layer):
22 | """ cross entropy loss for label smoothing
23 | Args:
24 | smoothing: float, smoothing rate
25 | x: tensor, predictions (before softmax) with shape [N, num_classes]
26 | target: tensor, target label with shape [N]
27 | Return:
28 | loss: float, cross entropy loss value
29 | """
30 | def __init__(self, smoothing=0.1):
31 | super().__init__()
32 | assert 0 <= smoothing < 1.0
33 | self.smoothing = smoothing
34 | self.confidence = 1 - smoothing
35 |
36 | def forward(self, x, target):
37 | log_probs = F.log_softmax(x) # [N, num_classes]
38 | # target_index is used to get prob for each of the N samples
39 | target_index = paddle.zeros([x.shape[0], 2], dtype='int64') # [N, 2]
40 | target_index[:, 0] = paddle.arange(x.shape[0])
41 | target_index[:, 1] = target
42 |
43 | nll_loss = -log_probs.gather_nd(index=target_index) # index: [N]
44 | smooth_loss = -log_probs.mean(axis=-1)
45 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
46 | return loss.mean()
47 |
48 |
49 | class SoftTargetCrossEntropyLoss(nn.Layer):
50 | """ cross entropy loss for soft target
51 | Args:
52 | x: tensor, predictions (before softmax) with shape [N, num_classes]
53 | target: tensor, soft target with shape [N, num_classes]
54 | Returns:
55 | loss: float, the mean loss value
56 | """
57 | def __init__(self):
58 | super().__init__()
59 |
60 | def forward(self, x, target):
61 | loss = paddle.sum(-target * F.log_softmax(x, axis=-1), axis=-1)
62 | return loss.mean()
63 |
64 |
65 | class DistillationLoss(nn.Layer):
66 | """Distillation loss function
67 | This layer includes the orginal loss (criterion) and a extra
68 | distillation loss (criterion), which computes the loss with
69 | different type options, between current model and
70 | a teacher model as its supervision.
71 |
72 | Args:
73 | base_criterion: nn.Layer, the original criterion
74 | teacher_model: nn.Layer, the teacher model as supervision
75 | distillation_type: str, one of ['none', 'soft', 'hard']
76 | alpha: float, ratio of base loss (* (1-alpha))
77 | and distillation loss( * alpha)
78 | tao: float, temperature in distillation
79 | """
80 | def __init__(self,
81 | base_criterion,
82 | teacher_model,
83 | distillation_type,
84 | alpha,
85 | tau):
86 | super().__init__()
87 | assert distillation_type in ['none', 'soft', 'hard']
88 | self.base_criterion = base_criterion
89 | self.teacher_model = teacher_model
90 | self.type = distillation_type
91 | self.alpha = alpha
92 | self.tau = tau
93 |
94 | def forward(self, inputs, outputs, targets):
95 | """
96 | Args:
97 | inputs: tensor, the orginal model inputs
98 | outputs: tensor, the outputs of the model
99 | outputds_kd: tensor, the distillation outputs of the model,
100 | this is usually obtained by a separate branch
101 | in the last layer of the model
102 | targets: tensor, the labels for the base criterion
103 | """
104 | outputs, outputs_kd = outputs[0], outputs[1]
105 | base_loss = self.base_criterion(outputs, targets)
106 | if self.type == 'none':
107 | return base_loss
108 |
109 | with paddle.no_grad():
110 | teacher_outputs = self.teacher_model(inputs)
111 |
112 | if self.type == 'soft':
113 | distillation_loss = F.kl_div(
114 | F.log_softmax(outputs_kd / self.tau, axis=1),
115 | F.log_softmax(teacher_outputs / self.tau, axis=1),
116 | reduction='sum') * (self.tau * self.tau) / outputs_kd.numel()
117 | elif self.type == 'hard':
118 | distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(axis=1))
119 |
120 | loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
121 | return loss
122 |
123 |
124 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Swin Transformer V2: Scaling Up Capacity and Resolution, [arxiv](https://arxiv.org/pdf/2111.09883)
2 |
3 | PaddlePaddle training/validation code and pretrained models for **Swin Transformer V2**.
4 |
5 | The official pytorch implementation is [here](https://github.com/microsoft/Swin-Transformer).
6 |
7 | This implementation is developed by [PaddleViT](https://github.com/BR-IDL/PaddleViT.git).
8 |
9 |
10 |
11 |
Comparison of the WindowAttention module between Swin Transformer V1 and Swin Transformer V2
12 |
13 |
14 | ## Update
15 |
16 | * Update (2021-11-27): Complete the modification of WindowAttention module according to the original paper
17 | - [x] post-norm configuration
18 | - [x] scaled cosine attention
19 | - [x] log-spaced continuous relative position bias
20 |
21 | ## Code modification explanation
22 |
23 | The code modification explanation is [here](./modification.md)
24 |
25 | ## Models trained from scratch using PaddleViT
26 |
27 | | Model | Acc@1 | Acc@5 | #Params | FLOPs | Image Size | Crop_pct | Interpolation | Link |
28 | |-------------------------------|-------|-------|---------|--------|------------|----------|---------------|--------------|
29 | | swin_b_224 | | | 88.9M | 15.3G | 224 | 0.9 | Log-CPB | coming soon |
30 |
31 | > *The results are evaluated on ImageNet2012 validation set.
32 |
33 |
34 | ## Requirements
35 | - Python>=3.6
36 | - yaml>=0.2.5
37 | - [PaddlePaddle](https://www.paddlepaddle.org.cn/documentation/docs/en/install/index_en.html)>=2.1.0
38 | - [yacs](https://github.com/rbgirshick/yacs)>=0.1.8
39 |
40 | ## Data
41 | ImageNet2012 dataset is used in the following folder structure:
42 | ```
43 | │imagenet/
44 | ├──train/
45 | │ ├── n01440764
46 | │ │ ├── n01440764_10026.JPEG
47 | │ │ ├── n01440764_10027.JPEG
48 | │ │ ├── ......
49 | │ ├── ......
50 | ├──val/
51 | │ ├── ILSVRC2012_val_00000293.JPEG
52 | │ ├── ILSVRC2012_val_00002138.JPEG
53 | │ ├── ......
54 | ```
55 |
56 | ## Usage
57 | To use the model with pretrained weights, download the `.pdparam` weight file and change related file paths in the following python scripts. The model config files are located in `./configs/`.
58 |
59 | For example, assume the downloaded weight file is stored in `./swin_base_patch4_window7_224.pdparams`, to use the `swin_base_patch4_window7_224` model in python:
60 | ```python
61 | from config import get_config
62 | from swin import build_swin as build_model
63 | # config files in ./configs/
64 | config = get_config('./configs/swinv2_base_patch4_window7_224.yaml')
65 | # build model
66 | model = build_model(config)
67 | # load pretrained weights, .pdparams is NOT needed
68 | model_state_dict = paddle.load('./swinv2_base_patch4_window7_224')
69 | model.set_dict(model_state_dict)
70 | ```
71 |
72 | ## Evaluation
73 | To evaluate Swin Transformer model performance on ImageNet2012 with a single GPU, run the following script using command line:
74 | ```shell
75 | sh run_eval.sh
76 | ```
77 | or
78 | ```shell
79 | CUDA_VISIBLE_DEVICES=0 \
80 | python main_single_gpu.py \
81 | -cfg='./configs/swinv2_base_patch4_window7_224.yaml' \
82 | -dataset='imagenet2012' \
83 | -batch_size=16 \
84 | -data_path='/dataset/imagenet' \
85 | -eval \
86 | -pretrained='./swinv2_base_patch4_window7_224'
87 | ```
88 |
89 |
90 |
91 | Run evaluation using multi-GPUs:
92 |
93 |
94 |
95 | ```shell
96 | sh run_eval_multi.sh
97 | ```
98 | or
99 | ```shell
100 | CUDA_VISIBLE_DEVICES=0,1,2,3 \
101 | python main_multi_gpu.py \
102 | -cfg='./configs/swinv2_base_patch4_window7_224.yaml' \
103 | -dataset='imagenet2012' \
104 | -batch_size=16 \
105 | -data_path='/dataset/imagenet' \
106 | -eval \
107 | -pretrained='./swinv2_base_patch4_window7_224'
108 | ```
109 |
110 |
111 |
112 |
113 | ## Training
114 | To train the Swin Transformer model on ImageNet2012 with single GPU, run the following script using command line:
115 | ```shell
116 | sh run_train.sh
117 | ```
118 | or
119 | ```shell
120 | CUDA_VISIBLE_DEVICES=0 \
121 | python main_singel_gpu.py \
122 | -cfg='./configs/swinv2_base_patch4_window7_224.yaml' \
123 | -dataset='imagenet2012' \
124 | -batch_size=32 \
125 | -data_path='/dataset/imagenet' \
126 | ```
127 |
128 |
129 |
130 |
131 | Run training using multi-GPUs:
132 |
133 |
134 |
135 | ```shell
136 | sh run_train_multi.sh
137 | ```
138 | or
139 | ```shell
140 | CUDA_VISIBLE_DEVICES=0,1,2,3 \
141 | python main_multi_gpu.py \
142 | -cfg='./configs/swinv2_base_patch4_window7_224.yaml' \
143 | -dataset='imagenet2012' \
144 | -batch_size=16 \
145 | -data_path='/dataset/imagenet' \
146 | ```
147 |
148 |
149 |
150 | ## Reference
151 | ```
152 | @article{liu2021swin,
153 | title={Swin Transformer V2: Scaling Up Capacity and Resolution},
154 | author={Liu, Ze and Hu, Han and Lin, Yutong and Yao, Zhuliang and Xie, Zhenda and Wei, Yixuan and Ning, Jia and Cao, Yue and Zhang, Zheng and Dong, Li and others},
155 | journal={arXiv preprint arXiv:2111.09883},
156 | year={2021}
157 | }
158 | ```
159 |
160 |
--------------------------------------------------------------------------------
/random_erasing.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Random Erasing for image tensor"""
16 |
17 | import random
18 | import math
19 | import paddle
20 |
21 |
22 | def _get_pixels(per_pixel, rand_color, patch_size, dtype="float32"):
23 | if per_pixel:
24 | return paddle.normal(shape=patch_size).astype(dtype)
25 | elif rand_color:
26 | return paddle.normal(shape=(patch_size[0], 1, 1)).astype(dtype)
27 | else:
28 | return paddle.zeros((patch_size[0], 1, 1)).astype(dtype)
29 |
30 |
31 | class RandomErasing(object):
32 | """
33 | Args:
34 | prob: probability of performing random erasing
35 | min_area: Minimum percentage of erased area wrt input image area
36 | max_area: Maximum percentage of erased area wrt input image area
37 | min_aspect: Minimum aspect ratio of earsed area
38 | max_aspect: Maximum aspect ratio of earsed area
39 | mode: pixel color mode, in ['const', 'rand', 'pixel']
40 | 'const' - erase block is constant valued 0 for all channels
41 | 'rand' - erase block is valued random color (same per-channel)
42 | 'pixel' - erase block is vauled random color per pixel
43 | min_count: Minimum # of ereasing blocks per image.
44 | max_count: Maximum # of ereasing blocks per image. Area per box is scaled by count
45 | per-image count is randomly chosen between min_count to max_count
46 | """
47 | def __init__(self, prob=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
48 | mode='const', min_count=1, max_count=None, num_splits=0):
49 | self.prob = prob
50 | self.min_area = min_area
51 | self.max_area = max_area
52 | max_aspect = max_aspect or 1 / min_aspect
53 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
54 | self.min_count = min_count
55 | self.max_count = max_count or min_count
56 | self.num_splits = num_splits
57 | mode = mode.lower()
58 | self.rand_color = False
59 | self.per_pixel = False
60 | if mode == "rand":
61 | self.rand_color = True
62 | elif mode == "pixel":
63 | self.per_pixel = True
64 | else:
65 | assert not mode or mode == "const"
66 |
67 | def _erase(self, img, chan, img_h, img_w, dtype):
68 | if random.random() > self.prob:
69 | return
70 | area = img_h * img_w
71 | count = self.min_count if self.min_count == self.max_count else \
72 | random.randint(self.min_count, self.max_count)
73 | for _ in range(count):
74 | for attempt in range(10):
75 | target_area = random.uniform(self.min_area, self.max_area) * area / count
76 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
77 | h = int(round(math.sqrt(target_area * aspect_ratio)))
78 | w = int(round(math.sqrt(target_area / aspect_ratio)))
79 | if w < img_w and h < img_h:
80 | top = random.randint(0, img_h - h)
81 | left = random.randint(0, img_w - w)
82 | img[:, top:top+h, left:left+w] = _get_pixels(
83 | self.per_pixel, self.rand_color, (chan, h, w),
84 | dtype=dtype)
85 | break
86 |
87 | def __call__(self, input):
88 | if len(input.shape) == 3:
89 | self._erase(input, *input.shape, input.dtype)
90 | else:
91 | batch_size, chan, img_h, img_w = input.shape
92 | batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
93 | for i in range(batch_start, batch_size):
94 | self._erase(input[i], chan, img_h, img_w, input.dtype)
95 | return input
96 |
97 |
98 |
99 | #def main():
100 | # re = RandomErasing(prob=1.0, min_area=0.2, max_area=0.6, mode='rand')
101 | # #re = RandomErasing(prob=1.0, min_area=0.2, max_area=0.6, mode='const')
102 | # #re = RandomErasing(prob=1.0, min_area=0.2, max_area=0.6, mode='pixel')
103 | # import PIL.Image as Image
104 | # import numpy as np
105 | # paddle.set_device('cpu')
106 | # img = paddle.to_tensor(np.asarray(Image.open('./lenna.png'))).astype('float32')
107 | # img = img / 255.0
108 | # img = paddle.transpose(img, [2, 0, 1])
109 | # new_img = re(img)
110 | # new_img = new_img * 255.0
111 | # new_img = paddle.transpose(new_img, [1, 2, 0])
112 | # new_img = new_img.cpu().numpy()
113 | # new_img = Image.fromarray(new_img.astype('uint8'))
114 | # new_img.save('./res.png')
115 | #
116 | #
117 | #
118 | #if __name__ == "__main__":
119 | # main()
120 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Configuration
16 |
17 | Configuration for data, model archtecture, and training, etc.
18 | Config can be set by .yaml file or by argparser(limited usage)
19 |
20 |
21 | """
22 |
23 | import os
24 | from yacs.config import CfgNode as CN
25 | import yaml
26 |
27 | _C = CN()
28 | _C.BASE = ['']
29 |
30 | # data settings
31 | _C.DATA = CN()
32 | _C.DATA.BATCH_SIZE = 8 #1024 batch_size for single GPU
33 | _C.DATA.BATCH_SIZE_EVAL = 8 #1024 batch_size for single GPU
34 | _C.DATA.DATA_PATH = '/dataset/imagenet/' # path to dataset
35 | _C.DATA.DATASET = 'imagenet2012' # dataset name
36 | _C.DATA.IMAGE_SIZE = 224 # input image size
37 | _C.DATA.CROP_PCT = 0.9 # input image scale ratio, scale is applied before centercrop in eval mode
38 | _C.DATA.NUM_WORKERS = 8 # number of data loading threads
39 |
40 | # model settings
41 | _C.MODEL = CN()
42 | _C.MODEL.TYPE = 'Swin'
43 | _C.MODEL.NAME = 'Swin'
44 | _C.MODEL.RESUME = None
45 | _C.MODEL.PRETRAINED = None
46 | _C.MODEL.NUM_CLASSES = 1000
47 | _C.MODEL.DROPOUT = 0.0
48 | _C.MODEL.ATTENTION_DROPOUT = 0.0
49 | _C.MODEL.DROP_PATH = 0.1
50 |
51 | # transformer settings
52 | _C.MODEL.TRANS = CN()
53 | _C.MODEL.TRANS.PATCH_SIZE = 4 # image_size = patch_size x window_size x num_windows
54 | _C.MODEL.TRANS.WINDOW_SIZE = 7
55 | _C.MODEL.TRANS.IN_CHANNELS = 3
56 | _C.MODEL.TRANS.EMBED_DIM = 96 # same as HIDDEN_SIZE in ViT
57 | _C.MODEL.TRANS.STAGE_DEPTHS = [2, 2, 6, 2]
58 | _C.MODEL.TRANS.NUM_HEADS = [3, 6, 12, 24]
59 | _C.MODEL.TRANS.MLP_RATIO = 4.
60 | _C.MODEL.TRANS.QKV_BIAS = True
61 | _C.MODEL.TRANS.QK_SCALE = None
62 | _C.MODEL.TRANS.APE = False # absolute positional embeddings
63 | _C.MODEL.TRANS.PATCH_NORM = True
64 | _C.MODEL.TRANS.EXTRA_NORM = False
65 |
66 | # training settings
67 | _C.TRAIN = CN()
68 | _C.TRAIN.LAST_EPOCH = 0
69 | _C.TRAIN.NUM_EPOCHS = 300
70 | _C.TRAIN.WARMUP_EPOCHS = 20
71 | _C.TRAIN.WEIGHT_DECAY = 0.05
72 | _C.TRAIN.BASE_LR = 5e-4
73 | _C.TRAIN.WARMUP_START_LR = 5e-7
74 | _C.TRAIN.END_LR = 5e-6
75 | _C.TRAIN.GRAD_CLIP = 5.0
76 | _C.TRAIN.ACCUM_ITER = 1
77 |
78 | _C.TRAIN.LR_SCHEDULER = CN()
79 | _C.TRAIN.LR_SCHEDULER.NAME = 'warmupcosine'
80 | _C.TRAIN.LR_SCHEDULER.MILESTONES = "30, 60, 90" # only used in StepLRScheduler
81 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 # only used in StepLRScheduler
82 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 # only used in StepLRScheduler
83 |
84 | _C.TRAIN.OPTIMIZER = CN()
85 | _C.TRAIN.OPTIMIZER.NAME = 'AdamW'
86 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
87 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) # for adamW
88 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
89 |
90 | # train augmentation
91 | _C.TRAIN.MIXUP_ALPHA = 0.8
92 | _C.TRAIN.CUTMIX_ALPHA = 1.0
93 | _C.TRAIN.CUTMIX_MINMAX = None
94 | _C.TRAIN.MIXUP_PROB = 1.0
95 | _C.TRAIN.MIXUP_SWITCH_PROB = 0.5
96 | _C.TRAIN.MIXUP_MODE = 'batch'
97 |
98 | _C.TRAIN.SMOOTHING = 0.1
99 | _C.TRAIN.COLOR_JITTER = 0.4
100 | _C.TRAIN.AUTO_AUGMENT = True #'rand-m9-mstd0.5-inc1'
101 |
102 | _C.TRAIN.RANDOM_ERASE_PROB = 0.25
103 | _C.TRAIN.RANDOM_ERASE_MODE = 'pixel'
104 | _C.TRAIN.RANDOM_ERASE_COUNT = 1
105 | _C.TRAIN.RANDOM_ERASE_SPLIT = False
106 |
107 | # augmentation
108 | _C.AUG = CN()
109 | _C.AUG.COLOR_JITTER = 0.4 # color jitter factor
110 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
111 | _C.AUG.RE_PROB = 0.25 # random earse prob
112 | _C.AUG.RE_MODE = 'pixel' # random earse mode
113 | _C.AUG.RE_COUNT = 1 # random earse count
114 | _C.AUG.MIXUP = 0.8 # mixup alpha, enabled if >0
115 | _C.AUG.CUTMIX = 1.0 # cutmix alpha, enabled if >0
116 | _C.AUG.CUTMIX_MINMAX = None # cutmix min/max ratio, overrides alpha
117 | _C.AUG.MIXUP_PROB = 1.0 # prob of mixup or cutmix when either/both is enabled
118 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 # prob of switching cutmix when both mixup and cutmix enabled
119 | _C.AUG.MIXUP_MODE = 'batch' #how to apply mixup/curmix params, per 'batch', 'pair', or 'elem'
120 |
121 | # misc
122 | _C.SAVE = "./output"
123 | _C.TAG = "default"
124 | _C.SAVE_FREQ = 1 # freq to save chpt
125 | _C.REPORT_FREQ = 50 # freq to logging info
126 | _C.VALIDATE_FREQ = 10 # freq to do validation
127 | _C.SEED = 42
128 | _C.EVAL = False # run evaluation only
129 | _C.AMP = False
130 | _C.LOCAL_RANK = 0
131 | _C.NGPUS = -1
132 |
133 |
134 | def _update_config_from_file(config, cfg_file):
135 | config.defrost()
136 | with open(cfg_file, 'r') as infile:
137 | yaml_cfg = yaml.load(infile, Loader=yaml.FullLoader)
138 | for cfg in yaml_cfg.setdefault('BASE', ['']):
139 | if cfg:
140 | _update_config_from_file(
141 | config, os.path.join(os.path.dirname(cfg_file), cfg)
142 | )
143 | print('merging config from {}'.format(cfg_file))
144 | config.merge_from_file(cfg_file)
145 | config.freeze()
146 |
147 | def update_config(config, args):
148 | """Update config by ArgumentParser
149 | Args:
150 | args: ArgumentParser contains options
151 | Return:
152 | config: updated config
153 | """
154 | if args.cfg:
155 | _update_config_from_file(config, args.cfg)
156 | config.defrost()
157 | if args.dataset:
158 | config.DATA.DATASET = args.dataset
159 | if args.batch_size:
160 | config.DATA.BATCH_SIZE = args.batch_size
161 | if args.image_size:
162 | config.DATA.IMAGE_SIZE = args.image_size
163 | if args.data_path:
164 | config.DATA.DATA_PATH = args.data_path
165 | if args.ngpus:
166 | config.NGPUS = args.ngpus
167 | if args.eval:
168 | config.EVAL = True
169 | config.DATA.BATCH_SIZE_EVAL = args.batch_size
170 | if args.pretrained:
171 | config.MODEL.PRETRAINED = args.pretrained
172 | if args.resume:
173 | config.MODEL.RESUME = args.resume
174 | if args.last_epoch:
175 | config.TRAIN.LAST_EPOCH = args.last_epoch
176 | if args.amp: # only during training
177 | if config.EVAL is True:
178 | config.AMP = False
179 |
180 | #config.freeze()
181 | return config
182 |
183 |
184 | def get_config(cfg_file=None):
185 | """Return a clone of config or load from yaml file"""
186 | config = _C.clone()
187 | if cfg_file:
188 | _update_config_from_file(config, cfg_file)
189 | return config
190 |
--------------------------------------------------------------------------------
/port_weights/load_pytorch_weights_384.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import numpy as np
17 | import paddle
18 | import torch
19 | import timm
20 | from swin_transformer import *
21 | from config import *
22 |
23 |
24 | config = get_config('./configs/swin_base_patch4_window12_384.yaml')
25 | print(config)
26 |
27 |
28 | def print_model_named_params(model):
29 | print('----------------------------------')
30 | for name, param in model.named_parameters():
31 | print(name, param.shape)
32 | print('----------------------------------')
33 |
34 |
35 | def print_model_named_buffers(model):
36 | print('----------------------------------')
37 | for name, param in model.named_buffers():
38 | print(name, param.shape)
39 | print('----------------------------------')
40 |
41 |
42 | def torch_to_paddle_mapping():
43 | mapping = [
44 | ('patch_embed.proj', 'patch_embedding.patch_embed'),
45 | ('patch_embed.norm', 'patch_embedding.norm'),
46 | ]
47 |
48 | # torch 'layers' to paddle 'stages'
49 | depths = config.MODEL.TRANS.STAGE_DEPTHS
50 | num_stages = len(depths)
51 | for stage_idx in range(num_stages):
52 | pp_s_prefix = f'stages.{stage_idx}.blocks'
53 | th_s_prefix = f'layers.{stage_idx}.blocks'
54 | for block_idx in range(depths[stage_idx]):
55 | th_b_prefix = f'{th_s_prefix}.{block_idx}'
56 | pp_b_prefix = f'{pp_s_prefix}.{block_idx}'
57 | layer_mapping = [
58 | (f'{th_b_prefix}.norm1', f'{pp_b_prefix}.norm1'),
59 | (f'{th_b_prefix}.attn.relative_position_bias_table', f'{pp_b_prefix}.attn.relative_position_bias_table'),
60 | (f'{th_b_prefix}.attn.qkv', f'{pp_b_prefix}.attn.qkv'),
61 | (f'{th_b_prefix}.attn.proj', f'{pp_b_prefix}.attn.proj'),
62 | (f'{th_b_prefix}.norm2', f'{pp_b_prefix}.norm2'),
63 | (f'{th_b_prefix}.mlp.fc1', f'{pp_b_prefix}.mlp.fc1'),
64 | (f'{th_b_prefix}.mlp.fc2', f'{pp_b_prefix}.mlp.fc2'),
65 | ]
66 | mapping.extend(layer_mapping)
67 | # stage downsample: last stage does not have downsample ops
68 | if stage_idx < num_stages - 1:
69 | mapping.extend([
70 | (f'layers.{stage_idx}.downsample.reduction.weight', f'stages.{stage_idx}.downsample.reduction.weight'),
71 | (f'layers.{stage_idx}.downsample.norm', f'stages.{stage_idx}.downsample.norm')])
72 |
73 | mapping.extend([
74 | ('norm', 'norm'),
75 | ('head', 'fc')])
76 | return mapping
77 |
78 |
79 |
80 | def convert(torch_model, paddle_model):
81 | def _set_value(th_name, pd_name, no_transpose=False):
82 | th_shape = th_params[th_name].shape
83 | pd_shape = tuple(pd_params[pd_name].shape) # paddle shape default type is list
84 | #assert th_shape == pd_shape, f'{th_shape} != {pd_shape}'
85 | print(f'set {th_name} {th_shape} to {pd_name} {pd_shape}')
86 | value = th_params[th_name].data.numpy()
87 | if len(value.shape) == 2:
88 | if not no_transpose:
89 | value = value.transpose((1, 0))
90 | pd_params[pd_name].set_value(value)
91 |
92 | # 1. get paddle and torch model parameters
93 | pd_params = {}
94 | th_params = {}
95 | for name, param in paddle_model.named_parameters():
96 | pd_params[name] = param
97 | for name, param in torch_model.named_parameters():
98 | th_params[name] = param
99 |
100 | for name, param in paddle_model.named_buffers():
101 | pd_params[name] = param
102 | for name, param in torch_model.named_buffers():
103 | th_params[name] = param
104 |
105 | # 2. get name mapping pairs
106 | mapping = torch_to_paddle_mapping()
107 | # 3. set torch param values to paddle params: may needs transpose on weights
108 | for th_name, pd_name in mapping:
109 | if th_name in th_params.keys(): # nn.Parameters
110 | if th_name.endswith('relative_position_bias_table'):
111 | _set_value(th_name, pd_name, no_transpose=True)
112 | else:
113 | _set_value(th_name, pd_name)
114 | else: # weight & bias
115 | th_name_w = f'{th_name}.weight'
116 | pd_name_w = f'{pd_name}.weight'
117 | _set_value(th_name_w, pd_name_w)
118 |
119 | if f'{th_name}.bias' in th_params.keys():
120 | th_name_b = f'{th_name}.bias'
121 | pd_name_b = f'{pd_name}.bias'
122 | _set_value(th_name_b, pd_name_b)
123 |
124 | return paddle_model
125 |
126 |
127 | def main():
128 |
129 | paddle.set_device('cpu')
130 | paddle_model = build_swin(config)
131 | paddle_model.eval()
132 |
133 | print_model_named_params(paddle_model)
134 | print_model_named_buffers(paddle_model)
135 |
136 | print('+++++++++++++++++++++++++++++++++++')
137 | device = torch.device('cpu')
138 | torch_model = timm.create_model('swin_base_patch4_window12_384', pretrained=True)
139 | torch_model = torch_model.to(device)
140 | torch_model.eval()
141 |
142 | print_model_named_params(torch_model)
143 | print_model_named_buffers(torch_model)
144 |
145 | # convert weights
146 | paddle_model = convert(torch_model, paddle_model)
147 |
148 | # check correctness
149 | x = np.random.randn(2, 3, 384, 384).astype('float32')
150 | x_paddle = paddle.to_tensor(x)
151 | x_torch = torch.Tensor(x).to(device)
152 |
153 | out_torch = torch_model(x_torch)
154 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||')
155 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||')
156 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||')
157 | out_paddle = paddle_model(x_paddle)
158 |
159 | out_torch = out_torch.data.cpu().numpy()
160 | out_paddle = out_paddle.cpu().numpy()
161 |
162 | print(out_torch.shape, out_paddle.shape)
163 | print(out_torch[0, 0:20])
164 | print(out_paddle[0, 0:20])
165 | assert np.allclose(out_torch, out_paddle, atol = 1e-4)
166 |
167 | # save weights for paddle model
168 | model_path = os.path.join('./swin_base_patch4_window12_384.pdparams')
169 | paddle.save(paddle_model.state_dict(), model_path)
170 |
171 |
172 |
173 | #tmp = np.random.randn(1, 56, 128, 128).astype('float32')
174 | #xp = paddle.to_tensor(tmp)
175 | #xt = torch.Tensor(tmp).to(device)
176 | #xps = paddle.roll(xp, shifts=(-3, -3), axis=(1,2))
177 | #xts = torch.roll(xt,shifts=(-3, -3), dims=(1,2))
178 | #xps = xps.cpu().numpy()
179 | #xts = xts.data.cpu().numpy()
180 | #assert np.allclose(xps, xts, atol=1e-4)
181 |
182 | if __name__ == "__main__":
183 | main()
184 |
--------------------------------------------------------------------------------
/port_weights/load_pytorch_weights.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import numpy as np
17 | import paddle
18 | import torch
19 | import timm
20 | from swin_transformer import *
21 | from config import *
22 |
23 |
24 | config = get_config('./configs/swin_base_patch4_window7_224.yaml')
25 | print(config)
26 |
27 |
28 | def print_model_named_params(model):
29 | print('----------------------------------')
30 | for name, param in model.named_parameters():
31 | print(name, param.shape)
32 | print('----------------------------------')
33 |
34 |
35 | def print_model_named_buffers(model):
36 | print('----------------------------------')
37 | for name, param in model.named_buffers():
38 | print(name, param.shape)
39 | print('----------------------------------')
40 |
41 |
42 | def torch_to_paddle_mapping():
43 | mapping = [
44 | ('patch_embed.proj', 'patch_embedding.patch_embed'),
45 | ('patch_embed.norm', 'patch_embedding.norm'),
46 | ]
47 |
48 | # torch 'layers' to paddle 'stages'
49 | depths = config.MODEL.TRANS.STAGE_DEPTHS
50 | num_stages = len(depths)
51 | for stage_idx in range(num_stages):
52 | pp_s_prefix = f'stages.{stage_idx}.blocks'
53 | th_s_prefix = f'layers.{stage_idx}.blocks'
54 | for block_idx in range(depths[stage_idx]):
55 | th_b_prefix = f'{th_s_prefix}.{block_idx}'
56 | pp_b_prefix = f'{pp_s_prefix}.{block_idx}'
57 | layer_mapping = [
58 | (f'{th_b_prefix}.norm1', f'{pp_b_prefix}.norm1'),
59 | (f'{th_b_prefix}.attn.relative_position_bias_table', f'{pp_b_prefix}.attn.relative_position_bias_table'),
60 | (f'{th_b_prefix}.attn.qkv', f'{pp_b_prefix}.attn.qkv'),
61 | (f'{th_b_prefix}.attn.proj', f'{pp_b_prefix}.attn.proj'),
62 | (f'{th_b_prefix}.norm2', f'{pp_b_prefix}.norm2'),
63 | (f'{th_b_prefix}.mlp.fc1', f'{pp_b_prefix}.mlp.fc1'),
64 | (f'{th_b_prefix}.mlp.fc2', f'{pp_b_prefix}.mlp.fc2'),
65 | ]
66 | mapping.extend(layer_mapping)
67 | # stage downsample: last stage does not have downsample ops
68 | if stage_idx < num_stages - 1:
69 | mapping.extend([
70 | (f'layers.{stage_idx}.downsample.reduction.weight', f'stages.{stage_idx}.downsample.reduction.weight'),
71 | (f'layers.{stage_idx}.downsample.norm', f'stages.{stage_idx}.downsample.norm')])
72 |
73 | mapping.extend([
74 | ('norm', 'norm'),
75 | ('head', 'fc')])
76 | return mapping
77 |
78 |
79 |
80 | def convert(torch_model, paddle_model):
81 | def _set_value(th_name, pd_name, no_transpose=False):
82 | th_shape = th_params[th_name].shape
83 | pd_shape = tuple(pd_params[pd_name].shape) # paddle shape default type is list
84 | #assert th_shape == pd_shape, f'{th_shape} != {pd_shape}'
85 | print(f'set {th_name} {th_shape} to {pd_name} {pd_shape}')
86 | value = th_params[th_name].data.numpy()
87 | if len(value.shape) == 2:
88 | if not no_transpose:
89 | value = value.transpose((1, 0))
90 | pd_params[pd_name].set_value(value)
91 |
92 | # 1. get paddle and torch model parameters
93 | pd_params = {}
94 | th_params = {}
95 | for name, param in paddle_model.named_parameters():
96 | pd_params[name] = param
97 | for name, param in torch_model.named_parameters():
98 | th_params[name] = param
99 |
100 | for name, param in paddle_model.named_buffers():
101 | pd_params[name] = param
102 | for name, param in torch_model.named_buffers():
103 | th_params[name] = param
104 |
105 | # 2. get name mapping pairs
106 | mapping = torch_to_paddle_mapping()
107 | # 3. set torch param values to paddle params: may needs transpose on weights
108 | for th_name, pd_name in mapping:
109 | if th_name in th_params.keys(): # nn.Parameters
110 | if th_name.endswith('relative_position_bias_table'):
111 | _set_value(th_name, pd_name, no_transpose=True)
112 | else:
113 | _set_value(th_name, pd_name)
114 | else: # weight & bias
115 | th_name_w = f'{th_name}.weight'
116 | pd_name_w = f'{pd_name}.weight'
117 | _set_value(th_name_w, pd_name_w)
118 |
119 | if f'{th_name}.bias' in th_params.keys():
120 | th_name_b = f'{th_name}.bias'
121 | pd_name_b = f'{pd_name}.bias'
122 | _set_value(th_name_b, pd_name_b)
123 |
124 | return paddle_model
125 |
126 |
127 |
128 |
129 |
130 | def main():
131 |
132 | paddle.set_device('cpu')
133 | paddle_model = build_swin(config)
134 | paddle_model.eval()
135 |
136 | print_model_named_params(paddle_model)
137 | print_model_named_buffers(paddle_model)
138 |
139 | print('+++++++++++++++++++++++++++++++++++')
140 | device = torch.device('cpu')
141 | torch_model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
142 | torch_model = torch_model.to(device)
143 | torch_model.eval()
144 | print_model_named_params(torch_model)
145 | print_model_named_buffers(torch_model)
146 |
147 | # convert weights
148 | paddle_model = convert(torch_model, paddle_model)
149 |
150 | # check correctness
151 | x = np.random.randn(2, 3, 224, 224).astype('float32')
152 | x_paddle = paddle.to_tensor(x)
153 | x_torch = torch.Tensor(x).to(device)
154 |
155 | out_torch = torch_model(x_torch)
156 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||')
157 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||')
158 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||')
159 | out_paddle = paddle_model(x_paddle)
160 |
161 | out_torch = out_torch.data.cpu().numpy()
162 | out_paddle = out_paddle.cpu().numpy()
163 |
164 | print(out_torch.shape, out_paddle.shape)
165 | print(out_torch[0, 0:20])
166 | print(out_paddle[0, 0:20])
167 | assert np.allclose(out_torch, out_paddle, atol = 1e-4)
168 |
169 | # save weights for paddle model
170 | model_path = os.path.join('./swin_base_patch4_window7_224.pdparams')
171 | paddle.save(paddle_model.state_dict(), model_path)
172 |
173 |
174 |
175 | #tmp = np.random.randn(1, 56, 128, 128).astype('float32')
176 | #xp = paddle.to_tensor(tmp)
177 | #xt = torch.Tensor(tmp).to(device)
178 | #xps = paddle.roll(xp, shifts=(-3, -3), axis=(1,2))
179 | #xts = torch.roll(xt,shifts=(-3, -3), dims=(1,2))
180 | #xps = xps.cpu().numpy()
181 | #xts = xts.data.cpu().numpy()
182 | #assert np.allclose(xps, xts, atol=1e-4)
183 |
184 | if __name__ == "__main__":
185 | main()
186 |
--------------------------------------------------------------------------------
/port_weights/load_pytorch_weights_large_384.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | import numpy as np
17 | import paddle
18 | import torch
19 | import timm
20 | from swin_transformer import *
21 | from config import *
22 |
23 |
24 | config = get_config('./configs/swin_large_patch4_window12_384.yaml')
25 | print(config)
26 |
27 |
28 | def print_model_named_params(model):
29 | print('----------------------------------')
30 | for name, param in model.named_parameters():
31 | print(name, param.shape)
32 | print('----------------------------------')
33 |
34 |
35 | def print_model_named_buffers(model):
36 | print('----------------------------------')
37 | for name, param in model.named_buffers():
38 | print(name, param.shape)
39 | print('----------------------------------')
40 |
41 |
42 | def torch_to_paddle_mapping():
43 | mapping = [
44 | ('patch_embed.proj', 'patch_embedding.patch_embed'),
45 | ('patch_embed.norm', 'patch_embedding.norm'),
46 | ]
47 |
48 | # torch 'layers' to paddle 'stages'
49 | depths = config.MODEL.TRANS.STAGE_DEPTHS
50 | num_stages = len(depths)
51 | for stage_idx in range(num_stages):
52 | pp_s_prefix = f'stages.{stage_idx}.blocks'
53 | th_s_prefix = f'layers.{stage_idx}.blocks'
54 | for block_idx in range(depths[stage_idx]):
55 | th_b_prefix = f'{th_s_prefix}.{block_idx}'
56 | pp_b_prefix = f'{pp_s_prefix}.{block_idx}'
57 | layer_mapping = [
58 | (f'{th_b_prefix}.norm1', f'{pp_b_prefix}.norm1'),
59 | (f'{th_b_prefix}.attn.relative_position_bias_table', f'{pp_b_prefix}.attn.relative_position_bias_table'),
60 | (f'{th_b_prefix}.attn.qkv', f'{pp_b_prefix}.attn.qkv'),
61 | (f'{th_b_prefix}.attn.proj', f'{pp_b_prefix}.attn.proj'),
62 | (f'{th_b_prefix}.norm2', f'{pp_b_prefix}.norm2'),
63 | (f'{th_b_prefix}.mlp.fc1', f'{pp_b_prefix}.mlp.fc1'),
64 | (f'{th_b_prefix}.mlp.fc2', f'{pp_b_prefix}.mlp.fc2'),
65 | ]
66 | mapping.extend(layer_mapping)
67 | # stage downsample: last stage does not have downsample ops
68 | if stage_idx < num_stages - 1:
69 | mapping.extend([
70 | (f'layers.{stage_idx}.downsample.reduction.weight', f'stages.{stage_idx}.downsample.reduction.weight'),
71 | (f'layers.{stage_idx}.downsample.norm', f'stages.{stage_idx}.downsample.norm')])
72 |
73 | mapping.extend([
74 | ('norm', 'norm'),
75 | ('head', 'fc')])
76 | return mapping
77 |
78 |
79 |
80 | def convert(torch_model, paddle_model):
81 | def _set_value(th_name, pd_name, no_transpose=False):
82 | th_shape = th_params[th_name].shape
83 | pd_shape = tuple(pd_params[pd_name].shape) # paddle shape default type is list
84 | #assert th_shape == pd_shape, f'{th_shape} != {pd_shape}'
85 | print(f'set {th_name} {th_shape} to {pd_name} {pd_shape}')
86 | value = th_params[th_name].data.numpy()
87 | if len(value.shape) == 2:
88 | if not no_transpose:
89 | value = value.transpose((1, 0))
90 | pd_params[pd_name].set_value(value)
91 |
92 | # 1. get paddle and torch model parameters
93 | pd_params = {}
94 | th_params = {}
95 | for name, param in paddle_model.named_parameters():
96 | pd_params[name] = param
97 | for name, param in torch_model.named_parameters():
98 | th_params[name] = param
99 |
100 | for name, param in paddle_model.named_buffers():
101 | pd_params[name] = param
102 | for name, param in torch_model.named_buffers():
103 | th_params[name] = param
104 |
105 | # 2. get name mapping pairs
106 | mapping = torch_to_paddle_mapping()
107 | # 3. set torch param values to paddle params: may needs transpose on weights
108 | for th_name, pd_name in mapping:
109 | if th_name in th_params.keys(): # nn.Parameters
110 | if th_name.endswith('relative_position_bias_table'):
111 | _set_value(th_name, pd_name, no_transpose=True)
112 | else:
113 | _set_value(th_name, pd_name)
114 | else: # weight & bias
115 | th_name_w = f'{th_name}.weight'
116 | pd_name_w = f'{pd_name}.weight'
117 | _set_value(th_name_w, pd_name_w)
118 |
119 | if f'{th_name}.bias' in th_params.keys():
120 | th_name_b = f'{th_name}.bias'
121 | pd_name_b = f'{pd_name}.bias'
122 | _set_value(th_name_b, pd_name_b)
123 |
124 | return paddle_model
125 |
126 |
127 |
128 |
129 |
130 | def main():
131 |
132 | paddle.set_device('cpu')
133 | paddle_model = build_swin(config)
134 | paddle_model.eval()
135 |
136 | print_model_named_params(paddle_model)
137 | print_model_named_buffers(paddle_model)
138 |
139 | print('+++++++++++++++++++++++++++++++++++')
140 | device = torch.device('cpu')
141 | torch_model = timm.create_model('swin_large_patch4_window12_384', pretrained=True)
142 | torch_model = torch_model.to(device)
143 | torch_model.eval()
144 | print_model_named_params(torch_model)
145 | print_model_named_buffers(torch_model)
146 |
147 | # convert weights
148 | paddle_model = convert(torch_model, paddle_model)
149 |
150 | # check correctness
151 | x = np.random.randn(2, 3, 384, 384).astype('float32')
152 | x_paddle = paddle.to_tensor(x)
153 | x_torch = torch.Tensor(x).to(device)
154 |
155 | out_torch = torch_model(x_torch)
156 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||')
157 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||')
158 | print('|||||||||||||||||||||||||||||||||||||||||||||||||||')
159 | out_paddle = paddle_model(x_paddle)
160 |
161 | out_torch = out_torch.data.cpu().numpy()
162 | out_paddle = out_paddle.cpu().numpy()
163 |
164 | print(out_torch.shape, out_paddle.shape)
165 | print(out_torch[0, 0:20])
166 | print(out_paddle[0, 0:20])
167 | assert np.allclose(out_torch, out_paddle, atol = 1e-4)
168 |
169 | # save weights for paddle model
170 | model_path = os.path.join('./swin_large_patch4_window12_384.pdparams')
171 | paddle.save(paddle_model.state_dict(), model_path)
172 |
173 |
174 |
175 | #tmp = np.random.randn(1, 56, 128, 128).astype('float32')
176 | #xp = paddle.to_tensor(tmp)
177 | #xt = torch.Tensor(tmp).to(device)
178 | #xps = paddle.roll(xp, shifts=(-3, -3), axis=(1,2))
179 | #xts = torch.roll(xt,shifts=(-3, -3), dims=(1,2))
180 | #xps = xps.cpu().numpy()
181 | #xts = xts.data.cpu().numpy()
182 | #assert np.allclose(xps, xts, atol=1e-4)
183 |
184 | if __name__ == "__main__":
185 | main()
186 |
--------------------------------------------------------------------------------
/modification.md:
--------------------------------------------------------------------------------
1 | # 代码改动说明
2 |
3 | 写在开头:我个人水平有限,对于Swin Transformer的代码理解可能没有很透彻,在修改过程中有错误的话欢迎大家及时指正!也希望能借这个机会多多交流~~
4 |
5 | ## Model Architecture
6 |
7 | Swin Transformer V2 相比于V1版本提出的三个改动集中在`swin_transformer.py`的`WindowAttention`模块,分别为:
8 |
9 | * 将pre-norm更改为post-norm
10 | * 将点乘attention计算方式更改为cosine attention,并添加用于scaled的参数$\tau$
11 | * 使用continuous relative position bias替代原本直接学习relative position bias的方式,并将线性的相对坐标更改为log-spaced coordinates
12 |
13 | ### 1. Post-norm
14 |
15 | 直接修改`swin_transformer.py`的`SwinTransformerBlock`中的代码顺序,向后移动`self.norm1(x)`和`self.norm2(x)`到attention以及mlp操作后,shortcut操作之前,例如:
16 |
17 | ```python
18 | # x = self.norm2(x) # Swin-T v1, pre-norm
19 | x = self.mlp(x) # [bs,H*W,C]
20 | x = self.norm2(x) # Swin-T v2, post-norm
21 | if self.drop_path is not None:
22 | x = h + self.drop_path(x)
23 | else:
24 | x = h + x
25 | ```
26 |
27 | 注意代码中额外添加了`self.norm3`,对应原文的:
28 |
29 | > For SwinV2-H and SwinV2-G, we further introduce a layer normalization unit on the main branch every 6 layers.
30 |
31 | 对于大模型,每隔6个`SwinTransformerBlock`就做一次额外的layer norm。可以通过设置**config**里的`EXTRA_NORM`参数开启。
32 |
33 | ## 2. Attention计算方式
34 |
35 | ### 2.1 Dot product attention
36 |
37 | 原始的swin transformer self-attention计算方式:
38 | $$
39 | \text { Attention }(Q, K, V)=\operatorname{SoftMax}\left(Q K^{T} / \sqrt{d}+B\right) V
40 | $$
41 | Softmax内前面的点乘attention计算对应`WindowAttention`模块如下代码:
42 |
43 | ```python
44 | qkv = self.qkv(x).chunk(3, axis=-1)
45 | q, k, v = map(self.transpose_multihead, qkv)
46 | q = q * self.scale # i.e., sqrt(d)
47 | attn = paddle.matmul(q, k, transpose_y=True)
48 | ```
49 |
50 | ### 2.2 Scaled cosine attention
51 |
52 | V2提出的scaled cosine attention计算方式:
53 | $$
54 | \operatorname{Sim}\left(\mathbf{q}_{i}, \mathbf{k}_{j}\right)=\cos \left(\mathbf{q}_{i}, \mathbf{k}_{j}\right) / \tau+B_{i j}
55 | $$
56 | 其中$\tau$每个layer的每个head都不同,是可学习参数,且限定最小取值为0.01。
57 |
58 | 代码更改如下:
59 |
60 | 首先在`__init__`中定义$\tau$:
61 |
62 | ```python
63 | # Swin-T v2, Scaled cosine attention
64 | self.tau = paddle.create_parameter(
65 | shape = [num_heads, window_size[0]*window_size[1], window_size[0]*window_size[1]],
66 | dtype='float32',
67 | default_initializer=paddle.nn.initializer.Constant(1))
68 | ```
69 |
70 | 然后在`forward`中:
71 |
72 | ```python
73 | qkv = self.qkv(x).chunk(3, axis=-1) # {list:3}
74 | q, k, v = map(self.transpose_multihead, qkv) # [bs*num_window=1*64,4,49,32] -> [bs*num_window=1*16,8,49,32]-> [bs*num_window=1*4,16,49,32]->[bs*num_window=1*1,32,49,32]
75 |
76 | # Swin-T v2, Scaled cosine attention, Eq.(2)
77 | qk = paddle.matmul(q, k, transpose_y=True) # [bs*num_window=1*64,num_heads=4,49,49] -> [bs*num_window=1*16,num_heads=8,49,49] -> [bs*num_window=1*4,num_heads=16,49,49] -> [bs*num_window=1*1,num_heads=32,49,49]
78 | q2 = paddle.multiply(q, q).sum(-1).sqrt().unsqueeze(3)
79 | k2 = paddle.multiply(k, k).sum(-1).sqrt().unsqueeze(3)
80 | attn = qk/paddle.clip(paddle.matmul(q2, k2, transpose_y=True), min=1e-6)
81 | attn = attn/paddle.clip(self.tau.unsqueeze(0), min=0.01)
82 | ```
83 |
84 | ## 3.Log-Spaced CPB策略
85 |
86 | ## 3.1 Continuous relative position bias
87 |
88 | 作者在将训练好的模型迁移到更高分辨率以及更大尺度的window size时,发现直接使用双三次插值的方式去扩充relative position bias会导致性能下降很多,如文章的Tabel1第一行所示。因此V2版本使用了**连续相对位置偏差**的方式,这里我认为连续(continuous)指的是利用一个小网络(比如两层全连接中间带一个ReLu)学习每个相对位置坐标对应的bias,利用小网络的泛化性去适应更大尺寸的window size(这里理解的不是很透彻,还需要再研究一下)。
89 |
90 | * 原始模型的代码:
91 |
92 | 首先在`WindowAttention`的`__init__`方法中定义relative_position_bias_table ,并根据当前block对应的window size计算relative_position_index:
93 |
94 | ```python
95 | self.relative_position_bias_table = paddle.create_parameter(
96 | shape=[(2 * window_size[0] -1) * (2 * window_size[1] - 1), num_heads],
97 | dtype='float32',
98 | default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
99 |
100 | # relative position index for each token inside window
101 | coords_h = paddle.arange(self.window_size[0])
102 | coords_w = paddle.arange(self.window_size[1])
103 | coords = paddle.stack(paddle.meshgrid([coords_h, coords_w])) # [2, window_h, window_w]
104 | coords_flatten = paddle.flatten(coords, 1) # [2, window_h * window_w]
105 | # 2, window_h * window_w, window_h * window_h
106 | relative_coords = coords_flatten.unsqueeze(2) - coords_flatten.unsqueeze(1)
107 | # winwod_h*window_w, window_h*window_w, 2
108 | relative_coords = relative_coords.transpose([1, 2, 0])
109 | relative_coords[:, :, 0] += self.window_size[0] - 1
110 | relative_coords[:, :, 1] += self.window_size[1] - 1
111 | relative_coords[:, :, 0] *= 2* self.window_size[1] - 1
112 | # [window_size * window_size, window_size*window_size]
113 | relative_position_index = relative_coords.sum(-1)
114 | self.register_buffer("relative_position_index", relative_position_index)
115 | ```
116 |
117 | 在`forward`过程中,使用如下方式调用:
118 |
119 | ```python
120 | def get_relative_pos_bias_from_pos_index(self):
121 | table = self.relative_position_bias_table # N x num_heads
122 | # index is a tensor
123 | index = self.relative_position_index.reshape([-1]) # window_h*window_w * window_h*window_w
124 | # NOTE: paddle does NOT support indexing Tensor by a Tensor
125 | relative_position_bias = paddle.index_select(x=table, index=index)
126 | return relative_position_bias
127 | def forward(......):
128 | ......
129 | relative_position_bias = relative_position_bias.transpose([2, 0, 1])
130 | attn = attn + relative_position_bias.unsqueeze(0)
131 | ......
132 | ```
133 |
134 | * V2对应代码:
135 |
136 | `__init__`中:
137 |
138 | ```python
139 | ## Swin-T v2, small meta network, Eq.(3)
140 | self.cpb = Mlp_Relu(in_features=2, # delta x, delta y
141 | hidden_features=512, # TODO: hidden dims
142 | out_features=self.num_heads,
143 | dropout=dropout)
144 | ```
145 |
146 | 还需解决的点在于中间隐藏层维度取多少,这里我设置了512。相对坐标的index计算过程在下面一节会说。
147 |
148 | `forward`中:
149 |
150 | ```python
151 | def get_continuous_relative_position_bias(self):
152 | # The continuous position bias approach adopts a small meta network on the relative coordinates
153 | continuous_relative_position_bias = self.cpb(self.log_relative_position_index)
154 | return continuous_relative_position_bias
155 | def forward(......):
156 | ......
157 | ## Swin-T v2
158 | relative_position_bias = self.get_continuous_relative_position_bias()
159 | relative_position_bias = relative_position_bias.reshape(
160 | [self.window_size[0] * self.window_size[1],
161 | self.window_size[0] * self.window_size[1],
162 | -1])
163 |
164 | # nH, window_h*window_w, window_h*window_w
165 | relative_position_bias = relative_position_bias.transpose([2, 0, 1])
166 | attn = attn + relative_position_bias.unsqueeze(0)
167 | ......
168 | ```
169 |
170 | ### 3.2 Log-spaced coordinates
171 |
172 | 此外,作者提到:
173 |
174 | > When transferred across largely varied window sizes, there will be a large portion of relative coordinate range requiring extrapolation.
175 |
176 | 原先的线性编码计算patch之间的相对位置偏差会导致模型在迁移到更大尺寸的window size时,插值的变化范围也会间隔较大。因此提出:
177 |
178 | >we propose to use the log-spaced coordinates instead of the original linear-spaced ones
179 |
180 | log-spaced coordinates文章中对应公式4:
181 | $$
182 | \begin{aligned}
183 | &\widehat{\Delta x}=\operatorname{sign}(x) \cdot \log (1+|\Delta x|) \\
184 | &\widehat{\Delta y}=\operatorname{sign}(y) \cdot \log (1+|\Delta y|)
185 | \end{aligned}
186 | $$
187 | 但是我感觉$\operatorname{sign}(·)$里面应该是$\Delta x$和$\Delta y$,对应的修改后代码:
188 |
189 | ```python
190 | # relative position index for each token inside window
191 | coords_h = paddle.arange(self.window_size[0])
192 | coords_w = paddle.arange(self.window_size[1])
193 | coords = paddle.stack(paddle.meshgrid([coords_h, coords_w])) # [2, window_h, window_w]
194 | coords_flatten = paddle.flatten(coords, 1) # [2, window_h * window_w]
195 | # 2, window_h * window_w, window_h * window_h
196 | relative_coords = coords_flatten.unsqueeze(2) - coords_flatten.unsqueeze(1)
197 | # winwod_h*window_w, window_h*window_w, 2
198 | relative_coords = relative_coords.transpose([1, 2, 0])
199 |
200 | ## Swin-T v2, log-spaced coordinates, Eq.(4)
201 | log_relative_position_index = paddle.multiply(relative_coords.cast(dtype='float32').sign(),
202 | paddle.log((relative_coords.cast(dtype='float32').abs()+1)))
203 | self.register_buffer("log_relative_position_index", log_relative_position_index)
204 | ```
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Dataset related classes and methods for ViT training and validation
17 | Cifar10, Cifar100 and ImageNet2012 are supported
18 | """
19 |
20 | import os
21 | import math
22 | from paddle.io import Dataset
23 | from paddle.io import DataLoader
24 | from paddle.io import DistributedBatchSampler
25 | from paddle.vision import transforms
26 | from paddle.vision import datasets
27 | from paddle.vision import image_load
28 | from auto_augment import auto_augment_policy_original
29 | from auto_augment import AutoAugment
30 | from transforms import RandomHorizontalFlip
31 | from random_erasing import RandomErasing
32 |
33 | class ImageNet2012Dataset(Dataset):
34 | """Build ImageNet2012 dataset
35 |
36 | This class gets train/val imagenet datasets, which loads transfomed data and labels.
37 |
38 | Attributes:
39 | file_folder: path where imagenet images are stored
40 | transform: preprocessing ops to apply on image
41 | img_path_list: list of full path of images in whole dataset
42 | label_list: list of labels of whole dataset
43 | """
44 |
45 | def __init__(self, file_folder, mode="train", transform=None):
46 | """Init ImageNet2012 Dataset with dataset file path, mode(train/val), and transform"""
47 | super(ImageNet2012Dataset, self).__init__()
48 | assert mode in ["train", "val"]
49 | self.file_folder = file_folder
50 | self.transform = transform
51 | self.img_path_list = []
52 | self.label_list = []
53 | self.mode = mode
54 |
55 | if mode == "train":
56 | self.list_file = os.path.join(self.file_folder, "Annotations", "CLS-LOC", "train.txt")
57 | else:
58 | self.list_file = os.path.join(self.file_folder, "Annotations", "CLS-LOC", "val.txt")
59 |
60 | with open(self.list_file, 'r') as infile:
61 | for line in infile:
62 | img_path = line.strip().split()[0]
63 | img_label = int(line.strip().split()[1])
64 | self.img_path_list.append(os.path.join(self.file_folder, "Data", "CLS-LOC", self.mode, img_path))
65 | self.label_list.append(img_label)
66 | print(f'----- Imagenet2012 image {mode} list len = {len(self.label_list)}')
67 |
68 | def __len__(self):
69 | return len(self.label_list)
70 |
71 | def __getitem__(self, index):
72 | data = image_load(self.img_path_list[index]).convert('RGB')
73 | data = self.transform(data)
74 | label = self.label_list[index]
75 |
76 | return data, label
77 |
78 |
79 | def get_train_transforms(config):
80 | """ Get training transforms
81 |
82 | For training, a RandomResizedCrop is applied, then normalization is applied with
83 | [0.5, 0.5, 0.5] mean and std. The input pixel values must be rescaled to [0, 1.]
84 | Outputs is converted to tensor
85 |
86 | Args:
87 | config: configs contains IMAGE_SIZE, see config.py for details
88 | Returns:
89 | transforms_train: training transforms
90 | """
91 |
92 | aug_op_list = []
93 | # STEP1: random crop and resize
94 | aug_op_list.append(
95 | transforms.RandomResizedCrop((config.DATA.IMAGE_SIZE, config.DATA.IMAGE_SIZE),
96 | scale=(0.05, 1.0), interpolation='bicubic'))
97 | # STEP2: auto_augment or color jitter
98 | if config.TRAIN.AUTO_AUGMENT:
99 | policy = auto_augment_policy_original()
100 | auto_augment = AutoAugment(policy)
101 | aug_op_list.append(auto_augment)
102 | else:
103 | jitter = (float(config.TRAIN.COLOR_JITTER), ) * 3
104 | aug_op_list.append(transforms.ColorJitter(*jitter))
105 | # STEP3: other ops
106 | aug_op_list.append(transforms.ToTensor())
107 | aug_op_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
108 | # STEP4: random erasing
109 | if config.TRAIN.RANDOM_ERASE_PROB > 0.:
110 | random_erasing = RandomErasing(prob=config.TRAIN.RANDOM_ERASE_PROB,
111 | mode=config.TRAIN.RANDOM_ERASE_MODE,
112 | max_count=config.TRAIN.RANDOM_ERASE_COUNT,
113 | num_splits=config.TRAIN.RANDOM_ERASE_SPLIT)
114 | aug_op_list.append(random_erasing)
115 | # Final: compose transforms and return
116 | transforms_train = transforms.Compose(aug_op_list)
117 | return transforms_train
118 |
119 |
120 | def get_val_transforms(config):
121 | """ Get training transforms
122 |
123 | For validation, image is first Resize then CenterCrop to image_size.
124 | Then normalization is applied with [0.5, 0.5, 0.5] mean and std.
125 | The input pixel values must be rescaled to [0, 1.]
126 | Outputs is converted to tensor
127 |
128 | Args:
129 | config: configs contains IMAGE_SIZE, see config.py for details
130 | Returns:
131 | transforms_train: training transforms
132 | """
133 |
134 | scale_size = int(math.floor(config.DATA.IMAGE_SIZE / config.DATA.CROP_PCT))
135 | transforms_val = transforms.Compose([
136 | transforms.Resize(scale_size, interpolation='bicubic'),
137 | transforms.CenterCrop((config.DATA.IMAGE_SIZE, config.DATA.IMAGE_SIZE)),
138 | transforms.ToTensor(),
139 | #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
140 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
141 | ])
142 | return transforms_val
143 |
144 |
145 | def get_dataset(config, mode='train'):
146 | """ Get dataset from config and mode (train/val)
147 |
148 | Returns the related dataset object according to configs and mode(train/val)
149 |
150 | Args:
151 | config: configs contains dataset related settings. see config.py for details
152 | Returns:
153 | dataset: dataset object
154 | """
155 |
156 | assert mode in ['train', 'val']
157 | if config.DATA.DATASET == "cifar10":
158 | if mode == 'train':
159 | dataset = datasets.Cifar10(mode=mode, transform=get_train_transforms(config))
160 | else:
161 | dataset = datasets.Cifar10(mode=mode, transform=get_val_transforms(config))
162 | elif config.DATA.DATASET == "cifar100":
163 | if mode == 'train':
164 | dataset = datasets.Cifar100(mode=mode, transform=get_train_transforms(config))
165 | else:
166 | dataset = datasets.Cifar100(mode=mode, transform=get_val_transforms(config))
167 | elif config.DATA.DATASET == "imagenet2012":
168 | if mode == 'train':
169 | dataset = ImageNet2012Dataset(config.DATA.DATA_PATH,
170 | mode=mode,
171 | transform=get_train_transforms(config))
172 | else:
173 | dataset = ImageNet2012Dataset(config.DATA.DATA_PATH,
174 | mode=mode,
175 | transform=get_val_transforms(config))
176 | else:
177 | raise NotImplementedError(
178 | "[{config.DATA.DATASET}] Only cifar10, cifar100, imagenet2012 are supported now")
179 | return dataset
180 |
181 |
182 | def get_dataloader(config, dataset, mode='train', multi_process=False):
183 | """Get dataloader with config, dataset, mode as input, allows multiGPU settings.
184 |
185 | Multi-GPU loader is implements as distributedBatchSampler.
186 |
187 | Args:
188 | config: see config.py for details
189 | dataset: paddle.io.dataset object
190 | mode: train/val
191 | multi_process: if True, use DistributedBatchSampler to support multi-processing
192 | Returns:
193 | dataloader: paddle.io.DataLoader object.
194 | """
195 |
196 | if mode == 'train':
197 | batch_size = config.DATA.BATCH_SIZE
198 | else:
199 | batch_size = config.DATA.BATCH_SIZE_EVAL
200 |
201 | if multi_process is True:
202 | sampler = DistributedBatchSampler(dataset,
203 | batch_size=batch_size,
204 | shuffle=(mode == 'train'))
205 | dataloader = DataLoader(dataset,
206 | batch_sampler=sampler,
207 | num_workers=config.DATA.NUM_WORKERS)
208 | else:
209 | dataloader = DataLoader(dataset,
210 | batch_size=batch_size,
211 | num_workers=config.DATA.NUM_WORKERS,
212 | shuffle=(mode == 'train'))
213 | return dataloader
214 |
--------------------------------------------------------------------------------
/mixup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """mixup and cutmix for batch data"""
16 | import numpy as np
17 | import paddle
18 |
19 |
20 | def rand_bbox(image_shape, lam, count=None):
21 | """ CutMix bbox by lam value
22 | Generate 1 random bbox by value lam. lam is the cut size rate.
23 | The cut_size is computed by sqrt(1-lam) * image_size.
24 |
25 | Args:
26 | image_shape: tuple/list, image height and width
27 | lam: float, cutmix lambda value
28 | count: int, number of bbox to generate
29 | """
30 | image_h, image_w = image_shape[-2:]
31 | cut_rate = np.sqrt(1. - lam)
32 | cut_h = int(cut_rate * image_h)
33 | cut_w = int(cut_rate * image_w)
34 |
35 | # get random bbox center
36 | cy = np.random.randint(0, image_h, size=count)
37 | cx = np.random.randint(0, image_w, size=count)
38 |
39 | # get bbox coords
40 | bbox_x1 = np.clip(cx - cut_w // 2, 0, image_w)
41 | bbox_y1 = np.clip(cy - cut_h // 2, 0, image_h)
42 | bbox_x2 = np.clip(cx + cut_w // 2, 0, image_w)
43 | bbox_y2 = np.clip(cy + cut_h // 2, 0, image_h)
44 |
45 | # NOTE: in paddle, tensor indexing e.g., a[x1:x2],
46 | # if x1 == x2, paddle will raise ValueErros,
47 | # while in pytorch, it will return [] tensor
48 | return bbox_x1, bbox_y1, bbox_x2, bbox_y2
49 |
50 |
51 | def rand_bbox_minmax(image_shape, minmax, count=None):
52 | """ CutMix bbox by min and max value
53 | Generate 1 random bbox by min and max percentage values.
54 | Minmax is a tuple/list of min and max percentage vlaues
55 | applied to the image width and height.
56 |
57 | Args:
58 | image_shape: tuple/list, image height and width
59 | minmax: tuple/list, min and max percentage values of image size
60 | count: int, number of bbox to generate
61 | """
62 | assert len(minmax) == 2
63 | image_h, image_w = image_shape[-2:]
64 | min_ratio = minmax[0]
65 | max_ratio = minmax[1]
66 | cut_h = np.random.randint(int(image_h * min_ratio), int(image_h * max_ratio), size=count)
67 | cut_w = np.random.randint(int(image_w * min_ratio), int(image_w * max_ratio), size=count)
68 |
69 | bbox_x1 = np.random.randint(0, image_w - cut_w, size=count)
70 | bbox_y1 = np.random.randint(0, image_h - cut_h, size=count)
71 | bbox_x2 = bbox_x1 + cut_w
72 | bbox_y2 = bbox_y1 + cut_h
73 |
74 | return bbox_x1, bbox_y1, bbox_x2, bbox_y2
75 |
76 |
77 | def cutmix_generate_bbox_adjust_lam(image_shape, lam, minmax=None, correct_lam=True, count=None):
78 | """Generate bbox and apply correction for lambda
79 | If the mimmax is None, apply the standard cutmix by lam value,
80 | If the minmax is set, apply the cutmix by min and max percentage values.
81 |
82 | Args:
83 | image_shape: tuple/list, image height and width
84 | lam: float, cutmix lambda value
85 | minmax: tuple/list, min and max percentage values of image size
86 | correct_lam: bool, if True, correct the lam value by the generated bbox
87 | count: int, number of bbox to generate
88 | """
89 | if minmax is not None:
90 | bbox_x1, bbox_y1, bbox_x2, bbox_y2 = rand_bbox_minmax(image_shape, minmax, count)
91 | else:
92 | bbox_x1, bbox_y1, bbox_x2, bbox_y2 = rand_bbox(image_shape, lam, count)
93 |
94 | if correct_lam or minmax is not None:
95 | image_h, image_w = image_shape[-2:]
96 | bbox_area = (bbox_y2 - bbox_y1) * (bbox_x2 - bbox_x1)
97 | lam = 1. - bbox_area / float(image_h * image_w)
98 | return (bbox_x1, bbox_y1, bbox_x2, bbox_y2), lam
99 |
100 |
101 | def one_hot(x, num_classes, on_value=1., off_value=0.):
102 | """ Generate one-hot vector for label smoothing
103 | Args:
104 | x: tensor, contains label/class indices
105 | num_classes: int, num of classes (len of the one-hot vector)
106 | on_value: float, the vector value at label index, default=1.
107 | off_value: float, the vector value at non-label indices, default=0.
108 | Returns:
109 | one_hot: tensor, tensor with on value at label index and off value
110 | at non-label indices.
111 | """
112 | x = x.reshape_([-1, 1])
113 | x_smoothed = paddle.full((x.shape[0], num_classes), fill_value=off_value)
114 | for i in range(x.shape[0]):
115 | x_smoothed[i, x[i]] = on_value
116 | return x_smoothed
117 |
118 |
119 | def mixup_one_hot(label, num_classes, lam=1., smoothing=0.):
120 | """ mixup and label smoothing in batch
121 | label smoothing is firstly applied, then
122 | mixup is applied by mixing the bacth and its flip,
123 | with a mixup rate.
124 |
125 | Args:
126 | label: tensor, label tensor with shape [N], contains the class indices
127 | num_classes: int, num of all classes
128 | lam: float, mixup rate, default=1.0
129 | smoothing: float, label smoothing rate
130 | """
131 | off_value = smoothing / num_classes
132 | on_value = 1. - smoothing + off_value
133 | y1 = one_hot(label, num_classes, on_value, off_value)
134 | y2 = one_hot(label.flip(axis=[0]), num_classes, on_value, off_value)
135 | return y2 * (1 - lam) + y1 * lam
136 |
137 |
138 | class Mixup:
139 | """Mixup class
140 | Args:
141 | mixup_alpha: float, mixup alpha for beta distribution, default=1.0,
142 | cutmix_alpha: float, cutmix alpha for beta distribution, default=0.0,
143 | cutmix_minmax: list/tuple, min and max value for cutmix ratio, default=None,
144 | prob: float, if random prob < prob, do not use mixup, default=1.0,
145 | switch_prob: float, prob of switching mixup and cutmix, default=0.5,
146 | mode: string, mixup up, now only 'batch' is supported, default='batch',
147 | correct_lam: bool, if True, apply correction of lam, default=True,
148 | label_smoothing: float, label smoothing rate, default=0.1,
149 | num_classes: int, num of classes, default=1000
150 | """
151 | def __init__(self,
152 | mixup_alpha=1.0,
153 | cutmix_alpha=0.0,
154 | cutmix_minmax=None,
155 | prob=1.0,
156 | switch_prob=0.5,
157 | mode='batch',
158 | correct_lam=True,
159 | label_smoothing=0.1,
160 | num_classes=1000):
161 | self.mixup_alpha = mixup_alpha
162 | self.cutmix_alpha = cutmix_alpha
163 | self.cutmix_minmax = cutmix_minmax
164 | if cutmix_minmax is not None:
165 | assert len(cutmix_minmax) == 2
166 | self.cutmix_alpha = 1.0
167 | self.mix_prob = prob
168 | self.switch_prob = switch_prob
169 | self.label_smoothing = label_smoothing
170 | self.num_classes = num_classes
171 | self.mode = mode
172 | self.correct_lam = correct_lam
173 | assert mode == 'batch', 'Now only batch mode is supported!'
174 |
175 | def __call__(self, x, target):
176 | assert x.shape[0] % 2 == 0, "Batch size should be even"
177 | lam = self._mix_batch(x)
178 | target = mixup_one_hot(target, self.num_classes, lam, self.label_smoothing)
179 | return x, target
180 |
181 | def get_params(self):
182 | """Decide to use cutmix or regular mixup by sampling and
183 | sample lambda for mixup
184 | """
185 | lam = 1.
186 | use_cutmix = False
187 | use_mixup = np.random.rand() < self.mix_prob
188 | if use_mixup:
189 | if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
190 | use_cutmix = np.random.rand() < self.switch_prob
191 | alpha = self.cutmix_alpha if use_cutmix else self.mixup_alpha
192 | lam_mix = np.random.beta(alpha, alpha)
193 | elif self.mixup_alpha == 0. and self.cutmix_alpha > 0.:
194 | use_cutmix=True
195 | lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
196 | elif self.mixup_alpha > 0. and self.cutmix_alpha == 0.:
197 | lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
198 | else:
199 | raise ValueError('mixup_alpha and cutmix_alpha cannot be all 0')
200 | lam = float(lam_mix)
201 | return lam, use_cutmix
202 |
203 | def _mix_batch(self, x):
204 | """mixup/cutmix by adding batch data and its flipped version"""
205 | lam, use_cutmix = self.get_params()
206 | if lam == 1.:
207 | return lam
208 | if use_cutmix:
209 | (bbox_x1, bbox_y1, bbox_x2, bbox_y2), lam = cutmix_generate_bbox_adjust_lam(
210 | x.shape,
211 | lam,
212 | minmax=self.cutmix_minmax,
213 | correct_lam=self.correct_lam)
214 |
215 | # NOTE: in paddle, tensor indexing e.g., a[x1:x2],
216 | # if x1 == x2, paddle will raise ValueErros,
217 | # but in pytorch, it will return [] tensor without errors
218 | if int(bbox_x1) != int(bbox_x2) and int(bbox_y1) != int(bbox_y2):
219 | x[:, :, int(bbox_x1): int(bbox_x2), int(bbox_y1): int(bbox_y2)] = x.flip(axis=[0])[
220 | :, :, int(bbox_x1): int(bbox_x2), int(bbox_y1): int(bbox_y2)]
221 | else:
222 | x_flipped = x.flip(axis=[0])
223 | x_flipped = x_flipped * (1 - lam)
224 | x.set_value(x * (lam) + x_flipped)
225 | return lam
226 |
--------------------------------------------------------------------------------
/auto_augment.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 |
14 | """Auto Augmentation"""
15 |
16 | import random
17 | import numpy as np
18 | from PIL import Image, ImageEnhance, ImageOps
19 |
20 |
21 | def auto_augment_policy_original():
22 | """ImageNet auto augment policy"""
23 | policy = [
24 | [('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
25 | [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
26 | [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
27 | [('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
28 | [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
29 | [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
30 | [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
31 | [('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
32 | [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
33 | [('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
34 | [('Rotate', 0.8, 8), ('Color', 0.4, 0)],
35 | [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
36 | [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
37 | [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
38 | [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
39 | [('Rotate', 0.8, 8), ('Color', 1.0, 2)],
40 | [('Color', 0.8, 8), ('Solarize', 0.8, 7)],
41 | [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
42 | [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
43 | [('Color', 0.4, 0), ('Equalize', 0.6, 3)],
44 | [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
45 | [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
46 | [('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
47 | [('Color', 0.6, 4), ('Contrast', 1.0, 8)],
48 | [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
49 | ]
50 | policy = [[SubPolicy(*args) for args in subpolicy] for subpolicy in policy]
51 | return policy
52 |
53 |
54 | class AutoAugment():
55 | """Auto Augment
56 | Randomly choose a tuple of augment ops from a list of policy
57 | Then apply the tuple of augment ops to input image
58 | """
59 | def __init__(self, policy):
60 | self.policy = policy
61 |
62 | def __call__(self, image, policy_idx=None):
63 | if policy_idx is None:
64 | policy_idx = random.randint(0, len(self.policy)-1)
65 |
66 | sub_policy = self.policy[policy_idx]
67 | for op in sub_policy:
68 | image = op(image)
69 | return image
70 |
71 |
72 | class SubPolicy:
73 | """Subpolicy
74 | Read augment name and magnitude, apply augment with probability
75 | Args:
76 | op_name: str, augment operation name
77 | prob: float, if prob > random prob, apply augment
78 | magnitude_idx: int, index of magnitude in preset magnitude ranges
79 | """
80 | def __init__(self, op_name, prob, magnitude_idx):
81 | # ranges of operations' magnitude
82 | ranges = {
83 | 'ShearX': np.linspace(0, 0.3, 10), # [-0.3, 0.3] (by random negative)
84 | 'ShearY': np.linspace(0, 0.3, 10), # [-0.3, 0.3] (by random negative)
85 | 'TranslateX': np.linspace(0, 150 / 331, 10), #[-0.45, 0.45] (by random negative)
86 | 'TranslateY': np.linspace(0, 150 / 331, 10), #[-0.45, 0.45] (by random negative)
87 | 'Rotate': np.linspace(0, 30, 10), #[-30, 30] (by random negative)
88 | 'Color': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative)
89 | 'Posterize': np.round(np.linspace(8, 4, 10), 0).astype(np.int), #[0, 4]
90 | 'Solarize': np.linspace(256, 0, 10), #[0, 256]
91 | 'Contrast': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative)
92 | 'Sharpness': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative)
93 | 'Brightness': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative)
94 | 'AutoContrast': [0] * 10, # no range
95 | 'Equalize': [0] * 10, # no range
96 | 'Invert': [0] * 10, # no range
97 | }
98 |
99 | # augmentation operations
100 | # Lambda is not pickleable for DDP
101 | #image_ops = {
102 | # 'ShearX': lambda image, magnitude: shear_x(image, magnitude),
103 | # 'ShearY': lambda image, magnitude: shear_y(image, magnitude),
104 | # 'TranslateX': lambda image, magnitude: translate_x(image, magnitude),
105 | # 'TranslateY': lambda image, magnitude: translate_y(image, magnitude),
106 | # 'Rotate': lambda image, magnitude: rotate(image, magnitude),
107 | # 'AutoContrast': lambda image, magnitude: auto_contrast(image, magnitude),
108 | # 'Invert': lambda image, magnitude: invert(image, magnitude),
109 | # 'Equalize': lambda image, magnitude: equalize(image, magnitude),
110 | # 'Solarize': lambda image, magnitude: solarize(image, magnitude),
111 | # 'Posterize': lambda image, magnitude: posterize(image, magnitude),
112 | # 'Contrast': lambda image, magnitude: contrast(image, magnitude),
113 | # 'Color': lambda image, magnitude: color(image, magnitude),
114 | # 'Brightness': lambda image, magnitude: brightness(image, magnitude),
115 | # 'Sharpness': lambda image, magnitude: sharpness(image, magnitude),
116 | #}
117 | image_ops = {
118 | 'ShearX': shear_x,
119 | 'ShearY': shear_y,
120 | 'TranslateX': translate_x_relative,
121 | 'TranslateY': translate_y_relative,
122 | 'Rotate': rotate,
123 | 'AutoContrast': auto_contrast,
124 | 'Invert': invert,
125 | 'Equalize': equalize,
126 | 'Solarize': solarize,
127 | 'Posterize': posterize,
128 | 'Contrast': contrast,
129 | 'Color': color,
130 | 'Brightness': brightness,
131 | 'Sharpness': sharpness,
132 | }
133 |
134 | self.prob = prob
135 | self.magnitude = ranges[op_name][magnitude_idx]
136 | self.op = image_ops[op_name]
137 |
138 | def __call__(self, image):
139 | if self.prob > random.random():
140 | image = self.op(image, self.magnitude)
141 | return image
142 |
143 |
144 | # PIL Image transforms
145 | # https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.transform
146 | def shear_x(image, magnitude, fillcolor=(128, 128, 128)):
147 | factor = magnitude * random.choice([-1, 1]) # random negative
148 | return image.transform(image.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), fillcolor=fillcolor)
149 |
150 |
151 | def shear_y(image, magnitude, fillcolor=(128, 128, 128)):
152 | factor = magnitude * random.choice([-1, 1]) # random negative
153 | return image.transform(image.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), fillcolor=fillcolor)
154 |
155 |
156 | def translate_x_relative(image, magnitude, fillcolor=(128, 128, 128)):
157 | pixels = magnitude * image.size[0]
158 | pixels = pixels * random.choice([-1, 1]) # random negative
159 | return image.transform(image.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), fillcolor=fillcolor)
160 |
161 |
162 | def translate_y_relative(image, magnitude, fillcolor=(128, 128, 128)):
163 | pixels = magnitude * image.size[0]
164 | pixels = pixels * random.choice([-1, 1]) # random negative
165 | return image.transform(image.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), fillcolor=fillcolor)
166 |
167 |
168 | def translate_x_absolute(image, magnitude, fillcolor=(128, 128, 128)):
169 | magnitude = magnitude * random.choice([-1, 1]) # random negative
170 | return image.transform(image.size, Image.AFFINE, (1, 0, magnitude, 0, 1, 0), fillcolor=fillcolor)
171 |
172 |
173 | def translate_y_absolute(image, magnitude, fillcolor=(128, 128, 128)):
174 | magnitude = magnitude * random.choice([-1, 1]) # random negative
175 | return image.transform(image.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude), fillcolor=fillcolor)
176 |
177 |
178 | def rotate(image, magnitude):
179 | rot = image.convert("RGBA").rotate(magnitude)
180 | return Image.composite(rot,
181 | Image.new('RGBA', rot.size, (128, ) * 4),
182 | rot).convert(image.mode)
183 |
184 |
185 | def auto_contrast(image, magnitude=None):
186 | return ImageOps.autocontrast(image)
187 |
188 |
189 | def invert(image, magnitude=None):
190 | return ImageOps.invert(image)
191 |
192 |
193 | def equalize(image, magnitude=None):
194 | return ImageOps.equalize(image)
195 |
196 |
197 | def solarize(image, magnitude):
198 | return ImageOps.solarize(image, magnitude)
199 |
200 |
201 | def posterize(image, magnitude):
202 | return ImageOps.posterize(image, magnitude)
203 |
204 |
205 | def contrast(image, magnitude):
206 | magnitude = magnitude * random.choice([-1, 1]) # random negative
207 | return ImageEnhance.Contrast(image).enhance(1 + magnitude)
208 |
209 |
210 | def color(image, magnitude):
211 | magnitude = magnitude * random.choice([-1, 1]) # random negative
212 | return ImageEnhance.Color(image).enhance(1 + magnitude)
213 |
214 |
215 | def brightness(image, magnitude):
216 | magnitude = magnitude * random.choice([-1, 1]) # random negative
217 | return ImageEnhance.Brightness(image).enhance(1 + magnitude)
218 |
219 |
220 | def sharpness(image, magnitude):
221 | magnitude = magnitude * random.choice([-1, 1]) # random negative
222 | return ImageEnhance.Sharpness(image).enhance(1 + magnitude)
223 |
224 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/main_single_gpu.py:
--------------------------------------------------------------------------------
1 |
2 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Swin training/validation using single GPU """
17 |
18 | import sys
19 | import os
20 | import time
21 | import logging
22 | import argparse
23 | import random
24 | import numpy as np
25 | import warnings
26 | warnings.filterwarnings('ignore')
27 |
28 | import paddle
29 | import paddle.nn as nn
30 | import paddle.nn.functional as F
31 | from datasets import get_dataloader
32 | from datasets import get_dataset
33 |
34 | from utils import AverageMeter
35 | from utils import WarmupCosineScheduler
36 | from utils import get_exclude_from_weight_decay_fn
37 | from config import get_config
38 | from config import update_config
39 | from mixup import Mixup
40 | from losses import LabelSmoothingCrossEntropyLoss
41 | from losses import SoftTargetCrossEntropyLoss
42 | from losses import DistillationLoss
43 | from swin_transformer import build_swin as build_model
44 |
45 |
46 |
47 | def get_arguments():
48 | """return argumeents, this will overwrite the config after loading yaml file"""
49 | parser = argparse.ArgumentParser('Swin')
50 | parser.add_argument('-cfg', type=str, default='/home/ubuntu13/lsz/code/S-T-V2/PaddleViT/image_classification/SwinTransformerV2/configs/swinv2_base_patch4_window7_224.yaml')
51 | parser.add_argument('-dataset', type=str, default='imagenet2012')
52 | parser.add_argument('-batch_size', type=int, default=48)
53 | parser.add_argument('-image_size', type=int, default=None)
54 | parser.add_argument('-data_path', type=str, default='/home/ubuntu13/lsz/dataset/ILSVRC')
55 | parser.add_argument('-ngpus', type=int, default=None)
56 | parser.add_argument('-pretrained', type=str, default=None)
57 | parser.add_argument('-resume', type=str, default=None)
58 | parser.add_argument('-last_epoch', type=int, default=None)
59 | parser.add_argument('-eval', action='store_true')
60 | parser.add_argument('-amp', action='store_true', default=True)
61 | arguments = parser.parse_args()
62 | return arguments
63 |
64 |
65 | def get_logger(filename, logger_name=None):
66 | """set logging file and format
67 | Args:
68 | filename: str, full path of the logger file to write
69 | logger_name: str, the logger name, e.g., 'master_logger', 'local_logger'
70 | Return:
71 | logger: python logger
72 | """
73 | log_format = "%(asctime)s %(message)s"
74 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
75 | format=log_format, datefmt="%m%d %I:%M:%S %p")
76 | # different name is needed when creating multiple logger in one process
77 | logger = logging.getLogger(logger_name)
78 | fh = logging.FileHandler(os.path.join(filename))
79 | fh.setFormatter(logging.Formatter(log_format))
80 | logger.addHandler(fh)
81 | return logger
82 |
83 |
84 | def train(dataloader,
85 | model,
86 | criterion,
87 | optimizer,
88 | epoch,
89 | total_epochs,
90 | total_batch,
91 | debug_steps=100,
92 | accum_iter=1,
93 | mixup_fn=None,
94 | amp=False,
95 | logger=None):
96 | """Training for one epoch
97 | Args:
98 | dataloader: paddle.io.DataLoader, dataloader instance
99 | model: nn.Layer, a ViT model
100 | criterion: nn.criterion
101 | epoch: int, current epoch
102 | total_epochs: int, total num of epochs
103 | total_batch: int, total num of batches for one epoch
104 | debug_steps: int, num of iters to log info, default: 100
105 | accum_iter: int, num of iters for accumulating gradients, default: 1
106 | mixup_fn: Mixup, mixup instance, default: None
107 | amp: bool, if True, use mix precision training, default: False
108 | logger: logger for logging, default: None
109 | Returns:
110 | train_loss_meter.avg: float, average loss on current process/gpu
111 | train_acc_meter.avg: float, average top1 accuracy on current process/gpu
112 | train_time: float, training time
113 | """
114 | model.train()
115 | train_loss_meter = AverageMeter()
116 | train_acc_meter = AverageMeter()
117 |
118 | if amp is True:
119 | scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
120 | time_st = time.time()
121 |
122 | for batch_id, data in enumerate(dataloader):
123 | image = data[0]
124 | label = data[1]
125 | label_orig = label.clone()
126 |
127 | if mixup_fn is not None:
128 | image, label = mixup_fn(image, label_orig)
129 |
130 | if amp is True: # mixed precision training
131 | with paddle.amp.auto_cast():
132 | output = model(image)
133 | loss = criterion(image, output, label)
134 | scaled = scaler.scale(loss)
135 | scaled.backward()
136 | if ((batch_id +1) % accum_iter == 0) or (batch_id + 1 == len(dataloader)):
137 | scaler.minimize(optimizer, scaled)
138 | optimizer.clear_grad()
139 | else: # full precision training
140 | output = model(image)
141 | loss = criterion(output, label)
142 | #NOTE: division may be needed depending on the loss function
143 | # Here no division is needed:
144 | # default 'reduction' param in nn.CrossEntropyLoss is set to 'mean'
145 | #loss = loss / accum_iter
146 | loss.backward()
147 |
148 | if ((batch_id +1) % accum_iter == 0) or (batch_id + 1 == len(dataloader)):
149 | optimizer.step()
150 | optimizer.clear_grad()
151 |
152 | pred = F.softmax(output)
153 | if mixup_fn:
154 | acc = paddle.metric.accuracy(pred, label_orig)
155 | else:
156 | acc = paddle.metric.accuracy(pred, label_orig.unsqueeze(1))
157 |
158 | batch_size = image.shape[0]
159 | train_loss_meter.update(loss.numpy()[0], batch_size)
160 | train_acc_meter.update(acc.numpy()[0], batch_size)
161 |
162 | if logger and batch_id % debug_steps == 0:
163 | logger.info(
164 | f"Epoch[{epoch:03d}/{total_epochs:03d}], " +
165 | f"Step[{batch_id:04d}/{total_batch:04d}], " +
166 | f"Avg Loss: {train_loss_meter.avg:.4f}, " +
167 | f"Avg Acc: {train_acc_meter.avg:.4f}")
168 |
169 | train_time = time.time() - time_st
170 | return train_loss_meter.avg, train_acc_meter.avg, train_time
171 |
172 |
173 | def validate(dataloader, model, criterion, total_batch, debug_steps=100, logger=None):
174 | """Validation for whole dataset
175 | Args:
176 | dataloader: paddle.io.DataLoader, dataloader instance
177 | model: nn.Layer, a ViT model
178 | criterion: nn.criterion
179 | total_batch: int, total num of batches for one epoch
180 | debug_steps: int, num of iters to log info, default: 100
181 | logger: logger for logging, default: None
182 | Returns:
183 | val_loss_meter.avg: float, average loss on current process/gpu
184 | val_acc1_meter.avg: float, average top1 accuracy on current process/gpu
185 | val_acc5_meter.avg: float, average top5 accuracy on current process/gpu
186 | val_time: float, valitaion time
187 | """
188 | model.eval()
189 | val_loss_meter = AverageMeter()
190 | val_acc1_meter = AverageMeter()
191 | val_acc5_meter = AverageMeter()
192 | time_st = time.time()
193 |
194 | with paddle.no_grad():
195 | for batch_id, data in enumerate(dataloader):
196 | image = data[0]
197 | label = data[1]
198 |
199 | output = model(image)
200 | loss = criterion(output, label)
201 |
202 | pred = F.softmax(output)
203 | acc1 = paddle.metric.accuracy(pred, label.unsqueeze(1))
204 | acc5 = paddle.metric.accuracy(pred, label.unsqueeze(1), k=5)
205 |
206 | batch_size = image.shape[0]
207 | val_loss_meter.update(loss.numpy()[0], batch_size)
208 | val_acc1_meter.update(acc1.numpy()[0], batch_size)
209 | val_acc5_meter.update(acc5.numpy()[0], batch_size)
210 |
211 | if logger and batch_id % debug_steps == 0:
212 | logger.info(
213 | f"Val Step[{batch_id:04d}/{total_batch:04d}], " +
214 | f"Avg Loss: {val_loss_meter.avg:.4f}, " +
215 | f"Avg Acc@1: {val_acc1_meter.avg:.4f}, " +
216 | f"Avg Acc@5: {val_acc5_meter.avg:.4f}")
217 |
218 | val_time = time.time() - time_st
219 | return val_loss_meter.avg, val_acc1_meter.avg, val_acc5_meter.avg, val_time
220 |
221 |
222 | def main():
223 | # STEP 0: Preparation
224 | # config is updated by: (1) config.py, (2) yaml file, (3) arguments
225 | arguments = get_arguments()
226 | config = get_config()
227 | config = update_config(config, arguments)
228 | # set output folder
229 | if not config.EVAL:
230 | config.SAVE = '{}/train-{}'.format(config.SAVE, time.strftime('%Y%m%d-%H-%M-%S'))
231 | else:
232 | config.SAVE = '{}/eval-{}'.format(config.SAVE, time.strftime('%Y%m%d-%H-%M-%S'))
233 | if not os.path.exists(config.SAVE):
234 | os.makedirs(config.SAVE, exist_ok=True)
235 | last_epoch = config.TRAIN.LAST_EPOCH
236 | seed = config.SEED
237 | paddle.seed(seed)
238 | np.random.seed(seed)
239 | random.seed(seed)
240 | logger = get_logger(filename=os.path.join(config.SAVE, 'log.txt'))
241 | logger.info(f'\n{config}')
242 |
243 | # STEP 1: Create model
244 | model = build_model(config)
245 |
246 | # STEP 2: Create train and val dataloader
247 | dataset_train = get_dataset(config, mode='train')
248 | dataset_val = get_dataset(config, mode='val')
249 | dataloader_train = get_dataloader(config, dataset_train, 'train', False)
250 | dataloader_val = get_dataloader(config, dataset_val, 'val', False)
251 |
252 | # STEP 3: Define Mixup function
253 | mixup_fn = None
254 | if config.TRAIN.MIXUP_PROB > 0 or config.TRAIN.CUTMIX_ALPHA > 0 or config.TRAIN.CUTMIX_MINMAX is not None:
255 | mixup_fn = Mixup(mixup_alpha=config.TRAIN.MIXUP_ALPHA,
256 | cutmix_alpha=config.TRAIN.CUTMIX_ALPHA,
257 | cutmix_minmax=config.TRAIN.CUTMIX_MINMAX,
258 | prob=config.TRAIN.MIXUP_PROB,
259 | switch_prob=config.TRAIN.MIXUP_SWITCH_PROB,
260 | mode=config.TRAIN.MIXUP_MODE,
261 | label_smoothing=config.TRAIN.SMOOTHING)
262 |
263 | # STEP 4: Define criterion
264 | if config.TRAIN.MIXUP_PROB > 0.:
265 | criterion = SoftTargetCrossEntropyLoss()
266 | elif config.TRAIN.SMOOTHING:
267 | criterion = LabelSmoothingCrossEntropyLoss()
268 | else:
269 | criterion = nn.CrossEntropyLoss()
270 | # only use cross entropy for val
271 | criterion_val = nn.CrossEntropyLoss()
272 |
273 | # STEP 5: Define optimizer and lr_scheduler
274 | # set lr according to batch size and world size (hacked from official code)
275 | linear_scaled_lr = (config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE) / 512.0
276 | linear_scaled_warmup_start_lr = (config.TRAIN.WARMUP_START_LR * config.DATA.BATCH_SIZE) / 512.0
277 | linear_scaled_end_lr = (config.TRAIN.END_LR * config.DATA.BATCH_SIZE) / 512.0
278 |
279 | if config.TRAIN.ACCUM_ITER > 1:
280 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUM_ITER
281 | linear_scaled_warmup_start_lr = linear_scaled_warmup_start_lr * config.TRAIN.ACCUM_ITER
282 | linear_scaled_end_lr = linear_scaled_end_lr * config.TRAIN.ACCUM_ITER
283 |
284 | config.TRAIN.BASE_LR = linear_scaled_lr
285 | config.TRAIN.WARMUP_START_LR = linear_scaled_warmup_start_lr
286 | config.TRAIN.END_LR = linear_scaled_end_lr
287 |
288 | scheduler = None
289 | if config.TRAIN.LR_SCHEDULER.NAME == "warmupcosine":
290 | scheduler = WarmupCosineScheduler(learning_rate=config.TRAIN.BASE_LR,
291 | warmup_start_lr=config.TRAIN.WARMUP_START_LR,
292 | start_lr=config.TRAIN.BASE_LR,
293 | end_lr=config.TRAIN.END_LR,
294 | warmup_epochs=config.TRAIN.WARMUP_EPOCHS,
295 | total_epochs=config.TRAIN.NUM_EPOCHS,
296 | last_epoch=config.TRAIN.LAST_EPOCH,
297 | )
298 | elif config.TRAIN.LR_SCHEDULER.NAME == "cosine":
299 | scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=config.TRAIN.BASE_LR,
300 | T_max=config.TRAIN.NUM_EPOCHS,
301 | last_epoch=last_epoch)
302 | elif config.scheduler == "multi-step":
303 | milestones = [int(v.strip()) for v in config.TRAIN.LR_SCHEDULER.MILESTONES.split(",")]
304 | scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=config.TRAIN.BASE_LR,
305 | milestones=milestones,
306 | gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
307 | last_epoch=last_epoch)
308 | else:
309 | logger.fatal(f"Unsupported Scheduler: {config.TRAIN.LR_SCHEDULER}.")
310 | raise NotImplementedError(f"Unsupported Scheduler: {config.TRAIN.LR_SCHEDULER}.")
311 |
312 | if config.TRAIN.OPTIMIZER.NAME == "SGD":
313 | if config.TRAIN.GRAD_CLIP:
314 | clip = paddle.nn.ClipGradByGlobalNorm(config.TRAIN.GRAD_CLIP)
315 | else:
316 | clip = None
317 | optimizer = paddle.optimizer.Momentum(
318 | parameters=model.parameters(),
319 | learning_rate=scheduler if scheduler is not None else config.TRAIN.BASE_LR,
320 | weight_decay=config.TRAIN.WEIGHT_DECAY,
321 | momentum=config.TRAIN.OPTIMIZER.MOMENTUM,
322 | grad_clip=clip)
323 | elif config.TRAIN.OPTIMIZER.NAME == "AdamW":
324 | if config.TRAIN.GRAD_CLIP:
325 | clip = paddle.nn.ClipGradByGlobalNorm(config.TRAIN.GRAD_CLIP)
326 | else:
327 | clip = None
328 | optimizer = paddle.optimizer.AdamW(
329 | parameters=model.parameters(),
330 | learning_rate=scheduler if scheduler is not None else config.TRAIN.BASE_LR,
331 | beta1=config.TRAIN.OPTIMIZER.BETAS[0],
332 | beta2=config.TRAIN.OPTIMIZER.BETAS[1],
333 | weight_decay=config.TRAIN.WEIGHT_DECAY,
334 | epsilon=config.TRAIN.OPTIMIZER.EPS,
335 | grad_clip=clip,
336 | apply_decay_param_fun=get_exclude_from_weight_decay_fn([
337 | 'absolute_pos_embed', 'relative_position_bias_table']),
338 | )
339 | else:
340 | logger.fatal(f"Unsupported Optimizer: {config.TRAIN.OPTIMIZER.NAME}.")
341 | raise NotImplementedError(f"Unsupported Optimizer: {config.TRAIN.OPTIMIZER.NAME}.")
342 |
343 | # STEP 6: Load pretrained model or load resume model and optimizer states
344 | if config.MODEL.PRETRAINED:
345 | if (config.MODEL.PRETRAINED).endswith('.pdparams'):
346 | raise ValueError(f'{config.MODEL.PRETRAINED} should not contain .pdparams')
347 | assert os.path.isfile(config.MODEL.PRETRAINED + '.pdparams') is True
348 | model_state = paddle.load(config.MODEL.PRETRAINED+'.pdparams')
349 | model.set_dict(model_state)
350 | logger.info(f"----- Pretrained: Load model state from {config.MODEL.PRETRAINED}")
351 |
352 | if config.MODEL.RESUME:
353 | assert os.path.isfile(config.MODEL.RESUME+'.pdparams') is True
354 | assert os.path.isfile(config.MODEL.RESUME+'.pdopt') is True
355 | model_state = paddle.load(config.MODEL.RESUME+'.pdparams')
356 | model.set_dict(model_state)
357 | opt_state = paddle.load(config.MODEL.RESUME+'.pdopt')
358 | optimizer.set_state_dict(opt_state)
359 | logger.info(
360 | f"----- Resume: Load model and optmizer from {config.MODEL.RESUME}")
361 |
362 | # STEP 7: Validation (eval mode)
363 | if config.EVAL:
364 | logger.info('----- Start Validating')
365 | val_loss, val_acc1, val_acc5, val_time = validate(
366 | dataloader=dataloader_val,
367 | model=model,
368 | criterion=criterion_val,
369 | total_batch=len(dataloader_val),
370 | debug_steps=config.REPORT_FREQ,
371 | logger=logger)
372 | logger.info(f"Validation Loss: {val_loss:.4f}, " +
373 | f"Validation Acc@1: {val_acc1:.4f}, " +
374 | f"Validation Acc@5: {val_acc5:.4f}, " +
375 | f"time: {val_time:.2f}")
376 | return
377 |
378 | # STEP 8: Start training and validation (train mode)
379 | logger.info(f"Start training from epoch {last_epoch+1}.")
380 | for epoch in range(last_epoch+1, config.TRAIN.NUM_EPOCHS+1):
381 | # train
382 | logger.info(f"Now training epoch {epoch}. LR={optimizer.get_lr():.6f}")
383 | train_loss, train_acc, train_time = train(dataloader=dataloader_train,
384 | model=model,
385 | criterion=criterion,
386 | optimizer=optimizer,
387 | epoch=epoch,
388 | total_epochs=config.TRAIN.NUM_EPOCHS,
389 | total_batch=len(dataloader_train),
390 | debug_steps=config.REPORT_FREQ,
391 | accum_iter=config.TRAIN.ACCUM_ITER,
392 | mixup_fn=mixup_fn,
393 | amp=config.AMP,
394 | logger=logger)
395 | scheduler.step()
396 | logger.info(f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " +
397 | f"Train Loss: {train_loss:.4f}, " +
398 | f"Train Acc: {train_acc:.4f}, " +
399 | f"time: {train_time:.2f}")
400 | # validation
401 | if epoch % config.VALIDATE_FREQ == 0 or epoch == config.TRAIN.NUM_EPOCHS:
402 | logger.info(f'----- Validation after Epoch: {epoch}')
403 | val_loss, val_acc1, val_acc5, val_time = validate(
404 | dataloader=dataloader_val,
405 | model=model,
406 | criterion=criterion_val,
407 | total_batch=len(dataloader_val),
408 | debug_steps=config.REPORT_FREQ,
409 | logger=logger)
410 | logger.info(f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " +
411 | f"Validation Loss: {val_loss:.4f}, " +
412 | f"Validation Acc@1: {val_acc1:.4f}, " +
413 | f"Validation Acc@5: {val_acc5:.4f}, " +
414 | f"time: {val_time:.2f}")
415 | # model save
416 | if epoch % config.SAVE_FREQ == 0 or epoch == config.TRAIN.NUM_EPOCHS:
417 | model_path = os.path.join(
418 | config.SAVE, f"{config.MODEL.TYPE}-Epoch-{epoch}-Loss-{train_loss}")
419 | paddle.save(model.state_dict(), model_path + '.pdparams')
420 | paddle.save(optimizer.state_dict(), model_path + '.pdopt')
421 | logger.info(f"----- Save model: {model_path}.pdparams")
422 | logger.info(f"----- Save optim: {model_path}.pdopt")
423 |
424 |
425 | if __name__ == "__main__":
426 | main()
427 |
--------------------------------------------------------------------------------
/main_multi_gpu.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Swin training/validation using multiple GPU """
16 |
17 | import sys
18 | import os
19 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
20 | import time
21 | import logging
22 | import argparse
23 | import random
24 | import numpy as np
25 | import paddle
26 | import paddle.nn as nn
27 | import paddle.nn.functional as F
28 | import paddle.distributed as dist
29 | from datasets import get_dataloader
30 | from datasets import get_dataset
31 | from utils import AverageMeter
32 | from utils import WarmupCosineScheduler
33 | from utils import get_exclude_from_weight_decay_fn
34 | from config import get_config
35 | from config import update_config
36 | from mixup import Mixup
37 | from losses import LabelSmoothingCrossEntropyLoss
38 | from losses import SoftTargetCrossEntropyLoss
39 | from losses import DistillationLoss
40 | from swin_transformer import build_swin as build_model
41 |
42 |
43 | def get_arguments():
44 | """return argumeents, this will overwrite the config after loading yaml file"""
45 | parser = argparse.ArgumentParser('Swin')
46 | parser.add_argument('-cfg', type=str, default='/home/ubuntu13/lsz/code/S-T-V2/PaddleViT/image_classification/SwinTransformerV2/configs/swinv2_base_patch4_window7_224.yaml')
47 | parser.add_argument('-dataset', type=str, default='imagenet2012')
48 | parser.add_argument('-batch_size', type=int, default=100)
49 | parser.add_argument('-image_size', type=int, default=None)
50 | parser.add_argument('-data_path', type=str, default='/home/ubuntu13/lsz/dataset/ILSVRC')
51 | parser.add_argument('-ngpus', type=int, default=None)
52 | parser.add_argument('-pretrained', type=str, default=None)
53 | parser.add_argument('-resume', type=str, default=None)
54 | parser.add_argument('-last_epoch', type=int, default=None)
55 | parser.add_argument('-eval', action='store_true')
56 | parser.add_argument('-amp', action='store_true', default=True)
57 | arguments = parser.parse_args()
58 | return arguments
59 |
60 |
61 | def get_logger(filename, logger_name=None):
62 | """set logging file and format
63 | Args:
64 | filename: str, full path of the logger file to write
65 | logger_name: str, the logger name, e.g., 'master_logger', 'local_logger'
66 | Return:
67 | logger: python logger
68 | """
69 | log_format = "%(asctime)s %(message)s"
70 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
71 | format=log_format, datefmt="%m%d %I:%M:%S %p")
72 | # different name is needed when creating multiple logger in one process
73 | logger = logging.getLogger(logger_name)
74 | fh = logging.FileHandler(os.path.join(filename))
75 | fh.setFormatter(logging.Formatter(log_format))
76 | logger.addHandler(fh)
77 | return logger
78 |
79 |
80 | def train(dataloader,
81 | model,
82 | criterion,
83 | optimizer,
84 | epoch,
85 | total_epochs,
86 | total_batch,
87 | debug_steps=100,
88 | accum_iter=1,
89 | mixup_fn=None,
90 | amp=False,
91 | local_logger=None,
92 | master_logger=None):
93 | """Training for one epoch
94 | Args:
95 | dataloader: paddle.io.DataLoader, dataloader instance
96 | model: nn.Layer, a ViT model
97 | criterion: nn.criterion
98 | epoch: int, current epoch
99 | total_epochs: int, total num of epochs
100 | total_batch: int, total num of batches for one epoch
101 | debug_steps: int, num of iters to log info, default: 100
102 | accum_iter: int, num of iters for accumulating gradients, default: 1
103 | mixup_fn: Mixup, mixup instance, default: None
104 | amp: bool, if True, use mix precision training, default: False
105 | local_logger: logger for local process/gpu, default: None
106 | master_logger: logger for main process, default: None
107 | Returns:
108 | train_loss_meter.avg: float, average loss on current process/gpu
109 | train_acc_meter.avg: float, average top1 accuracy on current process/gpu
110 | master_train_loss_meter.avg: float, average loss on all processes/gpus
111 | master_train_acc_meter.avg: float, average top1 accuracy on all processes/gpus
112 | train_time: float, training time
113 | """
114 | model.train()
115 | train_loss_meter = AverageMeter()
116 | train_acc_meter = AverageMeter()
117 | master_train_loss_meter = AverageMeter()
118 | master_train_acc_meter = AverageMeter()
119 |
120 | if amp is True:
121 | scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
122 | time_st = time.time()
123 |
124 | for batch_id, data in enumerate(dataloader):
125 | image = data[0]
126 | label = data[1]
127 | label_orig = label.clone()
128 |
129 | if mixup_fn is not None:
130 | image, label = mixup_fn(image, label_orig)
131 |
132 | if amp is True: # mixed precision training
133 | with paddle.amp.auto_cast():
134 | output = model(image)
135 | loss = criterion(image, output, label)
136 | scaled = scaler.scale(loss)
137 | scaled.backward()
138 | if ((batch_id +1) % accum_iter == 0) or (batch_id + 1 == len(dataloader)):
139 | scaler.minimize(optimizer, scaled)
140 | optimizer.clear_grad()
141 | else: # full precision training
142 | output = model(image)
143 | loss = criterion(output, label)
144 | #NOTE: division may be needed depending on the loss function
145 | # Here no division is needed:
146 | # default 'reduction' param in nn.CrossEntropyLoss is set to 'mean'
147 | #loss = loss / accum_iter
148 | loss.backward()
149 |
150 | if ((batch_id +1) % accum_iter == 0) or (batch_id + 1 == len(dataloader)):
151 | optimizer.step()
152 | optimizer.clear_grad()
153 |
154 | pred = F.softmax(output)
155 | if mixup_fn:
156 | acc = paddle.metric.accuracy(pred, label_orig)
157 | else:
158 | acc = paddle.metric.accuracy(pred, label_orig.unsqueeze(1))
159 |
160 | batch_size = paddle.to_tensor(image.shape[0])
161 |
162 | # sync from other gpus for overall loss and acc
163 | master_loss = loss.clone()
164 | master_acc = acc.clone()
165 | master_batch_size = batch_size.clone()
166 | dist.all_reduce(master_loss)
167 | dist.all_reduce(master_acc)
168 | dist.all_reduce(master_batch_size)
169 | master_loss = master_loss / dist.get_world_size()
170 | master_acc = master_acc / dist.get_world_size()
171 | master_train_loss_meter.update(master_loss.numpy()[0], master_batch_size.numpy()[0])
172 | master_train_acc_meter.update(master_acc.numpy()[0], master_batch_size.numpy()[0])
173 |
174 | train_loss_meter.update(loss.numpy()[0], batch_size.numpy()[0])
175 | train_acc_meter.update(acc.numpy()[0], batch_size.numpy()[0])
176 |
177 | if batch_id % debug_steps == 0:
178 | if local_logger:
179 | local_logger.info(
180 | f"Epoch[{epoch:03d}/{total_epochs:03d}], " +
181 | f"Step[{batch_id:04d}/{total_batch:04d}], " +
182 | f"Avg Loss: {train_loss_meter.avg:.4f}, " +
183 | f"Avg Acc: {train_acc_meter.avg:.4f}")
184 | if master_logger and dist.get_rank() == 0:
185 | master_logger.info(
186 | f"Epoch[{epoch:03d}/{total_epochs:03d}], " +
187 | f"Step[{batch_id:04d}/{total_batch:04d}], " +
188 | f"Avg Loss: {master_train_loss_meter.avg:.4f}, " +
189 | f"Avg Acc: {master_train_acc_meter.avg:.4f}")
190 |
191 | train_time = time.time() - time_st
192 | return (train_loss_meter.avg,
193 | train_acc_meter.avg,
194 | master_train_loss_meter.avg,
195 | master_train_acc_meter.avg,
196 | train_time)
197 |
198 |
199 | def validate(dataloader,
200 | model,
201 | criterion,
202 | total_batch,
203 | debug_steps=100,
204 | local_logger=None,
205 | master_logger=None):
206 | """Validation for whole dataset
207 | Args:
208 | dataloader: paddle.io.DataLoader, dataloader instance
209 | model: nn.Layer, a ViT model
210 | criterion: nn.criterion
211 | total_epoch: int, total num of epoch, for logging
212 | debug_steps: int, num of iters to log info, default: 100
213 | local_logger: logger for local process/gpu, default: None
214 | master_logger: logger for main process, default: None
215 | Returns:
216 | val_loss_meter.avg: float, average loss on current process/gpu
217 | val_acc1_meter.avg: float, average top1 accuracy on current process/gpu
218 | val_acc5_meter.avg: float, average top5 accuracy on current process/gpu
219 | master_val_loss_meter.avg: float, average loss on all processes/gpus
220 | master_val_acc1_meter.avg: float, average top1 accuracy on all processes/gpus
221 | master_val_acc5_meter.avg: float, average top5 accuracy on all processes/gpus
222 | val_time: float, validation time
223 | """
224 | model.eval()
225 | val_loss_meter = AverageMeter()
226 | val_acc1_meter = AverageMeter()
227 | val_acc5_meter = AverageMeter()
228 | master_val_loss_meter = AverageMeter()
229 | master_val_acc1_meter = AverageMeter()
230 | master_val_acc5_meter = AverageMeter()
231 | time_st = time.time()
232 |
233 | with paddle.no_grad():
234 | for batch_id, data in enumerate(dataloader):
235 | image = data[0]
236 | label = data[1]
237 |
238 | output = model(image)
239 | loss = criterion(output, label)
240 |
241 | pred = F.softmax(output)
242 | acc1 = paddle.metric.accuracy(pred, label.unsqueeze(1))
243 | acc5 = paddle.metric.accuracy(pred, label.unsqueeze(1), k=5)
244 |
245 | batch_size = paddle.to_tensor(image.shape[0])
246 |
247 | master_loss = loss.clone()
248 | master_acc1 = acc1.clone()
249 | master_acc5 = acc5.clone()
250 | master_batch_size = batch_size.clone()
251 |
252 | dist.all_reduce(master_loss)
253 | dist.all_reduce(master_acc1)
254 | dist.all_reduce(master_acc5)
255 | dist.all_reduce(master_batch_size)
256 | master_loss = master_loss / dist.get_world_size()
257 | master_acc1 = master_acc1 / dist.get_world_size()
258 | master_acc5 = master_acc5 / dist.get_world_size()
259 |
260 | master_val_loss_meter.update(master_loss.numpy()[0], master_batch_size.numpy()[0])
261 | master_val_acc1_meter.update(master_acc1.numpy()[0], master_batch_size.numpy()[0])
262 | master_val_acc5_meter.update(master_acc5.numpy()[0], master_batch_size.numpy()[0])
263 |
264 | val_loss_meter.update(loss.numpy()[0], batch_size.numpy()[0])
265 | val_acc1_meter.update(acc1.numpy()[0], batch_size.numpy()[0])
266 | val_acc5_meter.update(acc5.numpy()[0], batch_size.numpy()[0])
267 |
268 | if batch_id % debug_steps == 0:
269 | if local_logger:
270 | local_logger.info(
271 | f"Val Step[{batch_id:04d}/{total_batch:04d}], " +
272 | f"Avg Loss: {val_loss_meter.avg:.4f}, " +
273 | f"Avg Acc@1: {val_acc1_meter.avg:.4f}, " +
274 | f"Avg Acc@5: {val_acc5_meter.avg:.4f}")
275 | if master_logger and dist.get_rank() == 0:
276 | master_logger.info(
277 | f"Val Step[{batch_id:04d}/{total_batch:04d}], " +
278 | f"Avg Loss: {master_val_loss_meter.avg:.4f}, " +
279 | f"Avg Acc@1: {master_val_acc1_meter.avg:.4f}, " +
280 | f"Avg Acc@5: {master_val_acc5_meter.avg:.4f}")
281 | val_time = time.time() - time_st
282 | return (val_loss_meter.avg,
283 | val_acc1_meter.avg,
284 | val_acc5_meter.avg,
285 | master_val_loss_meter.avg,
286 | master_val_acc1_meter.avg,
287 | master_val_acc5_meter.avg,
288 | val_time)
289 |
290 |
291 | def main_worker(*args):
292 | # STEP 0: Preparation
293 | config = args[0]
294 | dist.init_parallel_env()
295 | last_epoch = config.TRAIN.LAST_EPOCH
296 | world_size = dist.get_world_size()
297 | local_rank = dist.get_rank()
298 | seed = config.SEED + local_rank
299 | paddle.seed(seed)
300 | np.random.seed(seed)
301 | random.seed(seed)
302 | # logger for each process/gpu
303 | local_logger = get_logger(
304 | filename=os.path.join(config.SAVE, 'log_{}.txt'.format(local_rank)),
305 | logger_name='local_logger')
306 | # overall logger
307 | if local_rank == 0:
308 | master_logger = get_logger(
309 | filename=os.path.join(config.SAVE, 'log.txt'),
310 | logger_name='master_logger')
311 | master_logger.info(f'\n{config}')
312 | else:
313 | master_logger = None
314 | local_logger.info(f'----- world_size = {world_size}, local_rank = {local_rank}')
315 | if local_rank == 0:
316 | master_logger.info(f'----- world_size = {world_size}, local_rank = {local_rank}')
317 |
318 | # STEP 1: Create model
319 | model = build_model(config)
320 | model = paddle.DataParallel(model)
321 |
322 | # STEP 2: Create train and val dataloader
323 | dataset_train, dataset_val = args[1], args[2]
324 | dataloader_train = get_dataloader(config, dataset_train, 'train', True)
325 | dataloader_val = get_dataloader(config, dataset_val, 'test', True)
326 | total_batch_train = len(dataloader_train)
327 | total_batch_val = len(dataloader_val)
328 | local_logger.info(f'----- Total # of train batch (single gpu): {total_batch_train}')
329 | local_logger.info(f'----- Total # of val batch (single gpu): {total_batch_val}')
330 | if local_rank == 0:
331 | master_logger.info(f'----- Total # of train batch (single gpu): {total_batch_train}')
332 | master_logger.info(f'----- Total # of val batch (single gpu): {total_batch_val}')
333 |
334 | # STEP 3: Define Mixup function
335 | mixup_fn = None
336 | if config.TRAIN.MIXUP_PROB > 0 or config.TRAIN.CUTMIX_ALPHA > 0 or config.TRAIN.CUTMIX_MINMAX is not None:
337 | mixup_fn = Mixup(mixup_alpha=config.TRAIN.MIXUP_ALPHA,
338 | cutmix_alpha=config.TRAIN.CUTMIX_ALPHA,
339 | cutmix_minmax=config.TRAIN.CUTMIX_MINMAX,
340 | prob=config.TRAIN.MIXUP_PROB,
341 | switch_prob=config.TRAIN.MIXUP_SWITCH_PROB,
342 | mode=config.TRAIN.MIXUP_MODE,
343 | label_smoothing=config.TRAIN.SMOOTHING)
344 |
345 | # STEP 4: Define criterion
346 | if config.TRAIN.MIXUP_PROB > 0.:
347 | criterion = SoftTargetCrossEntropyLoss()
348 | elif config.TRAIN.SMOOTHING:
349 | criterion = LabelSmoothingCrossEntropyLoss()
350 | else:
351 | criterion = nn.CrossEntropyLoss()
352 | # only use cross entropy for val
353 | criterion_val = nn.CrossEntropyLoss()
354 |
355 | # STEP 5: Define optimizer and lr_scheduler
356 | # set lr according to batch size and world size (hacked from official code)
357 | linear_scaled_lr = (config.TRAIN.BASE_LR *
358 | config.DATA.BATCH_SIZE * dist.get_world_size()) / 512.0
359 | linear_scaled_warmup_start_lr = (config.TRAIN.WARMUP_START_LR *
360 | config.DATA.BATCH_SIZE * dist.get_world_size()) / 512.0
361 | linear_scaled_end_lr = (config.TRAIN.END_LR *
362 | config.DATA.BATCH_SIZE * dist.get_world_size()) / 512.0
363 |
364 | if config.TRAIN.ACCUM_ITER > 1:
365 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUM_ITER
366 | linear_scaled_warmup_start_lr = linear_scaled_warmup_start_lr * config.TRAIN.ACCUM_ITER
367 | linear_scaled_end_lr = linear_scaled_end_lr * config.TRAIN.ACCUM_ITER
368 |
369 | config.TRAIN.BASE_LR = linear_scaled_lr
370 | config.TRAIN.WARMUP_START_LR = linear_scaled_warmup_start_lr
371 | config.TRAIN.END_LR = linear_scaled_end_lr
372 |
373 | scheduler = None
374 | if config.TRAIN.LR_SCHEDULER.NAME == "warmupcosine":
375 | scheduler = WarmupCosineScheduler(learning_rate=config.TRAIN.BASE_LR,
376 | warmup_start_lr=config.TRAIN.WARMUP_START_LR,
377 | start_lr=config.TRAIN.BASE_LR,
378 | end_lr=config.TRAIN.END_LR,
379 | warmup_epochs=config.TRAIN.WARMUP_EPOCHS,
380 | total_epochs=config.TRAIN.NUM_EPOCHS,
381 | last_epoch=config.TRAIN.LAST_EPOCH,
382 | )
383 | elif config.TRAIN.LR_SCHEDULER.NAME == "cosine":
384 | scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=config.TRAIN.BASE_LR,
385 | T_max=config.TRAIN.NUM_EPOCHS,
386 | last_epoch=last_epoch)
387 | elif config.scheduler == "multi-step":
388 | milestones = [int(v.strip()) for v in config.TRAIN.LR_SCHEDULER.MILESTONES.split(",")]
389 | scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=config.TRAIN.BASE_LR,
390 | milestones=milestones,
391 | gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
392 | last_epoch=last_epoch)
393 | else:
394 | local_logger.fatal(f"Unsupported Scheduler: {config.TRAIN.LR_SCHEDULER}.")
395 | if local_rank == 0:
396 | master_logger.fatal(f"Unsupported Scheduler: {config.TRAIN.LR_SCHEDULER}.")
397 | raise NotImplementedError(f"Unsupported Scheduler: {config.TRAIN.LR_SCHEDULER}.")
398 |
399 | if config.TRAIN.OPTIMIZER.NAME == "SGD":
400 | if config.TRAIN.GRAD_CLIP:
401 | clip = paddle.nn.ClipGradByGlobalNorm(config.TRAIN.GRAD_CLIP)
402 | else:
403 | clip = None
404 | optimizer = paddle.optimizer.Momentum(
405 | parameters=model.parameters(),
406 | learning_rate=scheduler if scheduler is not None else config.TRAIN.BASE_LR,
407 | weight_decay=config.TRAIN.WEIGHT_DECAY,
408 | momentum=config.TRAIN.OPTIMIZER.MOMENTUM,
409 | grad_clip=clip)
410 | elif config.TRAIN.OPTIMIZER.NAME == "AdamW":
411 | if config.TRAIN.GRAD_CLIP:
412 | clip = paddle.nn.ClipGradByGlobalNorm(config.TRAIN.GRAD_CLIP)
413 | else:
414 | clip = None
415 | optimizer = paddle.optimizer.AdamW(
416 | parameters=model.parameters(),
417 | learning_rate=scheduler if scheduler is not None else config.TRAIN.BASE_LR,
418 | beta1=config.TRAIN.OPTIMIZER.BETAS[0],
419 | beta2=config.TRAIN.OPTIMIZER.BETAS[1],
420 | weight_decay=config.TRAIN.WEIGHT_DECAY,
421 | epsilon=config.TRAIN.OPTIMIZER.EPS,
422 | grad_clip=clip,
423 | apply_decay_param_fun=get_exclude_from_weight_decay_fn([
424 | 'absolute_pos_embed', 'relative_position_bias_table']),
425 | )
426 | else:
427 | local_logger.fatal(f"Unsupported Optimizer: {config.TRAIN.OPTIMIZER.NAME}.")
428 | if local_rank == 0:
429 | master_logger.fatal(f"Unsupported Optimizer: {config.TRAIN.OPTIMIZER.NAME}.")
430 | raise NotImplementedError(f"Unsupported Optimizer: {config.TRAIN.OPTIMIZER.NAME}.")
431 |
432 | # STEP 6: Load pretrained model / load resumt model and optimizer states
433 | if config.MODEL.PRETRAINED:
434 | if (config.MODEL.PRETRAINED).endswith('.pdparams'):
435 | raise ValueError(f'{config.MODEL.PRETRAINED} should not contain .pdparams')
436 | assert os.path.isfile(config.MODEL.PRETRAINED + '.pdparams') is True
437 | model_state = paddle.load(config.MODEL.PRETRAINED+'.pdparams')
438 | model.set_dict(model_state)
439 | local_logger.info(f"----- Pretrained: Load model state from {config.MODEL.PRETRAINED}")
440 | if local_rank == 0:
441 | master_logger.info(
442 | f"----- Pretrained: Load model state from {config.MODEL.PRETRAINED}")
443 |
444 | if config.MODEL.RESUME:
445 | assert os.path.isfile(config.MODEL.RESUME+'.pdparams') is True
446 | assert os.path.isfile(config.MODEL.RESUME+'.pdopt') is True
447 | model_state = paddle.load(config.MODEL.RESUME+'.pdparams')
448 | model.set_dict(model_state)
449 | opt_state = paddle.load(config.MODEL.RESUME+'.pdopt')
450 | optimizer.set_state_dict(opt_state)
451 | local_logger.info(
452 | f"----- Resume Training: Load model and optmizer from {config.MODEL.RESUME}")
453 | if local_rank == 0:
454 | master_logger.info(
455 | f"----- Resume Training: Load model and optmizer from {config.MODEL.RESUME}")
456 |
457 | # STEP 7: Validation (eval mode)
458 | if config.EVAL:
459 | local_logger.info('----- Start Validating')
460 | if local_rank == 0:
461 | master_logger.info('----- Start Validating')
462 | val_loss, val_acc1, val_acc5, avg_loss, avg_acc1, avg_acc5, val_time = validate(
463 | dataloader=dataloader_val,
464 | model=model,
465 | criterion=criterion_val,
466 | total_batch=total_batch_val,
467 | debug_steps=config.REPORT_FREQ,
468 | local_logger=local_logger,
469 | master_logger=master_logger)
470 | local_logger.info(f"Validation Loss: {val_loss:.4f}, " +
471 | f"Validation Acc@1: {val_acc1:.4f}, " +
472 | f"Validation Acc@5: {val_acc5:.4f}, " +
473 | f"time: {val_time:.2f}")
474 | if local_rank == 0:
475 | master_logger.info(f"Validation Loss: {avg_loss:.4f}, " +
476 | f"Validation Acc@1: {avg_acc1:.4f}, " +
477 | f"Validation Acc@5: {avg_acc5:.4f}, " +
478 | f"time: {val_time:.2f}")
479 | return
480 |
481 | # STEP 8: Start training and validation (train mode)
482 | local_logger.info(f"Start training from epoch {last_epoch+1}.")
483 | if local_rank == 0:
484 | master_logger.info(f"Start training from epoch {last_epoch+1}.")
485 | for epoch in range(last_epoch+1, config.TRAIN.NUM_EPOCHS+1):
486 | # train
487 | local_logger.info(f"Now training epoch {epoch}. LR={optimizer.get_lr():.6f}")
488 | if local_rank == 0:
489 | master_logger.info(f"Now training epoch {epoch}. LR={optimizer.get_lr():.6f}")
490 | train_loss, train_acc, avg_loss, avg_acc, train_time = train(
491 | dataloader=dataloader_train,
492 | model=model,
493 | criterion=criterion,
494 | optimizer=optimizer,
495 | epoch=epoch,
496 | total_epochs=config.TRAIN.NUM_EPOCHS,
497 | total_batch=total_batch_train,
498 | debug_steps=config.REPORT_FREQ,
499 | accum_iter=config.TRAIN.ACCUM_ITER,
500 | mixup_fn=mixup_fn,
501 | amp=config.AMP,
502 | local_logger=local_logger,
503 | master_logger=master_logger)
504 |
505 | scheduler.step()
506 |
507 | local_logger.info(f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " +
508 | f"Train Loss: {train_loss:.4f}, " +
509 | f"Train Acc: {train_acc:.4f}, " +
510 | f"time: {train_time:.2f}")
511 | if local_rank == 0:
512 | master_logger.info(f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " +
513 | f"Train Loss: {avg_loss:.4f}, " +
514 | f"Train Acc: {avg_acc:.4f}, " +
515 | f"time: {train_time:.2f}")
516 |
517 | # validation
518 | if epoch % config.VALIDATE_FREQ == 0 or epoch == config.TRAIN.NUM_EPOCHS:
519 | local_logger.info(f'----- Validation after Epoch: {epoch}')
520 | if local_rank == 0:
521 | master_logger.info(f'----- Validation after Epoch: {epoch}')
522 | val_loss, val_acc1, val_acc5, avg_loss, avg_acc1, avg_acc5, val_time = validate(
523 | dataloader=dataloader_val,
524 | model=model,
525 | criterion=criterion_val,
526 | total_batch=total_batch_val,
527 | debug_steps=config.REPORT_FREQ,
528 | local_logger=local_logger,
529 | master_logger=master_logger)
530 | local_logger.info(f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " +
531 | f"Validation Loss: {val_loss:.4f}, " +
532 | f"Validation Acc@1: {val_acc1:.4f}, " +
533 | f"Validation Acc@5: {val_acc5:.4f}, " +
534 | f"time: {val_time:.2f}")
535 | if local_rank == 0:
536 | master_logger.info(f"----- Epoch[{epoch:03d}/{config.TRAIN.NUM_EPOCHS:03d}], " +
537 | f"Validation Loss: {avg_loss:.4f}, " +
538 | f"Validation Acc@1: {avg_acc1:.4f}, " +
539 | f"Validation Acc@5: {avg_acc5:.4f}, " +
540 | f"time: {val_time:.2f}")
541 | # model save
542 | if local_rank == 0:
543 | if epoch % config.SAVE_FREQ == 0 or epoch == config.TRAIN.NUM_EPOCHS:
544 | model_path = os.path.join(
545 | config.SAVE, f"{config.MODEL.TYPE}-Epoch-{epoch}-Loss-{train_loss}")
546 | paddle.save(model.state_dict(), model_path + '.pdparams')
547 | paddle.save(optimizer.state_dict(), model_path + '.pdopt')
548 | master_logger.info(f"----- Save model: {model_path}.pdparams")
549 | master_logger.info(f"----- Save optim: {model_path}.pdopt")
550 |
551 |
552 | def main():
553 | # config is updated by: (1) config.py, (2) yaml file, (3) arguments
554 | arguments = get_arguments()
555 | config = get_config()
556 | config = update_config(config, arguments)
557 |
558 | # set output folder
559 | if not config.EVAL:
560 | config.SAVE = '{}/train-{}'.format(config.SAVE, time.strftime('%Y%m%d-%H-%M-%S'))
561 | else:
562 | config.SAVE = '{}/eval-{}'.format(config.SAVE, time.strftime('%Y%m%d-%H-%M-%S'))
563 |
564 | if not os.path.exists(config.SAVE):
565 | os.makedirs(config.SAVE, exist_ok=True)
566 |
567 | # get dataset and start DDP
568 | dataset_train = get_dataset(config, mode='train')
569 | dataset_val = get_dataset(config, mode='val')
570 | config.NGPUS = len(paddle.static.cuda_places()) if config.NGPUS == -1 else config.NGPUS
571 | dist.spawn(main_worker, args=(config, dataset_train, dataset_val, ), nprocs=config.NGPUS)
572 |
573 |
574 | if __name__ == "__main__":
575 | main()
576 |
--------------------------------------------------------------------------------
/swin_transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 PPViT Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Implement Transformer Class for Swin Transformer V2
17 | """
18 |
19 | from types import TracebackType
20 | import paddle
21 | from paddle.framework import dtype
22 | import paddle.nn as nn
23 | from droppath import DropPath
24 |
25 |
26 | class Identity(nn.Layer):
27 | """ Identity layer
28 |
29 | The output of this layer is the input without any change.
30 | Use this layer to avoid if condition in some forward methods
31 |
32 | """
33 | def __init__(self):
34 | super(Identity, self).__init__()
35 | def forward(self, x):
36 | return x
37 |
38 |
39 | class PatchEmbedding(nn.Layer):
40 | """Patch Embeddings
41 |
42 | Apply patch embeddings on input images. Embeddings is implemented using a Conv2D op.
43 |
44 | Attributes:
45 | image_size: int, input image size, default: 224
46 | patch_size: int, size of patch, default: 4
47 | in_channels: int, input image channels, default: 3
48 | embed_dim: int, embedding dimension, default: 96
49 | """
50 |
51 | def __init__(self, image_size=224, patch_size=4, in_channels=3, embed_dim=96):
52 | super().__init__()
53 | image_size = (image_size, image_size) # TODO: add to_2tuple
54 | patch_size = (patch_size, patch_size)
55 | patches_resolution = [image_size[0]//patch_size[0], image_size[1]//patch_size[1]]
56 | self.image_size = image_size
57 | self.patch_size = patch_size
58 | self.patches_resolution = patches_resolution
59 | self.num_patches = patches_resolution[0] * patches_resolution[1]
60 | self.in_channels = in_channels
61 | self.embed_dim = embed_dim
62 | self.patch_embed = nn.Conv2D(in_channels=in_channels,
63 | out_channels=embed_dim,
64 | kernel_size=patch_size,
65 | stride=patch_size)
66 |
67 | w_attr, b_attr = self._init_weights_layernorm()
68 | self.norm = nn.LayerNorm(embed_dim,
69 | weight_attr=w_attr,
70 | bias_attr=b_attr)
71 |
72 | def _init_weights_layernorm(self):
73 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1))
74 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
75 | return weight_attr, bias_attr
76 |
77 | def forward(self, x):
78 | x = self.patch_embed(x) # [batch, embed_dim, h, w] h,w = patch_resolution
79 | x = x.flatten(start_axis=2, stop_axis=-1) # [batch, embed_dim, h*w] h*w = num_patches
80 | x = x.transpose([0, 2, 1]) # [batch, h*w, embed_dim]
81 | x = self.norm(x) # [batch, num_patches, embed_dim]
82 | return x
83 |
84 |
85 | class PatchMerging(nn.Layer):
86 | """ Patch Merging class
87 |
88 | Merge multiple patch into one path and keep the out dim.
89 | Spefically, merge adjacent 2x2 patches(dim=C) into 1 patch.
90 | The concat dim 4*C is rescaled to 2*C
91 |
92 | Attributes:
93 | input_resolution: tuple of ints, the size of input
94 | dim: dimension of single patch
95 | reduction: nn.Linear which maps 4C to 2C dim
96 | norm: nn.LayerNorm, applied after linear layer.
97 | """
98 |
99 | def __init__(self, input_resolution, dim):
100 | super(PatchMerging, self).__init__()
101 | self.input_resolution = input_resolution
102 | self.dim = dim
103 | w_attr_1, b_attr_1 = self._init_weights()
104 | self.reduction = nn.Linear(4 * dim,
105 | 2 * dim,
106 | weight_attr=w_attr_1,
107 | bias_attr=False)
108 |
109 | w_attr_2, b_attr_2 = self._init_weights_layernorm()
110 | self.norm = nn.LayerNorm(4*dim,
111 | weight_attr=w_attr_2,
112 | bias_attr=b_attr_2)
113 |
114 | def _init_weights_layernorm(self):
115 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1))
116 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
117 | return weight_attr, bias_attr
118 |
119 | def _init_weights(self):
120 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
121 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
122 | return weight_attr, bias_attr
123 |
124 | def forward(self, x):
125 | h, w = self.input_resolution
126 | b, _, c = x.shape
127 | x = x.reshape([b, h, w, c])
128 |
129 | x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
130 | x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
131 | x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
132 | x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
133 | x = paddle.concat([x0, x1, x2, x3], -1) #[B, H/2, W/2, 4*C]
134 | x = x.reshape([b, -1, 4*c]) # [B, H/2*W/2, 4*C]
135 |
136 | x = self.norm(x)
137 | x = self.reduction(x)
138 |
139 | return x
140 |
141 |
142 | class Mlp(nn.Layer):
143 | """ MLP module
144 |
145 | Impl using nn.Linear and activation is GELU, dropout is applied.
146 | Ops: fc -> act -> dropout -> fc -> dropout
147 |
148 | Attributes:
149 | fc1: nn.Linear
150 | fc2: nn.Linear
151 | act: GELU
152 | dropout1: dropout after fc1
153 | dropout2: dropout after fc2
154 | """
155 |
156 | def __init__(self, in_features, hidden_features, dropout):
157 | super(Mlp, self).__init__()
158 | w_attr_1, b_attr_1 = self._init_weights()
159 | self.fc1 = nn.Linear(in_features,
160 | hidden_features,
161 | weight_attr=w_attr_1,
162 | bias_attr=b_attr_1)
163 |
164 | w_attr_2, b_attr_2 = self._init_weights()
165 | self.fc2 = nn.Linear(hidden_features,
166 | in_features,
167 | weight_attr=w_attr_2,
168 | bias_attr=b_attr_2)
169 | self.act = nn.GELU()
170 | self.dropout = nn.Dropout(dropout)
171 |
172 | def _init_weights(self):
173 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
174 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
175 | return weight_attr, bias_attr
176 |
177 | def forward(self, x):
178 | x = self.fc1(x)
179 | x = self.act(x)
180 | x = self.dropout(x)
181 | x = self.fc2(x)
182 | x = self.dropout(x)
183 | return x
184 |
185 | class Mlp_Relu(nn.Layer):
186 | """ MLP module
187 |
188 | Impl using nn.Linear and activation is GELU, dropout is applied.
189 | Ops: fc -> act -> dropout -> fc -> dropout
190 |
191 | Attributes:
192 | fc1: nn.Linear
193 | fc2: nn.Linear
194 | act: RELU
195 | dropout1: dropout after fc1
196 | dropout2: dropout after fc2
197 | """
198 |
199 | def __init__(self, in_features, hidden_features, out_features, dropout):
200 | super(Mlp_Relu, self).__init__()
201 | w_attr_1, b_attr_1 = self._init_weights()
202 | self.fc1 = nn.Linear(in_features,
203 | hidden_features,
204 | weight_attr=w_attr_1,
205 | bias_attr=b_attr_1)
206 |
207 | w_attr_2, b_attr_2 = self._init_weights()
208 | self.fc2 = nn.Linear(hidden_features,
209 | out_features,
210 | weight_attr=w_attr_2,
211 | bias_attr=b_attr_2)
212 | self.act = nn.ReLU()
213 | self.dropout = nn.Dropout(dropout)
214 |
215 | def _init_weights(self):
216 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
217 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
218 | return weight_attr, bias_attr
219 |
220 | def forward(self, x):
221 | x = self.fc1(x)
222 | x = self.act(x)
223 | x = self.dropout(x)
224 | x = self.fc2(x)
225 | x = self.dropout(x)
226 | return x
227 |
228 |
229 | class WindowAttention(nn.Layer):
230 | """Window based multihead attention, with relative position bias.
231 |
232 | Both shifted window and non-shifted window are supported.
233 |
234 | Attributes:
235 | dim: int, input dimension (channels)
236 | window_size: int, height and width of the window
237 | num_heads: int, number of attention heads
238 | qkv_bias: bool, if True, enable learnable bias to q,k,v, default: True
239 | qk_scale: float, override default qk scale head_dim**-0.5 if set, default: None
240 | attention_dropout: float, dropout of attention
241 | dropout: float, dropout for output
242 | """
243 |
244 | def __init__(self,
245 | dim,
246 | window_size,
247 | num_heads,
248 | qkv_bias=True,
249 | qk_scale=None,
250 | attention_dropout=0.,
251 | dropout=0.):
252 | super(WindowAttention, self).__init__()
253 | self.window_size = window_size
254 | self.num_heads = num_heads
255 | self.dim = dim
256 | self.dim_head = dim // num_heads
257 | self.scale = qk_scale or self.dim_head ** -0.5
258 |
259 | self.relative_position_bias_table = paddle.create_parameter(
260 | shape=[(2 * window_size[0] -1) * (2 * window_size[1] - 1), num_heads],
261 | dtype='float32',
262 | default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
263 |
264 | # relative position index for each token inside window
265 | coords_h = paddle.arange(self.window_size[0])
266 | coords_w = paddle.arange(self.window_size[1])
267 | coords = paddle.stack(paddle.meshgrid([coords_h, coords_w])) # [2, window_h, window_w]
268 | coords_flatten = paddle.flatten(coords, 1) # [2, window_h * window_w]
269 | # 2, window_h * window_w, window_h * window_h
270 | relative_coords = coords_flatten.unsqueeze(2) - coords_flatten.unsqueeze(1)
271 | # winwod_h*window_w, window_h*window_w, 2
272 | relative_coords = relative_coords.transpose([1, 2, 0])
273 |
274 | ## Swin-T v1
275 | # relative_coords[:, :, 0] += self.window_size[0] - 1
276 | # relative_coords[:, :, 1] += self.window_size[1] - 1
277 | # relative_coords[:, :, 0] *= 2* self.window_size[1] - 1
278 | # relative_position_index = relative_coords.sum(-1) # [window_size * window_size, window_size*window_size]
279 | # self.register_buffer("relative_position_index", relative_position_index)
280 |
281 | ## Swin-T v2, log-spaced coordinates, Eq.(4)
282 | log_relative_position_index = paddle.multiply(relative_coords.cast(dtype='float32').sign(),
283 | paddle.log((relative_coords.cast(dtype='float32').abs()+1)))
284 | self.register_buffer("log_relative_position_index", log_relative_position_index)
285 | ## Swin-T v2, small meta network, Eq.(3)
286 | self.cpb = Mlp_Relu(in_features=2, # delta x, delta y
287 | hidden_features=512, # TODO: hidden dims
288 | out_features=self.num_heads,
289 | dropout=dropout)
290 |
291 | w_attr_1, b_attr_1 = self._init_weights()
292 | self.qkv = nn.Linear(dim,
293 | dim * 3,
294 | weight_attr=w_attr_1,
295 | bias_attr=b_attr_1 if qkv_bias else False)
296 |
297 | self.attn_dropout = nn.Dropout(attention_dropout)
298 |
299 | w_attr_2, b_attr_2 = self._init_weights()
300 | self.proj = nn.Linear(dim,
301 | dim,
302 | weight_attr=w_attr_2,
303 | bias_attr=b_attr_2)
304 | self.proj_dropout = nn.Dropout(dropout)
305 | self.softmax = nn.Softmax(axis=-1)
306 |
307 | # Swin-T v2, Scaled cosine attention
308 | self.tau = paddle.create_parameter(
309 | shape = [num_heads, window_size[0]*window_size[1], window_size[0]*window_size[1]],
310 | dtype='float32',
311 | default_initializer=paddle.nn.initializer.Constant(1))
312 |
313 | def _init_weights(self):
314 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
315 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
316 | return weight_attr, bias_attr
317 |
318 | def transpose_multihead(self, x):
319 | new_shape = x.shape[:-1] + [self.num_heads, self.dim_head]
320 | x = x.reshape(new_shape)
321 | x = x.transpose([0, 2, 1, 3])
322 | return x
323 |
324 | def get_relative_pos_bias_from_pos_index(self):
325 | # relative_position_bias_table is a ParamBase object
326 | # https://github.com/PaddlePaddle/Paddle/blob/067f558c59b34dd6d8626aad73e9943cf7f5960f/python/paddle/fluid/framework.py#L5727
327 | table = self.relative_position_bias_table # N x num_heads
328 | # index is a tensor
329 | index = self.relative_position_index.reshape([-1]) # window_h*window_w * window_h*window_w
330 | # NOTE: paddle does NOT support indexing Tensor by a Tensor
331 | relative_position_bias = paddle.index_select(x=table, index=index)
332 | return relative_position_bias
333 |
334 | def get_continuous_relative_position_bias(self):
335 | # The continuous position bias approach adopts a small meta network on the relative coordinates
336 | continuous_relative_position_bias = self.cpb(self.log_relative_position_index)
337 | return continuous_relative_position_bias
338 |
339 | def forward(self, x, mask=None):
340 | qkv = self.qkv(x).chunk(3, axis=-1) # {list:3}
341 | q, k, v = map(self.transpose_multihead, qkv) # [bs*num_window=1*64,4,49,32] -> [bs*num_window=1*16,8,49,32]-> [bs*num_window=1*4,16,49,32]->[bs*num_window=1*1,32,49,32]
342 |
343 | # Swin-T v2, Scaled cosine attention
344 | qk = paddle.matmul(q, k, transpose_y=True) # [bs*num_window=1*64,num_heads=4,49,49] -> [bs*num_window=1*16,num_heads=8,49,49] -> [bs*num_window=1*4,num_heads=16,49,49] -> [bs*num_window=1*1,num_heads=32,49,49]
345 | q2 = paddle.multiply(q, q).sum(-1).sqrt().unsqueeze(3)
346 | k2 = paddle.multiply(k, k).sum(-1).sqrt().unsqueeze(3)
347 | attn = qk/paddle.clip(paddle.matmul(q2, k2, transpose_y=True), min=1e-6)
348 | attn = attn/paddle.clip(self.tau.unsqueeze(0), min=0.01)
349 |
350 | ## Swin-T v1
351 | # relative_position_bias = self.get_relative_pos_bias_from_pos_index() #[2401,num_heads=4]->[2401,8]->[2401,16]->[2401,32]
352 | ## Swin-T v2
353 | relative_position_bias = self.get_continuous_relative_position_bias()
354 | relative_position_bias = relative_position_bias.reshape(
355 | [self.window_size[0] * self.window_size[1],
356 | self.window_size[0] * self.window_size[1],
357 | -1]) # [49,49,num_heads=4]->[49,49,8]->[49,49,16]->[49,49,32]
358 |
359 | # nH, window_h*window_w, window_h*window_w
360 | relative_position_bias = relative_position_bias.transpose([2, 0, 1]) # [bs*num_window=1*64,49,49]->[1*16,49,49]->[1*4,49,49]->[1*1,49,49]
361 | attn = attn + relative_position_bias.unsqueeze(0)
362 |
363 | if mask is not None:
364 | nW = mask.shape[0]
365 | attn = attn.reshape(
366 | [x.shape[0] // nW, nW, self.num_heads, x.shape[1], x.shape[1]])
367 | attn += mask.unsqueeze(1).unsqueeze(0)
368 | attn = attn.reshape([-1, self.num_heads, x.shape[1], x.shape[1]])
369 | attn = self.softmax(attn)
370 | else:
371 | attn = self.softmax(attn)
372 |
373 | attn = self.attn_dropout(attn) # [bs*num_window=1*64,4,49,49]->[1*16,8,49,49]->[1*4,16,49,49]->[1*1,32,49,49]
374 |
375 | z = paddle.matmul(attn, v) # [bs*num_window=1*64,4,49,32]->[1*16,8,49,32]->[1*4,16,49,32]->[1*1,32,49,32]
376 | z = z.transpose([0, 2, 1, 3])
377 | new_shape = z.shape[:-2] + [self.dim]
378 | z = z.reshape(new_shape)
379 | z = self.proj(z)
380 | z = self.proj_dropout(z) # [512,49,96]->[128,49,192]->[32,49,384]->[8,49,768]
381 |
382 | return z
383 |
384 |
385 | def windows_partition(x, window_size):
386 | """ partite windows into window_size x window_size
387 | Args:
388 | x: Tensor, shape=[b, h, w, c]
389 | window_size: int, window size
390 | Returns:
391 | x: Tensor, shape=[num_windows*b, window_size, window_size, c]
392 | """
393 |
394 | B, H, W, C = x.shape
395 | x = x.reshape([B, H//window_size, window_size, W//window_size, window_size, C]) # [bs,num_window,window_size,num_window,window_size,C]
396 | x = x.transpose([0, 1, 3, 2, 4, 5]) # [bs,num_window,num_window,window_size,window_Size,C]
397 | x = x.reshape([-1, window_size, window_size, C]) #(bs*num_windows,window_size, window_size, C)
398 |
399 | return x
400 |
401 |
402 | def windows_reverse(windows, window_size, H, W):
403 | """ Window reverse
404 | Args:
405 | windows: (n_windows * B, window_size, window_size, C)
406 | window_size: (int) window size
407 | H: (int) height of image
408 | W: (int) width of image
409 |
410 | Returns:
411 | x: (B, H, W, C)
412 | """
413 |
414 | B = int(windows.shape[0] / (H * W / window_size / window_size))
415 | x = windows.reshape([B, H // window_size, W // window_size, window_size, window_size, -1]) # [bs,num_window,num_window,window_size,window_Size,C]
416 | x = x.transpose([0, 1, 3, 2, 4, 5]) # [bs,num_window,window_size,num_window,window_size,C]
417 | x = x.reshape([B, H, W, -1]) #(bs,num_windows*window_size, num_windows*window_size, C)
418 | return x
419 |
420 |
421 | class SwinTransformerBlock(nn.Layer):
422 | """Swin transformer block
423 |
424 | Contains window multi head self attention, droppath, mlp, norm and residual.
425 |
426 | Attributes:
427 | dim: int, input dimension (channels)
428 | input_resolution: int, input resoultion
429 | num_heads: int, number of attention heads
430 | window_size: int, window size, default: 7
431 | shift_size: int, shift size for SW-MSA, default: 0
432 | mlp_ratio: float, ratio of mlp hidden dim and input embedding dim, default: 4.
433 | qkv_bias: bool, if True, enable learnable bias to q,k,v, default: True
434 | qk_scale: float, override default qk scale head_dim**-0.5 if set, default: None
435 | dropout: float, dropout for output, default: 0.
436 | attention_dropout: float, dropout of attention, default: 0.
437 | droppath: float, drop path rate, default: 0.
438 | """
439 |
440 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
441 | mlp_ratio=4., qkv_bias=True, qk_scale=None, dropout=0., extra_norm=False,
442 | attention_dropout=0., droppath=0.):
443 | super(SwinTransformerBlock, self).__init__()
444 | self.dim = dim
445 | self.extra_norm = extra_norm # Swin-T v2, introduce a LN unit on the main branch every 6 layers
446 | self.input_resolution = input_resolution
447 | self.num_heads = num_heads
448 | self.window_size = window_size
449 | self.shift_size = shift_size
450 | self.mlp_ratio = mlp_ratio
451 | if min(self.input_resolution) <= self.window_size:
452 | self.shift_size = 0
453 | self.window_size = min(self.input_resolution)
454 |
455 | w_attr_1, b_attr_1 = self._init_weights_layernorm()
456 | self.norm1 = nn.LayerNorm(dim,
457 | weight_attr=w_attr_1,
458 | bias_attr=b_attr_1)
459 |
460 | self.attn = WindowAttention(dim,
461 | window_size=(self.window_size, self.window_size),
462 | num_heads=num_heads,
463 | qkv_bias=qkv_bias,
464 | qk_scale=qk_scale,
465 | attention_dropout=attention_dropout,
466 | dropout=dropout)
467 | self.drop_path = DropPath(droppath) if droppath > 0. else None
468 |
469 | w_attr_2, b_attr_2 = self._init_weights_layernorm()
470 | self.norm2 = nn.LayerNorm(dim,
471 | weight_attr=w_attr_2,
472 | bias_attr=b_attr_2)
473 |
474 | self.mlp = Mlp(in_features=dim,
475 | hidden_features=int(dim*mlp_ratio),
476 | dropout=dropout)
477 | if extra_norm:
478 | # Swin-T v2, introduce a LN unit on the main branch every 6 layers
479 | w_attr_3, b_attr_3 = self._init_weights_layernorm()
480 | self.norm3 = nn.LayerNorm(dim,
481 | weight_attr=w_attr_3,
482 | bias_attr=b_attr_3)
483 |
484 | if self.shift_size > 0:
485 | H, W = self.input_resolution
486 | img_mask = paddle.zeros((1, H, W, 1))
487 | h_slices = (slice(0, -self.window_size),
488 | slice(-self.window_size, -self.shift_size),
489 | slice(-self.shift_size, None))
490 | w_slices = (slice(0, -self.window_size),
491 | slice(-self.window_size, -self.shift_size),
492 | slice(-self.shift_size, None))
493 | cnt = 0
494 | for h in h_slices:
495 | for w in w_slices:
496 | img_mask[:, h, w, :] = cnt
497 | cnt += 1
498 |
499 | mask_windows = windows_partition(img_mask, self.window_size)
500 | mask_windows = mask_windows.reshape((-1, self.window_size * self.window_size))
501 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
502 | attn_mask = paddle.where(attn_mask != 0,
503 | paddle.ones_like(attn_mask) * float(-100.0),
504 | attn_mask)
505 | attn_mask = paddle.where(attn_mask == 0,
506 | paddle.zeros_like(attn_mask),
507 | attn_mask)
508 | else:
509 | attn_mask = None
510 |
511 | self.register_buffer("attn_mask", attn_mask)
512 |
513 | def _init_weights_layernorm(self):
514 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1))
515 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
516 | return weight_attr, bias_attr
517 |
518 | def forward(self, x):
519 | H, W = self.input_resolution
520 | B, L, C = x.shape
521 | h = x
522 | # x = self.norm1(x) # Swin-T v1, pre-norm
523 |
524 | new_shape = [B, H, W, C]
525 | x = x.reshape(new_shape) # [bs,H,W,C]
526 |
527 | if self.shift_size > 0:
528 | shifted_x = paddle.roll(x,
529 | shifts=(-self.shift_size, -self.shift_size),
530 | axis=(1, 2)) # [bs,H,W,C]
531 | else:
532 | shifted_x = x
533 |
534 | x_windows = windows_partition(shifted_x, self.window_size) # [bs*num_windows,7,7,C]
535 | x_windows = x_windows.reshape([-1, self.window_size * self.window_size, C]) # [bs*num_windows,7*7,C]
536 |
537 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # [bs*num_windows,7*7,C]
538 | attn_windows = attn_windows.reshape([-1, self.window_size, self.window_size, C]) # [bs*num_windows,7,7,C]
539 |
540 | shifted_x = windows_reverse(attn_windows, self.window_size, H, W) # [bs,H,W,C]
541 |
542 | # reverse cyclic shift
543 | if self.shift_size > 0:
544 | x = paddle.roll(shifted_x,
545 | shifts=(self.shift_size, self.shift_size),
546 | axis=(1, 2))
547 | else:
548 | x = shifted_x
549 |
550 | x = x.reshape([B, H*W, C]) # [bs,H*W,C]
551 | x = self.norm1(x) # Swin-T v2, post-norm
552 |
553 | if self.drop_path is not None:
554 | x = h + self.drop_path(x)
555 | else:
556 | x = h + x
557 | h = x # [bs,H*W,C]
558 | # x = self.norm2(x) # Swin-T v1, pre-norm
559 | x = self.mlp(x) # [bs,H*W,C]
560 | x = self.norm2(x) # Swin-T v2, post-norm
561 | if self.drop_path is not None:
562 | x = h + self.drop_path(x)
563 | else:
564 | x = h + x
565 |
566 | if self.extra_norm: # Swin-T v2
567 | x = self.norm3(x)
568 |
569 | return x
570 |
571 |
572 | class SwinTransformerStage(nn.Layer):
573 | """Stage layers for swin transformer
574 |
575 | Stage layers contains a number of Transformer blocks and an optional
576 | patch merging layer, patch merging is not applied after last stage
577 |
578 | Attributes:
579 | dim: int, embedding dimension
580 | input_resolution: tuple, input resoliution
581 | depth: list, num of blocks in each stage
582 | blocks: nn.LayerList, contains SwinTransformerBlocks for one stage
583 | downsample: PatchMerging, patch merging layer, none if last stage
584 | """
585 | def __init__(self, dim, input_resolution, depth, num_heads, window_size,
586 | mlp_ratio=4., qkv_bias=True, qk_scale=None, dropout=0.,
587 | attention_dropout=0., droppath=0., downsample=None, sum_depth=None):
588 | super(SwinTransformerStage, self).__init__()
589 | self.dim = dim
590 | self.input_resolution = input_resolution
591 | self.depth = depth
592 |
593 | self.blocks = nn.LayerList()
594 | for i in range(depth):
595 | self.blocks.append(
596 | SwinTransformerBlock(
597 | dim=dim, input_resolution=input_resolution,
598 | num_heads=num_heads, window_size=window_size,
599 | shift_size=0 if (i % 2 == 0) else window_size // 2,
600 | mlp_ratio=mlp_ratio,
601 | extra_norm = sum_depth!=None and (i+sum_depth+1)%6==0, # Swin-T v2
602 | qkv_bias=qkv_bias, qk_scale=qk_scale,
603 | dropout=dropout, attention_dropout=attention_dropout,
604 | droppath=droppath[i] if isinstance(droppath, list) else droppath))
605 |
606 | if downsample is not None:
607 | self.downsample = downsample(input_resolution, dim=dim)
608 | else:
609 | self.downsample = None
610 |
611 | def forward(self, x):
612 | for block in self.blocks:
613 | x = block(x) # [bs,56*56,96] -> [bs,28*28,96*2] -> [bs,14*14,96*4] -> [bs,7*7,96*8]
614 | if self.downsample is not None:
615 | x = self.downsample(x) # [bs,28*28,96*2] -> [bs,14*14,96*4] -> [bs,7*7,96*8]
616 |
617 | return x
618 |
619 |
620 | class SwinTransformer(nn.Layer):
621 | """SwinTransformer class
622 |
623 | Attributes:
624 | num_classes: int, num of image classes
625 | num_stages: int, num of stages contains patch merging and Swin blocks
626 | depths: list of int, num of Swin blocks in each stage
627 | num_heads: int, num of heads in attention module
628 | embed_dim: int, output dimension of patch embedding
629 | num_features: int, output dimension of whole network before classifier
630 | mlp_ratio: float, hidden dimension of mlp layer is mlp_ratio * mlp input dim
631 | qkv_bias: bool, if True, set qkv layers have bias enabled
632 | qk_scale: float, scale factor for qk.
633 | ape: bool, if True, set to use absolute positional embeddings
634 | window_size: int, size of patch window for inputs
635 | dropout: float, dropout rate for linear layer
636 | dropout_attn: float, dropout rate for attention
637 | patch_embedding: PatchEmbedding, patch embedding instance
638 | patch_resolution: tuple, number of patches in row and column
639 | position_dropout: nn.Dropout, dropout op for position embedding
640 | stages: SwinTransformerStage, stage instances.
641 | norm: nn.LayerNorm, norm layer applied after transformer
642 | avgpool: nn.AveragePool2D, pooling layer before classifer
643 | fc: nn.Linear, classifier op.
644 | """
645 | def __init__(self,
646 | image_size=224,
647 | patch_size=4,
648 | in_channels=3,
649 | num_classes=1000,
650 | embed_dim=96,
651 | depths=[2, 2, 6, 2],
652 | num_heads=[3, 6, 12, 24],
653 | window_size=7,
654 | mlp_ratio=4.,
655 | qkv_bias=True,
656 | qk_scale=None,
657 | dropout=0.,
658 | attention_dropout=0.,
659 | droppath=0.,
660 | ape=False,
661 | extra_norm=False):
662 | super(SwinTransformer, self).__init__()
663 |
664 | self.num_classes = num_classes
665 | self.num_stages = len(depths)
666 | self.embed_dim = embed_dim
667 | self.num_features = int(self.embed_dim * 2 ** (self.num_stages - 1))
668 | self.mlp_ratio = mlp_ratio
669 | self.ape = ape
670 |
671 | self.patch_embedding = PatchEmbedding(image_size=image_size,
672 | patch_size=patch_size,
673 | in_channels=in_channels,
674 | embed_dim=embed_dim)
675 | num_patches = self.patch_embedding.num_patches
676 | self.patches_resolution = self.patch_embedding.patches_resolution
677 |
678 |
679 | if self.ape:
680 | self.absolute_positional_embedding = paddle.nn.ParameterList([
681 | paddle.create_parameter(
682 | shape=[1, num_patches, self.embed_dim], dtype='float32',
683 | default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02))])
684 |
685 | self.position_dropout = nn.Dropout(dropout)
686 |
687 | depth_decay = [x.item() for x in paddle.linspace(0, droppath, sum(depths))]
688 |
689 | self.stages = nn.LayerList()
690 | for stage_idx in range(self.num_stages):
691 | stage = SwinTransformerStage(
692 | dim=int(self.embed_dim * 2 ** stage_idx),
693 | input_resolution=(
694 | self.patches_resolution[0] // (2 ** stage_idx),
695 | self.patches_resolution[1] // (2 ** stage_idx)),
696 | depth=depths[stage_idx],
697 | sum_depth=sum(depths[:stage_idx]) if extra_norm else None, # Swin-T v2
698 | num_heads=num_heads[stage_idx],
699 | window_size=window_size,
700 | mlp_ratio=mlp_ratio,
701 | qkv_bias=qkv_bias,
702 | qk_scale=qk_scale,
703 | dropout=dropout,
704 | attention_dropout=attention_dropout,
705 | droppath=depth_decay[
706 | sum(depths[:stage_idx]):sum(depths[:stage_idx+1])],
707 | downsample=PatchMerging if (
708 | stage_idx < self.num_stages-1) else None,
709 | )
710 | self.stages.append(stage)
711 |
712 | w_attr_1, b_attr_1 = self._init_weights_layernorm()
713 | self.norm = nn.LayerNorm(self.num_features,
714 | weight_attr=w_attr_1,
715 | bias_attr=b_attr_1)
716 |
717 | self.avgpool = nn.AdaptiveAvgPool1D(1)
718 | w_attr_2, b_attr_2 = self._init_weights()
719 | self.fc = nn.Linear(self.num_features,
720 | self.num_classes,
721 | weight_attr=w_attr_2,
722 | bias_attr=b_attr_2)
723 |
724 | def _init_weights_layernorm(self):
725 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1))
726 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
727 | return weight_attr, bias_attr
728 |
729 | def _init_weights(self):
730 | weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
731 | bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
732 | return weight_attr, bias_attr
733 |
734 | def forward_features(self, x):
735 | x = self.patch_embedding(x) # [bs,H*W,96]
736 | if self.ape:
737 | x = x + self.absolute_positional_embedding
738 | x = self.position_dropout(x) # [bs,H*W,96]
739 |
740 | for stage in self.stages:
741 | x = stage(x) # [bs,784,192],[bs,196,384],[bs,49,768],[bs,49,768]
742 |
743 | x = self.norm(x) # [bs,49,768]
744 | x = x.transpose([0, 2, 1])
745 | x = self.avgpool(x) # [bs,768,1]
746 | x = x.flatten(1) # [bs,768]
747 | return x
748 |
749 | def forward(self, x):
750 | x = self.forward_features(x) # [bs,768]
751 | x = self.fc(x) # [bs,1000]
752 | return x
753 |
754 |
755 | def build_swin(config):
756 | model = SwinTransformer(
757 | image_size=config.DATA.IMAGE_SIZE,
758 | patch_size=config.MODEL.TRANS.PATCH_SIZE,
759 | in_channels=config.MODEL.TRANS.IN_CHANNELS,
760 | embed_dim=config.MODEL.TRANS.EMBED_DIM,
761 | num_classes=config.MODEL.NUM_CLASSES,
762 | depths=config.MODEL.TRANS.STAGE_DEPTHS,
763 | num_heads=config.MODEL.TRANS.NUM_HEADS,
764 | mlp_ratio=config.MODEL.TRANS.MLP_RATIO,
765 | qkv_bias=config.MODEL.TRANS.QKV_BIAS,
766 | qk_scale=config.MODEL.TRANS.QK_SCALE,
767 | ape=config.MODEL.TRANS.APE,
768 | window_size=config.MODEL.TRANS.WINDOW_SIZE,
769 | dropout=config.MODEL.DROPOUT,
770 | attention_dropout=config.MODEL.ATTENTION_DROPOUT,
771 | droppath=config.MODEL.DROP_PATH,
772 | extra_norm=config.MODEL.TRANS.EXTRA_NORM)
773 | return model
774 |
775 | if __name__ == '__main__':
776 | from main_single_gpu import get_arguments
777 | from config import get_config
778 | from config import update_config
779 | arguments = get_arguments()
780 | config = get_config()
781 | config = update_config(config, arguments)
782 |
783 | model = build_swin(config)
784 | image = paddle.randn([1, 3, 224, 224])
785 | output = model(image)
786 | print(output.shape)
787 |
--------------------------------------------------------------------------------